tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/device.py CHANGED
@@ -1,36 +1,40 @@
1
1
  from __future__ import annotations
2
2
  from dataclasses import dataclass, replace
3
3
  from collections import defaultdict
4
- from typing import Optional, Dict, Tuple, Any, Iterator
5
- import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib, sys
6
- from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
4
+ from typing import Optional, Any, Iterator, Generator
5
+ import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
6
+ from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
7
+ cpu_time_execution, colored, Context, round_up
7
8
  from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
8
9
  from tinygrad.renderer import Renderer
9
10
 
10
11
  # **************** Device ****************
11
12
 
13
+ ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CPU", "LLVM", "DSP", "WEBGPU"]
12
14
  class _Device:
13
15
  def __init__(self) -> None:
14
16
  self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
17
+ self._opened_devices:set[str] = set()
15
18
  @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
16
- def _canonicalize(self, device:str) -> str: return ((d:=device.split(":", 1)[0].upper()) + device[len(d):]).replace(":0", "")
19
+ def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):])
17
20
  # NOTE: you can't cache canonicalize in case Device.DEFAULT changes
18
21
  def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
19
22
  def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
20
23
  @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
21
24
  def __get_canonicalized_item(self, ix:str) -> Compiled:
22
25
  cpn = multiprocessing.current_process().name
23
- assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent, not {cpn}"
26
+ assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"can only open device {ix} from parent, not {cpn}"
24
27
  x = ix.split(":")[0].upper()
25
- ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{__name__.split(".")[0]}.runtime.ops_{x.lower()}')) \
28
+ ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) \
26
29
  if (cname.lower() == x.lower() + "device")][0](ix)
27
30
  if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
31
+ self._opened_devices.add(ix)
28
32
  return ret
29
33
  @property
30
34
  def default(self) -> Compiled: return self[self.DEFAULT]
31
35
  def get_available_devices(self) -> Iterator[str]:
32
- for device in ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM"]:
33
- with contextlib.suppress(Exception): yield self[device].dname
36
+ for device in ALL_DEVICES:
37
+ with contextlib.suppress(Exception): yield self[device].device
34
38
  @functools.cached_property
35
39
  def DEFAULT(self) -> str:
36
40
  if (from_env:=next((d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1), None)): return from_env
@@ -40,11 +44,41 @@ class _Device:
40
44
  return device
41
45
  except StopIteration as exc: raise RuntimeError("no usable devices") from exc
42
46
  Device = _Device()
47
+ atexit.register(lambda: [Device[dn].finalize() for dn in Device._opened_devices])
48
+
49
+ # **************** Profile ****************
50
+
51
+ class ProfileEvent: pass
52
+
53
+ @dataclass(frozen=True)
54
+ class ProfileDeviceEvent(ProfileEvent):
55
+ device:str; comp_tdiff:decimal.Decimal=decimal.Decimal(0); copy_tdiff:decimal.Decimal=decimal.Decimal(0) # noqa: E702
56
+
57
+ @dataclass(frozen=True)
58
+ class ProfileRangeEvent(ProfileEvent): device:str; name:str; st:decimal.Decimal; en:decimal.Decimal; is_copy:bool # noqa: E702
59
+
60
+ @dataclass(frozen=True)
61
+ class ProfileGraphEntry: device:str; name:str; st_id:int; en_id:int; is_copy:bool # noqa: E702
62
+
63
+ @dataclass(frozen=True)
64
+ class ProfileGraphEvent(ProfileEvent): ents:list[ProfileGraphEntry]; deps:list[list[int]]; sigs:list[decimal.Decimal] # noqa: E702
65
+
66
+ @dataclass
67
+ class ProfileResult: st:Optional[int]=None; en:Optional[int]=None # noqa: E702
68
+
69
+ @contextlib.contextmanager
70
+ def cpu_profile(name, device="CPU", is_copy=False, display=True) -> Generator[ProfileResult, None, None]:
71
+ yield (res:=ProfileResult(st:=time.perf_counter_ns()))
72
+ res.en = en = time.perf_counter_ns()
73
+ if PROFILE and display:
74
+ Compiled.profile_events += [ProfileRangeEvent(device, name, decimal.Decimal(st) / 1000, decimal.Decimal(en) / 1000, is_copy=is_copy)]
43
75
 
44
76
  # **************** Buffer + Allocators ****************
45
77
 
78
+
46
79
  @dataclass(frozen=True, eq=True)
47
- class BufferOptions:
80
+ class BufferSpec:
81
+ # TODO: move device, size, dtype here?
48
82
  image: Optional[ImageDType] = None
49
83
  uncached: bool = False
50
84
  cpu_access: bool = False
@@ -53,9 +87,9 @@ class BufferOptions:
53
87
  external_ptr: Optional[int] = None
54
88
 
55
89
  class Buffer:
56
- def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
57
- initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
58
- if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
90
+ def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferSpec]=None, initial_value:Optional[bytes]=None,
91
+ lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
92
+ if isinstance(dtype, ImageDType): options = BufferSpec(image=dtype) # TODO: image hack shouldn't be here. where should it be?
59
93
  else: assert isinstance(dtype, DType) and not isinstance(dtype, PtrDType)
60
94
  self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
61
95
  if base is None:
@@ -80,17 +114,23 @@ class Buffer:
80
114
  def ensure_allocated(self) -> Buffer: return self.allocate() if not self.is_allocated() else self
81
115
  def allocate(self, opaque=None, external_ptr=None) -> Buffer:
82
116
  assert not self.is_allocated(), "can't allocate already allocated buffer"
83
- self.allocator = Device[self.device].allocator
117
+ self.allocator:Allocator = Device[self.device].allocator
84
118
  if external_ptr is not None:
85
- self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferOptions(external_ptr=external_ptr)
119
+ self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferSpec(external_ptr=external_ptr)
86
120
  if self._base is not None:
87
121
  self._base.ensure_allocated()
88
- assert hasattr(self.allocator, "offset"), "offset function required for view"
89
- self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset)
122
+ assert hasattr(self.allocator, "_offset"), "offset function required for view"
123
+ self._buf: Any = self.allocator._offset(self.base._buf, self.nbytes, self.offset)
90
124
  else:
91
125
  self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
92
126
  if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
93
127
  return self
128
+ def deallocate(self):
129
+ assert self.is_allocated(), "buffer must be allocated to deallocate"
130
+ if self._base is None and (self.options is None or self.options.external_ptr is None):
131
+ if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
132
+ self.allocator.free(self._buf, self.nbytes, self.options)
133
+ del self._buf
94
134
  def __reduce__(self):
95
135
  buf = None
96
136
  if self._base is not None:
@@ -102,31 +142,27 @@ class Buffer:
102
142
  return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
103
143
  @property
104
144
  def nbytes(self): return self.size*self.dtype.itemsize
105
- def __del__(self):
106
- if not self.is_allocated(): return
107
- if self._base is None and (self.options is None or self.options.external_ptr is None):
108
- if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
109
- self.allocator.free(self._buf, self.nbytes, self.options)
145
+ def __del__(self): (not self.is_allocated()) or self.deallocate()
110
146
  def __repr__(self):
111
147
  return f"<buf real:{self.is_allocated()} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
112
148
  (f" offset:{self.offset}" if hasattr(self, "base") else "") + (f" {self.options=}" if self.options is not None else "") + ">"
113
149
  def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
114
150
  # zero copy with as_buffer (disabled by default due to use after free)
115
- if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer') and (self.options is None or self.options.image is None):
116
- return self.allocator.as_buffer(self._buf)
151
+ if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and (self.options is None or self.options.image is None):
152
+ return self.allocator._as_buffer(self._buf)
117
153
  assert not force_zero_copy, "force zero copy was passed, but copy is required"
118
154
  return self.copyout(memoryview(bytearray(self.nbytes)))
119
155
  def copyin(self, mv:memoryview):
120
156
  mv = flat_mv(mv)
121
157
  assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
122
158
  assert self.is_allocated(), "can't copyin to unallocated buffer"
123
- self.allocator.copyin(self._buf, mv)
159
+ self.allocator._copyin(self._buf, mv)
124
160
  return self
125
161
  def copyout(self, mv:memoryview) -> memoryview:
126
162
  mv = flat_mv(mv)
127
163
  assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
128
164
  assert self.is_allocated(), "can't copyout unallocated buffer"
129
- self.allocator.copyout(mv, self._buf)
165
+ self.allocator._copyout(mv, self._buf)
130
166
  return mv
131
167
  def view(self, size:int, dtype:DType, offset:int) -> Buffer:
132
168
  assert offset < self.nbytes, "offset must be less than nbytes"
@@ -135,22 +171,28 @@ class Buffer:
135
171
 
136
172
  # TODO: size, dest, src are the same type. can we enforce this?
137
173
  class Allocator:
138
- def alloc(self, size:int, options:Optional[BufferOptions]=None):
139
- assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
140
- return self._alloc(size, options if options is not None else BufferOptions())
141
- def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
142
- def free(self, opaque, size:int, options:Optional[BufferOptions]=None): self._free(opaque, options if options is not None else BufferOptions())
143
- def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
144
- def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
145
- def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
146
-
147
- class LRUAllocator(Allocator): # pylint: disable=abstract-method
174
+ # overridden in LRUAllocator
175
+ def alloc(self, size:int, options:Optional[BufferSpec]=None):
176
+ assert size > 0, f"alloc size must be positive, getting {size}"
177
+ return self._alloc(size, options if options is not None else BufferSpec())
178
+ def free(self, opaque, size:int, options:Optional[BufferSpec]=None): self._free(opaque, options if options is not None else BufferSpec())
179
+
180
+ # implemented by the runtime
181
+ def _alloc(self, size:int, options:BufferSpec): raise NotImplementedError("need alloc")
182
+ def _free(self, opaque, options:BufferSpec): pass # if opaque is a Python object, you don't need a free
183
+ def _copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
184
+ def _copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
185
+ # def _as_buffer(self, src) -> memoryview:
186
+ # def _offset(self, buf, size:int, offset:int):
187
+ # def _transfer(self, dest, src, sz:int, src_dev, dest_dev):
188
+
189
+ class LRUAllocator(Allocator):
148
190
  """
149
191
  The LRU Allocator is responsible for caching buffers.
150
192
  It ensures that buffers are not freed until it is absolutely necessary, optimizing performance.
151
193
  """
152
- def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
153
- def alloc(self, size:int, options:Optional[BufferOptions]=None):
194
+ def __init__(self): self.cache: dict[tuple[int, Optional[BufferSpec]], Any] = defaultdict(list)
195
+ def alloc(self, size:int, options:Optional[BufferSpec]=None):
154
196
  if len(c := self.cache[(size, options)]): return c.pop()
155
197
  try: return super().alloc(size, options)
156
198
  except (RuntimeError, MemoryError):
@@ -160,20 +202,78 @@ class LRUAllocator(Allocator): # pylint: disable=abstract-method
160
202
  for (sz,options),opaques in self.cache.items():
161
203
  for opaque in opaques: super().free(opaque, sz, options)
162
204
  opaques.clear()
163
- def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None):
164
- if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
205
+ def free(self, opaque:Any, size:int, options:Optional[BufferSpec]=None):
206
+ if LRU and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
165
207
  else: super().free(opaque, size, options)
166
208
 
167
209
  class _MallocAllocator(LRUAllocator):
168
- def _alloc(self, size:int, options:BufferOptions):
169
- return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else (ctypes.c_uint8 * size)()
170
- def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
171
- def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
172
- def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
173
- def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size])
210
+ def _alloc(self, size:int, options:BufferSpec):
211
+ # must be aligned to 0x20 for 256-bit ymm registers
212
+ # TODO: investigate if this is the cause of nondeterminism in speed
213
+ alignment = 0x1000 if size >= 0x1000 else 0x20
214
+ return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else self._alloc_aligned(size, alignment)
215
+ def _alloc_aligned(self, size:int, alignment:int):
216
+ buffer = (ctypes.c_uint8 * (size + alignment))()
217
+ offset = round_up(ctypes.addressof(buffer), alignment) - ctypes.addressof(buffer)
218
+ return (ctypes.c_uint8 * size).from_buffer(buffer, offset)
219
+ def _as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
220
+ def _copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
221
+ def _copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
222
+ def _offset(self, buf, size:int, offset:int): return from_mv(self._as_buffer(buf)[offset:offset+size])
174
223
 
175
224
  MallocAllocator = _MallocAllocator()
176
225
 
226
+ # NOTE: MAP_JIT is added to mmap module in python 3.13
227
+ MAP_JIT = 0x0800
228
+
229
+ # CPUProgram is a jit/shellcode program that can be just mmapped and jumped to
230
+ class CPUProgram:
231
+ rt_lib = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or sys.platform == "win32" else 'libgcc_s.so.1')
232
+ atomic_lib = ctypes.CDLL(ctypes.util.find_library('atomic')) if sys.platform == "linux" else None
233
+
234
+ def __init__(self, name:str, lib:bytes):
235
+ if sys.platform == "win32":
236
+ PAGE_EXECUTE_READWRITE = 0x40
237
+ MEM_COMMIT = 0x1000
238
+ MEM_RESERVE = 0x2000
239
+ ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p
240
+ self.mem = ctypes.windll.kernel32.VirtualAlloc(ctypes.c_void_p(0), ctypes.c_size_t(len(lib)), MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE)
241
+ ctypes.memmove(self.mem, lib, len(lib))
242
+ ctypes.windll.kernel32.GetCurrentProcess.restype = ctypes.c_void_p
243
+ proc = ctypes.windll.kernel32.GetCurrentProcess()
244
+ ctypes.windll.kernel32.FlushInstructionCache(ctypes.c_void_p(proc), ctypes.c_void_p(self.mem), ctypes.c_size_t(len(lib)))
245
+ self.fxn = ctypes.CFUNCTYPE(None)(self.mem)
246
+ else:
247
+ from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
248
+ # On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/
249
+ # MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np)
250
+ self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC)
251
+
252
+ if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(False)
253
+ self.mem.write(lib)
254
+ if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(True)
255
+
256
+ # __clear_cache isn't a normal libc function, but a compiler support routine found in libgcc_s for gcc and compiler-rt for clang.
257
+ # libgcc_s comes as shared library but compiler-rt is only a bunch of static library archives which we can't directly load, but fortunately
258
+ # it somehow found its way into libSystem on macos (likely because it used __builtin_clear_cache) and libgcc_s is ~always present on linux
259
+ # Using ["name"] instead of .name because otherwise name is getting mangled: https://docs.python.org/3.12/reference/expressions.html#index-5
260
+ CPUProgram.rt_lib["__clear_cache"](ctypes.c_void_p(mv_address(self.mem)), ctypes.c_void_p(mv_address(self.mem) + len(lib)))
261
+
262
+ self.fxn = ctypes.CFUNCTYPE(None)(mv_address(self.mem))
263
+
264
+ def __call__(self, *bufs, vals=(), wait=False):
265
+ args = list(bufs) + list(vals)
266
+ # NOTE: replace this by --target={host's triple}-elf in clang args once we only support macos sequoia and later.
267
+ # Apple relaxes abi requirement for stack arguments to always be at least 8 byte aligned on arm64
268
+ # https://developer.apple.com/documentation/xcode/writing-arm64-code-for-apple-platforms
269
+ # This hack is required because clang/llvm bug doesn't allow us to just use {host's triple}+'-elf' (relocation failures)
270
+ # The bug was fixed in https://github.com/llvm/llvm-project/commit/454cc36630296262cdb6360b60f90a64a97f7f1a but was only backported to xcode 16+
271
+ if platform.machine() == "arm64" and OSX: args = args[:8] + [ctypes.c_int64(a) if isinstance(a, int) else a for a in args[8:]]
272
+ return cpu_time_execution(lambda: self.fxn(*args), enable=wait)
273
+
274
+ def __del__(self):
275
+ if sys.platform == 'win32': ctypes.windll.kernel32.VirtualFree(ctypes.c_void_p(self.mem), ctypes.c_size_t(0), 0x8000) #0x8000 - MEM_RELEASE
276
+
177
277
  # **************** for Compiled Devices ****************
178
278
 
179
279
  class CompileError(Exception): pass
@@ -190,8 +290,10 @@ class Compiler:
190
290
  def disassemble(self, lib:bytes): pass
191
291
 
192
292
  class Compiled:
293
+ profile_events:list[ProfileEvent] = [ProfileDeviceEvent("CPU")] # NOTE: CPU is the default device.
294
+
193
295
  def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
194
- self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
296
+ self.device, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
195
297
  self.renderer = renderer or Renderer()
196
298
  def synchronize(self):
197
299
  """
@@ -200,6 +302,16 @@ class Compiled:
200
302
  This method ensures that all previously queued operations on the device have been completed before proceeding.
201
303
  """
202
304
  # override this in your device implementation
305
+ def _at_profile_finalize(self):
306
+ """
307
+ Called at the end of profiling to allow the device to finalize any profiling.
308
+ """
309
+ # override this in your device implementation
310
+ def finalize(self):
311
+ """
312
+ Called at the end of process lifetime to allow the device to finalize.
313
+ """
314
+ # override this in your device implementation
203
315
 
204
316
  # TODO: move this to each Device
205
317
  def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
@@ -207,7 +319,8 @@ def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
207
319
  if dtype == dtypes.bfloat16:
208
320
  # NOTE: this requires bf16 buffer support
209
321
  return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
210
- if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
322
+ if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
323
+ dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
211
324
  # for CI GPU and OSX, cl_khr_fp16 isn't supported
212
325
  # for CI LLVM, it segfaults because it can't link to the casting function
213
326
  # CI CUDA architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
@@ -219,3 +332,30 @@ def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
219
332
  if device == "PYTHON": return sys.version_info >= (3, 12)
220
333
  if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
221
334
  return True
335
+
336
+ if PROFILE:
337
+ @atexit.register
338
+ def finalize_profile():
339
+ devs = [Device[d] for d in Device._opened_devices]
340
+ for dev in devs: dev.synchronize()
341
+ for dev in devs: dev._at_profile_finalize()
342
+
343
+ with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(Compiled.profile_events, f)
344
+
345
+ from tinygrad.ops import launch_viz
346
+ launch_viz("PROFILE", fn)
347
+
348
+ if __name__ == "__main__":
349
+ for device in ALL_DEVICES:
350
+ try:
351
+ _ = Device[device].device
352
+ try:
353
+ from tinygrad import Tensor
354
+ with Context(CACHELEVEL=0): test = (Tensor([1,2,3], device=device) * 2).tolist()
355
+ if test != [2,4,6]: raise ValueError(f"got {test} instead of [2, 4, 6]")
356
+ result = colored("PASS", "green")
357
+ except Exception as e:
358
+ result = f"{colored('FAIL', 'yellow')} {e}"
359
+ except Exception as e:
360
+ result = f"{colored('FAIL', 'red')} {e}"
361
+ print(f"{'*' if device == Device.DEFAULT else ' '} {device:10s}: {result}")
tinygrad/dtype.py CHANGED
@@ -1,14 +1,16 @@
1
1
  from __future__ import annotations
2
- from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable
2
+ from typing import Final, Optional, ClassVar, Union, Callable, Literal
3
3
  import math, struct, ctypes, functools
4
4
  from dataclasses import dataclass, fields
5
- from tinygrad.helpers import getenv
5
+ from tinygrad.helpers import getenv, prod
6
6
 
7
7
  ConstType = Union[float, int, bool]
8
8
 
9
+ FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd']
10
+
9
11
  # all DTypes should only be created once
10
12
  class DTypeMetaClass(type):
11
- dcache: Dict[Tuple, DType] = {}
13
+ dcache: dict[tuple, DType] = {}
12
14
  def __call__(cls, *args, **kwargs):
13
15
  if (ret:=DTypeMetaClass.dcache.get(args, None)) is not None: return ret
14
16
  DTypeMetaClass.dcache[args] = ret = super().__call__(*args)
@@ -19,11 +21,11 @@ class DType(metaclass=DTypeMetaClass):
19
21
  priority: int # this determines when things get upcasted
20
22
  itemsize: int
21
23
  name: str
22
- fmt: Optional[str]
24
+ fmt: Optional[FmtStr]
23
25
  count: int
24
26
  _scalar: Optional[DType]
25
27
  @staticmethod
26
- def new(priority:int, itemsize:int, name:str, fmt:Optional[str]): return DType(priority, itemsize, name, fmt, 1, None)
28
+ def new(priority:int, itemsize:int, name:str, fmt:Optional[FmtStr]): return DType(priority, itemsize, name, fmt, 1, None)
27
29
  def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
28
30
  def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "")
29
31
  def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
@@ -36,7 +38,8 @@ class DType(metaclass=DTypeMetaClass):
36
38
  assert self.count == 1, f"can't vectorize {self} with size {sz}"
37
39
  if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
38
40
  return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
39
- def ptr(self, local=False) -> PtrDType: return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1)
41
+ def ptr(self, size=-1, local=False) -> PtrDType:
42
+ return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1, size)
40
43
  def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
41
44
 
42
45
  @dataclass(frozen=True, eq=False)
@@ -44,22 +47,24 @@ class PtrDType(DType):
44
47
  _base: DType
45
48
  local: bool
46
49
  v: int
50
+ size: int = -1 # -1 is unlimited size
47
51
  @property
48
52
  def base(self): return self._base
49
53
  @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
50
54
  def vec(self, sz:int) -> DType:
51
55
  assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
52
56
  if sz == 1: return self # sz=1 is a scalar
53
- return type(self)(*tuple(sz if f.name == 'v' else (self if f.name == '_scalar' else getattr(self, f.name)) for f in fields(self)))
54
- def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer")
57
+ return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz, self.size)
58
+ def ptr(self, size=-1, local=False): raise RuntimeError("can't make a pointer from a pointer")
55
59
  @property
56
60
  def vcount(self): return self.v
57
- def __repr__(self): return f"{self.base.__repr__()}.ptr({'local=True' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
61
+ def __repr__(self):
62
+ return f"{self.base.__repr__()}.ptr({self.size}{', local=True' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
58
63
 
59
64
  @dataclass(frozen=True, eq=False)
60
65
  class ImageDType(PtrDType):
61
- shape: Tuple[int, ...] = () # shape of the Image
62
- def ptr(self, local=False) -> PtrDType:
66
+ shape: tuple[int, ...] = () # shape of the Image
67
+ def ptr(self, size=-1, local=False) -> PtrDType:
63
68
  assert not local, "images can't be local"
64
69
  return self
65
70
  def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
@@ -68,13 +73,15 @@ class dtypes:
68
73
  @staticmethod
69
74
  @functools.lru_cache(None)
70
75
  def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType)
71
- @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
76
+ @staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool
72
77
  @functools.lru_cache(None)
73
78
  def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints
74
79
  @staticmethod
75
80
  @functools.lru_cache(None)
76
81
  def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
77
82
  @staticmethod
83
+ def is_bool(x: DType) -> bool: return x.scalar() == dtypes.bool
84
+ @staticmethod
78
85
  def from_py(x) -> DType:
79
86
  if x.__class__ is float: return dtypes.default_float
80
87
  if x.__class__ is int: return dtypes.default_int
@@ -83,7 +90,7 @@ class dtypes:
83
90
  if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
84
91
  raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
85
92
  @staticmethod
86
- def as_const(val: Tuple[ConstType, ...]|ConstType, dtype:DType):
93
+ def as_const(val: tuple[ConstType, ...]|ConstType, dtype:DType):
87
94
  if isinstance(val, tuple):
88
95
  assert len(val) == dtype.count, f"mismatch {val} {dtype}"
89
96
  return tuple(dtypes.as_const(x, dtype) for x in val)
@@ -97,18 +104,18 @@ class dtypes:
97
104
  @staticmethod
98
105
  @functools.lru_cache(None)
99
106
  def max(dtype:DType):
100
- if dtypes.is_int(dtype): return (2**(dtype.itemsize*8-(0 if dtypes.is_unsigned(dtype) else 1)))-1
107
+ if dtypes.is_int(dtype): return 2**(dtype.itemsize*8)-1+dtypes.min(dtype)
101
108
  return float("inf") if dtypes.is_float(dtype) else True
102
109
  @staticmethod
103
- def finfo(dtype:DType) -> Tuple[int, int]:
110
+ def finfo(dtype:DType) -> tuple[int, int]:
104
111
  """(exponent, mantissa)"""
105
112
  if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type")
106
113
  return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52)}[dtype]
107
114
  @staticmethod
108
- def fields() -> Dict[str, DType]: return DTYPES_DICT
115
+ def fields() -> dict[str, DType]: return DTYPES_DICT
109
116
  void: Final[DType] = DType.new(-1, 0, "void", None)
110
117
  bool: Final[DType] = DType.new(0, 1, "bool", '?')
111
- int8: Final[DType] = DType.new(1, 1, "char", 'b')
118
+ int8: Final[DType] = DType.new(1, 1, "signed char", 'b')
112
119
  uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B')
113
120
  int16: Final[DType] = DType.new(3, 2, "short", 'h')
114
121
  uint16: Final[DType] = DType.new(4, 2, "unsigned short", 'H')
@@ -129,9 +136,9 @@ class dtypes:
129
136
 
130
137
  # NOTE: these are image dtypes
131
138
  @staticmethod
132
- def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, shp)
139
+ def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, prod(shp), shp)
133
140
  @staticmethod
134
- def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, shp)
141
+ def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, prod(shp), shp)
135
142
 
136
143
  default_float: ClassVar[DType] = float32
137
144
  default_int: ClassVar[DType] = int32
@@ -156,18 +163,15 @@ promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes
156
163
  dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
157
164
 
158
165
  @functools.lru_cache(None)
159
- def _get_recursive_parents(dtype:DType) -> Set[DType]:
166
+ def _get_recursive_parents(dtype:DType) -> set[DType]:
160
167
  return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
161
168
  @functools.lru_cache(None)
162
169
  def least_upper_dtype(*ds:DType) -> DType:
163
170
  return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
164
171
  def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
165
172
 
166
- # HACK: staticmethods are not callable in 3.8 so we have to compare the class
167
- DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'void'))
168
- or v.__class__ is staticmethod or isinstance(v, tuple))}
169
- INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
170
- INVERSE_DTYPES_DICT['void'] = 'void'
173
+ DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void"))}
174
+ INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void"}
171
175
 
172
176
  def sum_acc_dtype(dt:DType):
173
177
  # default acc dtype for sum
@@ -179,9 +183,16 @@ def truncate_fp16(x):
179
183
  try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
180
184
  except OverflowError: return math.copysign(math.inf, x)
181
185
 
182
- truncate: Dict[DType, Callable] = {dtypes.bool: bool,
183
- # TODO: bfloat16
184
- dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
186
+ def truncate_bf16(x):
187
+ max_bf16 = struct.unpack('f', struct.pack('I', 0x7f7f0000))[0]
188
+ if x > max_bf16 or x < -max_bf16: return math.copysign(math.inf, x)
189
+ f32_int = struct.unpack('I', struct.pack('f', x))[0]
190
+ bf = struct.unpack('f', struct.pack('I', f32_int & 0xFFFF0000))[0]
191
+ return bf
192
+
193
+ truncate: dict[DType, Callable] = {dtypes.bool: bool,
194
+ dtypes.float16: truncate_fp16, dtypes.bfloat16: truncate_bf16,
195
+ dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
185
196
  dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
186
197
  dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
187
198
  dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value,