tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tinygrad/codegen/kernel.py +248 -115
  2. tinygrad/codegen/lowerer.py +215 -0
  3. tinygrad/codegen/transcendental.py +310 -0
  4. tinygrad/codegen/uopgraph.py +622 -0
  5. tinygrad/codegen/uops.py +235 -393
  6. tinygrad/device.py +428 -69
  7. tinygrad/dtype.py +18 -4
  8. tinygrad/engine/graph.py +19 -32
  9. tinygrad/engine/jit.py +148 -70
  10. tinygrad/engine/realize.py +127 -51
  11. tinygrad/engine/schedule.py +259 -216
  12. tinygrad/engine/search.py +29 -22
  13. tinygrad/function.py +9 -0
  14. tinygrad/helpers.py +87 -49
  15. tinygrad/lazy.py +34 -35
  16. tinygrad/multi.py +41 -36
  17. tinygrad/nn/__init__.py +39 -22
  18. tinygrad/nn/state.py +3 -3
  19. tinygrad/ops.py +63 -62
  20. tinygrad/renderer/__init__.py +43 -21
  21. tinygrad/renderer/assembly.py +104 -106
  22. tinygrad/renderer/cstyle.py +87 -60
  23. tinygrad/renderer/llvmir.py +21 -30
  24. tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
  25. tinygrad/runtime/autogen/cuda.py +6 -162
  26. tinygrad/runtime/autogen/kfd.py +32 -0
  27. tinygrad/runtime/autogen/libc.py +4260 -0
  28. tinygrad/runtime/autogen/nvrtc.py +579 -0
  29. tinygrad/runtime/graph/clang.py +2 -2
  30. tinygrad/runtime/graph/cuda.py +8 -11
  31. tinygrad/runtime/graph/hcq.py +120 -107
  32. tinygrad/runtime/graph/metal.py +18 -15
  33. tinygrad/runtime/ops_amd.py +197 -305
  34. tinygrad/runtime/ops_clang.py +2 -2
  35. tinygrad/runtime/ops_cuda.py +36 -94
  36. tinygrad/runtime/ops_disk.py +3 -7
  37. tinygrad/runtime/ops_gpu.py +4 -2
  38. tinygrad/runtime/ops_hip.py +70 -0
  39. tinygrad/runtime/ops_metal.py +38 -27
  40. tinygrad/runtime/ops_nv.py +283 -363
  41. tinygrad/runtime/ops_python.py +26 -30
  42. tinygrad/runtime/support/compiler_cuda.py +78 -0
  43. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
  44. tinygrad/runtime/support/elf.py +38 -0
  45. tinygrad/shape/shapetracker.py +5 -14
  46. tinygrad/shape/symbolic.py +4 -8
  47. tinygrad/shape/view.py +34 -22
  48. tinygrad/tensor.py +399 -97
  49. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
  50. tinygrad-0.9.2.dist-info/RECORD +70 -0
  51. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
  52. tinygrad/codegen/linearizer.py +0 -528
  53. tinygrad-0.9.1.dist-info/RECORD +0 -63
  54. /tinygrad/runtime/{driver → support}/__init__.py +0 -0
  55. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
  56. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,622 @@
1
+ from __future__ import annotations
2
+ from typing import Iterator, Optional, Tuple, Dict, List, Set, Union, cast, TYPE_CHECKING, Any, DefaultDict, Callable
3
+ import functools, itertools, heapq, math, operator
4
+ from collections import defaultdict
5
+ from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
6
+ from tinygrad.ops import UnaryOps, BinaryOps, exec_alu
7
+ from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI, all_same, partition
8
+ from tinygrad.codegen.uops import UOp, NOp, UOps, UPat, PatternMatcher, END_FOR_UOP, type_verify, print_uops
9
+ from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
10
+ if TYPE_CHECKING: from tinygrad.renderer import Renderer
11
+
12
+ # ***** float4/image store handling *****
13
+
14
+ def fold_expanded(ex, buf):
15
+ if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None
16
+ new_srcs = dedup(list(ex.src))
17
+ old_new_srcs = new_srcs[:]
18
+ is_load, is_image = new_srcs[0].op is UOps.LOAD, isinstance(buf.dtype, ImageDType)
19
+
20
+ # first, extract all the relevant offsets
21
+ offsets_rootsrc: DefaultDict[Any, dict] = defaultdict(dict)
22
+ for i,s in enumerate(new_srcs):
23
+ if (s.dtype is not None and s.dtype.count != 1) or (is_image and s.src[1].dtype != dtypes.int.vec(3)): continue
24
+ idx = s.src[1] if not is_image else s.src[1].src[2] # only id4 for image
25
+ if idx.arg is BinaryOps.ADD and idx.src[1].op is UOps.CONST: root_src, arg = idx.src[0], idx.src[1].arg
26
+ elif idx.op is UOps.CONST: root_src, arg = "CONST", idx.arg
27
+ else: root_src, arg = idx, 0
28
+ # add idx and idy for image
29
+ if is_image: root_src = (s.src[1].src[0:2], root_src)
30
+ # add gates for gated
31
+ if len(s.src) >= 4: root_src = (s.src[3], root_src)
32
+ assert arg not in offsets_rootsrc[root_src]
33
+ offsets_rootsrc[root_src][arg] = i
34
+
35
+ # then rewrite everything we can
36
+ used = set()
37
+ for rootsrc, offsets in offsets_rootsrc.items():
38
+ for o in offsets:
39
+ for fold_length in [4] if is_image else ([8,4,2] if buf.dtype == PtrDType(dtypes.half) and getenv("ALLOW_HALF8") else [4,2]):
40
+ if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)):
41
+ load_1 = new_srcs[offsets[o]]
42
+ new_src = list(load_1.src)
43
+ if not is_image and not new_src[1].divides(fold_length): continue
44
+ # for images, we rewrite the index
45
+ if is_image: new_src[1] = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (new_src[1].src[0], new_src[1].src[1]))
46
+ # vectorize the store/loadconst
47
+ if not is_load or len(new_src) >= 4:
48
+ new_src[2] = UOp(UOps.VECTORIZE, new_src[2].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[2] for i in range(fold_length)))
49
+ # generate the folded new_srcs
50
+ if is_load:
51
+ new_load = UOp(UOps.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
52
+ for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(UOps.GEP, load_1.dtype, (new_load,), i)
53
+ else:
54
+ for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(UOps.STORE, None, tuple(new_src)) if i == 0 else None
55
+ for i in range(fold_length): used.add((rootsrc,o+i))
56
+
57
+ # dedup expand for LOAD
58
+ if is_load and len(old_new_srcs) != len(ex.src): new_srcs = [new_srcs[old_new_srcs.index(s)] for s in ex.src]
59
+ # remove Nones for STORE
60
+ return UOp(ex.op, ex.dtype, tuple(x for x in new_srcs if x is not None), ex.arg) if len(used) else None
61
+
62
+ def vectorize_reduce(vec:UOp):
63
+ if all_same(vec.src): return None # don't REDUCE the same thing multiple times
64
+ if not all_same([(x.src[1:], x.arg) for x in vec.src]): return None
65
+ return UOp(UOps.REDUCE, vec.dtype, (UOp(UOps.VECTORIZE, vec.dtype, tuple(x.src[0] for x in vec.src)),) + vec.src[0].src[1:], vec.src[0].arg)
66
+
67
+ def vectorize_alu(vec:UOp):
68
+ if not all_same([x.arg for x in vec.src]): return None
69
+ return UOp(vec.src[0].op, vec.dtype, tuple(UOp(UOps.VECTORIZE, cast(DType, vec.src[0].src[i].dtype).vec(cast(DType, vec.dtype).count),
70
+ tuple(x.src[i] for x in vec.src)) for i in range(len(vec.src[0].src))), vec.src[0].arg)
71
+
72
+ float4_folding = PatternMatcher([
73
+ (UPat(UOps.EXPAND, src=UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True), name="ex"), fold_expanded),
74
+ (UPat({UOps.BARRIER, UOps.SINK}, src=UPat(UOps.STORE, src=(UPat(name="buf"), UPat(), UPat()), allow_any_len=True), name="ex"), fold_expanded),
75
+ (UPat(UOps.VECTORIZE, src=UPat(UOps.REDUCE), name="vec"), vectorize_reduce),
76
+ (UPat(UOps.VECTORIZE, src=UPat({UOps.ALU, UOps.CAST, UOps.BITCAST}), name="vec"), vectorize_alu),
77
+ ])
78
+
79
+ # ***** mod *****
80
+
81
+ def _get_add_chain(x:UOp):
82
+ if x.op is UOps.ALU and x.arg is BinaryOps.ADD:
83
+ for s in x.src: yield from _get_add_chain(s)
84
+ else: yield x
85
+
86
+ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
87
+ # simplify x in x % c
88
+ # None means no change
89
+ remainder, something_changed = [], False
90
+ for u in _get_add_chain(x):
91
+ if (factor:=u.const_factor())%c != factor:
92
+ remainder.append(u.divides(factor)*(factor%c))
93
+ something_changed = True
94
+ else: remainder.append(u)
95
+ if not something_changed: return None
96
+ return functools.reduce(operator.add, remainder) if remainder else x.const(0)
97
+
98
+ def div_folding(x:UOp, c:int) -> Optional[UOp]:
99
+ # simplify x // c, None means no change
100
+ # simple cancel div case
101
+ if 0 <= x.vmin.arg and x.vmax.arg < c: return x.const(0)
102
+
103
+ quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1
104
+ for u in _get_add_chain(x):
105
+ if u.op is UOps.CONST:
106
+ # add all const together first
107
+ if rem_const != 0: something_changed = True
108
+ rem_const += u.arg
109
+ elif (factor:=u.const_factor())%c == 0:
110
+ if factor: quotient.append(u.divides(c))
111
+ something_changed = True
112
+ else:
113
+ # divisor is the smallest common divisor of all MULs
114
+ if u.op is UOps.ALU and u.arg is BinaryOps.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor
115
+ remainder.append(u)
116
+ gcd = math.gcd(gcd, factor)
117
+
118
+ # handle the const
119
+ if rem_const%c != rem_const:
120
+ something_changed = True
121
+ quotient.append(x.const(rem_const//c))
122
+ rem_const = rem_const%c
123
+ if rem_const != 0: remainder.append(x.const(rem_const))
124
+
125
+ # x // c -> quotient + (remainder // div) // (c // div)
126
+ div = gcd if gcd > 1 else divisor
127
+
128
+ if not something_changed: return newx//(c//div) if 1 < div < c and (newx:=div_folding(x, div)) is not None else None
129
+ rem:Optional[UOp] = functools.reduce(operator.add, remainder) if remainder else None
130
+ quo:Optional[UOp] = functools.reduce(operator.add, quotient) if quotient else None
131
+ if quo is None: return x.const(0) if rem is None else cast(UOp, div_folding(rem, div))//(c//div)
132
+ return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo
133
+
134
+ # ***** transcendental *****
135
+
136
+ def transcendental_folding(ops):
137
+ return PatternMatcher([(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="d"),), arg=k), cast(Callable, v))
138
+ for k,v in ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if k not in ops])
139
+
140
+ # ***** threefry *****
141
+
142
+ def threefry2x32(x: UOp, seed: UOp):
143
+ # split x into two uint32, since x in a uint64
144
+ x0, x1 = (x & 0xffffffff).cast(dtypes.uint32), ((x // 2**32) & 0xffffffff).cast(dtypes.uint32)
145
+
146
+ rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
147
+ ks = [0x0, (seed := seed.cast(dtypes.uint32)) ^ 0x1BD11BDA, seed]
148
+ xr = [x0 + ks[-1], x1 + ks[0]]
149
+ for i in range(5):
150
+ for r in rotations[i % 2]: xr[0], xr[1] = (x0 := xr[0] + xr[1]), x0 ^ ((xr[1] * 2**r) + (xr[1] // 2**(32 - r)))
151
+ xr = [(xr[0] + ks[i % 3]), (xr[1] + ks[(i + 1) % 3] + i + 1)]
152
+
153
+ return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64)
154
+
155
+ # ***** main rewriter *****
156
+
157
+ def reduce_before_expand(reduce, expand, x):
158
+ # if the expand is being reduced, you can't push it through
159
+ # NOTE: could do a partial push here in some cases
160
+ expands = flatten([x.arg for x in reduce.src[1:] if x.op is UOps.EXPAND])
161
+ if any(x in expands for x in expand.arg): return None
162
+ red = UOp(UOps.REDUCE, x.dtype, (x,)+reduce.src[1:], reduce.arg)
163
+ return UOp(expand.op, expand.dtype, tuple(UOp(UOps.GEP, reduce.dtype, (red,), i) for i in range(x.dtype.count)), expand.arg)
164
+
165
+ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None, extra=None):
166
+ if getenv("DISABLE_LOOP_COLLAPSE") or rng not in reduce.src: return None # must be the right REDUCE
167
+ if mval.arg >= 0 or loop_start.arg != 0:
168
+ # TODO: support and test this with other mvals and loop_starts
169
+ if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
170
+ return None
171
+ if idx2 is not None: idx = idx + idx2
172
+ if idx3 is not None: idx = idx + idx3
173
+ comprange = UOp.min(loop_end, UOp.max((idx-compval-mval)//mval + (loop_end-loop_start), loop_start))
174
+ new_reduce_op = comprange.cast(multconst.dtype) * multconst
175
+ ret = UOp(UOps.REDUCE, reduce.dtype, (new_reduce_op,) + tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
176
+ if extra is not None: ret = ret + UOp(UOps.REDUCE, reduce.dtype, (extra,) + reduce.src[1:], reduce.arg)
177
+ return ret
178
+
179
+ def index_collapse(idx,rng,buf,add,mul,ld,reduce):
180
+ if rng not in reduce.src: return None
181
+ return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx)),)+
182
+ tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
183
+
184
+ # this is symbolic 2.0
185
+ constant_folder = PatternMatcher([
186
+ # VECTORIZE/GEP
187
+ (NOp(UOps.GEP, src=(NOp(UOps.VECTORIZE, name="cast"),), name="gep"), lambda gep, cast: cast.src[gep.arg]),
188
+ *[(NOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(NOp(UOps.GEP, dtypes.float,
189
+ src=(NOp.var('x', dtype=dtypes.float.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in [2, 4, 8, 16]],
190
+ *[(NOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(NOp(UOps.GEP, dtypes.half,
191
+ src=(NOp.var('x', dtype=dtypes.half.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in [2, 4, 8, 16]],
192
+ # tensor core with a 0 input is acc
193
+ *[(NOp(UOps.WMMA, src=(NOp(UOps.VECTORIZE, src=tuple(NOp.const(None, 0.0) for _ in range(i))), NOp.var(), NOp.var('acc'))),
194
+ lambda acc: acc) for i in [2, 4, 8]],
195
+ *[(NOp(UOps.WMMA, src=(NOp.var(), NOp(UOps.VECTORIZE, src=tuple(NOp.const(None, 0.0) for _ in range(i))), NOp.var('acc'))),
196
+ lambda acc: acc) for i in [2, 4, 8]],
197
+ # tensor core cleanups
198
+ *[(NOp(UOps.REDUCE, src=(NOp(UOps.EXPAND, src=tuple(NOp(UOps.GEP, dtypes.float, src=(NOp.var('x'),), arg=i) for i in range(j)), name="expand"),)
199
+ ,name="reduce", allow_any_len=True), reduce_before_expand) for j in [2,4,8]],
200
+ (NOp.var("add") + NOp(UOps.WMMA, name="wmma"),
201
+ lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
202
+ # threefry
203
+ (NOp(UOps.ALU, dtype=dtypes.uint64, src=(NOp.var("x"), NOp.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32),
204
+ # extra arange loop folding because we don't fold adds. TODO: fold adds
205
+ (NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng") +
206
+ NOp.var("idx2") + NOp.var("idx3"))
207
+ .lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
208
+ (NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng") +
209
+ NOp.var("idx2"))
210
+ .lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
211
+ # arange loop folding (reduce)
212
+ (NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng"))
213
+ .lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
214
+ (NOp(UOps.REDUCE, src=((NOp.var("idx") - NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng"))
215
+ .lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True),
216
+ lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
217
+ # arange loop folding (unrolled)
218
+ (NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng"))
219
+ .lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)) + NOp.var("extra"),),
220
+ arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
221
+ # indexing (with a multiply offset)!
222
+ (NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).cast()*
223
+ NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"),),
224
+ arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
225
+ (NOp(UOps.REDUCE, src=(NOp.var('idx').ne(NOp(UOps.RANGE, name="rng")).__neg__().cast()*
226
+ NOp(UOps.LOAD, src=(NOp.var("buf"), NOp(UOps.RANGE, name="rng")), name="ld"),),
227
+ arg=BinaryOps.ADD, name="reduce", allow_any_len=True),
228
+ lambda **kwargs: index_collapse(add=UOp.const(dtypes.int, 0), mul=UOp.const(dtypes.int, 1), **kwargs)),
229
+ (NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).where(
230
+ NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"), NOp.const(None, 0.0)),),
231
+ arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
232
+ # other arange folders
233
+ (NOp.cvar("c1") - (NOp.var("x") + NOp.cvar("c2")), lambda c1, c2, x: (c1-c2)-x), # c1 - (x + c2) -> (c1-c2) - x
234
+ (-(NOp.var("x") * NOp.cvar("c1")), lambda x, c1: x*-c1),
235
+ # max folding
236
+ (NOp.max(NOp.var('x'), NOp.var('y')), lambda x,y: x if x.vmin.arg >= y.vmax.arg else y if x.vmax.arg <= y.vmin.arg else None),
237
+ # const rules
238
+ (NOp(UOps.GEP, src=(NOp.cvar("c"),), name="root"), lambda root, c: root.const(c.arg)),
239
+ (UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const(c.arg)),
240
+ # a REDUCE without ranges is a NOOP
241
+ (NOp(UOps.REDUCE, src=(NOp.var('x'),)), lambda x: x),
242
+ # GEP on a const is the const
243
+ (NOp(UOps.GEP, src=(NOp.cvar("x"),), name="root"), lambda root,x: root.const(x.arg)),
244
+ # a conditional with the same results either way is a noop, also fold const conditionals
245
+ (NOp.var().where(NOp.var("val"), NOp.var("val")), lambda val: val),
246
+ (NOp.cvar('gate').where(NOp.var('c0'), NOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
247
+ # ** constant folding **
248
+ (UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: root.const(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
249
+ # ** self folding **
250
+ (-(-NOp.var('x')), lambda x: x), # -(-x) -> x
251
+ (NOp.var('x') + 0, lambda x: x), # x+0 -> x
252
+ (NOp.var('x') * 1, lambda x: x), # x*1 -> x
253
+ (NOp.var('x') * -1, lambda x: -x), # x*-1 -> -x
254
+ (NOp.var('x') // NOp.var('x'), lambda x: x.const(1)), # x//x -> 1
255
+ (NOp.var('x') // 1, lambda x: x), # x//1 -> x
256
+ (NOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
257
+ (NOp.var('x') / NOp.var('x'), lambda x: x.const(1)), # x/x -> 1
258
+ (NOp.var('x') / NOp.cvar('c'), lambda x,c: x*exec_alu(UnaryOps.RECIP, c.dtype, [c.arg])), # x/c -> x*(1/c)
259
+ # ** zero folding **
260
+ # x*0 -> 0 or 0*x -> 0
261
+ # if x is nan or inf it should render the nan value.
262
+ # NOTE: this can be wrong for loaded NaN
263
+ (NOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
264
+ # x-x -> 0
265
+ (NOp.var('x') - NOp.var('x'), lambda x: x.const(0)),
266
+ (UPat(UOps.ALU, name='x'), lambda x: x.const(x.vmin.arg) if x.vmin.arg == x.vmax.arg else None),
267
+ # ** load/store folding **
268
+ (NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.load(NOp.var("buf"), NOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
269
+ # ** two stage add/mul folding **
270
+ ((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x+x.const(exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
271
+ ((NOp.var("x") * NOp.cvar("c1")) * NOp.cvar("c2"), lambda x,c1,c2: x*x.const(exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
272
+ # *** rules from symbolic ***
273
+ # ** lt **
274
+ # c0*x<c1 for positive int c0,c1
275
+ ((NOp.cvar('c0')*NOp.var('x')).lt(NOp.cvar('c1')),
276
+ lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if dtypes.is_int(x.dtype) and c0.arg > 0 and c1.arg > 0 else None),
277
+ # mul add lt
278
+ (((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')).lt(NOp.cvar('c1')),
279
+ lambda x,x2,c0,c1: x.lt(c1.arg//c0.arg) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax.arg and x2.vmin.arg >= 0 else None),
280
+ # generic lt folding (use div)
281
+ (NOp.var('x').lt(NOp.cvar('c')), lambda x,c: newx.src[0].lt(newx.src[1]) if 0 < c.arg and dtypes.is_int(x.dtype) and \
282
+ not dtypes.is_unsigned(x.dtype) and (newx:=div_folding(x,c.arg)) is not None and newx.op is UOps.ALU and newx.arg is BinaryOps.IDIV else None),
283
+ # ** div **
284
+ # # div folding
285
+ (NOp.var('x') // NOp.cvar('c'), lambda x,c:
286
+ newx if 0 < c.arg and not dtypes.is_unsigned(x.dtype) and (newx:=div_folding(x,c.arg)) is not None else None),
287
+ # mul add div
288
+ (((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')) // NOp.cvar('c1'), lambda x,x2,c0,c1:\
289
+ x*(c0.arg//g)//(c1.arg//g) if c0.arg > 0 and c1.arg > 0 and (g:=math.gcd(c0.arg,c1.arg)) > 1 and g > x2.vmax.arg and x2.vmin.arg >= 0 else None),
290
+ # ** mod **
291
+ # apply mod to mod input
292
+ (NOp.var('x') % NOp.cvar('c'), lambda x,c: newx%c if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
293
+ # remove mod
294
+ (NOp.var('x') % NOp.cvar('c'), lambda x,c:\
295
+ x-(x.vmin.arg//c.arg)*c.arg if 0 < c.arg and 0 <= x.vmin.arg and x.vmin.arg//c.arg == x.vmax.arg//c.arg else None),
296
+ # mul mod
297
+ ((NOp.cvar('c0')*NOp.var('x')) % NOp.cvar('c1'), lambda x,c0,c1: (x%(c1.arg//c0.arg))*c0 if c1.arg%c0.arg == 0 else None),
298
+ # mod mod
299
+ ((NOp.var('x') % NOp.cvar('c0')) % NOp.cvar('c1'), lambda x,c0,c1: x % c1 if c0.arg % c1.arg == 0 else None),
300
+ # (x%c)+(x//c)*c = x
301
+ (NOp.var('x')%NOp.cvar('c')+(NOp.var('x')//NOp.cvar('c'))*NOp.cvar('c'), lambda x,c: x),
302
+ # ** combine terms **
303
+ # -(x+y) -> -x + -y
304
+ (-(NOp.var("x") + NOp.var("y")), lambda x,y: (-x)+(-y)),
305
+ # (x+c0)*c1 -> x*c1+c0*c1. only for signed int, float have inf*0=nan issue
306
+ ((NOp.var("x") + NOp.cvar("c0")) * NOp.cvar("c1"), lambda x,c0,c1:
307
+ x*c1+c0.arg*c1.arg if dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
308
+ # (x*c0)+(x*c1) -> x*(c0+c1)
309
+ (NOp.var("x") * NOp.cvar("c0") + NOp.var("x") * NOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
310
+ # (x*c0)+(y*c0) -> (x+y)*c0
311
+ #((NOp.var("x") * NOp.cvar("c0")) + (NOp.var("y") * NOp.cvar("c0")), lambda x,y,c0: c0*(x+y)),
312
+ # (x*x2)/x2 -> x
313
+ ((NOp.var("x") * NOp.var("x2")) / NOp.var("x2"), lambda x,x2: x),
314
+ # (x//c0)//c1 -> x//(c0*c1)
315
+ ((NOp.var("x") // NOp.cvar("c0")) // NOp.cvar("c1"), lambda x,c0,c1: x//x.const(exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))),
316
+ # (x/x1)/x2 -> x/(x1*x2)
317
+ ((NOp.var("x") / NOp.var("x2")) / NOp.var("x3"), lambda x,x2,x3: x/(x2*x3)),
318
+ # c0 + x < c1 -> x < c1 - c0
319
+ ((NOp.cvar("c0") + NOp.var("x")).lt(NOp.cvar("c1")), lambda x,c0,c1: UOp.lt(x, x.const(exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
320
+ # (x+x*c0)-> x*(c0+1)
321
+ (NOp.var("x") + NOp.var("x") * NOp.cvar("c0"), lambda x,c0: x*(c0.arg+1)),
322
+ # x!=0 -> (bool)x
323
+ (NOp.var("x").ne(0), lambda x: x.cast(dtypes.bool)),
324
+ # bool != 1 -> not bool
325
+ (NOp.var("x", dtype=dtypes.bool).ne(1), lambda x: -x),
326
+ # TODO: can do the invert of this (flip alt/load) when we fix double ops
327
+ (NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.var("gate").where(NOp.var("alt"), NOp.load(NOp.var("buf"), NOp.var("idx")))),
328
+ lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
329
+ # VECTORIZE-PHI-GEP -> PHI-VECTORIZE
330
+ (NOp(UOps.VECTORIZE, src=tuple(NOp(UOps.PHI, src=(NOp(UOps.GEP, src=(NOp.var("val"),), arg=i), NOp.var(f"v{i}"))) for i in range(4)), name="root"),
331
+ lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))),
332
+ (NOp(UOps.VECTORIZE, src=tuple(NOp(UOps.PHI, src=(NOp(UOps.GEP, src=(NOp.var("val"),), arg=i), NOp.var(f"v{i}"))) for i in range(2)), name="root"),
333
+ lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1))))),
334
+ # cast NOOP (NOTE: it's str to deal with PtrDType)
335
+ (NOp(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
336
+ (NOp(UOps.VECTORIZE, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
337
+ # fold gated LOAD/STORE
338
+ (NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.var("var"), NOp.const(dtypes.bool, True)), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
339
+ (NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.var("var"), NOp.const(dtypes.bool, True), NOp.var("barrier")),
340
+ lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
341
+ (NOp.load(NOp.var(), NOp.var(), NOp.var("var"), NOp.const(dtypes.bool, False)), lambda var: var),
342
+ (NOp.load(NOp.var(), NOp.var(), NOp.var("var"), NOp.const(dtypes.bool, False), NOp.var()), lambda var: var),
343
+ (NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.var("val"), NOp.const(dtypes.bool, True)), UOp.store),
344
+ (NOp.store(NOp.var(), NOp.var(), NOp.var(), NOp.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)),
345
+ # remove NOOPs from SINK
346
+ (NOp(UOps.SINK, name="root"),
347
+ lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None),
348
+ # ** move add consts to end (NOTE: this is still happening before constant folding) **
349
+ (UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(UOps.CONST, name='c1'), UPat(name='x'))), lambda c1,x: x+c1 if x.op is not UOps.CONST else None),
350
+ (UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name='x'), UPat(UOps.CONST, name='c1'))), UPat(name='y')]),
351
+ lambda x,c1,y: (x+y)+c1),
352
+ ])
353
+
354
+ # *** uop expander ***
355
+
356
+ def _expand_arg_to_idx(args:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]) -> int:
357
+ idx, mul = 0, 1
358
+ for axis,m in args[::-1]:
359
+ idx += rpk[axis] * mul
360
+ mul *= m
361
+ return idx
362
+
363
+ def _choices_from_args(args:Tuple[Tuple[int, int], ...]) -> List[Dict[int, int]]:
364
+ return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
365
+
366
+ def do_expand(root:UOp):
367
+ expands = [x for x in root.src if x.op is UOps.EXPAND]
368
+ if len(expands) == 0: return None
369
+ expand_args = tuple(sorted(dedup(flatten([x.arg for x in expands]))))
370
+ if root.op is UOps.WMMA:
371
+ # both the reduce and upcast args are not expanded here
372
+ dont_expand_args = tuple(x for x in expand_args if x[0] in root.arg[-1] or x[0] in [y[0] for y in flatten(root.arg[-2])])
373
+ expand_args = tuple(x for x in expand_args if x not in dont_expand_args)
374
+ else:
375
+ dont_expand_args = ()
376
+ new_srcs: List[UOp] = []
377
+ lrpks = _choices_from_args(dont_expand_args)
378
+ for rpk in _choices_from_args(expand_args):
379
+ new_src: List[UOp] = []
380
+ for src in root.src:
381
+ if src.op is UOps.EXPAND:
382
+ lnew_src = tuple(src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})] for lrpk in lrpks)
383
+ # TODO: is this right for UOps.WMMA? when there's more than one, all lnew_src should be the same
384
+ new_src.append(lnew_src[0] if len(lnew_src) == 1 or root.op is UOps.WMMA else UOp(UOps.EXPAND, root.dtype, lnew_src, dont_expand_args))
385
+ else:
386
+ new_src.append(src)
387
+ new_srcs.append(UOp(root.op, root.dtype, tuple(new_src), root.arg))
388
+ if root.op is UOps.EXPAND:
389
+ # merge two expands
390
+ expand_args, old_args = tuple(sorted(root.arg+expand_args)), expand_args
391
+ assert len(expand_args) == (len(old_args) + len(root.arg))
392
+ new_srcs = [new_srcs[_expand_arg_to_idx(old_args, rpk)].src[_expand_arg_to_idx(root.arg, rpk)] for rpk in _choices_from_args(expand_args)]
393
+ if root.op is UOps.IF:
394
+ # merge ifs into an or
395
+ conditions = functools.reduce(lambda x,y: x|y, dedup(x.src[0] for x in new_srcs if x.src[0].op is not UOps.CONST))
396
+ barriers = tuple(set(x.src[1] for x in new_srcs))
397
+ new_srcs = [UOp(UOps.IF, src=(conditions,)+barriers) for _ in new_srcs]
398
+ assert prod([x[1] for x in expand_args]) == len(new_srcs)
399
+ return UOp(UOps.EXPAND, root.dtype, tuple(new_srcs), expand_args)
400
+
401
+ acc_number = 0
402
+ def do_reduce(root:UOp):
403
+ global acc_number
404
+ reduce_parented, reduce_unparented = partition(root.src[1:], lambda x: x in root.src[0].parents)
405
+ ret = root.src[0]
406
+ if len(reduce_parented):
407
+ assert root.dtype is not None
408
+ const = UOp.const(root.dtype, 0 if root.arg is BinaryOps.ADD else dtypes.min(root.dtype.scalar()))
409
+ acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(reduce_parented), (acc_number,))
410
+ acc_number += 1
411
+ ret = UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, ret)))
412
+ # for MAX, we can just ignore the unparented
413
+ if root.arg is BinaryOps.ADD:
414
+ for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype)
415
+ return ret
416
+
417
+ def do_contract(con:UOp):
418
+ ex = con.src[0]
419
+ assert con.dtype is not None
420
+ # CONTRACT without EXPAND repeats the element VECTORIZED
421
+ if ex.op is not UOps.EXPAND: return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count)
422
+ # CONTRACT may remove several axes from EXPAND
423
+ assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
424
+ srcs = []
425
+ for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
426
+ lsrcs = [ex.src[_expand_arg_to_idx(ex.arg, {**rpk, **lrpk})] for lrpk in _choices_from_args(con.arg)]
427
+ srcs.append(UOp(UOps.VECTORIZE, con.dtype, tuple(lsrcs)))
428
+ return srcs[0] if len(srcs) == 1 else UOp(UOps.EXPAND, con.dtype, tuple(srcs), new_ex_args)
429
+
430
+ def no_vectorized_alu(alu):
431
+ if alu.dtype.count == 1: return None
432
+ alus = tuple(UOp(alu.op, alu.dtype.scalar(),
433
+ tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), i) for s in alu.src), alu.arg) for i in range(alu.dtype.count))
434
+ return UOp(UOps.VECTORIZE, alu.dtype, alus)
435
+
436
+ def create_gate(root:UOp) -> Optional[UOp]:
437
+ @functools.lru_cache(None)
438
+ def _gate_srcs(u:UOp, gate:UOp) -> UOp:
439
+ if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER: return UOp(u.op, u.dtype, u.src[:-1]+(UOp(UOps.IF, None, (gate, u.src[-1])),), u.arg)
440
+ return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg)
441
+ return None if len(root.src) == 3 or (ret:=_gate_srcs(root, root.src[3])) is root else ret
442
+
443
+ expander = PatternMatcher([
444
+ # create gate MUST BE BEFORE expander
445
+ (NOp(UOps.STORE, name="root"), create_gate),
446
+ # do expansion
447
+ (UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE,
448
+ UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root"), do_expand),
449
+ (NOp(UOps.CONTRACT, name="con"), do_contract),
450
+ # remove EXPANDs from SINK
451
+ (NOp(UOps.SINK, name="root"),
452
+ lambda root: UOp(UOps.SINK, root.dtype, a, root.arg)
453
+ if len(a:=tuple(flatten(x.src if x.op is UOps.EXPAND else (x,) for x in root.src))) != len(root.src) else None),
454
+ # BARRIERs aren't actually expanded
455
+ (NOp(UOps.BARRIER, src=(NOp(UOps.EXPAND, name="ex"),)), lambda ex: UOp(UOps.EXPAND, None, (UOp(UOps.BARRIER, None, ex.src),)*len(ex.src), ex.arg)),
456
+ # empty EXPAND is NOOP
457
+ (NOp(UOps.EXPAND, src=(NOp.var('x'),), arg=()), lambda x: x),
458
+ # EXPAND GEP (needed for WMMA, generalize this) -> vectorized ALU
459
+ (NOp(UOps.EXPAND, name="ex", src=tuple(NOp.var('x').gep(i)+NOp.var('y').gep(i) for i in range(8))),
460
+ lambda ex,x,y: UOp(UOps.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(8)), ex.arg)),
461
+ ])
462
+
463
+ def delete_redundant_gates(root:UOp) -> Optional[UOp]:
464
+ @functools.lru_cache(None)
465
+ def find_gate(x:UOp) -> Optional[UOp]:
466
+ if x.op is UOps.IF: return x
467
+ return next((ret for s in x.src if (ret:=find_gate(s)) is not None), None)
468
+ if len(root.src) == 3 or (gate:=find_gate(root)) is None or gate.src[0] is not root.src[3]: return None
469
+ return UOp(UOps.STORE, root.dtype, root.src[:3], root.arg)
470
+
471
+ reducer = PatternMatcher([
472
+ (NOp(UOps.REDUCE, name="root"), do_reduce),
473
+ # no ALU on vectorized dtypes
474
+ (UPat({UOps.ALU, UOps.CAST, UOps.BITCAST}, name="alu"), no_vectorized_alu),
475
+ # delete_redundant_gates (after expand, is this still needed?)
476
+ (NOp(UOps.STORE, name="root"), delete_redundant_gates),
477
+ ])
478
+
479
+ # *** uop graph ***
480
+
481
+ def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[UOp, None]], in_degree:Dict[UOp, int]):
482
+ if u in children: return srcs[u]
483
+ srcs[u] = {}
484
+ children[u] = []
485
+ for x in u.src:
486
+ srcs[u].update(get_children_dfs(x, children, srcs, in_degree))
487
+ if x.op is UOps.RANGE and x.arg[1]: srcs[u][x] = None
488
+ children[x].append(u)
489
+ in_degree[u] = len(u.src)
490
+ return srcs[u]
491
+
492
+ def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
493
+ nodes: Dict[Tuple, UOp] = {}
494
+ replace: Dict[UOp, UOp] = {}
495
+ def __inner_rewrite(n:UOp) -> UOp:
496
+ if n in replace: return replace[n]
497
+ replace_source = (n.op, n.dtype, tuple(__inner_rewrite(y) for y in n.src), n.arg)
498
+ if found := nodes.get(replace_source): replace[n] = found
499
+ else: nodes[replace_source] = replace[n] = found = __inner_rewrite(new_x) if (new_x := pm.rewrite(x:=UOp(*replace_source))) else x
500
+ return found
501
+ return __inner_rewrite(sink)
502
+
503
+ class UOpGraph:
504
+ def __init__(self, sink:Union[UOp, List[UOp]], opts:Optional[Renderer]=None):
505
+ self.sink: UOp = sink if isinstance(sink, UOp) else UOp(UOps.SINK, None, tuple(sink))
506
+ assert self.sink.op is UOps.SINK, f"sink isn't sink, it's {self.sink.op}"
507
+ # used by linearizer
508
+ self._uops: Optional[List[UOp]] = None
509
+ self.opts = opts
510
+ self.folder = constant_folder + transcendental_folding({} if TRANSCENDENTAL >= 2 or opts is None else opts.code_for_op.keys())
511
+
512
+ def __reduce__(self): return self.__class__, (self.sink, self.opts)
513
+ def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
514
+ def __getitem__(self, index) -> UOp: return self.uops[index]
515
+
516
+ @property
517
+ def uops(self) -> List[UOp]:
518
+ if self._uops is None: self.linearize()
519
+ return cast(List[UOp], self._uops)
520
+
521
+ def graph(self):
522
+ from tinygrad.engine.graph import graph_uops
523
+ graph_uops(self.uops)
524
+
525
+ def print(self): print_uops(self.uops)
526
+
527
+ cnt = 0
528
+ def linearize(self, extra_pm:Optional[PatternMatcher]=None, skip_check=False) -> UOpGraph:
529
+ global acc_number
530
+ acc_number = 0
531
+
532
+ # NOTE: relinearizering should be okay
533
+ #assert self._uops is None, "already linearized"
534
+
535
+ # do graph rewrite
536
+ sink = graph_rewrite(self.sink, self.folder)
537
+
538
+ # rewrite pyint to int32
539
+ sink = graph_rewrite(sink, PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE}, dtype=dtypes.pyint, name="x"),
540
+ lambda x: UOp(x.op, dtypes.int32, x.src, x.arg))]))
541
+
542
+ # expand
543
+ UOpGraph.cnt += 1
544
+ if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0):
545
+ sink = graph_rewrite(sink, self.folder+expander+float4_folding if self.opts is not None and self.opts.supports_float4 else self.folder+expander)
546
+ sink = graph_rewrite(sink, self.folder+expander+reducer)
547
+
548
+ # for PTX only
549
+ if extra_pm: sink = graph_rewrite(sink, self.folder+extra_pm)
550
+
551
+ # filter nodes that don't link to a sink
552
+ # BFS toposort
553
+ children: Dict[UOp, List[UOp]] = {}
554
+ range_srcs: Dict[UOp, Dict[UOp, None]] = {}
555
+ in_degree: Dict[UOp, int] = {}
556
+ get_children_dfs(sink, children, range_srcs, in_degree)
557
+
558
+ @functools.lru_cache(None)
559
+ def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
560
+ if x.op is UOps.SINK: return set()
561
+ return set.union({x} if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
562
+
563
+ # scope children impact the toposort and END* insertion
564
+ scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
565
+ range_phi = {r:[p for p in scope_children[r] if p.op is UOps.PHI] for r in scope_children if r.op is UOps.RANGE}
566
+
567
+ queue:List[Tuple[int, UOp]] = []
568
+ def push(u:UOp):
569
+ priority = 0
570
+ # prefer ranges that depend on the least number of independent ranges
571
+ if u.op is UOps.RANGE and u.arg[1]:
572
+ priority += u.arg[0]
573
+ for p in range_phi[u]:
574
+ priority += 10000*len([r for r in range_srcs[p] if not any(i in range_phi[u] for i in range_phi[r])])
575
+ # prefer uops that are loop children
576
+ else:
577
+ priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is UOps.RANGE and u in ss])
578
+ heapq.heappush(queue, (priority, u))
579
+
580
+ for u in children:
581
+ if in_degree[u] == 0: push(u)
582
+
583
+ scope_end: Dict[UOp, UOp] = {}
584
+ self._uops = []
585
+ while queue:
586
+ p,x = heapq.heappop(queue)
587
+ if DEBUG >= 7: print(f"{p:5d}",x)
588
+ if x in scope_children: scope_end[x] = x
589
+ if x.op is UOps.DEFINE_ACC:
590
+ idx = min([self._uops.index(l) for l in x.src if l.op is UOps.RANGE])
591
+ self._uops.insert(idx, x)
592
+ else: self._uops.append(x)
593
+ for u, ss in scope_children.items():
594
+ if x in ss:
595
+ ss.remove(x)
596
+ if len(ss) == 0: scope_end[u] = x
597
+ for u in children[x]:
598
+ in_degree[u] -= 1
599
+ if in_degree[u] == 0: push(u)
600
+
601
+ # end scopes in toposort order
602
+ for u, x in scope_end.items(): self._uops.insert(self._uops.index(x)+1, UOp(END_FOR_UOP[u.op][1], None, (u,)))
603
+
604
+ # sanity checks (NOTE: these can cause things to be skipped in BEAM)
605
+ if not skip_check:
606
+ bad_ops = dedup([x.op for x in self._uops if x.op in {UOps.EXPAND, UOps.CONTRACT, UOps.REDUCE}])
607
+ try:
608
+ type_verify(self.uops)
609
+ assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
610
+ assert len(bad_ops) == 0, f"bad UOps left in list: {bad_ops}"
611
+ # TODO: this should be enabled, and the valid clause should be removed
612
+ # NOTE: multiple identical stores to DEFINE_LOCAL is okay
613
+ assert len(all_stores := [x.src[0:2]+x.src[3:] for x in self._uops if x.op is UOps.STORE and x.src[0].op is not UOps.DEFINE_LOCAL]) \
614
+ == len(dedup(all_stores)), "repeated stores in uops"
615
+ except AssertionError as e:
616
+ self.print()
617
+ if not CI: self.graph()
618
+ raise e
619
+
620
+ # strip the SINK
621
+ self._uops = self._uops[:-1]
622
+ return self