tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +78 -90
- tinygrad/codegen/linearizer.py +237 -169
- tinygrad/codegen/uops.py +278 -242
- tinygrad/device.py +147 -10
- tinygrad/dtype.py +7 -7
- tinygrad/engine/graph.py +16 -16
- tinygrad/engine/jit.py +39 -36
- tinygrad/engine/realize.py +6 -5
- tinygrad/engine/schedule.py +15 -7
- tinygrad/engine/search.py +6 -3
- tinygrad/function.py +17 -23
- tinygrad/helpers.py +77 -8
- tinygrad/lazy.py +26 -26
- tinygrad/multi.py +13 -9
- tinygrad/nn/__init__.py +1 -1
- tinygrad/nn/datasets.py +2 -1
- tinygrad/nn/state.py +3 -4
- tinygrad/ops.py +49 -16
- tinygrad/renderer/__init__.py +8 -4
- tinygrad/renderer/assembly.py +93 -100
- tinygrad/renderer/cstyle.py +47 -42
- tinygrad/renderer/llvmir.py +30 -30
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +11504 -1
- tinygrad/runtime/autogen/comgr.py +36 -10
- tinygrad/runtime/autogen/hsa.py +146 -14
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/nv_gpu.py +269 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +20 -11
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +3 -2
- tinygrad/runtime/graph/cuda.py +2 -2
- tinygrad/runtime/graph/hcq.py +122 -78
- tinygrad/runtime/ops_amd.py +302 -316
- tinygrad/runtime/ops_cuda.py +3 -3
- tinygrad/runtime/ops_disk.py +70 -5
- tinygrad/runtime/ops_gpu.py +2 -2
- tinygrad/runtime/ops_metal.py +5 -6
- tinygrad/runtime/ops_npy.py +1 -1
- tinygrad/runtime/ops_nv.py +161 -166
- tinygrad/runtime/ops_python.py +20 -16
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +5 -2
- tinygrad/shape/symbolic.py +1 -3
- tinygrad/shape/view.py +34 -19
- tinygrad/tensor.py +219 -135
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/runtime/driver/hsa.py +0 -143
- tinygrad/runtime/graph/hsa.py +0 -171
- tinygrad/runtime/ops_hsa.py +0 -278
- tinygrad-0.9.0.dist-info/RECORD +0 -60
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/codegen/linearizer.py
CHANGED
@@ -1,35 +1,83 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import List, Tuple,
|
2
|
+
from typing import List, Tuple, Optional, Type, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence, Callable
|
3
3
|
import itertools, math, functools
|
4
4
|
from collections import defaultdict
|
5
5
|
|
6
|
-
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
7
|
-
from tinygrad.helpers import colored, DEBUG, prod, getenv, to_function_name
|
6
|
+
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
7
|
+
from tinygrad.helpers import colored, DEBUG, dedup, diskcache_put, prod, getenv, to_function_name, flatten
|
8
8
|
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
|
9
9
|
from tinygrad.shape.shapetracker import ShapeTracker
|
10
|
-
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node
|
10
|
+
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node, sint
|
11
11
|
from tinygrad.codegen.kernel import LocalBuffer, Kernel
|
12
12
|
from tinygrad.renderer import Program
|
13
13
|
|
14
14
|
from tinygrad.codegen.uops import UOps, UOp, UOpGraph
|
15
15
|
|
16
|
-
def get_grouped_dims(prefix,
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
16
|
+
def get_grouped_dims(prefix:str, off:int, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse_dims:bool=False):
|
17
|
+
""" Maps all global/local dims onto global/local sizes and returns the idxs, loop_idxs and sizes.
|
18
|
+
|
19
|
+
* If there are fewer dims than size, size will be padded with 1s to the length of max_sizes.
|
20
|
+
* If there are more dims than size, dims will be collapsed onto size starting from left-most (i.e. onto x, then y, then z).
|
21
|
+
* If the dim is too large for the size, the dim will be split between adjacent size axes space permitting, otherwise assert
|
22
|
+
|
23
|
+
Keyword arguments:
|
24
|
+
prefix -- the prefix to use for the size Variable names.
|
25
|
+
off -- the starting index for the size Variable names.
|
26
|
+
dims -- the global or local dims of the full shape.
|
27
|
+
max_sizes -- the maximum values for each size in (x, y, z) order.
|
28
|
+
reverse_dims -- reverse the order of the dims as they are mapped into size, i.e. if True, the right dim will go to the left size (.x).
|
29
|
+
"""
|
30
|
+
|
31
|
+
# check the edge cases on max_sizes
|
32
|
+
if max_sizes is None: max_sizes = tuple([0xFFFFFFFFFFFFFFFF] * len(dims))
|
33
|
+
assert len(max_sizes) > 0 or len(dims) == 0, f"{prefix} dims should be empty because no size axes available"
|
34
|
+
if len(max_sizes) == 0: return [], [], None
|
35
|
+
|
36
|
+
# initialize the map of dims to size with a single dim in each size axis
|
37
|
+
# TODO: support sint properly
|
38
|
+
size_dims:List[List[Tuple[int, sint, sint]]] = [[(dim_idx, dim, dim if isinstance(dim, int) else dim.max+1)] for dim_idx, dim in enumerate(dims)]
|
39
|
+
|
40
|
+
# reverse the order of the dims to size map, if desired (currently for globals where smallest stride is on the right)
|
41
|
+
# TODO: remove reverse_dims, the mapping of dims to size for globals should be cosearched with memory layouts for optimal peformance
|
42
|
+
if reverse_dims: size_dims = size_dims[::-1]
|
43
|
+
|
44
|
+
# ensure that the initial dims initially fit the valid size axes
|
45
|
+
for size_idx in range(min(len(max_sizes), len(size_dims))):
|
46
|
+
# if the initial dim is too large, split the dim to separate size axes, if possible
|
47
|
+
dim_idx, dim, dim_max = size_dims[size_idx][0]
|
48
|
+
if dim_max <= (max_sz:=max_sizes[size_idx]): continue
|
49
|
+
assert isinstance(dim, int), "variable shape too large for size"
|
50
|
+
for factor in range(2, int(dim**0.5)+1):
|
51
|
+
if dim % factor == 0 and dim // factor <= max_sz:
|
52
|
+
size_dims = size_dims[:size_idx] + [[(dim_idx, dim//factor, dim//factor)], [(dim_idx, factor, factor)]] + size_dims[size_idx+1:]
|
53
|
+
break
|
54
|
+
assert size_dims[size_idx][0][2] <= max_sz, f"dim at {size_idx} too large and non-factorable: {dim} > {max_sz}"
|
55
|
+
|
56
|
+
# compress the extra dims, collapsing them onto the left-most valid size axis
|
57
|
+
cur_size_idx = 0
|
58
|
+
while len(size_dims) > len(max_sizes):
|
59
|
+
if prod([dim_max for (_, _, dim_max) in size_dims[cur_size_idx]])*size_dims[cur_size_idx+1][0][2] <= max_sizes[cur_size_idx]:
|
60
|
+
size_dims = size_dims[:cur_size_idx] + [size_dims[cur_size_idx] + size_dims[cur_size_idx+1]] + size_dims[cur_size_idx+2:]
|
61
|
+
elif cur_size_idx < len(max_sizes)-1: cur_size_idx += 1
|
62
|
+
else: raise AssertionError(f"cannot fit dims in size: {dims=} {max_sizes=}")
|
63
|
+
|
64
|
+
# construct the final dim idx variables from the the portions of the size variables
|
65
|
+
sizes, idxs = [prod([dim for (_, dim, _) in size_dim]) for size_dim in size_dims], [NumNode(0)] * len(dims)
|
66
|
+
size_vars = loop_idxs = [Variable(f"{prefix}{len(sizes)-1-(i+off) if reverse_dims else i+off}", 0, s-1) for i,s in enumerate(sizes)]
|
67
|
+
for size_idx, size_var in enumerate(size_vars):
|
68
|
+
for dim_idx, dim, _ in size_dims[size_idx]:
|
69
|
+
idxs[dim_idx] += (size_var % dim) * (idxs[dim_idx].max+1)
|
70
|
+
size_var //= dim
|
71
|
+
|
72
|
+
# pad the final sizes array to the proper length if necessary
|
73
|
+
return idxs, [x for x in loop_idxs if not isinstance(x, NumNode)], sizes + [1]*(len(max_sizes)-len(sizes))
|
26
74
|
|
27
75
|
def expand_idx(node:Node) -> Union[Variable, NumNode]: return next((v for v in node.vars() if v.expr.startswith("_uidx")), NumNode(0))
|
28
76
|
def expand_idxs(nodes:Sequence[Node]) -> Tuple[Union[Variable, NumNode], ...]:
|
29
77
|
eidxs = [expand_idx(node) for node in nodes]
|
30
78
|
return tuple([v if v not in eidxs[:j] else NumNode(0) for j, v in enumerate(eidxs)]) # take only first occurrence of expand variable
|
31
79
|
def iter_idxs(idxs:Tuple[Union[Variable, NumNode], ...]) -> Iterator[Tuple[int,...]]:
|
32
|
-
yield from (x[::-1] for x in itertools.product(*[
|
80
|
+
yield from (x[::-1] for x in itertools.product(*[list(range(v.min, v.max + 1)) for v in idxs[::-1]]))
|
33
81
|
|
34
82
|
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
|
35
83
|
idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
|
@@ -44,13 +92,21 @@ def expand_node(node:Node, idxs:Optional[Tuple[Union[Variable, NumNode], ...]]=N
|
|
44
92
|
if idxs is None: idxs = (expand_idx(node),)
|
45
93
|
return [node.substitute({k:v for k,v in zip(idxs, (NumNode(x) for x in rep)) if isinstance(k, Variable)}) for rep in iter_idxs(idxs)]
|
46
94
|
|
47
|
-
|
48
|
-
|
95
|
+
def variable_to_uop(x, ctx=None) -> UOp:
|
96
|
+
if isinstance(x, int): return UOp.const(dtypes.int, x)
|
97
|
+
return x.render(render_ops, ctx)
|
49
98
|
|
50
|
-
|
51
|
-
|
52
|
-
|
99
|
+
render_ops: Dict[Type, Callable[..., UOp]] = {
|
100
|
+
NumNode: lambda self, ops, ctx: UOp.const(dtypes.int, self.b),
|
101
|
+
Variable: lambda self, ops, ctx: ctx[self.expr] if self.expr in ctx else UOp(UOps.DEFINE_VAR, dtypes.int, (), self),
|
102
|
+
MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
|
103
|
+
DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
|
104
|
+
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
|
105
|
+
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
|
106
|
+
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+variable_to_uop(b, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
107
|
+
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*variable_to_uop(b, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
53
108
|
|
109
|
+
class Linearizer(Kernel):
|
54
110
|
def get_reduce_acc(self, reduceop:LazyOp):
|
55
111
|
if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
|
56
112
|
if reduceop.op is ReduceOps.MAX:
|
@@ -60,16 +116,6 @@ class Linearizer(Kernel):
|
|
60
116
|
# NOTE: once images are loaded, we uop them as their base float
|
61
117
|
def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt
|
62
118
|
|
63
|
-
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
|
64
|
-
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
|
65
|
-
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
|
66
|
-
ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
|
67
|
-
LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT),
|
68
|
-
SumNode: lambda self,ops,ctx:
|
69
|
-
functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
70
|
-
AndNode: lambda self,ops,ctx:
|
71
|
-
functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
72
|
-
|
73
119
|
def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Optional[UOp]=None, loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]:
|
74
120
|
buf = self.bufs[i]
|
75
121
|
localtype = self.get_base_dtype(buf.dtype if acc is None else acc.dtype)
|
@@ -89,48 +135,47 @@ class Linearizer(Kernel):
|
|
89
135
|
# todo: multioutput test with different output valids to add if acc is None: g_valid = NumNode(1)
|
90
136
|
|
91
137
|
if amt > 1: localtype = localtype.vec(amt)
|
92
|
-
e_idxs, e_valids = expand_node(g_idx, expand_vars), expand_node(g_valid, expand_vars)
|
138
|
+
e_idxs, e_valids = expand_node(g_idx, expand_vars), expand_node(g_valid, expand_vars) # pylint: disable=possibly-used-before-assignment
|
93
139
|
|
94
140
|
ret = []
|
95
141
|
invalid_value = 0
|
96
142
|
acc_count = 0
|
97
143
|
for idx, valid, rep_idx in zip(e_idxs, e_valids, iter_idxs(expand_vars)):
|
98
144
|
this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid)
|
99
|
-
|
100
|
-
key = f"{acc is not None}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
|
145
|
+
key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
|
101
146
|
if key not in self.load_cache:
|
102
147
|
if acc is not None:
|
103
|
-
self.load_cache[key] =
|
148
|
+
self.load_cache[key] = UOp(UOps.DEFINE_ACC, localtype, (UOp.const(localtype.scalar(), self.get_reduce_acc(acc)), *loop_ctx), (i, acc_count))
|
104
149
|
acc_count += 1
|
105
150
|
elif this_const is not None:
|
106
|
-
self.load_cache[key] =
|
151
|
+
self.load_cache[key] = UOp.const(localtype, this_const)
|
107
152
|
if valid.min == 0 and valid.max == 1:
|
108
|
-
valid_rendered = valid.render(
|
109
|
-
self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_rendered, self.load_cache[key],
|
153
|
+
valid_rendered = valid.render(render_ops, self.loop_uops)
|
154
|
+
self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_rendered, self.load_cache[key], UOp.const(localtype, invalid_value))
|
110
155
|
elif isinstance(buf.dtype, ImageDType):
|
111
156
|
buf_uop = self.buf_uops[i]
|
112
157
|
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
113
158
|
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
114
|
-
rendered_idx =
|
115
|
-
valid_tuple = (valid.render(
|
116
|
-
self.load_cache[key] =
|
159
|
+
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
|
160
|
+
valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
|
161
|
+
self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4),
|
117
162
|
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
118
163
|
if localtype == localtype.scalar():
|
119
164
|
idx_small = idx%4
|
120
|
-
res = idx_small.render(
|
121
|
-
out =
|
165
|
+
res = idx_small.render(render_ops, self.loop_uops)
|
166
|
+
out = UOp(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
|
122
167
|
for ix in range(idx_small.max, idx_small.min, -1):
|
123
|
-
rvv =
|
124
|
-
sel = UOp.alu(BinaryOps.CMPLT, res,
|
168
|
+
rvv = UOp(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
|
169
|
+
sel = UOp.alu(BinaryOps.CMPLT, res, UOp.const(dtypes.int, ix))
|
125
170
|
out = UOp.alu(TernaryOps.WHERE, sel, rvv, out)
|
126
171
|
self.load_cache[key] = out
|
127
172
|
else:
|
128
173
|
buf_uop = self.buf_uops[i]
|
129
174
|
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
130
|
-
rendered_idx = idx.render(
|
131
|
-
valid_tuple = (valid.render(
|
132
|
-
self.load_cache[key] =
|
133
|
-
ret.append(
|
175
|
+
rendered_idx = idx.render(render_ops, self.loop_uops)
|
176
|
+
valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
|
177
|
+
self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
178
|
+
ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
134
179
|
return ret
|
135
180
|
|
136
181
|
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
|
@@ -153,7 +198,7 @@ class Linearizer(Kernel):
|
|
153
198
|
amt = len(grouped)
|
154
199
|
idx, valid = self.sts[i].expr_idxs(k)
|
155
200
|
assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
|
156
|
-
store_offset_new[k] =
|
201
|
+
store_offset_new[k] = UOp(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
|
157
202
|
store_offset = store_offset_new
|
158
203
|
|
159
204
|
stores = []
|
@@ -161,29 +206,24 @@ class Linearizer(Kernel):
|
|
161
206
|
idx, valid = self.sts[i].expr_idxs(_idx)
|
162
207
|
if isinstance(buf.dtype, ImageDType):
|
163
208
|
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
164
|
-
rendered_idx =
|
165
|
-
tuple(x.render(
|
209
|
+
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), \
|
210
|
+
tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
|
166
211
|
else:
|
167
|
-
rendered_idx = idx.render(
|
168
|
-
|
169
|
-
|
212
|
+
rendered_idx = idx.render(render_ops, self.loop_uops)
|
213
|
+
# TODO: let UPat check this once it's fast
|
214
|
+
if valid.min == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var)))
|
215
|
+
else: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(render_ops, self.loop_uops))))
|
170
216
|
return stores
|
171
217
|
|
172
218
|
# render loop
|
173
|
-
def render_loop(self, xx:List[Variable], depth:int) -> Tuple[UOp, ...]:
|
174
|
-
new_loops = {x.expr:
|
175
|
-
|
176
|
-
|
219
|
+
def render_loop(self, xx:List[Variable], depth:int, reduce:bool) -> Tuple[UOp, ...]:
|
220
|
+
new_loops = {x.expr:UOp(UOps.RANGE, dtypes.int32, (
|
221
|
+
UOp.const(dtypes.int, x.min) if isinstance(x.min, int) else cast(Node, x.min).render(render_ops, self.loop_uops),
|
222
|
+
UOp.const(dtypes.int, x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(render_ops, self.loop_uops)), arg=(depth,i,reduce)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
|
177
223
|
self.loop_uops.update(new_loops)
|
178
224
|
return tuple(new_loops.values())
|
179
225
|
|
180
|
-
def
|
181
|
-
global_idxs, local_idxs, upcast_idxs):
|
182
|
-
# define indicies
|
183
|
-
full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
|
184
|
-
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
|
185
|
-
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
186
|
-
|
226
|
+
def index_local_aliases(self, global_idxs, local_idxs, reduce_idxs, upcast_idxs, full_upcast_idxs):
|
187
227
|
def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
|
188
228
|
replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
|
189
229
|
for s in local_sizes:
|
@@ -199,33 +239,39 @@ class Linearizer(Kernel):
|
|
199
239
|
replace_idxs.append(full_var)
|
200
240
|
return replace_idxs
|
201
241
|
|
202
|
-
# compute local aliases
|
203
|
-
alias_buf_idxs
|
204
|
-
for
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
#
|
218
|
-
loop_ctx = self.render_loop(reduce_idxs, 2)
|
219
|
-
|
220
|
-
# define accumulator - modify idxs if necessary for TC
|
221
|
-
out_buf = -1 if self.group_for_reduces else 0
|
242
|
+
# compute local aliases
|
243
|
+
alias_buf_idxs: DefaultDict[LazyOp, List[Tuple[int, int, List]]] = defaultdict(list)
|
244
|
+
for op, local_alias in self.local_alias.items():
|
245
|
+
for i in local_alias:
|
246
|
+
localbuf_idx = self.bufs.index(local_alias[i])
|
247
|
+
buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
|
248
|
+
if (tc:=self.tensor_core):
|
249
|
+
min_alias_idx = min(local_alias.keys())
|
250
|
+
replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
|
251
|
+
for n in range(len(tc.threads)):
|
252
|
+
buf_idxs[self.global_dims+n] = replace_input_idxs[n] # replace locals
|
253
|
+
for n in range(tc.num_upcasts()):
|
254
|
+
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
|
255
|
+
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
|
256
|
+
alias_buf_idxs[op].append((i, localbuf_idx, buf_idxs))
|
257
|
+
# modify idxs if necessary for TC
|
222
258
|
if (tc:=self.tensor_core):
|
223
259
|
replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
|
224
260
|
for n in range(len(tc.threads)):
|
225
261
|
local_idxs[n] = replace_acc_idxs[n] # replace locals
|
226
262
|
for n in range(len(replace_acc_idxs)-len(tc.threads)):
|
227
263
|
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
|
228
|
-
if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+
|
264
|
+
if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+upcast_idxs}")
|
265
|
+
return alias_buf_idxs
|
266
|
+
|
267
|
+
def render_reduceop(self, reduceop:LazyOp, accs:Dict[LazyOp, List[UOp]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]],
|
268
|
+
global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, reduce_idxs, fake_reduce_idxs,
|
269
|
+
alias_buf_idxs:List[Tuple[int, int, List]]) -> Tuple[List[NumNode|Variable], List[NumNode|Variable]]:
|
270
|
+
# reduce loop
|
271
|
+
loop_ctx = self.render_loop(reduce_idxs, (i:=self.reduceops.index(reduceop))*2+2, True)
|
272
|
+
|
273
|
+
# define accumulator - modify idxs if necessary for TC
|
274
|
+
out_buf = -len(self.reduceops)+i if self.group_for_reduces else 0
|
229
275
|
accs[reduceop] = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
|
230
276
|
|
231
277
|
# store local aliases
|
@@ -235,34 +281,39 @@ class Linearizer(Kernel):
|
|
235
281
|
# run tensor cores AST
|
236
282
|
wmma_sz = [prod(l) for l in tc.thread_local_sizes]
|
237
283
|
def upcast_strides(buf:int):
|
238
|
-
strides,
|
239
|
-
for (sz, stride,
|
240
|
-
strides.append((0 if stride == 0 else
|
241
|
-
|
284
|
+
strides, next_ = [], 1
|
285
|
+
for (sz, stride, _) in self.upcasted_axis(buf)[tc.num_upcasts():]:
|
286
|
+
strides.append((0 if stride == 0 else next_, sz))
|
287
|
+
next_ *= 1 if stride == 0 else sz
|
242
288
|
return strides
|
243
289
|
upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
|
244
290
|
# cast initial accs
|
245
|
-
wmmas = [
|
291
|
+
wmmas = [UOp(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
|
246
292
|
for x in range(0, len(accs[reduceop]), wmma_sz[2])]
|
247
|
-
for
|
248
|
-
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(
|
249
|
-
ops = (
|
250
|
-
|
293
|
+
for it in [x[::-1] for x in itertools.product(*list([range(sz) for _,sz in upcasts[0]][::-1]))]:
|
294
|
+
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(it, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
|
295
|
+
ops = (UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
|
296
|
+
UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
|
251
297
|
wmmas[(wmma_idx:=offs[2]//wmma_sz[2])])
|
252
298
|
# TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid
|
253
|
-
wmmas[wmma_idx] =
|
299
|
+
wmmas[wmma_idx] = UOp(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev))
|
254
300
|
# phi the last wmmas back to accs
|
255
|
-
accs[reduceop] = [
|
301
|
+
accs[reduceop] = [UOp(UOps.PHI, tc.dtype_out, (acc, UOp(UOps.GEP, tc.dtype_out, (wmmas[z//wmma_sz[2]],), z%wmma_sz[2])))
|
256
302
|
for z, acc in enumerate(accs[reduceop])]
|
257
303
|
else:
|
258
304
|
assert not locals_to_store, "storing locals isn't supported here"
|
259
305
|
|
260
306
|
# load earlybufs
|
261
|
-
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i,
|
307
|
+
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[reduceop][i]) if i in self.local_alias else i,
|
262
308
|
global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs})
|
263
309
|
|
310
|
+
def gate_acc(r, idxs): return [
|
311
|
+
UOp.alu(TernaryOps.WHERE, valid.render(render_ops, self.loop_uops), acc, UOp.const(r.dtype, 0)) if valid.min == 0 and valid.max == 1 else acc
|
312
|
+
for valid, acc in zip(expand_node(self.sts[self.full_buf_index].expr_idxs(idxs)[1], expand_idxs(idxs)), accs[r])]
|
313
|
+
local_accs = {r: gate_acc(r,global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for r in accs}
|
314
|
+
|
264
315
|
# run early AST (with reduce)
|
265
|
-
self.ast_parse(reduceop,
|
316
|
+
self.ast_parse(reduceop, local_accs, self.acc_offsets(self.full_buf_index), loaded_buffers, reduce_acc=accs[reduceop])
|
266
317
|
|
267
318
|
# end the reduce loop
|
268
319
|
self.load_cache.clear()
|
@@ -270,13 +321,13 @@ class Linearizer(Kernel):
|
|
270
321
|
# end the local loop, do the local reduce
|
271
322
|
if self.group_for_reduces:
|
272
323
|
fake_global_idxs = [x*0 for x in global_idxs]
|
273
|
-
stores = self.global_store(
|
274
|
-
barrier =
|
324
|
+
stores = self.global_store(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop]) # store accumulators
|
325
|
+
barrier = UOp(UOps.BARRIER, None, tuple(stores))
|
275
326
|
if self.opts.has_local:
|
276
327
|
fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
|
277
328
|
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
278
|
-
if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(
|
279
|
-
barrier =
|
329
|
+
if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(render_ops, self.loop_uops)
|
330
|
+
barrier = UOp(UOps.IF, None, (if_cond, barrier))
|
280
331
|
|
281
332
|
# create new late reduce local loops and replace local_idxs that have been used
|
282
333
|
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501
|
@@ -295,71 +346,68 @@ class Linearizer(Kernel):
|
|
295
346
|
# NOTE: this structure is the same as the reduce op above
|
296
347
|
|
297
348
|
# late reduce loop
|
298
|
-
loop_ctx = self.render_loop(end_local_idxs, 3)
|
349
|
+
loop_ctx = self.render_loop(end_local_idxs, i*2+3, True)
|
299
350
|
|
300
351
|
# define late accumulator
|
301
352
|
accs[reduceop] = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
|
302
353
|
|
303
354
|
# load localbufs
|
304
|
-
loaded_buffers[self.bufs[
|
355
|
+
loaded_buffers[self.bufs[out_buf]] = self.global_load(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
|
305
356
|
|
306
357
|
# there's no AST here (and there's no shape for the reduce LazyOp)
|
307
|
-
self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[
|
358
|
+
self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[out_buf]),)),\
|
308
359
|
accs, self.acc_offsets(-1), loaded_buffers, reduce_acc=accs[reduceop])
|
309
360
|
|
310
361
|
# end the late reduce loop
|
311
362
|
self.load_cache.clear()
|
312
363
|
|
313
|
-
|
314
|
-
|
315
|
-
|
364
|
+
if reduceop is not self.reduceops[-1]:
|
365
|
+
for j in self.upcast_in_mid_reduce_axes:
|
366
|
+
self.upcasted -= 1
|
367
|
+
self.group_for_reduces += 1
|
368
|
+
assert self.buf_uops[out_buf] is not None, "Local reduce buf must have been uoped at this point"
|
369
|
+
fake_local_idxs = local_idxs[:self.local_dims] + [x*0 for x in local_idxs[self.local_dims:]]
|
370
|
+
stores = self.global_store(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop])
|
371
|
+
barrier = UOp(UOps.BARRIER, None, tuple(stores))
|
372
|
+
accs[reduceop] = self.global_load(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
|
373
|
+
return local_idxs[:self.local_dims] + [NumNode(0) for _ in range(self.group_for_reduces)], upcast_idxs
|
316
374
|
|
317
375
|
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
318
|
-
def linearize(self):
|
376
|
+
def linearize(self) -> Linearizer:
|
319
377
|
# no new opts and we already ran? skip relinearizing
|
320
378
|
if self.applied_opts == self.applied_opts_cache: return self
|
321
379
|
|
322
380
|
# late alias the tensor core buffers
|
323
|
-
if (tc:=self.tensor_core) and
|
381
|
+
if (tc:=self.tensor_core) and self.tensor_core_opts is not None:
|
324
382
|
alias_pattern = [0]*(self.global_dims) + [2]*(len(tc.threads)) + [0]*(self.local_dims-len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) # noqa: E501
|
325
|
-
for
|
326
|
-
self.alias_buffer(tc_buf, alias_pattern)
|
383
|
+
for op, tc_bufs in self.bufs_for_tensor_core.items():
|
384
|
+
for tc_buf in tc_bufs: self.alias_buffer(op, tc_buf, alias_pattern)
|
327
385
|
|
328
386
|
# save backups
|
329
387
|
sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduces, self.upcasted
|
330
388
|
|
331
|
-
# global uop cache
|
332
|
-
self.saved_exprs: Dict[Tuple, UOp] = dict()
|
333
|
-
|
334
|
-
# limit dims if we need to
|
335
|
-
if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max)
|
336
|
-
|
337
389
|
# uops
|
338
|
-
self.uops:UOpGraph = UOpGraph()
|
339
390
|
self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
|
340
391
|
self.loop_uops: Dict[str, UOp] = {}
|
341
392
|
|
342
393
|
# add global buffers
|
343
394
|
for i,buf in enumerate(self.bufs):
|
344
395
|
if isinstance(buf, MemBuffer):
|
345
|
-
self.buf_uops[i] =
|
396
|
+
self.buf_uops[i] = UOp(UOps.DEFINE_GLOBAL,
|
346
397
|
buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
|
347
398
|
(buf.idx, any(buf.idx == x.idx for x in self.outbufs)))
|
348
|
-
# add var vals
|
349
|
-
for i,var in enumerate(self.vars):
|
350
|
-
assert var.expr is not None
|
351
|
-
self.loop_uops[var.expr] = self.uops.add(UOps.DEFINE_VAR, dtypes.int32, (), var)
|
352
399
|
# define local buffers
|
353
|
-
for
|
354
|
-
self.buf_uops[self.bufs.index(lb)] =
|
355
|
-
|
400
|
+
for aliases in self.local_alias.values():
|
401
|
+
for lb in aliases.values(): self.buf_uops[self.bufs.index(lb)] = UOp(UOps.DEFINE_LOCAL, PtrDType(lb.dtype),
|
402
|
+
(), (lb.name, self.sts[self.bufs.index(lb)].size))
|
356
403
|
# add a local buffer for multistage reduce. # TODO: use local alias
|
357
404
|
if self.group_for_reduces:
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
405
|
+
for i in range(len(self.reduceops)):
|
406
|
+
# TODO: the strides of this can be controlled
|
407
|
+
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
|
408
|
+
temp_dtype = self.get_base_dtype(cast(LazyOp, self.reduceop).dtype)
|
409
|
+
self.bufs.append(LocalBuffer(name:=f"temp{i if len(self.reduceops) > 1 else ''}", buf_size:=self.sts[-1].size, temp_dtype))
|
410
|
+
self.buf_uops.append(UOp(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), (name, buf_size)))
|
363
411
|
|
364
412
|
# kernel name (before late upcast)
|
365
413
|
self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
|
@@ -372,44 +420,39 @@ class Linearizer(Kernel):
|
|
372
420
|
self.name = self.name+colored(suffix, 'BLACK')
|
373
421
|
|
374
422
|
# define indexes
|
375
|
-
|
376
|
-
|
423
|
+
gl_dims = self.full_shape[:self.first_reduce+self.group_for_reduces]
|
424
|
+
global_idxs, loop_global_idxs, self.global_size = get_grouped_dims("idx" if self.dont_use_locals else "gidx", 0, gl_dims[:self.global_dims],
|
425
|
+
self.opts.global_max, self.opts.has_local)
|
426
|
+
local_idxs, loop_local_idxs, self.local_size = get_grouped_dims("lidx", self.global_dims, gl_dims[self.global_dims:],
|
427
|
+
self.opts.local_max if self.opts.has_local else (), False)
|
377
428
|
upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
|
429
|
+
full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
|
378
430
|
|
379
|
-
#
|
380
|
-
self.
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
|
385
|
-
elif self.opts.has_local:
|
386
|
-
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs]
|
387
|
-
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
|
388
|
-
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
|
431
|
+
# render global and local as specials or a loop
|
432
|
+
if self.opts.has_local:
|
433
|
+
self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
434
|
+
if not self.dont_use_locals:
|
435
|
+
self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
|
389
436
|
else:
|
390
|
-
self.
|
391
|
-
|
392
|
-
|
437
|
+
self.global_size, self.local_size = None, None
|
438
|
+
self.render_loop(loop_global_idxs+loop_local_idxs, 1, False)
|
439
|
+
|
440
|
+
# define idxs for aliased buffers TODO: this doesn't belong in Kernel, but it can't exist in Block either (because of multireduce tensor cores)
|
441
|
+
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
|
442
|
+
alias_buf_idxs = self.index_local_aliases(global_idxs,local_idxs,reduce_idxs,upcast_idxs,full_upcast_idxs)
|
393
443
|
|
394
444
|
# parse AST
|
445
|
+
self.load_cache: Dict[str, UOp] = {}
|
395
446
|
loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]] = {}
|
396
447
|
accs: Dict[LazyOp, List[UOp]] = {}
|
397
|
-
self.load_cache: Dict[str, UOp] = {}
|
398
448
|
|
399
|
-
#
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs)
|
449
|
+
# render reduceops by depth
|
450
|
+
for reduceop in self.reduceops:
|
451
|
+
self.render_block((reduceop, ), global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
|
452
|
+
stores = self.render_block(self.ast, global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
|
404
453
|
|
405
|
-
#
|
406
|
-
|
407
|
-
for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})
|
408
|
-
|
409
|
-
# run late AST (without the store)
|
410
|
-
for op in self.ast:
|
411
|
-
val = self.ast_parse(op.src[0], accs, None, loaded_buffers)
|
412
|
-
self.global_store(op.arg.idx, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
454
|
+
# only the final stores are needed to define the full UOps graph
|
455
|
+
self.uops:UOpGraph = UOpGraph(flatten(stores))
|
413
456
|
|
414
457
|
# maybe graph the uops
|
415
458
|
if DEBUG >= 5: self.uops.print()
|
@@ -422,16 +465,40 @@ class Linearizer(Kernel):
|
|
422
465
|
self.applied_opts_cache = self.applied_opts[:]
|
423
466
|
return self
|
424
467
|
|
468
|
+
def render_block(self, outputs:Tuple[LazyOp, ...], global_idxs, local_idxs, upcast_idxs, full_upcast_idxs,
|
469
|
+
alias_buf_idxs:DefaultDict[LazyOp,List[Tuple[int,int,List[NumNode|Variable]]]],
|
470
|
+
loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], accs:Dict[LazyOp,List[UOp]]) -> List[List[UOp]]:
|
471
|
+
reduceops = dedup(x for x in outputs if x.op in ReduceOps)
|
472
|
+
assert len(reduceops) <= 1, "max one reduceop per block"
|
473
|
+
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
|
474
|
+
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
475
|
+
|
476
|
+
if len(reduceops) != 0:
|
477
|
+
# TODO: delete render_reduceop and move the logic for group_for_reduces to Block
|
478
|
+
nlidx, nuidx = self.render_reduceop((r:=reduceops[0]),accs,loaded_buffers,\
|
479
|
+
global_idxs,local_idxs,upcast_idxs,full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[r])
|
480
|
+
|
481
|
+
# all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
|
482
|
+
# been rewritten with fake end_local_idxs.
|
483
|
+
if r is self.reduceops[-1]: local_idxs[:], upcast_idxs[:] = nlidx, nuidx
|
484
|
+
return [accs[r]]
|
485
|
+
|
486
|
+
# load latebufs
|
487
|
+
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
|
488
|
+
for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})
|
489
|
+
# run late AST (without the store)
|
490
|
+
store_vals = {op.arg.idx:self.ast_parse(op.src[0], accs, None, loaded_buffers) for op in self.ast}
|
491
|
+
return [self.global_store(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) for i, val in store_vals.items()]
|
492
|
+
|
425
493
|
def ast_parse(self, x:LazyOp, accs:Dict[LazyOp, List[UOp]], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], reduce_acc:Optional[List[UOp]]=None, cache=None) -> List[UOp]: # noqa: E501
|
426
494
|
if cache is None: cache = {}
|
427
495
|
if x in cache: return cache[x]
|
428
496
|
if x.op in BufferOps: return loaded_buffers[x.arg]
|
429
497
|
if x.op in [UnaryOps.CAST, UnaryOps.BITCAST]:
|
430
|
-
return [
|
498
|
+
return [UOp(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST,
|
431
499
|
self.get_base_dtype(x.arg), (u,)) for u in self.ast_parse(x.src[0], accs, offs, loaded_buffers)]
|
432
500
|
if x.op in ReduceOps and reduce_acc is None:
|
433
|
-
|
434
|
-
return accs[x]
|
501
|
+
return [accs[x][i] for i in offs] if offs else accs[x]
|
435
502
|
|
436
503
|
values = [self.ast_parse(v, accs, offs, loaded_buffers, cache=cache) for v in x.src]
|
437
504
|
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}
|
@@ -444,17 +511,18 @@ class Linearizer(Kernel):
|
|
444
511
|
ret.append(acc[off])
|
445
512
|
for off in range(len(acc)):
|
446
513
|
if input_acc[off] != acc[off]:
|
447
|
-
acc[off] =
|
448
|
-
else: ret = [UOp.alu(x.op, *
|
514
|
+
acc[off] = UOp(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]))
|
515
|
+
else: ret = [UOp.alu(x.op, *src) for src in zip(*values)]
|
449
516
|
cache[x] = ret
|
450
517
|
return ret
|
451
518
|
|
452
519
|
def to_program(self) -> Program:
|
453
520
|
self.linearize()
|
454
521
|
info = get_lazyop_info(self.ast[0])
|
455
|
-
src = self.opts.render(to_function_name(self.name), self.uops)
|
522
|
+
src = self.opts.render(name:=to_function_name(self.name), self.uops)
|
523
|
+
if getenv("RUN_PROCESS_REPLAY"): diskcache_put("process_replay", id(self), (self.ast, self.opts, self.applied_opts, name, src))
|
456
524
|
ops, mem = self.uops.flops_mem()
|
457
|
-
run_count = prod((self.global_size
|
525
|
+
run_count = prod((self.global_size or []) + (self.local_size or []))
|
458
526
|
# NOTE: we use min here to ignore the indexing FLOPS
|
459
527
|
return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
|
460
528
|
self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|