tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +114 -172
- tinygrad/codegen/linearize.py +211 -81
- tinygrad/codegen/lowerer.py +30 -35
- tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
- tinygrad/codegen/transcendental.py +12 -13
- tinygrad/device.py +170 -47
- tinygrad/dtype.py +28 -26
- tinygrad/engine/jit.py +80 -63
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +162 -0
- tinygrad/engine/realize.py +58 -107
- tinygrad/engine/schedule.py +381 -314
- tinygrad/engine/search.py +40 -44
- tinygrad/gradient.py +70 -0
- tinygrad/helpers.py +77 -58
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +89 -64
- tinygrad/ops.py +562 -446
- tinygrad/renderer/__init__.py +79 -36
- tinygrad/renderer/cstyle.py +70 -84
- tinygrad/renderer/llvmir.py +32 -20
- tinygrad/renderer/ptx.py +79 -99
- tinygrad/renderer/wgsl.py +87 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libpciaccess.py +2023 -0
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +19 -21
- tinygrad/runtime/ops_amd.py +488 -327
- tinygrad/runtime/ops_clang.py +15 -28
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +129 -38
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +45 -40
- tinygrad/runtime/ops_metal.py +93 -73
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +232 -270
- tinygrad/runtime/ops_python.py +51 -46
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +63 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +384 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +26 -4
- tinygrad/runtime/support/hcq.py +254 -324
- tinygrad/runtime/support/llvm.py +32 -0
- tinygrad/shape/shapetracker.py +84 -53
- tinygrad/shape/view.py +103 -138
- tinygrad/spec.py +154 -0
- tinygrad/tensor.py +744 -496
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
- tinygrad-0.10.1.dist-info/RECORD +86 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/codegen/linearize.py
CHANGED
@@ -1,92 +1,222 @@
|
|
1
|
-
from
|
2
|
-
import
|
3
|
-
from
|
4
|
-
from tinygrad.
|
5
|
-
from tinygrad.
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
1
|
+
from __future__ import annotations
|
2
|
+
import collections, heapq
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from tinygrad.ops import UOp, Ops, PatternMatcher, UPat, graph_rewrite, GroupOp
|
5
|
+
from tinygrad.spec import type_verify
|
6
|
+
from tinygrad.dtype import dtypes, PtrDType
|
7
|
+
from tinygrad.helpers import dedup, flatten, partition
|
8
|
+
|
9
|
+
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block}
|
10
|
+
|
11
|
+
def disp(y:UOp) -> str:
|
12
|
+
if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0])
|
13
|
+
if y.op is Ops.IF: return f'IF{id(y)}'
|
14
|
+
if y.op is Ops.RANGE: return str(y.arg)
|
15
|
+
return "<NONE>"
|
16
|
+
|
17
|
+
@dataclass(frozen=True)
|
18
|
+
class BasicBlock:
|
19
|
+
ctx: tuple[UOp, ...]
|
20
|
+
lst: tuple[UOp, ...]
|
21
|
+
end: UOp|None = None
|
22
|
+
def __lt__(self, o:BasicBlock): return tuple(x.tuplize for x in self.ctx+self.lst) < tuple(x.tuplize for x in o.ctx+o.lst)
|
23
|
+
def __repr__(self):
|
24
|
+
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+\
|
25
|
+
f"{[disp(y) for y in self.ctx]} {len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
|
26
|
+
|
27
|
+
def append_to_block(ctx:tuple[dict[UOp, tuple[UOp, ...]], dict[UOp, list[UOp]]], x:UOp):
|
28
|
+
block_ctxs, children = ctx
|
29
|
+
in_this_block = set(x.arg.lst)
|
30
|
+
|
31
|
+
# collections to build
|
32
|
+
new_srcs: list[UOp] = []
|
33
|
+
to_append: list[UOp] = []
|
34
|
+
old_blocks: dict[tuple[UOp, ...], UOp] = {}
|
35
|
+
new_blocks: dict[tuple[UOp, ...], list[UOp]] = {}
|
36
|
+
|
37
|
+
for u in x.src:
|
38
|
+
if u.op is Ops.BLOCK:
|
39
|
+
# merge sibling blocks. NOTE: blocks must only have one output source
|
40
|
+
assert u.arg.ctx not in old_blocks, "sibling should never have been created"
|
41
|
+
old_blocks[u.arg.ctx] = u
|
42
|
+
elif u.op not in DONT_PLACE_IN_BLOCK and set(children[u]).issubset(in_this_block):
|
43
|
+
# if it can go in blocks and all its children are in the block, we add it to the block
|
44
|
+
if (block_ctx:=block_ctxs[u]) == x.arg.ctx:
|
45
|
+
# if it's the same context, we place the UOp in this block and append the parents to its srcs
|
46
|
+
new_srcs.extend(u.src)
|
47
|
+
to_append.append(u)
|
48
|
+
else:
|
49
|
+
# if it's a different context, we create a new block with this UOp
|
50
|
+
new_blocks.setdefault(block_ctx, []).append(u)
|
47
51
|
else:
|
48
|
-
#
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
for
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
52
|
+
# otherwise, we keep it in the srcs
|
53
|
+
new_srcs.append(u)
|
54
|
+
if len(to_append) == 0 and len(new_blocks) == 0: return None
|
55
|
+
|
56
|
+
for rng,lst in new_blocks.items():
|
57
|
+
srcs = flatten(y.src for y in lst)
|
58
|
+
if (old_block:=old_blocks.pop(rng, None)) is not None:
|
59
|
+
# NOTE: order shouldn't matter here
|
60
|
+
srcs.extend(old_block.src)
|
61
|
+
lst.extend(old_block.arg.lst)
|
62
|
+
new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(srcs)), BasicBlock(rng, tuple(lst)))
|
63
|
+
lrng = list(rng)
|
64
|
+
for r in rng[::-1]:
|
65
|
+
if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART:
|
66
|
+
lrng.remove(r)
|
67
|
+
new_block = UOp(Ops.BLOCKEND, src=(new_block,),
|
68
|
+
arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r))
|
69
|
+
new_srcs.append(new_block)
|
70
|
+
return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(list(old_blocks.values())+new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst))
|
71
|
+
|
72
|
+
make_basic_blocks = PatternMatcher([
|
73
|
+
(UPat(Ops.SINK, name="x"), lambda x: UOp(Ops.BLOCK, src=x.src, arg=BasicBlock((), (x,)))),
|
74
|
+
(UPat(Ops.BLOCK, name="x"), append_to_block),
|
75
|
+
])
|
76
|
+
|
77
|
+
def block_merge(ctx, x:UOp):
|
78
|
+
# ctx is children here
|
79
|
+
if x.op is Ops.BLOCKEND:
|
80
|
+
# if it's a BLOCKEND, see if we are done with placement. if all the children of the range are in here
|
81
|
+
in_this_block = set(x.arg.lst)
|
82
|
+
if len([y for y in ctx[x.arg.end] if y not in in_this_block]) == 0:
|
83
|
+
# find the parent block that has the BLOCKSTART in the ctx
|
84
|
+
parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and UOp(Ops.BLOCKSTART, src=(x.arg.end,)) in y.arg.ctx]
|
85
|
+
assert len(parent_blocks) <= 1, "should never have two parent blocks"
|
86
|
+
if len(parent_blocks) == 1:
|
87
|
+
parent_block = parent_blocks[0]
|
88
|
+
# range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if)
|
89
|
+
early_ops, late_ops = partition(x.arg.lst, lambda y: y.op is Ops.DEFINE_ACC and x.arg.end in y.src)
|
90
|
+
return UOp(Ops.BLOCK, dtypes.void, tuple(y for y in x.src if y is not parent_block)+parent_block.src,
|
91
|
+
BasicBlock(tuple(y for y in x.arg.ctx if y is not x.arg.end), tuple(early_ops)+parent_block.arg.lst+tuple(late_ops)))
|
92
|
+
|
93
|
+
new_srcs: list[UOp] = []
|
94
|
+
to_append: list[UOp] = []
|
95
|
+
new_ctx = x.arg.ctx
|
96
|
+
placed = set()
|
97
|
+
for u in x.src:
|
98
|
+
if u.op is Ops.BLOCK and (tuple(u.arg.ctx) == tuple(x.arg.ctx) or (x.arg.end is not None and x.arg.end in u.arg.ctx)):
|
99
|
+
# NOTE: this can't appear in srcs twice or it would be a BLOCKFORK
|
100
|
+
new_ctx += tuple(y for y in u.arg.ctx if y not in x.arg.ctx)
|
101
|
+
new_srcs.extend(u.src)
|
102
|
+
to_append.extend(u.arg.lst)
|
103
|
+
elif u.op is Ops.BLOCKFORK and x.src.count(u) == u.arg: # block fork appears # of times in srcs
|
104
|
+
if u not in placed:
|
105
|
+
new_srcs.extend(u.src)
|
106
|
+
placed.add(u)
|
107
|
+
else:
|
108
|
+
# keep it in srcs
|
109
|
+
new_srcs.append(u)
|
110
|
+
if len(to_append) == 0 and len(placed) == 0: return None
|
111
|
+
return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(tuple(sorted(new_ctx, key=lambda x: x.tuplize)), tuple(to_append)+x.arg.lst, x.arg.end))
|
112
|
+
|
113
|
+
pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),])
|
114
|
+
|
115
|
+
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
|
116
|
+
def block_reorder(in_block:UOp):
|
117
|
+
in_this_block = set(in_block.arg.lst)
|
118
|
+
local_children: collections.defaultdict[UOp, list[UOp]] = collections.defaultdict(list)
|
119
|
+
in_degree: collections.defaultdict[UOp, int] = collections.defaultdict(int)
|
120
|
+
priorities:dict[UOp, int] = {}
|
121
|
+
|
122
|
+
# get local children and assign priorities
|
123
|
+
for u in reversed(in_block.arg.lst):
|
124
|
+
for s in u.src:
|
125
|
+
if s in in_this_block:
|
126
|
+
local_children[s].append(u)
|
127
|
+
in_degree[u] += 1
|
128
|
+
# put loads in the beginning of the block and prevent priority inversion
|
129
|
+
priorities[u] = min([-1000 if u.op is Ops.LOAD else 0] + [priorities[x] for x in local_children[u]])
|
130
|
+
|
131
|
+
# placement queue
|
132
|
+
queue:list[tuple[int, tuple, UOp]] = []
|
65
133
|
def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u))
|
66
134
|
|
67
|
-
|
68
|
-
|
135
|
+
# place the first ones that don't have deps
|
136
|
+
for u in in_block.arg.lst:
|
137
|
+
if u not in in_degree: push(u)
|
69
138
|
|
70
|
-
|
71
|
-
_uops: List[UOp] = []
|
139
|
+
newlst = []
|
72
140
|
while queue:
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
if x.op is Ops.DEFINE_ACC:
|
77
|
-
idx = min([_uops.index(l) for l in x.src if l.op is Ops.RANGE])
|
78
|
-
_uops.insert(idx, x)
|
79
|
-
else: _uops.append(x)
|
80
|
-
for u, ss in scope_children.items():
|
81
|
-
if x in ss:
|
82
|
-
ss.remove(x)
|
83
|
-
if len(ss) == 0: scope_end[u] = x
|
84
|
-
for u in children[x]:
|
141
|
+
_,_,x = heapq.heappop(queue)
|
142
|
+
newlst.append(x)
|
143
|
+
for u in local_children[x]:
|
85
144
|
in_degree[u] -= 1
|
86
145
|
if in_degree[u] == 0: push(u)
|
87
146
|
|
88
|
-
|
89
|
-
|
147
|
+
assert len(newlst) == len(in_block.arg.lst), f"len mismatch {len(newlst)} != {len(in_block.arg.lst)}"
|
148
|
+
return in_block.replace(arg=BasicBlock(in_block.arg.ctx, tuple(newlst)))
|
149
|
+
|
150
|
+
def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
|
151
|
+
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
152
|
+
|
153
|
+
# get children and all block contexts
|
154
|
+
temp_block_ctxs: dict[UOp, list[UOp]] = {}
|
155
|
+
children: dict[UOp, list[UOp]] = {}
|
156
|
+
for u in sink.toposort:
|
157
|
+
this_block_ctx: list[UOp] = []
|
158
|
+
for s in u.src:
|
159
|
+
# save children
|
160
|
+
children.setdefault(s, []).append(u)
|
161
|
+
# compute block ctx
|
162
|
+
if s.op in {Ops.RANGE, Ops.IF}: this_block_ctx.append(s)
|
163
|
+
# don't flow (fully) through assign and store
|
164
|
+
elif s.op is Ops.STORE:
|
165
|
+
# ugh, deal with non-reduce locals. probably wrong
|
166
|
+
if isinstance(s.src[0].dtype, PtrDType) and s.src[0].dtype.local:
|
167
|
+
idx_context, store_context = temp_block_ctxs[s.src[0]], temp_block_ctxs[s]
|
168
|
+
this_block_ctx += [x for x in store_context if x not in idx_context and x.op is Ops.RANGE]
|
169
|
+
elif s.op is Ops.ASSIGN:
|
170
|
+
# flow though assign, but remove the ranges used in the assign
|
171
|
+
assert s.src[0].op is Ops.DEFINE_ACC
|
172
|
+
this_block_ctx += [x for x in temp_block_ctxs[s.src[1]] if x not in s.src[0].src[1:]]
|
173
|
+
else:
|
174
|
+
# flow though everything else
|
175
|
+
this_block_ctx += temp_block_ctxs[s]
|
176
|
+
temp_block_ctxs[u] = sorted(dedup(this_block_ctx), key=lambda x: x.tuplize)
|
177
|
+
|
178
|
+
# make final block_ctxs, add BLOCKSTART to block_ctxs for IF and RANGE
|
179
|
+
block_ctxs: dict[UOp, tuple[UOp, ...]] = {}
|
180
|
+
for u in sink.toposort:
|
181
|
+
block_ctxs[u] = ((UOp(Ops.BLOCKSTART, src=(u,)),) + tuple(temp_block_ctxs[u])) if u.op in {Ops.IF, Ops.RANGE} else tuple(temp_block_ctxs[u])
|
182
|
+
|
183
|
+
# TODO: there's probably a clever way to remove this while loop
|
184
|
+
while 1:
|
185
|
+
sink = graph_rewrite(sink, make_basic_blocks, ctx=(block_ctxs, children))
|
186
|
+
|
187
|
+
# add BLOCKFORK (slow!)
|
188
|
+
block_parent_count = collections.Counter(flatten([x.src for x in sink.toposort if x.op is Ops.BLOCK]))
|
189
|
+
non_block_parents = set(flatten([x.src for x in sink.toposort if x.op is not Ops.BLOCK]))
|
190
|
+
forks = {u:UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], (u,))),), arg=child_count)
|
191
|
+
for u,child_count in block_parent_count.items() if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents}
|
192
|
+
|
193
|
+
if not len(forks): break
|
194
|
+
sink = sink.substitute(forks)
|
195
|
+
|
196
|
+
# combine matching BLOCKENDS
|
197
|
+
blockends_to_arg: dict[UOp, list[UOp]] = {}
|
198
|
+
for be in sink.toposort:
|
199
|
+
if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
|
200
|
+
new_forks = {}
|
201
|
+
for k,v in blockends_to_arg.items():
|
202
|
+
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
|
203
|
+
if len(v) > 1:
|
204
|
+
out = UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCKEND, src=tuple(flatten(x.src for x in v)),
|
205
|
+
arg=BasicBlock(tuple(dedup(flatten([y.arg.ctx for y in v]))), v[0].arg.lst, k)),), arg=len(v))
|
206
|
+
for u in v: new_forks[u] = out
|
207
|
+
sink = sink.substitute(new_forks)
|
208
|
+
|
209
|
+
# reorder ops in block for speed
|
210
|
+
sink = sink.substitute({u:newu for u in sink.toposort if u.op is Ops.BLOCK and (newu:=block_reorder(u)) is not u})
|
211
|
+
|
212
|
+
# final rewrite to merge all blocks into one
|
213
|
+
sink = graph_rewrite(sink, pm_block_merge, ctx=children)
|
214
|
+
|
215
|
+
# there should just be one block left, with a few parents with 0 srcs
|
216
|
+
assert sink.op is Ops.BLOCK
|
217
|
+
_uops = sorted(dedup(sink.src), key=lambda x: x.tuplize)
|
218
|
+
assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
|
219
|
+
_uops += sink.arg.lst
|
90
220
|
|
91
221
|
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
|
92
222
|
if not skip_check: type_verify(_uops)
|
tinygrad/codegen/lowerer.py
CHANGED
@@ -1,17 +1,14 @@
|
|
1
1
|
# the job of the lowerer is to do indexing
|
2
|
-
from __future__ import annotations
|
3
2
|
import functools, itertools, operator
|
4
3
|
from dataclasses import dataclass
|
5
|
-
from typing import
|
6
|
-
from tinygrad.shape.shapetracker import ShapeTracker
|
7
|
-
from tinygrad.shape.view import variable_to_uop
|
4
|
+
from typing import cast
|
8
5
|
from tinygrad.dtype import dtypes, PtrDType
|
9
|
-
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
|
6
|
+
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
|
10
7
|
from tinygrad.renderer import Renderer
|
11
|
-
from tinygrad.helpers import all_int, prod, partition, flatten
|
8
|
+
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap
|
12
9
|
|
13
10
|
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
14
|
-
def get_contraction(old_shape:
|
11
|
+
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
|
15
12
|
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
16
13
|
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
|
17
14
|
except ValueError: return None
|
@@ -19,7 +16,7 @@ def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> O
|
|
19
16
|
|
20
17
|
# ***** indexing *****
|
21
18
|
|
22
|
-
def _limit_dims(dims:
|
19
|
+
def _limit_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
|
23
20
|
# TODO: symbolic shape
|
24
21
|
if not all_int(dims): return dims
|
25
22
|
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
|
@@ -30,25 +27,24 @@ def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
|
|
30
27
|
else: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
31
28
|
return dims
|
32
29
|
|
33
|
-
def get_grouped_dims(prefix, dims:
|
30
|
+
def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|None, reverse=False) -> list[UOp]:
|
34
31
|
if reverse: dims = dims[::-1]
|
35
32
|
limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
|
36
33
|
ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
|
37
34
|
if limited != dims:
|
38
35
|
ret = []
|
39
|
-
|
40
|
-
for idx,
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
idx //= dims[c]
|
36
|
+
if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}")
|
37
|
+
for idx, contraction_group in zip(raw_idxs, contraction):
|
38
|
+
for c in contraction_group[:-1]:
|
39
|
+
ret.append(idx % dims[c])
|
40
|
+
idx //= dims[c]
|
41
|
+
ret.append(idx)
|
46
42
|
return ret[::-1] if reverse else ret
|
47
43
|
|
48
44
|
@dataclass
|
49
45
|
class IndexContext:
|
50
|
-
idxs:
|
51
|
-
ridxs:
|
46
|
+
idxs: list[UOp]
|
47
|
+
ridxs: list[UOp]
|
52
48
|
acc_num: int = 0
|
53
49
|
|
54
50
|
def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
@@ -56,14 +52,11 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
|
56
52
|
# NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
|
57
53
|
full_shape = ast.full_shape
|
58
54
|
first_upcasted = len(full_shape)-ki.upcasted
|
59
|
-
first_output_st: ShapeTracker = ast.src[0].st_arg
|
60
55
|
# if there's no reduce, this is first_upcasted. assumes reduces are at the end
|
61
|
-
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.
|
62
|
-
local_loads = [x for x in ast.
|
56
|
+
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.toposort if x.op is Ops.REDUCE_AXIS))
|
57
|
+
local_loads = [x for x in ast.toposort if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
|
63
58
|
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
|
64
|
-
group_for_reduces = sum([any(
|
65
|
-
[[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)],
|
66
|
-
first_output_st.shape[first_reduce:first_upcasted])]) if local_loads else 0
|
59
|
+
group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)])
|
67
60
|
global_dims = first_reduce-ki.local_dims
|
68
61
|
|
69
62
|
if opts.has_local:
|
@@ -76,22 +69,21 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
|
76
69
|
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
77
70
|
else:
|
78
71
|
# all loops are RANGES
|
79
|
-
idxs = [UOp(Ops.RANGE, dtypes.int, (
|
80
|
-
for i,g in enumerate(full_shape[:first_reduce])]
|
72
|
+
idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i) for i,g in enumerate(full_shape[:first_reduce])]
|
81
73
|
|
82
74
|
# reduce loops
|
83
|
-
idxs += [UOp(Ops.RANGE, dtypes.int, (
|
75
|
+
idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i)
|
84
76
|
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
|
85
77
|
|
86
78
|
# upcast loops
|
87
79
|
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
|
88
80
|
assert isinstance(g, int), "needs to be int to upcast/unroll"
|
89
|
-
idxs.append(UOp(Ops.
|
81
|
+
idxs.append(UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
|
90
82
|
|
91
83
|
# late indexes (group for reduce)
|
92
84
|
ridxs = idxs[:]
|
93
85
|
for a in range(first_reduce, first_reduce+group_for_reduces):
|
94
|
-
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (
|
86
|
+
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(full_shape[a])), 1000+a)
|
95
87
|
|
96
88
|
return IndexContext(idxs, ridxs)
|
97
89
|
|
@@ -100,7 +92,7 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
|
100
92
|
def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
101
93
|
# NOTE: always using ridxs is fine here
|
102
94
|
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
|
103
|
-
assert all(x.op is Ops.
|
95
|
+
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
|
104
96
|
alu_op: Ops = x.arg[0]
|
105
97
|
ret = x.src[0]
|
106
98
|
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
@@ -114,12 +106,10 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
|
114
106
|
|
115
107
|
def lower_load_store(ctx: IndexContext, x: UOp):
|
116
108
|
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
|
117
|
-
# TODO: check has_valid in UPat, not here
|
118
|
-
has_valid = valid.op is not Ops.CONST or valid.arg is not True
|
119
109
|
buf = x.src[0]
|
120
110
|
if x.op is Ops.LOAD:
|
121
111
|
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
|
122
|
-
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid
|
112
|
+
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
|
123
113
|
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
124
114
|
if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.ASSIGN:
|
125
115
|
reduce_input = x.src[2].src[1].src[1] if x.src[2].src[1].src[1] is not x.src[2].src[0] else x.src[2].src[1].src[0]
|
@@ -130,14 +120,19 @@ def lower_load_store(ctx: IndexContext, x: UOp):
|
|
130
120
|
if (not cast(PtrDType, x.src[0].dtype).local) or store_back:
|
131
121
|
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
132
122
|
if oidx is not ridx: valid = valid * oidx.eq(0)
|
133
|
-
|
134
|
-
|
123
|
+
return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid), x.src[2]))
|
124
|
+
|
125
|
+
def lower_const(x:UOp):
|
126
|
+
assert all(v.mask is None for v in unwrap(x.st).views), f"VIEW in CONST/DEFINE_VAR source must be unmasked, got {x.st}"
|
127
|
+
return x.replace(src=())
|
135
128
|
|
136
129
|
pm_lowerer = PatternMatcher([
|
137
130
|
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
131
|
+
(UPat((Ops.CONST, Ops.DEFINE_VAR), src=(UPat(Ops.VIEW),), name="x"), lower_const),
|
138
132
|
(UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
|
139
133
|
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
|
140
134
|
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
|
135
|
+
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
|
141
136
|
])
|
142
137
|
|
143
138
|
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
|