tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,247 @@
|
|
1
|
+
from typing import Optional, Any, Callable
|
2
|
+
import functools, operator
|
3
|
+
from collections import defaultdict
|
4
|
+
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
5
|
+
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve
|
6
|
+
from tinygrad.ops import graph_rewrite, GroupOp
|
7
|
+
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, mulacc_unrolled
|
8
|
+
from tinygrad.helpers import getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
|
9
|
+
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
|
10
|
+
from tinygrad.renderer import Renderer
|
11
|
+
|
12
|
+
# ***** float4/image store handling *****
|
13
|
+
|
14
|
+
def fold_expanded(ex, buf):
|
15
|
+
new_srcs = dedup(list(ex.src))
|
16
|
+
old_new_srcs = new_srcs[:]
|
17
|
+
is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType)
|
18
|
+
|
19
|
+
# TODO: get the device from the buffer somehow
|
20
|
+
# NOTE: this can't be Device.DEFAULT because it opens devices
|
21
|
+
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None
|
22
|
+
lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
|
23
|
+
|
24
|
+
# first, extract all the relevant offsets
|
25
|
+
offsets_rootsrc: defaultdict[Any, dict] = defaultdict(dict)
|
26
|
+
for i,s in enumerate(new_srcs):
|
27
|
+
idx = s.src[0].src[1]
|
28
|
+
if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue
|
29
|
+
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
30
|
+
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
31
|
+
else: root_src, arg = idx, 0
|
32
|
+
# add gates for gated
|
33
|
+
if len(s.src[0].src) == 3: root_src = (s.src[0].src[2], root_src)
|
34
|
+
assert arg not in offsets_rootsrc[root_src], f"{offsets_rootsrc[root_src][arg]} != {i} with {len(s.src)} sources"
|
35
|
+
offsets_rootsrc[root_src][arg] = i
|
36
|
+
|
37
|
+
# then rewrite everything we can
|
38
|
+
used: set[tuple[UOp, UOp]] = set()
|
39
|
+
for rootsrc, offsets in offsets_rootsrc.items():
|
40
|
+
for o in offsets:
|
41
|
+
for fold_length in lengths:
|
42
|
+
if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)):
|
43
|
+
load_1 = new_srcs[offsets[o]]
|
44
|
+
new_src = list(load_1.src)
|
45
|
+
oidx = new_src[0].src[1]
|
46
|
+
if oidx.divides(fold_length) is None: continue
|
47
|
+
if is_image:
|
48
|
+
# for images, we rewrite the index. it must evenly divide 4 from the above check
|
49
|
+
new_src[0] = buf.index(
|
50
|
+
UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
|
51
|
+
rootsrc[0] if isinstance(rootsrc, tuple) else None)
|
52
|
+
else:
|
53
|
+
# for non image, we upcast the index pointer
|
54
|
+
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(size=new_src[0].dtype.size, local=new_src[0].dtype.local))
|
55
|
+
# generate the folded new_srcs
|
56
|
+
if is_load:
|
57
|
+
new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
|
58
|
+
for i in range(fold_length): new_srcs[offsets[o+i]] = new_load.gep(i)
|
59
|
+
else: # vectorize the store
|
60
|
+
new_src[1] = UOp(Ops.VECTORIZE, new_src[1].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[1] for i in range(fold_length)))
|
61
|
+
for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(Ops.STORE, dtypes.void, tuple(new_src)) if i == 0 else None
|
62
|
+
used.update((rootsrc,o+i) for i in range(fold_length))
|
63
|
+
|
64
|
+
# dedup expand for LOAD
|
65
|
+
if is_load and len(old_new_srcs) != len(ex.src): new_srcs = [new_srcs[old_new_srcs.index(s)] for s in ex.src]
|
66
|
+
# remove Nones for STORE
|
67
|
+
return UOp(ex.op, ex.dtype, tuple(x for x in new_srcs if x is not None), ex.arg) if len(used) else None
|
68
|
+
|
69
|
+
def fix_unfoldable_image_load(load:UOp, buf:UOp):
|
70
|
+
if not isinstance(buf.dtype, ImageDType) or (oidx:=load.src[0].src[1]).dtype.count == 2: return None
|
71
|
+
id4 = oidx % 4
|
72
|
+
new_src = list(load.src)
|
73
|
+
# TODO: copied logic from above
|
74
|
+
new_src[0] = load.src[0].src[0].index(
|
75
|
+
UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
|
76
|
+
load.src[0].src[2] if len(load.src[0].src) == 3 else None)
|
77
|
+
vec_load = UOp(Ops.LOAD, load.dtype.vec(4), tuple(new_src))
|
78
|
+
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan')))
|
79
|
+
|
80
|
+
buf_idx_pat = UPat(Ops.INDEX, src=(UPat.var("buf"),), allow_any_len=True)
|
81
|
+
float4_folding = PatternMatcher([
|
82
|
+
(UPat(Ops.VECTORIZE, src=UPat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
|
83
|
+
(UPat((Ops.BARRIER, Ops.SINK), src=UPat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
|
84
|
+
])
|
85
|
+
|
86
|
+
# ***** image load valid simplification *****
|
87
|
+
|
88
|
+
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
89
|
+
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0)
|
90
|
+
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
|
91
|
+
|
92
|
+
# wait for it to be image indexed before running simplification
|
93
|
+
if start_idx.dtype.count != 2: return None
|
94
|
+
|
95
|
+
# can drop valid if idx is out of bound when valid is False
|
96
|
+
drop_stmt = []
|
97
|
+
for stmt in split_uop(valid, Ops.AND):
|
98
|
+
X, is_upper_bound, c = parse_valid(stmt)
|
99
|
+
|
100
|
+
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
|
101
|
+
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)):
|
102
|
+
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, Ops.ADD), idx)
|
103
|
+
testidx = testidx.simplify()
|
104
|
+
if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
|
105
|
+
drop_stmt.append(stmt)
|
106
|
+
continue
|
107
|
+
|
108
|
+
# if X <= c, check if it's out of bound when X = c+1
|
109
|
+
# if X >= c, check if it's out of bound when X = c-1
|
110
|
+
test_value = c + 1 if is_upper_bound else c - 1
|
111
|
+
for i,b in zip(idx.src, (buf.dtype.shape[1], buf.dtype.shape[0])):
|
112
|
+
if i.is_increasing():
|
113
|
+
rw = i.substitute({X:X.const_like(test_value)}).simplify()
|
114
|
+
if rw.vmin >= b or rw.vmax < 0:
|
115
|
+
drop_stmt.append(stmt)
|
116
|
+
break
|
117
|
+
|
118
|
+
if not drop_stmt and idx is start_idx: return None
|
119
|
+
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None
|
120
|
+
return buf.index(idx, new_valid)
|
121
|
+
|
122
|
+
# ***** optional patterns *****
|
123
|
+
|
124
|
+
powers_of_two = {2**i:i for i in range(64)}
|
125
|
+
@functools.lru_cache(None)
|
126
|
+
def get_late_rewrite_patterns(ops, force_transcendental=False):
|
127
|
+
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
|
128
|
+
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
|
129
|
+
# rewrite SQRT to xpow 0.5
|
130
|
+
if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
|
131
|
+
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
|
132
|
+
if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
|
133
|
+
# rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
|
134
|
+
if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)]
|
135
|
+
if Ops.SHR in ops:
|
136
|
+
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) and resolve(x>=0,False) else None)]
|
137
|
+
if Ops.NEG in ops:
|
138
|
+
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
|
139
|
+
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
|
140
|
+
if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
|
141
|
+
return PatternMatcher(pat)
|
142
|
+
|
143
|
+
|
144
|
+
# *** uop expander ***
|
145
|
+
|
146
|
+
# TODO: there's a lot shared with gep_through_wmma here
|
147
|
+
def no_vectorized_wmma(wmma:UOp):
|
148
|
+
out_sz = prod(x[1] for x in wmma.arg[6][-1])
|
149
|
+
if wmma.dtype.count == out_sz: return None
|
150
|
+
tsrcs = []
|
151
|
+
for s,sz in zip(wmma.src, wmma.arg[6]):
|
152
|
+
ssz = prod(x[1] for x in sz)
|
153
|
+
tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
|
154
|
+
wmmas = [UOp(Ops.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
|
155
|
+
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
|
156
|
+
return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
|
157
|
+
|
158
|
+
def no_vectorized_alu(alu):
|
159
|
+
if alu.dtype.vcount == 1: return None
|
160
|
+
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
|
161
|
+
return UOp(Ops.VECTORIZE, alu.dtype, alus)
|
162
|
+
|
163
|
+
def no_vectorized_load_store(ls:UOp):
|
164
|
+
idx = ls.src[0]
|
165
|
+
assert isinstance(idx.dtype, PtrDType)
|
166
|
+
if idx.dtype.v == 1: return None
|
167
|
+
tv = [UOp(ls.op, ls.dtype.scalar(), tuple(j.gep(i) for j in ls.src)) for i in range(idx.dtype.v)]
|
168
|
+
return UOp(Ops.VECTORIZE, ls.dtype, tuple(tv))
|
169
|
+
|
170
|
+
def no_vectorized_acc(acc:UOp):
|
171
|
+
if acc.dtype.count == 1: return None
|
172
|
+
alus = tuple(UOp(acc.op, acc.dtype.scalar(),
|
173
|
+
tuple(s.gep(i) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
|
174
|
+
return UOp(Ops.VECTORIZE, acc.dtype, alus)
|
175
|
+
|
176
|
+
devectorize = PatternMatcher([
|
177
|
+
# no ALU on vectorized dtypes
|
178
|
+
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
|
179
|
+
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
180
|
+
(UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
|
181
|
+
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
|
182
|
+
])
|
183
|
+
|
184
|
+
devectorize_load_store = PatternMatcher([
|
185
|
+
# TODO: add vectorized support to transcendental
|
186
|
+
(UPat((Ops.INDEX, Ops.EXP2, Ops.LOG2, Ops.SIN), name="alu"), no_vectorized_alu),
|
187
|
+
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
|
188
|
+
])
|
189
|
+
|
190
|
+
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
|
191
|
+
if store_gate not in [gate.src[0] for gate in val.toposort if gate.op is Ops.IF]: return None
|
192
|
+
# remove the gate from the index
|
193
|
+
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
|
194
|
+
|
195
|
+
load_store_indexing = PatternMatcher([
|
196
|
+
# late fixup of unfoldable image loads
|
197
|
+
(UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
|
198
|
+
# simplify valid
|
199
|
+
(UPat(Ops.AND, name="valid"), simplify_valid),
|
200
|
+
# image load valid idx simplification
|
201
|
+
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
202
|
+
# delete_redundant_gates (after expand)
|
203
|
+
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
|
204
|
+
UPat.var("val"))), delete_redundant_gates),
|
205
|
+
])
|
206
|
+
|
207
|
+
def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:UOp|None=None) -> UOp:
|
208
|
+
# this moves the mask from the indexing to the load/store op for rendering
|
209
|
+
nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
|
210
|
+
return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is Ops.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
|
211
|
+
|
212
|
+
pm_render = PatternMatcher([
|
213
|
+
# for rendering, we use explicit VECTORIZE
|
214
|
+
(UPat(Ops.CONST, name='c'),
|
215
|
+
lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
|
216
|
+
(UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
217
|
+
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
218
|
+
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
219
|
+
# move masks of loads/stores
|
220
|
+
(UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("mask"))),
|
221
|
+
masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
|
222
|
+
# gate any stores that aren't gated with ifs
|
223
|
+
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
224
|
+
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
225
|
+
])
|
226
|
+
|
227
|
+
# *** uop graph ***
|
228
|
+
|
229
|
+
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
230
|
+
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
231
|
+
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
|
232
|
+
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
|
233
|
+
|
234
|
+
if DEVECTORIZE:
|
235
|
+
# devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse
|
236
|
+
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+
|
237
|
+
mulacc_unrolled)
|
238
|
+
else:
|
239
|
+
# new devectorize only for load/store
|
240
|
+
sink = graph_rewrite(sink, sym+devectorize_load_store+mulacc_unrolled)
|
241
|
+
|
242
|
+
# optional pre matcher
|
243
|
+
if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher)
|
244
|
+
|
245
|
+
# final rules for the renderer (without sym)
|
246
|
+
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher)
|
247
|
+
return sink
|
@@ -0,0 +1,121 @@
|
|
1
|
+
# this converts a lowerer program into a vectorized program
|
2
|
+
|
3
|
+
import functools, itertools, operator
|
4
|
+
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod
|
5
|
+
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, graph_rewrite
|
6
|
+
from tinygrad.codegen.symbolic import sym
|
7
|
+
|
8
|
+
def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int:
|
9
|
+
idx, mul = 0, 1
|
10
|
+
for axis,m in args[::-1]:
|
11
|
+
idx += rpk[axis] * mul
|
12
|
+
mul *= m
|
13
|
+
return idx
|
14
|
+
|
15
|
+
def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]:
|
16
|
+
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
|
17
|
+
|
18
|
+
@functools.lru_cache(None)
|
19
|
+
def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]:
|
20
|
+
return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
|
21
|
+
|
22
|
+
def do_expand(root:UOp):
|
23
|
+
expands = [x for x in root.src if x.op is Ops.UNROLL]
|
24
|
+
if len(expands) == 0: return None
|
25
|
+
# NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
|
26
|
+
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is Ops.WMMA else ()
|
27
|
+
if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0:
|
28
|
+
# if there's only one expand arg, it's okay to use it (optimization)
|
29
|
+
expand_args = expands[0].arg
|
30
|
+
else:
|
31
|
+
# otherwise, we sort them and GEP
|
32
|
+
expand_args = tuple(x for x in sorted(dedup(flatten(expands_args))) if x[0] not in exclude_args)
|
33
|
+
expand_sz = prod([x[1] for x in expand_args])
|
34
|
+
new_srcs = []
|
35
|
+
for i,src in enumerate(root.src):
|
36
|
+
if src.op is Ops.UNROLL:
|
37
|
+
if root.op is Ops.IF and i == 0:
|
38
|
+
# IF means OR on first arg to IF
|
39
|
+
new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)]))
|
40
|
+
elif expand_args == src.arg:
|
41
|
+
# just remove the expand
|
42
|
+
new_srcs.append(src.src[0])
|
43
|
+
else:
|
44
|
+
lst = _swizzle_args(expand_args, src.arg, exclude_args)
|
45
|
+
# if the base dtype is > 1, put those at the end
|
46
|
+
if src.dtype.count > 1: lst = flatten([[i*src.dtype.count+j for j in range(src.dtype.count)] for i in lst])
|
47
|
+
new_srcs.append(src.src[0].gep(tuple(lst)))
|
48
|
+
else:
|
49
|
+
# non-UNROLL input
|
50
|
+
if root.op is Ops.IF:
|
51
|
+
# for the first arg of IF, just pass them through ignoring UNROLLS
|
52
|
+
new_srcs.append(src)
|
53
|
+
elif src.dtype.count > 1:
|
54
|
+
# put any input dtype > 1 grouped together
|
55
|
+
new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
|
56
|
+
else:
|
57
|
+
# repeat the arg
|
58
|
+
new_srcs.append(src.broadcast(expand_sz))
|
59
|
+
|
60
|
+
new_arg = root.arg
|
61
|
+
if root.op is Ops.GEP:
|
62
|
+
assert root.dtype.count == 1
|
63
|
+
# is this right?
|
64
|
+
new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
|
65
|
+
nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
|
66
|
+
return UOp(Ops.UNROLL, root.dtype, (nsrc,), expand_args)
|
67
|
+
|
68
|
+
def do_contract(con:UOp):
|
69
|
+
ex = con.src[0]
|
70
|
+
# CONTRACT without UNROLL repeats the element VECTORIZED
|
71
|
+
if ex.op is not Ops.UNROLL: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count)
|
72
|
+
# CONTRACT may remove several axes from UNROLL
|
73
|
+
assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
|
74
|
+
idxs = []
|
75
|
+
for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
|
76
|
+
idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)]
|
77
|
+
return UOp(Ops.UNROLL, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
|
78
|
+
|
79
|
+
expander = PatternMatcher([
|
80
|
+
# double expand
|
81
|
+
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
|
82
|
+
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
83
|
+
# do expansion
|
84
|
+
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
|
85
|
+
Ops.VECTORIZE, Ops.IF), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
86
|
+
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
87
|
+
# vectorize DEFINE_ACC
|
88
|
+
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"),
|
89
|
+
lambda acc,v: acc.replace(dtype=v.dtype, src=(acc.src[0].broadcast(v.dtype.count),)+acc.src[1:])),
|
90
|
+
# BARRIERs aren't actually expanded
|
91
|
+
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
|
92
|
+
lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)),
|
93
|
+
# empty UNROLL is NOOP
|
94
|
+
(UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x),
|
95
|
+
# UNROLL GEP (needed for WMMA, generalize this) -> vectorized ALU
|
96
|
+
(UPat(Ops.UNROLL, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
|
97
|
+
lambda ex,x,y: UOp(Ops.UNROLL, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
|
98
|
+
])
|
99
|
+
|
100
|
+
def create_gate(root:UOp) -> UOp|None:
|
101
|
+
@functools.lru_cache(None)
|
102
|
+
def _gate_srcs(u:UOp, gate:UOp) -> UOp:
|
103
|
+
if u.op is Ops.BARRIER: return u
|
104
|
+
if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER:
|
105
|
+
return UOp(u.op, u.dtype, u.src[:-1]+(UOp(Ops.IF, src=(gate, u.src[-1])),), arg=u.arg)
|
106
|
+
return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg)
|
107
|
+
idx = root.src[0]
|
108
|
+
if idx.op is Ops.CAST: idx = idx.src[0]
|
109
|
+
return None if idx.op is not Ops.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret
|
110
|
+
|
111
|
+
migrate_indexing = PatternMatcher([
|
112
|
+
# create gate MUST BE BEFORE expander
|
113
|
+
(UPat(Ops.STORE, name="root"), create_gate),
|
114
|
+
])
|
115
|
+
|
116
|
+
def expand_rewrite(sink:UOp) -> UOp:
|
117
|
+
# initial symbolic + migrate indexing (remove this)
|
118
|
+
sink = graph_rewrite(sink, sym+migrate_indexing)
|
119
|
+
|
120
|
+
# expand
|
121
|
+
return graph_rewrite(sink, sym+expander)
|