tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +253 -225
- tinygrad/codegen/linearizer.py +398 -436
- tinygrad/codegen/uops.py +451 -0
- tinygrad/device.py +268 -274
- tinygrad/dtype.py +56 -40
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +198 -0
- tinygrad/engine/realize.py +192 -0
- tinygrad/engine/schedule.py +370 -0
- tinygrad/engine/search.py +199 -0
- tinygrad/{mlops.py → function.py} +40 -32
- tinygrad/helpers.py +144 -46
- tinygrad/lazy.py +143 -242
- tinygrad/multi.py +173 -0
- tinygrad/nn/__init__.py +180 -9
- tinygrad/nn/datasets.py +8 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +87 -19
- tinygrad/ops.py +104 -45
- tinygrad/renderer/__init__.py +65 -0
- tinygrad/renderer/assembly.py +269 -0
- tinygrad/renderer/cstyle.py +308 -210
- tinygrad/renderer/llvmir.py +119 -124
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +13403 -0
- tinygrad/runtime/autogen/comgr.py +891 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5893 -0
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33597 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +56 -0
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +39 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +187 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +550 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +129 -37
- tinygrad/runtime/ops_disk.py +111 -43
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +41 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +625 -0
- tinygrad/runtime/ops_python.py +208 -0
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +46 -107
- tinygrad/shape/symbolic.py +99 -98
- tinygrad/shape/view.py +162 -45
- tinygrad/tensor.py +2492 -483
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/codegen/linearizer.py
CHANGED
@@ -1,48 +1,89 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import List, Tuple,
|
2
|
+
from typing import List, Tuple, Optional, Type, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence, Callable
|
3
3
|
import itertools, math, functools
|
4
4
|
from collections import defaultdict
|
5
|
-
from enum import Enum, auto
|
6
|
-
from dataclasses import dataclass
|
7
5
|
|
8
6
|
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
9
|
-
from tinygrad.helpers import colored, DEBUG, prod, getenv,
|
7
|
+
from tinygrad.helpers import colored, DEBUG, dedup, diskcache_put, prod, getenv, to_function_name, flatten
|
10
8
|
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
|
11
9
|
from tinygrad.shape.shapetracker import ShapeTracker
|
12
|
-
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
|
10
|
+
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node, sint
|
13
11
|
from tinygrad.codegen.kernel import LocalBuffer, Kernel
|
14
|
-
from tinygrad.
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
if
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
12
|
+
from tinygrad.renderer import Program
|
13
|
+
|
14
|
+
from tinygrad.codegen.uops import UOps, UOp, UOpGraph
|
15
|
+
|
16
|
+
def get_grouped_dims(prefix:str, off:int, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse_dims:bool=False):
|
17
|
+
""" Maps all global/local dims onto global/local sizes and returns the idxs, loop_idxs and sizes.
|
18
|
+
|
19
|
+
* If there are fewer dims than size, size will be padded with 1s to the length of max_sizes.
|
20
|
+
* If there are more dims than size, dims will be collapsed onto size starting from left-most (i.e. onto x, then y, then z).
|
21
|
+
* If the dim is too large for the size, the dim will be split between adjacent size axes space permitting, otherwise assert
|
22
|
+
|
23
|
+
Keyword arguments:
|
24
|
+
prefix -- the prefix to use for the size Variable names.
|
25
|
+
off -- the starting index for the size Variable names.
|
26
|
+
dims -- the global or local dims of the full shape.
|
27
|
+
max_sizes -- the maximum values for each size in (x, y, z) order.
|
28
|
+
reverse_dims -- reverse the order of the dims as they are mapped into size, i.e. if True, the right dim will go to the left size (.x).
|
29
|
+
"""
|
30
|
+
|
31
|
+
# check the edge cases on max_sizes
|
32
|
+
if max_sizes is None: max_sizes = tuple([0xFFFFFFFFFFFFFFFF] * len(dims))
|
33
|
+
assert len(max_sizes) > 0 or len(dims) == 0, f"{prefix} dims should be empty because no size axes available"
|
34
|
+
if len(max_sizes) == 0: return [], [], None
|
35
|
+
|
36
|
+
# initialize the map of dims to size with a single dim in each size axis
|
37
|
+
# TODO: support sint properly
|
38
|
+
size_dims:List[List[Tuple[int, sint, sint]]] = [[(dim_idx, dim, dim if isinstance(dim, int) else dim.max+1)] for dim_idx, dim in enumerate(dims)]
|
39
|
+
|
40
|
+
# reverse the order of the dims to size map, if desired (currently for globals where smallest stride is on the right)
|
41
|
+
# TODO: remove reverse_dims, the mapping of dims to size for globals should be cosearched with memory layouts for optimal peformance
|
42
|
+
if reverse_dims: size_dims = size_dims[::-1]
|
43
|
+
|
44
|
+
# ensure that the initial dims initially fit the valid size axes
|
45
|
+
for size_idx in range(min(len(max_sizes), len(size_dims))):
|
46
|
+
# if the initial dim is too large, split the dim to separate size axes, if possible
|
47
|
+
dim_idx, dim, dim_max = size_dims[size_idx][0]
|
48
|
+
if dim_max <= (max_sz:=max_sizes[size_idx]): continue
|
49
|
+
assert isinstance(dim, int), "variable shape too large for size"
|
50
|
+
for factor in range(2, int(dim**0.5)+1):
|
51
|
+
if dim % factor == 0 and dim // factor <= max_sz:
|
52
|
+
size_dims = size_dims[:size_idx] + [[(dim_idx, dim//factor, dim//factor)], [(dim_idx, factor, factor)]] + size_dims[size_idx+1:]
|
53
|
+
break
|
54
|
+
assert size_dims[size_idx][0][2] <= max_sz, f"dim at {size_idx} too large and non-factorable: {dim} > {max_sz}"
|
55
|
+
|
56
|
+
# compress the extra dims, collapsing them onto the left-most valid size axis
|
57
|
+
cur_size_idx = 0
|
58
|
+
while len(size_dims) > len(max_sizes):
|
59
|
+
if prod([dim_max for (_, _, dim_max) in size_dims[cur_size_idx]])*size_dims[cur_size_idx+1][0][2] <= max_sizes[cur_size_idx]:
|
60
|
+
size_dims = size_dims[:cur_size_idx] + [size_dims[cur_size_idx] + size_dims[cur_size_idx+1]] + size_dims[cur_size_idx+2:]
|
61
|
+
elif cur_size_idx < len(max_sizes)-1: cur_size_idx += 1
|
62
|
+
else: raise AssertionError(f"cannot fit dims in size: {dims=} {max_sizes=}")
|
63
|
+
|
64
|
+
# construct the final dim idx variables from the the portions of the size variables
|
65
|
+
sizes, idxs = [prod([dim for (_, dim, _) in size_dim]) for size_dim in size_dims], [NumNode(0)] * len(dims)
|
66
|
+
size_vars = loop_idxs = [Variable(f"{prefix}{len(sizes)-1-(i+off) if reverse_dims else i+off}", 0, s-1) for i,s in enumerate(sizes)]
|
67
|
+
for size_idx, size_var in enumerate(size_vars):
|
68
|
+
for dim_idx, dim, _ in size_dims[size_idx]:
|
69
|
+
idxs[dim_idx] += (size_var % dim) * (idxs[dim_idx].max+1)
|
70
|
+
size_var //= dim
|
71
|
+
|
72
|
+
# pad the final sizes array to the proper length if necessary
|
73
|
+
return idxs, [x for x in loop_idxs if not isinstance(x, NumNode)], sizes + [1]*(len(max_sizes)-len(sizes))
|
74
|
+
|
75
|
+
def expand_idx(node:Node) -> Union[Variable, NumNode]: return next((v for v in node.vars() if v.expr.startswith("_uidx")), NumNode(0))
|
76
|
+
def expand_idxs(nodes:Sequence[Node]) -> Tuple[Union[Variable, NumNode], ...]:
|
77
|
+
eidxs = [expand_idx(node) for node in nodes]
|
78
|
+
return tuple([v if v not in eidxs[:j] else NumNode(0) for j, v in enumerate(eidxs)]) # take only first occurrence of expand variable
|
44
79
|
def iter_idxs(idxs:Tuple[Union[Variable, NumNode], ...]) -> Iterator[Tuple[int,...]]:
|
45
|
-
yield from (x[::-1] for x in itertools.product(*[
|
80
|
+
yield from (x[::-1] for x in itertools.product(*[list(range(v.min, v.max + 1)) for v in idxs[::-1]]))
|
81
|
+
|
82
|
+
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
|
83
|
+
idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
|
84
|
+
# TODO: bring back the valid removal logic (correct!)
|
85
|
+
if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid)
|
86
|
+
return (idx, idy), valid
|
46
87
|
|
47
88
|
# expand a Node into List[Node] that enumerates the underlying Variables from min to max
|
48
89
|
# expand increments earlier variables faster than later variables (as specified in the argument)
|
@@ -51,95 +92,90 @@ def expand_node(node:Node, idxs:Optional[Tuple[Union[Variable, NumNode], ...]]=N
|
|
51
92
|
if idxs is None: idxs = (expand_idx(node),)
|
52
93
|
return [node.substitute({k:v for k,v in zip(idxs, (NumNode(x) for x in rep)) if isinstance(k, Variable)}) for rep in iter_idxs(idxs)]
|
53
94
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
return self.uop(UOps.ALU, dtype, (a, render_b), op)
|
58
|
-
|
59
|
-
# NOTE: the consts have to be cached for deduping of downstream uops to work
|
60
|
-
def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp:
|
61
|
-
return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)
|
95
|
+
def variable_to_uop(x, ctx=None) -> UOp:
|
96
|
+
if isinstance(x, int): return UOp.const(dtypes.int, x)
|
97
|
+
return x.render(render_ops, ctx)
|
62
98
|
|
63
|
-
|
99
|
+
render_ops: Dict[Type, Callable[..., UOp]] = {
|
100
|
+
NumNode: lambda self, ops, ctx: UOp.const(dtypes.int, self.b),
|
101
|
+
Variable: lambda self, ops, ctx: ctx[self.expr] if self.expr in ctx else UOp(UOps.DEFINE_VAR, dtypes.int, (), self),
|
102
|
+
MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
|
103
|
+
DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
|
104
|
+
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
|
105
|
+
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
|
106
|
+
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+variable_to_uop(b, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
107
|
+
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*variable_to_uop(b, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
64
108
|
|
109
|
+
class Linearizer(Kernel):
|
65
110
|
def get_reduce_acc(self, reduceop:LazyOp):
|
66
|
-
|
67
|
-
if reduceop.op
|
68
|
-
|
69
|
-
|
70
|
-
return -math.inf if dtypes.is_float(dtype) else False
|
111
|
+
if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
|
112
|
+
if reduceop.op is ReduceOps.MAX:
|
113
|
+
if dtypes.is_int(reduceop.dtype): return 0 if dtypes.is_unsigned(reduceop.dtype) else -2**(reduceop.dtype.itemsize*8-1)
|
114
|
+
return -math.inf if dtypes.is_float(reduceop.dtype) else False
|
71
115
|
|
72
116
|
# NOTE: once images are loaded, we uop them as their base float
|
73
|
-
def get_base_dtype(self, dt:DType): return dt.base if isinstance(dt, ImageDType) else dt
|
74
|
-
|
75
|
-
|
76
|
-
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
|
77
|
-
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
|
78
|
-
ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
|
79
|
-
LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT, dtype=dtypes.bool),
|
80
|
-
SumNode: lambda self,ops,ctx:
|
81
|
-
functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
82
|
-
AndNode: lambda self,ops,ctx:
|
83
|
-
functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
84
|
-
|
85
|
-
def global_load(self, i:int, idxs:Sequence[Node], acc=None, barrier:Optional[UOp]=None) -> List[UOp]:
|
117
|
+
def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt
|
118
|
+
|
119
|
+
def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Optional[UOp]=None, loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]:
|
86
120
|
buf = self.bufs[i]
|
87
|
-
localtype = self.get_base_dtype(buf.dtype if acc is None else
|
88
|
-
const = buf.val if isinstance(buf, ConstBuffer) else
|
121
|
+
localtype = self.get_base_dtype(buf.dtype if acc is None else acc.dtype)
|
122
|
+
const = buf.val if isinstance(buf, ConstBuffer) else None
|
89
123
|
|
90
|
-
|
91
|
-
expand_vars = tuple([rename_var(expand_idx(idx), f"_uidx{j}") for j, idx in enumerate(idxs)])
|
92
|
-
fake_idxs = [idx.substitute({eidx: ev}) if isinstance(eidx:=expand_idx(idx), Variable) else idx for idx, ev in zip(idxs, expand_vars)]
|
124
|
+
expand_vars = expand_idxs(idxs)
|
93
125
|
|
94
126
|
dim, amt = None, 1
|
95
127
|
# float 4 grouping
|
96
128
|
if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [4,2]:
|
97
129
|
dim, amt = upcast_dim[0], len(float4_expand)
|
98
|
-
g_idx, g_valid = self.sts[i].expr_idxs(
|
130
|
+
g_idx, g_valid = self.sts[i].expr_idxs(idxs[:dim] + [float4_expand[0]] + idxs[dim+1:])
|
99
131
|
# do not use float4 if idx is not aligned
|
100
132
|
if g_idx != (g_idx//amt*amt): dim, amt = None, 1
|
101
133
|
if dim is None:
|
102
|
-
g_idx, g_valid = self.sts[i].expr_idxs(
|
134
|
+
g_idx, g_valid = self.sts[i].expr_idxs(idxs)
|
135
|
+
# todo: multioutput test with different output valids to add if acc is None: g_valid = NumNode(1)
|
103
136
|
|
104
137
|
if amt > 1: localtype = localtype.vec(amt)
|
105
|
-
e_idxs, e_valids = expand_node(g_idx, expand_vars), expand_node(g_valid, expand_vars)
|
138
|
+
e_idxs, e_valids = expand_node(g_idx, expand_vars), expand_node(g_valid, expand_vars) # pylint: disable=possibly-used-before-assignment
|
106
139
|
|
107
140
|
ret = []
|
108
|
-
invalid_value = 0
|
141
|
+
invalid_value = 0
|
142
|
+
acc_count = 0
|
109
143
|
for idx, valid, rep_idx in zip(e_idxs, e_valids, iter_idxs(expand_vars)):
|
110
144
|
this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid)
|
111
|
-
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
|
145
|
+
key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
|
112
146
|
if key not in self.load_cache:
|
113
147
|
if acc is not None:
|
114
|
-
self.load_cache[key] =
|
148
|
+
self.load_cache[key] = UOp(UOps.DEFINE_ACC, localtype, (UOp.const(localtype.scalar(), self.get_reduce_acc(acc)), *loop_ctx), (i, acc_count))
|
149
|
+
acc_count += 1
|
115
150
|
elif this_const is not None:
|
116
|
-
self.load_cache[key] =
|
151
|
+
self.load_cache[key] = UOp.const(localtype, this_const)
|
117
152
|
if valid.min == 0 and valid.max == 1:
|
118
|
-
valid_rendered = valid.render(
|
119
|
-
self.load_cache[key] =
|
153
|
+
valid_rendered = valid.render(render_ops, self.loop_uops)
|
154
|
+
self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_rendered, self.load_cache[key], UOp.const(localtype, invalid_value))
|
120
155
|
elif isinstance(buf.dtype, ImageDType):
|
121
156
|
buf_uop = self.buf_uops[i]
|
122
157
|
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
123
158
|
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
124
|
-
rendered_idx =
|
125
|
-
valid_tuple = (valid.render(
|
126
|
-
self.load_cache[key] =
|
159
|
+
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
|
160
|
+
valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
|
161
|
+
self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4),
|
162
|
+
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
127
163
|
if localtype == localtype.scalar():
|
128
164
|
idx_small = idx%4
|
129
|
-
res = idx_small.render(
|
130
|
-
out =
|
165
|
+
res = idx_small.render(render_ops, self.loop_uops)
|
166
|
+
out = UOp(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
|
131
167
|
for ix in range(idx_small.max, idx_small.min, -1):
|
132
|
-
rvv =
|
133
|
-
sel =
|
134
|
-
out =
|
168
|
+
rvv = UOp(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
|
169
|
+
sel = UOp.alu(BinaryOps.CMPLT, res, UOp.const(dtypes.int, ix))
|
170
|
+
out = UOp.alu(TernaryOps.WHERE, sel, rvv, out)
|
135
171
|
self.load_cache[key] = out
|
136
172
|
else:
|
137
173
|
buf_uop = self.buf_uops[i]
|
138
174
|
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
139
|
-
rendered_idx = idx.render(
|
140
|
-
valid_tuple = (valid.render(
|
141
|
-
self.load_cache[key] =
|
142
|
-
ret.append(
|
175
|
+
rendered_idx = idx.render(render_ops, self.loop_uops)
|
176
|
+
valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
|
177
|
+
self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
178
|
+
ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
143
179
|
return ret
|
144
180
|
|
145
181
|
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
|
@@ -147,12 +183,12 @@ class Linearizer(Kernel):
|
|
147
183
|
buf_uop = self.buf_uops[i]
|
148
184
|
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
149
185
|
|
150
|
-
|
151
|
-
_idxs = [
|
186
|
+
expand_vars = expand_idxs(idxs)
|
187
|
+
_idxs = zip(*[expand_node(idx, expand_vars) for idx in idxs]) if idxs else [tuple()] # transpose
|
152
188
|
store_offset = dict(zip(_idxs, store))
|
153
189
|
|
154
190
|
# float4 grouping
|
155
|
-
if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand :=
|
191
|
+
if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [2,4]:
|
156
192
|
grouped_store_offset = defaultdict(list)
|
157
193
|
for k in store_offset:
|
158
194
|
_idx = k[:upcast_dim[0]] + (float4_expand[0],) + k[upcast_dim[0]+1:]
|
@@ -162,61 +198,221 @@ class Linearizer(Kernel):
|
|
162
198
|
amt = len(grouped)
|
163
199
|
idx, valid = self.sts[i].expr_idxs(k)
|
164
200
|
assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
|
165
|
-
store_offset_new[k] =
|
201
|
+
store_offset_new[k] = UOp(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
|
166
202
|
store_offset = store_offset_new
|
167
203
|
|
168
204
|
stores = []
|
169
|
-
for
|
170
|
-
idx, valid = self.sts[i].expr_idxs(
|
205
|
+
for _idx, var in store_offset.items():
|
206
|
+
idx, valid = self.sts[i].expr_idxs(_idx)
|
171
207
|
if isinstance(buf.dtype, ImageDType):
|
172
|
-
|
173
|
-
rendered_idx =
|
208
|
+
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
209
|
+
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), \
|
210
|
+
tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
|
174
211
|
else:
|
175
|
-
rendered_idx = idx.render(
|
176
|
-
|
177
|
-
|
212
|
+
rendered_idx = idx.render(render_ops, self.loop_uops)
|
213
|
+
# TODO: let UPat check this once it's fast
|
214
|
+
if valid.min == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var)))
|
215
|
+
else: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(render_ops, self.loop_uops))))
|
178
216
|
return stores
|
179
217
|
|
218
|
+
# render loop
|
219
|
+
def render_loop(self, xx:List[Variable], depth:int, reduce:bool) -> Tuple[UOp, ...]:
|
220
|
+
new_loops = {x.expr:UOp(UOps.RANGE, dtypes.int32, (
|
221
|
+
UOp.const(dtypes.int, x.min) if isinstance(x.min, int) else cast(Node, x.min).render(render_ops, self.loop_uops),
|
222
|
+
UOp.const(dtypes.int, x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(render_ops, self.loop_uops)), arg=(depth,i,reduce)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
|
223
|
+
self.loop_uops.update(new_loops)
|
224
|
+
return tuple(new_loops.values())
|
225
|
+
|
226
|
+
def index_local_aliases(self, global_idxs, local_idxs, reduce_idxs, upcast_idxs, full_upcast_idxs):
|
227
|
+
def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
|
228
|
+
replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
|
229
|
+
for s in local_sizes:
|
230
|
+
thread_idxs.append(thread_idx % s)
|
231
|
+
thread_idx //= s
|
232
|
+
for alias in aliases:
|
233
|
+
full_var, full_var_sz = NumNode(0), 1
|
234
|
+
if alias[0] != 0:
|
235
|
+
for i in alias:
|
236
|
+
next_var = local_idxs[i-1] if i > 0 else thread_idxs[-i-1]
|
237
|
+
full_var += next_var * full_var_sz
|
238
|
+
full_var_sz *= next_var.max+1
|
239
|
+
replace_idxs.append(full_var)
|
240
|
+
return replace_idxs
|
241
|
+
|
242
|
+
# compute local aliases
|
243
|
+
alias_buf_idxs: DefaultDict[LazyOp, List[Tuple[int, int, List]]] = defaultdict(list)
|
244
|
+
for op, local_alias in self.local_alias.items():
|
245
|
+
for i in local_alias:
|
246
|
+
localbuf_idx = self.bufs.index(local_alias[i])
|
247
|
+
buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
|
248
|
+
if (tc:=self.tensor_core):
|
249
|
+
min_alias_idx = min(local_alias.keys())
|
250
|
+
replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
|
251
|
+
for n in range(len(tc.threads)):
|
252
|
+
buf_idxs[self.global_dims+n] = replace_input_idxs[n] # replace locals
|
253
|
+
for n in range(tc.num_upcasts()):
|
254
|
+
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
|
255
|
+
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
|
256
|
+
alias_buf_idxs[op].append((i, localbuf_idx, buf_idxs))
|
257
|
+
# modify idxs if necessary for TC
|
258
|
+
if (tc:=self.tensor_core):
|
259
|
+
replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
|
260
|
+
for n in range(len(tc.threads)):
|
261
|
+
local_idxs[n] = replace_acc_idxs[n] # replace locals
|
262
|
+
for n in range(len(replace_acc_idxs)-len(tc.threads)):
|
263
|
+
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
|
264
|
+
if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+upcast_idxs}")
|
265
|
+
return alias_buf_idxs
|
266
|
+
|
267
|
+
def render_reduceop(self, reduceop:LazyOp, accs:Dict[LazyOp, List[UOp]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]],
|
268
|
+
global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, reduce_idxs, fake_reduce_idxs,
|
269
|
+
alias_buf_idxs:List[Tuple[int, int, List]]) -> Tuple[List[NumNode|Variable], List[NumNode|Variable]]:
|
270
|
+
# reduce loop
|
271
|
+
loop_ctx = self.render_loop(reduce_idxs, (i:=self.reduceops.index(reduceop))*2+2, True)
|
272
|
+
|
273
|
+
# define accumulator - modify idxs if necessary for TC
|
274
|
+
out_buf = -len(self.reduceops)+i if self.group_for_reduces else 0
|
275
|
+
accs[reduceop] = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
|
276
|
+
|
277
|
+
# store local aliases
|
278
|
+
locals_to_store = [(localbuf_idx, buf_idxs, self.global_load(i, buf_idxs)) for i, localbuf_idx, buf_idxs in alias_buf_idxs]
|
279
|
+
|
280
|
+
if (tc:=self.tensor_core):
|
281
|
+
# run tensor cores AST
|
282
|
+
wmma_sz = [prod(l) for l in tc.thread_local_sizes]
|
283
|
+
def upcast_strides(buf:int):
|
284
|
+
strides, next_ = [], 1
|
285
|
+
for (sz, stride, _) in self.upcasted_axis(buf)[tc.num_upcasts():]:
|
286
|
+
strides.append((0 if stride == 0 else next_, sz))
|
287
|
+
next_ *= 1 if stride == 0 else sz
|
288
|
+
return strides
|
289
|
+
upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
|
290
|
+
# cast initial accs
|
291
|
+
wmmas = [UOp(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
|
292
|
+
for x in range(0, len(accs[reduceop]), wmma_sz[2])]
|
293
|
+
for it in [x[::-1] for x in itertools.product(*list([range(sz) for _,sz in upcasts[0]][::-1]))]:
|
294
|
+
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(it, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
|
295
|
+
ops = (UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
|
296
|
+
UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
|
297
|
+
wmmas[(wmma_idx:=offs[2]//wmma_sz[2])])
|
298
|
+
# TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid
|
299
|
+
wmmas[wmma_idx] = UOp(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev))
|
300
|
+
# phi the last wmmas back to accs
|
301
|
+
accs[reduceop] = [UOp(UOps.PHI, tc.dtype_out, (acc, UOp(UOps.GEP, tc.dtype_out, (wmmas[z//wmma_sz[2]],), z%wmma_sz[2])))
|
302
|
+
for z, acc in enumerate(accs[reduceop])]
|
303
|
+
else:
|
304
|
+
assert not locals_to_store, "storing locals isn't supported here"
|
305
|
+
|
306
|
+
# load earlybufs
|
307
|
+
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[reduceop][i]) if i in self.local_alias else i,
|
308
|
+
global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs})
|
309
|
+
|
310
|
+
def gate_acc(r, idxs): return [
|
311
|
+
UOp.alu(TernaryOps.WHERE, valid.render(render_ops, self.loop_uops), acc, UOp.const(r.dtype, 0)) if valid.min == 0 and valid.max == 1 else acc
|
312
|
+
for valid, acc in zip(expand_node(self.sts[self.full_buf_index].expr_idxs(idxs)[1], expand_idxs(idxs)), accs[r])]
|
313
|
+
local_accs = {r: gate_acc(r,global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for r in accs}
|
314
|
+
|
315
|
+
# run early AST (with reduce)
|
316
|
+
self.ast_parse(reduceop, local_accs, self.acc_offsets(self.full_buf_index), loaded_buffers, reduce_acc=accs[reduceop])
|
317
|
+
|
318
|
+
# end the reduce loop
|
319
|
+
self.load_cache.clear()
|
320
|
+
|
321
|
+
# end the local loop, do the local reduce
|
322
|
+
if self.group_for_reduces:
|
323
|
+
fake_global_idxs = [x*0 for x in global_idxs]
|
324
|
+
stores = self.global_store(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop]) # store accumulators
|
325
|
+
barrier = UOp(UOps.BARRIER, None, tuple(stores))
|
326
|
+
if self.opts.has_local:
|
327
|
+
fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
|
328
|
+
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
329
|
+
if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(render_ops, self.loop_uops)
|
330
|
+
barrier = UOp(UOps.IF, None, (if_cond, barrier))
|
331
|
+
|
332
|
+
# create new late reduce local loops and replace local_idxs that have been used
|
333
|
+
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501
|
334
|
+
local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
|
335
|
+
|
336
|
+
# if any group_for_reduce items aren't reduces, upcast them here
|
337
|
+
for j in self.upcast_in_mid_reduce_axes:
|
338
|
+
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
|
339
|
+
self.upcast()
|
340
|
+
self.group_for_reduces -= 1
|
341
|
+
local_idxs = local_idxs[:-1]
|
342
|
+
end_local_idxs = end_local_idxs[:-1]
|
343
|
+
# regenerate upcast_idxs
|
344
|
+
upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
|
345
|
+
|
346
|
+
# NOTE: this structure is the same as the reduce op above
|
347
|
+
|
348
|
+
# late reduce loop
|
349
|
+
loop_ctx = self.render_loop(end_local_idxs, i*2+3, True)
|
350
|
+
|
351
|
+
# define late accumulator
|
352
|
+
accs[reduceop] = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
|
353
|
+
|
354
|
+
# load localbufs
|
355
|
+
loaded_buffers[self.bufs[out_buf]] = self.global_load(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
|
356
|
+
|
357
|
+
# there's no AST here (and there's no shape for the reduce LazyOp)
|
358
|
+
self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[out_buf]),)),\
|
359
|
+
accs, self.acc_offsets(-1), loaded_buffers, reduce_acc=accs[reduceop])
|
360
|
+
|
361
|
+
# end the late reduce loop
|
362
|
+
self.load_cache.clear()
|
363
|
+
|
364
|
+
if reduceop is not self.reduceops[-1]:
|
365
|
+
for j in self.upcast_in_mid_reduce_axes:
|
366
|
+
self.upcasted -= 1
|
367
|
+
self.group_for_reduces += 1
|
368
|
+
assert self.buf_uops[out_buf] is not None, "Local reduce buf must have been uoped at this point"
|
369
|
+
fake_local_idxs = local_idxs[:self.local_dims] + [x*0 for x in local_idxs[self.local_dims:]]
|
370
|
+
stores = self.global_store(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop])
|
371
|
+
barrier = UOp(UOps.BARRIER, None, tuple(stores))
|
372
|
+
accs[reduceop] = self.global_load(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
|
373
|
+
return local_idxs[:self.local_dims] + [NumNode(0) for _ in range(self.group_for_reduces)], upcast_idxs
|
374
|
+
|
180
375
|
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
181
|
-
def linearize(self):
|
376
|
+
def linearize(self) -> Linearizer:
|
182
377
|
# no new opts and we already ran? skip relinearizing
|
183
378
|
if self.applied_opts == self.applied_opts_cache: return self
|
184
379
|
|
185
|
-
#
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
380
|
+
# late alias the tensor core buffers
|
381
|
+
if (tc:=self.tensor_core) and self.tensor_core_opts is not None:
|
382
|
+
alias_pattern = [0]*(self.global_dims) + [2]*(len(tc.threads)) + [0]*(self.local_dims-len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) # noqa: E501
|
383
|
+
for op, tc_bufs in self.bufs_for_tensor_core.items():
|
384
|
+
for tc_buf in tc_bufs: self.alias_buffer(op, tc_buf, alias_pattern)
|
190
385
|
|
191
|
-
#
|
192
|
-
|
386
|
+
# save backups
|
387
|
+
sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduces, self.upcasted
|
193
388
|
|
194
389
|
# uops
|
195
|
-
self.uops: List[UOp] = []
|
196
390
|
self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
|
197
391
|
self.loop_uops: Dict[str, UOp] = {}
|
198
392
|
|
199
393
|
# add global buffers
|
200
394
|
for i,buf in enumerate(self.bufs):
|
201
395
|
if isinstance(buf, MemBuffer):
|
202
|
-
self.buf_uops[i] =
|
203
|
-
|
204
|
-
|
205
|
-
assert var.expr is not None
|
206
|
-
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), var.expr)
|
396
|
+
self.buf_uops[i] = UOp(UOps.DEFINE_GLOBAL,
|
397
|
+
buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
|
398
|
+
(buf.idx, any(buf.idx == x.idx for x in self.outbufs)))
|
207
399
|
# define local buffers
|
208
|
-
for
|
209
|
-
self.buf_uops[self.bufs.index(lb)] =
|
400
|
+
for aliases in self.local_alias.values():
|
401
|
+
for lb in aliases.values(): self.buf_uops[self.bufs.index(lb)] = UOp(UOps.DEFINE_LOCAL, PtrDType(lb.dtype),
|
402
|
+
(), (lb.name, self.sts[self.bufs.index(lb)].size))
|
210
403
|
# add a local buffer for multistage reduce. # TODO: use local alias
|
211
|
-
if self.
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
404
|
+
if self.group_for_reduces:
|
405
|
+
for i in range(len(self.reduceops)):
|
406
|
+
# TODO: the strides of this can be controlled
|
407
|
+
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
|
408
|
+
temp_dtype = self.get_base_dtype(cast(LazyOp, self.reduceop).dtype)
|
409
|
+
self.bufs.append(LocalBuffer(name:=f"temp{i if len(self.reduceops) > 1 else ''}", buf_size:=self.sts[-1].size, temp_dtype))
|
410
|
+
self.buf_uops.append(UOp(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), (name, buf_size)))
|
217
411
|
|
218
412
|
# kernel name (before late upcast)
|
219
|
-
self.name = ("
|
413
|
+
self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
|
414
|
+
(f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \
|
415
|
+
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
220
416
|
|
221
417
|
# name the function something unique
|
222
418
|
Linearizer.kernel_cnt[(function_name := to_function_name(self.name))] += 1
|
@@ -224,343 +420,109 @@ class Linearizer(Kernel):
|
|
224
420
|
self.name = self.name+colored(suffix, 'BLACK')
|
225
421
|
|
226
422
|
# define indexes
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
self.loop_uops.update(
|
238
|
-
|
239
|
-
|
240
|
-
# set global/local size
|
241
|
-
self.global_size: Optional[List[int]] = None
|
242
|
-
self.local_size: Optional[List[int]] = None
|
243
|
-
if self.dont_use_locals:
|
244
|
-
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
|
245
|
-
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
|
246
|
-
elif self.opts.has_local:
|
247
|
-
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1]
|
248
|
-
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
|
249
|
-
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) # noqa: E501
|
423
|
+
gl_dims = self.full_shape[:self.first_reduce+self.group_for_reduces]
|
424
|
+
global_idxs, loop_global_idxs, self.global_size = get_grouped_dims("idx" if self.dont_use_locals else "gidx", 0, gl_dims[:self.global_dims],
|
425
|
+
self.opts.global_max, self.opts.has_local)
|
426
|
+
local_idxs, loop_local_idxs, self.local_size = get_grouped_dims("lidx", self.global_dims, gl_dims[self.global_dims:],
|
427
|
+
self.opts.local_max if self.opts.has_local else (), False)
|
428
|
+
upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
|
429
|
+
full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
|
430
|
+
|
431
|
+
# render global and local as specials or a loop
|
432
|
+
if self.opts.has_local:
|
433
|
+
self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
434
|
+
if not self.dont_use_locals:
|
435
|
+
self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
|
250
436
|
else:
|
251
|
-
|
437
|
+
self.global_size, self.local_size = None, None
|
438
|
+
self.render_loop(loop_global_idxs+loop_local_idxs, 1, False)
|
439
|
+
|
440
|
+
# define idxs for aliased buffers TODO: this doesn't belong in Kernel, but it can't exist in Block either (because of multireduce tensor cores)
|
441
|
+
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
|
442
|
+
alias_buf_idxs = self.index_local_aliases(global_idxs,local_idxs,reduce_idxs,upcast_idxs,full_upcast_idxs)
|
252
443
|
|
253
444
|
# parse AST
|
254
|
-
loaded_buffers = {}
|
255
|
-
acc: List[UOp] = []
|
256
445
|
self.load_cache: Dict[str, UOp] = {}
|
446
|
+
loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]] = {}
|
447
|
+
accs: Dict[LazyOp, List[UOp]] = {}
|
257
448
|
|
258
|
-
#
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)] # noqa: E501
|
263
|
-
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
264
|
-
|
265
|
-
# define accumulator
|
266
|
-
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))
|
267
|
-
|
268
|
-
if self.tensor_core:
|
269
|
-
def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
|
270
|
-
replace_idxs = []
|
271
|
-
for alias in aliases:
|
272
|
-
full_var, full_var_sz = NumNode(0), 1
|
273
|
-
if alias[0] != 0:
|
274
|
-
for i in alias:
|
275
|
-
next_var = local_idxs[-i] if i > 0 else Variable(None, 0, local_size-1)
|
276
|
-
full_var += next_var * full_var_sz
|
277
|
-
full_var_sz *= next_var.max+1
|
278
|
-
replace_idxs.append(full_var)
|
279
|
-
return replace_idxs
|
280
|
-
replace_acc_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[2], self.tensor_core.thread_local_aliases[2])
|
281
|
-
for n in range(len(self.tensor_core.threads)):
|
282
|
-
local_idxs[self.local_dims-len(self.tensor_core.threads)+n] = replace_acc_idxs[n] # replace locals
|
283
|
-
for n in range(len(replace_acc_idxs)-len(self.tensor_core.threads)):
|
284
|
-
upcast_idxs[n] = replace_acc_idxs[len(self.tensor_core.threads)+n] # replace upcasts
|
285
|
-
|
286
|
-
# reduce loop
|
287
|
-
loop_ctx = render_loop(reduce_idxs)
|
288
|
-
|
289
|
-
# barrier for fast GEMM
|
290
|
-
if self.tensor_core: self.uop(UOps.BARRIER, None, (), cachable=False)
|
291
|
-
|
292
|
-
# compute local aliases
|
293
|
-
locals_to_store = []
|
294
|
-
for i in self.local_alias:
|
295
|
-
localbuf_idx = self.bufs.index(self.local_alias[i])
|
296
|
-
buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
|
297
|
-
if self.tensor_core:
|
298
|
-
min_alias_idx = min(self.local_alias.keys())
|
299
|
-
replace_input_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[i-min_alias_idx], self.tensor_core.thread_local_aliases[i-min_alias_idx]) # noqa: E501
|
300
|
-
for n in range(len(self.tensor_core.threads)):
|
301
|
-
buf_idxs[self.first_reduce-len(self.tensor_core.threads)+n] = replace_input_idxs[n] # replace locals
|
302
|
-
for n in range(len(replace_input_idxs)-len(self.tensor_core.threads)):
|
303
|
-
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(self.tensor_core.threads)+n] # replace upcasts
|
304
|
-
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: idxs=", buf_idxs)
|
305
|
-
ll = self.global_load(i, buf_idxs)
|
306
|
-
locals_to_store.append((localbuf_idx, buf_idxs, ll))
|
307
|
-
|
308
|
-
# copy in any global buffers
|
309
|
-
if self.tensor_core:
|
310
|
-
wmma_sz = self.tensor_core.thread_local_sizes
|
311
|
-
# calculate the number of local accumulator reduces and render WMMAs: this is bad... this needs to come from someplace else
|
312
|
-
nx, ny, nacc = (len(locals_to_store[0][2])//wmma_sz[0]), (len(locals_to_store[1][2])//wmma_sz[1]), (len(acc)//wmma_sz[2])
|
313
|
-
acc_reds = math.isqrt((nx*ny)//nacc)
|
314
|
-
i, bx, by = 0, nx//acc_reds, ny//acc_reds
|
315
|
-
for y in range(by):
|
316
|
-
for x in range(bx):
|
317
|
-
for j in range(acc_reds):
|
318
|
-
op1, op2, op3 = locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]], locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]], acc[i:i+wmma_sz[2]] # noqa: E501
|
319
|
-
if self.opts.device != "HIP":
|
320
|
-
ops = tuple(op1+op2+op3)
|
321
|
-
else:
|
322
|
-
ops = (self.uop(UOps.CAST, dtypes.half.vec(16), tuple(op1)),
|
323
|
-
self.uop(UOps.CAST, dtypes.half.vec(16), tuple(op2)),
|
324
|
-
self.uop(UOps.CAST, dtypes.float.vec(8), tuple(op3)))
|
325
|
-
ret = self.uop(UOps.WMMA, dtypes.float.vec(2) if wmma_sz[2] == 2 else dtypes.float.vec(8), ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,)) # noqa: E501
|
326
|
-
for z in range(cast(DType, ret.dtype).sz):
|
327
|
-
acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + loop_ctx)
|
328
|
-
i += wmma_sz[2]
|
329
|
-
else:
|
330
|
-
if locals_to_store:
|
331
|
-
self.uop(UOps.BARRIER, None, (), cachable=False)
|
332
|
-
for i, idxs, ll in locals_to_store: self.global_store(i, idxs, ll)
|
333
|
-
self.uop(UOps.BARRIER, None, (), cachable=False)
|
334
|
-
|
335
|
-
# load earlybufs
|
336
|
-
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs}) # noqa: E501
|
337
|
-
|
338
|
-
# run early AST (with reduce)
|
339
|
-
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
|
340
|
-
|
341
|
-
# end the reduce loop
|
342
|
-
self.load_cache.clear()
|
343
|
-
|
344
|
-
# end the local loop, do the local reduce
|
345
|
-
if self.group_for_reduce:
|
346
|
-
fake_global_idxs = [x*0 for x in global_idxs]
|
347
|
-
stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
348
|
-
barrier = self.uop(UOps.BARRIER, None, tuple(stores), cachable=False)
|
349
|
-
if self.opts.has_local:
|
350
|
-
fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
|
351
|
-
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
352
|
-
if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self)
|
353
|
-
barrier = self.uop(UOps.IF, None, (if_cond, barrier), cachable=False)
|
354
|
-
|
355
|
-
# create new late reduce local loops and replace local_idxs that have been used
|
356
|
-
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] # noqa: E501
|
357
|
-
local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
|
358
|
-
|
359
|
-
# if any group_for_reduce items aren't reduces, upcast them here
|
360
|
-
for j in self.upcast_in_mid_reduce_axes:
|
361
|
-
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
|
362
|
-
self.upcast()
|
363
|
-
self.group_for_reduce.pop()
|
364
|
-
local_idxs = local_idxs[:-1]
|
365
|
-
end_local_idxs = end_local_idxs[:-1]
|
366
|
-
# regenerate upcast_idxs
|
367
|
-
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
|
368
|
-
|
369
|
-
# NOTE: this structure is the same as the reduce op above
|
370
|
-
|
371
|
-
# define late accumulator
|
372
|
-
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))
|
373
|
-
|
374
|
-
# late reduce loop
|
375
|
-
loop_ctx = render_loop(end_local_idxs)
|
376
|
-
|
377
|
-
# load localbufs
|
378
|
-
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
|
449
|
+
# render reduceops by depth
|
450
|
+
for reduceop in self.reduceops:
|
451
|
+
self.render_block((reduceop, ), global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
|
452
|
+
stores = self.render_block(self.ast, global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
|
379
453
|
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
# end the late reduce loop
|
384
|
-
self.load_cache.clear()
|
385
|
-
|
386
|
-
# load latebufs
|
387
|
-
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer}) # noqa: E501
|
388
|
-
|
389
|
-
# run late AST (without the store)
|
390
|
-
val = self.ast_parse(self.ast.src[0], acc, None, loaded_buffers)
|
391
|
-
|
392
|
-
# store
|
393
|
-
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
394
|
-
|
395
|
-
# get PHI node loop scope, link anything using a DEFINE_ACC to the loop as a "parent"
|
396
|
-
acc_scope: DefaultDict[UOp, List[UOp]] = defaultdict(list)
|
397
|
-
for u in self.uops:
|
398
|
-
if u.uop == UOps.PHI:
|
399
|
-
acc_scope[u.vin[0]] += u.vin[2:]
|
400
|
-
|
401
|
-
# graph helper functions
|
402
|
-
@functools.lru_cache(None)
|
403
|
-
def get_recursive_parents(x:UOp, with_phi=False) -> Set[UOp]:
|
404
|
-
return set.union(set(x.vin), *[get_recursive_parents(p, with_phi) for p in x.vin], set(acc_scope[x]) if with_phi else set())
|
405
|
-
|
406
|
-
def get_recursive_children(x:UOp) -> Set[UOp]:
|
407
|
-
deps = set([x])
|
408
|
-
ssize = 0
|
409
|
-
while ssize != len(deps):
|
410
|
-
ssize = len(deps)
|
411
|
-
for u in self.uops:
|
412
|
-
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
|
413
|
-
deps.add(u)
|
414
|
-
return deps
|
415
|
-
|
416
|
-
def replace_op(old:UOp, new:UOp):
|
417
|
-
for u in self.uops:
|
418
|
-
u.vin = tuple(new if x is old else x for x in u.vin)
|
419
|
-
self.uops.remove(old)
|
420
|
-
|
421
|
-
# fix loop scope, push uops upward out of loop if it does not depend on the loop
|
422
|
-
loop_stack: List[List[UOp]] = [[]]
|
423
|
-
for u in self.uops:
|
424
|
-
if not loop_stack[-1]: loop_stack[-1].append(u)
|
425
|
-
elif u.uop == UOps.LOOP: loop_stack.append([u])
|
426
|
-
elif u.uop not in [UOps.CONST, UOps.ALU, UOps.CAST, UOps.LOAD]: loop_stack[-1].append(u)
|
427
|
-
else:
|
428
|
-
parents = get_recursive_parents(u, with_phi=True)
|
429
|
-
# don't push any local buffer because there might have STORE and BARRIER (not considered as parent) between DEFINE_LOCAL and here
|
430
|
-
if any(u.uop == UOps.DEFINE_LOCAL for u in parents): loop_stack[-1].append(u)
|
431
|
-
else:
|
432
|
-
for i in reversed(range(len(loop_stack))):
|
433
|
-
# check backwards and put the uop in the first encounter with some dependency
|
434
|
-
if any(x in parents for x in loop_stack[i]) or i == 0:
|
435
|
-
loop_stack[i].append(u)
|
436
|
-
break
|
437
|
-
self.uops = flatten(loop_stack)
|
438
|
-
|
439
|
-
# uops optimization
|
440
|
-
changed_something = True
|
441
|
-
while changed_something:
|
442
|
-
changed_something = False
|
443
|
-
for u in self.uops:
|
444
|
-
if u.uop == UOps.PHI and len(u.vin) == 3:
|
445
|
-
# if the parents of the PHI node don't have the LOOP in their parents, it can be folded
|
446
|
-
# TODO: ADD becomes a MUL, MAX can just become nothing
|
447
|
-
if all(x.uop != UOps.LOOP for x in get_recursive_parents(UOp(u.uop, u.dtype, u.vin[0:2], u.arg))) and u.vin[1].arg == BinaryOps.ADD:
|
448
|
-
if DEBUG >= 4: print(f"removing PHI node {u}")
|
449
|
-
del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)]
|
450
|
-
# NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype
|
451
|
-
loop_len = self.uop(UOps.ALU, u.vin[2].vin[1].dtype, (u.vin[2].vin[1], u.vin[2].vin[0]), BinaryOps.SUB, insert_before=self.uops.index(u))
|
452
|
-
if loop_len.dtype != u.dtype: loop_len = self.uop(UOps.CAST, u.dtype, (loop_len,), insert_before=self.uops.index(u))
|
453
|
-
replace_op(u, self.uop(UOps.ALU, u.dtype, (u.vin[1], loop_len,), BinaryOps.MUL, insert_before=self.uops.index(u)))
|
454
|
-
changed_something = True
|
455
|
-
|
456
|
-
# (recursively) remove childless uops
|
457
|
-
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
|
458
|
-
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL}
|
459
|
-
while 1:
|
460
|
-
has_child: Set[UOp] = set()
|
461
|
-
for ru in self.uops:
|
462
|
-
for vu in ru.vin:
|
463
|
-
has_child.add(vu)
|
464
|
-
nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop in UOPS_W_SIDE_EFFECTS]
|
465
|
-
if len(nu) == len(self.uops): break
|
466
|
-
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
|
467
|
-
self.uops = nu
|
468
|
-
del nu
|
469
|
-
|
470
|
-
# add UOps.END
|
471
|
-
for u in self.uops:
|
472
|
-
if u.uop == UOps.LOOP:
|
473
|
-
# add END of loops after the last thing that (recursively) depends on them
|
474
|
-
self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(sorted(list(get_recursive_children(u)), key=self.uops.index)[-1])+1) # noqa: E501
|
475
|
-
elif u.uop == UOps.IF:
|
476
|
-
# END any if statements at the end of the uops
|
477
|
-
self.uop(UOps.END, None, (u,), cachable=False)
|
454
|
+
# only the final stores are needed to define the full UOps graph
|
455
|
+
self.uops:UOpGraph = UOpGraph(flatten(stores))
|
478
456
|
|
479
457
|
# maybe graph the uops
|
480
|
-
if DEBUG >= 5:
|
481
|
-
|
482
|
-
print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") # noqa: E501
|
483
|
-
if getenv("GRAPHUOPS"):
|
484
|
-
from tinygrad.graph import graph_uops
|
485
|
-
graph_uops(self.uops)
|
458
|
+
if DEBUG >= 5: self.uops.print()
|
459
|
+
if getenv("GRAPHUOPS"): self.uops.graph()
|
486
460
|
|
487
461
|
# restore backups
|
488
|
-
self.sts, self.
|
462
|
+
self.sts, self.group_for_reduces, self.upcasted = sts_backup, gfr_backup, upc_backup
|
489
463
|
|
490
464
|
# set cache and return
|
491
465
|
self.applied_opts_cache = self.applied_opts[:]
|
492
466
|
return self
|
493
467
|
|
494
|
-
def
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
if
|
510
|
-
|
511
|
-
return self.const(vin[0].arg, dtype, insert_before)
|
512
|
-
if uop == UOps.ALU:
|
513
|
-
# rewrites. NOTE: the rewritten NEG op is still around...
|
514
|
-
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG:
|
515
|
-
return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable, insert_before)
|
516
|
-
# constant folding
|
517
|
-
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype, insert_before)
|
518
|
-
if arg == TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
|
519
|
-
# zero folding
|
520
|
-
for x in [0,1]:
|
521
|
-
if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
|
522
|
-
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
|
523
|
-
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x]
|
524
|
-
if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0]
|
525
|
-
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
|
526
|
-
|
527
|
-
key = (uop, dtype, vin, arg)
|
528
|
-
if insert_before is None: insert_before = len(self.uops)
|
529
|
-
# check if the cached expr is valid with the given insert place.
|
530
|
-
if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr
|
531
|
-
ret = UOp(uop, dtype, vin, arg)
|
532
|
-
self.uops.insert(insert_before, ret)
|
533
|
-
if cachable: self.saved_exprs[key] = ret
|
534
|
-
return ret
|
468
|
+
def render_block(self, outputs:Tuple[LazyOp, ...], global_idxs, local_idxs, upcast_idxs, full_upcast_idxs,
|
469
|
+
alias_buf_idxs:DefaultDict[LazyOp,List[Tuple[int,int,List[NumNode|Variable]]]],
|
470
|
+
loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], accs:Dict[LazyOp,List[UOp]]) -> List[List[UOp]]:
|
471
|
+
reduceops = dedup(x for x in outputs if x.op in ReduceOps)
|
472
|
+
assert len(reduceops) <= 1, "max one reduceop per block"
|
473
|
+
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
|
474
|
+
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
475
|
+
|
476
|
+
if len(reduceops) != 0:
|
477
|
+
# TODO: delete render_reduceop and move the logic for group_for_reduces to Block
|
478
|
+
nlidx, nuidx = self.render_reduceop((r:=reduceops[0]),accs,loaded_buffers,\
|
479
|
+
global_idxs,local_idxs,upcast_idxs,full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[r])
|
480
|
+
|
481
|
+
# all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
|
482
|
+
# been rewritten with fake end_local_idxs.
|
483
|
+
if r is self.reduceops[-1]: local_idxs[:], upcast_idxs[:] = nlidx, nuidx
|
484
|
+
return [accs[r]]
|
535
485
|
|
536
|
-
|
486
|
+
# load latebufs
|
487
|
+
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
|
488
|
+
for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})
|
489
|
+
# run late AST (without the store)
|
490
|
+
store_vals = {op.arg.idx:self.ast_parse(op.src[0], accs, None, loaded_buffers) for op in self.ast}
|
491
|
+
return [self.global_store(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) for i, val in store_vals.items()]
|
492
|
+
|
493
|
+
def ast_parse(self, x:LazyOp, accs:Dict[LazyOp, List[UOp]], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], reduce_acc:Optional[List[UOp]]=None, cache=None) -> List[UOp]: # noqa: E501
|
537
494
|
if cache is None: cache = {}
|
538
495
|
if x in cache: return cache[x]
|
539
496
|
if x.op in BufferOps: return loaded_buffers[x.arg]
|
540
|
-
if x.op
|
541
|
-
return [
|
542
|
-
|
543
|
-
|
544
|
-
return
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
if (castop:=x.src[0]).op == UnaryOps.CAST and (mulop:=castop.src[0]).op == BinaryOps.MUL:
|
549
|
-
# MULACC with acc cast rewrite: MUL -> CAST -> SUM => CAST -> MULACC
|
550
|
-
x = LazyOp(TernaryOps.MULACC, tuple(LazyOp(UnaryOps.CAST, (s, ), castop.arg) for s in mulop.src), x.arg)
|
551
|
-
|
552
|
-
values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx, cache=cache) for v in x.src]
|
553
|
-
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
|
497
|
+
if x.op in [UnaryOps.CAST, UnaryOps.BITCAST]:
|
498
|
+
return [UOp(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST,
|
499
|
+
self.get_base_dtype(x.arg), (u,)) for u in self.ast_parse(x.src[0], accs, offs, loaded_buffers)]
|
500
|
+
if x.op in ReduceOps and reduce_acc is None:
|
501
|
+
return [accs[x][i] for i in offs] if offs else accs[x]
|
502
|
+
|
503
|
+
values = [self.ast_parse(v, accs, offs, loaded_buffers, cache=cache) for v in x.src]
|
504
|
+
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}
|
554
505
|
if x.op in ops:
|
506
|
+
assert reduce_acc is not None
|
555
507
|
ret: List[UOp] = []
|
556
|
-
input_acc =
|
508
|
+
acc, input_acc = reduce_acc, reduce_acc[:]
|
557
509
|
for val, off in zip(zip(*values), cast(List[int], offs)):
|
558
|
-
acc[off] =
|
510
|
+
acc[off] = UOp.alu(ops[cast(ReduceOps, x.op)], *(val+(acc[off], )))
|
559
511
|
ret.append(acc[off])
|
560
512
|
for off in range(len(acc)):
|
561
513
|
if input_acc[off] != acc[off]:
|
562
|
-
acc[off] =
|
563
|
-
else:
|
564
|
-
ret = [self.uop(UOps.ALU, dtypes.bool if x.op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else val[-1].dtype, val, x.op) for val in zip(*values)]
|
514
|
+
acc[off] = UOp(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]))
|
515
|
+
else: ret = [UOp.alu(x.op, *src) for src in zip(*values)]
|
565
516
|
cache[x] = ret
|
566
517
|
return ret
|
518
|
+
|
519
|
+
def to_program(self) -> Program:
|
520
|
+
self.linearize()
|
521
|
+
info = get_lazyop_info(self.ast[0])
|
522
|
+
src = self.opts.render(name:=to_function_name(self.name), self.uops)
|
523
|
+
if getenv("RUN_PROCESS_REPLAY"): diskcache_put("process_replay", id(self), (self.ast, self.opts, self.applied_opts, name, src))
|
524
|
+
ops, mem = self.uops.flops_mem()
|
525
|
+
run_count = prod((self.global_size or []) + (self.local_size or []))
|
526
|
+
# NOTE: we use min here to ignore the indexing FLOPS
|
527
|
+
return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
|
528
|
+
self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|