tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/codegen/devectorizer.py
CHANGED
@@ -1,88 +1,13 @@
|
|
1
|
-
from typing import
|
2
|
-
import functools, operator
|
1
|
+
from typing import Any, cast
|
2
|
+
import functools, operator, itertools
|
3
3
|
from collections import defaultdict
|
4
|
-
from
|
5
|
-
from tinygrad.
|
6
|
-
from tinygrad.ops import graph_rewrite, GroupOp
|
7
|
-
from tinygrad.
|
8
|
-
from tinygrad.helpers import getenv, flatten,
|
9
|
-
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from tinygrad.dtype import dtypes, ImageDType, PtrDType, DType, AddrSpace
|
6
|
+
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
|
7
|
+
from tinygrad.uop.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
|
8
|
+
from tinygrad.helpers import getenv, flatten, AMX, prod, partition
|
10
9
|
from tinygrad.renderer import Renderer
|
11
10
|
|
12
|
-
# ***** float4/image store handling *****
|
13
|
-
|
14
|
-
def fold_expanded(ex, buf):
|
15
|
-
new_srcs = dedup(list(ex.src))
|
16
|
-
old_new_srcs = new_srcs[:]
|
17
|
-
is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType)
|
18
|
-
|
19
|
-
# TODO: get the device from the buffer somehow
|
20
|
-
# NOTE: this can't be Device.DEFAULT because it opens devices
|
21
|
-
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None
|
22
|
-
lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
|
23
|
-
|
24
|
-
# first, extract all the relevant offsets
|
25
|
-
offsets_rootsrc: defaultdict[Any, dict] = defaultdict(dict)
|
26
|
-
for i,s in enumerate(new_srcs):
|
27
|
-
idx = s.src[0].src[1]
|
28
|
-
if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue
|
29
|
-
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
30
|
-
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
31
|
-
else: root_src, arg = idx, 0
|
32
|
-
# add gates for gated
|
33
|
-
if len(s.src[0].src) == 3: root_src = (s.src[0].src[2], root_src)
|
34
|
-
assert arg not in offsets_rootsrc[root_src], f"{offsets_rootsrc[root_src][arg]} != {i} with {len(s.src)} sources"
|
35
|
-
offsets_rootsrc[root_src][arg] = i
|
36
|
-
|
37
|
-
# then rewrite everything we can
|
38
|
-
used: set[tuple[UOp, UOp]] = set()
|
39
|
-
for rootsrc, offsets in offsets_rootsrc.items():
|
40
|
-
for o in offsets:
|
41
|
-
for fold_length in lengths:
|
42
|
-
if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)):
|
43
|
-
load_1 = new_srcs[offsets[o]]
|
44
|
-
new_src = list(load_1.src)
|
45
|
-
oidx = new_src[0].src[1]
|
46
|
-
if oidx.divides(fold_length) is None: continue
|
47
|
-
if is_image:
|
48
|
-
# for images, we rewrite the index. it must evenly divide 4 from the above check
|
49
|
-
new_src[0] = buf.index(
|
50
|
-
UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
|
51
|
-
rootsrc[0] if isinstance(rootsrc, tuple) else None)
|
52
|
-
else:
|
53
|
-
# for non image, we upcast the index pointer
|
54
|
-
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(size=new_src[0].dtype.size, local=new_src[0].dtype.local))
|
55
|
-
# generate the folded new_srcs
|
56
|
-
if is_load:
|
57
|
-
new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
|
58
|
-
for i in range(fold_length): new_srcs[offsets[o+i]] = new_load.gep(i)
|
59
|
-
else: # vectorize the store
|
60
|
-
new_src[1] = UOp(Ops.VECTORIZE, new_src[1].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[1] for i in range(fold_length)))
|
61
|
-
for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(Ops.STORE, dtypes.void, tuple(new_src)) if i == 0 else None
|
62
|
-
used.update((rootsrc,o+i) for i in range(fold_length))
|
63
|
-
|
64
|
-
# dedup expand for LOAD
|
65
|
-
if is_load and len(old_new_srcs) != len(ex.src): new_srcs = [new_srcs[old_new_srcs.index(s)] for s in ex.src]
|
66
|
-
# remove Nones for STORE
|
67
|
-
return UOp(ex.op, ex.dtype, tuple(x for x in new_srcs if x is not None), ex.arg) if len(used) else None
|
68
|
-
|
69
|
-
def fix_unfoldable_image_load(load:UOp, buf:UOp):
|
70
|
-
if not isinstance(buf.dtype, ImageDType) or (oidx:=load.src[0].src[1]).dtype.count == 2: return None
|
71
|
-
id4 = oidx % 4
|
72
|
-
new_src = list(load.src)
|
73
|
-
# TODO: copied logic from above
|
74
|
-
new_src[0] = load.src[0].src[0].index(
|
75
|
-
UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
|
76
|
-
load.src[0].src[2] if len(load.src[0].src) == 3 else None)
|
77
|
-
vec_load = UOp(Ops.LOAD, load.dtype.vec(4), tuple(new_src))
|
78
|
-
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan')))
|
79
|
-
|
80
|
-
buf_idx_pat = UPat(Ops.INDEX, src=(UPat.var("buf"),), allow_any_len=True)
|
81
|
-
float4_folding = PatternMatcher([
|
82
|
-
(UPat(Ops.VECTORIZE, src=UPat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
|
83
|
-
(UPat((Ops.BARRIER, Ops.SINK), src=UPat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
|
84
|
-
])
|
85
|
-
|
86
11
|
# ***** image load valid simplification *****
|
87
12
|
|
88
13
|
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
@@ -95,7 +20,8 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
|
95
20
|
# can drop valid if idx is out of bound when valid is False
|
96
21
|
drop_stmt = []
|
97
22
|
for stmt in split_uop(valid, Ops.AND):
|
98
|
-
X, is_upper_bound, c = parse_valid(stmt)
|
23
|
+
try: X, is_upper_bound, c = parse_valid(stmt)
|
24
|
+
except ValueError: return None
|
99
25
|
|
100
26
|
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
|
101
27
|
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)):
|
@@ -119,27 +45,173 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
|
119
45
|
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None
|
120
46
|
return buf.index(idx, new_valid)
|
121
47
|
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
|
139
|
-
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
|
140
|
-
if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
|
141
|
-
return PatternMatcher(pat)
|
48
|
+
def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
|
49
|
+
if store_gate not in [gate.src[0] for gate in val.toposort() if gate.op is Ops.IF]: return None
|
50
|
+
# remove the gate from the index
|
51
|
+
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val, *store.src[2:])
|
52
|
+
|
53
|
+
load_store_indexing = PatternMatcher([
|
54
|
+
# simplify valid
|
55
|
+
(UPat(Ops.AND, name="valid"), simplify_valid),
|
56
|
+
# image load valid idx simplification
|
57
|
+
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
58
|
+
# index True is just Index
|
59
|
+
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)),
|
60
|
+
# delete_redundant_gates (after expand)
|
61
|
+
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
|
62
|
+
UPat.var("val")), name="store", allow_any_len=True), delete_redundant_gates),
|
63
|
+
])
|
142
64
|
|
65
|
+
# ***** load/store grouping *****
|
66
|
+
|
67
|
+
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
68
|
+
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
|
69
|
+
# generate the individual indexes
|
70
|
+
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
|
71
|
+
symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}")
|
72
|
+
# extract all the relevant offsets
|
73
|
+
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
|
74
|
+
for i in range(vec.dtype.count):
|
75
|
+
idx: Any = midx.src[i].src[1]
|
76
|
+
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
77
|
+
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
78
|
+
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
79
|
+
else: root_src, arg = idx, 0
|
80
|
+
if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
|
81
|
+
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
|
82
|
+
|
83
|
+
# the buf.dtype is always a pointer
|
84
|
+
ptrdtype = cast(PtrDType, buf.dtype)
|
85
|
+
|
86
|
+
# then rewrite everything we can into groups
|
87
|
+
ret = []
|
88
|
+
idxs: list[int|None] = [None]*vec.dtype.count
|
89
|
+
global_offset = 0
|
90
|
+
for offsets in offsets_rootsrc.values():
|
91
|
+
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
|
92
|
+
for grp in grouped_offsets:
|
93
|
+
# get the index offset for this element. using [0] is okay, because they are the same
|
94
|
+
lidx = midx.src[offsets[grp[0]][0]]
|
95
|
+
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace))
|
96
|
+
# set the idxs of the output
|
97
|
+
for i,g in enumerate(grp):
|
98
|
+
for oo in offsets[g]: idxs[oo] = global_offset+i
|
99
|
+
# add this lidx to the CAT
|
100
|
+
ret.append(lidx)
|
101
|
+
global_offset += len(grp)
|
102
|
+
assert None not in idxs, f"some idxs are missing {idxs}"
|
103
|
+
# this base thing is for image, we want the CAT to be a normal pointer
|
104
|
+
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret))
|
105
|
+
return post_cat.gep(tuple(cast(list[int], idxs)))
|
106
|
+
|
107
|
+
def cat_after_store(cat:UOp, data:UOp, sto:UOp):
|
108
|
+
# TODO: this is written in many places
|
109
|
+
offset = 0
|
110
|
+
ret: list[UOp] = []
|
111
|
+
for s in cat.src:
|
112
|
+
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))), *sto.src[2:]))
|
113
|
+
offset += s.dtype.count
|
114
|
+
return UOp(Ops.NOOP, src=tuple(ret))
|
115
|
+
|
116
|
+
def gep_on_store(gep:UOp, st:UOp, sto:UOp):
|
117
|
+
# NOTE: we need to invert the gep here, but it may be an expanding gep
|
118
|
+
# fake argsort. TODO: handle duplicates
|
119
|
+
a = {}
|
120
|
+
for i,x in enumerate(gep.arg): a[x] = i
|
121
|
+
new_arg = tuple(x[1] for x in sorted(a.items()))
|
122
|
+
return gep.src[0].store(st.gep(new_arg), *sto.src[2:])
|
123
|
+
|
124
|
+
load_store_folding = PatternMatcher([
|
125
|
+
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"))), expand_index),
|
126
|
+
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"),
|
127
|
+
UPat.var("mask"))), expand_index),
|
128
|
+
# GEP after LOAD
|
129
|
+
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
|
130
|
+
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
|
131
|
+
# GEP on data of STORE
|
132
|
+
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), allow_any_len=True, name="sto"), gep_on_store),
|
133
|
+
# put PTRCAT after LOAD
|
134
|
+
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
|
135
|
+
lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
|
136
|
+
# put PTRCAT after STORE
|
137
|
+
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), allow_any_len=True, name="sto"), cat_after_store),
|
138
|
+
])
|
139
|
+
|
140
|
+
# *** correct load/store ***
|
141
|
+
|
142
|
+
def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
143
|
+
# this splits loads and stores into multiple chunks
|
144
|
+
|
145
|
+
# if there's only one element to load/store, no splitting needed
|
146
|
+
if (sz:=ls.src[0].dtype.count) == 1: return None
|
147
|
+
buf = idx.src[0]
|
148
|
+
|
149
|
+
# determine fold lengths
|
150
|
+
lengths = []
|
151
|
+
must_divide = True
|
152
|
+
if ctx is not None and ctx.device == "DSP":
|
153
|
+
lengths = [128,64,32,16,8,4]
|
154
|
+
must_divide = False
|
155
|
+
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
|
156
|
+
pass
|
157
|
+
elif cast(PtrDType, buf.dtype).addrspace == AddrSpace.REG:
|
158
|
+
pass
|
159
|
+
elif isinstance(buf.dtype, ImageDType):
|
160
|
+
lengths = [4]
|
161
|
+
elif ctx is not None and ctx.supports_float4:
|
162
|
+
# TODO: a better way to get this than ctx
|
163
|
+
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2])
|
164
|
+
lengths.append(1) # worst case, it's not folded
|
165
|
+
|
166
|
+
# filter fold lengths that don't divide
|
167
|
+
if must_divide: lengths = [x for x in lengths if idx.src[1].divides(x) is not None]
|
168
|
+
|
169
|
+
# split based on the fold lengths
|
170
|
+
global_offset = 0
|
171
|
+
ret = []
|
172
|
+
ptrdtype = cast(PtrDType, buf.dtype)
|
173
|
+
while global_offset < sz:
|
174
|
+
# with 1 at the end of the lengths list, this will always hit
|
175
|
+
for fold_length in lengths:
|
176
|
+
if global_offset+fold_length > sz: continue
|
177
|
+
lidx = buf.index(idx.src[1] + global_offset, idx.src[2] if len(idx.src) > 2 else None)
|
178
|
+
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace))
|
179
|
+
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
|
180
|
+
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
|
181
|
+
global_offset += fold_length
|
182
|
+
break
|
183
|
+
|
184
|
+
# if it wasn't split, we return None. otherwise we CAT them
|
185
|
+
if len(ret) <= 1: return None
|
186
|
+
return UOp(Ops.CAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp(Ops.NOOP, src=tuple(ret))
|
187
|
+
|
188
|
+
def image_fixup(ls:UOp):
|
189
|
+
# normal image load or store, with the CAST from expand_index
|
190
|
+
if ls.src[0].op is Ops.CAST and isinstance(image_dtype:=ls.src[0].src[0].dtype, ImageDType):
|
191
|
+
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
|
192
|
+
idx = ls.src[0].src[0]
|
193
|
+
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
|
194
|
+
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
|
195
|
+
return ls.replace(src=(idx,)+ls.src[1:])
|
196
|
+
|
197
|
+
# this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores
|
198
|
+
if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].dtype != dtypes.int.vec(2):
|
199
|
+
assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it"
|
200
|
+
idx = ls.src[0]
|
201
|
+
id4 = idx.src[1] % 4
|
202
|
+
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
|
203
|
+
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
|
204
|
+
vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:])
|
205
|
+
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan')))
|
206
|
+
|
207
|
+
return None
|
208
|
+
|
209
|
+
correct_load_store = PatternMatcher([
|
210
|
+
# split LOAD/STORE
|
211
|
+
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.INDEX, name="idx").cast(),), name="ls", allow_any_len=True), split_load_store),
|
212
|
+
# image indexing, including unfoldable images
|
213
|
+
(UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup),
|
214
|
+
])
|
143
215
|
|
144
216
|
# *** uop expander ***
|
145
217
|
|
@@ -155,93 +227,164 @@ def no_vectorized_wmma(wmma:UOp):
|
|
155
227
|
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
|
156
228
|
return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
|
157
229
|
|
158
|
-
def no_vectorized_alu(alu):
|
230
|
+
def no_vectorized_alu(alu:UOp):
|
159
231
|
if alu.dtype.vcount == 1: return None
|
160
232
|
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
|
161
233
|
return UOp(Ops.VECTORIZE, alu.dtype, alus)
|
162
234
|
|
163
|
-
def
|
164
|
-
idx = ls.src[0]
|
165
|
-
assert isinstance(idx.dtype, PtrDType)
|
166
|
-
if idx.dtype.v == 1: return None
|
167
|
-
tv = [UOp(ls.op, ls.dtype.scalar(), tuple(j.gep(i) for j in ls.src)) for i in range(idx.dtype.v)]
|
168
|
-
return UOp(Ops.VECTORIZE, ls.dtype, tuple(tv))
|
169
|
-
|
170
|
-
def no_vectorized_acc(acc:UOp):
|
235
|
+
def no_vectorized_acc(acc:UOp, c:UOp):
|
171
236
|
if acc.dtype.count == 1: return None
|
172
|
-
|
173
|
-
|
174
|
-
return UOp(Ops.
|
237
|
+
assert c.arg == 0, "this only supports index 0"
|
238
|
+
new_acc = acc.replace(dtype=acc.dtype.base.scalar().ptr(acc.dtype.count, cast(PtrDType, acc.dtype).addrspace))
|
239
|
+
return UOp(Ops.PTRCAT, acc.dtype, tuple([new_acc.index(UOp.const(dtypes.int, i)) for i in range(acc.dtype.count)]))
|
175
240
|
|
176
241
|
devectorize = PatternMatcher([
|
177
242
|
# no ALU on vectorized dtypes
|
178
|
-
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST
|
243
|
+
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
|
179
244
|
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
180
|
-
(UPat(Ops.
|
181
|
-
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
|
245
|
+
(UPat(Ops.DEFINE_REG, name="acc").index(UPat.cvar("c")), no_vectorized_acc),
|
182
246
|
])
|
183
247
|
|
184
|
-
devectorize_load_store = PatternMatcher([
|
185
|
-
# TODO: add vectorized support to transcendental
|
186
|
-
(UPat((Ops.INDEX, Ops.EXP2, Ops.LOG2, Ops.SIN), name="alu"), no_vectorized_alu),
|
187
|
-
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
|
188
|
-
])
|
189
|
-
|
190
|
-
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
|
191
|
-
if store_gate not in [gate.src[0] for gate in val.toposort if gate.op is Ops.IF]: return None
|
192
|
-
# remove the gate from the index
|
193
|
-
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
|
194
|
-
|
195
|
-
load_store_indexing = PatternMatcher([
|
196
|
-
# late fixup of unfoldable image loads
|
197
|
-
(UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
|
198
|
-
# simplify valid
|
199
|
-
(UPat(Ops.AND, name="valid"), simplify_valid),
|
200
|
-
# image load valid idx simplification
|
201
|
-
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
202
|
-
# delete_redundant_gates (after expand)
|
203
|
-
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
|
204
|
-
UPat.var("val"))), delete_redundant_gates),
|
205
|
-
])
|
206
|
-
|
207
|
-
def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:UOp|None=None) -> UOp:
|
208
|
-
# this moves the mask from the indexing to the load/store op for rendering
|
209
|
-
nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
|
210
|
-
return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is Ops.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
|
211
|
-
|
212
248
|
pm_render = PatternMatcher([
|
213
249
|
# for rendering, we use explicit VECTORIZE
|
214
250
|
(UPat(Ops.CONST, name='c'),
|
215
251
|
lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
|
216
252
|
(UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
217
253
|
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
254
|
+
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
|
218
255
|
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
219
|
-
#
|
220
|
-
(UPat(
|
221
|
-
|
256
|
+
# give any loads that are masked an alt value
|
257
|
+
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
|
258
|
+
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None),
|
222
259
|
# gate any stores that aren't gated with ifs
|
223
|
-
(UPat(Ops.STORE,
|
224
|
-
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(
|
260
|
+
(UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True),
|
261
|
+
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \
|
262
|
+
len(store.src) <= 2 or store.src[2].op != Ops.IF else None),
|
225
263
|
])
|
226
264
|
|
227
|
-
# ***
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
if
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
#
|
246
|
-
|
247
|
-
|
265
|
+
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
266
|
+
|
267
|
+
@dataclass
|
268
|
+
class ReduceContext:
|
269
|
+
acc_num: int = 0
|
270
|
+
|
271
|
+
def horizontal_reduce(inp:UOp, out_dtype:DType) -> list[UOp]:
|
272
|
+
# if this has a horizontal reduction component, do that first
|
273
|
+
if inp.dtype != out_dtype:
|
274
|
+
# NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7]
|
275
|
+
horizontal_amount = inp.dtype.count//out_dtype.count
|
276
|
+
return [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)]
|
277
|
+
return [inp]
|
278
|
+
|
279
|
+
def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
280
|
+
inp, reduce_range = red.src[0], red.src[1:]
|
281
|
+
lst = horizontal_reduce(inp, red.dtype)
|
282
|
+
assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
|
283
|
+
# if we have a range
|
284
|
+
if len(reduce_range) != 0:
|
285
|
+
topo = inp.toposort()
|
286
|
+
stored_ranges = flatten([x.src[2:] for x in topo if x.op is Ops.STORE])
|
287
|
+
input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in stored_ranges])
|
288
|
+
identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar()))
|
289
|
+
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
|
290
|
+
do_store = acc.store(identity, UOp(Ops.NOOP, src=input_ranges)) if len(input_ranges) else acc.store(identity)
|
291
|
+
lst = [acc.load(do_store, *reduce_range)] + lst # put acc as the first element
|
292
|
+
ctx.acc_num += 1
|
293
|
+
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
294
|
+
return acc.load(acc.store(ret, *reduce_range)) if len(reduce_range) != 0 else ret
|
295
|
+
|
296
|
+
def no_vectorized_reduce(inp:UOp, red:UOp):
|
297
|
+
if inp.dtype != red.dtype:
|
298
|
+
red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), horizontal_reduce(inp, red.dtype)),)+red.src[1:])
|
299
|
+
if red.dtype.vcount == 1: return red
|
300
|
+
# no_vectorize_alu ignoring ranges
|
301
|
+
if red.dtype.vcount == 1: return None
|
302
|
+
alus = tuple(UOp(red.op, red.dtype.scalar(), (red.src[0].gep(i),)+red.src[1:], red.arg) for i in range(red.dtype.vcount))
|
303
|
+
return UOp(Ops.VECTORIZE, red.dtype, alus)
|
304
|
+
|
305
|
+
def reduce_rangeless(red:UOp):
|
306
|
+
# TODO: share code with reduce_unparented
|
307
|
+
if red.arg not in {Ops.ADD, Ops.MAX}: return None
|
308
|
+
if red.src[0].dtype != red.dtype: return None
|
309
|
+
if any(x.op in {Ops.RANGE} for x in red.src[0].toposort()): return None
|
310
|
+
ret = red.src[0]
|
311
|
+
if red.arg is Ops.ADD:
|
312
|
+
for r in red.src[1:]:
|
313
|
+
ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
314
|
+
return ret
|
315
|
+
|
316
|
+
def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparents)
|
317
|
+
|
318
|
+
pm_reduce_collapse = PatternMatcher([
|
319
|
+
# lift x+y out of reduce on lt
|
320
|
+
((UPat.var("x")+UPat.var("y")) < UPat.var("c"), lambda x,y,c: (x < (c-y)) if no_range(y) and no_range(c) else None),
|
321
|
+
# lift x*y out of reduce
|
322
|
+
((UPat.var("x")*UPat.var("y")) < UPat.var("c"),
|
323
|
+
lambda x,y,c: (x < ((c+y-1) // y)) if no_range(y) and no_range(c) and y.vmin > 0 else None),
|
324
|
+
# lift x+y out of reduce on ne
|
325
|
+
((UPat.var("x")+UPat.var("y")) != UPat.var("c"), lambda x,y,c: (x != (c-y)) if no_range(y) and no_range(c) else None),
|
326
|
+
# fold the range
|
327
|
+
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(0, UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True),
|
328
|
+
lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
|
329
|
+
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True),
|
330
|
+
lambda r,cut,val: cut.maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
|
331
|
+
# REDUCE on ADD
|
332
|
+
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
333
|
+
lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)),
|
334
|
+
# MUL casted bool
|
335
|
+
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")),
|
336
|
+
lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)),
|
337
|
+
# WHERE on LOAD (works on max too)
|
338
|
+
(UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True),
|
339
|
+
lambda buf,idx,gate: buf.index(idx, gate).load()),
|
340
|
+
(UPat.var("gate").where(0, UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True),
|
341
|
+
lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()),
|
342
|
+
# INDEX on RANGE / gated RANGE
|
343
|
+
(UPat.var("buf").index(UPat.var("expr"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted())),
|
344
|
+
lambda buf,r,idx,expr: buf.index(expr.substitute({r:idx.cast(r.dtype)}), (idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))),
|
345
|
+
# AND on WHERE
|
346
|
+
((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \
|
347
|
+
.where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
348
|
+
lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)),
|
349
|
+
# remove REDUCEs that no longer have a RANGE in the src
|
350
|
+
(UPat(Ops.REDUCE, name="red"), reduce_rangeless),
|
351
|
+
# devectorize REDUCE
|
352
|
+
(UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce),
|
353
|
+
# index/load/where. TODO: this is more aggressive than needed
|
354
|
+
(UPat((Ops.INDEX, Ops.LOAD, Ops.WHERE), name="alu"), no_vectorized_alu),
|
355
|
+
])+sym
|
356
|
+
|
357
|
+
def reduce_collapse(red:UOp):
|
358
|
+
included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:]))
|
359
|
+
if any(x.op in {Ops.STORE, Ops.REDUCE} for x in included): return None
|
360
|
+
replaces: dict[UOp, UOp] = {}
|
361
|
+
for u in included:
|
362
|
+
for s in u.src:
|
363
|
+
if s in not_included and s not in replaces and s.op not in {Ops.CONST, Ops.VCONST, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}:
|
364
|
+
replaces[s] = UOp(Ops.DEFINE_VAR, dtype=s.dtype, arg=(f'in{len(replaces)}', s.vmin, s.vmax))
|
365
|
+
collapse_fxn = red.substitute(replaces)
|
366
|
+
sink = graph_rewrite(collapse_fxn, pm_reduce_collapse, name="reduce_collapse")
|
367
|
+
# TODO: why is REDUCE needed here and just RANGE isn't enough?
|
368
|
+
if any(x.op in {Ops.REDUCE, Ops.RANGE} for x in sink.toposort()): return None
|
369
|
+
return sink.substitute({v:k for k,v in replaces.items()})
|
370
|
+
|
371
|
+
def reduce_unparented(red:UOp):
|
372
|
+
if red.arg not in {Ops.ADD, Ops.MAX}: return None
|
373
|
+
reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents)
|
374
|
+
if len(reduce_unparented) == 0: return None
|
375
|
+
ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) if len(reduce_parented) or red.dtype != red.src[0].dtype else red.src[0]
|
376
|
+
if red.arg is Ops.ADD:
|
377
|
+
for r in reduce_unparented: ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
378
|
+
return ret
|
379
|
+
|
380
|
+
pm_reduce = PatternMatcher([
|
381
|
+
# remove any ranges from a REDUCE that aren't referenced in the reduce source
|
382
|
+
(UPat(Ops.REDUCE, name="red"), reduce_unparented),
|
383
|
+
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
|
384
|
+
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),
|
385
|
+
# REDUCE -> DEFINE_ACC+ASSIGN
|
386
|
+
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
|
387
|
+
# tensor core built in accumulate
|
388
|
+
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
|
389
|
+
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
390
|
+
])+sym
|
tinygrad/codegen/expander.py
CHANGED
@@ -2,8 +2,7 @@
|
|
2
2
|
|
3
3
|
import functools, itertools, operator
|
4
4
|
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod
|
5
|
-
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp
|
6
|
-
from tinygrad.codegen.symbolic import sym
|
5
|
+
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp
|
7
6
|
|
8
7
|
def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int:
|
9
8
|
idx, mul = 0, 1
|
@@ -15,7 +14,7 @@ def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) ->
|
|
15
14
|
def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]:
|
16
15
|
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
|
17
16
|
|
18
|
-
@functools.
|
17
|
+
@functools.cache
|
19
18
|
def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]:
|
20
19
|
return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
|
21
20
|
|
@@ -50,6 +49,9 @@ def do_expand(root:UOp):
|
|
50
49
|
if root.op is Ops.IF:
|
51
50
|
# for the first arg of IF, just pass them through ignoring UNROLLS
|
52
51
|
new_srcs.append(src)
|
52
|
+
elif root.op in {Ops.REDUCE, Ops.STORE} and src.op is Ops.RANGE:
|
53
|
+
# for any range args of REDUCE, pass them through
|
54
|
+
new_srcs.append(src)
|
53
55
|
elif src.dtype.count > 1:
|
54
56
|
# put any input dtype > 1 grouped together
|
55
57
|
new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
|
@@ -81,12 +83,9 @@ expander = PatternMatcher([
|
|
81
83
|
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
|
82
84
|
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
83
85
|
# do expansion
|
84
|
-
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX,
|
85
|
-
Ops.VECTORIZE, Ops.IF), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
86
|
+
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX,
|
87
|
+
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
86
88
|
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
87
|
-
# vectorize DEFINE_ACC
|
88
|
-
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"),
|
89
|
-
lambda acc,v: acc.replace(dtype=v.dtype, src=(acc.src[0].broadcast(v.dtype.count),)+acc.src[1:])),
|
90
89
|
# BARRIERs aren't actually expanded
|
91
90
|
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
|
92
91
|
lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)),
|
@@ -98,7 +97,7 @@ expander = PatternMatcher([
|
|
98
97
|
])
|
99
98
|
|
100
99
|
def create_gate(root:UOp) -> UOp|None:
|
101
|
-
@functools.
|
100
|
+
@functools.cache
|
102
101
|
def _gate_srcs(u:UOp, gate:UOp) -> UOp:
|
103
102
|
if u.op is Ops.BARRIER: return u
|
104
103
|
if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER:
|
@@ -112,10 +111,3 @@ migrate_indexing = PatternMatcher([
|
|
112
111
|
# create gate MUST BE BEFORE expander
|
113
112
|
(UPat(Ops.STORE, name="root"), create_gate),
|
114
113
|
])
|
115
|
-
|
116
|
-
def expand_rewrite(sink:UOp) -> UOp:
|
117
|
-
# initial symbolic + migrate indexing (remove this)
|
118
|
-
sink = graph_rewrite(sink, sym+migrate_indexing)
|
119
|
-
|
120
|
-
# expand
|
121
|
-
return graph_rewrite(sink, sym+expander)
|