tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +78 -90
- tinygrad/codegen/linearizer.py +237 -169
- tinygrad/codegen/uops.py +278 -242
- tinygrad/device.py +147 -10
- tinygrad/dtype.py +7 -7
- tinygrad/engine/graph.py +16 -16
- tinygrad/engine/jit.py +39 -36
- tinygrad/engine/realize.py +6 -5
- tinygrad/engine/schedule.py +15 -7
- tinygrad/engine/search.py +6 -3
- tinygrad/function.py +17 -23
- tinygrad/helpers.py +77 -8
- tinygrad/lazy.py +26 -26
- tinygrad/multi.py +13 -9
- tinygrad/nn/__init__.py +1 -1
- tinygrad/nn/datasets.py +2 -1
- tinygrad/nn/state.py +3 -4
- tinygrad/ops.py +49 -16
- tinygrad/renderer/__init__.py +8 -4
- tinygrad/renderer/assembly.py +93 -100
- tinygrad/renderer/cstyle.py +47 -42
- tinygrad/renderer/llvmir.py +30 -30
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +11504 -1
- tinygrad/runtime/autogen/comgr.py +36 -10
- tinygrad/runtime/autogen/hsa.py +146 -14
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/nv_gpu.py +269 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +20 -11
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +3 -2
- tinygrad/runtime/graph/cuda.py +2 -2
- tinygrad/runtime/graph/hcq.py +122 -78
- tinygrad/runtime/ops_amd.py +302 -316
- tinygrad/runtime/ops_cuda.py +3 -3
- tinygrad/runtime/ops_disk.py +70 -5
- tinygrad/runtime/ops_gpu.py +2 -2
- tinygrad/runtime/ops_metal.py +5 -6
- tinygrad/runtime/ops_npy.py +1 -1
- tinygrad/runtime/ops_nv.py +161 -166
- tinygrad/runtime/ops_python.py +20 -16
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +5 -2
- tinygrad/shape/symbolic.py +1 -3
- tinygrad/shape/view.py +34 -19
- tinygrad/tensor.py +219 -135
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/runtime/driver/hsa.py +0 -143
- tinygrad/runtime/graph/hsa.py +0 -171
- tinygrad/runtime/ops_hsa.py +0 -278
- tinygrad-0.9.0.dist-info/RECORD +0 -60
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/codegen/uops.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable
|
3
|
-
import functools, itertools, heapq
|
2
|
+
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TypeVar
|
3
|
+
import functools, itertools, heapq, math
|
4
4
|
from collections import defaultdict
|
5
5
|
from enum import Enum, auto
|
6
|
-
from dataclasses import dataclass
|
7
|
-
from tinygrad.dtype import dtypes, DType
|
6
|
+
from dataclasses import dataclass, field
|
7
|
+
from tinygrad.dtype import ConstType, dtypes, DType
|
8
8
|
from tinygrad.shape.symbolic import sint, Variable
|
9
9
|
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
|
10
10
|
from tinygrad.helpers import prod, DEBUG, getenv
|
@@ -12,7 +12,7 @@ from tinygrad.helpers import prod, DEBUG, getenv
|
|
12
12
|
# the order of these UOps controls the order of the toposort
|
13
13
|
class UOps(Enum):
|
14
14
|
# ops that aren't rendered
|
15
|
-
SINK = auto()
|
15
|
+
SINK = auto(); VAR = auto() # noqa: E702
|
16
16
|
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
|
17
17
|
CONST = auto(); SPECIAL = auto() # noqa: E702
|
18
18
|
NOOP = auto(); UNMUL = auto(); GEP = auto() # noqa: E702
|
@@ -26,89 +26,119 @@ class UOps(Enum):
|
|
26
26
|
# these two are not graph nodes
|
27
27
|
ENDRANGE = auto(); ENDIF = auto() # noqa: E702
|
28
28
|
|
29
|
-
|
29
|
+
def ufix(dtype: Optional[DType], x): return UOp.const(dtype, x) if not isinstance(x, UOp) else x
|
30
|
+
@dataclass(frozen=True, eq=False)
|
30
31
|
class UOp:
|
31
|
-
|
32
|
+
op: UOps
|
32
33
|
dtype: Optional[DType] = None
|
33
|
-
|
34
|
+
src: Tuple[UOp, ...] = tuple()
|
34
35
|
arg: Any = None
|
35
|
-
def
|
36
|
+
def commutative(self) -> bool:
|
37
|
+
return self.op is UOps.ALU and self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR}
|
36
38
|
@functools.cached_property
|
37
39
|
def cmp_tuple(self):
|
38
40
|
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
39
|
-
return (self.
|
40
|
-
(type(self.
|
41
|
+
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
|
42
|
+
(type(self.op), self.op.value), self.dtype, self.src)
|
41
43
|
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
|
42
44
|
def __repr__(self):
|
43
|
-
return f"{str(self.
|
44
|
-
def cast(self, dtype): return UOp(UOps.CAST, dtype, (self,))
|
45
|
+
return f"{str(self.op):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.op for x in self.src]):32s} {self.arg}"
|
46
|
+
def cast(self, dtype=None): return UOp(UOps.CAST, dtype, (self,))
|
47
|
+
def name(self, name:Optional[str]): return UOp(UOps.VAR, src=(self,), arg=name)
|
45
48
|
def __neg__(self): return UOp.alu(UnaryOps.NEG, self)
|
46
|
-
def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, x)
|
47
|
-
def
|
48
|
-
def
|
49
|
+
def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, ufix(self.dtype, x))
|
50
|
+
def __radd__(self, x): return UOp.alu(BinaryOps.ADD, ufix(self.dtype, x), self)
|
51
|
+
def __sub__(self, x): return UOp.alu(BinaryOps.ADD, self, -ufix(self.dtype, x))
|
52
|
+
def __mul__(self, x): return UOp.alu(BinaryOps.MUL, self, ufix(self.dtype, x))
|
53
|
+
def __rmul__(self, x): return UOp.alu(BinaryOps.MUL, ufix(self.dtype, x), self)
|
54
|
+
def __floordiv__(self, x): return UOp.alu(BinaryOps.IDIV, self, ufix(self.dtype, x))
|
55
|
+
def __truediv__(self, x): return UOp.alu(BinaryOps.MUL, self, UOp.alu(UnaryOps.RECIP, ufix(self.dtype, x)))
|
56
|
+
def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x))
|
57
|
+
def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x))
|
58
|
+
def ge(self, x): return -self.lt(x)
|
59
|
+
def max(self, x): return UOp.alu(BinaryOps.MAX, self, x)
|
60
|
+
def min(self, x): return -UOp.alu(BinaryOps.MAX, -self, -x)
|
49
61
|
@staticmethod
|
50
|
-
|
62
|
+
@functools.lru_cache(maxsize=None)
|
63
|
+
def const(dtype:Optional[DType], b:ConstType|Variable):
|
64
|
+
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b)
|
65
|
+
return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
|
51
66
|
@staticmethod
|
52
|
-
def
|
67
|
+
def alu(arg, *src:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else src[-1].dtype, src, arg)
|
53
68
|
@staticmethod
|
54
|
-
def
|
69
|
+
def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values()))
|
55
70
|
@staticmethod
|
56
|
-
def
|
71
|
+
def store(*src:UOp, **kwargs): return UOp(UOps.STORE, None, tuple(src)+tuple(kwargs.values()))
|
72
|
+
@staticmethod
|
73
|
+
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.VAR, dtype=dtype, arg=name)
|
74
|
+
@staticmethod
|
75
|
+
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.CONST, dtype=dtype).name(name)
|
57
76
|
@functools.cached_property
|
58
|
-
def parents(self) -> Set[UOp]: return set.union(set(self.
|
77
|
+
def parents(self) -> Set[UOp]: return set.union(set(self.src), *[x.parents for x in self.src])
|
78
|
+
@property # parents with self
|
79
|
+
def sparents(self) -> Set[UOp]: return set([self]).union(self.parents)
|
80
|
+
def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.op is UOps.DEFINE_VAR])
|
59
81
|
|
60
82
|
def uop_alu_resolve(u:UOp) -> sint:
|
61
|
-
if u.
|
62
|
-
if u.
|
63
|
-
if u.
|
64
|
-
if u.
|
65
|
-
if u.
|
66
|
-
|
83
|
+
if u.op is UOps.CONST: return u.arg
|
84
|
+
if u.op is UOps.DEFINE_VAR: return u.arg
|
85
|
+
if u.op is UOps.SPECIAL: return u.arg[2]-1
|
86
|
+
if u.op is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.src[0]) * uop_alu_resolve(u.src[1])
|
87
|
+
if u.op is UOps.ALU and u.arg is BinaryOps.SHL: return uop_alu_resolve(u.src[0]) * (2**cast(int, uop_alu_resolve(u.src[1])))
|
88
|
+
if u.op is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.src[0]) + uop_alu_resolve(u.src[1])
|
89
|
+
raise RuntimeError(f"ALU resolve fail @ {u.op}")
|
67
90
|
|
68
91
|
# *** simplification logic ***
|
69
92
|
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
93
|
+
@dataclass(frozen=True)
|
94
|
+
class UPat:
|
95
|
+
op: Optional[Union[UOps, Set[UOps]]] = None
|
96
|
+
arg: Any = None
|
97
|
+
src: Optional[Union[Tuple[UPat, ...], List[UPat], UPat]] = None
|
98
|
+
name: Optional[str] = None
|
99
|
+
dtype: Optional[Union[DType, Set[DType]]] = None
|
100
|
+
allow_len: Set[int] = field(default_factory=set)
|
101
|
+
|
102
|
+
@staticmethod
|
103
|
+
def compile(u: UOp, name:Optional[str]=None) -> UPat:
|
104
|
+
if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg)
|
105
|
+
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None, name, u.dtype)
|
106
|
+
|
107
|
+
T = TypeVar("T")
|
108
|
+
def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool: return m2 not in m1 if isinstance(m1, set) else m2 != m1
|
109
|
+
|
110
|
+
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
|
111
|
+
if pat.name is not None and store.setdefault(pat.name, uop) is not uop: return False
|
112
|
+
if pat.arg is not None and __unmatch(pat.arg, uop.arg): return False
|
113
|
+
if pat.dtype is not None and uop.dtype is not None and __unmatch(pat.dtype, uop.dtype): return False
|
114
|
+
if pat.op is not None and __unmatch(pat.op, uop.op): return False
|
115
|
+
if pat.src is None: return True
|
116
|
+
# only one if it's a tuple
|
117
|
+
# try all permutations if it's a list
|
118
|
+
# repeat if it's a UPat
|
119
|
+
for vp in itertools.permutations(pat.src) if isinstance(pat.src,list) else ([pat.src] if isinstance(pat.src,tuple) else [(pat.src,)*len(uop.src)]):
|
120
|
+
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len): return False
|
121
|
+
new_store = store.copy()
|
122
|
+
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.src, vp)):
|
123
|
+
store.update(new_store)
|
124
|
+
return True
|
125
|
+
return False
|
97
126
|
|
98
127
|
class PatternMatcher:
|
99
|
-
def __init__(self, patterns:List[Tuple[
|
128
|
+
def __init__(self, patterns:List[Tuple[Union[UPat, UOp], Callable]]):
|
100
129
|
self.patterns = patterns
|
101
|
-
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[
|
130
|
+
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list)
|
102
131
|
# uop is required, arg is optional
|
103
132
|
for p,fxn in self.patterns:
|
104
|
-
|
105
|
-
|
106
|
-
|
133
|
+
if isinstance(p, UOp): p = UPat.compile(p)
|
134
|
+
assert p.op is not None
|
135
|
+
if isinstance(p.op, set):
|
136
|
+
for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
|
107
137
|
else:
|
108
|
-
self.pdict[(
|
138
|
+
self.pdict[(p.op, p.arg)].append((p, fxn))
|
109
139
|
|
110
140
|
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
111
|
-
for p,fxn in itertools.chain(self.pdict[(uop.
|
141
|
+
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
|
112
142
|
store: Dict[str, UOp] = {}
|
113
143
|
if _match(uop, p, store): return fxn(**store)
|
114
144
|
return None
|
@@ -116,130 +146,169 @@ class PatternMatcher:
|
|
116
146
|
def sum_collapse(phi_input, loop, val1, val2):
|
117
147
|
for v1,v2 in [(val1, val2), (val2, val1)]:
|
118
148
|
if loop not in v1.parents:
|
119
|
-
loop_range = loop.
|
149
|
+
loop_range = loop.src[1]-loop.src[0]
|
120
150
|
ret = v1*loop_range.cast(v1.dtype)
|
121
151
|
return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+ret
|
122
152
|
return None
|
123
153
|
|
124
|
-
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst):
|
154
|
+
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng):
|
155
|
+
if not rng.arg[2]: return None # must be a reduce
|
125
156
|
if mval.arg >= 0 or loop_start.arg != 0:
|
126
157
|
# TODO: support and test this with other mvals and loop_starts
|
127
158
|
if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
|
128
159
|
return None
|
129
|
-
comprange = UOp.min(loop_end, UOp.max(UOp.alu(BinaryOps.
|
160
|
+
comprange = UOp.min(loop_end, UOp.max(UOp.alu(BinaryOps.IDIV, idx-compval-mval, mval) + (loop_end-loop_start), loop_start))
|
130
161
|
return UOp(UOps.UNMUL, multconst.dtype, (comprange.cast(multconst.dtype) * multconst, loop_end-loop_start))
|
131
162
|
|
132
163
|
# this is symbolic 2.0
|
133
164
|
constant_folder = PatternMatcher([
|
134
165
|
# arange loop folding (early)
|
135
|
-
(
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
166
|
+
(UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
|
167
|
+
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL, src=[UPat(UOps.CONST, name="mval"),
|
168
|
+
UPat(UOps.RANGE, name="rng", src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
|
169
|
+
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse),
|
170
|
+
(UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
|
171
|
+
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, UnaryOps.NEG, src=[
|
172
|
+
UPat(UOps.RANGE, name="rng", src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
|
173
|
+
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))),
|
174
|
+
lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
|
140
175
|
# sum collapse to mul (with possible GEP)
|
141
|
-
(
|
142
|
-
|
143
|
-
(
|
144
|
-
|
145
|
-
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
|
176
|
+
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),
|
177
|
+
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
178
|
+
(UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),)),
|
179
|
+
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
146
180
|
# deal with UNMUL
|
147
|
-
(
|
148
|
-
|
149
|
-
|
150
|
-
(
|
151
|
-
({"__name__": "root", "uop": UOps.CAST, "vin": ({"uop": UOps.UNMUL, "__name__": "unmul"},)},
|
152
|
-
lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.vin[0].cast(root.dtype), unmul.vin[1]))),
|
181
|
+
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
|
182
|
+
lambda c1,c2,v: v if c1.arg == c2.arg else None),
|
183
|
+
(UOp(UOps.UNMUL, src=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero),
|
184
|
+
(UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.src[0].cast(root.dtype), unmul.src[1]))),
|
153
185
|
# max on special can go away (TODO: special should be variable, same thing applies)
|
154
|
-
(
|
155
|
-
lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
|
186
|
+
(UOp.max(UOp.cvar('c'), UOp(UOps.SPECIAL).name('s')), lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
|
156
187
|
# const rules
|
157
|
-
(
|
158
|
-
(
|
188
|
+
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
|
189
|
+
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
|
159
190
|
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
|
160
|
-
(
|
161
|
-
(
|
162
|
-
(
|
191
|
+
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
|
192
|
+
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST),)), UPat(name="x"))), lambda x: x),
|
193
|
+
(UPat(UOps.PHI, src=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
|
163
194
|
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
|
164
|
-
(
|
165
|
-
(
|
195
|
+
(UPat(UOps.DEFINE_ACC, name="root", src=(UPat(UOps.CONST),)), lambda root: UOp.cast(root.src[0], root.dtype)),
|
196
|
+
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
|
166
197
|
# max -2147483648
|
167
|
-
(
|
168
|
-
# -(-x) -> x
|
169
|
-
({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"__name__": "x"},)})}, lambda x: x),
|
170
|
-
# x+-y -> x-y
|
171
|
-
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "x"}, {"__name__": "my", "uop": UOps.ALU, "arg": UnaryOps.NEG})},
|
172
|
-
lambda x, my: x-my.vin[0]),
|
173
|
-
# -1*x -> -x
|
174
|
-
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -1}]}, lambda x: -x),
|
198
|
+
(UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
|
175
199
|
# bool < False is always false, True < bool is always false
|
176
|
-
(
|
177
|
-
(
|
178
|
-
lambda x: UOp.const(dtypes.bool, False)),
|
200
|
+
(UOp.var().lt(UOp.const(dtypes.bool, False)), lambda: UOp.const(dtypes.bool, False)),
|
201
|
+
(UOp.const(dtypes.bool, True).lt(UOp.var()), lambda: UOp.const(dtypes.bool, False)),
|
179
202
|
# a conditional with the same results either way is a noop, also fold const conditionals
|
180
|
-
(
|
181
|
-
(
|
182
|
-
lambda gate, c0, c1: c0 if gate.arg else c1),
|
203
|
+
(UOp.alu(TernaryOps.WHERE, UOp.var(), UOp.var("val"), UOp.var("val")), lambda val: val),
|
204
|
+
(UOp.alu(TernaryOps.WHERE, UOp.cvar('gate'), UOp.var('c0'), UOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
|
183
205
|
# ** constant folding **
|
184
|
-
(
|
185
|
-
lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))),
|
206
|
+
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
|
186
207
|
# ** self folding **
|
187
|
-
(
|
188
|
-
(
|
189
|
-
(
|
190
|
-
(
|
191
|
-
(
|
208
|
+
(-(-UOp.var('x')), lambda x: x), # -(-x) -> x
|
209
|
+
(UOp.var('x') + 0, lambda x: x), # x+0 -> x
|
210
|
+
(UOp.var('x') - 0, lambda x: x), # x-0 -> x
|
211
|
+
(UOp.var('x') * 1, lambda x: x), # x*1 -> x
|
212
|
+
(UOp.var('x') * -1, lambda x: -x), # x*-1 -> -x
|
213
|
+
(UOp.var('x') // UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x//x -> 1
|
214
|
+
(UOp.var('x') // 1, lambda x: x), # x//1 -> x
|
215
|
+
(UOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
|
216
|
+
(UOp.var('x') / UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x/x -> 1
|
217
|
+
(UOp.var('x') / UOp.cvar('c'), lambda x,c: x*exec_alu(UnaryOps.RECIP, c.dtype, [c.arg])), # x/c -> x*(1/c)
|
218
|
+
(UOp.var('x', dtype=dtypes.bool).max(UOp.const(dtypes.bool, False)), lambda x: x), # max(x, False) -> x
|
192
219
|
# ** zero folding **
|
193
|
-
|
194
|
-
|
220
|
+
#x*0 -> 0 or 0*x -> 0
|
221
|
+
#if x is nan or inf it should render the nan value.
|
222
|
+
# NOTE: this can be wrong for loaded NaN
|
223
|
+
(UOp.var('x') * 0, lambda x: UOp.const(x.dtype, float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
224
|
+
(UOp.var('x') - UOp.var('x'), lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
|
195
225
|
# ** load/store folding **
|
196
|
-
(
|
197
|
-
{"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})}, lambda buf, idx: UOp(UOps.NOOP)),
|
226
|
+
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.load(UOp.var("buf"), UOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
|
198
227
|
# ** two stage add/sub folding **
|
199
|
-
(
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
228
|
+
((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
|
229
|
+
((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),
|
230
|
+
# *** rules from symbolic ***
|
231
|
+
# two stage mul, (x*c1)*c2 = x*(c1*c2)
|
232
|
+
((UOp.var("x") * UOp.cvar("c1")) * UOp.cvar("c2"), lambda x,c1,c2: x*UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
|
233
|
+
# x%1 -> 0
|
234
|
+
(UOp.var("x") % UOp.const(None, 1), lambda x: UOp.const(x.dtype, 0)),
|
235
|
+
# (x*c0)+(x*c1) -> x*(c0+c1)
|
236
|
+
(UOp.var("x") * UOp.cvar("c0") + UOp.var("x") * UOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
|
237
|
+
# (x*c0)+(y*c0) -> (x+y)*c0
|
238
|
+
#((UOp.var("x") * UOp.cvar("c0")) + (UOp.var("y") * UOp.cvar("c0")), lambda x,y,c0: c0*(x+y)),
|
239
|
+
# (x*c0)//c0 -> x
|
240
|
+
((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c0"), lambda x,c0: x if c0.arg != 0 else None),
|
241
|
+
# (x*x2)/x2 -> x
|
242
|
+
((UOp.var("x") * UOp.var("x2")) / UOp.var("x2"), lambda x,x2: x),
|
243
|
+
# (x//c0)//c1 -> x//(c0*c1)
|
244
|
+
((UOp.var("x") // UOp.cvar("c0")) // UOp.cvar("c1"), lambda x,c0,c1: x//UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))),
|
245
|
+
# (x/x1)/x2 -> x/(x1*x2)
|
246
|
+
((UOp.var("x") / UOp.var("x2")) / UOp.var("x3"), lambda x,x2,x3: x/(x2*x3)),
|
247
|
+
# c0 + x < c1 -> x < c1 - c0
|
248
|
+
((UOp.cvar("c0") + UOp.var("x")).lt(UOp.cvar("c1")),
|
249
|
+
lambda x,c0,c1: UOp.lt(x, UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
|
250
|
+
# (x+x*c0)-> x*(c0+1)
|
251
|
+
(UOp.var("x") + UOp.var("x") * UOp.cvar("c0"), lambda x,c0: x*UOp.const(x.dtype, c0.arg+1)),
|
205
252
|
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
206
|
-
(
|
207
|
-
|
208
|
-
lambda buf, idx, gate, alt: UOp(UOps.STORE, None, (buf, idx, alt, gate))),
|
253
|
+
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
|
254
|
+
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
|
209
255
|
# store float4/float2 directly (remove CAST/GEP)
|
210
|
-
(
|
211
|
-
|
212
|
-
lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
|
213
|
-
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.CAST, "vin":
|
214
|
-
tuple({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i} for i in range(2))})},
|
215
|
-
lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
|
256
|
+
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store),
|
257
|
+
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store),
|
216
258
|
# CAST-PHI-GEP -> PHI-CAST
|
217
|
-
(
|
218
|
-
tuple({"uop": UOps.PHI, "vin": ({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i}, {"__name__": f"v{i}"})} for i in range(4))},
|
259
|
+
(UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
|
219
260
|
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
|
220
|
-
(
|
221
|
-
tuple({"uop": UOps.PHI, "vin": ({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i}, {"__name__": f"v{i}"})} for i in range(2))},
|
261
|
+
(UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
|
222
262
|
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
|
223
263
|
# NEG/CMPLT -> CMPLT
|
224
|
-
(
|
225
|
-
{"__name__": "c", "uop": UOps.CONST, "dtype": dtypes.int})},
|
226
|
-
lambda c,x: UOp(UOps.ALU, dtypes.bool, (UOp.const(c.dtype, -c.arg), x), BinaryOps.CMPLT)),
|
264
|
+
(UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
|
227
265
|
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
228
|
-
(
|
266
|
+
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
267
|
+
# fold gated LOAD/STORE
|
268
|
+
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
|
269
|
+
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var"), UOp.var("barrier")),
|
270
|
+
lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
|
271
|
+
(UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var")), lambda var: var),
|
272
|
+
(UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var"), UOp.var()), lambda var: var),
|
273
|
+
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(None, 1)), UOp.store),
|
274
|
+
(UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(None, 0)), lambda: UOp(UOps.NOOP)),
|
275
|
+
# remove NOOPs from SINK
|
276
|
+
(UPat(UOps.SINK, name="root"),
|
277
|
+
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None)
|
229
278
|
])
|
230
279
|
|
231
280
|
# *** uop graph ***
|
232
281
|
|
282
|
+
def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]):
|
283
|
+
if u in children: return
|
284
|
+
children[u] = []
|
285
|
+
for x in u.src:
|
286
|
+
get_children_dfs(x, children, in_degree)
|
287
|
+
children[x].append(u)
|
288
|
+
in_degree[u] = len(u.src)
|
289
|
+
|
290
|
+
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
|
291
|
+
nodes: Dict[Tuple, UOp] = {}
|
292
|
+
replace: Dict[UOp, UOp] = {}
|
293
|
+
def __inner_rewrite(n:UOp) -> UOp:
|
294
|
+
if n in replace: return replace[n]
|
295
|
+
replace_source = (n.op, n.dtype, tuple(__inner_rewrite(y) for y in n.src), n.arg)
|
296
|
+
if found := nodes.get(replace_source): replace[n] = found
|
297
|
+
else: nodes[replace_source] = replace[n] = __inner_rewrite(new_x) if (new_x := pm.rewrite(x:=UOp(*replace_source))) else x
|
298
|
+
return replace[n]
|
299
|
+
return __inner_rewrite(sink)
|
300
|
+
|
233
301
|
class UOpGraph:
|
234
|
-
def __init__(self):
|
235
|
-
self.
|
302
|
+
def __init__(self, sinks:List[UOp]):
|
303
|
+
self.sinks: List[UOp] = sinks
|
304
|
+
# used by linearizer
|
236
305
|
self._uops: Optional[List[UOp]] = None
|
237
306
|
|
238
307
|
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
|
239
308
|
def __getitem__(self, index) -> UOp: return self.uops[index]
|
240
309
|
|
241
|
-
def vars(self) -> List[Variable]: return [x.arg for x in self.uops if x.
|
242
|
-
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.
|
310
|
+
def vars(self) -> List[Variable]: return sorted([x.arg for x in self.uops if x.op is UOps.DEFINE_VAR], key=lambda v: v.expr)
|
311
|
+
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.op is UOps.DEFINE_GLOBAL]
|
243
312
|
|
244
313
|
@property
|
245
314
|
def uops(self):
|
@@ -252,164 +321,131 @@ class UOpGraph:
|
|
252
321
|
|
253
322
|
def print(self):
|
254
323
|
for i,u in enumerate(self):
|
255
|
-
print(f"{i:4d} {str(u.
|
256
|
-
|
257
|
-
def graph_rewrite(self, sink, pm):
|
258
|
-
# recursive rewrite
|
259
|
-
changed = getenv("UOPS_REWRITE", 1)
|
260
|
-
run_cnt = 0
|
261
|
-
while changed:
|
262
|
-
changed = 0
|
263
|
-
@functools.lru_cache
|
264
|
-
def rewrite(u:UOp) -> UOp:
|
265
|
-
nonlocal changed
|
266
|
-
recurse_cnt = 0
|
267
|
-
up = u
|
268
|
-
# locally recursively rewrite
|
269
|
-
while (rewritten := pm.rewrite(up)):
|
270
|
-
assert recurse_cnt < 100, f"recursive_rewrite looped {up} <--> {rewritten}"
|
271
|
-
up = rewritten
|
272
|
-
recurse_cnt += 1
|
273
|
-
changed += recurse_cnt
|
274
|
-
# NOTE: this changes UOp, so we have to delete caches
|
275
|
-
up.vin = tuple(rewrite(x) for x in up.vin)
|
276
|
-
if hasattr(up, "parents"): del up.parents
|
277
|
-
if hasattr(up, "cmp_tuple"): del up.cmp_tuple
|
278
|
-
# replace with cached nodes
|
279
|
-
if found:=self.nodes.get(key:=up.tuple()): return found
|
280
|
-
else: self.nodes[key] = up
|
281
|
-
return up
|
282
|
-
sink = rewrite(sink)
|
283
|
-
run_cnt += 1
|
284
|
-
assert run_cnt < 100, "exceeded 100 rewrite loops!"
|
285
|
-
return sink
|
324
|
+
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}")
|
286
325
|
|
287
326
|
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
|
288
327
|
# NOTE: relinearizering should be okay
|
289
328
|
#assert self._uops is None, "already linearized"
|
329
|
+
sink = UOp(UOps.SINK, None, tuple(self.sinks))
|
290
330
|
|
291
|
-
#
|
292
|
-
|
293
|
-
|
294
|
-
if u.uop is UOps.STORE: _sinks.append(u)
|
295
|
-
if u.uop is UOps.SINK: _sinks.extend(u.vin)
|
296
|
-
sink = UOp(UOps.SINK, None, tuple(_sinks))
|
297
|
-
del _sinks
|
298
|
-
|
299
|
-
sink = self.graph_rewrite(sink, constant_folder)
|
300
|
-
if extra_pm: sink = self.graph_rewrite(sink, PatternMatcher(constant_folder.patterns+extra_pm.patterns))
|
331
|
+
# dedup all nodes and do graph rewrite
|
332
|
+
sink = graph_rewrite(sink, constant_folder)
|
333
|
+
if extra_pm: sink = graph_rewrite(sink, PatternMatcher(constant_folder.patterns+extra_pm.patterns))
|
301
334
|
|
302
335
|
# filter nodes that don't link to a sink
|
303
336
|
# BFS toposort
|
304
|
-
|
305
|
-
in_degree:
|
306
|
-
|
307
|
-
ifs = []
|
308
|
-
nodes: Dict[UOp, None] = {}
|
309
|
-
def add_parents(u:UOp):
|
310
|
-
if u in nodes: return
|
311
|
-
nodes[u] = None
|
312
|
-
for x in u.vin:
|
313
|
-
add_parents(x)
|
314
|
-
in_degree[u] += 1
|
315
|
-
graph[x].append(u)
|
316
|
-
if u.uop is UOps.RANGE: loops.append(u)
|
317
|
-
if u.uop is UOps.IF: ifs.append(u)
|
318
|
-
sink = UOp(UOps.SINK, None, tuple(x for x in sink.vin if x.uop is not UOps.NOOP))
|
319
|
-
add_parents(sink)
|
337
|
+
children: Dict[UOp, List[UOp]] = {}
|
338
|
+
in_degree: Dict[UOp, int] = {}
|
339
|
+
get_children_dfs(sink, children, in_degree)
|
320
340
|
|
321
341
|
@functools.lru_cache(None)
|
322
|
-
def get_recursive_children(x:UOp, include_self=False) -> Set[UOp]:
|
323
|
-
if x.
|
324
|
-
return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, True) for u in
|
325
|
-
|
342
|
+
def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
|
343
|
+
if x.op is UOps.SINK: return set()
|
344
|
+
return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
|
345
|
+
|
346
|
+
# scope children impact the toposort and END* insertion
|
347
|
+
end_for_uop = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
|
348
|
+
loops, ifs = [x for x in in_degree if x.op is UOps.RANGE], [x for x in in_degree if x.op is UOps.IF]
|
349
|
+
scope_children = {p:get_recursive_children(p, end_for_uop[p.op][0]) for p in (loops+ifs)[::-1]}
|
326
350
|
|
327
|
-
queue:
|
328
|
-
def push(u):
|
351
|
+
queue:List[Tuple[int, UOp]] = []
|
352
|
+
def push(u:UOp):
|
329
353
|
priority = 0
|
330
354
|
# prefer uops that are loop children
|
331
|
-
for l, ss in
|
332
|
-
if u in ss: priority -= l.arg[0]*1000 + l.arg[1]
|
355
|
+
for l, ss in scope_children.items():
|
356
|
+
if l.op is UOps.RANGE and u in ss: priority -= l.arg[0]*1000 + l.arg[1]
|
333
357
|
heapq.heappush(queue, (priority, u))
|
334
358
|
|
335
|
-
for u in
|
359
|
+
for u in children:
|
336
360
|
if in_degree[u] == 0: push(u)
|
337
361
|
|
338
362
|
if getenv("FUZZ_UOPS", 0):
|
339
363
|
from test.external.fuzz_uops import fuzz_uops
|
340
|
-
self.fuzz_paths = fuzz_uops(
|
364
|
+
self.fuzz_paths = fuzz_uops(children, in_degree.copy(), scope_children)
|
341
365
|
|
342
366
|
self._uops = []
|
343
367
|
while queue:
|
344
368
|
p,x = heapq.heappop(queue)
|
345
369
|
if DEBUG >= 7: print(p,x)
|
346
|
-
if x.
|
347
|
-
idx = min([self._uops.index(l) for l in x.
|
370
|
+
if x.op is UOps.DEFINE_ACC and len(x.src) > 1:
|
371
|
+
idx = min([self._uops.index(l) for l in x.src if l.op is UOps.RANGE])
|
348
372
|
self._uops.insert(idx, x)
|
349
373
|
else:
|
350
374
|
self._uops.append(x)
|
351
|
-
for u, ss in
|
375
|
+
for u, ss in scope_children.items():
|
352
376
|
if x in ss:
|
353
377
|
ss.remove(x)
|
354
|
-
if len(ss) == 0: self._uops.append(UOp(
|
355
|
-
for u in
|
378
|
+
if len(ss) == 0: self._uops.append(UOp(end_for_uop[u.op][1], None, (u,)))
|
379
|
+
for u in children[x]:
|
356
380
|
in_degree[u] -= 1
|
357
381
|
if in_degree[u] == 0: push(u)
|
358
382
|
|
359
|
-
assert self._uops[-1].
|
383
|
+
assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
|
360
384
|
self._uops = self._uops[:-1]
|
361
385
|
|
362
|
-
# TODO: ifs should be removed and just the store should be gated
|
363
|
-
for u in ifs[::-1]: self._uops.append(UOp(UOps.ENDIF, None, (u,)))
|
364
|
-
|
365
386
|
if type_verify: self.type_verify()
|
366
387
|
|
367
|
-
def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None) -> UOp:
|
368
|
-
if found:=self.nodes.get(key:=(uop, dtype, vin, arg)): return found
|
369
|
-
self.nodes[key] = ret = UOp(*key)
|
370
|
-
return ret
|
371
|
-
|
372
388
|
# *** checker functions ***
|
373
389
|
|
374
|
-
def flops_mem(self) -> Tuple[sint, sint]:
|
390
|
+
def flops_mem(self, ignore_indexing=False) -> Tuple[sint, sint]:
|
375
391
|
flops: sint = 0
|
376
392
|
mem: sint = 0
|
377
393
|
mults: sint = 1
|
378
394
|
mult_stack = []
|
395
|
+
dont_count: Set[UOp] = set()
|
396
|
+
if ignore_indexing:
|
397
|
+
for u in self.uops:
|
398
|
+
if u.op is UOps.LOAD:
|
399
|
+
dont_count = dont_count.union(u.src[1].sparents)
|
400
|
+
if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
|
401
|
+
elif u.op is UOps.STORE:
|
402
|
+
dont_count = dont_count.union(u.src[1].sparents)
|
403
|
+
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
|
379
404
|
for u in self.uops:
|
380
|
-
if u.
|
405
|
+
if u.op is UOps.RANGE:
|
381
406
|
mult_stack.append(mults)
|
382
|
-
mults *= uop_alu_resolve(u.
|
383
|
-
elif u.
|
407
|
+
mults *= uop_alu_resolve(u.src[1])
|
408
|
+
elif u.op is UOps.ENDRANGE:
|
384
409
|
mults = mult_stack.pop(-1)
|
385
|
-
elif u.
|
386
|
-
flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
|
387
|
-
elif u.uop is UOps.LOAD:
|
410
|
+
elif u.op is UOps.LOAD:
|
388
411
|
assert u.dtype is not None
|
389
412
|
mem += u.dtype.itemsize * mults
|
390
|
-
elif u.
|
391
|
-
assert u.
|
392
|
-
mem += u.
|
393
|
-
elif u.
|
413
|
+
elif u.op is UOps.STORE:
|
414
|
+
assert u.src[2].dtype is not None
|
415
|
+
mem += u.src[2].dtype.itemsize * mults
|
416
|
+
elif u.op is UOps.ALU and u not in dont_count:
|
417
|
+
flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
|
418
|
+
elif u.op is UOps.WMMA and u not in dont_count:
|
394
419
|
assert u.arg[1] is not None
|
395
420
|
flops += 2 * prod(u.arg[1]) // 32 * mults
|
396
421
|
return flops, mem
|
397
422
|
|
398
423
|
def type_verify(self):
|
399
424
|
for u in self.uops:
|
400
|
-
uop, arg,
|
401
|
-
if uop in
|
402
|
-
if uop is UOps.DEFINE_ACC:
|
425
|
+
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
|
426
|
+
if uop in (UOps.CONST, UOps.DEFINE_ACC):
|
427
|
+
if uop is UOps.DEFINE_ACC:
|
428
|
+
assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
|
429
|
+
arg = src[0].arg
|
403
430
|
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
404
431
|
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
|
432
|
+
if uop is UOps.LOAD and len(src) > 2 and src[2].op not in {UOps.IF, UOps.BARRIER}: assert src[2].dtype == dtypes.bool
|
433
|
+
if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool
|
405
434
|
if uop is UOps.ALU:
|
406
435
|
if arg in UnaryOps:
|
407
|
-
assert dtype ==
|
408
|
-
elif arg in (BinaryOps.CMPLT, BinaryOps.
|
436
|
+
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
437
|
+
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE):
|
409
438
|
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
410
|
-
assert
|
439
|
+
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
440
|
+
elif arg is BinaryOps.IDIV:
|
441
|
+
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
|
442
|
+
f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
|
443
|
+
assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
|
444
|
+
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
|
445
|
+
# the distance to shift isn't typechecked
|
446
|
+
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
411
447
|
elif arg in BinaryOps:
|
412
|
-
assert dtype ==
|
448
|
+
assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
413
449
|
elif arg == TernaryOps.WHERE:
|
414
|
-
assert
|
415
|
-
assert dtype ==
|
450
|
+
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
|
451
|
+
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|