tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/codegen/lowerer.py
CHANGED
@@ -1,161 +1,114 @@
|
|
1
1
|
# the job of the lowerer is to do indexing
|
2
|
-
import functools,
|
3
|
-
from dataclasses import dataclass
|
2
|
+
import functools, operator
|
4
3
|
from typing import cast
|
5
|
-
from
|
6
|
-
from tinygrad.
|
7
|
-
from tinygrad.
|
8
|
-
from tinygrad.helpers import
|
9
|
-
from tinygrad.codegen.expander import expand_rewrite
|
10
|
-
|
11
|
-
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
12
|
-
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
|
13
|
-
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
14
|
-
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
|
15
|
-
except ValueError: return None
|
16
|
-
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from tinygrad.dtype import dtypes, AddrSpace, PtrDType
|
6
|
+
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite
|
7
|
+
from tinygrad.helpers import prod, partition, flatten
|
17
8
|
|
18
9
|
# ***** indexing *****
|
19
|
-
def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
|
20
|
-
# TODO: symbolic shape
|
21
|
-
if not all_int(dims): return dims
|
22
|
-
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
|
23
|
-
for i,m in enumerate(max_sizes):
|
24
|
-
if i < (len(dims)-1) and dims[i] * dims[i+1] <= m:
|
25
|
-
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
|
26
|
-
break
|
27
|
-
else: return None
|
28
|
-
return dims
|
29
|
-
|
30
|
-
def _split_dims(dims, max_sizes):
|
31
|
-
if all(d <= m for d,m in zip(dims, max_sizes)): return dims
|
32
|
-
_dims = list(dims) + [1]*(3-len(dims))
|
33
|
-
for i in range(len(_dims)):
|
34
|
-
while _dims[i] > max_sizes[i]:
|
35
|
-
div = next((d for d in range(2, math.ceil(math.sqrt(_dims[i])) + 1) if (_dims[i] % d) == 0), 1)
|
36
|
-
if div == 1: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
37
|
-
_dims[i], _dims[(i+1)%len(_dims)] = _dims[i]//div, _dims[(i+1)%len(_dims)]*div
|
38
|
-
return tuple(_dims[:2] if _dims[2] == 1 else _dims[0] if _dims[1:3] == [1,1] else _dims)
|
39
|
-
|
40
|
-
def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|None, reverse=False) -> list[UOp]:
|
41
|
-
if reverse: dims = dims[::-1]
|
42
|
-
# try to group first: (a, b, c, d) -> (ab, c, d)
|
43
|
-
limited = (grouped if (grouped := _group_dims(dims, max_sizes)) else dims) if max_sizes is not None else dims
|
44
|
-
# check if grouping failed
|
45
|
-
if max_sizes is not None and len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
46
|
-
# try to split up dims: (a,) -> (b, c)
|
47
|
-
if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims
|
48
|
-
ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
|
49
|
-
if len(limited) < len(dims):
|
50
|
-
ret = []
|
51
|
-
if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}")
|
52
|
-
for idx, contraction_group in zip(raw_idxs, contraction):
|
53
|
-
for c in contraction_group[:-1]:
|
54
|
-
ret.append(idx % dims[c])
|
55
|
-
idx //= dims[c]
|
56
|
-
ret.append(idx)
|
57
|
-
elif len(limited) > len(dims):
|
58
|
-
a, b = len(limited), len(dims)
|
59
|
-
if a == 2 and b == 1: ret = [raw_idxs[0] * limited[1] + raw_idxs[1]]
|
60
|
-
if a == 3 and b == 1: ret = [raw_idxs[0] * (limited[1] * limited[2]) + raw_idxs[1] * limited[2] + raw_idxs[2]]
|
61
|
-
if a == 3 and b == 2: ret = [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
|
62
|
-
return ret[::-1] if reverse else ret
|
63
10
|
|
64
11
|
@dataclass
|
65
12
|
class IndexContext:
|
13
|
+
axis_types: tuple[AxisType, ...]
|
66
14
|
idxs: list[UOp]
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.toposort if x.op is Ops.REDUCE_AXIS))
|
77
|
-
local_loads = [x for x in ast.toposort if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
|
78
|
-
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
|
79
|
-
group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)])
|
80
|
-
global_dims = first_reduce-ki.local_dims
|
81
|
-
|
82
|
-
if opts.has_local:
|
83
|
-
if ki.dont_use_locals:
|
84
|
-
assert ki.local_dims == 0, "can't use locals if there's no local dims"
|
85
|
-
idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True)
|
15
|
+
start: int = 0
|
16
|
+
|
17
|
+
def shape_to_idx(s, axis_types, start=0):
|
18
|
+
# indexes
|
19
|
+
idxs = []
|
20
|
+
for i, (s, at) in enumerate(zip(s, axis_types)):
|
21
|
+
if at in (AxisType.UPCAST, AxisType.UNROLL):
|
22
|
+
assert isinstance(s, int), "needs to be int to upcast/unroll"
|
23
|
+
idxs.append(UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(s), tuple(range(s))),), ((i,s),), tag=1))
|
86
24
|
else:
|
87
|
-
#
|
88
|
-
idxs
|
89
|
-
|
90
|
-
else:
|
91
|
-
# all loops are RANGES
|
92
|
-
idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i) for i,g in enumerate(full_shape[:first_reduce])]
|
25
|
+
# all others are RANGES
|
26
|
+
idxs.append(UOp(Ops.RANGE, dtypes.int, (sint_to_uop(s),), start+i))
|
27
|
+
return idxs
|
93
28
|
|
94
|
-
|
95
|
-
|
96
|
-
|
29
|
+
def get_index(ast:UOp) -> IndexContext:
|
30
|
+
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
|
31
|
+
if len(ast.full_shape) != len(axis_types): axis_types = (AxisType.LOOP,)*len(ast.full_shape)
|
32
|
+
return IndexContext(axis_types, [], 0)
|
97
33
|
|
98
|
-
|
99
|
-
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
|
100
|
-
assert isinstance(g, int), "needs to be int to upcast/unroll"
|
101
|
-
idxs.append(UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
|
34
|
+
# ***** lowering (given index) *****
|
102
35
|
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
36
|
+
def subblock(ctx: IndexContext, full_new_idx: list[UOp], src: UOp):
|
37
|
+
lc = IndexContext(ctx.axis_types, full_new_idx, ctx.start+1000)
|
38
|
+
ctx.start = lc.start
|
39
|
+
return graph_rewrite(src, pm_lowerer, lc, name="subblock", bottom_up=True)
|
107
40
|
|
108
|
-
|
41
|
+
def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
42
|
+
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
|
43
|
+
full_new_idx = list(ctx.idxs)
|
44
|
+
for a in x.axis_arg: full_new_idx[a] = new_idxs[a]
|
109
45
|
|
110
|
-
|
46
|
+
ret = subblock(ctx, full_new_idx, x.src[0])
|
111
47
|
|
112
|
-
def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
113
48
|
# NOTE: always using ridxs is fine here
|
114
|
-
reduce_range, reduce_expand = partition([
|
49
|
+
reduce_range, reduce_expand = partition([full_new_idx[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
|
115
50
|
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
|
116
|
-
alu_op: Ops = x.arg[0]
|
117
|
-
ret = x.src[0]
|
118
51
|
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
119
|
-
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
idx, valid = x.st_arg.to_indexed_uops(
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
if
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
52
|
+
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis), tag=1)
|
53
|
+
# REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group
|
54
|
+
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple(reduce_range), x.arg[0])
|
55
|
+
|
56
|
+
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
|
57
|
+
# TODO: reenable after REDUCE_AXIS is fixed
|
58
|
+
#assert x.src[1].shape == x.src[0].shape, f"shape mismatch on store {x.src[1].shape} != {x.src[0].shape}"
|
59
|
+
|
60
|
+
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
|
61
|
+
idx, valid = x.st_arg.to_indexed_uops(new_idxs)
|
62
|
+
used_idxs = [x for x in UOp.sink(idx, valid).toposort() if x in new_idxs]
|
63
|
+
real_new_idxs = []
|
64
|
+
for i in range(len(x.src[0].shape)):
|
65
|
+
if new_idxs[i] in used_idxs or len(ctx.idxs) <= i: real_new_idxs.append(new_idxs[i])
|
66
|
+
else: real_new_idxs.append(ctx.idxs[i])
|
67
|
+
|
68
|
+
stored = subblock(ctx, real_new_idxs, x.src[1])
|
69
|
+
used_ranges = [x for x in used_idxs if x.op is Ops.RANGE]
|
70
|
+
ret = buf.index(idx, valid).store(stored, *used_ranges)
|
71
|
+
|
72
|
+
# insert BARRIER if we are ending a LOCAL, IF if we are ending a GROUP_REDUCE
|
73
|
+
if cast(PtrDType, buf.dtype).addrspace == AddrSpace.LOCAL and \
|
74
|
+
any(ctx.axis_types[x.arg%1000] in {AxisType.GROUP_REDUCE, AxisType.LOCAL} for x in used_ranges):
|
75
|
+
ret = ret.barrier()
|
76
|
+
range_gates = [x.eq(0) for x in used_ranges if ctx.axis_types[x.arg%1000] == AxisType.GROUP_REDUCE]
|
77
|
+
if len(range_gates): ret = UOp(Ops.IF, src=(functools.reduce(operator.and_, range_gates), ret))
|
78
|
+
return ret
|
79
|
+
|
80
|
+
def fixup_wmma(ctx:IndexContext, x:UOp):
|
81
|
+
if x.tag is not None: return None
|
82
|
+
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
|
83
|
+
full_new_idx = list(ctx.idxs)
|
84
|
+
for a in x.arg[-1]: full_new_idx[a] = new_idxs[a]
|
85
|
+
|
86
|
+
srcs = subblock(ctx, full_new_idx, UOp.sink(*x.src)).src
|
87
|
+
|
88
|
+
# NOTE: this assumes these are expanded. which now shouldn't change anything
|
89
|
+
new_x_arg_m2 = tuple([tuple([(full_new_idx[a].arg[0][0], sz) for a,sz in v]) for v in x.arg[-2]])
|
90
|
+
new_x_arg_m1 = tuple([full_new_idx[a].arg[0][0] for a in x.arg[-1]])
|
91
|
+
return x.replace(src=srcs, arg=x.arg[:-2]+(new_x_arg_m2, new_x_arg_m1), tag=1)
|
148
92
|
|
149
93
|
pm_lowerer = PatternMatcher([
|
94
|
+
# TODO: remove these hacks
|
95
|
+
# hack for old style CONST(VIEW) (now it's just VIEW(CONST))
|
96
|
+
(UPat((Ops.DEFINE_VAR, Ops.CONST), src=(UPat(Ops.VIEW, name="v"),), name="c"), lambda c,v: c.replace(src=()).view(v.arg)),
|
97
|
+
# hack for old style VALID (now it's just VIEW(CONST))
|
98
|
+
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c"), UPat(Ops.CONST, arg=0)), lambda c,v: c.replace(src=()).view(v.arg)),
|
99
|
+
|
100
|
+
# consts and loads
|
101
|
+
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),), name="view"),
|
102
|
+
lambda ctx,view,c: c if all(x.mask is None for x in view.arg.views) else view.arg.to_indexed_uops(ctx.idxs)[1].where(c, c.const_like(0))),
|
103
|
+
(UPat(Ops.LOAD, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"),
|
104
|
+
lambda ctx,buf,x: UOp(Ops.LOAD, x.dtype, (buf.index(*x.st_arg.to_indexed_uops(ctx.idxs)),)+x.src[1:])),
|
105
|
+
|
106
|
+
# reduce/view_const
|
150
107
|
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
151
|
-
(UPat(
|
152
|
-
(UPat(Ops.
|
153
|
-
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
|
154
|
-
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
|
155
|
-
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
|
156
|
-
])
|
108
|
+
(UPat(Ops.STORE, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"), lower_store),
|
109
|
+
(UPat(Ops.WMMA, name="x"), fixup_wmma),
|
157
110
|
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
111
|
+
# axis fixups for WMMA
|
112
|
+
(UPat((Ops.CONTRACT, Ops.UNROLL), name="x"),
|
113
|
+
lambda ctx,x: x.replace(tag=1, arg=tuple([(ctx.idxs[a].arg[0][0], sz) for a,sz in x.arg])) if x.tag is None else None),
|
114
|
+
])
|
@@ -0,0 +1,38 @@
|
|
1
|
+
# opt opinionatedly transforms an ast into an optimized ast using either heuristics or beam search
|
2
|
+
|
3
|
+
from tinygrad.codegen.opt.kernel import Kernel
|
4
|
+
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
|
5
|
+
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops
|
6
|
+
from tinygrad.helpers import NOOPT, BEAM, USE_TC, getenv
|
7
|
+
from tinygrad.renderer import Renderer
|
8
|
+
from tinygrad.uop.spec import type_verify
|
9
|
+
|
10
|
+
def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
|
11
|
+
"""
|
12
|
+
Optimize an AST based on heuristics or BEAM search.
|
13
|
+
|
14
|
+
Args:
|
15
|
+
ast: The Ops.SINK rooted AST
|
16
|
+
renderer: The renderer used to generate the code
|
17
|
+
|
18
|
+
Returns:
|
19
|
+
The Ops.SINK rooted AST transformed to apply the opts and with a KernelInfo in the arg.
|
20
|
+
"""
|
21
|
+
|
22
|
+
k = Kernel(ast, opts=renderer)
|
23
|
+
if ast.arg is not None and ast.arg.opts_to_apply is not None: k.apply_opts(ast.arg.opts_to_apply)
|
24
|
+
elif not NOOPT:
|
25
|
+
if not k.apply_tensor_cores(USE_TC.value): k.apply_opts(hand_coded_optimizations(k))
|
26
|
+
if BEAM >= 1:
|
27
|
+
from tinygrad.codegen.opt.search import beam_search, bufs_from_lin
|
28
|
+
kb = Kernel(ast, opts=renderer)
|
29
|
+
rawbufs = bufs_from_lin(kb, allocate=False)
|
30
|
+
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
31
|
+
ret = k.get_optimized_ast()
|
32
|
+
if __debug__: type_verify(list(ret.toposort()))
|
33
|
+
return ret
|
34
|
+
|
35
|
+
pm_optimize = PatternMatcher([
|
36
|
+
(UPat(Ops.SINK, name="ast"), lambda ctx,ast:
|
37
|
+
get_optimized_ast(ast, ctx) if (ast.arg is None or ast.arg.opts_to_apply is not None) and ast.src[0].st is not None else None),
|
38
|
+
])
|
@@ -0,0 +1,125 @@
|
|
1
|
+
import itertools
|
2
|
+
from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps, KernelOptError, AxisType
|
3
|
+
from tinygrad.helpers import getenv, DEBUG, prod, NOLOCALS
|
4
|
+
from tinygrad.dtype import ImageDType
|
5
|
+
from tinygrad.uop.ops import Ops, resolve
|
6
|
+
|
7
|
+
def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
8
|
+
# make a copy so it does not mutate the input
|
9
|
+
k = k.copy()
|
10
|
+
|
11
|
+
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
12
|
+
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
|
13
|
+
if k.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
14
|
+
k.reduceop is not None and k.reduceop.arg[0] is Ops.ADD and len(k.full_shape) >= 2 and k.opts.has_shared and \
|
15
|
+
(mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
|
16
|
+
st0, st1 = k.sts[k.bufs.index(mulop.src[0])], k.sts[k.bufs.index(mulop.src[1])]
|
17
|
+
strides0, strides1 = st0.real_strides(), st1.real_strides()
|
18
|
+
def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
|
19
|
+
if strides0[first_reduce:=(k.axes_of(AxisType.REDUCE)[0])] == 1 and \
|
20
|
+
not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
|
21
|
+
for global_idx in k.axes_of(AxisType.GLOBAL):
|
22
|
+
if k.full_shape[first_reduce]%MV_THREADS_PER_ROW == 0 and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
23
|
+
if DEBUG >= 3:
|
24
|
+
print(f"MATVEC: {k.full_shape=} {first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
|
25
|
+
if MV_THREADS_PER_ROW > 1: k.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
26
|
+
if MV_BLOCKSIZE > 1: k.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
27
|
+
if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
|
28
|
+
return k.applied_opts
|
29
|
+
|
30
|
+
# are we grouping? (requires local shape support)
|
31
|
+
if resolve(prod(k.sts[0].shape[i] for i in k.upcastable_dims) <= 2048, False):
|
32
|
+
for sz in [16]:
|
33
|
+
try:
|
34
|
+
k.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
|
35
|
+
break
|
36
|
+
except KernelOptError: pass
|
37
|
+
|
38
|
+
# upcast float4 images
|
39
|
+
for buf_index,buf in enumerate(k.bufs):
|
40
|
+
if isinstance(buf.src[0].dtype, ImageDType):
|
41
|
+
if (unit_stride_axes_mul_4 := [i for i in k.sts[buf_index].unit_stride_axes(ignore_valid=True) if k.sts[buf_index].shape[i]%4 == 0]):
|
42
|
+
if (axis:=unit_stride_axes_mul_4[0]) in k.upcastable_dims:
|
43
|
+
k.apply_opt(Opt(OptOps.UPCAST, axis, 4))
|
44
|
+
elif axis in k.unrollable_dims:
|
45
|
+
k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims.index(axis), 4))
|
46
|
+
|
47
|
+
# no more opt if we are grouping
|
48
|
+
if k.group_for_reduces: return k.applied_opts
|
49
|
+
|
50
|
+
# **** below this line need to be optional and benchmarked ****
|
51
|
+
|
52
|
+
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
53
|
+
to_upcast: list[int] = []
|
54
|
+
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
|
55
|
+
for axis in k.upcastable_dims:
|
56
|
+
if k.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in k.sts) and \
|
57
|
+
prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
|
58
|
+
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
59
|
+
to_upcast.append(axis)
|
60
|
+
for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
61
|
+
|
62
|
+
# potentially do more upcasts of non reduce axes based on a heuristic
|
63
|
+
is_dsp = k.opts is not None and k.opts.device == "DSP"
|
64
|
+
upcasted_axis: set[int] = set()
|
65
|
+
while resolve(prod(k.sts[0].shape[i] for i in k.upcastable_dims) >= 1024):
|
66
|
+
xb_choices = []
|
67
|
+
# consider all upcastable axes with 3 or 4 upcast (128 on the DSP)
|
68
|
+
for axis, upcast_amount in itertools.product(k.upcastable_dims, ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
|
69
|
+
# if we haven't upcasted it, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
70
|
+
if axis in upcasted_axis or k.full_shape[axis]%upcast_amount != 0: continue
|
71
|
+
if any(st.views[-1].strides[axis] == 0 and \
|
72
|
+
all(x != 0 for t,x in zip(k.axis_types, st.real_strides()) if t in (AxisType.UPCAST, AxisType.UNROLL)) for st in k.sts):
|
73
|
+
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts),
|
74
|
+
sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount))
|
75
|
+
if xb_choices:
|
76
|
+
xb_choices = sorted(xb_choices)
|
77
|
+
if DEBUG >= 4: print(f"more upcast axis : {xb_choices}")
|
78
|
+
k.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
|
79
|
+
upcasted_axis.add(xb_choices[0][2])
|
80
|
+
else: break
|
81
|
+
|
82
|
+
# if last reduce dim is small(ish), loop unroll the reduce
|
83
|
+
# NOTE: this can fail on multireduce with mismatching dimensions, this is okay
|
84
|
+
try:
|
85
|
+
upcast_size = prod(k.full_shape[a] for a in k.axes_of(AxisType.UPCAST, AxisType.UNROLL))
|
86
|
+
if k.unrollable_dims and (upcast_size <= 4 or not k.axes_of(AxisType.UNROLL)) and (upcast_size < 64):
|
87
|
+
if (s:=k.full_shape[k.unrollable_dims[-1]]) <= 32:
|
88
|
+
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0))
|
89
|
+
# if it's small, upcast a second reduce dimension too
|
90
|
+
if k.unrollable_dims and s <= 3 and k.full_shape[k.unrollable_dims[-1]] <= 3:
|
91
|
+
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0))
|
92
|
+
else:
|
93
|
+
for splits in [4]:
|
94
|
+
if k.full_shape[axis:=k.unrollable_dims[-1]]%splits == 0:
|
95
|
+
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, splits))
|
96
|
+
break
|
97
|
+
except KernelOptError: pass
|
98
|
+
|
99
|
+
# if nothing at all is upcasted and it's easy to, do an upcast
|
100
|
+
for splits in [4]:
|
101
|
+
# TODO: somehow this never hits a reduce
|
102
|
+
if not k.upcasted and k.upcastable_dims and k.full_shape[k.upcastable_dims[-1]] % splits == 0:
|
103
|
+
k.apply_opt(Opt(OptOps.UPCAST, k.upcastable_dims[-1], splits))
|
104
|
+
|
105
|
+
# **** local groups ****
|
106
|
+
|
107
|
+
if k.opts.has_local:
|
108
|
+
if NOLOCALS:
|
109
|
+
k.apply_opt(Opt(OptOps.NOLOCALS))
|
110
|
+
else:
|
111
|
+
# prioritize making expand axes local
|
112
|
+
local_axis_ranking = [(any(st.views[-1].strides[axis] == 0 for st in k.sts), axis) for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP)]
|
113
|
+
to_local: list[tuple[int, int]] = []
|
114
|
+
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
115
|
+
local_size = prod(sz for _, sz in to_local)
|
116
|
+
local_sz: int|None = next((x for x in ([32] * (axis == 0) + [16,8,4,3,2]) if k.full_shape[axis] % x == 0 and local_size * x <= 128), None)
|
117
|
+
if local_sz is not None: to_local.append((axis, local_sz))
|
118
|
+
deleted_shape = 0
|
119
|
+
for axis, local_sz in sorted(to_local[:3]):
|
120
|
+
axis = axis - deleted_shape
|
121
|
+
will_delete_shape = local_sz == k.full_shape[axis]
|
122
|
+
k.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
|
123
|
+
if will_delete_shape: deleted_shape += 1
|
124
|
+
|
125
|
+
return k.applied_opts
|