tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,89 @@
|
|
1
|
+
import math
|
2
|
+
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType
|
3
|
+
from tinygrad.helpers import all_int
|
4
|
+
from tinygrad.dtype import dtypes
|
5
|
+
from tinygrad.shape.view import get_contraction
|
6
|
+
from tinygrad.renderer import Renderer
|
7
|
+
|
8
|
+
def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
|
9
|
+
# TODO: symbolic shape
|
10
|
+
if not all_int(dims): return dims
|
11
|
+
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
|
12
|
+
for i,m in enumerate(max_sizes):
|
13
|
+
if i < (len(dims)-1) and dims[i] * dims[i+1] <= m:
|
14
|
+
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
|
15
|
+
break
|
16
|
+
else: return None
|
17
|
+
return dims
|
18
|
+
|
19
|
+
def _split_dims(dims, max_sizes):
|
20
|
+
if all(d <= m for d,m in zip(dims, max_sizes)): return dims
|
21
|
+
_dims = list(dims) + [1]*(3-len(dims))
|
22
|
+
for i in range(len(_dims)):
|
23
|
+
while _dims[i] > max_sizes[i]:
|
24
|
+
div = next((d for d in range(2, math.ceil(math.sqrt(_dims[i])) + 1) if (_dims[i] % d) == 0), 1)
|
25
|
+
if div == 1: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
26
|
+
_dims[i], _dims[(i+1)%len(_dims)] = _dims[i]//div, _dims[(i+1)%len(_dims)]*div
|
27
|
+
return tuple(_dims[:2] if _dims[2] == 1 else _dims[0] if _dims[1:3] == [1,1] else _dims)
|
28
|
+
|
29
|
+
def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|None, reverse=False) -> list[UOp]:
|
30
|
+
if reverse: dims = dims[::-1]
|
31
|
+
# try to group first: (a, b, c, d) -> (ab, c, d)
|
32
|
+
limited = (grouped if (grouped := _group_dims(dims, max_sizes)) else dims) if max_sizes is not None else dims
|
33
|
+
# check if grouping failed
|
34
|
+
if max_sizes is not None and len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
35
|
+
# try to split up dims: (a,) -> (b, c)
|
36
|
+
if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims
|
37
|
+
ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
|
38
|
+
if len(limited) < len(dims):
|
39
|
+
ret = []
|
40
|
+
if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}")
|
41
|
+
for idx, contraction_group in zip(raw_idxs, contraction):
|
42
|
+
for c in contraction_group[:-1]:
|
43
|
+
ret.append(idx % dims[c])
|
44
|
+
idx //= dims[c]
|
45
|
+
ret.append(idx)
|
46
|
+
elif len(limited) > len(dims):
|
47
|
+
a, b = len(limited), len(dims)
|
48
|
+
if a == 2 and b == 1: ret = [raw_idxs[0] * limited[1] + raw_idxs[1]]
|
49
|
+
if a == 3 and b == 1: ret = [raw_idxs[0] * (limited[1] * limited[2]) + raw_idxs[1] * limited[2] + raw_idxs[2]]
|
50
|
+
if a == 3 and b == 2: ret = [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
|
51
|
+
return ret[::-1] if reverse else ret
|
52
|
+
|
53
|
+
def add_gpudims(ctx:Renderer, s:UOp):
|
54
|
+
if s.arg is None: return None
|
55
|
+
ki: KernelInfo = s.arg
|
56
|
+
global_dims = [i for i,x in enumerate(ki.axis_types) if x is AxisType.GLOBAL]
|
57
|
+
local_dims = [i for i,x in enumerate(ki.axis_types) if x in (AxisType.LOCAL, AxisType.GROUP_REDUCE)]
|
58
|
+
if not global_dims and not local_dims: return None
|
59
|
+
s_topo = list(s.toposort())
|
60
|
+
if any(x.op is Ops.SPECIAL for x in s_topo): return None
|
61
|
+
|
62
|
+
# get global and local shape
|
63
|
+
all_ranges = {x.arg%1000:x for x in s_topo if x.op is Ops.RANGE}
|
64
|
+
ranges = [all_ranges[r] for r in global_dims+local_dims if r in all_ranges]
|
65
|
+
global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg%1000 in global_dims])
|
66
|
+
local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg%1000 in local_dims])
|
67
|
+
|
68
|
+
# get the idxs
|
69
|
+
if ki.dont_use_locals:
|
70
|
+
assert not local_dims, "can't use locals if there's no local dims"
|
71
|
+
idxs = get_grouped_dims("idx", global_shape, ctx.global_max, reverse=True)
|
72
|
+
else:
|
73
|
+
# define indexes for GPU-like execution
|
74
|
+
idxs = get_grouped_dims("gidx", global_shape, ctx.global_max, reverse=True) + get_grouped_dims("lidx", local_shape, ctx.local_max)
|
75
|
+
|
76
|
+
# apply to multiple ranges
|
77
|
+
subs = {}
|
78
|
+
for r in s_topo:
|
79
|
+
if r.op is not Ops.RANGE: continue
|
80
|
+
try:
|
81
|
+
ii = (global_dims+local_dims).index(r.arg%1000)
|
82
|
+
if r.arg < 2000 and ki.axis_types[r.arg%1000] == AxisType.GROUP_REDUCE: continue
|
83
|
+
subs[r] = idxs[ii]
|
84
|
+
except ValueError: continue
|
85
|
+
return s.substitute(subs)
|
86
|
+
|
87
|
+
pm_add_gpudims = PatternMatcher([
|
88
|
+
(UPat(Ops.SINK, name="s"), add_gpudims),
|
89
|
+
])
|
tinygrad/codegen/linearize.py
CHANGED
@@ -1,234 +1,236 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import
|
3
|
-
from
|
4
|
-
from
|
5
|
-
from tinygrad.
|
6
|
-
from tinygrad.
|
7
|
-
from tinygrad.helpers import dedup, flatten, partition
|
2
|
+
import heapq
|
3
|
+
from collections import defaultdict
|
4
|
+
from dataclasses import dataclass, replace
|
5
|
+
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat, GroupOp
|
6
|
+
from tinygrad.helpers import dedup, all_same, flatten, getenv
|
8
7
|
|
9
|
-
|
8
|
+
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
|
9
|
+
def block_reorder(lst:list[UOp]) -> list[UOp]:
|
10
|
+
in_this_block = set(lst)
|
11
|
+
local_children: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
12
|
+
in_degree:dict[UOp, int] = {}
|
13
|
+
priorities:dict[UOp, int] = {}
|
14
|
+
|
15
|
+
# get local children and assign priorities
|
16
|
+
# NOTE: this requires the lst be locally toposorted
|
17
|
+
for u in reversed(lst):
|
18
|
+
in_degree[u] = 0
|
19
|
+
for s in u.src:
|
20
|
+
if s in in_this_block:
|
21
|
+
local_children[s].append(u)
|
22
|
+
in_degree[u] += 1
|
23
|
+
# put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too
|
24
|
+
priority = [0] + [priorities[x] for x in local_children[u]]
|
25
|
+
if u.op is Ops.LOAD: priority.append(-1000)
|
26
|
+
if u.op is Ops.BARRIER: priority.append(-1500)
|
27
|
+
priorities[u] = min(priority)
|
28
|
+
|
29
|
+
# number the uops in "ideal" order
|
30
|
+
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))}
|
31
|
+
|
32
|
+
# then force then to be toposorted in as close to the ideal order as possible
|
33
|
+
heapq.heapify(heap:=[(nkey[u],u) for u in lst if in_degree[u] == 0])
|
34
|
+
newlst = []
|
35
|
+
while heap:
|
36
|
+
newlst.append(u:=heapq.heappop(heap)[1])
|
37
|
+
for v in local_children[u]:
|
38
|
+
in_degree[v] -= 1
|
39
|
+
if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v))
|
40
|
+
|
41
|
+
assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}"
|
42
|
+
return newlst
|
43
|
+
|
44
|
+
# ***** basic block *****
|
10
45
|
|
11
46
|
def disp(y:UOp) -> str:
|
12
|
-
if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0])
|
13
47
|
if y.op is Ops.IF: return f'IF{id(y)}'
|
14
48
|
if y.op is Ops.RANGE: return str(y.arg)
|
15
49
|
return "<NONE>"
|
16
50
|
|
17
|
-
@dataclass(frozen=True)
|
51
|
+
@dataclass(frozen=True, eq=False)
|
18
52
|
class BasicBlock:
|
19
|
-
ctx: tuple[UOp, ...]
|
20
53
|
lst: tuple[UOp, ...]
|
54
|
+
ctx: tuple[UOp, ...] = ()
|
21
55
|
end: UOp|None = None
|
22
|
-
|
56
|
+
cnt: int = 0
|
57
|
+
child_ctx: tuple[UOp, ...]|None = None
|
58
|
+
def __lt__(self, _:BasicBlock): raise RuntimeError("no comparing basic blocks")
|
23
59
|
def __repr__(self):
|
24
|
-
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+\
|
25
|
-
f"{[disp(y) for y in self.ctx]} {
|
26
|
-
|
27
|
-
def
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
60
|
+
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+f'f{self.cnt} '+\
|
61
|
+
f"{[disp(y) for y in self.ctx]} {[disp(y) for y in self.child_ctx] if self.child_ctx is not None else '-'} "+\
|
62
|
+
f"{len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
|
63
|
+
def last_ctx(self): return self.child_ctx if self.child_ctx is not None else self.ctx
|
64
|
+
|
65
|
+
def _sort_ctx(inp): return tuple(sorted(dedup(inp), key=lambda x: x.tuplize))
|
66
|
+
|
67
|
+
# ***** block context *****
|
68
|
+
|
69
|
+
@dataclass
|
70
|
+
class BlockContext:
|
71
|
+
child_count: dict[UOp, int]
|
72
|
+
block_ctxs: dict[UOp, tuple[UOp, ...]]
|
73
|
+
child_ctxs: dict[UOp, tuple[UOp, ...]]
|
74
|
+
def last_ctx(self, u): return self.child_ctxs.get(u, self.block_ctxs[u])
|
75
|
+
@staticmethod
|
76
|
+
def from_sink(sink:UOp) -> BlockContext:
|
77
|
+
# get children and all block contexts
|
78
|
+
ctx = BlockContext({}, {}, {})
|
79
|
+
for u in sink.toposort():
|
80
|
+
this_block_ctx: list[UOp] = []
|
81
|
+
ctx.child_count[u] = 0
|
82
|
+
|
83
|
+
# get children and accumulate the last_ctx
|
84
|
+
for s in u.src:
|
85
|
+
# NOTE: if a parent appears multiple times in the src, it counts multiple times as a child
|
86
|
+
ctx.child_count[s] += 1
|
87
|
+
this_block_ctx += ctx.last_ctx(s)
|
88
|
+
|
89
|
+
# save the block ctx. SINK never has anything
|
90
|
+
ctx.block_ctxs[u] = _sort_ctx(this_block_ctx) if u.op is not Ops.SINK else ()
|
91
|
+
|
92
|
+
# RANGE/IF add to the next ctx
|
93
|
+
# STORE/ASSIGN subtract from the next ctx
|
94
|
+
if u.op in {Ops.RANGE, Ops.IF}: ctx.child_ctxs[u] = _sort_ctx(ctx.block_ctxs[u] + (u,))
|
95
|
+
elif u.op is Ops.STORE: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src])
|
96
|
+
return ctx
|
97
|
+
|
98
|
+
# ***** make blocks *****
|
99
|
+
|
100
|
+
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}
|
101
|
+
|
102
|
+
def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp, ...], cnt:int=1) -> UOp:
|
103
|
+
ends_to_add = [z for z in new_ctx if z not in current_ctx]
|
104
|
+
while len(ends_to_add):
|
105
|
+
r:UOp = ends_to_add.pop(-1)
|
106
|
+
new_ctx = tuple([z for z in new_ctx if z is not r])
|
107
|
+
end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,))
|
108
|
+
base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock((end_uop,), tuple(new_ctx), end=r, cnt=cnt))
|
109
|
+
return base_block
|
110
|
+
|
111
|
+
def make_block_bottom_up(ctx:BlockContext, x:UOp):
|
112
|
+
if x.op is Ops.BLOCKSTART:
|
113
|
+
current_ctx, child_ctx = x.arg
|
114
|
+
lst = list(x.src)
|
115
|
+
child_count = 1
|
116
|
+
else:
|
117
|
+
current_ctx, child_count, child_ctx = ctx.block_ctxs[x], ctx.child_count[x], ctx.child_ctxs.get(x, None)
|
118
|
+
lst = [x]
|
119
|
+
|
120
|
+
# count of times we've seen this block, or a seed for a new block if we can't merge it
|
121
|
+
unmergable: defaultdict[UOp, int] = defaultdict(int)
|
122
|
+
blockseeds = defaultdict(list)
|
123
|
+
|
124
|
+
# add the srcs of this to the frontier
|
125
|
+
# NOTE: things may be in here multiple times, that's okay
|
126
|
+
frontier_nodes = list(flatten(y.src[::-1] for y in lst))
|
127
|
+
while len(frontier_nodes):
|
128
|
+
u = frontier_nodes.pop(0)
|
129
|
+
if u.op not in DONT_PLACE_IN_BLOCK and ctx.child_count[u] == unmergable[u]+1:
|
130
|
+
# count is correct
|
131
|
+
if (newctx:=ctx.block_ctxs[u]) == current_ctx:
|
132
|
+
# block has same context, merge it, and put the srcs on the frontier
|
133
|
+
lst.append(u)
|
134
|
+
frontier_nodes.extend(u.src[::-1])
|
48
135
|
else:
|
49
|
-
#
|
50
|
-
|
136
|
+
# block has different context, add it to blockseeds
|
137
|
+
blockseeds[(newctx, ctx.child_ctxs.get(u, None))].append(u)
|
138
|
+
del unmergable[u]
|
51
139
|
else:
|
52
|
-
#
|
53
|
-
|
54
|
-
if len(to_append) == 0 and len(new_blocks) == 0: return None
|
55
|
-
|
56
|
-
for rng,lst in new_blocks.items():
|
57
|
-
srcs = flatten(y.src for y in lst)
|
58
|
-
if (old_block:=old_blocks.pop(rng, None)) is not None:
|
59
|
-
# NOTE: order shouldn't matter here
|
60
|
-
srcs.extend(old_block.src)
|
61
|
-
lst.extend(old_block.arg.lst)
|
62
|
-
new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(srcs)), BasicBlock(rng, tuple(lst)))
|
63
|
-
lrng = list(rng)
|
64
|
-
for r in rng[::-1]:
|
65
|
-
if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART:
|
66
|
-
lrng.remove(r)
|
67
|
-
new_block = UOp(Ops.BLOCKEND, src=(new_block,),
|
68
|
-
arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r))
|
69
|
-
new_srcs.append(new_block)
|
70
|
-
return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(list(old_blocks.values())+new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst))
|
71
|
-
|
72
|
-
make_basic_blocks = PatternMatcher([
|
73
|
-
(UPat(Ops.SINK, name="x"),
|
74
|
-
lambda x: UOp(Ops.BLOCK, src=x.src+((UOp(Ops.NAME, arg=x.arg.name),) if x.arg is not None else ()), arg=BasicBlock((), (x,)))),
|
75
|
-
(UPat(Ops.BLOCK, name="x"), append_to_block),
|
76
|
-
])
|
77
|
-
|
78
|
-
def block_merge(ctx, x:UOp):
|
79
|
-
# ctx is children here
|
80
|
-
if x.op is Ops.BLOCKEND:
|
81
|
-
# if it's a BLOCKEND, see if we are done with placement. if all the children of the range are in here
|
82
|
-
in_this_block = set(x.arg.lst)
|
83
|
-
if len([y for y in ctx[x.arg.end] if y not in in_this_block]) == 0:
|
84
|
-
# find the parent block that has the BLOCKSTART in the ctx
|
85
|
-
parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and UOp(Ops.BLOCKSTART, src=(x.arg.end,)) in y.arg.ctx]
|
86
|
-
assert len(parent_blocks) <= 1, "should never have two parent blocks"
|
87
|
-
if len(parent_blocks) == 1:
|
88
|
-
parent_block = parent_blocks[0]
|
89
|
-
# range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if)
|
90
|
-
early_ops, late_ops = partition(x.arg.lst, lambda y: y.op is Ops.DEFINE_ACC and x.arg.end in y.src)
|
91
|
-
return UOp(Ops.BLOCK, dtypes.void, tuple(y for y in x.src if y is not parent_block)+parent_block.src,
|
92
|
-
BasicBlock(tuple(y for y in x.arg.ctx if y is not x.arg.end), tuple(early_ops)+parent_block.arg.lst+tuple(late_ops)))
|
93
|
-
|
94
|
-
new_srcs: list[UOp] = []
|
95
|
-
to_append: list[UOp] = []
|
96
|
-
new_ctx = x.arg.ctx
|
97
|
-
placed = set()
|
98
|
-
for u in x.src:
|
99
|
-
if u.op is Ops.BLOCK and (tuple(u.arg.ctx) == tuple(x.arg.ctx) or (x.arg.end is not None and x.arg.end in u.arg.ctx)):
|
100
|
-
# NOTE: this can't appear in srcs twice or it would be a BLOCKFORK
|
101
|
-
new_ctx += tuple(y for y in u.arg.ctx if y not in x.arg.ctx)
|
102
|
-
new_srcs.extend(u.src)
|
103
|
-
to_append.extend(u.arg.lst)
|
104
|
-
elif u.op is Ops.BLOCKFORK and x.src.count(u) == u.arg: # block fork appears # of times in srcs
|
105
|
-
if u not in placed:
|
106
|
-
new_srcs.extend(u.src)
|
107
|
-
placed.add(u)
|
108
|
-
else:
|
109
|
-
# keep it in srcs
|
110
|
-
new_srcs.append(u)
|
111
|
-
if len(to_append) == 0 and len(placed) == 0: return None
|
112
|
-
return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(tuple(sorted(new_ctx, key=lambda x: x.tuplize)), tuple(to_append)+x.arg.lst, x.arg.end))
|
113
|
-
|
114
|
-
pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),])
|
115
|
-
|
116
|
-
def block_finalize(block:UOp):
|
117
|
-
if len(block.src) == 0: return None
|
118
|
-
_uops = sorted(dedup(block.src), key=lambda x: x.tuplize)
|
119
|
-
assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
|
120
|
-
_uops += block.arg.lst
|
121
|
-
# strip the SINK
|
122
|
-
assert _uops[-1].op is Ops.SINK, "doesn't end with SINK"
|
123
|
-
return UOp(Ops.BLOCK, arg=BasicBlock((), tuple(_uops[:-1])))
|
140
|
+
# count is incorrect (or it's DONT_PLACE_IN_BLOCK), add it to unmergable
|
141
|
+
unmergable[u] += 1
|
124
142
|
|
125
|
-
|
143
|
+
# add unmergables to sources
|
144
|
+
srcs = []
|
145
|
+
for u,cnt in unmergable.items(): srcs += [add_blockends(u, ctx.block_ctxs[u], current_ctx, cnt=cnt)]*cnt
|
126
146
|
|
127
|
-
#
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
in_degree: collections.defaultdict[UOp, int] = collections.defaultdict(int)
|
132
|
-
priorities:dict[UOp, int] = {}
|
133
|
-
|
134
|
-
# get local children and assign priorities
|
135
|
-
for u in reversed(in_block.arg.lst):
|
136
|
-
for s in u.src:
|
137
|
-
if s in in_this_block:
|
138
|
-
local_children[s].append(u)
|
139
|
-
in_degree[u] += 1
|
140
|
-
# put loads in the beginning of the block and prevent priority inversion
|
141
|
-
priorities[u] = min([-1000 if u.op is Ops.LOAD else 0] + [priorities[x] for x in local_children[u]])
|
142
|
-
|
143
|
-
# placement queue
|
144
|
-
queue:list[tuple[int, tuple, UOp]] = []
|
145
|
-
def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u))
|
147
|
+
# add blockseeds, with blockends as needed
|
148
|
+
for (new_ctx, new_child_ctx), v in blockseeds.items():
|
149
|
+
base_block = UOp(Ops.BLOCKSTART, src=tuple(v), arg=(new_ctx, new_child_ctx))
|
150
|
+
srcs.append(add_blockends(base_block, new_ctx, current_ctx))
|
146
151
|
|
147
|
-
|
148
|
-
|
149
|
-
|
152
|
+
lst = lst[::-1]
|
153
|
+
if getenv("BLOCK_REORDER", 1): lst = block_reorder(lst)
|
154
|
+
bb = BasicBlock(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx)
|
155
|
+
return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb)
|
150
156
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
newlst.append(x)
|
155
|
-
for u in local_children[x]:
|
156
|
-
in_degree[u] -= 1
|
157
|
-
if in_degree[u] == 0: push(u)
|
158
|
-
|
159
|
-
assert len(newlst) == len(in_block.arg.lst), f"len mismatch {len(newlst)} != {len(in_block.arg.lst)}"
|
160
|
-
return in_block.replace(arg=BasicBlock(in_block.arg.ctx, tuple(newlst)))
|
161
|
-
|
162
|
-
def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
|
163
|
-
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
164
|
-
|
165
|
-
# get children and all block contexts
|
166
|
-
temp_block_ctxs: dict[UOp, list[UOp]] = {}
|
167
|
-
children: dict[UOp, list[UOp]] = {}
|
168
|
-
for u in sink.toposort:
|
169
|
-
this_block_ctx: list[UOp] = []
|
170
|
-
for s in u.src:
|
171
|
-
# save children
|
172
|
-
children.setdefault(s, []).append(u)
|
173
|
-
# compute block ctx
|
174
|
-
if s.op in {Ops.RANGE, Ops.IF}: this_block_ctx.append(s)
|
175
|
-
# don't flow (fully) through assign and store
|
176
|
-
elif s.op is Ops.STORE:
|
177
|
-
# ugh, deal with non-reduce locals. probably wrong
|
178
|
-
if isinstance(s.src[0].dtype, PtrDType) and s.src[0].dtype.local:
|
179
|
-
idx_context, store_context = temp_block_ctxs[s.src[0]], temp_block_ctxs[s]
|
180
|
-
this_block_ctx += [x for x in store_context if x not in idx_context and x.op is Ops.RANGE]
|
181
|
-
elif s.op is Ops.ASSIGN:
|
182
|
-
# flow though assign, but remove the ranges used in the assign
|
183
|
-
assert s.src[0].op is Ops.DEFINE_ACC
|
184
|
-
this_block_ctx += [x for x in temp_block_ctxs[s.src[1]] if x not in s.src[0].src[1:]]
|
185
|
-
else:
|
186
|
-
# flow though everything else
|
187
|
-
this_block_ctx += temp_block_ctxs[s]
|
188
|
-
temp_block_ctxs[u] = sorted(dedup(this_block_ctx), key=lambda x: x.tuplize)
|
189
|
-
|
190
|
-
# make final block_ctxs, add BLOCKSTART to block_ctxs for IF and RANGE
|
191
|
-
block_ctxs: dict[UOp, tuple[UOp, ...]] = {}
|
192
|
-
for u in sink.toposort:
|
193
|
-
block_ctxs[u] = ((UOp(Ops.BLOCKSTART, src=(u,)),) + tuple(temp_block_ctxs[u])) if u.op in {Ops.IF, Ops.RANGE} else tuple(temp_block_ctxs[u])
|
194
|
-
|
195
|
-
# TODO: there's probably a clever way to remove this while loop
|
196
|
-
while 1:
|
197
|
-
sink = graph_rewrite(sink, make_basic_blocks, ctx=(block_ctxs, children))
|
198
|
-
|
199
|
-
# add BLOCKFORK (slow!)
|
200
|
-
block_parent_count = collections.Counter(flatten([x.src for x in sink.toposort if x.op is Ops.BLOCK]))
|
201
|
-
non_block_parents = set(flatten([x.src for x in sink.toposort if x.op is not Ops.BLOCK]))
|
202
|
-
forks = {u:UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], (u,))),), arg=child_count)
|
203
|
-
for u,child_count in block_parent_count.items() if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents}
|
157
|
+
block_create = PatternMatcher([
|
158
|
+
(UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up),
|
159
|
+
])
|
204
160
|
|
205
|
-
|
206
|
-
sink = sink.substitute(forks)
|
161
|
+
# ***** blockend merging ****
|
207
162
|
|
208
|
-
|
163
|
+
def merge_blockends(sink:UOp) -> UOp|None:
|
164
|
+
# only run on the final BLOCK with the SINK in it
|
165
|
+
if sink.arg.lst[-1].op is not Ops.SINK: return None
|
166
|
+
# combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs
|
209
167
|
blockends_to_arg: dict[UOp, list[UOp]] = {}
|
210
|
-
for be in sink.toposort:
|
168
|
+
for be in sink.toposort():
|
211
169
|
if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
|
212
170
|
new_forks = {}
|
213
171
|
for k,v in blockends_to_arg.items():
|
214
172
|
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
|
215
173
|
if len(v) > 1:
|
216
|
-
|
217
|
-
|
174
|
+
bb = BasicBlock(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v))
|
175
|
+
out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb)
|
176
|
+
# NOTE: bb.ctx != u.arg.ctx can cause problems here
|
218
177
|
for u in v: new_forks[u] = out
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
178
|
+
if len(new_forks) == 0: return None
|
179
|
+
return sink.substitute(new_forks)
|
180
|
+
|
181
|
+
pm_blockend_merge = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), merge_blockends)])
|
182
|
+
|
183
|
+
# ***** block merging ****
|
184
|
+
|
185
|
+
def merge_block(x:UOp):
|
186
|
+
unmergable_blocks, mergable_blocks = [], []
|
187
|
+
mergable_dict: defaultdict[UOp, int] = defaultdict(int)
|
188
|
+
for y in x.src:
|
189
|
+
if y.op is Ops.BLOCK and x.op is Ops.BLOCK and x.arg.ctx == y.arg.ctx: mergable_dict[y] += 1
|
190
|
+
elif y.op is Ops.BLOCK and x.op is Ops.BLOCKEND and x.arg.end in y.arg.ctx: mergable_dict[y] += 1
|
191
|
+
else: unmergable_blocks.append(y)
|
192
|
+
for k,v in mergable_dict.items():
|
193
|
+
if v == k.arg.cnt: mergable_blocks.append(k)
|
194
|
+
else: unmergable_blocks.extend([k]*v)
|
195
|
+
if len(mergable_blocks) == 0: return None
|
196
|
+
del mergable_dict
|
197
|
+
|
198
|
+
# create the block
|
199
|
+
arg = replace(x.arg, lst=tuple(flatten([y.arg.lst for y in mergable_blocks]))+x.arg.lst)
|
200
|
+
return UOp(x.op, src=tuple(flatten([y.src for y in mergable_blocks])+unmergable_blocks), arg=arg)
|
201
|
+
|
202
|
+
def remove_blockend(x:UOp):
|
203
|
+
# if there's any remaining blocks that need to go in this BLOCKEND, we don't remove it
|
204
|
+
if any(x.arg.end in y.arg.ctx for y in x.src if y.op in {Ops.BLOCK, Ops.BLOCKEND}): return None
|
205
|
+
|
206
|
+
if (parent_blocks := [y for y in x.src if y.op is Ops.BLOCK and y.arg.child_ctx is not None and x.arg.end in y.arg.child_ctx]):
|
207
|
+
assert all_same(parent_blocks), f"should never have two parent blocks (has {len(parent_blocks)})"
|
208
|
+
parent_block = parent_blocks[0]
|
209
|
+
assert len(parent_blocks) == parent_block.arg.cnt
|
210
|
+
# NOTE: DEFINE_ACC doesn't have to be handled in any special way
|
211
|
+
late_ops = list(x.arg.lst)
|
212
|
+
# NOTE: we have to add a barrier at the start if barrier is used in the range
|
213
|
+
if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE:
|
214
|
+
late_ops = [UOp(Ops.BARRIER)] + late_ops
|
215
|
+
# peephole opt, remove any BARRIERs next to each other
|
216
|
+
for i in range(len(late_ops)-1):
|
217
|
+
if late_ops[i].op is Ops.BARRIER and late_ops[i+1].op is Ops.BARRIER: late_ops[i+1] = UOp(Ops.NOOP)
|
218
|
+
arg = BasicBlock(parent_block.arg.lst+tuple(late_ops), tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt)
|
219
|
+
return UOp(Ops.BLOCK, src=tuple(y for y in x.src if y is not parent_block)+parent_block.src, arg=arg)
|
220
|
+
|
221
|
+
block_merge = PatternMatcher([
|
222
|
+
(UPat((Ops.BLOCK, Ops.BLOCKEND), name="x"), merge_block),
|
223
|
+
(UPat(Ops.BLOCKEND, name="x"), remove_blockend),
|
224
|
+
])
|
223
225
|
|
224
|
-
|
225
|
-
sink = graph_rewrite(sink, pm_block_merge, ctx=children)
|
226
|
+
# ****** finalize ******
|
226
227
|
|
227
|
-
|
228
|
-
sink
|
228
|
+
def finalize(sink:UOp) -> UOp:
|
229
|
+
if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src):
|
230
|
+
raise RuntimeError(f"linearize failure {sink.op} {[x.op for x in sink.src if x.op not in DONT_PLACE_IN_BLOCK]}")
|
229
231
|
|
230
|
-
#
|
231
|
-
|
232
|
+
# place the early things
|
233
|
+
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)
|
234
|
+
return UOp(Ops.BLOCKFINAL, arg=BasicBlock(tuple(lst)))
|
232
235
|
|
233
|
-
|
234
|
-
return list(sink.arg.lst)
|
236
|
+
pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)])
|