tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/device.py
CHANGED
@@ -1,41 +1,44 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import
|
3
|
-
from dataclasses import dataclass
|
2
|
+
from dataclasses import dataclass, replace
|
4
3
|
from collections import defaultdict
|
5
|
-
from typing import
|
6
|
-
import importlib, inspect, functools, pathlib, os, ctypes,
|
7
|
-
from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
|
8
|
-
from tinygrad.dtype import DType, ImageDType
|
4
|
+
from typing import Optional, Dict, Tuple, Any, Iterator
|
5
|
+
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib, sys
|
6
|
+
from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
|
7
|
+
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
|
9
8
|
from tinygrad.renderer import Renderer
|
10
9
|
|
11
10
|
# **************** Device ****************
|
12
11
|
|
13
12
|
class _Device:
|
14
|
-
def __init__(self) -> None:
|
13
|
+
def __init__(self) -> None:
|
14
|
+
self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
|
15
15
|
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
16
|
-
def _canonicalize(self, device:str) -> str: return (device.split(":", 1)[0].upper() +
|
16
|
+
def _canonicalize(self, device:str) -> str: return ((d:=device.split(":", 1)[0].upper()) + device[len(d):]).replace(":0", "")
|
17
17
|
# NOTE: you can't cache canonicalize in case Device.DEFAULT changes
|
18
18
|
def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
|
19
19
|
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
|
20
20
|
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
21
21
|
def __get_canonicalized_item(self, ix:str) -> Compiled:
|
22
|
-
|
23
|
-
|
22
|
+
cpn = multiprocessing.current_process().name
|
23
|
+
assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent, not {cpn}"
|
24
24
|
x = ix.split(":")[0].upper()
|
25
|
-
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'
|
25
|
+
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{__name__.split(".")[0]}.runtime.ops_{x.lower()}')) \
|
26
|
+
if (cname.lower() == x.lower() + "device")][0](ix)
|
26
27
|
if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
|
27
28
|
return ret
|
29
|
+
@property
|
30
|
+
def default(self) -> Compiled: return self[self.DEFAULT]
|
31
|
+
def get_available_devices(self) -> Iterator[str]:
|
32
|
+
for device in ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM"]:
|
33
|
+
with contextlib.suppress(Exception): yield self[device].dname
|
28
34
|
@functools.cached_property
|
29
35
|
def DEFAULT(self) -> str:
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
return device
|
37
|
-
except Exception: pass
|
38
|
-
raise RuntimeError("no usable devices")
|
36
|
+
if (from_env:=next((d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1), None)): return from_env
|
37
|
+
try:
|
38
|
+
device = next(self.get_available_devices())
|
39
|
+
os.environ[device] = "1" # we set this in environment for spawned children
|
40
|
+
return device
|
41
|
+
except StopIteration as exc: raise RuntimeError("no usable devices") from exc
|
39
42
|
Device = _Device()
|
40
43
|
|
41
44
|
# **************** Buffer + Allocators ****************
|
@@ -47,12 +50,13 @@ class BufferOptions:
|
|
47
50
|
cpu_access: bool = False
|
48
51
|
host: bool = False
|
49
52
|
nolru: bool = False
|
53
|
+
external_ptr: Optional[int] = None
|
50
54
|
|
51
55
|
class Buffer:
|
52
56
|
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
|
53
57
|
initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
|
54
|
-
assert isinstance(dtype, DType)
|
55
58
|
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
|
59
|
+
else: assert isinstance(dtype, DType) and not isinstance(dtype, PtrDType)
|
56
60
|
self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
|
57
61
|
if base is None:
|
58
62
|
assert offset == 0, "base buffers can't have offset"
|
@@ -73,10 +77,12 @@ class Buffer:
|
|
73
77
|
def lb_refcount(self): return self.base._lb_refcount
|
74
78
|
def ref(self, cnt): self.base._lb_refcount += cnt
|
75
79
|
def is_allocated(self) -> bool: return hasattr(self, '_buf')
|
76
|
-
def ensure_allocated(self) -> Buffer: return self.allocate() if not
|
77
|
-
def allocate(self, opaque=None) -> Buffer:
|
78
|
-
assert not
|
80
|
+
def ensure_allocated(self) -> Buffer: return self.allocate() if not self.is_allocated() else self
|
81
|
+
def allocate(self, opaque=None, external_ptr=None) -> Buffer:
|
82
|
+
assert not self.is_allocated(), "can't allocate already allocated buffer"
|
79
83
|
self.allocator = Device[self.device].allocator
|
84
|
+
if external_ptr is not None:
|
85
|
+
self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferOptions(external_ptr=external_ptr)
|
80
86
|
if self._base is not None:
|
81
87
|
self._base.ensure_allocated()
|
82
88
|
assert hasattr(self.allocator, "offset"), "offset function required for view"
|
@@ -88,7 +94,7 @@ class Buffer:
|
|
88
94
|
def __reduce__(self):
|
89
95
|
buf = None
|
90
96
|
if self._base is not None:
|
91
|
-
return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset,
|
97
|
+
return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, self.is_allocated())
|
92
98
|
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
|
93
99
|
if self.is_allocated():
|
94
100
|
buf = bytearray(self.nbytes)
|
@@ -97,17 +103,17 @@ class Buffer:
|
|
97
103
|
@property
|
98
104
|
def nbytes(self): return self.size*self.dtype.itemsize
|
99
105
|
def __del__(self):
|
100
|
-
if not
|
101
|
-
if self._base is None:
|
106
|
+
if not self.is_allocated(): return
|
107
|
+
if self._base is None and (self.options is None or self.options.external_ptr is None):
|
102
108
|
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
|
103
109
|
self.allocator.free(self._buf, self.nbytes, self.options)
|
104
110
|
def __repr__(self):
|
105
|
-
return f"<buf real:{
|
106
|
-
(f" offset:{self.offset}" if hasattr(self, "base") else "") +
|
107
|
-
(">" if self.options is None else f" {self.options=}>")
|
111
|
+
return f"<buf real:{self.is_allocated()} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
|
112
|
+
(f" offset:{self.offset}" if hasattr(self, "base") else "") + (f" {self.options=}" if self.options is not None else "") + ">"
|
108
113
|
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
|
109
114
|
# zero copy with as_buffer (disabled by default due to use after free)
|
110
|
-
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer')
|
115
|
+
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer') and (self.options is None or self.options.image is None):
|
116
|
+
return self.allocator.as_buffer(self._buf)
|
111
117
|
assert not force_zero_copy, "force zero copy was passed, but copy is required"
|
112
118
|
return self.copyout(memoryview(bytearray(self.nbytes)))
|
113
119
|
def copyin(self, mv:memoryview):
|
@@ -133,13 +139,16 @@ class Allocator:
|
|
133
139
|
assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
|
134
140
|
return self._alloc(size, options if options is not None else BufferOptions())
|
135
141
|
def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
|
136
|
-
def free(self, opaque, size:int, options:Optional[BufferOptions]=None):
|
137
|
-
self._free(opaque, options if options is not None else BufferOptions())
|
142
|
+
def free(self, opaque, size:int, options:Optional[BufferOptions]=None): self._free(opaque, options if options is not None else BufferOptions())
|
138
143
|
def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
|
139
144
|
def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
|
140
145
|
def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
|
141
146
|
|
142
147
|
class LRUAllocator(Allocator): # pylint: disable=abstract-method
|
148
|
+
"""
|
149
|
+
The LRU Allocator is responsible for caching buffers.
|
150
|
+
It ensures that buffers are not freed until it is absolutely necessary, optimizing performance.
|
151
|
+
"""
|
143
152
|
def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
|
144
153
|
def alloc(self, size:int, options:Optional[BufferOptions]=None):
|
145
154
|
if len(c := self.cache[(size, options)]): return c.pop()
|
@@ -156,7 +165,8 @@ class LRUAllocator(Allocator): # pylint: disable=abstract-method
|
|
156
165
|
else: super().free(opaque, size, options)
|
157
166
|
|
158
167
|
class _MallocAllocator(LRUAllocator):
|
159
|
-
def _alloc(self, size:int, options:BufferOptions):
|
168
|
+
def _alloc(self, size:int, options:BufferOptions):
|
169
|
+
return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else (ctypes.c_uint8 * size)()
|
160
170
|
def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
|
161
171
|
def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
|
162
172
|
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
|
@@ -170,151 +180,42 @@ class CompileError(Exception): pass
|
|
170
180
|
|
171
181
|
class Compiler:
|
172
182
|
def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
|
173
|
-
def compile(self, src:str) -> bytes:
|
183
|
+
def compile(self, src:str) -> bytes: return src.encode() # NOTE: empty compiler is the default
|
174
184
|
def compile_cached(self, src:str) -> bytes:
|
175
185
|
if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
|
176
186
|
assert not getenv("ASSERT_COMPILE"), f"tried to compile with ASSERT_COMPILE set\n{src}"
|
177
187
|
lib = self.compile(src)
|
178
188
|
if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
|
179
189
|
return lib
|
190
|
+
def disassemble(self, lib:bytes): pass
|
180
191
|
|
181
192
|
class Compiled:
|
182
193
|
def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
|
183
194
|
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
|
184
195
|
self.renderer = renderer or Renderer()
|
185
|
-
def synchronize(self):
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
if
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
def _read_signal(self, sig): raise NotImplementedError("need _read_signal") # reads a value for a signal
|
212
|
-
|
213
|
-
@classmethod
|
214
|
-
def _read_timestamp(self, sig): raise NotImplementedError("need _read_timestamp") # reads a timestamp for a signal
|
215
|
-
|
216
|
-
@classmethod
|
217
|
-
def _set_signal(self, sig, value): raise NotImplementedError("need _set_signal") # sets a value for a signal
|
218
|
-
|
219
|
-
@classmethod
|
220
|
-
def _get_signal(self, value=0, **kwargs): raise NotImplementedError("need _get_signal") # allocates a new signal
|
221
|
-
|
222
|
-
@classmethod
|
223
|
-
def _wait_signal(self, signal, value=0, timeout=10000): raise NotImplementedError("need _wait_signal") # waits for a signal value
|
224
|
-
|
225
|
-
def _gpu2cpu_time(self, gpu_time, is_copy): raise NotImplementedError("need _gpu2cpu_time")
|
226
|
-
|
227
|
-
def _prof_setup(self):
|
228
|
-
self.profile_logger = ProfileLogger()
|
229
|
-
|
230
|
-
def _sync_queue(q_t):
|
231
|
-
q_t().timestamp(self.timeline_signal).signal(self.timeline_signal, self.timeline_value).submit(self)
|
232
|
-
self.timeline_value += 1
|
233
|
-
cpu_start_time = time.perf_counter_ns() / 1e3
|
234
|
-
self._wait_signal(self.timeline_signal, self.timeline_value - 1)
|
235
|
-
return cpu_start_time, self._read_timestamp(self.timeline_signal)
|
236
|
-
self.cpu_start_time, self.gpu_start_time = _sync_queue(self.hw_compute_queue_t)
|
237
|
-
self.copy_cpu_start_time, self.copy_gpu_start_time = _sync_queue(self.hw_copy_queue_t)
|
238
|
-
|
239
|
-
atexit.register(self._prof_finalize)
|
240
|
-
|
241
|
-
def _prof_process_events(self):
|
242
|
-
self.raw_prof_records += [(self._read_timestamp(st), self._read_timestamp(en), name, is_cp) for st, en, name, is_cp in self.sig_prof_records]
|
243
|
-
for st, en, _, _ in self.sig_prof_records: self.signals_pool += [st, en] # type: ignore
|
244
|
-
self.sig_prof_records = []
|
245
|
-
|
246
|
-
def _prof_finalize(self):
|
247
|
-
for st, en, name, is_cp in self.raw_prof_records:
|
248
|
-
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, ["COMPUTE", "DMA"][is_cp])]
|
249
|
-
del self.profile_logger
|
250
|
-
|
251
|
-
def _wrap_timeline_signal(self):
|
252
|
-
self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
|
253
|
-
self._set_signal(self.timeline_signal, 0)
|
254
|
-
cast(HCQCompatAllocator, self.allocator).b_timeline = [0] * len(cast(HCQCompatAllocator, self.allocator).b)
|
255
|
-
|
256
|
-
class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
|
257
|
-
def __init__(self, device, batch_size=(2 << 20), batch_cnt=32):
|
258
|
-
self.device = device
|
259
|
-
self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
|
260
|
-
self.b_timeline, self.b_next = [0] * len(self.b), 0
|
261
|
-
super().__init__()
|
262
|
-
|
263
|
-
def copyin(self, dest, src: memoryview):
|
264
|
-
with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
|
265
|
-
for i in range(0, src.nbytes, self.b[0].size):
|
266
|
-
self.b_next = (self.b_next + 1) % len(self.b)
|
267
|
-
self.device._wait_signal(self.device.timeline_signal, self.b_timeline[self.b_next])
|
268
|
-
ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
|
269
|
-
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
270
|
-
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
|
271
|
-
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
272
|
-
self.b_timeline[self.b_next] = self.device.timeline_value
|
273
|
-
self.device.timeline_value += 1
|
274
|
-
|
275
|
-
def copy_from_disk(self, dest, src, size):
|
276
|
-
def _get_temp_buf():
|
277
|
-
# Check if the next buffer is safe to be used (its signal has passed) and reserve it.
|
278
|
-
if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device._read_signal(self.device.timeline_signal):
|
279
|
-
self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
|
280
|
-
return (self.b[self.b_next].va_addr, self.b_next)
|
281
|
-
return None
|
282
|
-
|
283
|
-
with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
|
284
|
-
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
|
285
|
-
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
286
|
-
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
|
287
|
-
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
288
|
-
self.b_timeline[batch_info[1]] = self.device.timeline_value
|
289
|
-
self.device.timeline_value += 1
|
290
|
-
|
291
|
-
def copyout(self, dest:memoryview, src):
|
292
|
-
self.device.synchronize()
|
293
|
-
|
294
|
-
with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
|
295
|
-
for i in range(0, dest.nbytes, self.b[0].size):
|
296
|
-
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
297
|
-
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
|
298
|
-
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
299
|
-
self.device._wait_signal(self.device.timeline_signal, self.device.timeline_value)
|
300
|
-
self.device.timeline_value += 1
|
301
|
-
|
302
|
-
ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
|
303
|
-
|
304
|
-
def transfer(self, dest, src, sz: int, src_dev, dest_dev):
|
305
|
-
src_dev._gpu_map(dest)
|
306
|
-
|
307
|
-
with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
|
308
|
-
src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
|
309
|
-
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
310
|
-
.copy(dest.va_addr, src.va_addr, sz) \
|
311
|
-
.signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
|
312
|
-
src_dev.timeline_value += 1
|
313
|
-
|
314
|
-
if src_dev != dest_dev:
|
315
|
-
dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
|
316
|
-
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
317
|
-
.signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
|
318
|
-
dest_dev.timeline_value += 1
|
319
|
-
|
320
|
-
def offset(self, buf, size:int, offset:int): return type(buf)(base=buf.base + offset, va_addr=buf.va_addr + offset, length=size, size=size)
|
196
|
+
def synchronize(self):
|
197
|
+
"""
|
198
|
+
Synchronize all pending operations on the device.
|
199
|
+
|
200
|
+
This method ensures that all previously queued operations on the device have been completed before proceeding.
|
201
|
+
"""
|
202
|
+
# override this in your device implementation
|
203
|
+
|
204
|
+
# TODO: move this to each Device
|
205
|
+
def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
|
206
|
+
if device is None: device = Device.DEFAULT
|
207
|
+
if dtype == dtypes.bfloat16:
|
208
|
+
# NOTE: this requires bf16 buffer support
|
209
|
+
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
|
210
|
+
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
|
211
|
+
# for CI GPU and OSX, cl_khr_fp16 isn't supported
|
212
|
+
# for CI LLVM, it segfaults because it can't link to the casting function
|
213
|
+
# CI CUDA architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
|
214
|
+
# PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
|
215
|
+
if dtype == dtypes.half:
|
216
|
+
if device == "GPU": return not CI and not OSX
|
217
|
+
if device in ["CUDA", "NV"]: return not CI
|
218
|
+
if device == "LLVM": return OSX
|
219
|
+
if device == "PYTHON": return sys.version_info >= (3, 12)
|
220
|
+
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
|
221
|
+
return True
|
tinygrad/dtype.py
CHANGED
@@ -1,47 +1,79 @@
|
|
1
|
-
from
|
2
|
-
from
|
3
|
-
import functools
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable
|
3
|
+
import math, struct, ctypes, functools
|
4
|
+
from dataclasses import dataclass, fields
|
4
5
|
from tinygrad.helpers import getenv
|
5
6
|
|
6
7
|
ConstType = Union[float, int, bool]
|
7
8
|
|
8
|
-
|
9
|
-
class
|
9
|
+
# all DTypes should only be created once
|
10
|
+
class DTypeMetaClass(type):
|
11
|
+
dcache: Dict[Tuple, DType] = {}
|
12
|
+
def __call__(cls, *args, **kwargs):
|
13
|
+
if (ret:=DTypeMetaClass.dcache.get(args, None)) is not None: return ret
|
14
|
+
DTypeMetaClass.dcache[args] = ret = super().__call__(*args)
|
15
|
+
return ret
|
16
|
+
|
17
|
+
@dataclass(frozen=True, eq=False)
|
18
|
+
class DType(metaclass=DTypeMetaClass):
|
10
19
|
priority: int # this determines when things get upcasted
|
11
20
|
itemsize: int
|
12
21
|
name: str
|
13
22
|
fmt: Optional[str]
|
14
23
|
count: int
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
def
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
def
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
24
|
+
_scalar: Optional[DType]
|
25
|
+
@staticmethod
|
26
|
+
def new(priority:int, itemsize:int, name:str, fmt:Optional[str]): return DType(priority, itemsize, name, fmt, 1, None)
|
27
|
+
def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
|
28
|
+
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "")
|
29
|
+
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
|
30
|
+
@property
|
31
|
+
def base(self): return self
|
32
|
+
@property
|
33
|
+
def vcount(self): return self.count
|
34
|
+
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
35
|
+
def vec(self, sz:int) -> DType:
|
36
|
+
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
37
|
+
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
|
38
|
+
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
|
39
|
+
def ptr(self, local=False) -> PtrDType: return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1)
|
40
|
+
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
|
41
|
+
|
42
|
+
@dataclass(frozen=True, eq=False)
|
31
43
|
class PtrDType(DType):
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
def
|
44
|
+
_base: DType
|
45
|
+
local: bool
|
46
|
+
v: int
|
47
|
+
@property
|
48
|
+
def base(self): return self._base
|
49
|
+
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
50
|
+
def vec(self, sz:int) -> DType:
|
51
|
+
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
|
52
|
+
if sz == 1: return self # sz=1 is a scalar
|
53
|
+
return type(self)(*tuple(sz if f.name == 'v' else (self if f.name == '_scalar' else getattr(self, f.name)) for f in fields(self)))
|
54
|
+
def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer")
|
55
|
+
@property
|
56
|
+
def vcount(self): return self.v
|
57
|
+
def __repr__(self): return f"{self.base.__repr__()}.ptr({'local=True' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
|
58
|
+
|
59
|
+
@dataclass(frozen=True, eq=False)
|
60
|
+
class ImageDType(PtrDType):
|
61
|
+
shape: Tuple[int, ...] = () # shape of the Image
|
62
|
+
def ptr(self, local=False) -> PtrDType:
|
63
|
+
assert not local, "images can't be local"
|
64
|
+
return self
|
65
|
+
def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
|
37
66
|
|
38
67
|
class dtypes:
|
39
68
|
@staticmethod
|
40
|
-
|
69
|
+
@functools.lru_cache(None)
|
70
|
+
def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType)
|
41
71
|
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
42
|
-
|
72
|
+
@functools.lru_cache(None)
|
73
|
+
def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints
|
43
74
|
@staticmethod
|
44
|
-
|
75
|
+
@functools.lru_cache(None)
|
76
|
+
def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
|
45
77
|
@staticmethod
|
46
78
|
def from_py(x) -> DType:
|
47
79
|
if x.__class__ is float: return dtypes.default_float
|
@@ -51,23 +83,44 @@ class dtypes:
|
|
51
83
|
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
|
52
84
|
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
|
53
85
|
@staticmethod
|
54
|
-
def as_const(val: ConstType, dtype:DType):
|
86
|
+
def as_const(val: Tuple[ConstType, ...]|ConstType, dtype:DType):
|
87
|
+
if isinstance(val, tuple):
|
88
|
+
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
|
89
|
+
return tuple(dtypes.as_const(x, dtype) for x in val)
|
90
|
+
# TODO: should truncate here
|
91
|
+
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
|
92
|
+
@staticmethod
|
93
|
+
@functools.lru_cache(None)
|
94
|
+
def min(dtype:DType):
|
95
|
+
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
|
96
|
+
return -float("inf") if dtypes.is_float(dtype) else False
|
97
|
+
@staticmethod
|
98
|
+
@functools.lru_cache(None)
|
99
|
+
def max(dtype:DType):
|
100
|
+
if dtypes.is_int(dtype): return (2**(dtype.itemsize*8-(0 if dtypes.is_unsigned(dtype) else 1)))-1
|
101
|
+
return float("inf") if dtypes.is_float(dtype) else True
|
102
|
+
@staticmethod
|
103
|
+
def finfo(dtype:DType) -> Tuple[int, int]:
|
104
|
+
"""(exponent, mantissa)"""
|
105
|
+
if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type")
|
106
|
+
return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52)}[dtype]
|
55
107
|
@staticmethod
|
56
108
|
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
109
|
+
void: Final[DType] = DType.new(-1, 0, "void", None)
|
110
|
+
bool: Final[DType] = DType.new(0, 1, "bool", '?')
|
111
|
+
int8: Final[DType] = DType.new(1, 1, "char", 'b')
|
112
|
+
uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B')
|
113
|
+
int16: Final[DType] = DType.new(3, 2, "short", 'h')
|
114
|
+
uint16: Final[DType] = DType.new(4, 2, "unsigned short", 'H')
|
115
|
+
int32: Final[DType] = DType.new(5, 4, "int", 'i')
|
116
|
+
uint32: Final[DType] = DType.new(6, 4, "unsigned int", 'I')
|
117
|
+
int64: Final[DType] = DType.new(7, 8, "long", 'q')
|
118
|
+
uint64: Final[DType] = DType.new(8, 8, "unsigned long", 'Q')
|
119
|
+
float16: Final[DType] = DType.new(9, 2, "half", 'e')
|
67
120
|
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
|
68
|
-
bfloat16: Final[DType] = DType(10, 2, "__bf16", None
|
69
|
-
float32: Final[DType] = DType(11, 4, "float", 'f'
|
70
|
-
float64: Final[DType] = DType(12, 8, "double", 'd'
|
121
|
+
bfloat16: Final[DType] = DType.new(10, 2, "__bf16", None)
|
122
|
+
float32: Final[DType] = DType.new(11, 4, "float", 'f')
|
123
|
+
float64: Final[DType] = DType.new(12, 8, "double", 'd')
|
71
124
|
|
72
125
|
# dtype aliases
|
73
126
|
half = float16; float = float32; double = float64 # noqa: E702
|
@@ -76,17 +129,25 @@ class dtypes:
|
|
76
129
|
|
77
130
|
# NOTE: these are image dtypes
|
78
131
|
@staticmethod
|
79
|
-
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1,
|
132
|
+
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, shp)
|
80
133
|
@staticmethod
|
81
|
-
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1,
|
134
|
+
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, shp)
|
82
135
|
|
83
136
|
default_float: ClassVar[DType] = float32
|
84
137
|
default_int: ClassVar[DType] = int32
|
85
138
|
|
139
|
+
floats = (float16, bfloat16, float32, float64)
|
140
|
+
uints = (uint8, uint16, uint32, uint64)
|
141
|
+
sints = (int8, int16, int32, int64)
|
142
|
+
ints = uints + sints
|
143
|
+
|
86
144
|
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
87
145
|
dtypes.default_float = getattr(dtypes, env_default_float.lower())
|
88
146
|
assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
|
89
147
|
|
148
|
+
DTypeLike = Union[str, DType]
|
149
|
+
def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype)
|
150
|
+
|
90
151
|
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
|
91
152
|
# we don't support weak type and complex type
|
92
153
|
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
|
@@ -103,11 +164,25 @@ def least_upper_dtype(*ds:DType) -> DType:
|
|
103
164
|
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
|
104
165
|
|
105
166
|
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
106
|
-
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default'))
|
167
|
+
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'void'))
|
168
|
+
or v.__class__ is staticmethod or isinstance(v, tuple))}
|
107
169
|
INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
|
170
|
+
INVERSE_DTYPES_DICT['void'] = 'void'
|
108
171
|
|
109
172
|
def sum_acc_dtype(dt:DType):
|
110
173
|
# default acc dtype for sum
|
111
174
|
if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
|
112
175
|
if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
|
113
|
-
return least_upper_dtype(dt, dtypes.float)
|
176
|
+
return least_upper_dtype(dt, dtypes.float)
|
177
|
+
|
178
|
+
def truncate_fp16(x):
|
179
|
+
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
|
180
|
+
except OverflowError: return math.copysign(math.inf, x)
|
181
|
+
|
182
|
+
truncate: Dict[DType, Callable] = {dtypes.bool: bool,
|
183
|
+
# TODO: bfloat16
|
184
|
+
dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
185
|
+
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
186
|
+
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
187
|
+
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value,
|
188
|
+
dtypes.int64: lambda x: ctypes.c_int64(x).value}
|