tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
- tinygrad/codegen/kernel.py +572 -83
- tinygrad/codegen/linearizer.py +415 -395
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +183 -0
- tinygrad/dtype.py +113 -0
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +76 -55
- tinygrad/helpers.py +196 -89
- tinygrad/lazy.py +210 -371
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +202 -22
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +112 -32
- tinygrad/nn/state.py +136 -39
- tinygrad/ops.py +119 -202
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +353 -166
- tinygrad/renderer/llvmir.py +150 -138
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +81 -0
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +75 -0
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +24 -77
- tinygrad/runtime/ops_cuda.py +175 -89
- tinygrad/runtime/ops_disk.py +56 -33
- tinygrad/runtime/ops_gpu.py +92 -95
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +39 -60
- tinygrad/runtime/ops_metal.py +92 -74
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +86 -254
- tinygrad/shape/symbolic.py +166 -141
- tinygrad/shape/view.py +296 -0
- tinygrad/tensor.py +2619 -448
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- tinygrad-0.9.0.dist-info/METADATA +227 -0
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/assembly.py +0 -190
- tinygrad/codegen/optimizer.py +0 -379
- tinygrad/codegen/search.py +0 -72
- tinygrad/graph.py +0 -83
- tinygrad/jit.py +0 -57
- tinygrad/nn/image.py +0 -100
- tinygrad/renderer/assembly_arm64.py +0 -169
- tinygrad/renderer/assembly_ptx.py +0 -98
- tinygrad/renderer/wgsl.py +0 -53
- tinygrad/runtime/lib.py +0 -113
- tinygrad/runtime/ops_cpu.py +0 -51
- tinygrad/runtime/ops_hip.py +0 -82
- tinygrad/runtime/ops_shm.py +0 -29
- tinygrad/runtime/ops_torch.py +0 -30
- tinygrad/runtime/ops_webgpu.py +0 -45
- tinygrad-0.7.0.dist-info/METADATA +0 -212
- tinygrad-0.7.0.dist-info/RECORD +0 -40
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/dtype.py
ADDED
@@ -0,0 +1,113 @@
|
|
1
|
+
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union
|
2
|
+
from dataclasses import dataclass
|
3
|
+
import numpy as np # TODO: remove numpy
|
4
|
+
import functools
|
5
|
+
from tinygrad.helpers import getenv
|
6
|
+
|
7
|
+
ConstType = Union[float, int, bool]
|
8
|
+
|
9
|
+
@dataclass(frozen=True, order=True)
|
10
|
+
class DType:
|
11
|
+
priority: int # this determines when things get upcasted
|
12
|
+
itemsize: int
|
13
|
+
name: str
|
14
|
+
fmt: Optional[str]
|
15
|
+
count: int
|
16
|
+
def __repr__(self): return f"dtypes.{'_'*(c:=self.count!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.count)*c}"
|
17
|
+
def vec(self, sz:int):
|
18
|
+
assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}"
|
19
|
+
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
|
20
|
+
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
|
21
|
+
# TODO: someday this will be removed with the "remove numpy" project
|
22
|
+
@property
|
23
|
+
def np(self) -> Optional[type]: return np.dtype(self.fmt).type if self.fmt is not None else None
|
24
|
+
|
25
|
+
# dependent typing?
|
26
|
+
@dataclass(frozen=True, repr=False)
|
27
|
+
class ImageDType(DType):
|
28
|
+
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
|
29
|
+
base: DType
|
30
|
+
def scalar(self): return self.base
|
31
|
+
def vec(self, sz:int): return self.base.vec(sz)
|
32
|
+
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
33
|
+
|
34
|
+
# @dataclass(frozen=True, init=False, repr=False, eq=False)
|
35
|
+
class PtrDType(DType):
|
36
|
+
def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
|
37
|
+
def __repr__(self): return f"ptr.{super().__repr__()}"
|
38
|
+
def __hash__(self): return super().__hash__()
|
39
|
+
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
|
40
|
+
def __ne__(self, dt): return not (self == dt)
|
41
|
+
|
42
|
+
class dtypes:
|
43
|
+
@staticmethod
|
44
|
+
def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64)
|
45
|
+
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
46
|
+
def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) or dtypes.is_unsigned(x)
|
47
|
+
@staticmethod
|
48
|
+
def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
49
|
+
@staticmethod
|
50
|
+
def from_np(x: type) -> DType: return DTYPES_DICT[np.dtype(x).name]
|
51
|
+
@staticmethod # NOTE: isinstance(True, int) is True in python
|
52
|
+
def from_py(x) -> DType: return dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int
|
53
|
+
@staticmethod
|
54
|
+
def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
|
55
|
+
@staticmethod
|
56
|
+
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
57
|
+
bool: Final[DType] = DType(0, 1, "bool", '?', 1)
|
58
|
+
int8: Final[DType] = DType(1, 1, "char", 'b', 1)
|
59
|
+
uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
|
60
|
+
int16: Final[DType] = DType(3, 2, "short", 'h', 1)
|
61
|
+
uint16: Final[DType] = DType(4, 2, "unsigned short", 'H', 1)
|
62
|
+
int32: Final[DType] = DType(5, 4, "int", 'i', 1)
|
63
|
+
uint32: Final[DType] = DType(6, 4, "unsigned int", 'I', 1)
|
64
|
+
int64: Final[DType] = DType(7, 8, "long", 'l', 1)
|
65
|
+
uint64: Final[DType] = DType(8, 8, "unsigned long", 'L', 1)
|
66
|
+
float16: Final[DType] = DType(9, 2, "half", 'e', 1)
|
67
|
+
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
|
68
|
+
bfloat16: Final[DType] = DType(10, 2, "__bf16", None, 1)
|
69
|
+
float32: Final[DType] = DType(11, 4, "float", 'f', 1)
|
70
|
+
float64: Final[DType] = DType(12, 8, "double", 'd', 1)
|
71
|
+
|
72
|
+
# dtype aliases
|
73
|
+
half = float16; float = float32; double = float64 # noqa: E702
|
74
|
+
uchar = uint8; ushort = uint16; uint = uint32; ulong = uint64 # noqa: E702
|
75
|
+
char = int8; short = int16; int = int32; long = int64 # noqa: E702
|
76
|
+
|
77
|
+
# NOTE: these are image dtypes
|
78
|
+
@staticmethod
|
79
|
+
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, shape=shp, base=dtypes.float32)
|
80
|
+
@staticmethod
|
81
|
+
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, shape=shp, base=dtypes.float32)
|
82
|
+
|
83
|
+
default_float: ClassVar[DType] = float32
|
84
|
+
default_int: ClassVar[DType] = int32
|
85
|
+
|
86
|
+
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
87
|
+
dtypes.default_float = getattr(dtypes, env_default_float.lower())
|
88
|
+
assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
|
89
|
+
|
90
|
+
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
|
91
|
+
# we don't support weak type and complex type
|
92
|
+
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
|
93
|
+
dtypes.int64: [dtypes.float16, dtypes.bfloat16], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
|
94
|
+
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16],
|
95
|
+
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
|
96
|
+
|
97
|
+
@functools.lru_cache(None)
|
98
|
+
def _get_recursive_parents(dtype:DType) -> Set[DType]:
|
99
|
+
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
|
100
|
+
@functools.lru_cache(None)
|
101
|
+
def least_upper_dtype(*ds:DType) -> DType:
|
102
|
+
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
|
103
|
+
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
|
104
|
+
|
105
|
+
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
106
|
+
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default')) or v.__class__ is staticmethod)}
|
107
|
+
INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
|
108
|
+
|
109
|
+
def sum_acc_dtype(dt:DType):
|
110
|
+
# default acc dtype for sum
|
111
|
+
if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
|
112
|
+
if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
|
113
|
+
return least_upper_dtype(dt, dtypes.float)
|
File without changes
|
tinygrad/engine/graph.py
ADDED
@@ -0,0 +1,100 @@
|
|
1
|
+
import os, atexit, functools
|
2
|
+
from collections import defaultdict
|
3
|
+
from typing import List, Any, DefaultDict
|
4
|
+
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LoadOps, BufferOps, TernaryOps, LazyOp
|
5
|
+
from tinygrad.device import Device
|
6
|
+
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, getenv
|
7
|
+
from tinygrad.codegen.linearizer import UOps, UOp
|
8
|
+
from tinygrad.shape.symbolic import NumNode
|
9
|
+
from tinygrad.lazy import LazyBuffer
|
10
|
+
|
11
|
+
try: import networkx as nx
|
12
|
+
except ImportError: pass
|
13
|
+
|
14
|
+
# **** debugging and graphing ****
|
15
|
+
|
16
|
+
if DEBUG >= 2:
|
17
|
+
def print_globalcounters():
|
18
|
+
if GlobalCounters.time_sum_s == 0: return
|
19
|
+
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
|
20
|
+
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
|
21
|
+
atexit.register(print_globalcounters)
|
22
|
+
|
23
|
+
def save_graph(G, fn, opt=""):
|
24
|
+
print("saving", G, f"to {fn}.svg")
|
25
|
+
nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot')
|
26
|
+
os.system(f'dot {opt} -Tsvg {fn}.dot -o {fn}.svg')
|
27
|
+
|
28
|
+
G:Any = None
|
29
|
+
def init_graph():
|
30
|
+
global G
|
31
|
+
if G is not None: return
|
32
|
+
G = nx.DiGraph()
|
33
|
+
atexit.register(functools.partial(save_graph, G, GRAPHPATH)) # -Gnslimit=100 can make it finish, but you won't like results
|
34
|
+
|
35
|
+
counts: DefaultDict[type, int] = defaultdict(int)
|
36
|
+
def nm(x):
|
37
|
+
if not hasattr(x, 'node_id'):
|
38
|
+
setattr(x, 'node_id', counts[type(x)])
|
39
|
+
counts[type(x)] += 1
|
40
|
+
return x.node_id
|
41
|
+
|
42
|
+
def realized_lazybuffer(lb:'LazyBuffer', num):
|
43
|
+
init_graph()
|
44
|
+
G.nodes[nm(lb)]['style'] = '"filled,bold"'
|
45
|
+
G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
|
46
|
+
G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num}"'
|
47
|
+
|
48
|
+
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0",
|
49
|
+
TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'}
|
50
|
+
def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
|
51
|
+
init_graph()
|
52
|
+
if lb.base.realized is None and lb.base.op is LoadOps.CONST: return
|
53
|
+
if lb.base != lb:
|
54
|
+
offset = lb.st.expr_idxs([NumNode(0)] * len(lb.st.shape))[0]
|
55
|
+
label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")
|
56
|
+
G.add_node(nm(lb), style='"filled,dashed"', fillcolor="#80ff8080", color="black", label=label)
|
57
|
+
G.add_edge(nm(lb.base), nm(lb), color='#00000060')
|
58
|
+
lb = lb.base
|
59
|
+
if lb.realized is None:
|
60
|
+
label_append = []
|
61
|
+
for idx,x in enumerate(lb.srcs):
|
62
|
+
if nm(x) not in G.nodes: log_lazybuffer(x)
|
63
|
+
if x.base.realized is None and x.base.op is LoadOps.CONST:
|
64
|
+
label_append.append(f"\nCONST{idx} {x.base.arg}")
|
65
|
+
else:
|
66
|
+
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
|
67
|
+
label = '"' + \
|
68
|
+
(str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
|
69
|
+
(f"\n{lb.dtype.name}" if lb.dtype.name != "float" else "")+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {LoadOps.CONST, UnaryOps.CAST} else "") + \
|
70
|
+
(f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + ''.join(label_append) + '"'
|
71
|
+
G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
|
72
|
+
if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
|
73
|
+
else:
|
74
|
+
if nm(lb) not in G.nodes:
|
75
|
+
# realized but unseen?
|
76
|
+
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
|
77
|
+
|
78
|
+
def _tree(lazyop:LazyOp, cycles, cnt, prefix=""):
|
79
|
+
cnt[0] += 1
|
80
|
+
if len(lazyop.src) == 0: return [f"━━ {prefix}{lazyop.op.name} {lazyop.arg if lazyop.arg else ''}"]
|
81
|
+
if (lid := id(lazyop)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
|
82
|
+
return [f"━⬆︎ goto {cycles[id(lazyop)][0]}: {lazyop.op.name}"]
|
83
|
+
cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
|
84
|
+
lines = [f"━┳ {prefix}{lazyop.op.name} {lazyop.arg if lazyop.arg else ''}"]
|
85
|
+
childs = [_tree(c, cycles, cnt) for c in lazyop.src[:]]
|
86
|
+
for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
|
87
|
+
return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
|
88
|
+
|
89
|
+
def print_tree(lazyop:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazyop, {}, [-1]))]))
|
90
|
+
|
91
|
+
def graph_uops(uops:List[UOp]):
|
92
|
+
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
|
93
|
+
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
|
94
|
+
UOps.RANGE: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
|
95
|
+
G = nx.DiGraph()
|
96
|
+
for u in uops:
|
97
|
+
if u.uop in {UOps.ENDRANGE, UOps.ENDIF}: continue
|
98
|
+
G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) # noqa: E501
|
99
|
+
for v in u.vin: G.add_edge(uops.index(v), uops.index(u))
|
100
|
+
save_graph(G, f'{GRAPHPATH}.uops', '-Grankdir=LR')
|
tinygrad/engine/jit.py
ADDED
@@ -0,0 +1,195 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional, Any
|
3
|
+
import functools, itertools, collections
|
4
|
+
from tinygrad.tensor import Tensor
|
5
|
+
from tinygrad.lazy import LazyBuffer
|
6
|
+
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
|
7
|
+
from tinygrad.device import Buffer, Compiled, Device
|
8
|
+
from tinygrad.dtype import DType
|
9
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
10
|
+
from tinygrad.shape.symbolic import Variable, sint
|
11
|
+
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner
|
12
|
+
from tinygrad.engine.schedule import _internal_memory_planner
|
13
|
+
from tinygrad.nn.state import get_parameters
|
14
|
+
from weakref import WeakKeyDictionary
|
15
|
+
|
16
|
+
def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[ExecItem]:
|
17
|
+
# Split JIT cache into batches for faster graph execution.
|
18
|
+
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
19
|
+
max_batch_size = getenv("JIT_BATCH_SIZE", 32)
|
20
|
+
graphed_jit_cache: List[ExecItem] = []
|
21
|
+
current_batch: List[ExecItem] = []
|
22
|
+
current_device: Optional[Compiled] = None
|
23
|
+
|
24
|
+
def flush_batch():
|
25
|
+
nonlocal current_batch, current_device, max_batch_size
|
26
|
+
try:
|
27
|
+
if len(current_batch) <= 1 or current_device is None: raise GraphException("only one kernel doesn't graph")
|
28
|
+
graph_runner = current_device.graph(current_batch, input_rawbuffers, var_vals)
|
29
|
+
# clear jit inputs to allow their memory to be freed/reused
|
30
|
+
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
|
31
|
+
graphed_jit_cache.append(ExecItem(graph_runner, cast(List[Optional[Buffer]], input_rawbuffers)))
|
32
|
+
max_batch_size *= 2
|
33
|
+
if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
|
34
|
+
except GraphException as e:
|
35
|
+
graphed_jit_cache.extend(current_batch)
|
36
|
+
if DEBUG >= 2: print(f"\tJIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
|
37
|
+
current_batch = []
|
38
|
+
current_device = None
|
39
|
+
|
40
|
+
for ji in jit_cache:
|
41
|
+
if ji.prg.__class__ in {EmptyOp, ViewOp}: continue
|
42
|
+
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
|
43
|
+
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
|
44
|
+
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA", "NV", "AMD"}:
|
45
|
+
ji_graph_dev = Device[ji.bufs[0].device]
|
46
|
+
|
47
|
+
graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None #type: ignore
|
48
|
+
can_be_graphed = ji_graph_dev and ji_graph_dev.graph
|
49
|
+
can_share_graph = (ji_graph_dev == current_device or (isinstance(graph_class, type) and issubclass(graph_class, MultiGraphRunner)) and
|
50
|
+
type(ji_graph_dev) == type(current_device))
|
51
|
+
can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and can_share_graph
|
52
|
+
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
|
53
|
+
|
54
|
+
if can_be_graphed: current_batch.append(ji)
|
55
|
+
else: graphed_jit_cache.append(ji)
|
56
|
+
|
57
|
+
current_device = ji_graph_dev
|
58
|
+
|
59
|
+
if len(current_batch) > 0: flush_batch()
|
60
|
+
return graphed_jit_cache
|
61
|
+
|
62
|
+
def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
|
63
|
+
input_replace: Dict[Tuple[int, int], int] = {}
|
64
|
+
for j,ji in enumerate(jit_cache):
|
65
|
+
for i,a in enumerate(ji.bufs):
|
66
|
+
if a in input_rawbuffers:
|
67
|
+
input_replace[(j,i)] = input_rawbuffers.index(a)
|
68
|
+
return input_replace
|
69
|
+
|
70
|
+
class GraphRunner(Runner): # pylint: disable=abstract-method
|
71
|
+
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
72
|
+
self.jit_cache = jit_cache
|
73
|
+
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
74
|
+
self.jc_idx_with_updatable_launch_dims = []
|
75
|
+
self.jc_idx_with_updatable_var_vals = []
|
76
|
+
op_estimate: sint = 0
|
77
|
+
mem_estimate: sint = 0
|
78
|
+
for j,ji in enumerate(jit_cache):
|
79
|
+
op_estimate += ji.prg.op_estimate
|
80
|
+
mem_estimate += ji.prg.mem_estimate
|
81
|
+
if isinstance(ji.prg, CompiledRunner):
|
82
|
+
if ji.prg.p.vars: self.jc_idx_with_updatable_var_vals.append(j)
|
83
|
+
if (ji.prg.p.global_size and not all_int(ji.prg.p.global_size)) or (ji.prg.p.local_size and not all_int(ji.prg.p.local_size)):
|
84
|
+
self.jc_idx_with_updatable_launch_dims.append(j)
|
85
|
+
self.vars = list(var_vals.keys())
|
86
|
+
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0], op_estimate, mem_estimate)
|
87
|
+
|
88
|
+
class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
|
89
|
+
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
90
|
+
self.w_dependency_map: Dict[Any, Any] = {}
|
91
|
+
self.r_dependency_map: Dict[Any, List[Any]] = collections.defaultdict(list)
|
92
|
+
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
93
|
+
|
94
|
+
def _access_resources(self, read, write, new_dependency:Any):
|
95
|
+
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
|
96
|
+
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
|
97
|
+
wait_nodes = []
|
98
|
+
|
99
|
+
for rawbuf in read + write:
|
100
|
+
if id(rawbuf._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf._buf)])
|
101
|
+
for rawbuf in write:
|
102
|
+
if id(rawbuf._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf._buf)))
|
103
|
+
|
104
|
+
for rawbuf in read: self.r_dependency_map[id(rawbuf._buf)].append(new_dependency)
|
105
|
+
for rawbuf in write: self.w_dependency_map[id(rawbuf._buf)] = new_dependency
|
106
|
+
return list({id(x):x for x in wait_nodes}.values())
|
107
|
+
|
108
|
+
ReturnType = TypeVar('ReturnType')
|
109
|
+
class TinyJit(Generic[ReturnType]):
|
110
|
+
def __init__(self, fxn:Callable[..., ReturnType]):
|
111
|
+
self.fxn = fxn
|
112
|
+
self.reset()
|
113
|
+
|
114
|
+
def add_buffer(self, b:Buffer) -> Buffer:
|
115
|
+
if found:=self.buffer_replace.get(b, None): return found
|
116
|
+
if b.is_allocated() or b.lb_refcount > 0: return b
|
117
|
+
if b._base is not None:
|
118
|
+
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.buffer_replace.get(b._base, b._base), offset=b.offset)
|
119
|
+
else:
|
120
|
+
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
|
121
|
+
return ret
|
122
|
+
|
123
|
+
def add(self, ei:ExecItem):
|
124
|
+
self.jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
|
125
|
+
|
126
|
+
def reset(self):
|
127
|
+
self.jit_cache: List[ExecItem] = []
|
128
|
+
self.input_replace: Dict[Tuple[int, int], int] = {}
|
129
|
+
self.extra_view_inputs: List[Tuple[int, int, str, int, DType]] = []
|
130
|
+
self.buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
|
131
|
+
self.cnt: int = 0
|
132
|
+
|
133
|
+
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
|
134
|
+
|
135
|
+
def __call__(self, *args, **kwargs) -> ReturnType:
|
136
|
+
input_tensors: List[Tuple[Union[int, str], Tensor]] = \
|
137
|
+
[(cast(Union[int, str], k),v) for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor]
|
138
|
+
if len(input_tensors): Tensor.realize(*[x[1] for x in input_tensors])
|
139
|
+
lbs: List[LazyBuffer] = flatten([v.lazydata.lbs for _,v in input_tensors])
|
140
|
+
expected_sts_var_dtype_device = [(*x.st.unbind(), x.dtype, x.device) for x in lbs]
|
141
|
+
input_rawbuffers: List[Buffer] = [v.base.realized for v in lbs if v.base.realized is not None]
|
142
|
+
assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT"
|
143
|
+
var_vals: Dict[Variable, int] = merge_dicts([x[1] for x in expected_sts_var_dtype_device] + \
|
144
|
+
[dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
|
145
|
+
|
146
|
+
expected_names, expected_lbs = [x[0] for x in input_tensors], [(x[0], tuple(x[1].keys()), x[2], x[3]) for x in expected_sts_var_dtype_device]
|
147
|
+
if self.cnt == 0:
|
148
|
+
# jit ignore
|
149
|
+
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
|
150
|
+
self.ret = self.fxn(*args, **kwargs)
|
151
|
+
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
|
152
|
+
elif self.cnt == 1:
|
153
|
+
# jit capture
|
154
|
+
self.expected_names: List[Union[int, str]] = expected_names
|
155
|
+
self.expected_lbs: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = expected_lbs
|
156
|
+
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
|
157
|
+
capturing.append(self)
|
158
|
+
self.ret = self.fxn(*args, **kwargs)
|
159
|
+
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
|
160
|
+
capturing.clear()
|
161
|
+
del self.buffer_replace
|
162
|
+
assert len(self.jit_cache), "didn't JIT anything!"
|
163
|
+
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
164
|
+
|
165
|
+
# track inputs that are views of buffers
|
166
|
+
for ji in self.jit_cache:
|
167
|
+
for b in ji.bufs:
|
168
|
+
if b is not None and b._base is not None and b._base in input_rawbuffers:
|
169
|
+
input_rawbuffers.append(b)
|
170
|
+
self.extra_view_inputs.append((input_rawbuffers.index(b.base), b.offset, b.device, b.size, b.dtype))
|
171
|
+
|
172
|
+
# memory planning (optional)
|
173
|
+
assigned = _internal_memory_planner([cast(List[Buffer], x.bufs) for x in self.jit_cache], debug_prefix="JIT ")
|
174
|
+
self.jit_cache = [ExecItem(ei.prg, [assigned.get(x,x).ensure_allocated() for x in ei.bufs if x is not None]) for ei in self.jit_cache]
|
175
|
+
|
176
|
+
# Condense the items into a graph executor.
|
177
|
+
if JIT < 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_rawbuffers, var_vals)
|
178
|
+
|
179
|
+
self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
|
180
|
+
if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_rawbuffers): print("WARNING: some input tensors not found")
|
181
|
+
elif self.cnt >= 2:
|
182
|
+
# jit exec
|
183
|
+
assert self.expected_names == expected_names, f"args mismatch in JIT: {self.expected_names=} != {expected_names}"
|
184
|
+
assert self.expected_lbs == expected_lbs, f"args mismatch in JIT: {self.expected_lbs=} != {expected_lbs=}"
|
185
|
+
for idx, offset, device, size, dtype in self.extra_view_inputs:
|
186
|
+
input_rawbuffers.append(Buffer(device, size, dtype, base=input_rawbuffers[idx], offset=offset).ensure_allocated())
|
187
|
+
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_rawbuffers[input_idx]
|
188
|
+
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
|
189
|
+
for ei in self.jit_cache: ei.run(var_vals, jit=True)
|
190
|
+
|
191
|
+
# clear jit inputs
|
192
|
+
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
|
193
|
+
|
194
|
+
self.cnt += 1
|
195
|
+
return self.ret
|
@@ -0,0 +1,191 @@
|
|
1
|
+
from typing import List, Dict, Optional, cast, Generator, Tuple
|
2
|
+
import time
|
3
|
+
from dataclasses import dataclass, replace
|
4
|
+
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int
|
5
|
+
from tinygrad.ops import BufferOps, LoadOps, LazyOp
|
6
|
+
from tinygrad.device import Device, Buffer
|
7
|
+
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
8
|
+
from tinygrad.renderer import Renderer, Program
|
9
|
+
from tinygrad.codegen.linearizer import Linearizer
|
10
|
+
from tinygrad.engine.schedule import ScheduleItem
|
11
|
+
|
12
|
+
# **************** Program Creation ****************
|
13
|
+
|
14
|
+
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
|
15
|
+
def get_linearizer(renderer:Renderer, ast:Tuple[LazyOp, ...]) -> Linearizer:
|
16
|
+
if DEBUG >= 3:
|
17
|
+
from tinygrad.engine.graph import print_tree
|
18
|
+
for op in ast: print_tree(op)
|
19
|
+
k = Linearizer(*ast, opts=renderer)
|
20
|
+
k.required_optimizations()
|
21
|
+
if not NOOPT:
|
22
|
+
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
|
23
|
+
if BEAM >= 1:
|
24
|
+
from tinygrad.engine.search import beam_search, time_linearizer, bufs_from_lin
|
25
|
+
kb, k_opt = Linearizer(*ast, opts=renderer), k
|
26
|
+
kb.required_optimizations()
|
27
|
+
rawbufs = bufs_from_lin(kb, allocate=False)
|
28
|
+
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
29
|
+
if getenv("BEAM_COMPARE", 1):
|
30
|
+
# TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
|
31
|
+
lins: List[Tuple[str, Linearizer]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
|
32
|
+
if used_tensor_cores:
|
33
|
+
lins.append(("hc", Linearizer(*ast, opts=renderer)))
|
34
|
+
lins[-1][1].hand_coded_optimizations()
|
35
|
+
timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
36
|
+
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
37
|
+
k = timed[0][1]
|
38
|
+
if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
|
39
|
+
# TODO: check the correctness inline once compare_linearizer is in core
|
40
|
+
if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
|
41
|
+
if DEBUG >= 4: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
|
42
|
+
return k
|
43
|
+
|
44
|
+
# **************** Runners ****************
|
45
|
+
|
46
|
+
class Runner:
|
47
|
+
def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0):
|
48
|
+
self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate = True, display_name, dname, op_estimate, mem_estimate
|
49
|
+
@property
|
50
|
+
def device(self): return Device[self.dname]
|
51
|
+
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
|
52
|
+
return self(rawbufs, {} if var_vals is None else var_vals)
|
53
|
+
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]:
|
54
|
+
raise NotImplementedError("override this")
|
55
|
+
|
56
|
+
class CompiledRunner(Runner):
|
57
|
+
def __init__(self, p:Program, precompiled:Optional[bytes]=None):
|
58
|
+
if DEBUG >= 4: print(p.src)
|
59
|
+
self.p:Program = p
|
60
|
+
self.lib:bytes = precompiled if precompiled is not None else Device[p.dname].compiler.compile_cached(p.src)
|
61
|
+
self.clprg = Device[p.dname].runtime(p.function_name, self.lib)
|
62
|
+
super().__init__(p.name, p.dname, p.op_estimate, p.mem_estimate)
|
63
|
+
|
64
|
+
def __reduce__(self): return self.__class__, (self.p, self.lib)
|
65
|
+
|
66
|
+
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]:
|
67
|
+
global_size, local_size = self.p.launch_dims(var_vals)
|
68
|
+
if global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
|
69
|
+
# TODO: this is copied from get_program
|
70
|
+
from tinygrad.engine.search import optimize_local_size
|
71
|
+
local_size = optimize_local_size(self.clprg, global_size, rawbufs)
|
72
|
+
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
73
|
+
self.p = replace(self.p, global_size=global_size, local_size=local_size)
|
74
|
+
lra = {}
|
75
|
+
if global_size:
|
76
|
+
lra['global_size'] = global_size
|
77
|
+
assert len(global_size) == 3, "global size must have len 3"
|
78
|
+
if local_size:
|
79
|
+
lra['local_size'] = local_size
|
80
|
+
assert len(local_size) == 3, "local size must have len 3"
|
81
|
+
return self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait)
|
82
|
+
|
83
|
+
class CustomOp(Runner):
|
84
|
+
def __init__(self, fxn):
|
85
|
+
self.fxn = fxn
|
86
|
+
super().__init__(self.fxn.__name__, "CUSTOM", 0, 0)
|
87
|
+
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): self.fxn(*rawbufs)
|
88
|
+
|
89
|
+
class EmptyOp(Runner):
|
90
|
+
def __init__(self, buf:Buffer): super().__init__(colored(f"empty {buf.size:10d} {buf.dtype}", "yellow"), buf.device)
|
91
|
+
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): pass
|
92
|
+
|
93
|
+
class ViewOp(Runner):
|
94
|
+
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
|
95
|
+
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
|
96
|
+
assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}"
|
97
|
+
|
98
|
+
class BufferCopy(Runner):
|
99
|
+
def __init__(self, total_sz, dest_device, src_device):
|
100
|
+
if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
|
101
|
+
else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
|
102
|
+
super().__init__(colored(name, "yellow"), dest_device, 0, total_sz)
|
103
|
+
def copy(self, dest, src):
|
104
|
+
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_fd') and src.nbytes >= 4096 and hasattr(src.allocator.device, 'fd'):
|
105
|
+
dest.allocator.copy_from_fd(dest._buf, src.allocator.device.fd, src._buf.offset, src.nbytes)
|
106
|
+
elif src.device.startswith("DISK") and hasattr(dest.allocator, 'as_buffer'):
|
107
|
+
# fast(ish) path, uses readinto in diskbuffers
|
108
|
+
src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
|
109
|
+
else:
|
110
|
+
dest.copyin(src.as_buffer(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy
|
111
|
+
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
|
112
|
+
dest, src = rawbufs[0:2]
|
113
|
+
assert dest.size == src.size and dest.dtype == src.dtype, f"buffer copy mismatch, {dest.size} != {src.size}, {dest.dtype} != {src.dtype}"
|
114
|
+
st = time.perf_counter()
|
115
|
+
self.copy(dest, src)
|
116
|
+
if wait:
|
117
|
+
Device[dest.device].synchronize()
|
118
|
+
return time.perf_counter() - st
|
119
|
+
|
120
|
+
class BufferXfer(BufferCopy):
|
121
|
+
def copy(self, dest, src):
|
122
|
+
if hasattr(dest.allocator.device, "track_cross_buffer") and hasattr(src.allocator, "track_cross_device"):
|
123
|
+
dest.allocator.device.track_cross_buffer.append(src)
|
124
|
+
src.allocator.track_cross_device.add(dest.allocator.device)
|
125
|
+
dest.allocator.transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.device, dest_dev=dest.allocator.device)
|
126
|
+
|
127
|
+
# **************** method cache ****************
|
128
|
+
|
129
|
+
method_cache: Dict[Tuple[str, Tuple[LazyOp, ...], int, bool], CompiledRunner] = {}
|
130
|
+
def get_runner(dname:str, ast:Tuple[LazyOp, ...]) -> CompiledRunner:
|
131
|
+
ckey = (dname, ast, BEAM.value, False)
|
132
|
+
if cret:=method_cache.get(ckey): return cret
|
133
|
+
bkey = (dname.split(":")[0], ast, BEAM.value, True)
|
134
|
+
if bret:=method_cache.get(bkey):
|
135
|
+
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, dname=dname), bret.lib)
|
136
|
+
else:
|
137
|
+
prg: Program = get_linearizer(Device[dname].renderer, ast).to_program()
|
138
|
+
if hasattr(prg.uops, "fuzz_paths"):
|
139
|
+
from test.external.fuzz_uops import UOpsFuzzerRunner
|
140
|
+
return UOpsFuzzerRunner(replace(prg, dname=dname))
|
141
|
+
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, dname=dname))
|
142
|
+
return ret
|
143
|
+
|
144
|
+
# **************** lowering functions ****************
|
145
|
+
|
146
|
+
@dataclass(frozen=True)
|
147
|
+
class ExecItem:
|
148
|
+
prg: Runner
|
149
|
+
bufs: List[Optional[Buffer]]
|
150
|
+
def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
|
151
|
+
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
|
152
|
+
et = self.prg(bufs, var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
|
153
|
+
if do_update_stats:
|
154
|
+
GlobalCounters.kernel_count += 1
|
155
|
+
GlobalCounters.global_ops += (op_estimate:=sym_infer(self.prg.op_estimate, var_vals))
|
156
|
+
GlobalCounters.global_mem += (mem_estimate:=sym_infer(self.prg.mem_estimate, var_vals))
|
157
|
+
if et is not None: GlobalCounters.time_sum_s += et
|
158
|
+
if DEBUG >= 2:
|
159
|
+
ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
|
160
|
+
print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(38-ansilen(self.prg.display_name))} arg {len(self.bufs):3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
|
161
|
+
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # noqa: E501
|
162
|
+
self.prg.first_run = False
|
163
|
+
return et
|
164
|
+
|
165
|
+
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
|
166
|
+
assert len(set(x.device for x in si.bufs)) == 1 or si.ast[0].op is LoadOps.COPY or getenv("USE_COPY_KERNEL")
|
167
|
+
if si.ast[0].op is BufferOps.STORE:
|
168
|
+
runner = get_runner(si.outputs[0].device, si.ast)
|
169
|
+
return ExecItem(runner, [si.bufs[x[0]] for x in runner.p.globals])
|
170
|
+
out, ast = si.outputs[0], si.ast[0]
|
171
|
+
if ast.op is LoadOps.COPY:
|
172
|
+
kernel_type = BufferCopy
|
173
|
+
if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
|
174
|
+
kernel_type = BufferXfer
|
175
|
+
return ExecItem(kernel_type(ast.arg, out.device, si.inputs[0].device), list(si.bufs))
|
176
|
+
if ast.op is LoadOps.CUSTOM: return ExecItem(CustomOp(ast.arg), list(si.bufs))
|
177
|
+
if ast.op is LoadOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
|
178
|
+
if ast.op is LoadOps.VIEW: return ExecItem(ViewOp(out), list(si.bufs))
|
179
|
+
raise RuntimeError(f"don't know how to lower {ast}")
|
180
|
+
|
181
|
+
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
|
182
|
+
while len(schedule): yield lower_schedule_item(schedule.pop(0))
|
183
|
+
|
184
|
+
# **************** main run function ****************
|
185
|
+
|
186
|
+
capturing: List = [] # put classes with an add method in here
|
187
|
+
|
188
|
+
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None, do_update_stats=True):
|
189
|
+
for ei in lower_schedule(schedule):
|
190
|
+
if len(capturing): capturing[0].add(ei)
|
191
|
+
ei.run(var_vals, do_update_stats=do_update_stats)
|