tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +248 -115
- tinygrad/codegen/lowerer.py +215 -0
- tinygrad/codegen/transcendental.py +310 -0
- tinygrad/codegen/uopgraph.py +622 -0
- tinygrad/codegen/uops.py +235 -393
- tinygrad/device.py +428 -69
- tinygrad/dtype.py +18 -4
- tinygrad/engine/graph.py +19 -32
- tinygrad/engine/jit.py +148 -70
- tinygrad/engine/realize.py +127 -51
- tinygrad/engine/schedule.py +259 -216
- tinygrad/engine/search.py +29 -22
- tinygrad/function.py +9 -0
- tinygrad/helpers.py +87 -49
- tinygrad/lazy.py +34 -35
- tinygrad/multi.py +41 -36
- tinygrad/nn/__init__.py +39 -22
- tinygrad/nn/state.py +3 -3
- tinygrad/ops.py +63 -62
- tinygrad/renderer/__init__.py +43 -21
- tinygrad/renderer/assembly.py +104 -106
- tinygrad/renderer/cstyle.py +87 -60
- tinygrad/renderer/llvmir.py +21 -30
- tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/kfd.py +32 -0
- tinygrad/runtime/autogen/libc.py +4260 -0
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/graph/clang.py +2 -2
- tinygrad/runtime/graph/cuda.py +8 -11
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +18 -15
- tinygrad/runtime/ops_amd.py +197 -305
- tinygrad/runtime/ops_clang.py +2 -2
- tinygrad/runtime/ops_cuda.py +36 -94
- tinygrad/runtime/ops_disk.py +3 -7
- tinygrad/runtime/ops_gpu.py +4 -2
- tinygrad/runtime/ops_hip.py +70 -0
- tinygrad/runtime/ops_metal.py +38 -27
- tinygrad/runtime/ops_nv.py +283 -363
- tinygrad/runtime/ops_python.py +26 -30
- tinygrad/runtime/support/compiler_cuda.py +78 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/shape/shapetracker.py +5 -14
- tinygrad/shape/symbolic.py +4 -8
- tinygrad/shape/view.py +34 -22
- tinygrad/tensor.py +399 -97
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
- tinygrad-0.9.2.dist-info/RECORD +70 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/runtime/{driver → support}/__init__.py +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/codegen/uops.py
CHANGED
@@ -1,24 +1,24 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import
|
3
|
-
import functools, itertools,
|
2
|
+
from typing import Optional, Tuple, Any, Set, cast, List, Union, DefaultDict, Callable, Dict
|
3
|
+
import functools, itertools, math
|
4
4
|
from collections import defaultdict
|
5
5
|
from enum import Enum, auto
|
6
|
-
from dataclasses import dataclass
|
6
|
+
from dataclasses import dataclass
|
7
7
|
from tinygrad.dtype import ConstType, dtypes, DType
|
8
8
|
from tinygrad.shape.symbolic import sint, Variable
|
9
9
|
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
|
10
|
-
from tinygrad.helpers import
|
10
|
+
from tinygrad.helpers import merge_dicts, prod, pretty_print
|
11
11
|
|
12
12
|
# the order of these UOps controls the order of the toposort
|
13
13
|
class UOps(Enum):
|
14
14
|
# ops that aren't rendered
|
15
|
-
SINK = auto();
|
15
|
+
SINK = auto(); EXPAND = auto(); CONTRACT = auto() # noqa: E702
|
16
16
|
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
|
17
17
|
CONST = auto(); SPECIAL = auto() # noqa: E702
|
18
|
-
NOOP = auto();
|
18
|
+
NOOP = auto(); GEP = auto() # noqa: E702
|
19
19
|
# math ops
|
20
|
-
CAST = auto(); BITCAST = auto() # noqa: E702
|
21
|
-
ALU = auto(); WMMA = auto() # noqa: E702
|
20
|
+
CAST = auto(); BITCAST = auto(); VECTORIZE = auto() # noqa: E702
|
21
|
+
ALU = auto(); REDUCE = auto(); WMMA = auto() # noqa: E702
|
22
22
|
# memory/assignment ops
|
23
23
|
LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
|
24
24
|
# control flow ops
|
@@ -26,7 +26,8 @@ class UOps(Enum):
|
|
26
26
|
# these two are not graph nodes
|
27
27
|
ENDRANGE = auto(); ENDIF = auto() # noqa: E702
|
28
28
|
|
29
|
-
|
29
|
+
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
|
30
|
+
|
30
31
|
@dataclass(frozen=True, eq=False)
|
31
32
|
class UOp:
|
32
33
|
op: UOps
|
@@ -34,418 +35,259 @@ class UOp:
|
|
34
35
|
src: Tuple[UOp, ...] = tuple()
|
35
36
|
arg: Any = None
|
36
37
|
def commutative(self) -> bool:
|
37
|
-
return self.op is UOps.ALU and
|
38
|
+
return (self.op is UOps.ALU and \
|
39
|
+
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR})
|
38
40
|
@functools.cached_property
|
39
41
|
def cmp_tuple(self):
|
40
42
|
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
41
43
|
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
|
42
|
-
|
44
|
+
self.arg.value, self.dtype, self.src)
|
43
45
|
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
|
44
|
-
def __repr__(self):
|
45
|
-
|
46
|
-
def
|
47
|
-
def
|
48
|
-
def
|
49
|
-
def
|
50
|
-
def
|
51
|
-
def
|
52
|
-
def
|
53
|
-
def
|
54
|
-
def
|
55
|
-
def
|
56
|
-
def
|
57
|
-
def
|
58
|
-
def
|
59
|
-
def
|
60
|
-
def
|
46
|
+
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.arg}, src=(%s))")
|
47
|
+
# *** uop syntactic sugar
|
48
|
+
def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x
|
49
|
+
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
|
50
|
+
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
|
51
|
+
def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i)
|
52
|
+
def __neg__(self): return self.alu(UnaryOps.NEG)
|
53
|
+
def __add__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x))
|
54
|
+
def __radd__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x))
|
55
|
+
def __sub__(self, x): return self.alu(BinaryOps.ADD, self.ufix(-x))
|
56
|
+
def __mul__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x))
|
57
|
+
def __rmul__(self, x): return self.ufix(x).alu(BinaryOps.MUL, self)
|
58
|
+
def __floordiv__(self, x): return self.alu(BinaryOps.IDIV, self.ufix(x))
|
59
|
+
def __truediv__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x).alu(UnaryOps.RECIP))
|
60
|
+
def __mod__(self, x): return self.alu(BinaryOps.MOD, self.ufix(x))
|
61
|
+
def __xor__(self, x): return self.alu(BinaryOps.XOR, self.ufix(x))
|
62
|
+
def __and__(self, x): return self.alu(BinaryOps.AND, self.ufix(x))
|
63
|
+
def __or__(self, x): return self.alu(BinaryOps.OR, self.ufix(x))
|
64
|
+
def ne(self, x): return self.alu(BinaryOps.CMPNE, self.ufix(x))
|
65
|
+
def eq(self, x): return -self.ne(x)
|
66
|
+
def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x))
|
67
|
+
def ge(self, x): return (-self).lt(-x+1)
|
68
|
+
def max(self, x): return self.alu(BinaryOps.MAX, x)
|
69
|
+
def min(self, x): return -(-self).max(-x)
|
70
|
+
def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y)
|
71
|
+
def recip(self): return self.alu(UnaryOps.RECIP)
|
72
|
+
def const(self:Union[UOp, DType, None], b:ConstType|Variable): return UOp._const(self.dtype if isinstance(self, UOp) else self, b)
|
73
|
+
def sconst(self:Union[UOp, DType, None], b:ConstType|Variable):
|
74
|
+
return UOp._const(cast(DType, self.dtype if isinstance(self, UOp) else self).scalar() if self is not None else self, b)
|
61
75
|
@staticmethod
|
62
76
|
@functools.lru_cache(maxsize=None)
|
63
|
-
def
|
64
|
-
|
77
|
+
def _const(dtype:Optional[DType], b:ConstType|Variable):
|
78
|
+
# TODO: fix dtype of b.max after Variable is just an UOp
|
79
|
+
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (UOp.const(dtypes.int, b.min), UOp.const(dtypes.int, cast(int,b.max))), b)
|
80
|
+
if dtype is not None and dtype != (sdtype := dtype.scalar()):
|
81
|
+
return UOp(UOps.VECTORIZE, dtype, src=tuple(UOp(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
|
65
82
|
return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
|
83
|
+
def alu(self, arg, *src:UOp):
|
84
|
+
return type(self)(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else (self, *src)[-1].dtype, (self,)+src, arg)
|
66
85
|
@staticmethod
|
67
|
-
def
|
68
|
-
@staticmethod
|
69
|
-
def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values()))
|
86
|
+
def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return type(src[0])(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values()))
|
70
87
|
@staticmethod
|
71
|
-
def store(*src:UOp, **kwargs): return
|
72
|
-
@staticmethod
|
73
|
-
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.VAR, dtype=dtype, arg=name)
|
74
|
-
@staticmethod
|
75
|
-
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.CONST, dtype=dtype).name(name)
|
88
|
+
def store(*src:UOp, **kwargs): return type((src:=(*src, *kwargs.values()))[0])(UOps.STORE, None, src)
|
76
89
|
@functools.cached_property
|
77
|
-
def parents(self) ->
|
90
|
+
def parents(self) -> Dict[UOp, None]: return merge_dicts([{x:None for x in self.src}]+[x.parents for x in self.src])
|
78
91
|
@property # parents with self
|
79
|
-
def sparents(self) ->
|
80
|
-
def vars(self) -> Set[UOp]: return set([x for x in
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
92
|
+
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
|
93
|
+
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
|
94
|
+
def const_factor(self) -> int:
|
95
|
+
"""largest known int that divides self"""
|
96
|
+
if self.op is UOps.CONST: return self.arg
|
97
|
+
if self.op is UOps.ALU:
|
98
|
+
if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[0].const_factor())
|
99
|
+
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1
|
100
|
+
return 1
|
101
|
+
def divides(self, v) -> Optional[UOp]:
|
102
|
+
if v==1: return self
|
103
|
+
if self.op is UOps.CONST: return self.const(self.arg//v) if self.arg%v == 0 else None
|
104
|
+
if self.op is UOps.ALU:
|
105
|
+
if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
|
106
|
+
if self.arg is BinaryOps.MUL:
|
107
|
+
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
|
108
|
+
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
109
|
+
return None # generic None if we aren't sure
|
110
|
+
@functools.cached_property
|
111
|
+
def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.min(cast(DType, self.dtype)))
|
112
|
+
@functools.cached_property
|
113
|
+
def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.max(cast(DType, self.dtype)))
|
114
|
+
@functools.cached_property
|
115
|
+
def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
|
116
|
+
# NOTE: returned UOp is assumed to be CONST
|
117
|
+
if self.op is UOps.DEFINE_VAR: return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else None
|
118
|
+
if self.op is UOps.RANGE: return self.src[0], self.const(self.src[1].arg-1) if isinstance(self.src[1].arg, int) else None
|
119
|
+
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
|
120
|
+
if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None
|
121
|
+
if self.op is UOps.CONST: return self, self
|
122
|
+
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
|
123
|
+
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
|
124
|
+
if self.arg is UnaryOps.NEG and self.dtype != dtypes.bool and not dtypes.is_unsigned(cast(DType, self.dtype)):
|
125
|
+
return self.sconst(-s0.vmax.arg), self.sconst(-s0.vmin.arg)
|
126
|
+
if self.arg is BinaryOps.ADD: return self.sconst(s0.vmin.arg+s1.vmin.arg), self.sconst(s0.vmax.arg+s1.vmax.arg)
|
127
|
+
if self.arg is BinaryOps.MUL and (s0.vmin.arg >= 0 or s1.vmin.arg >= 0):
|
128
|
+
# handle at lease one is non-negative
|
129
|
+
Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg)
|
130
|
+
Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg)
|
131
|
+
assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}"
|
132
|
+
return self.sconst(Lmin*Rmin), self.sconst(Lmax*Rmax)
|
133
|
+
if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.sconst(0), self.sconst(s1.arg-1)
|
134
|
+
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
|
135
|
+
if s1.arg > 0: return self.sconst(s0.vmin.arg//s1.arg), self.sconst(s0.vmax.arg//s1.arg)
|
136
|
+
if s1.arg < 0: return self.sconst(-(s0.vmax.arg//-s1.arg)), self.sconst(-(s0.vmin.arg//-s1.arg))
|
137
|
+
if self.arg is BinaryOps.MAX: return self.sconst(max(s0.vmin.arg, s1.vmin.arg)), self.sconst(max(s0.vmax.arg, s1.vmax.arg))
|
138
|
+
if self.arg is BinaryOps.CMPLT: return (UOp.sconst(dtypes.bool, True), UOp.sconst(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \
|
139
|
+
(UOp.sconst(dtypes.bool, False), UOp.sconst(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None)
|
140
|
+
return None, None
|
141
|
+
|
142
|
+
@dataclass(frozen=True, repr=False) # reuse repr from UOp
|
143
|
+
class NOp(UOp):
|
144
|
+
name:Optional[str] = None
|
145
|
+
src:Tuple[NOp, ...] = tuple()
|
146
|
+
allow_any_len:bool = False
|
147
|
+
@staticmethod
|
148
|
+
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.NOOP, dtype=dtype, name=name)
|
149
|
+
@staticmethod
|
150
|
+
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.CONST, dtype=dtype, name=name)
|
151
|
+
def const(self:Union[UOp, DType, None], b:ConstType|Variable): return NOp((x:=UOp.const(self, b)).op, x.dtype, x.src, x.arg)
|
90
152
|
|
91
|
-
|
153
|
+
def compile(self: NOp, name:Optional[str]=None) -> UPat:
|
154
|
+
return UPat(name=self.name, dtype=self.dtype) if self.op is UOps.NOOP else UPat(self.op, self.arg, (list if self.commutative()
|
155
|
+
else tuple)(src.compile() for src in self.src) or None, self.name or name, self.dtype, self.allow_any_len)
|
92
156
|
|
93
|
-
@dataclass(frozen=True)
|
94
157
|
class UPat:
|
95
|
-
op:
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
158
|
+
def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None,
|
159
|
+
name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False):
|
160
|
+
self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,))
|
161
|
+
self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,))
|
162
|
+
self.arg, self.name = arg, name
|
163
|
+
self.src: Any = None
|
164
|
+
# try all permutations if it's a list
|
165
|
+
if isinstance(src, list): self.src = list(itertools.permutations(src))
|
166
|
+
# only one if it's a tuple
|
167
|
+
elif isinstance(src, tuple): self.src = [src]
|
168
|
+
# repeat if it's a UPat
|
169
|
+
elif isinstance(src, UPat): self.src = [itertools.repeat(src)]
|
170
|
+
|
171
|
+
self.allowed_len: int = 0 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
|
109
172
|
|
110
|
-
def
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
173
|
+
def __repr__(self):
|
174
|
+
def rep(x):
|
175
|
+
form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
|
176
|
+
return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name),
|
177
|
+
set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
|
178
|
+
return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
|
179
|
+
|
180
|
+
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
|
181
|
+
if (pat.name is not None and store.setdefault(pat.name, uop) is not uop) or \
|
182
|
+
(pat.dtype is not None and uop.dtype not in pat.dtype) or \
|
183
|
+
(pat.arg is not None and pat.arg != uop.arg) or \
|
184
|
+
(pat.op is not None and uop.op not in pat.op): return []
|
185
|
+
if pat.src is None: return [store]
|
186
|
+
res: List[Dict[str, UOp]] = []
|
187
|
+
for vp in pat.src:
|
188
|
+
if pat.allowed_len != 0 and len(uop.src) != pat.allowed_len: return []
|
189
|
+
new_stores = [store.copy()]
|
190
|
+
for uu, vv in zip(uop.src, vp): new_stores = [rstore for nstore in new_stores for rstore in _match(uu, vv, nstore)]
|
191
|
+
res.extend(new_stores)
|
192
|
+
return res
|
126
193
|
|
127
194
|
class PatternMatcher:
|
128
|
-
def __init__(self, patterns:List[Tuple[Union[UPat,
|
195
|
+
def __init__(self, patterns:List[Tuple[Union[UPat, NOp], Callable]]):
|
129
196
|
self.patterns = patterns
|
130
197
|
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list)
|
131
198
|
# uop is required, arg is optional
|
132
199
|
for p,fxn in self.patterns:
|
133
|
-
if isinstance(p,
|
200
|
+
if isinstance(p, NOp): p = p.compile()
|
134
201
|
assert p.op is not None
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
202
|
+
for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
|
203
|
+
|
204
|
+
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
205
|
+
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
139
206
|
|
140
207
|
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
141
208
|
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
|
142
|
-
|
143
|
-
if _match(uop, p, store): return fxn(**store)
|
209
|
+
if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
|
144
210
|
return None
|
145
211
|
|
146
|
-
def
|
147
|
-
for
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
# deal with UNMUL
|
181
|
-
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
|
182
|
-
lambda c1,c2,v: v if c1.arg == c2.arg else None),
|
183
|
-
(UOp(UOps.UNMUL, src=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero),
|
184
|
-
(UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.src[0].cast(root.dtype), unmul.src[1]))),
|
185
|
-
# max on special can go away (TODO: special should be variable, same thing applies)
|
186
|
-
(UOp.max(UOp.cvar('c'), UOp(UOps.SPECIAL).name('s')), lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
|
187
|
-
# const rules
|
188
|
-
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
|
189
|
-
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
|
190
|
-
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
|
191
|
-
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
|
192
|
-
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST),)), UPat(name="x"))), lambda x: x),
|
193
|
-
(UPat(UOps.PHI, src=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
|
194
|
-
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
|
195
|
-
(UPat(UOps.DEFINE_ACC, name="root", src=(UPat(UOps.CONST),)), lambda root: UOp.cast(root.src[0], root.dtype)),
|
196
|
-
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
|
197
|
-
# max -2147483648
|
198
|
-
(UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
|
199
|
-
# bool < False is always false, True < bool is always false
|
200
|
-
(UOp.var().lt(UOp.const(dtypes.bool, False)), lambda: UOp.const(dtypes.bool, False)),
|
201
|
-
(UOp.const(dtypes.bool, True).lt(UOp.var()), lambda: UOp.const(dtypes.bool, False)),
|
202
|
-
# a conditional with the same results either way is a noop, also fold const conditionals
|
203
|
-
(UOp.alu(TernaryOps.WHERE, UOp.var(), UOp.var("val"), UOp.var("val")), lambda val: val),
|
204
|
-
(UOp.alu(TernaryOps.WHERE, UOp.cvar('gate'), UOp.var('c0'), UOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
|
205
|
-
# ** constant folding **
|
206
|
-
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
|
207
|
-
# ** self folding **
|
208
|
-
(-(-UOp.var('x')), lambda x: x), # -(-x) -> x
|
209
|
-
(UOp.var('x') + 0, lambda x: x), # x+0 -> x
|
210
|
-
(UOp.var('x') - 0, lambda x: x), # x-0 -> x
|
211
|
-
(UOp.var('x') * 1, lambda x: x), # x*1 -> x
|
212
|
-
(UOp.var('x') * -1, lambda x: -x), # x*-1 -> -x
|
213
|
-
(UOp.var('x') // UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x//x -> 1
|
214
|
-
(UOp.var('x') // 1, lambda x: x), # x//1 -> x
|
215
|
-
(UOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
|
216
|
-
(UOp.var('x') / UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x/x -> 1
|
217
|
-
(UOp.var('x') / UOp.cvar('c'), lambda x,c: x*exec_alu(UnaryOps.RECIP, c.dtype, [c.arg])), # x/c -> x*(1/c)
|
218
|
-
(UOp.var('x', dtype=dtypes.bool).max(UOp.const(dtypes.bool, False)), lambda x: x), # max(x, False) -> x
|
219
|
-
# ** zero folding **
|
220
|
-
#x*0 -> 0 or 0*x -> 0
|
221
|
-
#if x is nan or inf it should render the nan value.
|
222
|
-
# NOTE: this can be wrong for loaded NaN
|
223
|
-
(UOp.var('x') * 0, lambda x: UOp.const(x.dtype, float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
224
|
-
(UOp.var('x') - UOp.var('x'), lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
|
225
|
-
# ** load/store folding **
|
226
|
-
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.load(UOp.var("buf"), UOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
|
227
|
-
# ** two stage add/sub folding **
|
228
|
-
((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
|
229
|
-
((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),
|
230
|
-
# *** rules from symbolic ***
|
231
|
-
# two stage mul, (x*c1)*c2 = x*(c1*c2)
|
232
|
-
((UOp.var("x") * UOp.cvar("c1")) * UOp.cvar("c2"), lambda x,c1,c2: x*UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
|
233
|
-
# x%1 -> 0
|
234
|
-
(UOp.var("x") % UOp.const(None, 1), lambda x: UOp.const(x.dtype, 0)),
|
235
|
-
# (x*c0)+(x*c1) -> x*(c0+c1)
|
236
|
-
(UOp.var("x") * UOp.cvar("c0") + UOp.var("x") * UOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
|
237
|
-
# (x*c0)+(y*c0) -> (x+y)*c0
|
238
|
-
#((UOp.var("x") * UOp.cvar("c0")) + (UOp.var("y") * UOp.cvar("c0")), lambda x,y,c0: c0*(x+y)),
|
239
|
-
# (x*c0)//c0 -> x
|
240
|
-
((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c0"), lambda x,c0: x if c0.arg != 0 else None),
|
241
|
-
# (x*x2)/x2 -> x
|
242
|
-
((UOp.var("x") * UOp.var("x2")) / UOp.var("x2"), lambda x,x2: x),
|
243
|
-
# (x//c0)//c1 -> x//(c0*c1)
|
244
|
-
((UOp.var("x") // UOp.cvar("c0")) // UOp.cvar("c1"), lambda x,c0,c1: x//UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))),
|
245
|
-
# (x/x1)/x2 -> x/(x1*x2)
|
246
|
-
((UOp.var("x") / UOp.var("x2")) / UOp.var("x3"), lambda x,x2,x3: x/(x2*x3)),
|
247
|
-
# c0 + x < c1 -> x < c1 - c0
|
248
|
-
((UOp.cvar("c0") + UOp.var("x")).lt(UOp.cvar("c1")),
|
249
|
-
lambda x,c0,c1: UOp.lt(x, UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
|
250
|
-
# (x+x*c0)-> x*(c0+1)
|
251
|
-
(UOp.var("x") + UOp.var("x") * UOp.cvar("c0"), lambda x,c0: x*UOp.const(x.dtype, c0.arg+1)),
|
252
|
-
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
253
|
-
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
|
254
|
-
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
|
255
|
-
# store float4/float2 directly (remove CAST/GEP)
|
256
|
-
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store),
|
257
|
-
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store),
|
258
|
-
# CAST-PHI-GEP -> PHI-CAST
|
259
|
-
(UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
|
260
|
-
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
|
261
|
-
(UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
|
262
|
-
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
|
263
|
-
# NEG/CMPLT -> CMPLT
|
264
|
-
(UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
|
265
|
-
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
266
|
-
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
267
|
-
# fold gated LOAD/STORE
|
268
|
-
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
|
269
|
-
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var"), UOp.var("barrier")),
|
270
|
-
lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
|
271
|
-
(UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var")), lambda var: var),
|
272
|
-
(UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var"), UOp.var()), lambda var: var),
|
273
|
-
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(None, 1)), UOp.store),
|
274
|
-
(UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(None, 0)), lambda: UOp(UOps.NOOP)),
|
275
|
-
# remove NOOPs from SINK
|
276
|
-
(UPat(UOps.SINK, name="root"),
|
277
|
-
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None)
|
278
|
-
])
|
279
|
-
|
280
|
-
# *** uop graph ***
|
212
|
+
def type_verify(uops):
|
213
|
+
for u in uops:
|
214
|
+
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
|
215
|
+
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
|
216
|
+
if uop is UOps.CONST:
|
217
|
+
assert dtype is not None and dtype == dtype.scalar(), f"consts should be scalar, got {dtype}"
|
218
|
+
assert type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
219
|
+
if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}"
|
220
|
+
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
|
221
|
+
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1
|
222
|
+
if uop is UOps.VECTORIZE:
|
223
|
+
assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}"
|
224
|
+
assert all(dtype == x.dtype.vec(len(src)) for x in src), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
|
225
|
+
if uop is UOps.LOAD and len(src) > 3 and src[3].op is UOps.ALU: assert src[3].dtype == dtypes.bool and src[2].dtype == dtype
|
226
|
+
if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}"
|
227
|
+
if uop is UOps.STORE:
|
228
|
+
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
|
229
|
+
if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
|
230
|
+
if uop is UOps.ALU:
|
231
|
+
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
232
|
+
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
|
233
|
+
assert dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), f"{arg} output dtype mismatch {dtype=} != {bd=}"
|
234
|
+
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
235
|
+
elif arg is BinaryOps.IDIV:
|
236
|
+
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}"
|
237
|
+
assert dtypes.is_int(dtype), f"output dtype is not int {dtype=}"
|
238
|
+
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
|
239
|
+
# the distance to shift isn't typechecked
|
240
|
+
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
241
|
+
elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
242
|
+
elif arg == TernaryOps.WHERE:
|
243
|
+
assert src[0].dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), \
|
244
|
+
f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}"
|
245
|
+
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
281
246
|
|
282
|
-
def
|
283
|
-
if u in
|
284
|
-
|
285
|
-
|
286
|
-
get_children_dfs(x, children, in_degree)
|
287
|
-
children[x].append(u)
|
288
|
-
in_degree[u] = len(u.src)
|
289
|
-
|
290
|
-
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
|
291
|
-
nodes: Dict[Tuple, UOp] = {}
|
292
|
-
replace: Dict[UOp, UOp] = {}
|
293
|
-
def __inner_rewrite(n:UOp) -> UOp:
|
294
|
-
if n in replace: return replace[n]
|
295
|
-
replace_source = (n.op, n.dtype, tuple(__inner_rewrite(y) for y in n.src), n.arg)
|
296
|
-
if found := nodes.get(replace_source): replace[n] = found
|
297
|
-
else: nodes[replace_source] = replace[n] = __inner_rewrite(new_x) if (new_x := pm.rewrite(x:=UOp(*replace_source))) else x
|
298
|
-
return replace[n]
|
299
|
-
return __inner_rewrite(sink)
|
300
|
-
|
301
|
-
class UOpGraph:
|
302
|
-
def __init__(self, sinks:List[UOp]):
|
303
|
-
self.sinks: List[UOp] = sinks
|
304
|
-
# used by linearizer
|
305
|
-
self._uops: Optional[List[UOp]] = None
|
306
|
-
|
307
|
-
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
|
308
|
-
def __getitem__(self, index) -> UOp: return self.uops[index]
|
309
|
-
|
310
|
-
def vars(self) -> List[Variable]: return sorted([x.arg for x in self.uops if x.op is UOps.DEFINE_VAR], key=lambda v: v.expr)
|
311
|
-
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.op is UOps.DEFINE_GLOBAL]
|
312
|
-
|
313
|
-
@property
|
314
|
-
def uops(self):
|
315
|
-
if self._uops is None: self.linearize()
|
316
|
-
return self._uops
|
317
|
-
|
318
|
-
def graph(self):
|
319
|
-
from tinygrad.engine.graph import graph_uops
|
320
|
-
graph_uops(self.uops)
|
321
|
-
|
322
|
-
def print(self):
|
323
|
-
for i,u in enumerate(self):
|
324
|
-
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}")
|
325
|
-
|
326
|
-
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
|
327
|
-
# NOTE: relinearizering should be okay
|
328
|
-
#assert self._uops is None, "already linearized"
|
329
|
-
sink = UOp(UOps.SINK, None, tuple(self.sinks))
|
330
|
-
|
331
|
-
# dedup all nodes and do graph rewrite
|
332
|
-
sink = graph_rewrite(sink, constant_folder)
|
333
|
-
if extra_pm: sink = graph_rewrite(sink, PatternMatcher(constant_folder.patterns+extra_pm.patterns))
|
334
|
-
|
335
|
-
# filter nodes that don't link to a sink
|
336
|
-
# BFS toposort
|
337
|
-
children: Dict[UOp, List[UOp]] = {}
|
338
|
-
in_degree: Dict[UOp, int] = {}
|
339
|
-
get_children_dfs(sink, children, in_degree)
|
340
|
-
|
341
|
-
@functools.lru_cache(None)
|
342
|
-
def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
|
343
|
-
if x.op is UOps.SINK: return set()
|
344
|
-
return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
|
345
|
-
|
346
|
-
# scope children impact the toposort and END* insertion
|
347
|
-
end_for_uop = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
|
348
|
-
loops, ifs = [x for x in in_degree if x.op is UOps.RANGE], [x for x in in_degree if x.op is UOps.IF]
|
349
|
-
scope_children = {p:get_recursive_children(p, end_for_uop[p.op][0]) for p in (loops+ifs)[::-1]}
|
350
|
-
|
351
|
-
queue:List[Tuple[int, UOp]] = []
|
352
|
-
def push(u:UOp):
|
353
|
-
priority = 0
|
354
|
-
# prefer uops that are loop children
|
355
|
-
for l, ss in scope_children.items():
|
356
|
-
if l.op is UOps.RANGE and u in ss: priority -= l.arg[0]*1000 + l.arg[1]
|
357
|
-
heapq.heappush(queue, (priority, u))
|
358
|
-
|
359
|
-
for u in children:
|
360
|
-
if in_degree[u] == 0: push(u)
|
361
|
-
|
362
|
-
if getenv("FUZZ_UOPS", 0):
|
363
|
-
from test.external.fuzz_uops import fuzz_uops
|
364
|
-
self.fuzz_paths = fuzz_uops(children, in_degree.copy(), scope_children)
|
365
|
-
|
366
|
-
self._uops = []
|
367
|
-
while queue:
|
368
|
-
p,x = heapq.heappop(queue)
|
369
|
-
if DEBUG >= 7: print(p,x)
|
370
|
-
if x.op is UOps.DEFINE_ACC and len(x.src) > 1:
|
371
|
-
idx = min([self._uops.index(l) for l in x.src if l.op is UOps.RANGE])
|
372
|
-
self._uops.insert(idx, x)
|
373
|
-
else:
|
374
|
-
self._uops.append(x)
|
375
|
-
for u, ss in scope_children.items():
|
376
|
-
if x in ss:
|
377
|
-
ss.remove(x)
|
378
|
-
if len(ss) == 0: self._uops.append(UOp(end_for_uop[u.op][1], None, (u,)))
|
379
|
-
for u in children[x]:
|
380
|
-
in_degree[u] -= 1
|
381
|
-
if in_degree[u] == 0: push(u)
|
382
|
-
|
383
|
-
assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
|
384
|
-
self._uops = self._uops[:-1]
|
385
|
-
|
386
|
-
if type_verify: self.type_verify()
|
387
|
-
|
388
|
-
# *** checker functions ***
|
247
|
+
def uop_alu_resolve(u:UOp) -> sint:
|
248
|
+
if u.op in {UOps.CONST, UOps.DEFINE_VAR}: return u.arg
|
249
|
+
if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src)))
|
250
|
+
raise RuntimeError(f"ALU resolve fail @ {u.op}")
|
389
251
|
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
mult_stack.append(mults)
|
407
|
-
mults *= uop_alu_resolve(u.src[1])
|
408
|
-
elif u.op is UOps.ENDRANGE:
|
409
|
-
mults = mult_stack.pop(-1)
|
410
|
-
elif u.op is UOps.LOAD:
|
411
|
-
assert u.dtype is not None
|
412
|
-
mem += u.dtype.itemsize * mults
|
252
|
+
def print_uops(uops:List[UOp]):
|
253
|
+
for i,u in enumerate(uops):
|
254
|
+
formatted_parents = [uops.index(x) if x.op is not UOps.CONST else f"{x.arg}" for x in u.src]
|
255
|
+
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str(formatted_parents):32s} {u.arg}")
|
256
|
+
|
257
|
+
def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
258
|
+
flops: sint = 0
|
259
|
+
mem: sint = 0
|
260
|
+
mults: sint = 1
|
261
|
+
mult_stack: List[sint] = []
|
262
|
+
dont_count: Set[UOp] = set()
|
263
|
+
if ignore_indexing:
|
264
|
+
for u in uops:
|
265
|
+
if u.op is UOps.LOAD:
|
266
|
+
dont_count = dont_count.union(u.src[1].sparents)
|
267
|
+
if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
|
413
268
|
elif u.op is UOps.STORE:
|
414
|
-
|
415
|
-
|
416
|
-
elif u.op is UOps.
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
if
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
440
|
-
elif arg is BinaryOps.IDIV:
|
441
|
-
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
|
442
|
-
f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
|
443
|
-
assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
|
444
|
-
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
|
445
|
-
# the distance to shift isn't typechecked
|
446
|
-
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
447
|
-
elif arg in BinaryOps:
|
448
|
-
assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
449
|
-
elif arg == TernaryOps.WHERE:
|
450
|
-
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
|
451
|
-
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
269
|
+
dont_count = dont_count.union(u.src[1].sparents)
|
270
|
+
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
|
271
|
+
elif u.op is UOps.IF:
|
272
|
+
dont_count = dont_count.union(u.src[0].sparents)
|
273
|
+
for u in uops:
|
274
|
+
if u.op is UOps.RANGE:
|
275
|
+
mult_stack.append(mults)
|
276
|
+
mults *= uop_alu_resolve(u.src[1] - u.src[0])
|
277
|
+
elif u.op is UOps.ENDRANGE:
|
278
|
+
mults = mult_stack.pop(-1)
|
279
|
+
elif u.op is UOps.SPECIAL:
|
280
|
+
mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
|
281
|
+
elif u.op is UOps.LOAD:
|
282
|
+
assert u.dtype is not None
|
283
|
+
mem += u.dtype.itemsize * mults
|
284
|
+
elif u.op is UOps.STORE:
|
285
|
+
assert u.src[2].dtype is not None
|
286
|
+
mem += u.src[2].dtype.itemsize * mults
|
287
|
+
elif u.op is UOps.ALU and u not in dont_count:
|
288
|
+
assert u.dtype is not None
|
289
|
+
flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
|
290
|
+
elif u.op is UOps.WMMA and u not in dont_count:
|
291
|
+
assert u.arg[1] is not None
|
292
|
+
flops += 2 * prod(u.arg[1]) // 32 * mults
|
293
|
+
return flops, mem
|