tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/device.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
from dataclasses import dataclass, replace
|
3
3
|
from collections import defaultdict
|
4
|
-
from typing import
|
5
|
-
import
|
6
|
-
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv,
|
7
|
-
|
8
|
-
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
|
4
|
+
from typing import Any, Generic, TypeVar, Iterator
|
5
|
+
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal, time
|
6
|
+
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, \
|
7
|
+
Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
|
8
|
+
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
|
9
9
|
from tinygrad.renderer import Renderer
|
10
10
|
|
11
11
|
# **************** Device ****************
|
@@ -15,18 +15,18 @@ class _Device:
|
|
15
15
|
def __init__(self) -> None:
|
16
16
|
self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
|
17
17
|
self._opened_devices:set[str] = set()
|
18
|
-
@functools.
|
18
|
+
@functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none
|
19
19
|
def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):])
|
20
20
|
# NOTE: you can't cache canonicalize in case Device.DEFAULT changes
|
21
|
-
def canonicalize(self, device:
|
21
|
+
def canonicalize(self, device:str|None) -> str: return self._canonicalize(device if device is not None else Device.DEFAULT)
|
22
22
|
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
|
23
|
-
@functools.
|
23
|
+
@functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none
|
24
24
|
def __get_canonicalized_item(self, ix:str) -> Compiled:
|
25
|
-
|
26
|
-
|
27
|
-
x = ix.split(":")[0].
|
28
|
-
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'
|
29
|
-
if (cname.lower() == x
|
25
|
+
assert ALLOW_DEVICE_USAGE or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"usage of device {ix} disallowed"
|
26
|
+
base = (__package__ or __name__).split('.')[0] # tinygrad
|
27
|
+
x = ix.split(":")[0].lower()
|
28
|
+
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{base}.runtime.ops_{x}')) \
|
29
|
+
if (cname.lower() == x + "device")][0](ix)
|
30
30
|
if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
|
31
31
|
self._opened_devices.add(ix)
|
32
32
|
return ret
|
@@ -37,7 +37,10 @@ class _Device:
|
|
37
37
|
with contextlib.suppress(Exception): yield self[device].device
|
38
38
|
@functools.cached_property
|
39
39
|
def DEFAULT(self) -> str:
|
40
|
-
if (
|
40
|
+
dev = [dev] if (dev:=getenv("DEV", "").upper()) else []
|
41
|
+
from_env = dedup(dev + [d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1])
|
42
|
+
assert len(from_env) < 2, f"multiple devices set in env: {from_env}"
|
43
|
+
if len(from_env) == 1: return from_env[0]
|
41
44
|
try:
|
42
45
|
device = next(self.get_available_devices())
|
43
46
|
os.environ[device] = "1" # we set this in environment for spawned children
|
@@ -48,14 +51,12 @@ atexit.register(lambda: [Device[dn].finalize() for dn in Device._opened_devices]
|
|
48
51
|
|
49
52
|
# **************** Profile ****************
|
50
53
|
|
51
|
-
class ProfileEvent: pass
|
52
|
-
|
53
54
|
@dataclass(frozen=True)
|
54
55
|
class ProfileDeviceEvent(ProfileEvent):
|
55
56
|
device:str; comp_tdiff:decimal.Decimal=decimal.Decimal(0); copy_tdiff:decimal.Decimal=decimal.Decimal(0) # noqa: E702
|
56
57
|
|
57
58
|
@dataclass(frozen=True)
|
58
|
-
class
|
59
|
+
class ProfileProgramEvent(ProfileEvent): device:str; name:str; lib:bytes|None; base:int|None # noqa: E702
|
59
60
|
|
60
61
|
@dataclass(frozen=True)
|
61
62
|
class ProfileGraphEntry: device:str; name:str; st_id:int; en_id:int; is_copy:bool # noqa: E702
|
@@ -63,39 +64,42 @@ class ProfileGraphEntry: device:str; name:str; st_id:int; en_id:int; is_copy:boo
|
|
63
64
|
@dataclass(frozen=True)
|
64
65
|
class ProfileGraphEvent(ProfileEvent): ents:list[ProfileGraphEntry]; deps:list[list[int]]; sigs:list[decimal.Decimal] # noqa: E702
|
65
66
|
|
66
|
-
@dataclass
|
67
|
-
class ProfileResult: st:Optional[int]=None; en:Optional[int]=None # noqa: E702
|
68
|
-
|
69
|
-
@contextlib.contextmanager
|
70
|
-
def cpu_profile(name, device="CPU", is_copy=False, display=True) -> Generator[ProfileResult, None, None]:
|
71
|
-
yield (res:=ProfileResult(st:=time.perf_counter_ns()))
|
72
|
-
res.en = en = time.perf_counter_ns()
|
73
|
-
if PROFILE and display:
|
74
|
-
Compiled.profile_events += [ProfileRangeEvent(device, name, decimal.Decimal(st) / 1000, decimal.Decimal(en) / 1000, is_copy=is_copy)]
|
75
|
-
|
76
67
|
# **************** Buffer + Allocators ****************
|
77
68
|
|
78
|
-
|
79
69
|
@dataclass(frozen=True, eq=True)
|
80
70
|
class BufferSpec:
|
81
71
|
# TODO: move device, size, dtype here?
|
82
|
-
image:
|
72
|
+
image: ImageDType|None = None
|
83
73
|
uncached: bool = False
|
84
74
|
cpu_access: bool = False
|
85
75
|
host: bool = False
|
86
76
|
nolru: bool = False
|
87
|
-
external_ptr:
|
77
|
+
external_ptr: int|None = None
|
78
|
+
|
79
|
+
class MultiBuffer:
|
80
|
+
def __init__(self, device:tuple[str, ...], size:int, dtype:DType):
|
81
|
+
self.bufs = [Buffer(d, size, dtype) for d in device]
|
82
|
+
@property
|
83
|
+
def size(self): return self.bufs[0].size
|
84
|
+
@property
|
85
|
+
def dtype(self): return self.bufs[0].dtype
|
86
|
+
def ref(self, cnt):
|
87
|
+
for b in self.bufs: b.ref(cnt)
|
88
|
+
return self
|
89
|
+
def is_allocated(self): return all(x.is_allocated() for x in self.bufs)
|
90
|
+
def __repr__(self): return f"<multibuf real:{self.is_allocated()} device:{tuple(x.device for x in self.bufs)} size:{self.size} dtype:{self.dtype}>"
|
88
91
|
|
89
92
|
class Buffer:
|
90
|
-
|
91
|
-
|
93
|
+
profile_events:list[ProfileEvent] = []
|
94
|
+
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:BufferSpec|None=None, initial_value:bytes|None=None,
|
95
|
+
uop_refcount=0, base:Buffer|None=None, offset:int=0, preallocate=False):
|
92
96
|
if isinstance(dtype, ImageDType): options = BufferSpec(image=dtype) # TODO: image hack shouldn't be here. where should it be?
|
93
97
|
else: assert isinstance(dtype, DType) and not isinstance(dtype, PtrDType)
|
94
|
-
self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
|
98
|
+
self.device, self.size, self.dtype, self.options, self.offset, self.allocated_views = device, size, dtype, options, offset, 0
|
95
99
|
if base is None:
|
96
100
|
assert offset == 0, "base buffers can't have offset"
|
97
101
|
self._base = None
|
98
|
-
self.
|
102
|
+
self._uop_refcount = uop_refcount
|
99
103
|
if opaque is not None: self.allocate(opaque)
|
100
104
|
if initial_value is not None:
|
101
105
|
self.allocate()
|
@@ -108,60 +112,86 @@ class Buffer:
|
|
108
112
|
@property
|
109
113
|
def base(self) -> Buffer: return self._base if self._base is not None else self
|
110
114
|
@property
|
111
|
-
def
|
112
|
-
def ref(self, cnt):
|
113
|
-
|
114
|
-
|
115
|
+
def uop_refcount(self): return self.base._uop_refcount
|
116
|
+
def ref(self, cnt):
|
117
|
+
self.base._uop_refcount += cnt
|
118
|
+
return self
|
119
|
+
# check if the underlying buffer is allocated and the current buffer/view is initialized
|
120
|
+
def is_initialized(self) -> bool: return self.is_allocated() and hasattr(self, '_buf')
|
121
|
+
# check if the underlying buffer is allocated, possibly from the base object
|
122
|
+
def is_allocated(self) -> bool: return self.base.is_allocated() if self._base is not None else hasattr(self, '_buf')
|
123
|
+
def ensure_allocated(self) -> Buffer: return self.allocate() if not self.is_initialized() else self
|
115
124
|
def allocate(self, opaque=None, external_ptr=None) -> Buffer:
|
116
|
-
assert not self.
|
125
|
+
assert not self.is_initialized(), "can't allocate already allocated buffer"
|
126
|
+
if DEBUG >= 7: print(f"buffer: allocate {self.nbytes} bytes on {self.device}")
|
127
|
+
if MAX_BUFFER_SIZE > 0 and self.size > MAX_BUFFER_SIZE: raise RuntimeError(f"buffer of size {self.size/1e6:.2f}M is too large")
|
117
128
|
self.allocator:Allocator = Device[self.device].allocator
|
118
129
|
if external_ptr is not None:
|
119
130
|
self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferSpec(external_ptr=external_ptr)
|
120
131
|
if self._base is not None:
|
121
132
|
self._base.ensure_allocated()
|
133
|
+
self._base.allocated_views += 1
|
122
134
|
assert hasattr(self.allocator, "_offset"), "offset function required for view"
|
123
135
|
self._buf: Any = self.allocator._offset(self.base._buf, self.nbytes, self.offset)
|
124
136
|
else:
|
125
137
|
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
|
126
138
|
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
|
139
|
+
if PROFILE:
|
140
|
+
self._prof_num = num = len(Buffer.profile_events)
|
141
|
+
ts = decimal.Decimal(time.perf_counter_ns())/1000
|
142
|
+
Buffer.profile_events.append(ProfilePointEvent(self.device, "alloc", ts, num, {"dtype":str(self.dtype),"sz":self.size,"nbytes":self.nbytes}))
|
127
143
|
return self
|
128
144
|
def deallocate(self):
|
129
|
-
assert self
|
145
|
+
assert hasattr(self, '_buf'), "buffer must be allocated to deallocate"
|
146
|
+
if DEBUG is not None and DEBUG >= 7: print(f"buffer: deallocate {self.nbytes} bytes on {self.device}")
|
130
147
|
if self._base is None and (self.options is None or self.options.external_ptr is None):
|
131
|
-
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
|
148
|
+
if GlobalCounters is not None and not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
|
149
|
+
if PROFILE: Buffer.profile_events.append(ProfilePointEvent(self.device, "free", decimal.Decimal(time.perf_counter_ns())/1000, self._prof_num))
|
132
150
|
self.allocator.free(self._buf, self.nbytes, self.options)
|
151
|
+
elif self._base is not None: self._base.allocated_views -= 1
|
133
152
|
del self._buf
|
134
153
|
def __reduce__(self):
|
135
154
|
buf = None
|
136
155
|
if self._base is not None:
|
137
156
|
return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, self.is_allocated())
|
138
|
-
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.
|
157
|
+
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.uop_refcount)
|
139
158
|
if self.is_allocated():
|
140
159
|
buf = bytearray(self.nbytes)
|
141
160
|
self.copyout(memoryview(buf))
|
142
|
-
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.
|
161
|
+
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.uop_refcount)
|
143
162
|
@property
|
144
163
|
def nbytes(self): return self.size*self.dtype.itemsize
|
145
|
-
def __del__(self): (not self
|
164
|
+
def __del__(self): (not hasattr(self, '_buf')) or self.deallocate()
|
146
165
|
def __repr__(self):
|
147
166
|
return f"<buf real:{self.is_allocated()} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
|
148
|
-
(f" offset:{self.offset}" if
|
167
|
+
(f" offset:{self.offset}" if self._base is not None else "") + (f" {self.options=}" if self.options is not None else "") + ">"
|
168
|
+
def as_dmaref(self) -> DMARef:
|
169
|
+
assert hasattr(self.allocator, "_as_dmaref"), f"Device {self.device} doesn't support DMA"
|
170
|
+
return self.allocator._as_dmaref(self._buf)
|
149
171
|
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
|
150
172
|
# zero copy with as_buffer (disabled by default due to use after free)
|
151
173
|
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and (self.options is None or self.options.image is None):
|
152
174
|
return self.allocator._as_buffer(self._buf)
|
153
175
|
assert not force_zero_copy, "force zero copy was passed, but copy is required"
|
154
176
|
return self.copyout(memoryview(bytearray(self.nbytes)))
|
177
|
+
def as_typed_buffer(self, shape=None, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
|
178
|
+
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
|
179
|
+
assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12)
|
180
|
+
return self.as_buffer(allow_zero_copy, force_zero_copy).cast(self.dtype.base.fmt, shape if shape is not None else (self.size,))
|
181
|
+
def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821
|
182
|
+
import numpy as np
|
183
|
+
assert _to_np_dtype(self.dtype.base) is not None, f"no np dtype for {self.dtype.base}"
|
184
|
+
return np.frombuffer(self.as_buffer(), dtype=_to_np_dtype(self.dtype.base))
|
155
185
|
def copyin(self, mv:memoryview):
|
156
186
|
mv = flat_mv(mv)
|
157
187
|
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
158
|
-
assert self.
|
188
|
+
assert self.is_initialized(), "can't copyin to unallocated buffer"
|
159
189
|
self.allocator._copyin(self._buf, mv)
|
160
190
|
return self
|
161
191
|
def copyout(self, mv:memoryview) -> memoryview:
|
162
192
|
mv = flat_mv(mv)
|
163
193
|
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
164
|
-
assert self.
|
194
|
+
assert self.is_initialized(), "can't copyout unallocated buffer"
|
165
195
|
self.allocator._copyout(mv, self._buf)
|
166
196
|
return mv
|
167
197
|
def view(self, size:int, dtype:DType, offset:int) -> Buffer:
|
@@ -169,13 +199,33 @@ class Buffer:
|
|
169
199
|
if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
|
170
200
|
return Buffer(self.device, size, dtype, base=self, offset=offset)
|
171
201
|
|
202
|
+
@dataclass(frozen=True)
|
203
|
+
class DMACPURef:
|
204
|
+
addr: int
|
205
|
+
size: int
|
206
|
+
|
207
|
+
@dataclass(frozen=True)
|
208
|
+
class DMAFdRef:
|
209
|
+
fd: int
|
210
|
+
offset: int
|
211
|
+
size: int
|
212
|
+
|
213
|
+
DMARef = DMACPURef|DMAFdRef
|
214
|
+
|
215
|
+
DeviceType = TypeVar('DeviceType', bound='Compiled')
|
216
|
+
|
172
217
|
# TODO: size, dest, src are the same type. can we enforce this?
|
173
|
-
class Allocator:
|
218
|
+
class Allocator(Generic[DeviceType]):
|
219
|
+
def __init__(self, dev:DeviceType):
|
220
|
+
self.dev: DeviceType = dev
|
221
|
+
self.default_buffer_spec: BufferSpec = BufferSpec()
|
222
|
+
self.supports_copy_from_disk: bool = True
|
174
223
|
# overridden in LRUAllocator
|
175
|
-
def alloc(self, size:int, options:
|
224
|
+
def alloc(self, size:int, options:BufferSpec|None=None):
|
176
225
|
assert size > 0, f"alloc size must be positive, getting {size}"
|
177
|
-
return self._alloc(size, options if options is not None else
|
178
|
-
def free(self, opaque, size:int, options:
|
226
|
+
return self._alloc(size, options if options is not None else self.default_buffer_spec)
|
227
|
+
def free(self, opaque, size:int, options:BufferSpec|None=None):
|
228
|
+
self._free(opaque, options if options is not None else self.default_buffer_spec)
|
179
229
|
|
180
230
|
# implemented by the runtime
|
181
231
|
def _alloc(self, size:int, options:BufferSpec): raise NotImplementedError("need alloc")
|
@@ -186,13 +236,15 @@ class Allocator:
|
|
186
236
|
# def _offset(self, buf, size:int, offset:int):
|
187
237
|
# def _transfer(self, dest, src, sz:int, src_dev, dest_dev):
|
188
238
|
|
189
|
-
class LRUAllocator(Allocator):
|
239
|
+
class LRUAllocator(Allocator, Generic[DeviceType]):
|
190
240
|
"""
|
191
241
|
The LRU Allocator is responsible for caching buffers.
|
192
242
|
It ensures that buffers are not freed until it is absolutely necessary, optimizing performance.
|
193
243
|
"""
|
194
|
-
def __init__(self
|
195
|
-
|
244
|
+
def __init__(self, dev:DeviceType):
|
245
|
+
self.cache: dict[tuple[int, BufferSpec|None], Any] = defaultdict(list)
|
246
|
+
super().__init__(dev)
|
247
|
+
def alloc(self, size:int, options:BufferSpec|None=None):
|
196
248
|
if len(c := self.cache[(size, options)]): return c.pop()
|
197
249
|
try: return super().alloc(size, options)
|
198
250
|
except (RuntimeError, MemoryError):
|
@@ -202,84 +254,16 @@ class LRUAllocator(Allocator):
|
|
202
254
|
for (sz,options),opaques in self.cache.items():
|
203
255
|
for opaque in opaques: super().free(opaque, sz, options)
|
204
256
|
opaques.clear()
|
205
|
-
def free(self, opaque:Any, size:int, options:
|
257
|
+
def free(self, opaque:Any, size:int, options:BufferSpec|None=None):
|
206
258
|
if LRU and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
|
207
259
|
else: super().free(opaque, size, options)
|
208
260
|
|
209
|
-
class _MallocAllocator(LRUAllocator):
|
210
|
-
def _alloc(self, size:int, options:BufferSpec):
|
211
|
-
# must be aligned to 0x20 for 256-bit ymm registers
|
212
|
-
# TODO: investigate if this is the cause of nondeterminism in speed
|
213
|
-
alignment = 0x1000 if size >= 0x1000 else 0x20
|
214
|
-
return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else self._alloc_aligned(size, alignment)
|
215
|
-
def _alloc_aligned(self, size:int, alignment:int):
|
216
|
-
buffer = (ctypes.c_uint8 * (size + alignment))()
|
217
|
-
offset = round_up(ctypes.addressof(buffer), alignment) - ctypes.addressof(buffer)
|
218
|
-
return (ctypes.c_uint8 * size).from_buffer(buffer, offset)
|
219
|
-
def _as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
|
220
|
-
def _copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
|
221
|
-
def _copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
|
222
|
-
def _offset(self, buf, size:int, offset:int): return from_mv(self._as_buffer(buf)[offset:offset+size])
|
223
|
-
|
224
|
-
MallocAllocator = _MallocAllocator()
|
225
|
-
|
226
|
-
# NOTE: MAP_JIT is added to mmap module in python 3.13
|
227
|
-
MAP_JIT = 0x0800
|
228
|
-
|
229
|
-
# CPUProgram is a jit/shellcode program that can be just mmapped and jumped to
|
230
|
-
class CPUProgram:
|
231
|
-
rt_lib = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or sys.platform == "win32" else 'libgcc_s.so.1')
|
232
|
-
atomic_lib = ctypes.CDLL(ctypes.util.find_library('atomic')) if sys.platform == "linux" else None
|
233
|
-
|
234
|
-
def __init__(self, name:str, lib:bytes):
|
235
|
-
if sys.platform == "win32":
|
236
|
-
PAGE_EXECUTE_READWRITE = 0x40
|
237
|
-
MEM_COMMIT = 0x1000
|
238
|
-
MEM_RESERVE = 0x2000
|
239
|
-
ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p
|
240
|
-
self.mem = ctypes.windll.kernel32.VirtualAlloc(ctypes.c_void_p(0), ctypes.c_size_t(len(lib)), MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE)
|
241
|
-
ctypes.memmove(self.mem, lib, len(lib))
|
242
|
-
ctypes.windll.kernel32.GetCurrentProcess.restype = ctypes.c_void_p
|
243
|
-
proc = ctypes.windll.kernel32.GetCurrentProcess()
|
244
|
-
ctypes.windll.kernel32.FlushInstructionCache(ctypes.c_void_p(proc), ctypes.c_void_p(self.mem), ctypes.c_size_t(len(lib)))
|
245
|
-
self.fxn = ctypes.CFUNCTYPE(None)(self.mem)
|
246
|
-
else:
|
247
|
-
from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
|
248
|
-
# On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/
|
249
|
-
# MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np)
|
250
|
-
self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC)
|
251
|
-
|
252
|
-
if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(False)
|
253
|
-
self.mem.write(lib)
|
254
|
-
if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(True)
|
255
|
-
|
256
|
-
# __clear_cache isn't a normal libc function, but a compiler support routine found in libgcc_s for gcc and compiler-rt for clang.
|
257
|
-
# libgcc_s comes as shared library but compiler-rt is only a bunch of static library archives which we can't directly load, but fortunately
|
258
|
-
# it somehow found its way into libSystem on macos (likely because it used __builtin_clear_cache) and libgcc_s is ~always present on linux
|
259
|
-
# Using ["name"] instead of .name because otherwise name is getting mangled: https://docs.python.org/3.12/reference/expressions.html#index-5
|
260
|
-
CPUProgram.rt_lib["__clear_cache"](ctypes.c_void_p(mv_address(self.mem)), ctypes.c_void_p(mv_address(self.mem) + len(lib)))
|
261
|
-
|
262
|
-
self.fxn = ctypes.CFUNCTYPE(None)(mv_address(self.mem))
|
263
|
-
|
264
|
-
def __call__(self, *bufs, vals=(), wait=False):
|
265
|
-
args = list(bufs) + list(vals)
|
266
|
-
# NOTE: replace this by --target={host's triple}-elf in clang args once we only support macos sequoia and later.
|
267
|
-
# Apple relaxes abi requirement for stack arguments to always be at least 8 byte aligned on arm64
|
268
|
-
# https://developer.apple.com/documentation/xcode/writing-arm64-code-for-apple-platforms
|
269
|
-
# This hack is required because clang/llvm bug doesn't allow us to just use {host's triple}+'-elf' (relocation failures)
|
270
|
-
# The bug was fixed in https://github.com/llvm/llvm-project/commit/454cc36630296262cdb6360b60f90a64a97f7f1a but was only backported to xcode 16+
|
271
|
-
if platform.machine() == "arm64" and OSX: args = args[:8] + [ctypes.c_int64(a) if isinstance(a, int) else a for a in args[8:]]
|
272
|
-
return cpu_time_execution(lambda: self.fxn(*args), enable=wait)
|
273
|
-
|
274
|
-
def __del__(self):
|
275
|
-
if sys.platform == 'win32': ctypes.windll.kernel32.VirtualFree(ctypes.c_void_p(self.mem), ctypes.c_size_t(0), 0x8000) #0x8000 - MEM_RELEASE
|
276
|
-
|
277
261
|
# **************** for Compiled Devices ****************
|
278
262
|
|
279
263
|
class CompileError(Exception): pass
|
280
264
|
|
281
265
|
class Compiler:
|
282
|
-
def __init__(self, cachekey:
|
266
|
+
def __init__(self, cachekey:str|None=None): self.cachekey = None if DISABLE_COMPILER_CACHE else cachekey
|
283
267
|
def compile(self, src:str) -> bytes: return src.encode() # NOTE: empty compiler is the default
|
284
268
|
def compile_cached(self, src:str) -> bytes:
|
285
269
|
if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
|
@@ -292,9 +276,9 @@ class Compiler:
|
|
292
276
|
class Compiled:
|
293
277
|
profile_events:list[ProfileEvent] = [ProfileDeviceEvent("CPU")] # NOTE: CPU is the default device.
|
294
278
|
|
295
|
-
def __init__(self, device:str, allocator:Allocator, renderer:
|
279
|
+
def __init__(self, device:str, allocator:Allocator, renderer:Renderer|None, compiler:Compiler|None, runtime, graph=None, group_id=None):
|
296
280
|
self.device, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
|
297
|
-
self.renderer = renderer or Renderer()
|
281
|
+
self.renderer, self.group_id = renderer or Renderer(), group_id
|
298
282
|
def synchronize(self):
|
299
283
|
"""
|
300
284
|
Synchronize all pending operations on the device.
|
@@ -314,11 +298,16 @@ class Compiled:
|
|
314
298
|
# override this in your device implementation
|
315
299
|
|
316
300
|
# TODO: move this to each Device
|
317
|
-
def is_dtype_supported(dtype:DType, device:
|
301
|
+
def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
|
318
302
|
if device is None: device = Device.DEFAULT
|
319
303
|
if dtype == dtypes.bfloat16:
|
320
|
-
|
321
|
-
|
304
|
+
if device == "METAL": return not CI
|
305
|
+
if device in {"CUDA", "NV"}: return not CI and not getenv("PTX")
|
306
|
+
if device in {"CPU", "LLVM"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"}
|
307
|
+
return device == "AMD"
|
308
|
+
if dtype in dtypes.fp8s:
|
309
|
+
# not supported yet - in progress
|
310
|
+
return False
|
322
311
|
if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
|
323
312
|
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
|
324
313
|
# for CI GPU and OSX, cl_khr_fp16 isn't supported
|
@@ -340,10 +329,11 @@ if PROFILE:
|
|
340
329
|
for dev in devs: dev.synchronize()
|
341
330
|
for dev in devs: dev._at_profile_finalize()
|
342
331
|
|
343
|
-
with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(Compiled.profile_events, f)
|
332
|
+
with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(cpu_events+Compiled.profile_events+Buffer.profile_events, f)
|
344
333
|
|
345
|
-
|
346
|
-
|
334
|
+
if not getenv("SQTT", 0):
|
335
|
+
from tinygrad.uop.ops import launch_viz
|
336
|
+
launch_viz(PROFILE, fn)
|
347
337
|
|
348
338
|
if __name__ == "__main__":
|
349
339
|
for device in ALL_DEVICES:
|