tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/codegen/linearize.py
CHANGED
@@ -1,95 +1,234 @@
|
|
1
|
-
from
|
2
|
-
import
|
3
|
-
from
|
4
|
-
from tinygrad.
|
5
|
-
from tinygrad.
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
1
|
+
from __future__ import annotations
|
2
|
+
import collections, heapq
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from tinygrad.ops import UOp, Ops, PatternMatcher, UPat, graph_rewrite, GroupOp
|
5
|
+
from tinygrad.spec import type_verify
|
6
|
+
from tinygrad.dtype import dtypes, PtrDType
|
7
|
+
from tinygrad.helpers import dedup, flatten, partition
|
8
|
+
|
9
|
+
DONT_PLACE_IN_BLOCK = {Ops.NAME, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block}
|
10
|
+
|
11
|
+
def disp(y:UOp) -> str:
|
12
|
+
if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0])
|
13
|
+
if y.op is Ops.IF: return f'IF{id(y)}'
|
14
|
+
if y.op is Ops.RANGE: return str(y.arg)
|
15
|
+
return "<NONE>"
|
16
|
+
|
17
|
+
@dataclass(frozen=True)
|
18
|
+
class BasicBlock:
|
19
|
+
ctx: tuple[UOp, ...]
|
20
|
+
lst: tuple[UOp, ...]
|
21
|
+
end: UOp|None = None
|
22
|
+
def __lt__(self, o:BasicBlock): return tuple(x.tuplize for x in self.ctx+self.lst) < tuple(x.tuplize for x in o.ctx+o.lst)
|
23
|
+
def __repr__(self):
|
24
|
+
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+\
|
25
|
+
f"{[disp(y) for y in self.ctx]} {len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
|
26
|
+
|
27
|
+
def append_to_block(ctx:tuple[dict[UOp, tuple[UOp, ...]], dict[UOp, list[UOp]]], x:UOp):
|
28
|
+
block_ctxs, children = ctx
|
29
|
+
in_this_block = set(x.arg.lst)
|
30
|
+
|
31
|
+
# collections to build
|
32
|
+
new_srcs: list[UOp] = []
|
33
|
+
to_append: list[UOp] = []
|
34
|
+
old_blocks: dict[tuple[UOp, ...], UOp] = {}
|
35
|
+
new_blocks: dict[tuple[UOp, ...], list[UOp]] = {}
|
36
|
+
|
37
|
+
for u in x.src:
|
38
|
+
if u.op is Ops.BLOCK:
|
39
|
+
# merge sibling blocks. NOTE: blocks must only have one output source
|
40
|
+
assert u.arg.ctx not in old_blocks, "sibling should never have been created"
|
41
|
+
old_blocks[u.arg.ctx] = u
|
42
|
+
elif u.op not in DONT_PLACE_IN_BLOCK and set(children[u]).issubset(in_this_block):
|
43
|
+
# if it can go in blocks and all its children are in the block, we add it to the block
|
44
|
+
if (block_ctx:=block_ctxs[u]) == x.arg.ctx:
|
45
|
+
# if it's the same context, we place the UOp in this block and append the parents to its srcs
|
46
|
+
new_srcs.extend(u.src)
|
47
|
+
to_append.append(u)
|
48
|
+
else:
|
49
|
+
# if it's a different context, we create a new block with this UOp
|
50
|
+
new_blocks.setdefault(block_ctx, []).append(u)
|
51
|
+
else:
|
52
|
+
# otherwise, we keep it in the srcs
|
53
|
+
new_srcs.append(u)
|
54
|
+
if len(to_append) == 0 and len(new_blocks) == 0: return None
|
55
|
+
|
56
|
+
for rng,lst in new_blocks.items():
|
57
|
+
srcs = flatten(y.src for y in lst)
|
58
|
+
if (old_block:=old_blocks.pop(rng, None)) is not None:
|
59
|
+
# NOTE: order shouldn't matter here
|
60
|
+
srcs.extend(old_block.src)
|
61
|
+
lst.extend(old_block.arg.lst)
|
62
|
+
new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(srcs)), BasicBlock(rng, tuple(lst)))
|
63
|
+
lrng = list(rng)
|
64
|
+
for r in rng[::-1]:
|
65
|
+
if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART:
|
66
|
+
lrng.remove(r)
|
67
|
+
new_block = UOp(Ops.BLOCKEND, src=(new_block,),
|
68
|
+
arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r))
|
69
|
+
new_srcs.append(new_block)
|
70
|
+
return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(list(old_blocks.values())+new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst))
|
71
|
+
|
72
|
+
make_basic_blocks = PatternMatcher([
|
73
|
+
(UPat(Ops.SINK, name="x"),
|
74
|
+
lambda x: UOp(Ops.BLOCK, src=x.src+((UOp(Ops.NAME, arg=x.arg.name),) if x.arg is not None else ()), arg=BasicBlock((), (x,)))),
|
75
|
+
(UPat(Ops.BLOCK, name="x"), append_to_block),
|
76
|
+
])
|
77
|
+
|
78
|
+
def block_merge(ctx, x:UOp):
|
79
|
+
# ctx is children here
|
80
|
+
if x.op is Ops.BLOCKEND:
|
81
|
+
# if it's a BLOCKEND, see if we are done with placement. if all the children of the range are in here
|
82
|
+
in_this_block = set(x.arg.lst)
|
83
|
+
if len([y for y in ctx[x.arg.end] if y not in in_this_block]) == 0:
|
84
|
+
# find the parent block that has the BLOCKSTART in the ctx
|
85
|
+
parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and UOp(Ops.BLOCKSTART, src=(x.arg.end,)) in y.arg.ctx]
|
86
|
+
assert len(parent_blocks) <= 1, "should never have two parent blocks"
|
87
|
+
if len(parent_blocks) == 1:
|
88
|
+
parent_block = parent_blocks[0]
|
89
|
+
# range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if)
|
90
|
+
early_ops, late_ops = partition(x.arg.lst, lambda y: y.op is Ops.DEFINE_ACC and x.arg.end in y.src)
|
91
|
+
return UOp(Ops.BLOCK, dtypes.void, tuple(y for y in x.src if y is not parent_block)+parent_block.src,
|
92
|
+
BasicBlock(tuple(y for y in x.arg.ctx if y is not x.arg.end), tuple(early_ops)+parent_block.arg.lst+tuple(late_ops)))
|
93
|
+
|
94
|
+
new_srcs: list[UOp] = []
|
95
|
+
to_append: list[UOp] = []
|
96
|
+
new_ctx = x.arg.ctx
|
97
|
+
placed = set()
|
98
|
+
for u in x.src:
|
99
|
+
if u.op is Ops.BLOCK and (tuple(u.arg.ctx) == tuple(x.arg.ctx) or (x.arg.end is not None and x.arg.end in u.arg.ctx)):
|
100
|
+
# NOTE: this can't appear in srcs twice or it would be a BLOCKFORK
|
101
|
+
new_ctx += tuple(y for y in u.arg.ctx if y not in x.arg.ctx)
|
102
|
+
new_srcs.extend(u.src)
|
103
|
+
to_append.extend(u.arg.lst)
|
104
|
+
elif u.op is Ops.BLOCKFORK and x.src.count(u) == u.arg: # block fork appears # of times in srcs
|
105
|
+
if u not in placed:
|
106
|
+
new_srcs.extend(u.src)
|
107
|
+
placed.add(u)
|
47
108
|
else:
|
48
|
-
#
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
109
|
+
# keep it in srcs
|
110
|
+
new_srcs.append(u)
|
111
|
+
if len(to_append) == 0 and len(placed) == 0: return None
|
112
|
+
return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(tuple(sorted(new_ctx, key=lambda x: x.tuplize)), tuple(to_append)+x.arg.lst, x.arg.end))
|
113
|
+
|
114
|
+
pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),])
|
115
|
+
|
116
|
+
def block_finalize(block:UOp):
|
117
|
+
if len(block.src) == 0: return None
|
118
|
+
_uops = sorted(dedup(block.src), key=lambda x: x.tuplize)
|
119
|
+
assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
|
120
|
+
_uops += block.arg.lst
|
121
|
+
# strip the SINK
|
122
|
+
assert _uops[-1].op is Ops.SINK, "doesn't end with SINK"
|
123
|
+
return UOp(Ops.BLOCK, arg=BasicBlock((), tuple(_uops[:-1])))
|
124
|
+
|
125
|
+
pm_block_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="block"), block_finalize)])
|
126
|
+
|
127
|
+
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
|
128
|
+
def block_reorder(in_block:UOp):
|
129
|
+
in_this_block = set(in_block.arg.lst)
|
130
|
+
local_children: collections.defaultdict[UOp, list[UOp]] = collections.defaultdict(list)
|
131
|
+
in_degree: collections.defaultdict[UOp, int] = collections.defaultdict(int)
|
132
|
+
priorities:dict[UOp, int] = {}
|
133
|
+
|
134
|
+
# get local children and assign priorities
|
135
|
+
for u in reversed(in_block.arg.lst):
|
136
|
+
for s in u.src:
|
137
|
+
if s in in_this_block:
|
138
|
+
local_children[s].append(u)
|
139
|
+
in_degree[u] += 1
|
140
|
+
# put loads in the beginning of the block and prevent priority inversion
|
141
|
+
priorities[u] = min([-1000 if u.op is Ops.LOAD else 0] + [priorities[x] for x in local_children[u]])
|
142
|
+
|
143
|
+
# placement queue
|
144
|
+
queue:list[tuple[int, tuple, UOp]] = []
|
65
145
|
def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u))
|
66
146
|
|
67
|
-
|
68
|
-
|
147
|
+
# place the first ones that don't have deps
|
148
|
+
for u in in_block.arg.lst:
|
149
|
+
if u not in in_degree: push(u)
|
69
150
|
|
70
|
-
|
71
|
-
_uops: List[UOp] = []
|
151
|
+
newlst = []
|
72
152
|
while queue:
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
if x.op is Ops.DEFINE_ACC:
|
77
|
-
idx = min([_uops.index(l) for l in x.src if l.op is Ops.RANGE])
|
78
|
-
_uops.insert(idx, x)
|
79
|
-
else: _uops.append(x)
|
80
|
-
for u, ss in scope_children.items():
|
81
|
-
if x in ss:
|
82
|
-
ss.remove(x)
|
83
|
-
if len(ss) == 0: scope_end[u] = x
|
84
|
-
for u in children[x]:
|
153
|
+
_,_,x = heapq.heappop(queue)
|
154
|
+
newlst.append(x)
|
155
|
+
for u in local_children[x]:
|
85
156
|
in_degree[u] -= 1
|
86
157
|
if in_degree[u] == 0: push(u)
|
87
158
|
|
88
|
-
|
89
|
-
|
159
|
+
assert len(newlst) == len(in_block.arg.lst), f"len mismatch {len(newlst)} != {len(in_block.arg.lst)}"
|
160
|
+
return in_block.replace(arg=BasicBlock(in_block.arg.ctx, tuple(newlst)))
|
161
|
+
|
162
|
+
def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
|
163
|
+
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
164
|
+
|
165
|
+
# get children and all block contexts
|
166
|
+
temp_block_ctxs: dict[UOp, list[UOp]] = {}
|
167
|
+
children: dict[UOp, list[UOp]] = {}
|
168
|
+
for u in sink.toposort:
|
169
|
+
this_block_ctx: list[UOp] = []
|
170
|
+
for s in u.src:
|
171
|
+
# save children
|
172
|
+
children.setdefault(s, []).append(u)
|
173
|
+
# compute block ctx
|
174
|
+
if s.op in {Ops.RANGE, Ops.IF}: this_block_ctx.append(s)
|
175
|
+
# don't flow (fully) through assign and store
|
176
|
+
elif s.op is Ops.STORE:
|
177
|
+
# ugh, deal with non-reduce locals. probably wrong
|
178
|
+
if isinstance(s.src[0].dtype, PtrDType) and s.src[0].dtype.local:
|
179
|
+
idx_context, store_context = temp_block_ctxs[s.src[0]], temp_block_ctxs[s]
|
180
|
+
this_block_ctx += [x for x in store_context if x not in idx_context and x.op is Ops.RANGE]
|
181
|
+
elif s.op is Ops.ASSIGN:
|
182
|
+
# flow though assign, but remove the ranges used in the assign
|
183
|
+
assert s.src[0].op is Ops.DEFINE_ACC
|
184
|
+
this_block_ctx += [x for x in temp_block_ctxs[s.src[1]] if x not in s.src[0].src[1:]]
|
185
|
+
else:
|
186
|
+
# flow though everything else
|
187
|
+
this_block_ctx += temp_block_ctxs[s]
|
188
|
+
temp_block_ctxs[u] = sorted(dedup(this_block_ctx), key=lambda x: x.tuplize)
|
189
|
+
|
190
|
+
# make final block_ctxs, add BLOCKSTART to block_ctxs for IF and RANGE
|
191
|
+
block_ctxs: dict[UOp, tuple[UOp, ...]] = {}
|
192
|
+
for u in sink.toposort:
|
193
|
+
block_ctxs[u] = ((UOp(Ops.BLOCKSTART, src=(u,)),) + tuple(temp_block_ctxs[u])) if u.op in {Ops.IF, Ops.RANGE} else tuple(temp_block_ctxs[u])
|
194
|
+
|
195
|
+
# TODO: there's probably a clever way to remove this while loop
|
196
|
+
while 1:
|
197
|
+
sink = graph_rewrite(sink, make_basic_blocks, ctx=(block_ctxs, children))
|
198
|
+
|
199
|
+
# add BLOCKFORK (slow!)
|
200
|
+
block_parent_count = collections.Counter(flatten([x.src for x in sink.toposort if x.op is Ops.BLOCK]))
|
201
|
+
non_block_parents = set(flatten([x.src for x in sink.toposort if x.op is not Ops.BLOCK]))
|
202
|
+
forks = {u:UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], (u,))),), arg=child_count)
|
203
|
+
for u,child_count in block_parent_count.items() if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents}
|
204
|
+
|
205
|
+
if not len(forks): break
|
206
|
+
sink = sink.substitute(forks)
|
207
|
+
|
208
|
+
# combine matching BLOCKENDS
|
209
|
+
blockends_to_arg: dict[UOp, list[UOp]] = {}
|
210
|
+
for be in sink.toposort:
|
211
|
+
if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
|
212
|
+
new_forks = {}
|
213
|
+
for k,v in blockends_to_arg.items():
|
214
|
+
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
|
215
|
+
if len(v) > 1:
|
216
|
+
out = UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCKEND, src=tuple(flatten(x.src for x in v)),
|
217
|
+
arg=BasicBlock(tuple(dedup(flatten([y.arg.ctx for y in v]))), v[0].arg.lst, k)),), arg=len(v))
|
218
|
+
for u in v: new_forks[u] = out
|
219
|
+
sink = sink.substitute(new_forks)
|
220
|
+
|
221
|
+
# reorder ops in block for speed
|
222
|
+
sink = sink.substitute({u:newu for u in sink.toposort if u.op is Ops.BLOCK and (newu:=block_reorder(u)) is not u})
|
223
|
+
|
224
|
+
# final rewrite to merge all blocks into one
|
225
|
+
sink = graph_rewrite(sink, pm_block_merge, ctx=children)
|
226
|
+
|
227
|
+
# there should just be one block left, with a few parents with 0 srcs (now done in a rewriter)
|
228
|
+
sink = graph_rewrite(sink, pm_block_finalize)
|
90
229
|
|
91
230
|
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
|
92
|
-
if not skip_check: type_verify(
|
231
|
+
if not skip_check: type_verify(sink.arg.lst)
|
93
232
|
|
94
|
-
#
|
95
|
-
return
|
233
|
+
# return the list. TODO: refactor to return the UOp
|
234
|
+
return list(sink.arg.lst)
|
tinygrad/codegen/lowerer.py
CHANGED
@@ -1,54 +1,70 @@
|
|
1
1
|
# the job of the lowerer is to do indexing
|
2
|
-
|
3
|
-
import functools, itertools, operator
|
2
|
+
import functools, itertools, operator, math
|
4
3
|
from dataclasses import dataclass
|
5
|
-
from typing import
|
6
|
-
from tinygrad.shape.shapetracker import ShapeTracker
|
7
|
-
from tinygrad.shape.view import variable_to_uop
|
4
|
+
from typing import cast
|
8
5
|
from tinygrad.dtype import dtypes, PtrDType
|
9
|
-
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
|
6
|
+
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
|
10
7
|
from tinygrad.renderer import Renderer
|
11
|
-
from tinygrad.helpers import all_int, prod, partition, flatten
|
8
|
+
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap
|
9
|
+
from tinygrad.codegen.expander import expand_rewrite
|
12
10
|
|
13
11
|
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
14
|
-
def get_contraction(old_shape:
|
12
|
+
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
|
15
13
|
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
16
14
|
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
|
17
15
|
except ValueError: return None
|
18
16
|
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
|
19
17
|
|
20
18
|
# ***** indexing *****
|
21
|
-
|
22
|
-
def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
|
19
|
+
def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
|
23
20
|
# TODO: symbolic shape
|
24
21
|
if not all_int(dims): return dims
|
25
22
|
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
|
26
23
|
for i,m in enumerate(max_sizes):
|
27
|
-
if dims[i] * dims[i+1] <= m:
|
24
|
+
if i < (len(dims)-1) and dims[i] * dims[i+1] <= m:
|
28
25
|
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
|
29
26
|
break
|
30
|
-
else:
|
27
|
+
else: return None
|
31
28
|
return dims
|
32
29
|
|
33
|
-
def
|
30
|
+
def _split_dims(dims, max_sizes):
|
31
|
+
if all(d <= m for d,m in zip(dims, max_sizes)): return dims
|
32
|
+
_dims = list(dims) + [1]*(3-len(dims))
|
33
|
+
for i in range(len(_dims)):
|
34
|
+
while _dims[i] > max_sizes[i]:
|
35
|
+
div = next((d for d in range(2, math.ceil(math.sqrt(_dims[i])) + 1) if (_dims[i] % d) == 0), 1)
|
36
|
+
if div == 1: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
37
|
+
_dims[i], _dims[(i+1)%len(_dims)] = _dims[i]//div, _dims[(i+1)%len(_dims)]*div
|
38
|
+
return tuple(_dims[:2] if _dims[2] == 1 else _dims[0] if _dims[1:3] == [1,1] else _dims)
|
39
|
+
|
40
|
+
def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|None, reverse=False) -> list[UOp]:
|
34
41
|
if reverse: dims = dims[::-1]
|
35
|
-
|
42
|
+
# try to group first: (a, b, c, d) -> (ab, c, d)
|
43
|
+
limited = (grouped if (grouped := _group_dims(dims, max_sizes)) else dims) if max_sizes is not None else dims
|
44
|
+
# check if grouping failed
|
45
|
+
if max_sizes is not None and len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
46
|
+
# try to split up dims: (a,) -> (b, c)
|
47
|
+
if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims
|
36
48
|
ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
|
37
|
-
if limited
|
49
|
+
if len(limited) < len(dims):
|
38
50
|
ret = []
|
39
|
-
|
40
|
-
for idx,
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
51
|
+
if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}")
|
52
|
+
for idx, contraction_group in zip(raw_idxs, contraction):
|
53
|
+
for c in contraction_group[:-1]:
|
54
|
+
ret.append(idx % dims[c])
|
55
|
+
idx //= dims[c]
|
56
|
+
ret.append(idx)
|
57
|
+
elif len(limited) > len(dims):
|
58
|
+
a, b = len(limited), len(dims)
|
59
|
+
if a == 2 and b == 1: ret = [raw_idxs[0] * limited[1] + raw_idxs[1]]
|
60
|
+
if a == 3 and b == 1: ret = [raw_idxs[0] * (limited[1] * limited[2]) + raw_idxs[1] * limited[2] + raw_idxs[2]]
|
61
|
+
if a == 3 and b == 2: ret = [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
|
46
62
|
return ret[::-1] if reverse else ret
|
47
63
|
|
48
64
|
@dataclass
|
49
65
|
class IndexContext:
|
50
|
-
idxs:
|
51
|
-
ridxs:
|
66
|
+
idxs: list[UOp]
|
67
|
+
ridxs: list[UOp]
|
52
68
|
acc_num: int = 0
|
53
69
|
|
54
70
|
def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
@@ -56,14 +72,11 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
|
56
72
|
# NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
|
57
73
|
full_shape = ast.full_shape
|
58
74
|
first_upcasted = len(full_shape)-ki.upcasted
|
59
|
-
first_output_st: ShapeTracker = ast.src[0].st_arg
|
60
75
|
# if there's no reduce, this is first_upcasted. assumes reduces are at the end
|
61
|
-
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.
|
62
|
-
local_loads = [x for x in ast.
|
76
|
+
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.toposort if x.op is Ops.REDUCE_AXIS))
|
77
|
+
local_loads = [x for x in ast.toposort if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
|
63
78
|
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
|
64
|
-
group_for_reduces = sum([any(
|
65
|
-
[[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)],
|
66
|
-
first_output_st.shape[first_reduce:first_upcasted])]) if local_loads else 0
|
79
|
+
group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)])
|
67
80
|
global_dims = first_reduce-ki.local_dims
|
68
81
|
|
69
82
|
if opts.has_local:
|
@@ -76,22 +89,21 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
|
76
89
|
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
77
90
|
else:
|
78
91
|
# all loops are RANGES
|
79
|
-
idxs = [UOp(Ops.RANGE, dtypes.int, (
|
80
|
-
for i,g in enumerate(full_shape[:first_reduce])]
|
92
|
+
idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i) for i,g in enumerate(full_shape[:first_reduce])]
|
81
93
|
|
82
94
|
# reduce loops
|
83
|
-
idxs += [UOp(Ops.RANGE, dtypes.int, (
|
95
|
+
idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i)
|
84
96
|
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
|
85
97
|
|
86
98
|
# upcast loops
|
87
99
|
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
|
88
100
|
assert isinstance(g, int), "needs to be int to upcast/unroll"
|
89
|
-
idxs.append(UOp(Ops.
|
101
|
+
idxs.append(UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
|
90
102
|
|
91
103
|
# late indexes (group for reduce)
|
92
104
|
ridxs = idxs[:]
|
93
105
|
for a in range(first_reduce, first_reduce+group_for_reduces):
|
94
|
-
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (
|
106
|
+
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(full_shape[a])), 1000+a)
|
95
107
|
|
96
108
|
return IndexContext(idxs, ridxs)
|
97
109
|
|
@@ -100,7 +112,7 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
|
100
112
|
def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
101
113
|
# NOTE: always using ridxs is fine here
|
102
114
|
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
|
103
|
-
assert all(x.op is Ops.
|
115
|
+
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
|
104
116
|
alu_op: Ops = x.arg[0]
|
105
117
|
ret = x.src[0]
|
106
118
|
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
@@ -114,12 +126,10 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
|
114
126
|
|
115
127
|
def lower_load_store(ctx: IndexContext, x: UOp):
|
116
128
|
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
|
117
|
-
# TODO: check has_valid in UPat, not here
|
118
|
-
has_valid = valid.op is not Ops.CONST or valid.arg is not True
|
119
129
|
buf = x.src[0]
|
120
130
|
if x.op is Ops.LOAD:
|
121
131
|
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
|
122
|
-
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid
|
132
|
+
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
|
123
133
|
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
124
134
|
if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.ASSIGN:
|
125
135
|
reduce_input = x.src[2].src[1].src[1] if x.src[2].src[1].src[1] is not x.src[2].src[0] else x.src[2].src[1].src[0]
|
@@ -130,14 +140,22 @@ def lower_load_store(ctx: IndexContext, x: UOp):
|
|
130
140
|
if (not cast(PtrDType, x.src[0].dtype).local) or store_back:
|
131
141
|
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
132
142
|
if oidx is not ridx: valid = valid * oidx.eq(0)
|
133
|
-
|
134
|
-
|
143
|
+
return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid), x.src[2]))
|
144
|
+
|
145
|
+
def lower_const(x:UOp):
|
146
|
+
assert all(v.mask is None for v in unwrap(x.st).views), f"VIEW in CONST/DEFINE_VAR source must be unmasked, got {x.st}"
|
147
|
+
return x.replace(src=())
|
135
148
|
|
136
149
|
pm_lowerer = PatternMatcher([
|
137
150
|
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
151
|
+
(UPat((Ops.CONST, Ops.DEFINE_VAR), src=(UPat(Ops.VIEW),), name="x"), lower_const),
|
138
152
|
(UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
|
139
153
|
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
|
140
154
|
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
|
155
|
+
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
|
141
156
|
])
|
142
157
|
|
143
|
-
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
|
158
|
+
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
|
159
|
+
sink = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
|
160
|
+
# expand_rewrite turns this into a vectorized program
|
161
|
+
return expand_rewrite(sink)
|