tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
- tinygrad/codegen/kernel.py +572 -83
- tinygrad/codegen/linearizer.py +415 -395
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +183 -0
- tinygrad/dtype.py +113 -0
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +76 -55
- tinygrad/helpers.py +196 -89
- tinygrad/lazy.py +210 -371
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +202 -22
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +112 -32
- tinygrad/nn/state.py +136 -39
- tinygrad/ops.py +119 -202
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +353 -166
- tinygrad/renderer/llvmir.py +150 -138
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +81 -0
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +75 -0
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +24 -77
- tinygrad/runtime/ops_cuda.py +175 -89
- tinygrad/runtime/ops_disk.py +56 -33
- tinygrad/runtime/ops_gpu.py +92 -95
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +39 -60
- tinygrad/runtime/ops_metal.py +92 -74
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +86 -254
- tinygrad/shape/symbolic.py +166 -141
- tinygrad/shape/view.py +296 -0
- tinygrad/tensor.py +2619 -448
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- tinygrad-0.9.0.dist-info/METADATA +227 -0
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/assembly.py +0 -190
- tinygrad/codegen/optimizer.py +0 -379
- tinygrad/codegen/search.py +0 -72
- tinygrad/graph.py +0 -83
- tinygrad/jit.py +0 -57
- tinygrad/nn/image.py +0 -100
- tinygrad/renderer/assembly_arm64.py +0 -169
- tinygrad/renderer/assembly_ptx.py +0 -98
- tinygrad/renderer/wgsl.py +0 -53
- tinygrad/runtime/lib.py +0 -113
- tinygrad/runtime/ops_cpu.py +0 -51
- tinygrad/runtime/ops_hip.py +0 -82
- tinygrad/runtime/ops_shm.py +0 -29
- tinygrad/runtime/ops_torch.py +0 -30
- tinygrad/runtime/ops_webgpu.py +0 -45
- tinygrad-0.7.0.dist-info/METADATA +0 -212
- tinygrad-0.7.0.dist-info/RECORD +0 -40
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/codegen/linearizer.py
CHANGED
@@ -1,440 +1,460 @@
|
|
1
|
-
from
|
2
|
-
import
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence
|
3
|
+
import itertools, math, functools
|
3
4
|
from collections import defaultdict
|
4
|
-
from enum import Enum, auto
|
5
5
|
|
6
|
-
from tinygrad.
|
7
|
-
from tinygrad.
|
8
|
-
from tinygrad.
|
9
|
-
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
|
10
|
-
from tinygrad.runtime.lib import RawConst
|
6
|
+
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
|
7
|
+
from tinygrad.helpers import colored, DEBUG, prod, getenv, to_function_name
|
8
|
+
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
|
11
9
|
from tinygrad.shape.shapetracker import ShapeTracker
|
12
|
-
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode,
|
13
|
-
from tinygrad.codegen.
|
14
|
-
from tinygrad.
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
return
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
if any(x is None for x in nv): break
|
96
|
-
new_idxs.append(idxs[i:i+4])
|
97
|
-
new_values.append(nv)
|
98
|
-
if len(new_values) == len(idxs)//4:
|
99
|
-
return zip(new_idxs, new_values)
|
100
|
-
return zip([[i] for i in range(len(values[0]))], zip(*values))
|
101
|
-
|
102
|
-
# TODO: generic visitor pattern?
|
103
|
-
def expand_node(idx:Node) -> List[Node]:
|
104
|
-
if isinstance(idx, Variable): return [idx] if idx.expr is not None else [Variable.num(j) for j in range(idx.min, idx.max+1)]
|
105
|
-
if isinstance(idx, NumNode): return [idx]
|
106
|
-
if isinstance(idx, MulNode): return [x*idx.b for x in expand_node(idx.a)]
|
107
|
-
if isinstance(idx, SumNode): return [Variable.sum(list(it)) for it in itertools.product(*[expand_node(x) for x in idx.nodes])]
|
108
|
-
raise NotImplementedError(idx)
|
109
|
-
|
110
|
-
def expand_idxs(idxs:Sequence[Node]) -> Iterator[Tuple[Node, ...]]:
|
111
|
-
for x in itertools.product(*[expand_node(idx) for idx in idxs[::-1]]):
|
112
|
-
yield x[::-1]
|
113
|
-
|
114
|
-
class MemOp(NamedTuple):
|
115
|
-
name: str
|
116
|
-
idx: Node
|
117
|
-
local: bool
|
118
|
-
memory_dtype: DType
|
119
|
-
|
120
|
-
# shared
|
121
|
-
valid: Node
|
122
|
-
invalid_value: Union[float, int] = 0.0
|
123
|
-
|
124
|
-
class ConstOp(NamedTuple):
|
125
|
-
value: Union[float, int]
|
126
|
-
|
127
|
-
# shared
|
128
|
-
valid: Node
|
129
|
-
invalid_value: Union[float, int] = 0.0
|
130
|
-
|
131
|
-
class UOp(NamedTuple):
|
132
|
-
uop: UOps
|
133
|
-
out: Optional[Token]
|
134
|
-
vin: List[Token]
|
135
|
-
arg: Any
|
136
|
-
def __repr__(self): return f"{str(self.uop):20s}: {str(self.out) if self.out is not None else '':25s} {str(self.vin):32s} {self.arg}"
|
137
|
-
|
138
|
-
class Linearizer(OptimizedKernel):
|
139
|
-
def get_buffer_name(self, i):
|
140
|
-
if self.bufs[i].__class__ == LocalBuffer: return self.bufs[i].name
|
141
|
-
assert self.bufs[i].realized.__class__ is not RawConst # constants shouldn't be loaded with memops
|
142
|
-
return self.arg_bufs[self.bufs[i].realized]
|
143
|
-
|
144
|
-
def global_load(self, i:int, idxs:Sequence[VariableOrNum], acc=None) -> List[Token]:
|
145
|
-
const = self.bufs[i].realized._buf if isinstance(self.bufs[i].realized, RawConst) else acc
|
146
|
-
|
147
|
-
expanded_nodes = [expand_node(idx) for idx in idxs]
|
148
|
-
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
|
149
|
-
upcast_dim = self.get_upcast_dim(i)
|
150
|
-
|
151
|
-
amt = 1
|
152
|
-
if len(upcast_dim) == 1 and len(expanded_nodes[upcast_dim[0]]) in [4,2]:
|
153
|
-
dim, amt = upcast_dim[0], len(expanded_nodes[upcast_dim[0]])
|
10
|
+
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node
|
11
|
+
from tinygrad.codegen.kernel import LocalBuffer, Kernel
|
12
|
+
from tinygrad.renderer import Program
|
13
|
+
|
14
|
+
from tinygrad.codegen.uops import UOps, UOp, UOpGraph
|
15
|
+
|
16
|
+
def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
|
17
|
+
local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate((prod(local_dims[:-(maxdim-1)]),) + local_dims[-(maxdim-1):] if len(local_dims) > maxdim else local_dims)] # noqa: E501
|
18
|
+
if maxdim != 0 and len(local_dims) > maxdim:
|
19
|
+
dd = local_idxs[0]
|
20
|
+
nli = []
|
21
|
+
for s in local_dims[:-(maxdim-1)]:
|
22
|
+
nli.append(dd % s)
|
23
|
+
dd //= s
|
24
|
+
local_idxs = nli + local_idxs[-(maxdim-1):]
|
25
|
+
return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
|
26
|
+
|
27
|
+
def expand_idx(node:Node) -> Union[Variable, NumNode]: return next((v for v in node.vars() if v.expr.startswith("_uidx")), NumNode(0))
|
28
|
+
def expand_idxs(nodes:Sequence[Node]) -> Tuple[Union[Variable, NumNode], ...]:
|
29
|
+
eidxs = [expand_idx(node) for node in nodes]
|
30
|
+
return tuple([v if v not in eidxs[:j] else NumNode(0) for j, v in enumerate(eidxs)]) # take only first occurrence of expand variable
|
31
|
+
def iter_idxs(idxs:Tuple[Union[Variable, NumNode], ...]) -> Iterator[Tuple[int,...]]:
|
32
|
+
yield from (x[::-1] for x in itertools.product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]]))
|
33
|
+
|
34
|
+
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
|
35
|
+
idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
|
36
|
+
# TODO: bring back the valid removal logic (correct!)
|
37
|
+
if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid)
|
38
|
+
return (idx, idy), valid
|
39
|
+
|
40
|
+
# expand a Node into List[Node] that enumerates the underlying Variables from min to max
|
41
|
+
# expand increments earlier variables faster than later variables (as specified in the argument)
|
42
|
+
@functools.lru_cache(maxsize=None)
|
43
|
+
def expand_node(node:Node, idxs:Optional[Tuple[Union[Variable, NumNode], ...]]=None) -> List[Node]:
|
44
|
+
if idxs is None: idxs = (expand_idx(node),)
|
45
|
+
return [node.substitute({k:v for k,v in zip(idxs, (NumNode(x) for x in rep)) if isinstance(k, Variable)}) for rep in iter_idxs(idxs)]
|
46
|
+
|
47
|
+
class Linearizer(Kernel):
|
48
|
+
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op): return UOp.alu(op, a, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
|
49
|
+
|
50
|
+
# NOTE: the consts have to be cached for deduping of downstream uops to work
|
51
|
+
def const(self, b:ConstType, dtype:DType=dtypes.int32) -> UOp:
|
52
|
+
return self.uops.add(UOps.DEFINE_VAR, dtype, (), b.unbind()[0]) if isinstance(b, Variable) else UOp.const(dtype, b)
|
53
|
+
|
54
|
+
def get_reduce_acc(self, reduceop:LazyOp):
|
55
|
+
if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
|
56
|
+
if reduceop.op is ReduceOps.MAX:
|
57
|
+
if dtypes.is_int(reduceop.dtype): return 0 if dtypes.is_unsigned(reduceop.dtype) else -2**(reduceop.dtype.itemsize*8-1)
|
58
|
+
return -math.inf if dtypes.is_float(reduceop.dtype) else False
|
59
|
+
|
60
|
+
# NOTE: once images are loaded, we uop them as their base float
|
61
|
+
def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt
|
62
|
+
|
63
|
+
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
|
64
|
+
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
|
65
|
+
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
|
66
|
+
ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
|
67
|
+
LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT),
|
68
|
+
SumNode: lambda self,ops,ctx:
|
69
|
+
functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
70
|
+
AndNode: lambda self,ops,ctx:
|
71
|
+
functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
72
|
+
|
73
|
+
def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Optional[UOp]=None, loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]:
|
74
|
+
buf = self.bufs[i]
|
75
|
+
localtype = self.get_base_dtype(buf.dtype if acc is None else acc.dtype)
|
76
|
+
const = buf.val if isinstance(buf, ConstBuffer) else None
|
77
|
+
|
78
|
+
expand_vars = expand_idxs(idxs)
|
79
|
+
|
80
|
+
dim, amt = None, 1
|
81
|
+
# float 4 grouping
|
82
|
+
if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [4,2]:
|
83
|
+
dim, amt = upcast_dim[0], len(float4_expand)
|
84
|
+
g_idx, g_valid = self.sts[i].expr_idxs(idxs[:dim] + [float4_expand[0]] + idxs[dim+1:])
|
85
|
+
# do not use float4 if idx is not aligned
|
86
|
+
if g_idx != (g_idx//amt*amt): dim, amt = None, 1
|
87
|
+
if dim is None:
|
88
|
+
g_idx, g_valid = self.sts[i].expr_idxs(idxs)
|
89
|
+
# todo: multioutput test with different output valids to add if acc is None: g_valid = NumNode(1)
|
90
|
+
|
91
|
+
if amt > 1: localtype = localtype.vec(amt)
|
92
|
+
e_idxs, e_valids = expand_node(g_idx, expand_vars), expand_node(g_valid, expand_vars)
|
154
93
|
|
155
94
|
ret = []
|
156
|
-
invalid_value = 0
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
idx, valid = self.sts[i].expr_idxs(_idx)
|
163
|
-
localtype = dtypes.float32
|
164
|
-
else:
|
165
|
-
idx, valid = self.sts[i].expr_idxs(_idx)
|
166
|
-
localtype = dtypes.float32
|
167
|
-
this_const, idx, valid = (invalid_value, Variable.num(0), Variable.num(1)) if valid.max == 0 else (const, idx, valid)
|
168
|
-
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else self.get_buffer_name(i)}{idx.render()}{valid.render()}"
|
95
|
+
invalid_value = 0
|
96
|
+
acc_count = 0
|
97
|
+
for idx, valid, rep_idx in zip(e_idxs, e_valids, iter_idxs(expand_vars)):
|
98
|
+
this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid)
|
99
|
+
# todo: when multiple reduceops are supported, clearly disambiguate and test acc load keys are unique for each reduceop
|
100
|
+
key = f"{acc is not None}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
|
169
101
|
if key not in self.load_cache:
|
170
|
-
if
|
171
|
-
|
172
|
-
|
173
|
-
|
102
|
+
if acc is not None:
|
103
|
+
self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, loop_ctx, (self.get_reduce_acc(acc), i, acc_count))
|
104
|
+
acc_count += 1
|
105
|
+
elif this_const is not None:
|
106
|
+
self.load_cache[key] = self.const(this_const, localtype)
|
107
|
+
if valid.min == 0 and valid.max == 1:
|
108
|
+
valid_rendered = valid.render(self.render_ops, self)
|
109
|
+
self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_rendered, self.load_cache[key], self.const(invalid_value, localtype))
|
110
|
+
elif isinstance(buf.dtype, ImageDType):
|
111
|
+
buf_uop = self.buf_uops[i]
|
112
|
+
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
113
|
+
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
114
|
+
rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
|
115
|
+
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, buf.dtype.base.vec(4))) if valid.min == 0 else tuple()
|
116
|
+
self.load_cache[key] = self.uops.add(UOps.LOAD, buf.dtype.base.vec(4),
|
117
|
+
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
118
|
+
if localtype == localtype.scalar():
|
119
|
+
idx_small = idx%4
|
120
|
+
res = idx_small.render(self.render_ops, self)
|
121
|
+
out = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
|
122
|
+
for ix in range(idx_small.max, idx_small.min, -1):
|
123
|
+
rvv = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
|
124
|
+
sel = UOp.alu(BinaryOps.CMPLT, res, self.const(ix))
|
125
|
+
out = UOp.alu(TernaryOps.WHERE, sel, rvv, out)
|
126
|
+
self.load_cache[key] = out
|
127
|
+
else:
|
128
|
+
buf_uop = self.buf_uops[i]
|
129
|
+
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
130
|
+
rendered_idx = idx.render(self.render_ops, self)
|
131
|
+
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, localtype)) if valid.min == 0 else tuple()
|
132
|
+
self.load_cache[key] = self.uops.add(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
133
|
+
ret.append(self.uops.add(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
174
134
|
return ret
|
175
135
|
|
176
|
-
def global_store(self, i, idxs:List[
|
177
|
-
|
178
|
-
|
179
|
-
|
136
|
+
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
|
137
|
+
buf = self.bufs[i]
|
138
|
+
buf_uop = self.buf_uops[i]
|
139
|
+
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
180
140
|
|
141
|
+
expand_vars = expand_idxs(idxs)
|
142
|
+
_idxs = zip(*[expand_node(idx, expand_vars) for idx in idxs]) if idxs else [tuple()] # transpose
|
181
143
|
store_offset = dict(zip(_idxs, store))
|
182
144
|
|
183
145
|
# float4 grouping
|
184
|
-
if len(upcast_dim) == 1 and len(
|
146
|
+
if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [2,4]:
|
185
147
|
grouped_store_offset = defaultdict(list)
|
186
148
|
for k in store_offset:
|
187
|
-
_idx = k[:upcast_dim[0]] + (
|
149
|
+
_idx = k[:upcast_dim[0]] + (float4_expand[0],) + k[upcast_dim[0]+1:]
|
188
150
|
grouped_store_offset[_idx].append(store_offset[k])
|
189
151
|
store_offset_new = {}
|
190
|
-
for k,
|
191
|
-
amt = len(
|
152
|
+
for k,grouped in grouped_store_offset.items():
|
153
|
+
amt = len(grouped)
|
192
154
|
idx, valid = self.sts[i].expr_idxs(k)
|
193
|
-
assert idx
|
194
|
-
|
195
|
-
if all_same([x.name for x in out_tokens]) and tuple(range(amt)) == tuple(x.offset for x in out_tokens):
|
196
|
-
store_offset_new[k] = Token(out_tokens[0].name, dtypes._float4 if amt == 4 else dtypes._float2)
|
197
|
-
else:
|
198
|
-
store_offset_new[k] = self.uop(UOps.CAST, ssa("alu", dtypes._float4 if amt == 4 else dtypes._float2), out_tokens)
|
155
|
+
assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
|
156
|
+
store_offset_new[k] = self.uops.add(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
|
199
157
|
store_offset = store_offset_new
|
200
158
|
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
159
|
+
stores = []
|
160
|
+
for _idx, var in store_offset.items():
|
161
|
+
idx, valid = self.sts[i].expr_idxs(_idx)
|
162
|
+
if isinstance(buf.dtype, ImageDType):
|
163
|
+
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
164
|
+
rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), \
|
165
|
+
tuple(x.render(self.render_ops, self) for x in image_idx))
|
166
|
+
else:
|
167
|
+
rendered_idx = idx.render(self.render_ops, self)
|
168
|
+
if valid.min == 1: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var)))
|
169
|
+
else: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
|
170
|
+
return stores
|
171
|
+
|
172
|
+
# render loop
|
173
|
+
def render_loop(self, xx:List[Variable], depth:int) -> Tuple[UOp, ...]:
|
174
|
+
new_loops = {x.expr:self.uops.add(UOps.RANGE, dtypes.int32, (
|
175
|
+
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
|
176
|
+
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), arg=(depth,i)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
|
177
|
+
self.loop_uops.update(new_loops)
|
178
|
+
return tuple(new_loops.values())
|
179
|
+
|
180
|
+
def render_reduceop(self, reduceop:LazyOp, accs:Dict[LazyOp, List[UOp]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]],
|
181
|
+
global_idxs, local_idxs, upcast_idxs):
|
182
|
+
# define indicies
|
183
|
+
full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
|
184
|
+
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
|
185
|
+
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
186
|
+
|
187
|
+
def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
|
188
|
+
replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
|
189
|
+
for s in local_sizes:
|
190
|
+
thread_idxs.append(thread_idx % s)
|
191
|
+
thread_idx //= s
|
192
|
+
for alias in aliases:
|
193
|
+
full_var, full_var_sz = NumNode(0), 1
|
194
|
+
if alias[0] != 0:
|
195
|
+
for i in alias:
|
196
|
+
next_var = local_idxs[i-1] if i > 0 else thread_idxs[-i-1]
|
197
|
+
full_var += next_var * full_var_sz
|
198
|
+
full_var_sz *= next_var.max+1
|
199
|
+
replace_idxs.append(full_var)
|
200
|
+
return replace_idxs
|
201
|
+
|
202
|
+
# compute local aliases - modify idxs if necessary for TC
|
203
|
+
alias_buf_idxs = []
|
204
|
+
for i in self.local_alias:
|
205
|
+
localbuf_idx = self.bufs.index(self.local_alias[i])
|
206
|
+
buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
|
207
|
+
if (tc:=self.tensor_core):
|
208
|
+
min_alias_idx = min(self.local_alias.keys())
|
209
|
+
replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
|
210
|
+
for n in range(len(tc.threads)):
|
211
|
+
buf_idxs[self.global_dims+n] = replace_input_idxs[n] # replace locals
|
212
|
+
for n in range(tc.num_upcasts()):
|
213
|
+
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
|
214
|
+
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
|
215
|
+
alias_buf_idxs.append((i, localbuf_idx, buf_idxs,))
|
216
|
+
|
217
|
+
# reduce loop
|
218
|
+
loop_ctx = self.render_loop(reduce_idxs, 2)
|
219
|
+
|
220
|
+
# define accumulator - modify idxs if necessary for TC
|
221
|
+
out_buf = -1 if self.group_for_reduces else 0
|
222
|
+
if (tc:=self.tensor_core):
|
223
|
+
replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
|
224
|
+
for n in range(len(tc.threads)):
|
225
|
+
local_idxs[n] = replace_acc_idxs[n] # replace locals
|
226
|
+
for n in range(len(replace_acc_idxs)-len(tc.threads)):
|
227
|
+
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
|
228
|
+
if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs}")
|
229
|
+
accs[reduceop] = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
|
230
|
+
|
231
|
+
# store local aliases
|
232
|
+
locals_to_store = [(localbuf_idx, buf_idxs, self.global_load(i, buf_idxs)) for i, localbuf_idx, buf_idxs in alias_buf_idxs]
|
233
|
+
|
234
|
+
if (tc:=self.tensor_core):
|
235
|
+
# run tensor cores AST
|
236
|
+
wmma_sz = [prod(l) for l in tc.thread_local_sizes]
|
237
|
+
def upcast_strides(buf:int):
|
238
|
+
strides, next = [], 1
|
239
|
+
for (sz, stride, reduce) in self.upcasted_axis(buf)[tc.num_upcasts():]:
|
240
|
+
strides.append((0 if stride == 0 else next, sz))
|
241
|
+
next *= 1 if stride == 0 else sz
|
242
|
+
return strides
|
243
|
+
upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
|
244
|
+
# cast initial accs
|
245
|
+
wmmas = [self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
|
246
|
+
for x in range(0, len(accs[reduceop]), wmma_sz[2])]
|
247
|
+
for iter in [x[::-1] for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]:
|
248
|
+
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
|
249
|
+
ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
|
250
|
+
self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
|
251
|
+
wmmas[(wmma_idx:=offs[2]//wmma_sz[2])])
|
252
|
+
# TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid
|
253
|
+
wmmas[wmma_idx] = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev))
|
254
|
+
# phi the last wmmas back to accs
|
255
|
+
accs[reduceop] = [self.uops.add(UOps.PHI, tc.dtype_out, (acc, self.uops.add(UOps.GEP, tc.dtype_out, (wmmas[z//wmma_sz[2]],), z%wmma_sz[2])))
|
256
|
+
for z, acc in enumerate(accs[reduceop])]
|
257
|
+
else:
|
258
|
+
assert not locals_to_store, "storing locals isn't supported here"
|
259
|
+
|
260
|
+
# load earlybufs
|
261
|
+
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i,
|
262
|
+
global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs})
|
263
|
+
|
264
|
+
# run early AST (with reduce)
|
265
|
+
self.ast_parse(reduceop, accs, self.acc_offsets(self.full_buf_index), loaded_buffers, reduce_acc=accs[reduceop])
|
266
|
+
|
267
|
+
# end the reduce loop
|
268
|
+
self.load_cache.clear()
|
269
|
+
|
270
|
+
# end the local loop, do the local reduce
|
271
|
+
if self.group_for_reduces:
|
272
|
+
fake_global_idxs = [x*0 for x in global_idxs]
|
273
|
+
stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop]) # store accumulators
|
274
|
+
barrier = self.uops.add(UOps.BARRIER, None, tuple(stores))
|
275
|
+
if self.opts.has_local:
|
276
|
+
fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
|
277
|
+
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
278
|
+
if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(self.render_ops, self)
|
279
|
+
barrier = self.uops.add(UOps.IF, None, (if_cond, barrier))
|
280
|
+
|
281
|
+
# create new late reduce local loops and replace local_idxs that have been used
|
282
|
+
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501
|
283
|
+
local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
|
284
|
+
|
285
|
+
# if any group_for_reduce items aren't reduces, upcast them here
|
286
|
+
for j in self.upcast_in_mid_reduce_axes:
|
287
|
+
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
|
288
|
+
self.upcast()
|
289
|
+
self.group_for_reduces -= 1
|
290
|
+
local_idxs = local_idxs[:-1]
|
291
|
+
end_local_idxs = end_local_idxs[:-1]
|
292
|
+
# regenerate upcast_idxs
|
293
|
+
upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
|
294
|
+
|
295
|
+
# NOTE: this structure is the same as the reduce op above
|
296
|
+
|
297
|
+
# late reduce loop
|
298
|
+
loop_ctx = self.render_loop(end_local_idxs, 3)
|
299
|
+
|
300
|
+
# define late accumulator
|
301
|
+
accs[reduceop] = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
|
302
|
+
|
303
|
+
# load localbufs
|
304
|
+
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
|
305
|
+
|
306
|
+
# there's no AST here (and there's no shape for the reduce LazyOp)
|
307
|
+
self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)),\
|
308
|
+
accs, self.acc_offsets(-1), loaded_buffers, reduce_acc=accs[reduceop])
|
309
|
+
|
310
|
+
# end the late reduce loop
|
311
|
+
self.load_cache.clear()
|
312
|
+
|
313
|
+
# all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
|
314
|
+
# been rewritten with fake end_local_idxs.
|
315
|
+
return (accs, loaded_buffers, fake_reduce_idxs, local_idxs[:self.local_dims] + [NumNode(0) for i in range(self.group_for_reduces)], upcast_idxs)
|
205
316
|
|
206
317
|
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
207
318
|
def linearize(self):
|
208
|
-
|
319
|
+
# no new opts and we already ran? skip relinearizing
|
320
|
+
if self.applied_opts == self.applied_opts_cache: return self
|
321
|
+
|
322
|
+
# late alias the tensor core buffers
|
323
|
+
if (tc:=self.tensor_core) and (tc_opts:=self.tensor_core_opts):
|
324
|
+
alias_pattern = [0]*(self.global_dims) + [2]*(len(tc.threads)) + [0]*(self.local_dims-len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) # noqa: E501
|
325
|
+
for tc_buf in tc_opts.bufs:
|
326
|
+
self.alias_buffer(tc_buf, alias_pattern)
|
327
|
+
|
328
|
+
# save backups
|
329
|
+
sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduces, self.upcasted
|
330
|
+
|
331
|
+
# global uop cache
|
332
|
+
self.saved_exprs: Dict[Tuple, UOp] = dict()
|
209
333
|
|
210
334
|
# limit dims if we need to
|
211
|
-
if self.opts.global_max and self.opts.local_max: self.
|
335
|
+
if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max)
|
212
336
|
|
213
337
|
# uops
|
214
|
-
self.uops:
|
215
|
-
self.
|
216
|
-
self.
|
338
|
+
self.uops:UOpGraph = UOpGraph()
|
339
|
+
self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
|
340
|
+
self.loop_uops: Dict[str, UOp] = {}
|
217
341
|
|
218
342
|
# add global buffers
|
219
|
-
for buf
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
|
229
|
-
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
|
230
|
-
self.uop(UOps.DEFINE_LOCAL, None, [], ("temp", self.sts[-1].size()))
|
231
|
-
|
343
|
+
for i,buf in enumerate(self.bufs):
|
344
|
+
if isinstance(buf, MemBuffer):
|
345
|
+
self.buf_uops[i] = self.uops.add(UOps.DEFINE_GLOBAL,
|
346
|
+
buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
|
347
|
+
(buf.idx, any(buf.idx == x.idx for x in self.outbufs)))
|
348
|
+
# add var vals
|
349
|
+
for i,var in enumerate(self.vars):
|
350
|
+
assert var.expr is not None
|
351
|
+
self.loop_uops[var.expr] = self.uops.add(UOps.DEFINE_VAR, dtypes.int32, (), var)
|
232
352
|
# define local buffers
|
233
353
|
for lb in self.local_alias.values():
|
234
|
-
self.
|
235
|
-
|
236
|
-
#
|
237
|
-
if
|
354
|
+
self.buf_uops[self.bufs.index(lb)] = self.uops.add(UOps.DEFINE_LOCAL,
|
355
|
+
PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size))
|
356
|
+
# add a local buffer for multistage reduce. # TODO: use local alias
|
357
|
+
if self.group_for_reduces:
|
358
|
+
# TODO: the strides of this can be controlled
|
359
|
+
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
|
360
|
+
temp_dtype = self.get_base_dtype(cast(LazyOp, self.reduceop).dtype)
|
361
|
+
self.bufs.append(LocalBuffer("temp", self.sts[-1].size, temp_dtype))
|
362
|
+
self.buf_uops.append(self.uops.add(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size)))
|
238
363
|
|
239
364
|
# kernel name (before late upcast)
|
240
|
-
self.
|
241
|
-
|
242
|
-
|
243
|
-
# parse AST
|
244
|
-
loaded_buffers = {}
|
245
|
-
acc = []
|
246
|
-
|
247
|
-
# ssa
|
248
|
-
_ssa:DefaultDict[str,int] = defaultdict(int)
|
249
|
-
def ssa(name, ltype=dtypes.float) -> Token:
|
250
|
-
_ssa[name] += 1
|
251
|
-
return Token(f"{name}{_ssa[name]-1}", ltype)
|
252
|
-
|
253
|
-
# global loop
|
254
|
-
global_idxs = [Variable(f"gidx{i}", 0, self.full_shape[i]-1) for i in range(0, self.first_reduce-self.local_dims)]
|
255
|
-
self.uop(UOps.LOOP, None, [], (global_idxs, "global"))
|
365
|
+
self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
|
366
|
+
(f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \
|
367
|
+
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
256
368
|
|
257
|
-
#
|
258
|
-
|
259
|
-
|
369
|
+
# name the function something unique
|
370
|
+
Linearizer.kernel_cnt[(function_name := to_function_name(self.name))] += 1
|
371
|
+
suffix = f"{'n'+str(Linearizer.kernel_cnt[function_name]-1)}" if Linearizer.kernel_cnt[function_name] > 1 else ""
|
372
|
+
self.name = self.name+colored(suffix, 'BLACK')
|
373
|
+
|
374
|
+
# define indexes
|
375
|
+
global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
|
376
|
+
local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+self.group_for_reduces], 3 if self.opts.has_local else 0) # noqa: E501
|
377
|
+
upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
|
378
|
+
|
379
|
+
# set global/local size
|
380
|
+
self.global_size: Optional[List[int]] = None
|
381
|
+
self.local_size: Optional[List[int]] = None
|
382
|
+
if self.dont_use_locals:
|
383
|
+
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
|
384
|
+
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
|
385
|
+
elif self.opts.has_local:
|
386
|
+
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs]
|
387
|
+
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
|
388
|
+
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
|
389
|
+
else:
|
390
|
+
self.render_loop(loop_global_idxs+loop_local_idxs, 1)
|
391
|
+
if self.global_size is not None: self.global_size += [1]*(3-len(self.global_size))
|
392
|
+
if self.local_size is not None: self.local_size += [1]*(3-len(self.local_size))
|
260
393
|
|
261
|
-
#
|
262
|
-
|
263
|
-
|
394
|
+
# parse AST
|
395
|
+
loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]] = {}
|
396
|
+
accs: Dict[LazyOp, List[UOp]] = {}
|
397
|
+
self.load_cache: Dict[str, UOp] = {}
|
264
398
|
|
265
399
|
# reduce op
|
266
|
-
fake_reduce_idxs = []
|
267
|
-
if self.reduceop is not None:
|
268
|
-
|
269
|
-
|
270
|
-
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
271
|
-
|
272
|
-
# define accumulator
|
273
|
-
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
274
|
-
|
275
|
-
# reduce loop
|
276
|
-
self.uop(UOps.LOOP, None, [], (reduce_idxs, "reduce"))
|
277
|
-
|
278
|
-
# barrier for fast GEMM
|
279
|
-
if self.use_tensor_cores: self.uop(UOps.BARRIER, None, [], ())
|
280
|
-
|
281
|
-
# compute local aliases
|
282
|
-
locals_to_store = []
|
283
|
-
for i in self.local_alias:
|
284
|
-
strides = self.sts[i].real_strides()
|
285
|
-
extra_locals = [lidx for lidx,st in zip(local_idxs[self.exclude_local_upcast:], strides[len(global_idxs)+self.exclude_local_upcast:self.first_reduce]) if st == 0]
|
286
|
-
this_upcast_idxs: List[Node] = []
|
287
|
-
# TODO: just flipping the order here is likely not generic at all
|
288
|
-
for j,v in list(enumerate(full_upcast_idxs))[::-1] if self.reverse_upcast_dir else list(enumerate(full_upcast_idxs)):
|
289
|
-
if strides[len(global_idxs)+len(local_idxs)+len(reduce_idxs)+j] == 0:
|
290
|
-
if DEBUG >= 4: print(f"upcasting@{j} stride 0")
|
291
|
-
this_upcast_idxs.append(Variable.num(0))
|
292
|
-
elif (elc:=[el for el in extra_locals if v.min == el.min and v.max == el.max]):
|
293
|
-
if DEBUG >= 4: print(f"upcasting@{j} matched stride {elc[0]}")
|
294
|
-
this_upcast_idxs.append(elc[0])
|
295
|
-
extra_locals.remove(elc[0])
|
296
|
-
elif (elc:=[el for el in extra_locals if v.min == el.min and (v.max+1)%(el.max+1) == 0]):
|
297
|
-
tacc = Variable.num(0)
|
298
|
-
rem = v.max+1
|
299
|
-
while len(elc) and rem%(elc[0].max+1) == 0:
|
300
|
-
if DEBUG >= 4: print(f"upcasting@{j} partial stride {rem} {elc[0]} left: {elc[1:]}")
|
301
|
-
rem = rem//(elc[0].max+1)
|
302
|
-
tacc += (elc[0] * rem)
|
303
|
-
extra_locals.remove(elc[0])
|
304
|
-
elc = [el for el in extra_locals if v.min == el.min and rem%(el.max+1) == 0]
|
305
|
-
if DEBUG >= 4 and rem > 1: print(f"failed upcasting@{j} partial stride {rem} extra locals {extra_locals}")
|
306
|
-
this_upcast_idxs.append(tacc + Variable(None, 0, rem-1))
|
307
|
-
else:
|
308
|
-
if DEBUG >= 4: print(f"failed upcasting@{j} stride {v} extra locals {extra_locals}")
|
309
|
-
this_upcast_idxs.append(v)
|
310
|
-
idxs = global_idxs+local_idxs+reduce_idxs+(this_upcast_idxs[::-1] if self.reverse_upcast_dir else this_upcast_idxs)
|
311
|
-
ll = self.global_load(i, idxs)
|
312
|
-
locals_to_store.append((self.bufs.index(self.local_alias[i]), idxs, ll))
|
313
|
-
|
314
|
-
# copy in any global buffers
|
315
|
-
if self.use_tensor_cores:
|
316
|
-
if self.bufs[0].device == "METAL":
|
317
|
-
i = 0
|
318
|
-
for y0,y1 in zip(locals_to_store[1][2][::2], locals_to_store[1][2][1::2]):
|
319
|
-
for x0,x1 in zip(locals_to_store[0][2][::2], locals_to_store[0][2][1::2]):
|
320
|
-
self.uop(UOps.WMMA, None, [x0, x1, y0, y1, acc[i], acc[i+1]], "METAL")
|
321
|
-
i += 2
|
322
|
-
elif self.bufs[0].device == "HIP":
|
323
|
-
i = 0
|
324
|
-
for y in range(0, len(locals_to_store[1][2]), 0x10):
|
325
|
-
for x in range(0, len(locals_to_store[0][2]), 0x10):
|
326
|
-
self.uop(UOps.WMMA, None, acc[i:i+8]+locals_to_store[0][2][x:x+0x10]+locals_to_store[1][2][y:y+0x10], "HIP")
|
327
|
-
i += 8
|
328
|
-
else:
|
329
|
-
if locals_to_store:
|
330
|
-
self.uop(UOps.BARRIER, None, [], ())
|
331
|
-
for i, idxs, ll in locals_to_store: self.global_store(i, idxs, ll, ssa)
|
332
|
-
self.uop(UOps.BARRIER, None, [], ())
|
333
|
-
|
334
|
-
# load earlybufs
|
335
|
-
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
|
336
|
-
|
337
|
-
# run early AST (with reduce)
|
338
|
-
self.ast_parse(self.reduceop, [acc[off] for off in self.acc_offsets(self.full_buf_index)], loaded_buffers, ssa, do_reduce=True)
|
339
|
-
|
340
|
-
# end the reduce loop
|
341
|
-
self.uop(UOps.ENDLOOP, None, [], (reduce_idxs, "reduce"))
|
342
|
-
self.load_cache.clear()
|
343
|
-
|
344
|
-
# end the local loop, do the local reduce
|
345
|
-
if self.group_for_reduce:
|
346
|
-
fake_global_idxs = [x*0 for x in global_idxs]
|
347
|
-
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc, ssa) # store accumulators
|
348
|
-
self.uop(UOps.BARRIER, None, [], ())
|
349
|
-
self.uop(UOps.ENDLOOP, None, [], (local_idxs, "local"))
|
350
|
-
|
351
|
-
# local indexs are over, 0 them out
|
352
|
-
local_idxs = [x*0 for x in local_idxs]
|
353
|
-
|
354
|
-
# if any group_for_reduce items aren't reduces, upcast them here
|
355
|
-
for j in self.upcast_in_mid_reduce_axes:
|
356
|
-
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
|
357
|
-
self.upcast()
|
358
|
-
self.group_for_reduce.pop()
|
359
|
-
local_idxs = local_idxs[:-1]
|
360
|
-
# regenerate upcast_idxs
|
361
|
-
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
|
362
|
-
|
363
|
-
# NOTE: this structure is the same as the reduce op above
|
364
|
-
|
365
|
-
# define late accumulator
|
366
|
-
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
367
|
-
|
368
|
-
# late reduce loop
|
369
|
-
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
|
370
|
-
self.uop(UOps.LOOP, None, [], (end_local_idxs, "late_reduce"))
|
371
|
-
|
372
|
-
# load localbufs
|
373
|
-
loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs+upcast_idxs)
|
374
|
-
|
375
|
-
# there's no AST here (and there's no shape for the reduce LazyOp)
|
376
|
-
self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True) # type: ignore
|
377
|
-
|
378
|
-
# end the late reduce loop
|
379
|
-
self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce"))
|
380
|
-
self.load_cache.clear()
|
400
|
+
fake_reduce_idxs: List[Variable] = []
|
401
|
+
for reduceop in [self.reduceop] if self.reduceop is not None else []:
|
402
|
+
accs,loaded_buffers,fake_reduce_idxs,local_idxs,upcast_idxs = \
|
403
|
+
self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs)
|
381
404
|
|
382
405
|
# load latebufs
|
383
|
-
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs)
|
406
|
+
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
|
407
|
+
for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})
|
384
408
|
|
385
|
-
# run late AST
|
386
|
-
|
409
|
+
# run late AST (without the store)
|
410
|
+
for op in self.ast:
|
411
|
+
val = self.ast_parse(op.src[0], accs, None, loaded_buffers)
|
412
|
+
self.global_store(op.arg.idx, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
387
413
|
|
388
|
-
#
|
389
|
-
self.
|
414
|
+
# maybe graph the uops
|
415
|
+
if DEBUG >= 5: self.uops.print()
|
416
|
+
if getenv("GRAPHUOPS"): self.uops.graph()
|
390
417
|
|
391
|
-
|
392
|
-
|
393
|
-
self.uop(UOps.ENDLOOP, None, [], (global_idxs+local_idxs, "global+local"))
|
394
|
-
else:
|
395
|
-
# end the global loop
|
396
|
-
self.uop(UOps.ENDLOOP, None, [], (global_idxs, "global"))
|
418
|
+
# restore backups
|
419
|
+
self.sts, self.group_for_reduces, self.upcasted = sts_backup, gfr_backup, upc_backup
|
397
420
|
|
398
|
-
#
|
399
|
-
|
400
|
-
suffix = f"{'n'+str(Linearizer.kernel_cnt[self.function_name]-1)}" if Linearizer.kernel_cnt[self.function_name] > 1 else ""
|
401
|
-
self.function_name, self.display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK')
|
421
|
+
# set cache and return
|
422
|
+
self.applied_opts_cache = self.applied_opts[:]
|
402
423
|
return self
|
403
424
|
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
if
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers, ssa) # cast isn't an ALU op
|
418
|
-
if x.op in ReduceOps and not do_reduce: return acc
|
419
|
-
# MULACC fusion. TODO: this is copied from Interpreted
|
420
|
-
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == BinaryOps.MUL:
|
421
|
-
x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
|
422
|
-
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL:
|
423
|
-
x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg)
|
424
|
-
if x.op in {BinaryOps.ADD, BinaryOps.MUL}:
|
425
|
-
# Reorder sources to put constants first so get_grouped_maybe_float4 can fold the op
|
426
|
-
srcs = sorted(x.src, key=lambda x: (x.realized.__class__ != RawConst) if x.__class__ == LazyBuffer else 0)
|
427
|
-
x.src = tuple(srcs)
|
428
|
-
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
|
429
|
-
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
|
425
|
+
def ast_parse(self, x:LazyOp, accs:Dict[LazyOp, List[UOp]], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], reduce_acc:Optional[List[UOp]]=None, cache=None) -> List[UOp]: # noqa: E501
|
426
|
+
if cache is None: cache = {}
|
427
|
+
if x in cache: return cache[x]
|
428
|
+
if x.op in BufferOps: return loaded_buffers[x.arg]
|
429
|
+
if x.op in [UnaryOps.CAST, UnaryOps.BITCAST]:
|
430
|
+
return [self.uops.add(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST,
|
431
|
+
self.get_base_dtype(x.arg), (u,)) for u in self.ast_parse(x.src[0], accs, offs, loaded_buffers)]
|
432
|
+
if x.op in ReduceOps and reduce_acc is None:
|
433
|
+
assert offs is None, "not available if we aren't doing reduce"
|
434
|
+
return accs[x]
|
435
|
+
|
436
|
+
values = [self.ast_parse(v, accs, offs, loaded_buffers, cache=cache) for v in x.src]
|
437
|
+
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}
|
430
438
|
if x.op in ops:
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
for
|
438
|
-
|
439
|
-
|
440
|
-
|
439
|
+
assert reduce_acc is not None
|
440
|
+
ret: List[UOp] = []
|
441
|
+
acc, input_acc = reduce_acc, reduce_acc[:]
|
442
|
+
for val, off in zip(zip(*values), cast(List[int], offs)):
|
443
|
+
acc[off] = UOp.alu(ops[cast(ReduceOps, x.op)], *(val+(acc[off], )))
|
444
|
+
ret.append(acc[off])
|
445
|
+
for off in range(len(acc)):
|
446
|
+
if input_acc[off] != acc[off]:
|
447
|
+
acc[off] = self.uops.add(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]))
|
448
|
+
else: ret = [UOp.alu(x.op, *vin) for vin in zip(*values)]
|
449
|
+
cache[x] = ret
|
450
|
+
return ret
|
451
|
+
|
452
|
+
def to_program(self) -> Program:
|
453
|
+
self.linearize()
|
454
|
+
info = get_lazyop_info(self.ast[0])
|
455
|
+
src = self.opts.render(to_function_name(self.name), self.uops)
|
456
|
+
ops, mem = self.uops.flops_mem()
|
457
|
+
run_count = prod((self.global_size if self.global_size else []) + (self.local_size if self.local_size else []))
|
458
|
+
# NOTE: we use min here to ignore the indexing FLOPS
|
459
|
+
return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
|
460
|
+
self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|