tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +35 -37
- tinygrad/codegen/linearize.py +19 -10
- tinygrad/codegen/lowerer.py +31 -8
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +10 -0
- tinygrad/device.py +28 -11
- tinygrad/dtype.py +12 -3
- tinygrad/engine/jit.py +3 -2
- tinygrad/engine/multi.py +0 -1
- tinygrad/engine/realize.py +7 -4
- tinygrad/engine/schedule.py +227 -255
- tinygrad/engine/search.py +20 -27
- tinygrad/gradient.py +3 -0
- tinygrad/helpers.py +7 -4
- tinygrad/nn/state.py +2 -2
- tinygrad/ops.py +64 -329
- tinygrad/renderer/__init__.py +19 -3
- tinygrad/renderer/cstyle.py +39 -18
- tinygrad/renderer/llvmir.py +55 -18
- tinygrad/renderer/ptx.py +6 -2
- tinygrad/renderer/wgsl.py +20 -12
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/metal.py +28 -29
- tinygrad/runtime/ops_amd.py +37 -34
- tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
- tinygrad/runtime/ops_disk.py +1 -1
- tinygrad/runtime/ops_dsp.py +59 -33
- tinygrad/runtime/ops_llvm.py +14 -12
- tinygrad/runtime/ops_metal.py +78 -62
- tinygrad/runtime/ops_nv.py +9 -6
- tinygrad/runtime/ops_python.py +5 -5
- tinygrad/runtime/ops_webgpu.py +200 -38
- tinygrad/runtime/support/am/amdev.py +23 -11
- tinygrad/runtime/support/am/ip.py +10 -10
- tinygrad/runtime/support/elf.py +2 -0
- tinygrad/runtime/support/hcq.py +7 -5
- tinygrad/runtime/support/llvm.py +8 -14
- tinygrad/shape/shapetracker.py +3 -2
- tinygrad/shape/view.py +2 -3
- tinygrad/spec.py +21 -20
- tinygrad/tensor.py +150 -90
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- tinygrad/codegen/rewriter.py +0 -516
- tinygrad-0.10.1.dist-info/RECORD +0 -86
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,247 @@
|
|
1
|
+
from typing import Optional, Any, Callable
|
2
|
+
import functools, operator
|
3
|
+
from collections import defaultdict
|
4
|
+
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
5
|
+
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve
|
6
|
+
from tinygrad.ops import graph_rewrite, GroupOp
|
7
|
+
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, mulacc_unrolled
|
8
|
+
from tinygrad.helpers import getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
|
9
|
+
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
|
10
|
+
from tinygrad.renderer import Renderer
|
11
|
+
|
12
|
+
# ***** float4/image store handling *****
|
13
|
+
|
14
|
+
def fold_expanded(ex, buf):
|
15
|
+
new_srcs = dedup(list(ex.src))
|
16
|
+
old_new_srcs = new_srcs[:]
|
17
|
+
is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType)
|
18
|
+
|
19
|
+
# TODO: get the device from the buffer somehow
|
20
|
+
# NOTE: this can't be Device.DEFAULT because it opens devices
|
21
|
+
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None
|
22
|
+
lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
|
23
|
+
|
24
|
+
# first, extract all the relevant offsets
|
25
|
+
offsets_rootsrc: defaultdict[Any, dict] = defaultdict(dict)
|
26
|
+
for i,s in enumerate(new_srcs):
|
27
|
+
idx = s.src[0].src[1]
|
28
|
+
if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue
|
29
|
+
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
30
|
+
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
31
|
+
else: root_src, arg = idx, 0
|
32
|
+
# add gates for gated
|
33
|
+
if len(s.src[0].src) == 3: root_src = (s.src[0].src[2], root_src)
|
34
|
+
assert arg not in offsets_rootsrc[root_src], f"{offsets_rootsrc[root_src][arg]} != {i} with {len(s.src)} sources"
|
35
|
+
offsets_rootsrc[root_src][arg] = i
|
36
|
+
|
37
|
+
# then rewrite everything we can
|
38
|
+
used: set[tuple[UOp, UOp]] = set()
|
39
|
+
for rootsrc, offsets in offsets_rootsrc.items():
|
40
|
+
for o in offsets:
|
41
|
+
for fold_length in lengths:
|
42
|
+
if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)):
|
43
|
+
load_1 = new_srcs[offsets[o]]
|
44
|
+
new_src = list(load_1.src)
|
45
|
+
oidx = new_src[0].src[1]
|
46
|
+
if oidx.divides(fold_length) is None: continue
|
47
|
+
if is_image:
|
48
|
+
# for images, we rewrite the index. it must evenly divide 4 from the above check
|
49
|
+
new_src[0] = buf.index(
|
50
|
+
UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
|
51
|
+
rootsrc[0] if isinstance(rootsrc, tuple) else None)
|
52
|
+
else:
|
53
|
+
# for non image, we upcast the index pointer
|
54
|
+
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(size=new_src[0].dtype.size, local=new_src[0].dtype.local))
|
55
|
+
# generate the folded new_srcs
|
56
|
+
if is_load:
|
57
|
+
new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
|
58
|
+
for i in range(fold_length): new_srcs[offsets[o+i]] = new_load.gep(i)
|
59
|
+
else: # vectorize the store
|
60
|
+
new_src[1] = UOp(Ops.VECTORIZE, new_src[1].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[1] for i in range(fold_length)))
|
61
|
+
for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(Ops.STORE, dtypes.void, tuple(new_src)) if i == 0 else None
|
62
|
+
used.update((rootsrc,o+i) for i in range(fold_length))
|
63
|
+
|
64
|
+
# dedup expand for LOAD
|
65
|
+
if is_load and len(old_new_srcs) != len(ex.src): new_srcs = [new_srcs[old_new_srcs.index(s)] for s in ex.src]
|
66
|
+
# remove Nones for STORE
|
67
|
+
return UOp(ex.op, ex.dtype, tuple(x for x in new_srcs if x is not None), ex.arg) if len(used) else None
|
68
|
+
|
69
|
+
def fix_unfoldable_image_load(load:UOp, buf:UOp):
|
70
|
+
if not isinstance(buf.dtype, ImageDType) or (oidx:=load.src[0].src[1]).dtype.count == 2: return None
|
71
|
+
id4 = oidx % 4
|
72
|
+
new_src = list(load.src)
|
73
|
+
# TODO: copied logic from above
|
74
|
+
new_src[0] = load.src[0].src[0].index(
|
75
|
+
UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
|
76
|
+
load.src[0].src[2] if len(load.src[0].src) == 3 else None)
|
77
|
+
vec_load = UOp(Ops.LOAD, load.dtype.vec(4), tuple(new_src))
|
78
|
+
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan')))
|
79
|
+
|
80
|
+
buf_idx_pat = UPat(Ops.INDEX, src=(UPat.var("buf"),), allow_any_len=True)
|
81
|
+
float4_folding = PatternMatcher([
|
82
|
+
(UPat(Ops.VECTORIZE, src=UPat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
|
83
|
+
(UPat((Ops.BARRIER, Ops.SINK), src=UPat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
|
84
|
+
])
|
85
|
+
|
86
|
+
# ***** image load valid simplification *****
|
87
|
+
|
88
|
+
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
89
|
+
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0)
|
90
|
+
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
|
91
|
+
|
92
|
+
# wait for it to be image indexed before running simplification
|
93
|
+
if start_idx.dtype.count != 2: return None
|
94
|
+
|
95
|
+
# can drop valid if idx is out of bound when valid is False
|
96
|
+
drop_stmt = []
|
97
|
+
for stmt in split_uop(valid, Ops.AND):
|
98
|
+
X, is_upper_bound, c = parse_valid(stmt)
|
99
|
+
|
100
|
+
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
|
101
|
+
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)):
|
102
|
+
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, Ops.ADD), idx)
|
103
|
+
testidx = testidx.simplify()
|
104
|
+
if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
|
105
|
+
drop_stmt.append(stmt)
|
106
|
+
continue
|
107
|
+
|
108
|
+
# if X <= c, check if it's out of bound when X = c+1
|
109
|
+
# if X >= c, check if it's out of bound when X = c-1
|
110
|
+
test_value = c + 1 if is_upper_bound else c - 1
|
111
|
+
for i,b in zip(idx.src, (buf.dtype.shape[1], buf.dtype.shape[0])):
|
112
|
+
if i.is_increasing():
|
113
|
+
rw = i.substitute({X:X.const_like(test_value)}).simplify()
|
114
|
+
if rw.vmin >= b or rw.vmax < 0:
|
115
|
+
drop_stmt.append(stmt)
|
116
|
+
break
|
117
|
+
|
118
|
+
if not drop_stmt and idx is start_idx: return None
|
119
|
+
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None
|
120
|
+
return buf.index(idx, new_valid)
|
121
|
+
|
122
|
+
# ***** optional patterns *****
|
123
|
+
|
124
|
+
powers_of_two = {2**i:i for i in range(64)}
|
125
|
+
@functools.lru_cache(None)
|
126
|
+
def get_late_rewrite_patterns(ops, force_transcendental=False):
|
127
|
+
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
|
128
|
+
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
|
129
|
+
# rewrite SQRT to xpow 0.5
|
130
|
+
if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
|
131
|
+
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
|
132
|
+
if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
|
133
|
+
# rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
|
134
|
+
if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)]
|
135
|
+
if Ops.SHR in ops:
|
136
|
+
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) and resolve(x>=0,False) else None)]
|
137
|
+
if Ops.NEG in ops:
|
138
|
+
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
|
139
|
+
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
|
140
|
+
if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
|
141
|
+
return PatternMatcher(pat)
|
142
|
+
|
143
|
+
|
144
|
+
# *** uop expander ***
|
145
|
+
|
146
|
+
# TODO: there's a lot shared with gep_through_wmma here
|
147
|
+
def no_vectorized_wmma(wmma:UOp):
|
148
|
+
out_sz = prod(x[1] for x in wmma.arg[6][-1])
|
149
|
+
if wmma.dtype.count == out_sz: return None
|
150
|
+
tsrcs = []
|
151
|
+
for s,sz in zip(wmma.src, wmma.arg[6]):
|
152
|
+
ssz = prod(x[1] for x in sz)
|
153
|
+
tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
|
154
|
+
wmmas = [UOp(Ops.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
|
155
|
+
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
|
156
|
+
return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
|
157
|
+
|
158
|
+
def no_vectorized_alu(alu):
|
159
|
+
if alu.dtype.vcount == 1: return None
|
160
|
+
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
|
161
|
+
return UOp(Ops.VECTORIZE, alu.dtype, alus)
|
162
|
+
|
163
|
+
def no_vectorized_load_store(ls:UOp):
|
164
|
+
idx = ls.src[0]
|
165
|
+
assert isinstance(idx.dtype, PtrDType)
|
166
|
+
if idx.dtype.v == 1: return None
|
167
|
+
tv = [UOp(ls.op, ls.dtype.scalar(), tuple(j.gep(i) for j in ls.src)) for i in range(idx.dtype.v)]
|
168
|
+
return UOp(Ops.VECTORIZE, ls.dtype, tuple(tv))
|
169
|
+
|
170
|
+
def no_vectorized_acc(acc:UOp):
|
171
|
+
if acc.dtype.count == 1: return None
|
172
|
+
alus = tuple(UOp(acc.op, acc.dtype.scalar(),
|
173
|
+
tuple(s.gep(i) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
|
174
|
+
return UOp(Ops.VECTORIZE, acc.dtype, alus)
|
175
|
+
|
176
|
+
devectorize = PatternMatcher([
|
177
|
+
# no ALU on vectorized dtypes
|
178
|
+
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
|
179
|
+
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
180
|
+
(UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
|
181
|
+
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
|
182
|
+
])
|
183
|
+
|
184
|
+
devectorize_load_store = PatternMatcher([
|
185
|
+
# TODO: add vectorized support to transcendental
|
186
|
+
(UPat((Ops.INDEX, Ops.EXP2, Ops.LOG2, Ops.SIN), name="alu"), no_vectorized_alu),
|
187
|
+
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
|
188
|
+
])
|
189
|
+
|
190
|
+
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
|
191
|
+
if store_gate not in [gate.src[0] for gate in val.toposort if gate.op is Ops.IF]: return None
|
192
|
+
# remove the gate from the index
|
193
|
+
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
|
194
|
+
|
195
|
+
load_store_indexing = PatternMatcher([
|
196
|
+
# late fixup of unfoldable image loads
|
197
|
+
(UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
|
198
|
+
# simplify valid
|
199
|
+
(UPat(Ops.AND, name="valid"), simplify_valid),
|
200
|
+
# image load valid idx simplification
|
201
|
+
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
202
|
+
# delete_redundant_gates (after expand)
|
203
|
+
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
|
204
|
+
UPat.var("val"))), delete_redundant_gates),
|
205
|
+
])
|
206
|
+
|
207
|
+
def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:UOp|None=None) -> UOp:
|
208
|
+
# this moves the mask from the indexing to the load/store op for rendering
|
209
|
+
nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
|
210
|
+
return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is Ops.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
|
211
|
+
|
212
|
+
pm_render = PatternMatcher([
|
213
|
+
# for rendering, we use explicit VECTORIZE
|
214
|
+
(UPat(Ops.CONST, name='c'),
|
215
|
+
lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
|
216
|
+
(UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
217
|
+
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
218
|
+
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
219
|
+
# move masks of loads/stores
|
220
|
+
(UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("mask"))),
|
221
|
+
masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
|
222
|
+
# gate any stores that aren't gated with ifs
|
223
|
+
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
224
|
+
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
225
|
+
])
|
226
|
+
|
227
|
+
# *** uop graph ***
|
228
|
+
|
229
|
+
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
230
|
+
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
231
|
+
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
|
232
|
+
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
|
233
|
+
|
234
|
+
if DEVECTORIZE:
|
235
|
+
# devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse
|
236
|
+
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+
|
237
|
+
mulacc_unrolled)
|
238
|
+
else:
|
239
|
+
# new devectorize only for load/store
|
240
|
+
sink = graph_rewrite(sink, sym+devectorize_load_store+mulacc_unrolled)
|
241
|
+
|
242
|
+
# optional pre matcher
|
243
|
+
if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher)
|
244
|
+
|
245
|
+
# final rules for the renderer (without sym)
|
246
|
+
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher)
|
247
|
+
return sink
|
@@ -0,0 +1,121 @@
|
|
1
|
+
# this converts a lowerer program into a vectorized program
|
2
|
+
|
3
|
+
import functools, itertools, operator
|
4
|
+
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod
|
5
|
+
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, graph_rewrite
|
6
|
+
from tinygrad.codegen.symbolic import sym
|
7
|
+
|
8
|
+
def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int:
|
9
|
+
idx, mul = 0, 1
|
10
|
+
for axis,m in args[::-1]:
|
11
|
+
idx += rpk[axis] * mul
|
12
|
+
mul *= m
|
13
|
+
return idx
|
14
|
+
|
15
|
+
def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]:
|
16
|
+
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
|
17
|
+
|
18
|
+
@functools.lru_cache(None)
|
19
|
+
def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]:
|
20
|
+
return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
|
21
|
+
|
22
|
+
def do_expand(root:UOp):
|
23
|
+
expands = [x for x in root.src if x.op is Ops.UNROLL]
|
24
|
+
if len(expands) == 0: return None
|
25
|
+
# NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
|
26
|
+
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is Ops.WMMA else ()
|
27
|
+
if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0:
|
28
|
+
# if there's only one expand arg, it's okay to use it (optimization)
|
29
|
+
expand_args = expands[0].arg
|
30
|
+
else:
|
31
|
+
# otherwise, we sort them and GEP
|
32
|
+
expand_args = tuple(x for x in sorted(dedup(flatten(expands_args))) if x[0] not in exclude_args)
|
33
|
+
expand_sz = prod([x[1] for x in expand_args])
|
34
|
+
new_srcs = []
|
35
|
+
for i,src in enumerate(root.src):
|
36
|
+
if src.op is Ops.UNROLL:
|
37
|
+
if root.op is Ops.IF and i == 0:
|
38
|
+
# IF means OR on first arg to IF
|
39
|
+
new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)]))
|
40
|
+
elif expand_args == src.arg:
|
41
|
+
# just remove the expand
|
42
|
+
new_srcs.append(src.src[0])
|
43
|
+
else:
|
44
|
+
lst = _swizzle_args(expand_args, src.arg, exclude_args)
|
45
|
+
# if the base dtype is > 1, put those at the end
|
46
|
+
if src.dtype.count > 1: lst = flatten([[i*src.dtype.count+j for j in range(src.dtype.count)] for i in lst])
|
47
|
+
new_srcs.append(src.src[0].gep(tuple(lst)))
|
48
|
+
else:
|
49
|
+
# non-UNROLL input
|
50
|
+
if root.op is Ops.IF:
|
51
|
+
# for the first arg of IF, just pass them through ignoring UNROLLS
|
52
|
+
new_srcs.append(src)
|
53
|
+
elif src.dtype.count > 1:
|
54
|
+
# put any input dtype > 1 grouped together
|
55
|
+
new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
|
56
|
+
else:
|
57
|
+
# repeat the arg
|
58
|
+
new_srcs.append(src.broadcast(expand_sz))
|
59
|
+
|
60
|
+
new_arg = root.arg
|
61
|
+
if root.op is Ops.GEP:
|
62
|
+
assert root.dtype.count == 1
|
63
|
+
# is this right?
|
64
|
+
new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
|
65
|
+
nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
|
66
|
+
return UOp(Ops.UNROLL, root.dtype, (nsrc,), expand_args)
|
67
|
+
|
68
|
+
def do_contract(con:UOp):
|
69
|
+
ex = con.src[0]
|
70
|
+
# CONTRACT without UNROLL repeats the element VECTORIZED
|
71
|
+
if ex.op is not Ops.UNROLL: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count)
|
72
|
+
# CONTRACT may remove several axes from UNROLL
|
73
|
+
assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
|
74
|
+
idxs = []
|
75
|
+
for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
|
76
|
+
idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)]
|
77
|
+
return UOp(Ops.UNROLL, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
|
78
|
+
|
79
|
+
expander = PatternMatcher([
|
80
|
+
# double expand
|
81
|
+
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
|
82
|
+
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
83
|
+
# do expansion
|
84
|
+
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
|
85
|
+
Ops.VECTORIZE, Ops.IF), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
86
|
+
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
87
|
+
# vectorize DEFINE_ACC
|
88
|
+
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"),
|
89
|
+
lambda acc,v: acc.replace(dtype=v.dtype, src=(acc.src[0].broadcast(v.dtype.count),)+acc.src[1:])),
|
90
|
+
# BARRIERs aren't actually expanded
|
91
|
+
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
|
92
|
+
lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)),
|
93
|
+
# empty UNROLL is NOOP
|
94
|
+
(UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x),
|
95
|
+
# UNROLL GEP (needed for WMMA, generalize this) -> vectorized ALU
|
96
|
+
(UPat(Ops.UNROLL, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
|
97
|
+
lambda ex,x,y: UOp(Ops.UNROLL, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
|
98
|
+
])
|
99
|
+
|
100
|
+
def create_gate(root:UOp) -> UOp|None:
|
101
|
+
@functools.lru_cache(None)
|
102
|
+
def _gate_srcs(u:UOp, gate:UOp) -> UOp:
|
103
|
+
if u.op is Ops.BARRIER: return u
|
104
|
+
if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER:
|
105
|
+
return UOp(u.op, u.dtype, u.src[:-1]+(UOp(Ops.IF, src=(gate, u.src[-1])),), arg=u.arg)
|
106
|
+
return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg)
|
107
|
+
idx = root.src[0]
|
108
|
+
if idx.op is Ops.CAST: idx = idx.src[0]
|
109
|
+
return None if idx.op is not Ops.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret
|
110
|
+
|
111
|
+
migrate_indexing = PatternMatcher([
|
112
|
+
# create gate MUST BE BEFORE expander
|
113
|
+
(UPat(Ops.STORE, name="root"), create_gate),
|
114
|
+
])
|
115
|
+
|
116
|
+
def expand_rewrite(sink:UOp) -> UOp:
|
117
|
+
# initial symbolic + migrate indexing (remove this)
|
118
|
+
sink = graph_rewrite(sink, sym+migrate_indexing)
|
119
|
+
|
120
|
+
# expand
|
121
|
+
return graph_rewrite(sink, sym+expander)
|
tinygrad/codegen/kernel.py
CHANGED
@@ -3,43 +3,26 @@ import itertools, functools, math
|
|
3
3
|
from dataclasses import dataclass
|
4
4
|
from collections import defaultdict
|
5
5
|
from typing import Optional, cast, Final, Callable, Sequence
|
6
|
-
from enum import Enum, auto
|
7
6
|
|
8
7
|
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, view_left, print_uops
|
8
|
+
from tinygrad.ops import PatternMatcher
|
9
9
|
from tinygrad.spec import type_verify, shape_spec
|
10
10
|
from tinygrad.device import Device
|
11
|
-
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec
|
11
|
+
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec, Opt, OptOps
|
12
12
|
from tinygrad.dtype import ImageDType
|
13
13
|
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap, ContextVar
|
14
14
|
from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, USE_TC, AMX, CAPTURE_PROCESS_REPLAY
|
15
15
|
from tinygrad.shape.shapetracker import ShapeTracker
|
16
16
|
from tinygrad.shape.view import strides_for_shape
|
17
17
|
from tinygrad.codegen.linearize import linearize_uop
|
18
|
-
from tinygrad.codegen.
|
18
|
+
from tinygrad.codegen.devectorizer import full_graph_rewrite
|
19
19
|
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction
|
20
20
|
|
21
|
-
class OptOps(Enum):
|
22
|
-
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
23
|
-
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
24
|
-
def __lt__(self, x:OptOps): return self.value < x.value
|
25
|
-
|
26
21
|
class KernelOptError(Exception): pass
|
27
22
|
|
28
23
|
def check(cond:bool, msg:str=""):
|
29
24
|
if not cond: raise KernelOptError(msg)
|
30
25
|
|
31
|
-
@dataclass(frozen=True, order=True)
|
32
|
-
class Opt:
|
33
|
-
op: OptOps
|
34
|
-
axis: Optional[int] = None
|
35
|
-
arg: Optional[int | tuple] = None
|
36
|
-
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
|
37
|
-
def real_axis(self, k:Kernel):
|
38
|
-
if self.axis is None: return -1
|
39
|
-
if self.op is OptOps.UNROLL: return k.first_reduce+self.axis
|
40
|
-
if self.op in {OptOps.GROUP, OptOps.GROUPTOP}: return k.first_reduce+k.group_for_reduces+self.axis
|
41
|
-
return self.axis
|
42
|
-
|
43
26
|
@dataclass
|
44
27
|
class TensorCoreOptions:
|
45
28
|
axes: tuple[int, ...] # the location of the original N and M axes if still in the shape
|
@@ -325,8 +308,8 @@ class Kernel:
|
|
325
308
|
-1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
|
326
309
|
[0-N]: uses only the n'th tensor core available; useful for search
|
327
310
|
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
|
328
|
-
0: applies to only kernels with a single reduce axis and direct
|
329
|
-
1: allows kernels with multiple reduce axes and also multiplication of
|
311
|
+
0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL
|
312
|
+
1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers
|
330
313
|
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
|
331
314
|
"""
|
332
315
|
if tc_select is None: tc_select = TC_SELECT.value
|
@@ -339,7 +322,7 @@ class Kernel:
|
|
339
322
|
if extra_opts is not None:
|
340
323
|
for opt in extra_opts: self.apply_opt(opt)
|
341
324
|
else:
|
342
|
-
if
|
325
|
+
if AMX: return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
|
343
326
|
# hand-coded TC opts
|
344
327
|
for tc_dim in [tc_dim for tc_dim in [1,0] if tc_opts.axes_exist[tc_dim]]: # attempt to upcast M and N
|
345
328
|
szs = [sz for sz in [5,4,3,2] if self.full_shape[tc_opts.axes[tc_dim]] % sz == 0]
|
@@ -351,6 +334,12 @@ class Kernel:
|
|
351
334
|
except KernelOptError:
|
352
335
|
return False
|
353
336
|
|
337
|
+
def real_axis(self, opt:Opt):
|
338
|
+
if opt.axis is None: return -1
|
339
|
+
if opt.op is OptOps.UNROLL: return self.first_reduce+opt.axis
|
340
|
+
if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.first_reduce+self.group_for_reduces+opt.axis
|
341
|
+
return opt.axis
|
342
|
+
|
354
343
|
def apply_opt(self, opt:Opt, append_opt:bool=True):
|
355
344
|
if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals")
|
356
345
|
|
@@ -365,7 +354,7 @@ class Kernel:
|
|
365
354
|
self.applied_opts.append(opt)
|
366
355
|
return
|
367
356
|
|
368
|
-
axis =
|
357
|
+
axis = self.real_axis(opt)
|
369
358
|
check(axis < len(self.full_shape), "invalid axis")
|
370
359
|
|
371
360
|
if opt.op is OptOps.SWAP: amt = cast(int, opt.arg) # arg is an axis in the SWAPs
|
@@ -385,6 +374,8 @@ class Kernel:
|
|
385
374
|
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
|
386
375
|
|
387
376
|
if opt.op is OptOps.LOCAL: # cyan
|
377
|
+
# NOTE: LLVM/CPU can use locals too, but they are treated the same as globals (still helpful for L1 cache)
|
378
|
+
# it's disabled for now since it makes BEAM slow for little gain
|
388
379
|
check(self.opts.has_local, "target does not support local")
|
389
380
|
check(axis < self.global_dims, "local is for globals")
|
390
381
|
self.shift_to(axis, amt, insert_before=self.first_reduce)
|
@@ -409,7 +400,7 @@ class Kernel:
|
|
409
400
|
elif opt.op is OptOps.UPCAST: # yellow
|
410
401
|
check(axis < self.first_reduce, "upcast is for non-reduce")
|
411
402
|
check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.get_local_axes())), "can't upcast TC locals")
|
412
|
-
check(amt <= 16, "don't upcast more than 16")
|
403
|
+
check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
|
413
404
|
self.shift_to(axis, amt, insert_before=None)
|
414
405
|
self.upcast()
|
415
406
|
elif opt.op is OptOps.NOLOCALS:
|
@@ -425,7 +416,7 @@ class Kernel:
|
|
425
416
|
check(not self.vars, "does not work with symbolic shape")
|
426
417
|
check(axis < self.first_upcast, "cannot pad upcasted")
|
427
418
|
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
428
|
-
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}, {}), f"cannot pad {r}")
|
419
|
+
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}, cache={}), f"cannot pad {r}")
|
429
420
|
padded = False
|
430
421
|
for i,st in enumerate(self.sts):
|
431
422
|
if (s:=st.shape[axis]) == 1: continue # reduced
|
@@ -512,7 +503,7 @@ class Kernel:
|
|
512
503
|
for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
513
504
|
|
514
505
|
# potentially do more upcasts of non reduce axes based on a heuristic
|
515
|
-
upcasted_axis = set()
|
506
|
+
upcasted_axis: set[int] = set()
|
516
507
|
while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024):
|
517
508
|
xb_choices = []
|
518
509
|
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
@@ -582,7 +573,7 @@ class Kernel:
|
|
582
573
|
num = f"n{Kernel.kernel_cnt[function_name]-1}" if Kernel.kernel_cnt[function_name] > 1 else ""
|
583
574
|
return name + colored(num, 'BLACK')
|
584
575
|
|
585
|
-
def get_optimized_ast(self) -> UOp:
|
576
|
+
def get_optimized_ast(self, name_override:Optional[str]=None) -> UOp:
|
586
577
|
@functools.lru_cache(None)
|
587
578
|
def fixup_ast(op:UOp) -> UOp:
|
588
579
|
ret = op.replace(src=tuple(fixup_ast(x) for x in op.src))
|
@@ -592,7 +583,9 @@ class Kernel:
|
|
592
583
|
if op.op is Ops.CONST and any(v.mask is not None for v in unwrap(st_uop.st).views): return op.valid(unwrap(st_uop.st))
|
593
584
|
# otherwise we just replace the VIEW source
|
594
585
|
return ret.replace(src=(st_uop,)) if len(op.src) == 1 else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
|
595
|
-
if op.op is Ops.SINK:
|
586
|
+
if op.op is Ops.SINK:
|
587
|
+
return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override,
|
588
|
+
self.local_dims, self.upcasted, self.dont_use_locals))
|
596
589
|
if op.op is Ops.REDUCE_AXIS:
|
597
590
|
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
|
598
591
|
|
@@ -662,13 +655,17 @@ class Kernel:
|
|
662
655
|
# **** this is the lowerer ****
|
663
656
|
|
664
657
|
@track_rewrites()
|
665
|
-
def linearize(self) -> Kernel:
|
666
|
-
|
658
|
+
def linearize(self, name_override:Optional[str]=None) -> Kernel:
|
659
|
+
# display the AST
|
660
|
+
if getenv("VIZ"): graph_rewrite(self.ast, PatternMatcher([]), name="View Base AST")
|
661
|
+
|
662
|
+
modified_ast = self.get_optimized_ast(name_override)
|
667
663
|
|
668
664
|
if DEBUG >= 3:
|
669
665
|
print(self.name)
|
670
666
|
if getenv("RAWAST"): print(self.ast)
|
671
|
-
|
667
|
+
for i,(buf,st) in enumerate([(buf,st) for buf,st in zip(self.bufs, self.sts) if buf.op not in {Ops.CONST, Ops.VALID}]):
|
668
|
+
print(f"{i:2d}: {str(st.shape):25s} {str(buf.src[0].dtype).replace('dtypes.',''):20s}", st.real_strides())
|
672
669
|
print(self.applied_opts)
|
673
670
|
# verify AST matches the spec after applying opts
|
674
671
|
if __debug__: type_verify(list(modified_ast.toposort))
|
@@ -680,16 +677,17 @@ class Kernel:
|
|
680
677
|
return self
|
681
678
|
|
682
679
|
def to_program(self, name_override:Optional[str]=None) -> ProgramSpec:
|
683
|
-
self.linearize()
|
684
|
-
|
680
|
+
self.linearize(name_override)
|
681
|
+
assert self.uops[0].op is Ops.NAME, "first uop must be name"
|
682
|
+
src = self.opts.render(self.uops)
|
685
683
|
|
686
684
|
if CAPTURE_PROCESS_REPLAY:
|
687
|
-
diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts,
|
685
|
+
diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, self.uops[0].arg, ContextVar._cache, src))
|
688
686
|
|
689
687
|
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
690
688
|
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
691
689
|
mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
|
692
690
|
for _, group in itertools.groupby([x for x in self.ast.toposort if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
|
693
691
|
key=lambda x: (x.op, x.src[0].arg)))
|
694
|
-
return ProgramSpec(
|
695
|
-
|
692
|
+
return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, self.applied_opts, mem_bytes,
|
693
|
+
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|
tinygrad/codegen/linearize.py
CHANGED
@@ -6,7 +6,7 @@ from tinygrad.spec import type_verify
|
|
6
6
|
from tinygrad.dtype import dtypes, PtrDType
|
7
7
|
from tinygrad.helpers import dedup, flatten, partition
|
8
8
|
|
9
|
-
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block}
|
9
|
+
DONT_PLACE_IN_BLOCK = {Ops.NAME, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block}
|
10
10
|
|
11
11
|
def disp(y:UOp) -> str:
|
12
12
|
if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0])
|
@@ -70,7 +70,8 @@ def append_to_block(ctx:tuple[dict[UOp, tuple[UOp, ...]], dict[UOp, list[UOp]]],
|
|
70
70
|
return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(list(old_blocks.values())+new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst))
|
71
71
|
|
72
72
|
make_basic_blocks = PatternMatcher([
|
73
|
-
(UPat(Ops.SINK, name="x"),
|
73
|
+
(UPat(Ops.SINK, name="x"),
|
74
|
+
lambda x: UOp(Ops.BLOCK, src=x.src+((UOp(Ops.NAME, arg=x.arg.name),) if x.arg is not None else ()), arg=BasicBlock((), (x,)))),
|
74
75
|
(UPat(Ops.BLOCK, name="x"), append_to_block),
|
75
76
|
])
|
76
77
|
|
@@ -112,6 +113,17 @@ def block_merge(ctx, x:UOp):
|
|
112
113
|
|
113
114
|
pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),])
|
114
115
|
|
116
|
+
def block_finalize(block:UOp):
|
117
|
+
if len(block.src) == 0: return None
|
118
|
+
_uops = sorted(dedup(block.src), key=lambda x: x.tuplize)
|
119
|
+
assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
|
120
|
+
_uops += block.arg.lst
|
121
|
+
# strip the SINK
|
122
|
+
assert _uops[-1].op is Ops.SINK, "doesn't end with SINK"
|
123
|
+
return UOp(Ops.BLOCK, arg=BasicBlock((), tuple(_uops[:-1])))
|
124
|
+
|
125
|
+
pm_block_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="block"), block_finalize)])
|
126
|
+
|
115
127
|
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
|
116
128
|
def block_reorder(in_block:UOp):
|
117
129
|
in_this_block = set(in_block.arg.lst)
|
@@ -212,14 +224,11 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
|
|
212
224
|
# final rewrite to merge all blocks into one
|
213
225
|
sink = graph_rewrite(sink, pm_block_merge, ctx=children)
|
214
226
|
|
215
|
-
# there should just be one block left, with a few parents with 0 srcs
|
216
|
-
|
217
|
-
_uops = sorted(dedup(sink.src), key=lambda x: x.tuplize)
|
218
|
-
assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
|
219
|
-
_uops += sink.arg.lst
|
227
|
+
# there should just be one block left, with a few parents with 0 srcs (now done in a rewriter)
|
228
|
+
sink = graph_rewrite(sink, pm_block_finalize)
|
220
229
|
|
221
230
|
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
|
222
|
-
if not skip_check: type_verify(
|
231
|
+
if not skip_check: type_verify(sink.arg.lst)
|
223
232
|
|
224
|
-
#
|
225
|
-
return
|
233
|
+
# return the list. TODO: refactor to return the UOp
|
234
|
+
return list(sink.arg.lst)
|