tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/device.py CHANGED
@@ -1,36 +1,40 @@
1
1
  from __future__ import annotations
2
2
  from dataclasses import dataclass, replace
3
3
  from collections import defaultdict
4
- from typing import Optional, Dict, Tuple, Any, Iterator
5
- import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib, sys
6
- from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
4
+ from typing import Optional, Any, Iterator, Generator
5
+ import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
6
+ from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
7
+ cpu_time_execution, colored, Context, round_up
7
8
  from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
8
9
  from tinygrad.renderer import Renderer
9
10
 
10
11
  # **************** Device ****************
11
12
 
13
+ ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM", "DSP", "WEBGPU"]
12
14
  class _Device:
13
15
  def __init__(self) -> None:
14
16
  self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
17
+ self._opened_devices:set[str] = set()
15
18
  @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
16
- def _canonicalize(self, device:str) -> str: return ((d:=device.split(":", 1)[0].upper()) + device[len(d):]).replace(":0", "")
19
+ def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):])
17
20
  # NOTE: you can't cache canonicalize in case Device.DEFAULT changes
18
21
  def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
19
22
  def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
20
23
  @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
21
24
  def __get_canonicalized_item(self, ix:str) -> Compiled:
22
25
  cpn = multiprocessing.current_process().name
23
- assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent, not {cpn}"
26
+ assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"can only open device {ix} from parent, not {cpn}"
24
27
  x = ix.split(":")[0].upper()
25
- ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{__name__.split(".")[0]}.runtime.ops_{x.lower()}')) \
28
+ ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) \
26
29
  if (cname.lower() == x.lower() + "device")][0](ix)
27
30
  if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
31
+ self._opened_devices.add(ix)
28
32
  return ret
29
33
  @property
30
34
  def default(self) -> Compiled: return self[self.DEFAULT]
31
35
  def get_available_devices(self) -> Iterator[str]:
32
- for device in ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM"]:
33
- with contextlib.suppress(Exception): yield self[device].dname
36
+ for device in ALL_DEVICES:
37
+ with contextlib.suppress(Exception): yield self[device].device
34
38
  @functools.cached_property
35
39
  def DEFAULT(self) -> str:
36
40
  if (from_env:=next((d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1), None)): return from_env
@@ -41,10 +45,39 @@ class _Device:
41
45
  except StopIteration as exc: raise RuntimeError("no usable devices") from exc
42
46
  Device = _Device()
43
47
 
48
+ # **************** Profile ****************
49
+
50
+ class ProfileEvent: pass
51
+
52
+ @dataclass(frozen=True)
53
+ class ProfileDeviceEvent(ProfileEvent):
54
+ device:str; comp_tdiff:decimal.Decimal=decimal.Decimal(0); copy_tdiff:decimal.Decimal=decimal.Decimal(0) # noqa: E702
55
+
56
+ @dataclass(frozen=True)
57
+ class ProfileRangeEvent(ProfileEvent): device:str; name:str; st:decimal.Decimal; en:decimal.Decimal; is_copy:bool # noqa: E702
58
+
59
+ @dataclass(frozen=True)
60
+ class ProfileGraphEntry: device:str; name:str; st_id:int; en_id:int; is_copy:bool # noqa: E702
61
+
62
+ @dataclass(frozen=True)
63
+ class ProfileGraphEvent(ProfileEvent): ents:list[ProfileGraphEntry]; deps:list[list[int]]; sigs:list[decimal.Decimal] # noqa: E702
64
+
65
+ @dataclass
66
+ class ProfileResult: st:Optional[int]=None; en:Optional[int]=None # noqa: E702
67
+
68
+ @contextlib.contextmanager
69
+ def cpu_profile(name, device="CPU", is_copy=False, display=True) -> Generator[ProfileResult, None, None]:
70
+ yield (res:=ProfileResult(st:=time.perf_counter_ns()))
71
+ res.en = en = time.perf_counter_ns()
72
+ if PROFILE and display:
73
+ Compiled.profile_events += [ProfileRangeEvent(device, name, decimal.Decimal(st) / 1000, decimal.Decimal(en) / 1000, is_copy=is_copy)]
74
+
44
75
  # **************** Buffer + Allocators ****************
45
76
 
77
+
46
78
  @dataclass(frozen=True, eq=True)
47
- class BufferOptions:
79
+ class BufferSpec:
80
+ # TODO: move device, size, dtype here?
48
81
  image: Optional[ImageDType] = None
49
82
  uncached: bool = False
50
83
  cpu_access: bool = False
@@ -53,9 +86,9 @@ class BufferOptions:
53
86
  external_ptr: Optional[int] = None
54
87
 
55
88
  class Buffer:
56
- def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
57
- initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
58
- if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
89
+ def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferSpec]=None, initial_value:Optional[bytes]=None,
90
+ lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
91
+ if isinstance(dtype, ImageDType): options = BufferSpec(image=dtype) # TODO: image hack shouldn't be here. where should it be?
59
92
  else: assert isinstance(dtype, DType) and not isinstance(dtype, PtrDType)
60
93
  self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
61
94
  if base is None:
@@ -80,17 +113,23 @@ class Buffer:
80
113
  def ensure_allocated(self) -> Buffer: return self.allocate() if not self.is_allocated() else self
81
114
  def allocate(self, opaque=None, external_ptr=None) -> Buffer:
82
115
  assert not self.is_allocated(), "can't allocate already allocated buffer"
83
- self.allocator = Device[self.device].allocator
116
+ self.allocator:Allocator = Device[self.device].allocator
84
117
  if external_ptr is not None:
85
- self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferOptions(external_ptr=external_ptr)
118
+ self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferSpec(external_ptr=external_ptr)
86
119
  if self._base is not None:
87
120
  self._base.ensure_allocated()
88
- assert hasattr(self.allocator, "offset"), "offset function required for view"
89
- self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset)
121
+ assert hasattr(self.allocator, "_offset"), "offset function required for view"
122
+ self._buf: Any = self.allocator._offset(self.base._buf, self.nbytes, self.offset)
90
123
  else:
91
124
  self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
92
125
  if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
93
126
  return self
127
+ def deallocate(self):
128
+ assert self.is_allocated(), "buffer must be allocated to deallocate"
129
+ if self._base is None and (self.options is None or self.options.external_ptr is None):
130
+ if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
131
+ self.allocator.free(self._buf, self.nbytes, self.options)
132
+ del self._buf
94
133
  def __reduce__(self):
95
134
  buf = None
96
135
  if self._base is not None:
@@ -102,31 +141,27 @@ class Buffer:
102
141
  return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
103
142
  @property
104
143
  def nbytes(self): return self.size*self.dtype.itemsize
105
- def __del__(self):
106
- if not self.is_allocated(): return
107
- if self._base is None and (self.options is None or self.options.external_ptr is None):
108
- if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
109
- self.allocator.free(self._buf, self.nbytes, self.options)
144
+ def __del__(self): (not self.is_allocated()) or self.deallocate()
110
145
  def __repr__(self):
111
146
  return f"<buf real:{self.is_allocated()} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
112
147
  (f" offset:{self.offset}" if hasattr(self, "base") else "") + (f" {self.options=}" if self.options is not None else "") + ">"
113
148
  def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
114
149
  # zero copy with as_buffer (disabled by default due to use after free)
115
- if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer') and (self.options is None or self.options.image is None):
116
- return self.allocator.as_buffer(self._buf)
150
+ if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and (self.options is None or self.options.image is None):
151
+ return self.allocator._as_buffer(self._buf)
117
152
  assert not force_zero_copy, "force zero copy was passed, but copy is required"
118
153
  return self.copyout(memoryview(bytearray(self.nbytes)))
119
154
  def copyin(self, mv:memoryview):
120
155
  mv = flat_mv(mv)
121
156
  assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
122
157
  assert self.is_allocated(), "can't copyin to unallocated buffer"
123
- self.allocator.copyin(self._buf, mv)
158
+ self.allocator._copyin(self._buf, mv)
124
159
  return self
125
160
  def copyout(self, mv:memoryview) -> memoryview:
126
161
  mv = flat_mv(mv)
127
162
  assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
128
163
  assert self.is_allocated(), "can't copyout unallocated buffer"
129
- self.allocator.copyout(mv, self._buf)
164
+ self.allocator._copyout(mv, self._buf)
130
165
  return mv
131
166
  def view(self, size:int, dtype:DType, offset:int) -> Buffer:
132
167
  assert offset < self.nbytes, "offset must be less than nbytes"
@@ -135,22 +170,28 @@ class Buffer:
135
170
 
136
171
  # TODO: size, dest, src are the same type. can we enforce this?
137
172
  class Allocator:
138
- def alloc(self, size:int, options:Optional[BufferOptions]=None):
139
- assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
140
- return self._alloc(size, options if options is not None else BufferOptions())
141
- def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
142
- def free(self, opaque, size:int, options:Optional[BufferOptions]=None): self._free(opaque, options if options is not None else BufferOptions())
143
- def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
144
- def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
145
- def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
146
-
147
- class LRUAllocator(Allocator): # pylint: disable=abstract-method
173
+ # overridden in LRUAllocator
174
+ def alloc(self, size:int, options:Optional[BufferSpec]=None):
175
+ assert size > 0, f"alloc size must be positive, getting {size}"
176
+ return self._alloc(size, options if options is not None else BufferSpec())
177
+ def free(self, opaque, size:int, options:Optional[BufferSpec]=None): self._free(opaque, options if options is not None else BufferSpec())
178
+
179
+ # implemented by the runtime
180
+ def _alloc(self, size:int, options:BufferSpec): raise NotImplementedError("need alloc")
181
+ def _free(self, opaque, options:BufferSpec): pass # if opaque is a Python object, you don't need a free
182
+ def _copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
183
+ def _copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
184
+ # def _as_buffer(self, src) -> memoryview:
185
+ # def _offset(self, buf, size:int, offset:int):
186
+ # def _transfer(self, dest, src, sz:int, src_dev, dest_dev):
187
+
188
+ class LRUAllocator(Allocator):
148
189
  """
149
190
  The LRU Allocator is responsible for caching buffers.
150
191
  It ensures that buffers are not freed until it is absolutely necessary, optimizing performance.
151
192
  """
152
- def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
153
- def alloc(self, size:int, options:Optional[BufferOptions]=None):
193
+ def __init__(self): self.cache: dict[tuple[int, Optional[BufferSpec]], Any] = defaultdict(list)
194
+ def alloc(self, size:int, options:Optional[BufferSpec]=None):
154
195
  if len(c := self.cache[(size, options)]): return c.pop()
155
196
  try: return super().alloc(size, options)
156
197
  except (RuntimeError, MemoryError):
@@ -160,20 +201,67 @@ class LRUAllocator(Allocator): # pylint: disable=abstract-method
160
201
  for (sz,options),opaques in self.cache.items():
161
202
  for opaque in opaques: super().free(opaque, sz, options)
162
203
  opaques.clear()
163
- def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None):
164
- if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
204
+ def free(self, opaque:Any, size:int, options:Optional[BufferSpec]=None):
205
+ if LRU and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
165
206
  else: super().free(opaque, size, options)
166
207
 
167
208
  class _MallocAllocator(LRUAllocator):
168
- def _alloc(self, size:int, options:BufferOptions):
169
- return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else (ctypes.c_uint8 * size)()
170
- def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
171
- def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
172
- def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
173
- def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size])
209
+ def _alloc(self, size:int, options:BufferSpec):
210
+ return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else self._alloc_aligned(size, 16)
211
+ def _alloc_aligned(self, size:int, alignment:int):
212
+ buffer = (ctypes.c_uint8 * (size + alignment))()
213
+ offset = round_up(ctypes.addressof(buffer), alignment) - ctypes.addressof(buffer)
214
+ return (ctypes.c_uint8 * size).from_buffer(buffer, offset)
215
+ def _as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
216
+ def _copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
217
+ def _copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
218
+ def _offset(self, buf, size:int, offset:int): return from_mv(self._as_buffer(buf)[offset:offset+size])
174
219
 
175
220
  MallocAllocator = _MallocAllocator()
176
221
 
222
+ # NOTE: MAP_JIT is added to mmap module in python 3.13
223
+ MAP_JIT = 0x0800
224
+
225
+ # CPUProgram is a jit/shellcode program that can be just mmapped and jumped to
226
+ class CPUProgram:
227
+ helper_handle = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or sys.platform == "win32" else 'libgcc_s.so.1')
228
+ def __init__(self, name:str, lib:bytes):
229
+ if sys.platform == "win32":
230
+ PAGE_EXECUTE_READWRITE = 0x40
231
+ MEM_COMMIT = 0x1000
232
+ MEM_RESERVE = 0x2000
233
+ ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_uint64
234
+ ptr = ctypes.windll.kernel32.VirtualAlloc(ctypes.c_int(0), ctypes.c_int(len(lib)), MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE)
235
+ ctypes.memmove(ptr, lib, len(lib))
236
+ self.fxn = ctypes.CFUNCTYPE(None)(ptr)
237
+ else:
238
+ from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
239
+ # On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/
240
+ # MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np)
241
+ self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC)
242
+
243
+ if OSX: CPUProgram.helper_handle.pthread_jit_write_protect_np(False)
244
+ self.mem.write(lib)
245
+ if OSX: CPUProgram.helper_handle.pthread_jit_write_protect_np(True)
246
+
247
+ # __clear_cache isn't a normal libc function, but a compiler support routine found in libgcc_s for gcc and compiler-rt for clang.
248
+ # libgcc_s comes as shared library but compiler-rt is only a bunch of static library archives which we can't directly load, but fortunately
249
+ # it somehow found its way into libSystem on macos (likely because it used __builtin_clear_cache) and libgcc_s is ~always present on linux
250
+ # Using ["name"] instead of .name because otherwise name is getting mangled: https://docs.python.org/3.12/reference/expressions.html#index-5
251
+ CPUProgram.helper_handle["__clear_cache"](ctypes.c_void_p(mv_address(self.mem)), ctypes.c_void_p(mv_address(self.mem) + len(lib)))
252
+
253
+ self.fxn = ctypes.CFUNCTYPE(None)(mv_address(self.mem))
254
+
255
+ def __call__(self, *bufs, vals=(), wait=False):
256
+ args = list(bufs) + list(vals)
257
+ # NOTE: replace this by --target={host's triple}-elf in clang args once we only support macos sequoia and later.
258
+ # Apple relaxes abi requirement for stack arguments to always be at least 8 byte aligned on arm64
259
+ # https://developer.apple.com/documentation/xcode/writing-arm64-code-for-apple-platforms
260
+ # This hack is required because clang/llvm bug doesn't allow us to just use {host's triple}+'-elf' (relocation failures)
261
+ # The bug was fixed in https://github.com/llvm/llvm-project/commit/454cc36630296262cdb6360b60f90a64a97f7f1a but was only backported to xcode 16+
262
+ if platform.machine() == "arm64" and OSX: args = args[:8] + [ctypes.c_int64(a) if isinstance(a, int) else a for a in args[8:]]
263
+ return cpu_time_execution(lambda: self.fxn(*args), enable=wait)
264
+
177
265
  # **************** for Compiled Devices ****************
178
266
 
179
267
  class CompileError(Exception): pass
@@ -190,8 +278,10 @@ class Compiler:
190
278
  def disassemble(self, lib:bytes): pass
191
279
 
192
280
  class Compiled:
281
+ profile_events:list[ProfileEvent] = [ProfileDeviceEvent("CPU")] # NOTE: CPU is the default device.
282
+
193
283
  def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
194
- self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
284
+ self.device, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
195
285
  self.renderer = renderer or Renderer()
196
286
  def synchronize(self):
197
287
  """
@@ -200,6 +290,11 @@ class Compiled:
200
290
  This method ensures that all previously queued operations on the device have been completed before proceeding.
201
291
  """
202
292
  # override this in your device implementation
293
+ def _at_profile_finalize(self):
294
+ """
295
+ Called at the end of profiling to allow the device to finalize any profiling.
296
+ """
297
+ # override this in your device implementation
203
298
 
204
299
  # TODO: move this to each Device
205
300
  def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
@@ -207,7 +302,8 @@ def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
207
302
  if dtype == dtypes.bfloat16:
208
303
  # NOTE: this requires bf16 buffer support
209
304
  return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
210
- if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
305
+ if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
306
+ dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32]
211
307
  # for CI GPU and OSX, cl_khr_fp16 isn't supported
212
308
  # for CI LLVM, it segfaults because it can't link to the casting function
213
309
  # CI CUDA architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
@@ -219,3 +315,30 @@ def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
219
315
  if device == "PYTHON": return sys.version_info >= (3, 12)
220
316
  if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
221
317
  return True
318
+
319
+ if PROFILE:
320
+ @atexit.register
321
+ def finalize_profile():
322
+ devs = [Device[d] for d in Device._opened_devices]
323
+ for dev in devs: dev.synchronize()
324
+ for dev in devs: dev._at_profile_finalize()
325
+
326
+ with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(Compiled.profile_events, f)
327
+
328
+ from tinygrad.ops import launch_viz
329
+ launch_viz("PROFILE", fn)
330
+
331
+ if __name__ == "__main__":
332
+ for device in ALL_DEVICES:
333
+ try:
334
+ _ = Device[device].device
335
+ try:
336
+ from tinygrad import Tensor
337
+ with Context(CACHELEVEL=0): test = (Tensor([1,2,3], device=device) * 2).tolist()
338
+ if test != [2,4,6]: raise ValueError(f"got {test} instead of [2, 4, 6]")
339
+ result = colored("PASS", "green")
340
+ except Exception as e:
341
+ result = f"{colored('FAIL', 'yellow')} {e}"
342
+ except Exception as e:
343
+ result = f"{colored('FAIL', 'red')} {e}"
344
+ print(f"{'*' if device == Device.DEFAULT else ' '} {device:10s}: {result}")
tinygrad/dtype.py CHANGED
@@ -1,14 +1,16 @@
1
1
  from __future__ import annotations
2
- from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable
2
+ from typing import Final, Optional, ClassVar, Union, Callable, Literal
3
3
  import math, struct, ctypes, functools
4
4
  from dataclasses import dataclass, fields
5
- from tinygrad.helpers import getenv
5
+ from tinygrad.helpers import getenv, prod
6
6
 
7
7
  ConstType = Union[float, int, bool]
8
8
 
9
+ FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd']
10
+
9
11
  # all DTypes should only be created once
10
12
  class DTypeMetaClass(type):
11
- dcache: Dict[Tuple, DType] = {}
13
+ dcache: dict[tuple, DType] = {}
12
14
  def __call__(cls, *args, **kwargs):
13
15
  if (ret:=DTypeMetaClass.dcache.get(args, None)) is not None: return ret
14
16
  DTypeMetaClass.dcache[args] = ret = super().__call__(*args)
@@ -19,11 +21,11 @@ class DType(metaclass=DTypeMetaClass):
19
21
  priority: int # this determines when things get upcasted
20
22
  itemsize: int
21
23
  name: str
22
- fmt: Optional[str]
24
+ fmt: Optional[FmtStr]
23
25
  count: int
24
26
  _scalar: Optional[DType]
25
27
  @staticmethod
26
- def new(priority:int, itemsize:int, name:str, fmt:Optional[str]): return DType(priority, itemsize, name, fmt, 1, None)
28
+ def new(priority:int, itemsize:int, name:str, fmt:Optional[FmtStr]): return DType(priority, itemsize, name, fmt, 1, None)
27
29
  def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
28
30
  def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "")
29
31
  def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
@@ -36,7 +38,8 @@ class DType(metaclass=DTypeMetaClass):
36
38
  assert self.count == 1, f"can't vectorize {self} with size {sz}"
37
39
  if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
38
40
  return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
39
- def ptr(self, local=False) -> PtrDType: return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1)
41
+ def ptr(self, size=-1, local=False) -> PtrDType:
42
+ return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1, size)
40
43
  def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
41
44
 
42
45
  @dataclass(frozen=True, eq=False)
@@ -44,22 +47,24 @@ class PtrDType(DType):
44
47
  _base: DType
45
48
  local: bool
46
49
  v: int
50
+ size: int = -1 # -1 is unlimited size
47
51
  @property
48
52
  def base(self): return self._base
49
53
  @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
50
54
  def vec(self, sz:int) -> DType:
51
55
  assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
52
56
  if sz == 1: return self # sz=1 is a scalar
53
- return type(self)(*tuple(sz if f.name == 'v' else (self if f.name == '_scalar' else getattr(self, f.name)) for f in fields(self)))
54
- def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer")
57
+ return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz)
58
+ def ptr(self, size=-1, local=False): raise RuntimeError("can't make a pointer from a pointer")
55
59
  @property
56
60
  def vcount(self): return self.v
57
- def __repr__(self): return f"{self.base.__repr__()}.ptr({'local=True' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
61
+ def __repr__(self):
62
+ return f"{self.base.__repr__()}.ptr({self.size}{', local=True' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
58
63
 
59
64
  @dataclass(frozen=True, eq=False)
60
65
  class ImageDType(PtrDType):
61
- shape: Tuple[int, ...] = () # shape of the Image
62
- def ptr(self, local=False) -> PtrDType:
66
+ shape: tuple[int, ...] = () # shape of the Image
67
+ def ptr(self, size=-1, local=False) -> PtrDType:
63
68
  assert not local, "images can't be local"
64
69
  return self
65
70
  def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
@@ -68,7 +73,7 @@ class dtypes:
68
73
  @staticmethod
69
74
  @functools.lru_cache(None)
70
75
  def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType)
71
- @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
76
+ @staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool
72
77
  @functools.lru_cache(None)
73
78
  def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints
74
79
  @staticmethod
@@ -83,7 +88,7 @@ class dtypes:
83
88
  if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
84
89
  raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
85
90
  @staticmethod
86
- def as_const(val: Tuple[ConstType, ...]|ConstType, dtype:DType):
91
+ def as_const(val: tuple[ConstType, ...]|ConstType, dtype:DType):
87
92
  if isinstance(val, tuple):
88
93
  assert len(val) == dtype.count, f"mismatch {val} {dtype}"
89
94
  return tuple(dtypes.as_const(x, dtype) for x in val)
@@ -97,18 +102,18 @@ class dtypes:
97
102
  @staticmethod
98
103
  @functools.lru_cache(None)
99
104
  def max(dtype:DType):
100
- if dtypes.is_int(dtype): return (2**(dtype.itemsize*8-(0 if dtypes.is_unsigned(dtype) else 1)))-1
105
+ if dtypes.is_int(dtype): return 2**(dtype.itemsize*8)-1+dtypes.min(dtype)
101
106
  return float("inf") if dtypes.is_float(dtype) else True
102
107
  @staticmethod
103
- def finfo(dtype:DType) -> Tuple[int, int]:
108
+ def finfo(dtype:DType) -> tuple[int, int]:
104
109
  """(exponent, mantissa)"""
105
110
  if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type")
106
111
  return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52)}[dtype]
107
112
  @staticmethod
108
- def fields() -> Dict[str, DType]: return DTYPES_DICT
113
+ def fields() -> dict[str, DType]: return DTYPES_DICT
109
114
  void: Final[DType] = DType.new(-1, 0, "void", None)
110
115
  bool: Final[DType] = DType.new(0, 1, "bool", '?')
111
- int8: Final[DType] = DType.new(1, 1, "char", 'b')
116
+ int8: Final[DType] = DType.new(1, 1, "signed char", 'b')
112
117
  uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B')
113
118
  int16: Final[DType] = DType.new(3, 2, "short", 'h')
114
119
  uint16: Final[DType] = DType.new(4, 2, "unsigned short", 'H')
@@ -129,9 +134,9 @@ class dtypes:
129
134
 
130
135
  # NOTE: these are image dtypes
131
136
  @staticmethod
132
- def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, shp)
137
+ def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, prod(shp), shp)
133
138
  @staticmethod
134
- def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, shp)
139
+ def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, prod(shp), shp)
135
140
 
136
141
  default_float: ClassVar[DType] = float32
137
142
  default_int: ClassVar[DType] = int32
@@ -156,18 +161,15 @@ promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes
156
161
  dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
157
162
 
158
163
  @functools.lru_cache(None)
159
- def _get_recursive_parents(dtype:DType) -> Set[DType]:
164
+ def _get_recursive_parents(dtype:DType) -> set[DType]:
160
165
  return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
161
166
  @functools.lru_cache(None)
162
167
  def least_upper_dtype(*ds:DType) -> DType:
163
168
  return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
164
169
  def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
165
170
 
166
- # HACK: staticmethods are not callable in 3.8 so we have to compare the class
167
- DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'void'))
168
- or v.__class__ is staticmethod or isinstance(v, tuple))}
169
- INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
170
- INVERSE_DTYPES_DICT['void'] = 'void'
171
+ DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void"))}
172
+ INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void"}
171
173
 
172
174
  def sum_acc_dtype(dt:DType):
173
175
  # default acc dtype for sum
@@ -179,7 +181,7 @@ def truncate_fp16(x):
179
181
  try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
180
182
  except OverflowError: return math.copysign(math.inf, x)
181
183
 
182
- truncate: Dict[DType, Callable] = {dtypes.bool: bool,
184
+ truncate: dict[DType, Callable] = {dtypes.bool: bool,
183
185
  # TODO: bfloat16
184
186
  dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
185
187
  dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,