tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/kernel.py +230 -190
- tinygrad/codegen/linearizer.py +278 -384
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +132 -275
- tinygrad/dtype.py +53 -37
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +28 -14
- tinygrad/helpers.py +72 -43
- tinygrad/lazy.py +141 -240
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +179 -8
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +86 -17
- tinygrad/ops.py +70 -44
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +299 -206
- tinygrad/renderer/llvmir.py +118 -123
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +130 -38
- tinygrad/runtime/ops_disk.py +45 -42
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +42 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +41 -105
- tinygrad/shape/symbolic.py +98 -95
- tinygrad/shape/view.py +137 -35
- tinygrad/tensor.py +2367 -442
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/codegen/uops.py
ADDED
@@ -0,0 +1,415 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable
|
3
|
+
import functools, itertools, heapq
|
4
|
+
from collections import defaultdict
|
5
|
+
from enum import Enum, auto
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from tinygrad.dtype import dtypes, DType
|
8
|
+
from tinygrad.shape.symbolic import sint, Variable
|
9
|
+
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
|
10
|
+
from tinygrad.helpers import prod, DEBUG, getenv
|
11
|
+
|
12
|
+
# the order of these UOps controls the order of the toposort
|
13
|
+
class UOps(Enum):
|
14
|
+
# ops that aren't rendered
|
15
|
+
SINK = auto()
|
16
|
+
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
|
17
|
+
CONST = auto(); SPECIAL = auto() # noqa: E702
|
18
|
+
NOOP = auto(); UNMUL = auto(); GEP = auto() # noqa: E702
|
19
|
+
# math ops
|
20
|
+
CAST = auto(); BITCAST = auto() # noqa: E702
|
21
|
+
ALU = auto(); WMMA = auto() # noqa: E702
|
22
|
+
# memory/assignment ops
|
23
|
+
LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
|
24
|
+
# control flow ops
|
25
|
+
BARRIER = auto(); IF = auto(); RANGE = auto() # noqa: E702
|
26
|
+
# these two are not graph nodes
|
27
|
+
ENDRANGE = auto(); ENDIF = auto() # noqa: E702
|
28
|
+
|
29
|
+
@dataclass(eq=False)
|
30
|
+
class UOp:
|
31
|
+
uop: UOps
|
32
|
+
dtype: Optional[DType] = None
|
33
|
+
vin: Tuple[UOp, ...] = tuple()
|
34
|
+
arg: Any = None
|
35
|
+
def tuple(self): return (self.uop, self.dtype, self.vin, self.arg)
|
36
|
+
@functools.cached_property
|
37
|
+
def cmp_tuple(self):
|
38
|
+
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
39
|
+
return (self.uop.value, (self.arg if self.uop is not UOps.DEFINE_VAR else self.arg.expr) if self.uop is not UOps.ALU else \
|
40
|
+
(type(self.uop), self.uop.value), self.dtype, self.vin)
|
41
|
+
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
|
42
|
+
def __repr__(self):
|
43
|
+
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
|
44
|
+
def cast(self, dtype): return UOp(UOps.CAST, dtype, (self,))
|
45
|
+
def __neg__(self): return UOp.alu(UnaryOps.NEG, self)
|
46
|
+
def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, x)
|
47
|
+
def __sub__(self, x): return UOp.alu(BinaryOps.SUB, self, x)
|
48
|
+
def __mul__(self, x): return UOp.alu(BinaryOps.MUL, self, x)
|
49
|
+
@staticmethod
|
50
|
+
def max(x, y): return UOp.alu(BinaryOps.MAX, x, y)
|
51
|
+
@staticmethod
|
52
|
+
def min(x, y): return -UOp.alu(BinaryOps.MAX, -x, -y)
|
53
|
+
@staticmethod
|
54
|
+
def const(dtype, val): return UOp(UOps.CONST, dtype, arg=dtypes.as_const(val, dtype))
|
55
|
+
@staticmethod
|
56
|
+
def alu(arg, *vin:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else vin[-1].dtype, vin, arg)
|
57
|
+
@functools.cached_property
|
58
|
+
def parents(self) -> Set[UOp]: return set.union(set(self.vin), *[x.parents for x in self.vin])
|
59
|
+
|
60
|
+
def uop_alu_resolve(u:UOp) -> sint:
|
61
|
+
if u.uop is UOps.CONST: return u.arg
|
62
|
+
if u.uop is UOps.DEFINE_VAR: return u.arg
|
63
|
+
if u.uop is UOps.SPECIAL: return u.arg[2]-1
|
64
|
+
if u.uop is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1])
|
65
|
+
if u.uop is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1])
|
66
|
+
raise RuntimeError(f"ALU resolve fail @ {u.uop}")
|
67
|
+
|
68
|
+
# *** simplification logic ***
|
69
|
+
|
70
|
+
def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool:
|
71
|
+
for k,v in pattern.items():
|
72
|
+
if k == "__name__":
|
73
|
+
if v in store and store[v] != uop: return False
|
74
|
+
store[v] = uop
|
75
|
+
elif k == "arg":
|
76
|
+
if uop.arg != v: return False
|
77
|
+
elif k == "dtype":
|
78
|
+
if isinstance(v, set):
|
79
|
+
if uop.dtype not in v: return False
|
80
|
+
elif uop.dtype != v: return False
|
81
|
+
elif k == "uop":
|
82
|
+
if isinstance(v, set):
|
83
|
+
if uop.uop not in v: return False
|
84
|
+
elif uop.uop != v: return False
|
85
|
+
elif k == "vin":
|
86
|
+
# only one if it's a tuple
|
87
|
+
# try all permutations if it's a list
|
88
|
+
# repeat if it's a dict
|
89
|
+
for vp in itertools.permutations(v) if isinstance(v, list) else ([v] if isinstance(v, tuple) else [(v,)*len(uop.vin)]):
|
90
|
+
if len(uop.vin) != len(vp) and (len(uop.vin) not in pattern.get('__allow_len__', [])): return False
|
91
|
+
new_store = store.copy()
|
92
|
+
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.vin, vp)):
|
93
|
+
for k,v in new_store.items(): store[k] = v
|
94
|
+
return True
|
95
|
+
return False
|
96
|
+
return True
|
97
|
+
|
98
|
+
class PatternMatcher:
|
99
|
+
def __init__(self, patterns:List[Tuple[Dict[str, Any], Callable]]):
|
100
|
+
self.patterns = patterns
|
101
|
+
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[Dict[str, Any], Callable]]] = defaultdict(list)
|
102
|
+
# uop is required, arg is optional
|
103
|
+
for p,fxn in self.patterns:
|
104
|
+
uops = p["uop"]
|
105
|
+
if isinstance(uops, set):
|
106
|
+
for uop in uops: self.pdict[(uop, p.get("arg", None))].append((p, fxn))
|
107
|
+
else:
|
108
|
+
self.pdict[(uops, p.get("arg", None))].append((p, fxn))
|
109
|
+
|
110
|
+
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
111
|
+
for p,fxn in itertools.chain(self.pdict[(uop.uop, uop.arg)], self.pdict[(uop.uop, None)]):
|
112
|
+
store: Dict[str, UOp] = {}
|
113
|
+
if _match(uop, p, store): return fxn(**store)
|
114
|
+
return None
|
115
|
+
|
116
|
+
def sum_collapse(phi_input, loop, val1, val2):
|
117
|
+
for v1,v2 in [(val1, val2), (val2, val1)]:
|
118
|
+
if loop not in v1.parents:
|
119
|
+
loop_range = loop.vin[1]-loop.vin[0]
|
120
|
+
ret = v1*loop_range.cast(v1.dtype)
|
121
|
+
return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+ret
|
122
|
+
return None
|
123
|
+
|
124
|
+
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst):
|
125
|
+
if mval.arg >= 0 or loop_start.arg != 0:
|
126
|
+
# TODO: support and test this with other mvals and loop_starts
|
127
|
+
if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
|
128
|
+
return None
|
129
|
+
comprange = UOp.min(loop_end, UOp.max(UOp.alu(BinaryOps.DIV, idx-compval-mval, mval) + (loop_end-loop_start), loop_start))
|
130
|
+
return UOp(UOps.UNMUL, multconst.dtype, (comprange.cast(multconst.dtype) * multconst, loop_end-loop_start))
|
131
|
+
|
132
|
+
# this is symbolic 2.0
|
133
|
+
constant_folder = PatternMatcher([
|
134
|
+
# arange loop folding (early)
|
135
|
+
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": (
|
136
|
+
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin":
|
137
|
+
[{"__name__": "idx"}, {"uop": UOps.ALU, "arg": BinaryOps.MUL,
|
138
|
+
"vin": [{"__name__": "mval", "uop": UOps.CONST}, {"uop": UOps.RANGE, "vin": ({"__name__": "loop_start"}, {"__name__": "loop_end"})}]}]},
|
139
|
+
{"__name__": "compval", "uop": UOps.CONST})}, {"__name__": "multconst", "uop": UOps.CONST}, {"uop": UOps.CONST, "arg": 0})}, loop_collapse),
|
140
|
+
# sum collapse to mul (with possible GEP)
|
141
|
+
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.DEFINE_ACC, "vin": ({"uop": UOps.RANGE, "__name__": "loop"},)},
|
142
|
+
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
|
143
|
+
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.GEP,
|
144
|
+
"vin": ({"uop": UOps.DEFINE_ACC, "vin":({"uop": UOps.RANGE, "__name__": "loop"},)},)},
|
145
|
+
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
|
146
|
+
# deal with UNMUL
|
147
|
+
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"uop": UOps.CONST, "__name__": "c1"},
|
148
|
+
{"uop": UOps.UNMUL, "vin": [{"uop": UOps.CONST, "__name__": "c2"}, {"__name__": "v"}]}]},
|
149
|
+
lambda c1,c2,v: v if c1.arg == c2.arg else None),
|
150
|
+
({"uop": UOps.UNMUL, "vin": ({"uop": UOps.CONST, "__name__": "zero", "arg": 0}, {})}, lambda zero: zero),
|
151
|
+
({"__name__": "root", "uop": UOps.CAST, "vin": ({"uop": UOps.UNMUL, "__name__": "unmul"},)},
|
152
|
+
lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.vin[0].cast(root.dtype), unmul.vin[1]))),
|
153
|
+
# max on special can go away (TODO: special should be variable, same thing applies)
|
154
|
+
({"uop": UOps.ALU, "arg": BinaryOps.MAX, "vin": [{"__name__": "c", "uop": UOps.CONST}, {"__name__": "s", "uop": UOps.SPECIAL}]},
|
155
|
+
lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
|
156
|
+
# const rules
|
157
|
+
({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "c", "uop": UOps.CONST},)}, lambda root, c: UOp.const(root.dtype, c.arg)),
|
158
|
+
({"__name__": "root", "uop": UOps.CAST, "vin": {"__name__": "c", "uop": UOps.CONST}}, lambda root, c: UOp.const(root.dtype, c.arg)),
|
159
|
+
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
|
160
|
+
({"uop": UOps.PHI, "vin": ({"uop": UOps.DEFINE_ACC, "__name__": "acc"}, {"__name__": "acc"})}, lambda acc: UOp.const(acc.dtype, acc.arg[0])),
|
161
|
+
({"uop": UOps.PHI, "vin": ({"uop": UOps.DEFINE_ACC, "vin": tuple()}, {"__name__": "x"})}, lambda x: x),
|
162
|
+
({"uop": UOps.PHI, "vin": ({"uop": UOps.CONST}, {"__name__": "x"})}, lambda x: x),
|
163
|
+
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
|
164
|
+
({"__name__": "root", "uop": UOps.DEFINE_ACC, "vin": tuple()}, lambda root: UOp.const(root.dtype, root.arg[0])),
|
165
|
+
({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "x", "uop": UOps.CONST},)}, lambda root,x: UOp.const(root.dtype, x.arg)),
|
166
|
+
# max -2147483648
|
167
|
+
({"uop": UOps.ALU, "arg": BinaryOps.MAX, "dtype": dtypes.int, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -2147483648}]}, lambda x: x),
|
168
|
+
# -(-x) -> x
|
169
|
+
({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"__name__": "x"},)})}, lambda x: x),
|
170
|
+
# x+-y -> x-y
|
171
|
+
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "x"}, {"__name__": "my", "uop": UOps.ALU, "arg": UnaryOps.NEG})},
|
172
|
+
lambda x, my: x-my.vin[0]),
|
173
|
+
# -1*x -> -x
|
174
|
+
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -1}]}, lambda x: -x),
|
175
|
+
# bool < False is always false, True < bool is always false
|
176
|
+
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({}, {"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": False})}, lambda x: x),
|
177
|
+
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": True}, {})},
|
178
|
+
lambda x: UOp.const(dtypes.bool, False)),
|
179
|
+
# a conditional with the same results either way is a noop, also fold const conditionals
|
180
|
+
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({}, {"__name__": "val"}, {"__name__": "val"})}, lambda val: val),
|
181
|
+
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"__name__": "gate", "uop": UOps.CONST}, {"__name__": "c0"}, {"__name__": "c1"})},
|
182
|
+
lambda gate, c0, c1: c0 if gate.arg else c1),
|
183
|
+
# ** constant folding **
|
184
|
+
({"__name__": "root", "uop": UOps.ALU, "vin": {"uop": UOps.CONST}},
|
185
|
+
lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))),
|
186
|
+
# ** self folding **
|
187
|
+
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 0}]}, lambda x: x), # x+0 -> x or 0+x -> x
|
188
|
+
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 1}]}, lambda x: x), # x*1 -> x or 1*x -> x
|
189
|
+
({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 0})}, lambda x: x), # x-0 -> x
|
190
|
+
({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 1})}, lambda x: x), # x/1 -> x
|
191
|
+
({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": -1})}, lambda x: -x), # x/-1 -> -x
|
192
|
+
# ** zero folding **
|
193
|
+
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{}, {"__name__": "c", "uop": UOps.CONST, "arg": 0}]}, lambda c: c), # x*0 -> 0 or 0*x -> 0
|
194
|
+
({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"__name__": "x"})}, lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
|
195
|
+
# ** load/store folding **
|
196
|
+
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"},
|
197
|
+
{"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})}, lambda buf, idx: UOp(UOps.NOOP)),
|
198
|
+
# ** two stage add/sub folding **
|
199
|
+
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"uop": UOps.ALU, "arg": BinaryOps.ADD,
|
200
|
+
"vin": [{"__name__": "x"}, {"__name__": "c1", "uop": UOps.CONST}]}, {"__name__": "c2", "uop": UOps.CONST}]},
|
201
|
+
lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
|
202
|
+
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"uop": UOps.ALU, "arg": BinaryOps.SUB,
|
203
|
+
"vin": ({"__name__": "x"}, {"__name__": "c1", "uop": UOps.CONST})}, {"__name__": "c2", "uop": UOps.CONST}]},
|
204
|
+
lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.SUB, x.dtype, [c2.arg, c1.arg]))),
|
205
|
+
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
206
|
+
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.ALU, "arg": TernaryOps.WHERE,
|
207
|
+
"vin": ({"__name__": "gate"}, {"__name__": "alt"}, {"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})})},
|
208
|
+
lambda buf, idx, gate, alt: UOp(UOps.STORE, None, (buf, idx, alt, gate))),
|
209
|
+
# store float4/float2 directly (remove CAST/GEP)
|
210
|
+
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.CAST, "vin":
|
211
|
+
tuple({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i} for i in range(4))})},
|
212
|
+
lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
|
213
|
+
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.CAST, "vin":
|
214
|
+
tuple({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i} for i in range(2))})},
|
215
|
+
lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
|
216
|
+
# CAST-PHI-GEP -> PHI-CAST
|
217
|
+
({"__name__": "root", "uop": UOps.CAST, "vin":
|
218
|
+
tuple({"uop": UOps.PHI, "vin": ({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i}, {"__name__": f"v{i}"})} for i in range(4))},
|
219
|
+
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
|
220
|
+
({"__name__": "root", "uop": UOps.CAST, "vin":
|
221
|
+
tuple({"uop": UOps.PHI, "vin": ({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i}, {"__name__": f"v{i}"})} for i in range(2))},
|
222
|
+
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
|
223
|
+
# NEG/CMPLT -> CMPLT
|
224
|
+
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"__name__": "x"},)},
|
225
|
+
{"__name__": "c", "uop": UOps.CONST, "dtype": dtypes.int})},
|
226
|
+
lambda c,x: UOp(UOps.ALU, dtypes.bool, (UOp.const(c.dtype, -c.arg), x), BinaryOps.CMPLT)),
|
227
|
+
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
228
|
+
({"__name__": "root", "uop": UOps.CAST}, lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None),
|
229
|
+
])
|
230
|
+
|
231
|
+
# *** uop graph ***
|
232
|
+
|
233
|
+
class UOpGraph:
|
234
|
+
def __init__(self):
|
235
|
+
self.nodes: Dict[Tuple, UOp] = {}
|
236
|
+
self._uops: Optional[List[UOp]] = None
|
237
|
+
|
238
|
+
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
|
239
|
+
def __getitem__(self, index) -> UOp: return self.uops[index]
|
240
|
+
|
241
|
+
def vars(self) -> List[Variable]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_VAR]
|
242
|
+
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_GLOBAL]
|
243
|
+
|
244
|
+
@property
|
245
|
+
def uops(self):
|
246
|
+
if self._uops is None: self.linearize()
|
247
|
+
return self._uops
|
248
|
+
|
249
|
+
def graph(self):
|
250
|
+
from tinygrad.engine.graph import graph_uops
|
251
|
+
graph_uops(self.uops)
|
252
|
+
|
253
|
+
def print(self):
|
254
|
+
for i,u in enumerate(self):
|
255
|
+
print(f"{i:4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.vin]):32s} {u.arg}")
|
256
|
+
|
257
|
+
def graph_rewrite(self, sink, pm):
|
258
|
+
# recursive rewrite
|
259
|
+
changed = getenv("UOPS_REWRITE", 1)
|
260
|
+
run_cnt = 0
|
261
|
+
while changed:
|
262
|
+
changed = 0
|
263
|
+
@functools.lru_cache
|
264
|
+
def rewrite(u:UOp) -> UOp:
|
265
|
+
nonlocal changed
|
266
|
+
recurse_cnt = 0
|
267
|
+
up = u
|
268
|
+
# locally recursively rewrite
|
269
|
+
while (rewritten := pm.rewrite(up)):
|
270
|
+
assert recurse_cnt < 100, f"recursive_rewrite looped {up} <--> {rewritten}"
|
271
|
+
up = rewritten
|
272
|
+
recurse_cnt += 1
|
273
|
+
changed += recurse_cnt
|
274
|
+
# NOTE: this changes UOp, so we have to delete caches
|
275
|
+
up.vin = tuple(rewrite(x) for x in up.vin)
|
276
|
+
if hasattr(up, "parents"): del up.parents
|
277
|
+
if hasattr(up, "cmp_tuple"): del up.cmp_tuple
|
278
|
+
# replace with cached nodes
|
279
|
+
if found:=self.nodes.get(key:=up.tuple()): return found
|
280
|
+
else: self.nodes[key] = up
|
281
|
+
return up
|
282
|
+
sink = rewrite(sink)
|
283
|
+
run_cnt += 1
|
284
|
+
assert run_cnt < 100, "exceeded 100 rewrite loops!"
|
285
|
+
return sink
|
286
|
+
|
287
|
+
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
|
288
|
+
# NOTE: relinearizering should be okay
|
289
|
+
#assert self._uops is None, "already linearized"
|
290
|
+
|
291
|
+
# get sink
|
292
|
+
_sinks: List[UOp] = []
|
293
|
+
for u in self.nodes.values():
|
294
|
+
if u.uop is UOps.STORE: _sinks.append(u)
|
295
|
+
if u.uop is UOps.SINK: _sinks.extend(u.vin)
|
296
|
+
sink = UOp(UOps.SINK, None, tuple(_sinks))
|
297
|
+
del _sinks
|
298
|
+
|
299
|
+
sink = self.graph_rewrite(sink, constant_folder)
|
300
|
+
if extra_pm: sink = self.graph_rewrite(sink, PatternMatcher(constant_folder.patterns+extra_pm.patterns))
|
301
|
+
|
302
|
+
# filter nodes that don't link to a sink
|
303
|
+
# BFS toposort
|
304
|
+
graph: DefaultDict[UOp, List[UOp]] = defaultdict(list)
|
305
|
+
in_degree: DefaultDict[UOp, int] = defaultdict(int)
|
306
|
+
loops = []
|
307
|
+
ifs = []
|
308
|
+
nodes: Dict[UOp, None] = {}
|
309
|
+
def add_parents(u:UOp):
|
310
|
+
if u in nodes: return
|
311
|
+
nodes[u] = None
|
312
|
+
for x in u.vin:
|
313
|
+
add_parents(x)
|
314
|
+
in_degree[u] += 1
|
315
|
+
graph[x].append(u)
|
316
|
+
if u.uop is UOps.RANGE: loops.append(u)
|
317
|
+
if u.uop is UOps.IF: ifs.append(u)
|
318
|
+
sink = UOp(UOps.SINK, None, tuple(x for x in sink.vin if x.uop is not UOps.NOOP))
|
319
|
+
add_parents(sink)
|
320
|
+
|
321
|
+
@functools.lru_cache(None)
|
322
|
+
def get_recursive_children(x:UOp, include_self=False) -> Set[UOp]:
|
323
|
+
if x.uop is UOps.SINK: return set()
|
324
|
+
return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, True) for u in graph[x]] if x.uop is not UOps.PHI else []))
|
325
|
+
loops_children = {l:get_recursive_children(l) for l in loops[::-1]}
|
326
|
+
|
327
|
+
queue: List = []
|
328
|
+
def push(u):
|
329
|
+
priority = 0
|
330
|
+
# prefer uops that are loop children
|
331
|
+
for l, ss in loops_children.items():
|
332
|
+
if u in ss: priority -= l.arg[0]*1000 + l.arg[1]
|
333
|
+
heapq.heappush(queue, (priority, u))
|
334
|
+
|
335
|
+
for u in nodes:
|
336
|
+
if in_degree[u] == 0: push(u)
|
337
|
+
|
338
|
+
if getenv("FUZZ_UOPS", 0):
|
339
|
+
from test.external.fuzz_uops import fuzz_uops
|
340
|
+
self.fuzz_paths = fuzz_uops(graph, in_degree.copy(), loops_children)
|
341
|
+
|
342
|
+
self._uops = []
|
343
|
+
while queue:
|
344
|
+
p,x = heapq.heappop(queue)
|
345
|
+
if DEBUG >= 7: print(p,x)
|
346
|
+
if x.uop is UOps.DEFINE_ACC and len(x.vin):
|
347
|
+
idx = min([self._uops.index(l) for l in x.vin])
|
348
|
+
self._uops.insert(idx, x)
|
349
|
+
else:
|
350
|
+
self._uops.append(x)
|
351
|
+
for u, ss in loops_children.items():
|
352
|
+
if x in ss:
|
353
|
+
ss.remove(x)
|
354
|
+
if len(ss) == 0: self._uops.append(UOp(UOps.ENDRANGE, None, (u,)))
|
355
|
+
for u in graph[x]:
|
356
|
+
in_degree[u] -= 1
|
357
|
+
if in_degree[u] == 0: push(u)
|
358
|
+
|
359
|
+
assert self._uops[-1].uop is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
|
360
|
+
self._uops = self._uops[:-1]
|
361
|
+
|
362
|
+
# TODO: ifs should be removed and just the store should be gated
|
363
|
+
for u in ifs[::-1]: self._uops.append(UOp(UOps.ENDIF, None, (u,)))
|
364
|
+
|
365
|
+
if type_verify: self.type_verify()
|
366
|
+
|
367
|
+
def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None) -> UOp:
|
368
|
+
if found:=self.nodes.get(key:=(uop, dtype, vin, arg)): return found
|
369
|
+
self.nodes[key] = ret = UOp(*key)
|
370
|
+
return ret
|
371
|
+
|
372
|
+
# *** checker functions ***
|
373
|
+
|
374
|
+
def flops_mem(self) -> Tuple[sint, sint]:
|
375
|
+
flops: sint = 0
|
376
|
+
mem: sint = 0
|
377
|
+
mults: sint = 1
|
378
|
+
mult_stack = []
|
379
|
+
for u in self.uops:
|
380
|
+
if u.uop is UOps.RANGE:
|
381
|
+
mult_stack.append(mults)
|
382
|
+
mults *= uop_alu_resolve(u.vin[1])
|
383
|
+
elif u.uop is UOps.ENDRANGE:
|
384
|
+
mults = mult_stack.pop(-1)
|
385
|
+
elif u.uop is UOps.ALU:
|
386
|
+
flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
|
387
|
+
elif u.uop is UOps.LOAD:
|
388
|
+
assert u.dtype is not None
|
389
|
+
mem += u.dtype.itemsize * mults
|
390
|
+
elif u.uop is UOps.STORE:
|
391
|
+
assert u.vin[2].dtype is not None
|
392
|
+
mem += u.vin[2].dtype.itemsize * mults
|
393
|
+
elif u.uop is UOps.WMMA:
|
394
|
+
assert u.arg[1] is not None
|
395
|
+
flops += 2 * prod(u.arg[1]) // 32 * mults
|
396
|
+
return flops, mem
|
397
|
+
|
398
|
+
def type_verify(self):
|
399
|
+
for u in self.uops:
|
400
|
+
uop, arg, vin, dtype = u.uop, u.arg, u.vin, u.dtype
|
401
|
+
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
|
402
|
+
if uop is UOps.DEFINE_ACC: arg = arg[0]
|
403
|
+
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
404
|
+
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
|
405
|
+
if uop is UOps.ALU:
|
406
|
+
if arg in UnaryOps:
|
407
|
+
assert dtype == vin[0].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=}"
|
408
|
+
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPEQ):
|
409
|
+
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
410
|
+
assert vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
|
411
|
+
elif arg in BinaryOps:
|
412
|
+
assert dtype == vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
|
413
|
+
elif arg == TernaryOps.WHERE:
|
414
|
+
assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
|
415
|
+
assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}"
|