tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/device.py
CHANGED
@@ -1,36 +1,40 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
from dataclasses import dataclass, replace
|
3
3
|
from collections import defaultdict
|
4
|
-
from typing import Optional,
|
5
|
-
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib, sys
|
6
|
-
from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
|
4
|
+
from typing import Optional, Any, Iterator, Generator
|
5
|
+
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
|
6
|
+
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
|
7
|
+
cpu_time_execution, colored, Context, round_up
|
7
8
|
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
|
8
9
|
from tinygrad.renderer import Renderer
|
9
10
|
|
10
11
|
# **************** Device ****************
|
11
12
|
|
13
|
+
ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CPU", "LLVM", "DSP", "WEBGPU"]
|
12
14
|
class _Device:
|
13
15
|
def __init__(self) -> None:
|
14
16
|
self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
|
17
|
+
self._opened_devices:set[str] = set()
|
15
18
|
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
16
|
-
def _canonicalize(self, device:str) -> str: return ((d:=device.split(":", 1)[0].upper()) + device[len(d):])
|
19
|
+
def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):])
|
17
20
|
# NOTE: you can't cache canonicalize in case Device.DEFAULT changes
|
18
21
|
def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
|
19
22
|
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
|
20
23
|
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
21
24
|
def __get_canonicalized_item(self, ix:str) -> Compiled:
|
22
25
|
cpn = multiprocessing.current_process().name
|
23
|
-
assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent, not {cpn}"
|
26
|
+
assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"can only open device {ix} from parent, not {cpn}"
|
24
27
|
x = ix.split(":")[0].upper()
|
25
|
-
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'
|
28
|
+
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) \
|
26
29
|
if (cname.lower() == x.lower() + "device")][0](ix)
|
27
30
|
if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
|
31
|
+
self._opened_devices.add(ix)
|
28
32
|
return ret
|
29
33
|
@property
|
30
34
|
def default(self) -> Compiled: return self[self.DEFAULT]
|
31
35
|
def get_available_devices(self) -> Iterator[str]:
|
32
|
-
for device in
|
33
|
-
with contextlib.suppress(Exception): yield self[device].
|
36
|
+
for device in ALL_DEVICES:
|
37
|
+
with contextlib.suppress(Exception): yield self[device].device
|
34
38
|
@functools.cached_property
|
35
39
|
def DEFAULT(self) -> str:
|
36
40
|
if (from_env:=next((d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1), None)): return from_env
|
@@ -40,11 +44,41 @@ class _Device:
|
|
40
44
|
return device
|
41
45
|
except StopIteration as exc: raise RuntimeError("no usable devices") from exc
|
42
46
|
Device = _Device()
|
47
|
+
atexit.register(lambda: [Device[dn].finalize() for dn in Device._opened_devices])
|
48
|
+
|
49
|
+
# **************** Profile ****************
|
50
|
+
|
51
|
+
class ProfileEvent: pass
|
52
|
+
|
53
|
+
@dataclass(frozen=True)
|
54
|
+
class ProfileDeviceEvent(ProfileEvent):
|
55
|
+
device:str; comp_tdiff:decimal.Decimal=decimal.Decimal(0); copy_tdiff:decimal.Decimal=decimal.Decimal(0) # noqa: E702
|
56
|
+
|
57
|
+
@dataclass(frozen=True)
|
58
|
+
class ProfileRangeEvent(ProfileEvent): device:str; name:str; st:decimal.Decimal; en:decimal.Decimal; is_copy:bool # noqa: E702
|
59
|
+
|
60
|
+
@dataclass(frozen=True)
|
61
|
+
class ProfileGraphEntry: device:str; name:str; st_id:int; en_id:int; is_copy:bool # noqa: E702
|
62
|
+
|
63
|
+
@dataclass(frozen=True)
|
64
|
+
class ProfileGraphEvent(ProfileEvent): ents:list[ProfileGraphEntry]; deps:list[list[int]]; sigs:list[decimal.Decimal] # noqa: E702
|
65
|
+
|
66
|
+
@dataclass
|
67
|
+
class ProfileResult: st:Optional[int]=None; en:Optional[int]=None # noqa: E702
|
68
|
+
|
69
|
+
@contextlib.contextmanager
|
70
|
+
def cpu_profile(name, device="CPU", is_copy=False, display=True) -> Generator[ProfileResult, None, None]:
|
71
|
+
yield (res:=ProfileResult(st:=time.perf_counter_ns()))
|
72
|
+
res.en = en = time.perf_counter_ns()
|
73
|
+
if PROFILE and display:
|
74
|
+
Compiled.profile_events += [ProfileRangeEvent(device, name, decimal.Decimal(st) / 1000, decimal.Decimal(en) / 1000, is_copy=is_copy)]
|
43
75
|
|
44
76
|
# **************** Buffer + Allocators ****************
|
45
77
|
|
78
|
+
|
46
79
|
@dataclass(frozen=True, eq=True)
|
47
|
-
class
|
80
|
+
class BufferSpec:
|
81
|
+
# TODO: move device, size, dtype here?
|
48
82
|
image: Optional[ImageDType] = None
|
49
83
|
uncached: bool = False
|
50
84
|
cpu_access: bool = False
|
@@ -53,9 +87,9 @@ class BufferOptions:
|
|
53
87
|
external_ptr: Optional[int] = None
|
54
88
|
|
55
89
|
class Buffer:
|
56
|
-
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[
|
57
|
-
|
58
|
-
if isinstance(dtype, ImageDType): options =
|
90
|
+
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferSpec]=None, initial_value:Optional[bytes]=None,
|
91
|
+
lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
|
92
|
+
if isinstance(dtype, ImageDType): options = BufferSpec(image=dtype) # TODO: image hack shouldn't be here. where should it be?
|
59
93
|
else: assert isinstance(dtype, DType) and not isinstance(dtype, PtrDType)
|
60
94
|
self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
|
61
95
|
if base is None:
|
@@ -80,17 +114,23 @@ class Buffer:
|
|
80
114
|
def ensure_allocated(self) -> Buffer: return self.allocate() if not self.is_allocated() else self
|
81
115
|
def allocate(self, opaque=None, external_ptr=None) -> Buffer:
|
82
116
|
assert not self.is_allocated(), "can't allocate already allocated buffer"
|
83
|
-
self.allocator = Device[self.device].allocator
|
117
|
+
self.allocator:Allocator = Device[self.device].allocator
|
84
118
|
if external_ptr is not None:
|
85
|
-
self.options = replace(self.options, external_ptr=external_ptr) if self.options else
|
119
|
+
self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferSpec(external_ptr=external_ptr)
|
86
120
|
if self._base is not None:
|
87
121
|
self._base.ensure_allocated()
|
88
|
-
assert hasattr(self.allocator, "
|
89
|
-
self._buf: Any = self.allocator.
|
122
|
+
assert hasattr(self.allocator, "_offset"), "offset function required for view"
|
123
|
+
self._buf: Any = self.allocator._offset(self.base._buf, self.nbytes, self.offset)
|
90
124
|
else:
|
91
125
|
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
|
92
126
|
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
|
93
127
|
return self
|
128
|
+
def deallocate(self):
|
129
|
+
assert self.is_allocated(), "buffer must be allocated to deallocate"
|
130
|
+
if self._base is None and (self.options is None or self.options.external_ptr is None):
|
131
|
+
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
|
132
|
+
self.allocator.free(self._buf, self.nbytes, self.options)
|
133
|
+
del self._buf
|
94
134
|
def __reduce__(self):
|
95
135
|
buf = None
|
96
136
|
if self._base is not None:
|
@@ -102,31 +142,27 @@ class Buffer:
|
|
102
142
|
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
|
103
143
|
@property
|
104
144
|
def nbytes(self): return self.size*self.dtype.itemsize
|
105
|
-
def __del__(self):
|
106
|
-
if not self.is_allocated(): return
|
107
|
-
if self._base is None and (self.options is None or self.options.external_ptr is None):
|
108
|
-
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
|
109
|
-
self.allocator.free(self._buf, self.nbytes, self.options)
|
145
|
+
def __del__(self): (not self.is_allocated()) or self.deallocate()
|
110
146
|
def __repr__(self):
|
111
147
|
return f"<buf real:{self.is_allocated()} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
|
112
148
|
(f" offset:{self.offset}" if hasattr(self, "base") else "") + (f" {self.options=}" if self.options is not None else "") + ">"
|
113
149
|
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
|
114
150
|
# zero copy with as_buffer (disabled by default due to use after free)
|
115
|
-
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '
|
116
|
-
return self.allocator.
|
151
|
+
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and (self.options is None or self.options.image is None):
|
152
|
+
return self.allocator._as_buffer(self._buf)
|
117
153
|
assert not force_zero_copy, "force zero copy was passed, but copy is required"
|
118
154
|
return self.copyout(memoryview(bytearray(self.nbytes)))
|
119
155
|
def copyin(self, mv:memoryview):
|
120
156
|
mv = flat_mv(mv)
|
121
157
|
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
122
158
|
assert self.is_allocated(), "can't copyin to unallocated buffer"
|
123
|
-
self.allocator.
|
159
|
+
self.allocator._copyin(self._buf, mv)
|
124
160
|
return self
|
125
161
|
def copyout(self, mv:memoryview) -> memoryview:
|
126
162
|
mv = flat_mv(mv)
|
127
163
|
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
128
164
|
assert self.is_allocated(), "can't copyout unallocated buffer"
|
129
|
-
self.allocator.
|
165
|
+
self.allocator._copyout(mv, self._buf)
|
130
166
|
return mv
|
131
167
|
def view(self, size:int, dtype:DType, offset:int) -> Buffer:
|
132
168
|
assert offset < self.nbytes, "offset must be less than nbytes"
|
@@ -135,22 +171,28 @@ class Buffer:
|
|
135
171
|
|
136
172
|
# TODO: size, dest, src are the same type. can we enforce this?
|
137
173
|
class Allocator:
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
def free(self, opaque, size:int, options:Optional[
|
143
|
-
|
144
|
-
|
145
|
-
def
|
146
|
-
|
147
|
-
|
174
|
+
# overridden in LRUAllocator
|
175
|
+
def alloc(self, size:int, options:Optional[BufferSpec]=None):
|
176
|
+
assert size > 0, f"alloc size must be positive, getting {size}"
|
177
|
+
return self._alloc(size, options if options is not None else BufferSpec())
|
178
|
+
def free(self, opaque, size:int, options:Optional[BufferSpec]=None): self._free(opaque, options if options is not None else BufferSpec())
|
179
|
+
|
180
|
+
# implemented by the runtime
|
181
|
+
def _alloc(self, size:int, options:BufferSpec): raise NotImplementedError("need alloc")
|
182
|
+
def _free(self, opaque, options:BufferSpec): pass # if opaque is a Python object, you don't need a free
|
183
|
+
def _copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
|
184
|
+
def _copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
|
185
|
+
# def _as_buffer(self, src) -> memoryview:
|
186
|
+
# def _offset(self, buf, size:int, offset:int):
|
187
|
+
# def _transfer(self, dest, src, sz:int, src_dev, dest_dev):
|
188
|
+
|
189
|
+
class LRUAllocator(Allocator):
|
148
190
|
"""
|
149
191
|
The LRU Allocator is responsible for caching buffers.
|
150
192
|
It ensures that buffers are not freed until it is absolutely necessary, optimizing performance.
|
151
193
|
"""
|
152
|
-
def __init__(self): self.cache:
|
153
|
-
def alloc(self, size:int, options:Optional[
|
194
|
+
def __init__(self): self.cache: dict[tuple[int, Optional[BufferSpec]], Any] = defaultdict(list)
|
195
|
+
def alloc(self, size:int, options:Optional[BufferSpec]=None):
|
154
196
|
if len(c := self.cache[(size, options)]): return c.pop()
|
155
197
|
try: return super().alloc(size, options)
|
156
198
|
except (RuntimeError, MemoryError):
|
@@ -160,20 +202,78 @@ class LRUAllocator(Allocator): # pylint: disable=abstract-method
|
|
160
202
|
for (sz,options),opaques in self.cache.items():
|
161
203
|
for opaque in opaques: super().free(opaque, sz, options)
|
162
204
|
opaques.clear()
|
163
|
-
def free(self, opaque:Any, size:int, options:Optional[
|
164
|
-
if
|
205
|
+
def free(self, opaque:Any, size:int, options:Optional[BufferSpec]=None):
|
206
|
+
if LRU and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
|
165
207
|
else: super().free(opaque, size, options)
|
166
208
|
|
167
209
|
class _MallocAllocator(LRUAllocator):
|
168
|
-
def _alloc(self, size:int, options:
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
def
|
210
|
+
def _alloc(self, size:int, options:BufferSpec):
|
211
|
+
# must be aligned to 0x20 for 256-bit ymm registers
|
212
|
+
# TODO: investigate if this is the cause of nondeterminism in speed
|
213
|
+
alignment = 0x1000 if size >= 0x1000 else 0x20
|
214
|
+
return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else self._alloc_aligned(size, alignment)
|
215
|
+
def _alloc_aligned(self, size:int, alignment:int):
|
216
|
+
buffer = (ctypes.c_uint8 * (size + alignment))()
|
217
|
+
offset = round_up(ctypes.addressof(buffer), alignment) - ctypes.addressof(buffer)
|
218
|
+
return (ctypes.c_uint8 * size).from_buffer(buffer, offset)
|
219
|
+
def _as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
|
220
|
+
def _copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
|
221
|
+
def _copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
|
222
|
+
def _offset(self, buf, size:int, offset:int): return from_mv(self._as_buffer(buf)[offset:offset+size])
|
174
223
|
|
175
224
|
MallocAllocator = _MallocAllocator()
|
176
225
|
|
226
|
+
# NOTE: MAP_JIT is added to mmap module in python 3.13
|
227
|
+
MAP_JIT = 0x0800
|
228
|
+
|
229
|
+
# CPUProgram is a jit/shellcode program that can be just mmapped and jumped to
|
230
|
+
class CPUProgram:
|
231
|
+
rt_lib = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or sys.platform == "win32" else 'libgcc_s.so.1')
|
232
|
+
atomic_lib = ctypes.CDLL(ctypes.util.find_library('atomic')) if sys.platform == "linux" else None
|
233
|
+
|
234
|
+
def __init__(self, name:str, lib:bytes):
|
235
|
+
if sys.platform == "win32":
|
236
|
+
PAGE_EXECUTE_READWRITE = 0x40
|
237
|
+
MEM_COMMIT = 0x1000
|
238
|
+
MEM_RESERVE = 0x2000
|
239
|
+
ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p
|
240
|
+
self.mem = ctypes.windll.kernel32.VirtualAlloc(ctypes.c_void_p(0), ctypes.c_size_t(len(lib)), MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE)
|
241
|
+
ctypes.memmove(self.mem, lib, len(lib))
|
242
|
+
ctypes.windll.kernel32.GetCurrentProcess.restype = ctypes.c_void_p
|
243
|
+
proc = ctypes.windll.kernel32.GetCurrentProcess()
|
244
|
+
ctypes.windll.kernel32.FlushInstructionCache(ctypes.c_void_p(proc), ctypes.c_void_p(self.mem), ctypes.c_size_t(len(lib)))
|
245
|
+
self.fxn = ctypes.CFUNCTYPE(None)(self.mem)
|
246
|
+
else:
|
247
|
+
from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
|
248
|
+
# On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/
|
249
|
+
# MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np)
|
250
|
+
self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC)
|
251
|
+
|
252
|
+
if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(False)
|
253
|
+
self.mem.write(lib)
|
254
|
+
if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(True)
|
255
|
+
|
256
|
+
# __clear_cache isn't a normal libc function, but a compiler support routine found in libgcc_s for gcc and compiler-rt for clang.
|
257
|
+
# libgcc_s comes as shared library but compiler-rt is only a bunch of static library archives which we can't directly load, but fortunately
|
258
|
+
# it somehow found its way into libSystem on macos (likely because it used __builtin_clear_cache) and libgcc_s is ~always present on linux
|
259
|
+
# Using ["name"] instead of .name because otherwise name is getting mangled: https://docs.python.org/3.12/reference/expressions.html#index-5
|
260
|
+
CPUProgram.rt_lib["__clear_cache"](ctypes.c_void_p(mv_address(self.mem)), ctypes.c_void_p(mv_address(self.mem) + len(lib)))
|
261
|
+
|
262
|
+
self.fxn = ctypes.CFUNCTYPE(None)(mv_address(self.mem))
|
263
|
+
|
264
|
+
def __call__(self, *bufs, vals=(), wait=False):
|
265
|
+
args = list(bufs) + list(vals)
|
266
|
+
# NOTE: replace this by --target={host's triple}-elf in clang args once we only support macos sequoia and later.
|
267
|
+
# Apple relaxes abi requirement for stack arguments to always be at least 8 byte aligned on arm64
|
268
|
+
# https://developer.apple.com/documentation/xcode/writing-arm64-code-for-apple-platforms
|
269
|
+
# This hack is required because clang/llvm bug doesn't allow us to just use {host's triple}+'-elf' (relocation failures)
|
270
|
+
# The bug was fixed in https://github.com/llvm/llvm-project/commit/454cc36630296262cdb6360b60f90a64a97f7f1a but was only backported to xcode 16+
|
271
|
+
if platform.machine() == "arm64" and OSX: args = args[:8] + [ctypes.c_int64(a) if isinstance(a, int) else a for a in args[8:]]
|
272
|
+
return cpu_time_execution(lambda: self.fxn(*args), enable=wait)
|
273
|
+
|
274
|
+
def __del__(self):
|
275
|
+
if sys.platform == 'win32': ctypes.windll.kernel32.VirtualFree(ctypes.c_void_p(self.mem), ctypes.c_size_t(0), 0x8000) #0x8000 - MEM_RELEASE
|
276
|
+
|
177
277
|
# **************** for Compiled Devices ****************
|
178
278
|
|
179
279
|
class CompileError(Exception): pass
|
@@ -190,8 +290,10 @@ class Compiler:
|
|
190
290
|
def disassemble(self, lib:bytes): pass
|
191
291
|
|
192
292
|
class Compiled:
|
293
|
+
profile_events:list[ProfileEvent] = [ProfileDeviceEvent("CPU")] # NOTE: CPU is the default device.
|
294
|
+
|
193
295
|
def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
|
194
|
-
self.
|
296
|
+
self.device, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
|
195
297
|
self.renderer = renderer or Renderer()
|
196
298
|
def synchronize(self):
|
197
299
|
"""
|
@@ -200,6 +302,16 @@ class Compiled:
|
|
200
302
|
This method ensures that all previously queued operations on the device have been completed before proceeding.
|
201
303
|
"""
|
202
304
|
# override this in your device implementation
|
305
|
+
def _at_profile_finalize(self):
|
306
|
+
"""
|
307
|
+
Called at the end of profiling to allow the device to finalize any profiling.
|
308
|
+
"""
|
309
|
+
# override this in your device implementation
|
310
|
+
def finalize(self):
|
311
|
+
"""
|
312
|
+
Called at the end of process lifetime to allow the device to finalize.
|
313
|
+
"""
|
314
|
+
# override this in your device implementation
|
203
315
|
|
204
316
|
# TODO: move this to each Device
|
205
317
|
def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
|
@@ -207,7 +319,8 @@ def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
|
|
207
319
|
if dtype == dtypes.bfloat16:
|
208
320
|
# NOTE: this requires bf16 buffer support
|
209
321
|
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
|
210
|
-
if device
|
322
|
+
if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
|
323
|
+
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
|
211
324
|
# for CI GPU and OSX, cl_khr_fp16 isn't supported
|
212
325
|
# for CI LLVM, it segfaults because it can't link to the casting function
|
213
326
|
# CI CUDA architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
|
@@ -219,3 +332,30 @@ def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
|
|
219
332
|
if device == "PYTHON": return sys.version_info >= (3, 12)
|
220
333
|
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
|
221
334
|
return True
|
335
|
+
|
336
|
+
if PROFILE:
|
337
|
+
@atexit.register
|
338
|
+
def finalize_profile():
|
339
|
+
devs = [Device[d] for d in Device._opened_devices]
|
340
|
+
for dev in devs: dev.synchronize()
|
341
|
+
for dev in devs: dev._at_profile_finalize()
|
342
|
+
|
343
|
+
with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(Compiled.profile_events, f)
|
344
|
+
|
345
|
+
from tinygrad.ops import launch_viz
|
346
|
+
launch_viz("PROFILE", fn)
|
347
|
+
|
348
|
+
if __name__ == "__main__":
|
349
|
+
for device in ALL_DEVICES:
|
350
|
+
try:
|
351
|
+
_ = Device[device].device
|
352
|
+
try:
|
353
|
+
from tinygrad import Tensor
|
354
|
+
with Context(CACHELEVEL=0): test = (Tensor([1,2,3], device=device) * 2).tolist()
|
355
|
+
if test != [2,4,6]: raise ValueError(f"got {test} instead of [2, 4, 6]")
|
356
|
+
result = colored("PASS", "green")
|
357
|
+
except Exception as e:
|
358
|
+
result = f"{colored('FAIL', 'yellow')} {e}"
|
359
|
+
except Exception as e:
|
360
|
+
result = f"{colored('FAIL', 'red')} {e}"
|
361
|
+
print(f"{'*' if device == Device.DEFAULT else ' '} {device:10s}: {result}")
|
tinygrad/dtype.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import Final, Optional, ClassVar,
|
2
|
+
from typing import Final, Optional, ClassVar, Union, Callable, Literal
|
3
3
|
import math, struct, ctypes, functools
|
4
4
|
from dataclasses import dataclass, fields
|
5
|
-
from tinygrad.helpers import getenv
|
5
|
+
from tinygrad.helpers import getenv, prod
|
6
6
|
|
7
7
|
ConstType = Union[float, int, bool]
|
8
8
|
|
9
|
+
FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd']
|
10
|
+
|
9
11
|
# all DTypes should only be created once
|
10
12
|
class DTypeMetaClass(type):
|
11
|
-
dcache:
|
13
|
+
dcache: dict[tuple, DType] = {}
|
12
14
|
def __call__(cls, *args, **kwargs):
|
13
15
|
if (ret:=DTypeMetaClass.dcache.get(args, None)) is not None: return ret
|
14
16
|
DTypeMetaClass.dcache[args] = ret = super().__call__(*args)
|
@@ -19,11 +21,11 @@ class DType(metaclass=DTypeMetaClass):
|
|
19
21
|
priority: int # this determines when things get upcasted
|
20
22
|
itemsize: int
|
21
23
|
name: str
|
22
|
-
fmt: Optional[
|
24
|
+
fmt: Optional[FmtStr]
|
23
25
|
count: int
|
24
26
|
_scalar: Optional[DType]
|
25
27
|
@staticmethod
|
26
|
-
def new(priority:int, itemsize:int, name:str, fmt:Optional[
|
28
|
+
def new(priority:int, itemsize:int, name:str, fmt:Optional[FmtStr]): return DType(priority, itemsize, name, fmt, 1, None)
|
27
29
|
def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
|
28
30
|
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "")
|
29
31
|
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
|
@@ -36,7 +38,8 @@ class DType(metaclass=DTypeMetaClass):
|
|
36
38
|
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
37
39
|
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
|
38
40
|
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
|
39
|
-
def ptr(self, local=False) -> PtrDType:
|
41
|
+
def ptr(self, size=-1, local=False) -> PtrDType:
|
42
|
+
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1, size)
|
40
43
|
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
|
41
44
|
|
42
45
|
@dataclass(frozen=True, eq=False)
|
@@ -44,22 +47,24 @@ class PtrDType(DType):
|
|
44
47
|
_base: DType
|
45
48
|
local: bool
|
46
49
|
v: int
|
50
|
+
size: int = -1 # -1 is unlimited size
|
47
51
|
@property
|
48
52
|
def base(self): return self._base
|
49
53
|
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
50
54
|
def vec(self, sz:int) -> DType:
|
51
55
|
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
|
52
56
|
if sz == 1: return self # sz=1 is a scalar
|
53
|
-
return type(self)(
|
54
|
-
def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer")
|
57
|
+
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz, self.size)
|
58
|
+
def ptr(self, size=-1, local=False): raise RuntimeError("can't make a pointer from a pointer")
|
55
59
|
@property
|
56
60
|
def vcount(self): return self.v
|
57
|
-
def __repr__(self):
|
61
|
+
def __repr__(self):
|
62
|
+
return f"{self.base.__repr__()}.ptr({self.size}{', local=True' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
|
58
63
|
|
59
64
|
@dataclass(frozen=True, eq=False)
|
60
65
|
class ImageDType(PtrDType):
|
61
|
-
shape:
|
62
|
-
def ptr(self, local=False) -> PtrDType:
|
66
|
+
shape: tuple[int, ...] = () # shape of the Image
|
67
|
+
def ptr(self, size=-1, local=False) -> PtrDType:
|
63
68
|
assert not local, "images can't be local"
|
64
69
|
return self
|
65
70
|
def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
|
@@ -68,13 +73,15 @@ class dtypes:
|
|
68
73
|
@staticmethod
|
69
74
|
@functools.lru_cache(None)
|
70
75
|
def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType)
|
71
|
-
@staticmethod # static
|
76
|
+
@staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool
|
72
77
|
@functools.lru_cache(None)
|
73
78
|
def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints
|
74
79
|
@staticmethod
|
75
80
|
@functools.lru_cache(None)
|
76
81
|
def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
|
77
82
|
@staticmethod
|
83
|
+
def is_bool(x: DType) -> bool: return x.scalar() == dtypes.bool
|
84
|
+
@staticmethod
|
78
85
|
def from_py(x) -> DType:
|
79
86
|
if x.__class__ is float: return dtypes.default_float
|
80
87
|
if x.__class__ is int: return dtypes.default_int
|
@@ -83,7 +90,7 @@ class dtypes:
|
|
83
90
|
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
|
84
91
|
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
|
85
92
|
@staticmethod
|
86
|
-
def as_const(val:
|
93
|
+
def as_const(val: tuple[ConstType, ...]|ConstType, dtype:DType):
|
87
94
|
if isinstance(val, tuple):
|
88
95
|
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
|
89
96
|
return tuple(dtypes.as_const(x, dtype) for x in val)
|
@@ -97,18 +104,18 @@ class dtypes:
|
|
97
104
|
@staticmethod
|
98
105
|
@functools.lru_cache(None)
|
99
106
|
def max(dtype:DType):
|
100
|
-
if dtypes.is_int(dtype): return
|
107
|
+
if dtypes.is_int(dtype): return 2**(dtype.itemsize*8)-1+dtypes.min(dtype)
|
101
108
|
return float("inf") if dtypes.is_float(dtype) else True
|
102
109
|
@staticmethod
|
103
|
-
def finfo(dtype:DType) ->
|
110
|
+
def finfo(dtype:DType) -> tuple[int, int]:
|
104
111
|
"""(exponent, mantissa)"""
|
105
112
|
if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type")
|
106
113
|
return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52)}[dtype]
|
107
114
|
@staticmethod
|
108
|
-
def fields() ->
|
115
|
+
def fields() -> dict[str, DType]: return DTYPES_DICT
|
109
116
|
void: Final[DType] = DType.new(-1, 0, "void", None)
|
110
117
|
bool: Final[DType] = DType.new(0, 1, "bool", '?')
|
111
|
-
int8: Final[DType] = DType.new(1, 1, "char", 'b')
|
118
|
+
int8: Final[DType] = DType.new(1, 1, "signed char", 'b')
|
112
119
|
uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B')
|
113
120
|
int16: Final[DType] = DType.new(3, 2, "short", 'h')
|
114
121
|
uint16: Final[DType] = DType.new(4, 2, "unsigned short", 'H')
|
@@ -129,9 +136,9 @@ class dtypes:
|
|
129
136
|
|
130
137
|
# NOTE: these are image dtypes
|
131
138
|
@staticmethod
|
132
|
-
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, shp)
|
139
|
+
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, prod(shp), shp)
|
133
140
|
@staticmethod
|
134
|
-
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, shp)
|
141
|
+
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, prod(shp), shp)
|
135
142
|
|
136
143
|
default_float: ClassVar[DType] = float32
|
137
144
|
default_int: ClassVar[DType] = int32
|
@@ -156,18 +163,15 @@ promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes
|
|
156
163
|
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
|
157
164
|
|
158
165
|
@functools.lru_cache(None)
|
159
|
-
def _get_recursive_parents(dtype:DType) ->
|
166
|
+
def _get_recursive_parents(dtype:DType) -> set[DType]:
|
160
167
|
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
|
161
168
|
@functools.lru_cache(None)
|
162
169
|
def least_upper_dtype(*ds:DType) -> DType:
|
163
170
|
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
|
164
171
|
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
|
165
172
|
|
166
|
-
|
167
|
-
|
168
|
-
or v.__class__ is staticmethod or isinstance(v, tuple))}
|
169
|
-
INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
|
170
|
-
INVERSE_DTYPES_DICT['void'] = 'void'
|
173
|
+
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void"))}
|
174
|
+
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void"}
|
171
175
|
|
172
176
|
def sum_acc_dtype(dt:DType):
|
173
177
|
# default acc dtype for sum
|
@@ -179,9 +183,16 @@ def truncate_fp16(x):
|
|
179
183
|
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
|
180
184
|
except OverflowError: return math.copysign(math.inf, x)
|
181
185
|
|
182
|
-
|
183
|
-
|
184
|
-
|
186
|
+
def truncate_bf16(x):
|
187
|
+
max_bf16 = struct.unpack('f', struct.pack('I', 0x7f7f0000))[0]
|
188
|
+
if x > max_bf16 or x < -max_bf16: return math.copysign(math.inf, x)
|
189
|
+
f32_int = struct.unpack('I', struct.pack('f', x))[0]
|
190
|
+
bf = struct.unpack('f', struct.pack('I', f32_int & 0xFFFF0000))[0]
|
191
|
+
return bf
|
192
|
+
|
193
|
+
truncate: dict[DType, Callable] = {dtypes.bool: bool,
|
194
|
+
dtypes.float16: truncate_fp16, dtypes.bfloat16: truncate_bf16,
|
195
|
+
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
185
196
|
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
186
197
|
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
187
198
|
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value,
|