tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +248 -115
- tinygrad/codegen/lowerer.py +215 -0
- tinygrad/codegen/transcendental.py +310 -0
- tinygrad/codegen/uopgraph.py +622 -0
- tinygrad/codegen/uops.py +235 -393
- tinygrad/device.py +428 -69
- tinygrad/dtype.py +18 -4
- tinygrad/engine/graph.py +19 -32
- tinygrad/engine/jit.py +148 -70
- tinygrad/engine/realize.py +127 -51
- tinygrad/engine/schedule.py +259 -216
- tinygrad/engine/search.py +29 -22
- tinygrad/function.py +9 -0
- tinygrad/helpers.py +87 -49
- tinygrad/lazy.py +34 -35
- tinygrad/multi.py +41 -36
- tinygrad/nn/__init__.py +39 -22
- tinygrad/nn/state.py +3 -3
- tinygrad/ops.py +63 -62
- tinygrad/renderer/__init__.py +43 -21
- tinygrad/renderer/assembly.py +104 -106
- tinygrad/renderer/cstyle.py +87 -60
- tinygrad/renderer/llvmir.py +21 -30
- tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/kfd.py +32 -0
- tinygrad/runtime/autogen/libc.py +4260 -0
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/graph/clang.py +2 -2
- tinygrad/runtime/graph/cuda.py +8 -11
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +18 -15
- tinygrad/runtime/ops_amd.py +197 -305
- tinygrad/runtime/ops_clang.py +2 -2
- tinygrad/runtime/ops_cuda.py +36 -94
- tinygrad/runtime/ops_disk.py +3 -7
- tinygrad/runtime/ops_gpu.py +4 -2
- tinygrad/runtime/ops_hip.py +70 -0
- tinygrad/runtime/ops_metal.py +38 -27
- tinygrad/runtime/ops_nv.py +283 -363
- tinygrad/runtime/ops_python.py +26 -30
- tinygrad/runtime/support/compiler_cuda.py +78 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/shape/shapetracker.py +5 -14
- tinygrad/shape/symbolic.py +4 -8
- tinygrad/shape/view.py +34 -22
- tinygrad/tensor.py +399 -97
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
- tinygrad-0.9.2.dist-info/RECORD +70 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/runtime/{driver → support}/__init__.py +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,622 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Iterator, Optional, Tuple, Dict, List, Set, Union, cast, TYPE_CHECKING, Any, DefaultDict, Callable
|
3
|
+
import functools, itertools, heapq, math, operator
|
4
|
+
from collections import defaultdict
|
5
|
+
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
|
6
|
+
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu
|
7
|
+
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI, all_same, partition
|
8
|
+
from tinygrad.codegen.uops import UOp, NOp, UOps, UPat, PatternMatcher, END_FOR_UOP, type_verify, print_uops
|
9
|
+
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
10
|
+
if TYPE_CHECKING: from tinygrad.renderer import Renderer
|
11
|
+
|
12
|
+
# ***** float4/image store handling *****
|
13
|
+
|
14
|
+
def fold_expanded(ex, buf):
|
15
|
+
if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None
|
16
|
+
new_srcs = dedup(list(ex.src))
|
17
|
+
old_new_srcs = new_srcs[:]
|
18
|
+
is_load, is_image = new_srcs[0].op is UOps.LOAD, isinstance(buf.dtype, ImageDType)
|
19
|
+
|
20
|
+
# first, extract all the relevant offsets
|
21
|
+
offsets_rootsrc: DefaultDict[Any, dict] = defaultdict(dict)
|
22
|
+
for i,s in enumerate(new_srcs):
|
23
|
+
if (s.dtype is not None and s.dtype.count != 1) or (is_image and s.src[1].dtype != dtypes.int.vec(3)): continue
|
24
|
+
idx = s.src[1] if not is_image else s.src[1].src[2] # only id4 for image
|
25
|
+
if idx.arg is BinaryOps.ADD and idx.src[1].op is UOps.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
26
|
+
elif idx.op is UOps.CONST: root_src, arg = "CONST", idx.arg
|
27
|
+
else: root_src, arg = idx, 0
|
28
|
+
# add idx and idy for image
|
29
|
+
if is_image: root_src = (s.src[1].src[0:2], root_src)
|
30
|
+
# add gates for gated
|
31
|
+
if len(s.src) >= 4: root_src = (s.src[3], root_src)
|
32
|
+
assert arg not in offsets_rootsrc[root_src]
|
33
|
+
offsets_rootsrc[root_src][arg] = i
|
34
|
+
|
35
|
+
# then rewrite everything we can
|
36
|
+
used = set()
|
37
|
+
for rootsrc, offsets in offsets_rootsrc.items():
|
38
|
+
for o in offsets:
|
39
|
+
for fold_length in [4] if is_image else ([8,4,2] if buf.dtype == PtrDType(dtypes.half) and getenv("ALLOW_HALF8") else [4,2]):
|
40
|
+
if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)):
|
41
|
+
load_1 = new_srcs[offsets[o]]
|
42
|
+
new_src = list(load_1.src)
|
43
|
+
if not is_image and not new_src[1].divides(fold_length): continue
|
44
|
+
# for images, we rewrite the index
|
45
|
+
if is_image: new_src[1] = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (new_src[1].src[0], new_src[1].src[1]))
|
46
|
+
# vectorize the store/loadconst
|
47
|
+
if not is_load or len(new_src) >= 4:
|
48
|
+
new_src[2] = UOp(UOps.VECTORIZE, new_src[2].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[2] for i in range(fold_length)))
|
49
|
+
# generate the folded new_srcs
|
50
|
+
if is_load:
|
51
|
+
new_load = UOp(UOps.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
|
52
|
+
for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(UOps.GEP, load_1.dtype, (new_load,), i)
|
53
|
+
else:
|
54
|
+
for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(UOps.STORE, None, tuple(new_src)) if i == 0 else None
|
55
|
+
for i in range(fold_length): used.add((rootsrc,o+i))
|
56
|
+
|
57
|
+
# dedup expand for LOAD
|
58
|
+
if is_load and len(old_new_srcs) != len(ex.src): new_srcs = [new_srcs[old_new_srcs.index(s)] for s in ex.src]
|
59
|
+
# remove Nones for STORE
|
60
|
+
return UOp(ex.op, ex.dtype, tuple(x for x in new_srcs if x is not None), ex.arg) if len(used) else None
|
61
|
+
|
62
|
+
def vectorize_reduce(vec:UOp):
|
63
|
+
if all_same(vec.src): return None # don't REDUCE the same thing multiple times
|
64
|
+
if not all_same([(x.src[1:], x.arg) for x in vec.src]): return None
|
65
|
+
return UOp(UOps.REDUCE, vec.dtype, (UOp(UOps.VECTORIZE, vec.dtype, tuple(x.src[0] for x in vec.src)),) + vec.src[0].src[1:], vec.src[0].arg)
|
66
|
+
|
67
|
+
def vectorize_alu(vec:UOp):
|
68
|
+
if not all_same([x.arg for x in vec.src]): return None
|
69
|
+
return UOp(vec.src[0].op, vec.dtype, tuple(UOp(UOps.VECTORIZE, cast(DType, vec.src[0].src[i].dtype).vec(cast(DType, vec.dtype).count),
|
70
|
+
tuple(x.src[i] for x in vec.src)) for i in range(len(vec.src[0].src))), vec.src[0].arg)
|
71
|
+
|
72
|
+
float4_folding = PatternMatcher([
|
73
|
+
(UPat(UOps.EXPAND, src=UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True), name="ex"), fold_expanded),
|
74
|
+
(UPat({UOps.BARRIER, UOps.SINK}, src=UPat(UOps.STORE, src=(UPat(name="buf"), UPat(), UPat()), allow_any_len=True), name="ex"), fold_expanded),
|
75
|
+
(UPat(UOps.VECTORIZE, src=UPat(UOps.REDUCE), name="vec"), vectorize_reduce),
|
76
|
+
(UPat(UOps.VECTORIZE, src=UPat({UOps.ALU, UOps.CAST, UOps.BITCAST}), name="vec"), vectorize_alu),
|
77
|
+
])
|
78
|
+
|
79
|
+
# ***** mod *****
|
80
|
+
|
81
|
+
def _get_add_chain(x:UOp):
|
82
|
+
if x.op is UOps.ALU and x.arg is BinaryOps.ADD:
|
83
|
+
for s in x.src: yield from _get_add_chain(s)
|
84
|
+
else: yield x
|
85
|
+
|
86
|
+
def mod_folding(x:UOp, c:int) -> Optional[UOp]:
|
87
|
+
# simplify x in x % c
|
88
|
+
# None means no change
|
89
|
+
remainder, something_changed = [], False
|
90
|
+
for u in _get_add_chain(x):
|
91
|
+
if (factor:=u.const_factor())%c != factor:
|
92
|
+
remainder.append(u.divides(factor)*(factor%c))
|
93
|
+
something_changed = True
|
94
|
+
else: remainder.append(u)
|
95
|
+
if not something_changed: return None
|
96
|
+
return functools.reduce(operator.add, remainder) if remainder else x.const(0)
|
97
|
+
|
98
|
+
def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
99
|
+
# simplify x // c, None means no change
|
100
|
+
# simple cancel div case
|
101
|
+
if 0 <= x.vmin.arg and x.vmax.arg < c: return x.const(0)
|
102
|
+
|
103
|
+
quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1
|
104
|
+
for u in _get_add_chain(x):
|
105
|
+
if u.op is UOps.CONST:
|
106
|
+
# add all const together first
|
107
|
+
if rem_const != 0: something_changed = True
|
108
|
+
rem_const += u.arg
|
109
|
+
elif (factor:=u.const_factor())%c == 0:
|
110
|
+
if factor: quotient.append(u.divides(c))
|
111
|
+
something_changed = True
|
112
|
+
else:
|
113
|
+
# divisor is the smallest common divisor of all MULs
|
114
|
+
if u.op is UOps.ALU and u.arg is BinaryOps.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor
|
115
|
+
remainder.append(u)
|
116
|
+
gcd = math.gcd(gcd, factor)
|
117
|
+
|
118
|
+
# handle the const
|
119
|
+
if rem_const%c != rem_const:
|
120
|
+
something_changed = True
|
121
|
+
quotient.append(x.const(rem_const//c))
|
122
|
+
rem_const = rem_const%c
|
123
|
+
if rem_const != 0: remainder.append(x.const(rem_const))
|
124
|
+
|
125
|
+
# x // c -> quotient + (remainder // div) // (c // div)
|
126
|
+
div = gcd if gcd > 1 else divisor
|
127
|
+
|
128
|
+
if not something_changed: return newx//(c//div) if 1 < div < c and (newx:=div_folding(x, div)) is not None else None
|
129
|
+
rem:Optional[UOp] = functools.reduce(operator.add, remainder) if remainder else None
|
130
|
+
quo:Optional[UOp] = functools.reduce(operator.add, quotient) if quotient else None
|
131
|
+
if quo is None: return x.const(0) if rem is None else cast(UOp, div_folding(rem, div))//(c//div)
|
132
|
+
return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo
|
133
|
+
|
134
|
+
# ***** transcendental *****
|
135
|
+
|
136
|
+
def transcendental_folding(ops):
|
137
|
+
return PatternMatcher([(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="d"),), arg=k), cast(Callable, v))
|
138
|
+
for k,v in ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if k not in ops])
|
139
|
+
|
140
|
+
# ***** threefry *****
|
141
|
+
|
142
|
+
def threefry2x32(x: UOp, seed: UOp):
|
143
|
+
# split x into two uint32, since x in a uint64
|
144
|
+
x0, x1 = (x & 0xffffffff).cast(dtypes.uint32), ((x // 2**32) & 0xffffffff).cast(dtypes.uint32)
|
145
|
+
|
146
|
+
rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
|
147
|
+
ks = [0x0, (seed := seed.cast(dtypes.uint32)) ^ 0x1BD11BDA, seed]
|
148
|
+
xr = [x0 + ks[-1], x1 + ks[0]]
|
149
|
+
for i in range(5):
|
150
|
+
for r in rotations[i % 2]: xr[0], xr[1] = (x0 := xr[0] + xr[1]), x0 ^ ((xr[1] * 2**r) + (xr[1] // 2**(32 - r)))
|
151
|
+
xr = [(xr[0] + ks[i % 3]), (xr[1] + ks[(i + 1) % 3] + i + 1)]
|
152
|
+
|
153
|
+
return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64)
|
154
|
+
|
155
|
+
# ***** main rewriter *****
|
156
|
+
|
157
|
+
def reduce_before_expand(reduce, expand, x):
|
158
|
+
# if the expand is being reduced, you can't push it through
|
159
|
+
# NOTE: could do a partial push here in some cases
|
160
|
+
expands = flatten([x.arg for x in reduce.src[1:] if x.op is UOps.EXPAND])
|
161
|
+
if any(x in expands for x in expand.arg): return None
|
162
|
+
red = UOp(UOps.REDUCE, x.dtype, (x,)+reduce.src[1:], reduce.arg)
|
163
|
+
return UOp(expand.op, expand.dtype, tuple(UOp(UOps.GEP, reduce.dtype, (red,), i) for i in range(x.dtype.count)), expand.arg)
|
164
|
+
|
165
|
+
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None, extra=None):
|
166
|
+
if getenv("DISABLE_LOOP_COLLAPSE") or rng not in reduce.src: return None # must be the right REDUCE
|
167
|
+
if mval.arg >= 0 or loop_start.arg != 0:
|
168
|
+
# TODO: support and test this with other mvals and loop_starts
|
169
|
+
if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
|
170
|
+
return None
|
171
|
+
if idx2 is not None: idx = idx + idx2
|
172
|
+
if idx3 is not None: idx = idx + idx3
|
173
|
+
comprange = UOp.min(loop_end, UOp.max((idx-compval-mval)//mval + (loop_end-loop_start), loop_start))
|
174
|
+
new_reduce_op = comprange.cast(multconst.dtype) * multconst
|
175
|
+
ret = UOp(UOps.REDUCE, reduce.dtype, (new_reduce_op,) + tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
|
176
|
+
if extra is not None: ret = ret + UOp(UOps.REDUCE, reduce.dtype, (extra,) + reduce.src[1:], reduce.arg)
|
177
|
+
return ret
|
178
|
+
|
179
|
+
def index_collapse(idx,rng,buf,add,mul,ld,reduce):
|
180
|
+
if rng not in reduce.src: return None
|
181
|
+
return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx)),)+
|
182
|
+
tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
|
183
|
+
|
184
|
+
# this is symbolic 2.0
|
185
|
+
constant_folder = PatternMatcher([
|
186
|
+
# VECTORIZE/GEP
|
187
|
+
(NOp(UOps.GEP, src=(NOp(UOps.VECTORIZE, name="cast"),), name="gep"), lambda gep, cast: cast.src[gep.arg]),
|
188
|
+
*[(NOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(NOp(UOps.GEP, dtypes.float,
|
189
|
+
src=(NOp.var('x', dtype=dtypes.float.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in [2, 4, 8, 16]],
|
190
|
+
*[(NOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(NOp(UOps.GEP, dtypes.half,
|
191
|
+
src=(NOp.var('x', dtype=dtypes.half.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in [2, 4, 8, 16]],
|
192
|
+
# tensor core with a 0 input is acc
|
193
|
+
*[(NOp(UOps.WMMA, src=(NOp(UOps.VECTORIZE, src=tuple(NOp.const(None, 0.0) for _ in range(i))), NOp.var(), NOp.var('acc'))),
|
194
|
+
lambda acc: acc) for i in [2, 4, 8]],
|
195
|
+
*[(NOp(UOps.WMMA, src=(NOp.var(), NOp(UOps.VECTORIZE, src=tuple(NOp.const(None, 0.0) for _ in range(i))), NOp.var('acc'))),
|
196
|
+
lambda acc: acc) for i in [2, 4, 8]],
|
197
|
+
# tensor core cleanups
|
198
|
+
*[(NOp(UOps.REDUCE, src=(NOp(UOps.EXPAND, src=tuple(NOp(UOps.GEP, dtypes.float, src=(NOp.var('x'),), arg=i) for i in range(j)), name="expand"),)
|
199
|
+
,name="reduce", allow_any_len=True), reduce_before_expand) for j in [2,4,8]],
|
200
|
+
(NOp.var("add") + NOp(UOps.WMMA, name="wmma"),
|
201
|
+
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
202
|
+
# threefry
|
203
|
+
(NOp(UOps.ALU, dtype=dtypes.uint64, src=(NOp.var("x"), NOp.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32),
|
204
|
+
# extra arange loop folding because we don't fold adds. TODO: fold adds
|
205
|
+
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng") +
|
206
|
+
NOp.var("idx2") + NOp.var("idx3"))
|
207
|
+
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
208
|
+
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng") +
|
209
|
+
NOp.var("idx2"))
|
210
|
+
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
211
|
+
# arange loop folding (reduce)
|
212
|
+
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng"))
|
213
|
+
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
214
|
+
(NOp(UOps.REDUCE, src=((NOp.var("idx") - NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng"))
|
215
|
+
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True),
|
216
|
+
lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
|
217
|
+
# arange loop folding (unrolled)
|
218
|
+
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng"))
|
219
|
+
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)) + NOp.var("extra"),),
|
220
|
+
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
221
|
+
# indexing (with a multiply offset)!
|
222
|
+
(NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).cast()*
|
223
|
+
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"),),
|
224
|
+
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
|
225
|
+
(NOp(UOps.REDUCE, src=(NOp.var('idx').ne(NOp(UOps.RANGE, name="rng")).__neg__().cast()*
|
226
|
+
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp(UOps.RANGE, name="rng")), name="ld"),),
|
227
|
+
arg=BinaryOps.ADD, name="reduce", allow_any_len=True),
|
228
|
+
lambda **kwargs: index_collapse(add=UOp.const(dtypes.int, 0), mul=UOp.const(dtypes.int, 1), **kwargs)),
|
229
|
+
(NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).where(
|
230
|
+
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"), NOp.const(None, 0.0)),),
|
231
|
+
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
|
232
|
+
# other arange folders
|
233
|
+
(NOp.cvar("c1") - (NOp.var("x") + NOp.cvar("c2")), lambda c1, c2, x: (c1-c2)-x), # c1 - (x + c2) -> (c1-c2) - x
|
234
|
+
(-(NOp.var("x") * NOp.cvar("c1")), lambda x, c1: x*-c1),
|
235
|
+
# max folding
|
236
|
+
(NOp.max(NOp.var('x'), NOp.var('y')), lambda x,y: x if x.vmin.arg >= y.vmax.arg else y if x.vmax.arg <= y.vmin.arg else None),
|
237
|
+
# const rules
|
238
|
+
(NOp(UOps.GEP, src=(NOp.cvar("c"),), name="root"), lambda root, c: root.const(c.arg)),
|
239
|
+
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const(c.arg)),
|
240
|
+
# a REDUCE without ranges is a NOOP
|
241
|
+
(NOp(UOps.REDUCE, src=(NOp.var('x'),)), lambda x: x),
|
242
|
+
# GEP on a const is the const
|
243
|
+
(NOp(UOps.GEP, src=(NOp.cvar("x"),), name="root"), lambda root,x: root.const(x.arg)),
|
244
|
+
# a conditional with the same results either way is a noop, also fold const conditionals
|
245
|
+
(NOp.var().where(NOp.var("val"), NOp.var("val")), lambda val: val),
|
246
|
+
(NOp.cvar('gate').where(NOp.var('c0'), NOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
|
247
|
+
# ** constant folding **
|
248
|
+
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: root.const(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
|
249
|
+
# ** self folding **
|
250
|
+
(-(-NOp.var('x')), lambda x: x), # -(-x) -> x
|
251
|
+
(NOp.var('x') + 0, lambda x: x), # x+0 -> x
|
252
|
+
(NOp.var('x') * 1, lambda x: x), # x*1 -> x
|
253
|
+
(NOp.var('x') * -1, lambda x: -x), # x*-1 -> -x
|
254
|
+
(NOp.var('x') // NOp.var('x'), lambda x: x.const(1)), # x//x -> 1
|
255
|
+
(NOp.var('x') // 1, lambda x: x), # x//1 -> x
|
256
|
+
(NOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
|
257
|
+
(NOp.var('x') / NOp.var('x'), lambda x: x.const(1)), # x/x -> 1
|
258
|
+
(NOp.var('x') / NOp.cvar('c'), lambda x,c: x*exec_alu(UnaryOps.RECIP, c.dtype, [c.arg])), # x/c -> x*(1/c)
|
259
|
+
# ** zero folding **
|
260
|
+
# x*0 -> 0 or 0*x -> 0
|
261
|
+
# if x is nan or inf it should render the nan value.
|
262
|
+
# NOTE: this can be wrong for loaded NaN
|
263
|
+
(NOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
264
|
+
# x-x -> 0
|
265
|
+
(NOp.var('x') - NOp.var('x'), lambda x: x.const(0)),
|
266
|
+
(UPat(UOps.ALU, name='x'), lambda x: x.const(x.vmin.arg) if x.vmin.arg == x.vmax.arg else None),
|
267
|
+
# ** load/store folding **
|
268
|
+
(NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.load(NOp.var("buf"), NOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
|
269
|
+
# ** two stage add/mul folding **
|
270
|
+
((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x+x.const(exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
|
271
|
+
((NOp.var("x") * NOp.cvar("c1")) * NOp.cvar("c2"), lambda x,c1,c2: x*x.const(exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
|
272
|
+
# *** rules from symbolic ***
|
273
|
+
# ** lt **
|
274
|
+
# c0*x<c1 for positive int c0,c1
|
275
|
+
((NOp.cvar('c0')*NOp.var('x')).lt(NOp.cvar('c1')),
|
276
|
+
lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if dtypes.is_int(x.dtype) and c0.arg > 0 and c1.arg > 0 else None),
|
277
|
+
# mul add lt
|
278
|
+
(((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')).lt(NOp.cvar('c1')),
|
279
|
+
lambda x,x2,c0,c1: x.lt(c1.arg//c0.arg) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax.arg and x2.vmin.arg >= 0 else None),
|
280
|
+
# generic lt folding (use div)
|
281
|
+
(NOp.var('x').lt(NOp.cvar('c')), lambda x,c: newx.src[0].lt(newx.src[1]) if 0 < c.arg and dtypes.is_int(x.dtype) and \
|
282
|
+
not dtypes.is_unsigned(x.dtype) and (newx:=div_folding(x,c.arg)) is not None and newx.op is UOps.ALU and newx.arg is BinaryOps.IDIV else None),
|
283
|
+
# ** div **
|
284
|
+
# # div folding
|
285
|
+
(NOp.var('x') // NOp.cvar('c'), lambda x,c:
|
286
|
+
newx if 0 < c.arg and not dtypes.is_unsigned(x.dtype) and (newx:=div_folding(x,c.arg)) is not None else None),
|
287
|
+
# mul add div
|
288
|
+
(((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')) // NOp.cvar('c1'), lambda x,x2,c0,c1:\
|
289
|
+
x*(c0.arg//g)//(c1.arg//g) if c0.arg > 0 and c1.arg > 0 and (g:=math.gcd(c0.arg,c1.arg)) > 1 and g > x2.vmax.arg and x2.vmin.arg >= 0 else None),
|
290
|
+
# ** mod **
|
291
|
+
# apply mod to mod input
|
292
|
+
(NOp.var('x') % NOp.cvar('c'), lambda x,c: newx%c if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
|
293
|
+
# remove mod
|
294
|
+
(NOp.var('x') % NOp.cvar('c'), lambda x,c:\
|
295
|
+
x-(x.vmin.arg//c.arg)*c.arg if 0 < c.arg and 0 <= x.vmin.arg and x.vmin.arg//c.arg == x.vmax.arg//c.arg else None),
|
296
|
+
# mul mod
|
297
|
+
((NOp.cvar('c0')*NOp.var('x')) % NOp.cvar('c1'), lambda x,c0,c1: (x%(c1.arg//c0.arg))*c0 if c1.arg%c0.arg == 0 else None),
|
298
|
+
# mod mod
|
299
|
+
((NOp.var('x') % NOp.cvar('c0')) % NOp.cvar('c1'), lambda x,c0,c1: x % c1 if c0.arg % c1.arg == 0 else None),
|
300
|
+
# (x%c)+(x//c)*c = x
|
301
|
+
(NOp.var('x')%NOp.cvar('c')+(NOp.var('x')//NOp.cvar('c'))*NOp.cvar('c'), lambda x,c: x),
|
302
|
+
# ** combine terms **
|
303
|
+
# -(x+y) -> -x + -y
|
304
|
+
(-(NOp.var("x") + NOp.var("y")), lambda x,y: (-x)+(-y)),
|
305
|
+
# (x+c0)*c1 -> x*c1+c0*c1. only for signed int, float have inf*0=nan issue
|
306
|
+
((NOp.var("x") + NOp.cvar("c0")) * NOp.cvar("c1"), lambda x,c0,c1:
|
307
|
+
x*c1+c0.arg*c1.arg if dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
|
308
|
+
# (x*c0)+(x*c1) -> x*(c0+c1)
|
309
|
+
(NOp.var("x") * NOp.cvar("c0") + NOp.var("x") * NOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
|
310
|
+
# (x*c0)+(y*c0) -> (x+y)*c0
|
311
|
+
#((NOp.var("x") * NOp.cvar("c0")) + (NOp.var("y") * NOp.cvar("c0")), lambda x,y,c0: c0*(x+y)),
|
312
|
+
# (x*x2)/x2 -> x
|
313
|
+
((NOp.var("x") * NOp.var("x2")) / NOp.var("x2"), lambda x,x2: x),
|
314
|
+
# (x//c0)//c1 -> x//(c0*c1)
|
315
|
+
((NOp.var("x") // NOp.cvar("c0")) // NOp.cvar("c1"), lambda x,c0,c1: x//x.const(exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))),
|
316
|
+
# (x/x1)/x2 -> x/(x1*x2)
|
317
|
+
((NOp.var("x") / NOp.var("x2")) / NOp.var("x3"), lambda x,x2,x3: x/(x2*x3)),
|
318
|
+
# c0 + x < c1 -> x < c1 - c0
|
319
|
+
((NOp.cvar("c0") + NOp.var("x")).lt(NOp.cvar("c1")), lambda x,c0,c1: UOp.lt(x, x.const(exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
|
320
|
+
# (x+x*c0)-> x*(c0+1)
|
321
|
+
(NOp.var("x") + NOp.var("x") * NOp.cvar("c0"), lambda x,c0: x*(c0.arg+1)),
|
322
|
+
# x!=0 -> (bool)x
|
323
|
+
(NOp.var("x").ne(0), lambda x: x.cast(dtypes.bool)),
|
324
|
+
# bool != 1 -> not bool
|
325
|
+
(NOp.var("x", dtype=dtypes.bool).ne(1), lambda x: -x),
|
326
|
+
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
327
|
+
(NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.var("gate").where(NOp.var("alt"), NOp.load(NOp.var("buf"), NOp.var("idx")))),
|
328
|
+
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
|
329
|
+
# VECTORIZE-PHI-GEP -> PHI-VECTORIZE
|
330
|
+
(NOp(UOps.VECTORIZE, src=tuple(NOp(UOps.PHI, src=(NOp(UOps.GEP, src=(NOp.var("val"),), arg=i), NOp.var(f"v{i}"))) for i in range(4)), name="root"),
|
331
|
+
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))),
|
332
|
+
(NOp(UOps.VECTORIZE, src=tuple(NOp(UOps.PHI, src=(NOp(UOps.GEP, src=(NOp.var("val"),), arg=i), NOp.var(f"v{i}"))) for i in range(2)), name="root"),
|
333
|
+
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1))))),
|
334
|
+
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
335
|
+
(NOp(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
336
|
+
(NOp(UOps.VECTORIZE, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
337
|
+
# fold gated LOAD/STORE
|
338
|
+
(NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.var("var"), NOp.const(dtypes.bool, True)), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
|
339
|
+
(NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.var("var"), NOp.const(dtypes.bool, True), NOp.var("barrier")),
|
340
|
+
lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
|
341
|
+
(NOp.load(NOp.var(), NOp.var(), NOp.var("var"), NOp.const(dtypes.bool, False)), lambda var: var),
|
342
|
+
(NOp.load(NOp.var(), NOp.var(), NOp.var("var"), NOp.const(dtypes.bool, False), NOp.var()), lambda var: var),
|
343
|
+
(NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.var("val"), NOp.const(dtypes.bool, True)), UOp.store),
|
344
|
+
(NOp.store(NOp.var(), NOp.var(), NOp.var(), NOp.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)),
|
345
|
+
# remove NOOPs from SINK
|
346
|
+
(NOp(UOps.SINK, name="root"),
|
347
|
+
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None),
|
348
|
+
# ** move add consts to end (NOTE: this is still happening before constant folding) **
|
349
|
+
(UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(UOps.CONST, name='c1'), UPat(name='x'))), lambda c1,x: x+c1 if x.op is not UOps.CONST else None),
|
350
|
+
(UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name='x'), UPat(UOps.CONST, name='c1'))), UPat(name='y')]),
|
351
|
+
lambda x,c1,y: (x+y)+c1),
|
352
|
+
])
|
353
|
+
|
354
|
+
# *** uop expander ***
|
355
|
+
|
356
|
+
def _expand_arg_to_idx(args:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]) -> int:
|
357
|
+
idx, mul = 0, 1
|
358
|
+
for axis,m in args[::-1]:
|
359
|
+
idx += rpk[axis] * mul
|
360
|
+
mul *= m
|
361
|
+
return idx
|
362
|
+
|
363
|
+
def _choices_from_args(args:Tuple[Tuple[int, int], ...]) -> List[Dict[int, int]]:
|
364
|
+
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
|
365
|
+
|
366
|
+
def do_expand(root:UOp):
|
367
|
+
expands = [x for x in root.src if x.op is UOps.EXPAND]
|
368
|
+
if len(expands) == 0: return None
|
369
|
+
expand_args = tuple(sorted(dedup(flatten([x.arg for x in expands]))))
|
370
|
+
if root.op is UOps.WMMA:
|
371
|
+
# both the reduce and upcast args are not expanded here
|
372
|
+
dont_expand_args = tuple(x for x in expand_args if x[0] in root.arg[-1] or x[0] in [y[0] for y in flatten(root.arg[-2])])
|
373
|
+
expand_args = tuple(x for x in expand_args if x not in dont_expand_args)
|
374
|
+
else:
|
375
|
+
dont_expand_args = ()
|
376
|
+
new_srcs: List[UOp] = []
|
377
|
+
lrpks = _choices_from_args(dont_expand_args)
|
378
|
+
for rpk in _choices_from_args(expand_args):
|
379
|
+
new_src: List[UOp] = []
|
380
|
+
for src in root.src:
|
381
|
+
if src.op is UOps.EXPAND:
|
382
|
+
lnew_src = tuple(src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})] for lrpk in lrpks)
|
383
|
+
# TODO: is this right for UOps.WMMA? when there's more than one, all lnew_src should be the same
|
384
|
+
new_src.append(lnew_src[0] if len(lnew_src) == 1 or root.op is UOps.WMMA else UOp(UOps.EXPAND, root.dtype, lnew_src, dont_expand_args))
|
385
|
+
else:
|
386
|
+
new_src.append(src)
|
387
|
+
new_srcs.append(UOp(root.op, root.dtype, tuple(new_src), root.arg))
|
388
|
+
if root.op is UOps.EXPAND:
|
389
|
+
# merge two expands
|
390
|
+
expand_args, old_args = tuple(sorted(root.arg+expand_args)), expand_args
|
391
|
+
assert len(expand_args) == (len(old_args) + len(root.arg))
|
392
|
+
new_srcs = [new_srcs[_expand_arg_to_idx(old_args, rpk)].src[_expand_arg_to_idx(root.arg, rpk)] for rpk in _choices_from_args(expand_args)]
|
393
|
+
if root.op is UOps.IF:
|
394
|
+
# merge ifs into an or
|
395
|
+
conditions = functools.reduce(lambda x,y: x|y, dedup(x.src[0] for x in new_srcs if x.src[0].op is not UOps.CONST))
|
396
|
+
barriers = tuple(set(x.src[1] for x in new_srcs))
|
397
|
+
new_srcs = [UOp(UOps.IF, src=(conditions,)+barriers) for _ in new_srcs]
|
398
|
+
assert prod([x[1] for x in expand_args]) == len(new_srcs)
|
399
|
+
return UOp(UOps.EXPAND, root.dtype, tuple(new_srcs), expand_args)
|
400
|
+
|
401
|
+
acc_number = 0
|
402
|
+
def do_reduce(root:UOp):
|
403
|
+
global acc_number
|
404
|
+
reduce_parented, reduce_unparented = partition(root.src[1:], lambda x: x in root.src[0].parents)
|
405
|
+
ret = root.src[0]
|
406
|
+
if len(reduce_parented):
|
407
|
+
assert root.dtype is not None
|
408
|
+
const = UOp.const(root.dtype, 0 if root.arg is BinaryOps.ADD else dtypes.min(root.dtype.scalar()))
|
409
|
+
acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(reduce_parented), (acc_number,))
|
410
|
+
acc_number += 1
|
411
|
+
ret = UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, ret)))
|
412
|
+
# for MAX, we can just ignore the unparented
|
413
|
+
if root.arg is BinaryOps.ADD:
|
414
|
+
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype)
|
415
|
+
return ret
|
416
|
+
|
417
|
+
def do_contract(con:UOp):
|
418
|
+
ex = con.src[0]
|
419
|
+
assert con.dtype is not None
|
420
|
+
# CONTRACT without EXPAND repeats the element VECTORIZED
|
421
|
+
if ex.op is not UOps.EXPAND: return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count)
|
422
|
+
# CONTRACT may remove several axes from EXPAND
|
423
|
+
assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
|
424
|
+
srcs = []
|
425
|
+
for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
|
426
|
+
lsrcs = [ex.src[_expand_arg_to_idx(ex.arg, {**rpk, **lrpk})] for lrpk in _choices_from_args(con.arg)]
|
427
|
+
srcs.append(UOp(UOps.VECTORIZE, con.dtype, tuple(lsrcs)))
|
428
|
+
return srcs[0] if len(srcs) == 1 else UOp(UOps.EXPAND, con.dtype, tuple(srcs), new_ex_args)
|
429
|
+
|
430
|
+
def no_vectorized_alu(alu):
|
431
|
+
if alu.dtype.count == 1: return None
|
432
|
+
alus = tuple(UOp(alu.op, alu.dtype.scalar(),
|
433
|
+
tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), i) for s in alu.src), alu.arg) for i in range(alu.dtype.count))
|
434
|
+
return UOp(UOps.VECTORIZE, alu.dtype, alus)
|
435
|
+
|
436
|
+
def create_gate(root:UOp) -> Optional[UOp]:
|
437
|
+
@functools.lru_cache(None)
|
438
|
+
def _gate_srcs(u:UOp, gate:UOp) -> UOp:
|
439
|
+
if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER: return UOp(u.op, u.dtype, u.src[:-1]+(UOp(UOps.IF, None, (gate, u.src[-1])),), u.arg)
|
440
|
+
return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg)
|
441
|
+
return None if len(root.src) == 3 or (ret:=_gate_srcs(root, root.src[3])) is root else ret
|
442
|
+
|
443
|
+
expander = PatternMatcher([
|
444
|
+
# create gate MUST BE BEFORE expander
|
445
|
+
(NOp(UOps.STORE, name="root"), create_gate),
|
446
|
+
# do expansion
|
447
|
+
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE,
|
448
|
+
UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root"), do_expand),
|
449
|
+
(NOp(UOps.CONTRACT, name="con"), do_contract),
|
450
|
+
# remove EXPANDs from SINK
|
451
|
+
(NOp(UOps.SINK, name="root"),
|
452
|
+
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg)
|
453
|
+
if len(a:=tuple(flatten(x.src if x.op is UOps.EXPAND else (x,) for x in root.src))) != len(root.src) else None),
|
454
|
+
# BARRIERs aren't actually expanded
|
455
|
+
(NOp(UOps.BARRIER, src=(NOp(UOps.EXPAND, name="ex"),)), lambda ex: UOp(UOps.EXPAND, None, (UOp(UOps.BARRIER, None, ex.src),)*len(ex.src), ex.arg)),
|
456
|
+
# empty EXPAND is NOOP
|
457
|
+
(NOp(UOps.EXPAND, src=(NOp.var('x'),), arg=()), lambda x: x),
|
458
|
+
# EXPAND GEP (needed for WMMA, generalize this) -> vectorized ALU
|
459
|
+
(NOp(UOps.EXPAND, name="ex", src=tuple(NOp.var('x').gep(i)+NOp.var('y').gep(i) for i in range(8))),
|
460
|
+
lambda ex,x,y: UOp(UOps.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(8)), ex.arg)),
|
461
|
+
])
|
462
|
+
|
463
|
+
def delete_redundant_gates(root:UOp) -> Optional[UOp]:
|
464
|
+
@functools.lru_cache(None)
|
465
|
+
def find_gate(x:UOp) -> Optional[UOp]:
|
466
|
+
if x.op is UOps.IF: return x
|
467
|
+
return next((ret for s in x.src if (ret:=find_gate(s)) is not None), None)
|
468
|
+
if len(root.src) == 3 or (gate:=find_gate(root)) is None or gate.src[0] is not root.src[3]: return None
|
469
|
+
return UOp(UOps.STORE, root.dtype, root.src[:3], root.arg)
|
470
|
+
|
471
|
+
reducer = PatternMatcher([
|
472
|
+
(NOp(UOps.REDUCE, name="root"), do_reduce),
|
473
|
+
# no ALU on vectorized dtypes
|
474
|
+
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST}, name="alu"), no_vectorized_alu),
|
475
|
+
# delete_redundant_gates (after expand, is this still needed?)
|
476
|
+
(NOp(UOps.STORE, name="root"), delete_redundant_gates),
|
477
|
+
])
|
478
|
+
|
479
|
+
# *** uop graph ***
|
480
|
+
|
481
|
+
def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[UOp, None]], in_degree:Dict[UOp, int]):
|
482
|
+
if u in children: return srcs[u]
|
483
|
+
srcs[u] = {}
|
484
|
+
children[u] = []
|
485
|
+
for x in u.src:
|
486
|
+
srcs[u].update(get_children_dfs(x, children, srcs, in_degree))
|
487
|
+
if x.op is UOps.RANGE and x.arg[1]: srcs[u][x] = None
|
488
|
+
children[x].append(u)
|
489
|
+
in_degree[u] = len(u.src)
|
490
|
+
return srcs[u]
|
491
|
+
|
492
|
+
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
|
493
|
+
nodes: Dict[Tuple, UOp] = {}
|
494
|
+
replace: Dict[UOp, UOp] = {}
|
495
|
+
def __inner_rewrite(n:UOp) -> UOp:
|
496
|
+
if n in replace: return replace[n]
|
497
|
+
replace_source = (n.op, n.dtype, tuple(__inner_rewrite(y) for y in n.src), n.arg)
|
498
|
+
if found := nodes.get(replace_source): replace[n] = found
|
499
|
+
else: nodes[replace_source] = replace[n] = found = __inner_rewrite(new_x) if (new_x := pm.rewrite(x:=UOp(*replace_source))) else x
|
500
|
+
return found
|
501
|
+
return __inner_rewrite(sink)
|
502
|
+
|
503
|
+
class UOpGraph:
|
504
|
+
def __init__(self, sink:Union[UOp, List[UOp]], opts:Optional[Renderer]=None):
|
505
|
+
self.sink: UOp = sink if isinstance(sink, UOp) else UOp(UOps.SINK, None, tuple(sink))
|
506
|
+
assert self.sink.op is UOps.SINK, f"sink isn't sink, it's {self.sink.op}"
|
507
|
+
# used by linearizer
|
508
|
+
self._uops: Optional[List[UOp]] = None
|
509
|
+
self.opts = opts
|
510
|
+
self.folder = constant_folder + transcendental_folding({} if TRANSCENDENTAL >= 2 or opts is None else opts.code_for_op.keys())
|
511
|
+
|
512
|
+
def __reduce__(self): return self.__class__, (self.sink, self.opts)
|
513
|
+
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
|
514
|
+
def __getitem__(self, index) -> UOp: return self.uops[index]
|
515
|
+
|
516
|
+
@property
|
517
|
+
def uops(self) -> List[UOp]:
|
518
|
+
if self._uops is None: self.linearize()
|
519
|
+
return cast(List[UOp], self._uops)
|
520
|
+
|
521
|
+
def graph(self):
|
522
|
+
from tinygrad.engine.graph import graph_uops
|
523
|
+
graph_uops(self.uops)
|
524
|
+
|
525
|
+
def print(self): print_uops(self.uops)
|
526
|
+
|
527
|
+
cnt = 0
|
528
|
+
def linearize(self, extra_pm:Optional[PatternMatcher]=None, skip_check=False) -> UOpGraph:
|
529
|
+
global acc_number
|
530
|
+
acc_number = 0
|
531
|
+
|
532
|
+
# NOTE: relinearizering should be okay
|
533
|
+
#assert self._uops is None, "already linearized"
|
534
|
+
|
535
|
+
# do graph rewrite
|
536
|
+
sink = graph_rewrite(self.sink, self.folder)
|
537
|
+
|
538
|
+
# rewrite pyint to int32
|
539
|
+
sink = graph_rewrite(sink, PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE}, dtype=dtypes.pyint, name="x"),
|
540
|
+
lambda x: UOp(x.op, dtypes.int32, x.src, x.arg))]))
|
541
|
+
|
542
|
+
# expand
|
543
|
+
UOpGraph.cnt += 1
|
544
|
+
if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0):
|
545
|
+
sink = graph_rewrite(sink, self.folder+expander+float4_folding if self.opts is not None and self.opts.supports_float4 else self.folder+expander)
|
546
|
+
sink = graph_rewrite(sink, self.folder+expander+reducer)
|
547
|
+
|
548
|
+
# for PTX only
|
549
|
+
if extra_pm: sink = graph_rewrite(sink, self.folder+extra_pm)
|
550
|
+
|
551
|
+
# filter nodes that don't link to a sink
|
552
|
+
# BFS toposort
|
553
|
+
children: Dict[UOp, List[UOp]] = {}
|
554
|
+
range_srcs: Dict[UOp, Dict[UOp, None]] = {}
|
555
|
+
in_degree: Dict[UOp, int] = {}
|
556
|
+
get_children_dfs(sink, children, range_srcs, in_degree)
|
557
|
+
|
558
|
+
@functools.lru_cache(None)
|
559
|
+
def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
|
560
|
+
if x.op is UOps.SINK: return set()
|
561
|
+
return set.union({x} if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
|
562
|
+
|
563
|
+
# scope children impact the toposort and END* insertion
|
564
|
+
scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
|
565
|
+
range_phi = {r:[p for p in scope_children[r] if p.op is UOps.PHI] for r in scope_children if r.op is UOps.RANGE}
|
566
|
+
|
567
|
+
queue:List[Tuple[int, UOp]] = []
|
568
|
+
def push(u:UOp):
|
569
|
+
priority = 0
|
570
|
+
# prefer ranges that depend on the least number of independent ranges
|
571
|
+
if u.op is UOps.RANGE and u.arg[1]:
|
572
|
+
priority += u.arg[0]
|
573
|
+
for p in range_phi[u]:
|
574
|
+
priority += 10000*len([r for r in range_srcs[p] if not any(i in range_phi[u] for i in range_phi[r])])
|
575
|
+
# prefer uops that are loop children
|
576
|
+
else:
|
577
|
+
priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is UOps.RANGE and u in ss])
|
578
|
+
heapq.heappush(queue, (priority, u))
|
579
|
+
|
580
|
+
for u in children:
|
581
|
+
if in_degree[u] == 0: push(u)
|
582
|
+
|
583
|
+
scope_end: Dict[UOp, UOp] = {}
|
584
|
+
self._uops = []
|
585
|
+
while queue:
|
586
|
+
p,x = heapq.heappop(queue)
|
587
|
+
if DEBUG >= 7: print(f"{p:5d}",x)
|
588
|
+
if x in scope_children: scope_end[x] = x
|
589
|
+
if x.op is UOps.DEFINE_ACC:
|
590
|
+
idx = min([self._uops.index(l) for l in x.src if l.op is UOps.RANGE])
|
591
|
+
self._uops.insert(idx, x)
|
592
|
+
else: self._uops.append(x)
|
593
|
+
for u, ss in scope_children.items():
|
594
|
+
if x in ss:
|
595
|
+
ss.remove(x)
|
596
|
+
if len(ss) == 0: scope_end[u] = x
|
597
|
+
for u in children[x]:
|
598
|
+
in_degree[u] -= 1
|
599
|
+
if in_degree[u] == 0: push(u)
|
600
|
+
|
601
|
+
# end scopes in toposort order
|
602
|
+
for u, x in scope_end.items(): self._uops.insert(self._uops.index(x)+1, UOp(END_FOR_UOP[u.op][1], None, (u,)))
|
603
|
+
|
604
|
+
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
|
605
|
+
if not skip_check:
|
606
|
+
bad_ops = dedup([x.op for x in self._uops if x.op in {UOps.EXPAND, UOps.CONTRACT, UOps.REDUCE}])
|
607
|
+
try:
|
608
|
+
type_verify(self.uops)
|
609
|
+
assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
|
610
|
+
assert len(bad_ops) == 0, f"bad UOps left in list: {bad_ops}"
|
611
|
+
# TODO: this should be enabled, and the valid clause should be removed
|
612
|
+
# NOTE: multiple identical stores to DEFINE_LOCAL is okay
|
613
|
+
assert len(all_stores := [x.src[0:2]+x.src[3:] for x in self._uops if x.op is UOps.STORE and x.src[0].op is not UOps.DEFINE_LOCAL]) \
|
614
|
+
== len(dedup(all_stores)), "repeated stores in uops"
|
615
|
+
except AssertionError as e:
|
616
|
+
self.print()
|
617
|
+
if not CI: self.graph()
|
618
|
+
raise e
|
619
|
+
|
620
|
+
# strip the SINK
|
621
|
+
self._uops = self._uops[:-1]
|
622
|
+
return self
|