tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +78 -90
- tinygrad/codegen/linearizer.py +237 -169
- tinygrad/codegen/uops.py +278 -242
- tinygrad/device.py +147 -10
- tinygrad/dtype.py +7 -7
- tinygrad/engine/graph.py +16 -16
- tinygrad/engine/jit.py +39 -36
- tinygrad/engine/realize.py +6 -5
- tinygrad/engine/schedule.py +15 -7
- tinygrad/engine/search.py +6 -3
- tinygrad/function.py +17 -23
- tinygrad/helpers.py +77 -8
- tinygrad/lazy.py +26 -26
- tinygrad/multi.py +13 -9
- tinygrad/nn/__init__.py +1 -1
- tinygrad/nn/datasets.py +2 -1
- tinygrad/nn/state.py +3 -4
- tinygrad/ops.py +49 -16
- tinygrad/renderer/__init__.py +8 -4
- tinygrad/renderer/assembly.py +93 -100
- tinygrad/renderer/cstyle.py +47 -42
- tinygrad/renderer/llvmir.py +30 -30
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +11504 -1
- tinygrad/runtime/autogen/comgr.py +36 -10
- tinygrad/runtime/autogen/hsa.py +146 -14
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/nv_gpu.py +269 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +20 -11
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +3 -2
- tinygrad/runtime/graph/cuda.py +2 -2
- tinygrad/runtime/graph/hcq.py +122 -78
- tinygrad/runtime/ops_amd.py +302 -316
- tinygrad/runtime/ops_cuda.py +3 -3
- tinygrad/runtime/ops_disk.py +70 -5
- tinygrad/runtime/ops_gpu.py +2 -2
- tinygrad/runtime/ops_metal.py +5 -6
- tinygrad/runtime/ops_npy.py +1 -1
- tinygrad/runtime/ops_nv.py +161 -166
- tinygrad/runtime/ops_python.py +20 -16
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +5 -2
- tinygrad/shape/symbolic.py +1 -3
- tinygrad/shape/view.py +34 -19
- tinygrad/tensor.py +219 -135
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/runtime/driver/hsa.py +0 -143
- tinygrad/runtime/graph/hsa.py +0 -171
- tinygrad/runtime/ops_hsa.py +0 -278
- tinygrad-0.9.0.dist-info/RECORD +0 -60
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/device.py
CHANGED
@@ -2,9 +2,9 @@ from __future__ import annotations
|
|
2
2
|
import multiprocessing
|
3
3
|
from dataclasses import dataclass
|
4
4
|
from collections import defaultdict
|
5
|
-
from typing import List, Optional, Dict, Tuple, Any
|
6
|
-
import importlib, inspect, functools, pathlib, os, ctypes
|
7
|
-
from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
|
5
|
+
from typing import List, Optional, Dict, Tuple, Any, cast
|
6
|
+
import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib
|
7
|
+
from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
|
8
8
|
from tinygrad.dtype import DType, ImageDType
|
9
9
|
from tinygrad.renderer import Renderer
|
10
10
|
|
@@ -19,15 +19,17 @@ class _Device:
|
|
19
19
|
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
|
20
20
|
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
21
21
|
def __get_canonicalized_item(self, ix:str) -> Compiled:
|
22
|
-
|
23
|
-
|
22
|
+
assert ((cpn:=multiprocessing.current_process().name) == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], \
|
23
|
+
f"can only open device {ix} from parent, not {cpn}"
|
24
24
|
x = ix.split(":")[0].upper()
|
25
|
-
|
25
|
+
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
|
26
|
+
if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
|
27
|
+
return ret
|
26
28
|
@functools.cached_property
|
27
29
|
def DEFAULT(self) -> str:
|
28
30
|
device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore
|
29
31
|
if device_from_env: return device_from_env
|
30
|
-
for device in ["METAL", "
|
32
|
+
for device in ["METAL", "AMD", "CUDA", "GPU", "CLANG", "LLVM"]:
|
31
33
|
try:
|
32
34
|
if self[device]:
|
33
35
|
os.environ[device] = "1" # we set this in environment for spawned children
|
@@ -171,13 +173,148 @@ class Compiler:
|
|
171
173
|
def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
|
172
174
|
def compile_cached(self, src:str) -> bytes:
|
173
175
|
if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
|
174
|
-
assert not getenv("ASSERT_COMPILE"), "tried to compile with ASSERT_COMPILE set"
|
176
|
+
assert not getenv("ASSERT_COMPILE"), f"tried to compile with ASSERT_COMPILE set\n{src}"
|
175
177
|
lib = self.compile(src)
|
176
178
|
if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
|
177
179
|
return lib
|
178
180
|
|
179
181
|
class Compiled:
|
180
182
|
def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
|
181
|
-
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler
|
182
|
-
self.renderer = renderer
|
183
|
+
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
|
184
|
+
self.renderer = renderer or Renderer()
|
183
185
|
def synchronize(self): pass # override this in your device
|
186
|
+
|
187
|
+
# **************** for HCQ Compatible Devices ****************
|
188
|
+
|
189
|
+
@contextlib.contextmanager
|
190
|
+
def hcq_profile(dev, queue_type, enabled, desc):
|
191
|
+
st, en = (dev._get_signal(), dev._get_signal()) if enabled else (None, None)
|
192
|
+
if enabled: queue_type().timestamp(st).submit(dev)
|
193
|
+
try: yield (st, en)
|
194
|
+
finally:
|
195
|
+
if enabled: queue_type().timestamp(en).submit(dev)
|
196
|
+
if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
|
197
|
+
|
198
|
+
class HCQCompatCompiled(Compiled):
|
199
|
+
def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime, comp_queue_t, copy_queue_t, timeline_signals):
|
200
|
+
self.hw_compute_queue_t, self.hw_copy_queue_t = comp_queue_t, copy_queue_t
|
201
|
+
self.timeline_value: int = 1
|
202
|
+
self.timeline_signal, self._shadow_timeline_signal = timeline_signals
|
203
|
+
self.sig_prof_records: List[Tuple[Any, Any, str, bool]] = []
|
204
|
+
self.raw_prof_records: List[Tuple[int, int, str, bool]] = []
|
205
|
+
if PROFILE: self._prof_setup()
|
206
|
+
|
207
|
+
from tinygrad.runtime.graph.hcq import HCQGraph
|
208
|
+
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
|
209
|
+
|
210
|
+
@classmethod
|
211
|
+
def _read_signal(self, sig): raise NotImplementedError("need _read_signal") # reads a value for a signal
|
212
|
+
|
213
|
+
@classmethod
|
214
|
+
def _read_timestamp(self, sig): raise NotImplementedError("need _read_timestamp") # reads a timestamp for a signal
|
215
|
+
|
216
|
+
@classmethod
|
217
|
+
def _set_signal(self, sig, value): raise NotImplementedError("need _set_signal") # sets a value for a signal
|
218
|
+
|
219
|
+
@classmethod
|
220
|
+
def _get_signal(self, value=0, **kwargs): raise NotImplementedError("need _get_signal") # allocates a new signal
|
221
|
+
|
222
|
+
@classmethod
|
223
|
+
def _wait_signal(self, signal, value=0, timeout=10000): raise NotImplementedError("need _wait_signal") # waits for a signal value
|
224
|
+
|
225
|
+
def _gpu2cpu_time(self, gpu_time, is_copy): raise NotImplementedError("need _gpu2cpu_time")
|
226
|
+
|
227
|
+
def _prof_setup(self):
|
228
|
+
self.profile_logger = ProfileLogger()
|
229
|
+
|
230
|
+
def _sync_queue(q_t):
|
231
|
+
q_t().timestamp(self.timeline_signal).signal(self.timeline_signal, self.timeline_value).submit(self)
|
232
|
+
self.timeline_value += 1
|
233
|
+
cpu_start_time = time.perf_counter_ns() / 1e3
|
234
|
+
self._wait_signal(self.timeline_signal, self.timeline_value - 1)
|
235
|
+
return cpu_start_time, self._read_timestamp(self.timeline_signal)
|
236
|
+
self.cpu_start_time, self.gpu_start_time = _sync_queue(self.hw_compute_queue_t)
|
237
|
+
self.copy_cpu_start_time, self.copy_gpu_start_time = _sync_queue(self.hw_copy_queue_t)
|
238
|
+
|
239
|
+
atexit.register(self._prof_finalize)
|
240
|
+
|
241
|
+
def _prof_process_events(self):
|
242
|
+
self.raw_prof_records += [(self._read_timestamp(st), self._read_timestamp(en), name, is_cp) for st, en, name, is_cp in self.sig_prof_records]
|
243
|
+
for st, en, _, _ in self.sig_prof_records: self.signals_pool += [st, en] # type: ignore
|
244
|
+
self.sig_prof_records = []
|
245
|
+
|
246
|
+
def _prof_finalize(self):
|
247
|
+
for st, en, name, is_cp in self.raw_prof_records:
|
248
|
+
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, ["COMPUTE", "DMA"][is_cp])]
|
249
|
+
del self.profile_logger
|
250
|
+
|
251
|
+
def _wrap_timeline_signal(self):
|
252
|
+
self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
|
253
|
+
self._set_signal(self.timeline_signal, 0)
|
254
|
+
cast(HCQCompatAllocator, self.allocator).b_timeline = [0] * len(cast(HCQCompatAllocator, self.allocator).b)
|
255
|
+
|
256
|
+
class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
|
257
|
+
def __init__(self, device, batch_size=(2 << 20), batch_cnt=32):
|
258
|
+
self.device = device
|
259
|
+
self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
|
260
|
+
self.b_timeline, self.b_next = [0] * len(self.b), 0
|
261
|
+
super().__init__()
|
262
|
+
|
263
|
+
def copyin(self, dest, src: memoryview):
|
264
|
+
with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
|
265
|
+
for i in range(0, src.nbytes, self.b[0].size):
|
266
|
+
self.b_next = (self.b_next + 1) % len(self.b)
|
267
|
+
self.device._wait_signal(self.device.timeline_signal, self.b_timeline[self.b_next])
|
268
|
+
ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
|
269
|
+
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
270
|
+
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
|
271
|
+
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
272
|
+
self.b_timeline[self.b_next] = self.device.timeline_value
|
273
|
+
self.device.timeline_value += 1
|
274
|
+
|
275
|
+
def copy_from_disk(self, dest, src, size):
|
276
|
+
def _get_temp_buf():
|
277
|
+
# Check if the next buffer is safe to be used (its signal has passed) and reserve it.
|
278
|
+
if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device._read_signal(self.device.timeline_signal):
|
279
|
+
self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
|
280
|
+
return (self.b[self.b_next].va_addr, self.b_next)
|
281
|
+
return None
|
282
|
+
|
283
|
+
with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
|
284
|
+
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
|
285
|
+
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
286
|
+
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
|
287
|
+
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
288
|
+
self.b_timeline[batch_info[1]] = self.device.timeline_value
|
289
|
+
self.device.timeline_value += 1
|
290
|
+
|
291
|
+
def copyout(self, dest:memoryview, src):
|
292
|
+
self.device.synchronize()
|
293
|
+
|
294
|
+
with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
|
295
|
+
for i in range(0, dest.nbytes, self.b[0].size):
|
296
|
+
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
297
|
+
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
|
298
|
+
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
299
|
+
self.device._wait_signal(self.device.timeline_signal, self.device.timeline_value)
|
300
|
+
self.device.timeline_value += 1
|
301
|
+
|
302
|
+
ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
|
303
|
+
|
304
|
+
def transfer(self, dest, src, sz: int, src_dev, dest_dev):
|
305
|
+
src_dev._gpu_map(dest)
|
306
|
+
|
307
|
+
with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
|
308
|
+
src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
|
309
|
+
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
310
|
+
.copy(dest.va_addr, src.va_addr, sz) \
|
311
|
+
.signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
|
312
|
+
src_dev.timeline_value += 1
|
313
|
+
|
314
|
+
if src_dev != dest_dev:
|
315
|
+
dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
|
316
|
+
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
317
|
+
.signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
|
318
|
+
dest_dev.timeline_value += 1
|
319
|
+
|
320
|
+
def offset(self, buf, size:int, offset:int): return type(buf)(base=buf.base + offset, va_addr=buf.va_addr + offset, length=size, size=size)
|
tinygrad/dtype.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union
|
2
2
|
from dataclasses import dataclass
|
3
|
-
import numpy as np # TODO: remove numpy
|
4
3
|
import functools
|
5
4
|
from tinygrad.helpers import getenv
|
6
5
|
|
@@ -18,9 +17,6 @@ class DType:
|
|
18
17
|
assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}"
|
19
18
|
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
|
20
19
|
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
|
21
|
-
# TODO: someday this will be removed with the "remove numpy" project
|
22
|
-
@property
|
23
|
-
def np(self) -> Optional[type]: return np.dtype(self.fmt).type if self.fmt is not None else None
|
24
20
|
|
25
21
|
# dependent typing?
|
26
22
|
@dataclass(frozen=True, repr=False)
|
@@ -47,9 +43,13 @@ class dtypes:
|
|
47
43
|
@staticmethod
|
48
44
|
def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
49
45
|
@staticmethod
|
50
|
-
def
|
51
|
-
|
52
|
-
|
46
|
+
def from_py(x) -> DType:
|
47
|
+
if x.__class__ is float: return dtypes.default_float
|
48
|
+
if x.__class__ is int: return dtypes.default_int
|
49
|
+
if x.__class__ is bool: return dtypes.bool
|
50
|
+
# put this in the last is faster because there are more items than lists/tuples to check
|
51
|
+
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
|
52
|
+
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
|
53
53
|
@staticmethod
|
54
54
|
def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
|
55
55
|
@staticmethod
|
tinygrad/engine/graph.py
CHANGED
@@ -1,15 +1,14 @@
|
|
1
|
-
import os, atexit, functools
|
1
|
+
import os, atexit, functools, contextlib
|
2
2
|
from collections import defaultdict
|
3
|
-
from typing import List, Any, DefaultDict
|
3
|
+
from typing import List, Any, DefaultDict, Union
|
4
4
|
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LoadOps, BufferOps, TernaryOps, LazyOp
|
5
5
|
from tinygrad.device import Device
|
6
6
|
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, getenv
|
7
|
-
from tinygrad.codegen.
|
7
|
+
from tinygrad.codegen.uops import UOps, UOp, UPat
|
8
8
|
from tinygrad.shape.symbolic import NumNode
|
9
9
|
from tinygrad.lazy import LazyBuffer
|
10
10
|
|
11
|
-
|
12
|
-
except ImportError: pass
|
11
|
+
with contextlib.suppress(ImportError): import networkx as nx
|
13
12
|
|
14
13
|
# **** debugging and graphing ****
|
15
14
|
|
@@ -61,7 +60,7 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
|
|
61
60
|
for idx,x in enumerate(lb.srcs):
|
62
61
|
if nm(x) not in G.nodes: log_lazybuffer(x)
|
63
62
|
if x.base.realized is None and x.base.op is LoadOps.CONST:
|
64
|
-
label_append.append(f"\nCONST{idx} {x.base.arg}")
|
63
|
+
label_append.append(f"\nCONST{idx} {x.base.arg:g}")
|
65
64
|
else:
|
66
65
|
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
|
67
66
|
label = '"' + \
|
@@ -75,18 +74,19 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
|
|
75
74
|
# realized but unseen?
|
76
75
|
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
|
77
76
|
|
78
|
-
def _tree(
|
77
|
+
def _tree(dag:Union[LazyOp, UOp, UPat], cycles, cnt):
|
79
78
|
cnt[0] += 1
|
80
|
-
if
|
81
|
-
if (
|
82
|
-
|
79
|
+
src = dag.src if isinstance(dag.src, (list, tuple)) else [] if dag.src is None else [dag.src]
|
80
|
+
if len(src) == 0: return [f"━━ {dag.op} {dag.arg}"]
|
81
|
+
if (lid := id(dag)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
|
82
|
+
return [f"━⬆︎ goto {cycles[id(dag)][0]}: {dag.op}"]
|
83
83
|
cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
|
84
|
-
lines = [f"━┳ {
|
85
|
-
childs = [_tree(c, cycles, cnt) for c in
|
84
|
+
lines = [f"━┳ {dag.op} {dag.arg}"]
|
85
|
+
childs = [_tree(c, cycles, cnt) for c in src]
|
86
86
|
for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
|
87
87
|
return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
|
88
88
|
|
89
|
-
def print_tree(
|
89
|
+
def print_tree(dag:Union[LazyOp, UOp, UPat]): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(dag, {}, [-1]))]))
|
90
90
|
|
91
91
|
def graph_uops(uops:List[UOp]):
|
92
92
|
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
|
@@ -94,7 +94,7 @@ def graph_uops(uops:List[UOp]):
|
|
94
94
|
UOps.RANGE: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
|
95
95
|
G = nx.DiGraph()
|
96
96
|
for u in uops:
|
97
|
-
if u.
|
98
|
-
G.add_node(uops.index(u), label=f"{str(u.
|
99
|
-
for v in u.
|
97
|
+
if u.op in {UOps.ENDRANGE, UOps.ENDIF}: continue
|
98
|
+
G.add_node(uops.index(u), label=f"{str(u.op)[5:]}{(' '+str(u.arg).replace(':', '')) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.op, "#ffffff")) # noqa: E501
|
99
|
+
for v in u.src: G.add_edge(uops.index(v), uops.index(u))
|
100
100
|
save_graph(G, f'{GRAPHPATH}.uops', '-Grankdir=LR')
|
tinygrad/engine/jit.py
CHANGED
@@ -3,7 +3,7 @@ from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, O
|
|
3
3
|
import functools, itertools, collections
|
4
4
|
from tinygrad.tensor import Tensor
|
5
5
|
from tinygrad.lazy import LazyBuffer
|
6
|
-
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
|
6
|
+
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, ContextVar, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
|
7
7
|
from tinygrad.device import Buffer, Compiled, Device
|
8
8
|
from tinygrad.dtype import DType
|
9
9
|
from tinygrad.shape.shapetracker import ShapeTracker
|
@@ -41,7 +41,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
|
41
41
|
if ji.prg.__class__ in {EmptyOp, ViewOp}: continue
|
42
42
|
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
|
43
43
|
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
|
44
|
-
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"
|
44
|
+
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
|
45
45
|
ji_graph_dev = Device[ji.bufs[0].device]
|
46
46
|
|
47
47
|
graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None #type: ignore
|
@@ -82,7 +82,7 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
|
|
82
82
|
if ji.prg.p.vars: self.jc_idx_with_updatable_var_vals.append(j)
|
83
83
|
if (ji.prg.p.global_size and not all_int(ji.prg.p.global_size)) or (ji.prg.p.local_size and not all_int(ji.prg.p.local_size)):
|
84
84
|
self.jc_idx_with_updatable_launch_dims.append(j)
|
85
|
-
self.vars =
|
85
|
+
self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
|
86
86
|
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0], op_estimate, mem_estimate)
|
87
87
|
|
88
88
|
class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
|
@@ -97,15 +97,16 @@ class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
|
|
97
97
|
wait_nodes = []
|
98
98
|
|
99
99
|
for rawbuf in read + write:
|
100
|
-
if id(rawbuf._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf._buf)])
|
100
|
+
if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)])
|
101
101
|
for rawbuf in write:
|
102
|
-
if id(rawbuf._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf._buf)))
|
102
|
+
if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
|
103
103
|
|
104
|
-
for rawbuf in read: self.r_dependency_map[id(rawbuf._buf)].append(new_dependency)
|
105
|
-
for rawbuf in write: self.w_dependency_map[id(rawbuf._buf)] = new_dependency
|
104
|
+
for rawbuf in read: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
|
105
|
+
for rawbuf in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
|
106
106
|
return list({id(x):x for x in wait_nodes}.values())
|
107
107
|
|
108
108
|
ReturnType = TypeVar('ReturnType')
|
109
|
+
IN_JIT = ContextVar('IN_JIT', 0)
|
109
110
|
class TinyJit(Generic[ReturnType]):
|
110
111
|
def __init__(self, fxn:Callable[..., ReturnType]):
|
111
112
|
self.fxn = fxn
|
@@ -134,25 +135,26 @@ class TinyJit(Generic[ReturnType]):
|
|
134
135
|
|
135
136
|
def __call__(self, *args, **kwargs) -> ReturnType:
|
136
137
|
input_tensors: List[Tuple[Union[int, str], Tensor]] = \
|
137
|
-
[(cast(Union[int, str],
|
138
|
-
if
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
if self.cnt == 0:
|
138
|
+
[(cast(Union[int, str], name),t) for name,t in itertools.chain(enumerate(args), sorted(kwargs.items())) if t.__class__ is Tensor]
|
139
|
+
if input_tensors: Tensor.realize(*[t for _,t in input_tensors])
|
140
|
+
names: List[Union[int, str]] = [name for name,_ in input_tensors]
|
141
|
+
lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for _,t in input_tensors])
|
142
|
+
st_varvals_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
|
143
|
+
input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
|
144
|
+
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
|
145
|
+
var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
|
146
|
+
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
|
147
|
+
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
|
148
|
+
if not JIT or self.cnt == 0:
|
149
|
+
if IN_JIT: raise RuntimeError("having TinyJit inside another TinyJit is not supported")
|
148
150
|
# jit ignore
|
149
|
-
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
|
151
|
+
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value, IN_JIT=1):
|
150
152
|
self.ret = self.fxn(*args, **kwargs)
|
151
153
|
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
|
152
154
|
elif self.cnt == 1:
|
153
155
|
# jit capture
|
154
|
-
self.expected_names: List[Union[int, str]] =
|
155
|
-
self.
|
156
|
+
self.expected_names: List[Union[int, str]] = names
|
157
|
+
self.expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = st_vars_dtype_device
|
156
158
|
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
|
157
159
|
capturing.append(self)
|
158
160
|
self.ret = self.fxn(*args, **kwargs)
|
@@ -160,31 +162,32 @@ class TinyJit(Generic[ReturnType]):
|
|
160
162
|
capturing.clear()
|
161
163
|
del self.buffer_replace
|
162
164
|
assert len(self.jit_cache), "didn't JIT anything!"
|
163
|
-
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(
|
165
|
+
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_buffers)} inputs")
|
164
166
|
|
165
167
|
# track inputs that are views of buffers
|
166
|
-
for
|
167
|
-
for b in
|
168
|
-
if b is not None and b._base is not None and b._base in
|
169
|
-
|
170
|
-
self.extra_view_inputs.append((
|
168
|
+
for item in self.jit_cache:
|
169
|
+
for b in item.bufs:
|
170
|
+
if b is not None and b._base is not None and b._base in input_buffers:
|
171
|
+
input_buffers.append(b)
|
172
|
+
self.extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
|
171
173
|
|
172
174
|
# memory planning (optional)
|
173
|
-
assigned = _internal_memory_planner([cast(List[Buffer],
|
174
|
-
self.jit_cache = [ExecItem(
|
175
|
+
assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in self.jit_cache], debug_prefix="JIT ")
|
176
|
+
self.jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in self.jit_cache]
|
175
177
|
|
176
178
|
# Condense the items into a graph executor.
|
177
|
-
if JIT < 2: self.jit_cache = apply_graph_to_jit(self.jit_cache,
|
179
|
+
if JIT < 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals)
|
178
180
|
|
179
|
-
self.input_replace = get_input_replace(self.jit_cache,
|
180
|
-
if DEBUG >= 1 and len(set(self.input_replace.values())) != len(
|
181
|
+
self.input_replace = get_input_replace(self.jit_cache, input_buffers)
|
182
|
+
if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
|
181
183
|
elif self.cnt >= 2:
|
182
184
|
# jit exec
|
183
|
-
assert self.expected_names ==
|
184
|
-
assert self.
|
185
|
+
assert self.expected_names == names, f"args mismatch in JIT: {self.expected_names=} != {names}"
|
186
|
+
assert self.expected_st_vars_dtype_device == st_vars_dtype_device, \
|
187
|
+
f"args mismatch in JIT: {self.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}"
|
185
188
|
for idx, offset, device, size, dtype in self.extra_view_inputs:
|
186
|
-
|
187
|
-
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] =
|
189
|
+
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
|
190
|
+
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_buffers[input_idx]
|
188
191
|
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
|
189
192
|
for ei in self.jit_cache: ei.run(var_vals, jit=True)
|
190
193
|
|
tinygrad/engine/realize.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from typing import List, Dict, Optional, cast, Generator, Tuple
|
2
2
|
import time
|
3
3
|
from dataclasses import dataclass, replace
|
4
|
-
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int
|
4
|
+
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING
|
5
5
|
from tinygrad.ops import BufferOps, LoadOps, LazyOp
|
6
6
|
from tinygrad.device import Device, Buffer
|
7
7
|
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
@@ -38,7 +38,7 @@ def get_linearizer(renderer:Renderer, ast:Tuple[LazyOp, ...]) -> Linearizer:
|
|
38
38
|
if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
|
39
39
|
# TODO: check the correctness inline once compare_linearizer is in core
|
40
40
|
if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
|
41
|
-
if DEBUG >=
|
41
|
+
if DEBUG >= 5: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
|
42
42
|
return k
|
43
43
|
|
44
44
|
# **************** Runners ****************
|
@@ -101,8 +101,9 @@ class BufferCopy(Runner):
|
|
101
101
|
else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
|
102
102
|
super().__init__(colored(name, "yellow"), dest_device, 0, total_sz)
|
103
103
|
def copy(self, dest, src):
|
104
|
-
|
105
|
-
|
104
|
+
disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.device, 'io_uring') and hasattr(src.allocator.device, 'fd')
|
105
|
+
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
|
106
|
+
dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
|
106
107
|
elif src.device.startswith("DISK") and hasattr(dest.allocator, 'as_buffer'):
|
107
108
|
# fast(ish) path, uses readinto in diskbuffers
|
108
109
|
src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
|
@@ -187,5 +188,5 @@ capturing: List = [] # put classes with an add method in here
|
|
187
188
|
|
188
189
|
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None, do_update_stats=True):
|
189
190
|
for ei in lower_schedule(schedule):
|
190
|
-
if len(capturing): capturing[0].add(ei)
|
191
|
+
if len(capturing) and CAPTURING: capturing[0].add(ei)
|
191
192
|
ei.run(var_vals, do_update_stats=do_update_stats)
|
tinygrad/engine/schedule.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
import sys, pickle, atexit
|
2
2
|
from collections import defaultdict, deque
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union
|
4
|
+
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, get_args
|
5
5
|
from tinygrad.ops import LoadOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
|
6
6
|
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
7
|
-
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, prod, dedup, all_int, merge_dicts, getenv
|
7
|
+
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv
|
8
8
|
from tinygrad.shape.symbolic import Variable
|
9
|
-
from tinygrad.dtype import ImageDType, dtypes, DType
|
9
|
+
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType
|
10
10
|
from tinygrad.lazy import LazyBuffer
|
11
11
|
from tinygrad.shape.shapetracker import ShapeTracker
|
12
12
|
from tinygrad.device import Buffer
|
@@ -56,8 +56,13 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[Laz
|
|
56
56
|
if buf.op is LoadOps.CONST:
|
57
57
|
unbound_st, st_var_vals = st.simplify().unbind()
|
58
58
|
var_vals.update(st_var_vals)
|
59
|
-
if isinstance(buf.arg, Variable):
|
60
|
-
|
59
|
+
if isinstance(buf.arg, Variable):
|
60
|
+
val, var_val = buf.arg.unbind()
|
61
|
+
var_vals.__setitem__(val, var_val)
|
62
|
+
else:
|
63
|
+
assert isinstance(buf.arg, get_args(ConstType)), f"cannot create ConstBuffer with value {buf.arg}"
|
64
|
+
val = buf.arg
|
65
|
+
return LazyOp(BufferOps.CONST, (), ConstBuffer(val, buf.dtype, unbound_st))
|
61
66
|
|
62
67
|
# if we aren't fusing it, it's a load and we add it to the inputs
|
63
68
|
if buf.realized is not None or (buf in realizes and buf not in outputs):
|
@@ -69,7 +74,8 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[Laz
|
|
69
74
|
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
70
75
|
if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
|
71
76
|
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
|
72
|
-
raise RuntimeError(
|
77
|
+
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
78
|
+
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
73
79
|
return LazyOp(BufferOps.LOAD, (), MemBuffer(outputs.index(assign_targets[buf]), buf.dtype, unbound_st))
|
74
80
|
if buf not in inputs: inputs.append(buf)
|
75
81
|
return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.index(buf), buf.dtype, unbound_st))
|
@@ -138,6 +144,8 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
|
|
138
144
|
pass # don't realize image to image casts. this is part of a larger problem
|
139
145
|
else:
|
140
146
|
realizes[buf.base] = None
|
147
|
+
# check all other pads for safe fusion
|
148
|
+
elif any(v.mask is not None for v in buf.st.views): simple_pads.add(buf.base)
|
141
149
|
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
|
142
150
|
# base
|
143
151
|
allbufs[buf] = None
|
@@ -308,7 +316,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
|
|
308
316
|
if SAVE_SCHEDULE:
|
309
317
|
def _save():
|
310
318
|
print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
|
311
|
-
|
319
|
+
with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
|
312
320
|
if len(SCHEDULES) == 0: atexit.register(_save)
|
313
321
|
SCHEDULES.extend((ps.ast for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)])
|
314
322
|
# confirm everything was scheduled correctly
|
tinygrad/engine/search.py
CHANGED
@@ -45,7 +45,7 @@ def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_
|
|
45
45
|
input_bufs = [rawbufs[i] for i,_ in car.p.globals]
|
46
46
|
for _ in range(cnt):
|
47
47
|
if clear_l2:
|
48
|
-
with Context(DEBUG=0, BEAM=0,
|
48
|
+
with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
49
49
|
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
|
50
50
|
if early_stop is not None and early_stop < tms[-1]: break
|
51
51
|
return tms
|
@@ -70,6 +70,9 @@ def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler) ->
|
|
70
70
|
ret = None
|
71
71
|
except TimeoutException:
|
72
72
|
ret = None
|
73
|
+
except Exception as e:
|
74
|
+
if getenv("BEAM_STRICT_MODE"): raise e
|
75
|
+
ret = None
|
73
76
|
finally:
|
74
77
|
signal.alarm(0)
|
75
78
|
return x[0], ret
|
@@ -115,7 +118,7 @@ beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
|
|
115
118
|
def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=True) -> Linearizer:
|
116
119
|
global beam_pool
|
117
120
|
key = {"ast": lin.ast[0].key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
118
|
-
if (val:=diskcache_get("beam_search", key)) is not None
|
121
|
+
if not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
|
119
122
|
ret = lin.copy()
|
120
123
|
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
121
124
|
return ret
|
@@ -123,7 +126,7 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T
|
|
123
126
|
beam: List[Tuple[Linearizer, float]] = [(lin, float("inf"))]
|
124
127
|
seen_libs = set()
|
125
128
|
|
126
|
-
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "
|
129
|
+
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV"} else 0
|
127
130
|
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
|
128
131
|
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
|
129
132
|
|