tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/kernel.py +230 -190
- tinygrad/codegen/linearizer.py +278 -384
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +132 -275
- tinygrad/dtype.py +53 -37
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +28 -14
- tinygrad/helpers.py +72 -43
- tinygrad/lazy.py +141 -240
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +179 -8
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +86 -17
- tinygrad/ops.py +70 -44
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +299 -206
- tinygrad/renderer/llvmir.py +118 -123
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +130 -38
- tinygrad/runtime/ops_disk.py +45 -42
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +42 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +41 -105
- tinygrad/shape/symbolic.py +98 -95
- tinygrad/shape/view.py +137 -35
- tinygrad/tensor.py +2367 -442
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/codegen/linearizer.py
CHANGED
@@ -1,49 +1,42 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union,
|
2
|
+
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence
|
3
3
|
import itertools, math, functools
|
4
4
|
from collections import defaultdict
|
5
|
-
from enum import Enum, auto
|
6
|
-
from dataclasses import dataclass
|
7
5
|
|
8
|
-
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
9
|
-
from tinygrad.helpers import colored, DEBUG, prod, getenv,
|
6
|
+
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
|
7
|
+
from tinygrad.helpers import colored, DEBUG, prod, getenv, to_function_name
|
10
8
|
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
|
11
9
|
from tinygrad.shape.shapetracker import ShapeTracker
|
12
|
-
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
|
10
|
+
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node
|
13
11
|
from tinygrad.codegen.kernel import LocalBuffer, Kernel
|
14
|
-
from tinygrad.
|
15
|
-
|
16
|
-
|
17
|
-
class UOps(Enum):
|
18
|
-
LOOP = auto(); IF = auto(); END = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
|
19
|
-
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
|
20
|
-
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702
|
21
|
-
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
|
22
|
-
|
23
|
-
@dataclass(eq=False)
|
24
|
-
class UOp:
|
25
|
-
uop: UOps
|
26
|
-
dtype: Optional[DType]
|
27
|
-
vin: Tuple[UOp, ...]
|
28
|
-
arg: Any
|
29
|
-
def __repr__(self):
|
30
|
-
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
|
12
|
+
from tinygrad.renderer import Program
|
13
|
+
|
14
|
+
from tinygrad.codegen.uops import UOps, UOp, UOpGraph
|
31
15
|
|
32
16
|
def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
|
33
|
-
local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[
|
17
|
+
local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate((prod(local_dims[:-(maxdim-1)]),) + local_dims[-(maxdim-1):] if len(local_dims) > maxdim else local_dims)] # noqa: E501
|
34
18
|
if maxdim != 0 and len(local_dims) > maxdim:
|
35
|
-
dd = local_idxs[
|
19
|
+
dd = local_idxs[0]
|
36
20
|
nli = []
|
37
|
-
for s in local_dims[maxdim-1
|
21
|
+
for s in local_dims[:-(maxdim-1)]:
|
38
22
|
nli.append(dd % s)
|
39
23
|
dd //= s
|
40
|
-
local_idxs = local_idxs[
|
24
|
+
local_idxs = nli + local_idxs[-(maxdim-1):]
|
41
25
|
return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
|
42
26
|
|
43
|
-
def expand_idx(node:Node) -> Union[Variable, NumNode]: return next((v for v in node.vars() if v.expr
|
27
|
+
def expand_idx(node:Node) -> Union[Variable, NumNode]: return next((v for v in node.vars() if v.expr.startswith("_uidx")), NumNode(0))
|
28
|
+
def expand_idxs(nodes:Sequence[Node]) -> Tuple[Union[Variable, NumNode], ...]:
|
29
|
+
eidxs = [expand_idx(node) for node in nodes]
|
30
|
+
return tuple([v if v not in eidxs[:j] else NumNode(0) for j, v in enumerate(eidxs)]) # take only first occurrence of expand variable
|
44
31
|
def iter_idxs(idxs:Tuple[Union[Variable, NumNode], ...]) -> Iterator[Tuple[int,...]]:
|
45
32
|
yield from (x[::-1] for x in itertools.product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]]))
|
46
33
|
|
34
|
+
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
|
35
|
+
idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
|
36
|
+
# TODO: bring back the valid removal logic (correct!)
|
37
|
+
if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid)
|
38
|
+
return (idx, idy), valid
|
39
|
+
|
47
40
|
# expand a Node into List[Node] that enumerates the underlying Variables from min to max
|
48
41
|
# expand increments earlier variables faster than later variables (as specified in the argument)
|
49
42
|
@functools.lru_cache(maxsize=None)
|
@@ -52,94 +45,92 @@ def expand_node(node:Node, idxs:Optional[Tuple[Union[Variable, NumNode], ...]]=N
|
|
52
45
|
return [node.substitute({k:v for k,v in zip(idxs, (NumNode(x) for x in rep)) if isinstance(k, Variable)}) for rep in iter_idxs(idxs)]
|
53
46
|
|
54
47
|
class Linearizer(Kernel):
|
55
|
-
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op,
|
56
|
-
render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
|
57
|
-
return self.uop(UOps.ALU, dtype, (a, render_b), op)
|
48
|
+
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op): return UOp.alu(op, a, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
|
58
49
|
|
59
50
|
# NOTE: the consts have to be cached for deduping of downstream uops to work
|
60
|
-
def const(self, b:
|
61
|
-
return self.
|
62
|
-
|
63
|
-
def cast(self, val: UOp, dtype) -> UOp: return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
|
51
|
+
def const(self, b:ConstType, dtype:DType=dtypes.int32) -> UOp:
|
52
|
+
return self.uops.add(UOps.DEFINE_VAR, dtype, (), b.unbind()[0]) if isinstance(b, Variable) else UOp.const(dtype, b)
|
64
53
|
|
65
54
|
def get_reduce_acc(self, reduceop:LazyOp):
|
66
|
-
|
67
|
-
if reduceop.op
|
68
|
-
|
69
|
-
|
70
|
-
return -math.inf if dtypes.is_float(dtype) else False
|
55
|
+
if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
|
56
|
+
if reduceop.op is ReduceOps.MAX:
|
57
|
+
if dtypes.is_int(reduceop.dtype): return 0 if dtypes.is_unsigned(reduceop.dtype) else -2**(reduceop.dtype.itemsize*8-1)
|
58
|
+
return -math.inf if dtypes.is_float(reduceop.dtype) else False
|
71
59
|
|
72
60
|
# NOTE: once images are loaded, we uop them as their base float
|
73
|
-
def get_base_dtype(self, dt:DType): return dt.base if isinstance(dt, ImageDType) else dt
|
61
|
+
def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt
|
74
62
|
|
75
63
|
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
|
76
64
|
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
|
77
65
|
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
|
78
66
|
ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
|
79
|
-
LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT
|
67
|
+
LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT),
|
80
68
|
SumNode: lambda self,ops,ctx:
|
81
69
|
functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
82
70
|
AndNode: lambda self,ops,ctx:
|
83
|
-
functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL
|
71
|
+
functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
84
72
|
|
85
|
-
def global_load(self, i:int, idxs:
|
73
|
+
def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Optional[UOp]=None, loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]:
|
86
74
|
buf = self.bufs[i]
|
87
|
-
localtype = self.get_base_dtype(buf.dtype if acc is None else
|
88
|
-
const = buf.val if isinstance(buf, ConstBuffer) else
|
75
|
+
localtype = self.get_base_dtype(buf.dtype if acc is None else acc.dtype)
|
76
|
+
const = buf.val if isinstance(buf, ConstBuffer) else None
|
89
77
|
|
90
|
-
|
91
|
-
expand_vars = tuple([rename_var(expand_idx(idx), f"_uidx{j}") for j, idx in enumerate(idxs)])
|
92
|
-
fake_idxs = [idx.substitute({eidx: ev}) if isinstance(eidx:=expand_idx(idx), Variable) else idx for idx, ev in zip(idxs, expand_vars)]
|
78
|
+
expand_vars = expand_idxs(idxs)
|
93
79
|
|
94
80
|
dim, amt = None, 1
|
95
81
|
# float 4 grouping
|
96
82
|
if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [4,2]:
|
97
83
|
dim, amt = upcast_dim[0], len(float4_expand)
|
98
|
-
g_idx, g_valid = self.sts[i].expr_idxs(
|
84
|
+
g_idx, g_valid = self.sts[i].expr_idxs(idxs[:dim] + [float4_expand[0]] + idxs[dim+1:])
|
99
85
|
# do not use float4 if idx is not aligned
|
100
86
|
if g_idx != (g_idx//amt*amt): dim, amt = None, 1
|
101
87
|
if dim is None:
|
102
|
-
g_idx, g_valid = self.sts[i].expr_idxs(
|
88
|
+
g_idx, g_valid = self.sts[i].expr_idxs(idxs)
|
89
|
+
# todo: multioutput test with different output valids to add if acc is None: g_valid = NumNode(1)
|
103
90
|
|
104
91
|
if amt > 1: localtype = localtype.vec(amt)
|
105
92
|
e_idxs, e_valids = expand_node(g_idx, expand_vars), expand_node(g_valid, expand_vars)
|
106
93
|
|
107
94
|
ret = []
|
108
|
-
invalid_value = 0
|
95
|
+
invalid_value = 0
|
96
|
+
acc_count = 0
|
109
97
|
for idx, valid, rep_idx in zip(e_idxs, e_valids, iter_idxs(expand_vars)):
|
110
98
|
this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid)
|
111
|
-
|
99
|
+
# todo: when multiple reduceops are supported, clearly disambiguate and test acc load keys are unique for each reduceop
|
100
|
+
key = f"{acc is not None}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
|
112
101
|
if key not in self.load_cache:
|
113
102
|
if acc is not None:
|
114
|
-
self.load_cache[key] = self.
|
103
|
+
self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, loop_ctx, (self.get_reduce_acc(acc), i, acc_count))
|
104
|
+
acc_count += 1
|
115
105
|
elif this_const is not None:
|
116
106
|
self.load_cache[key] = self.const(this_const, localtype)
|
117
107
|
if valid.min == 0 and valid.max == 1:
|
118
108
|
valid_rendered = valid.render(self.render_ops, self)
|
119
|
-
self.load_cache[key] =
|
109
|
+
self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_rendered, self.load_cache[key], self.const(invalid_value, localtype))
|
120
110
|
elif isinstance(buf.dtype, ImageDType):
|
121
111
|
buf_uop = self.buf_uops[i]
|
122
112
|
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
123
113
|
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
124
|
-
rendered_idx = self.
|
114
|
+
rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
|
125
115
|
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, buf.dtype.base.vec(4))) if valid.min == 0 else tuple()
|
126
|
-
self.load_cache[key] = self.
|
116
|
+
self.load_cache[key] = self.uops.add(UOps.LOAD, buf.dtype.base.vec(4),
|
117
|
+
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
127
118
|
if localtype == localtype.scalar():
|
128
119
|
idx_small = idx%4
|
129
120
|
res = idx_small.render(self.render_ops, self)
|
130
|
-
out = self.
|
121
|
+
out = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
|
131
122
|
for ix in range(idx_small.max, idx_small.min, -1):
|
132
|
-
rvv = self.
|
133
|
-
sel =
|
134
|
-
out =
|
123
|
+
rvv = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
|
124
|
+
sel = UOp.alu(BinaryOps.CMPLT, res, self.const(ix))
|
125
|
+
out = UOp.alu(TernaryOps.WHERE, sel, rvv, out)
|
135
126
|
self.load_cache[key] = out
|
136
127
|
else:
|
137
128
|
buf_uop = self.buf_uops[i]
|
138
129
|
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
139
130
|
rendered_idx = idx.render(self.render_ops, self)
|
140
131
|
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, localtype)) if valid.min == 0 else tuple()
|
141
|
-
self.load_cache[key] = self.
|
142
|
-
ret.append(self.
|
132
|
+
self.load_cache[key] = self.uops.add(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
133
|
+
ret.append(self.uops.add(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
143
134
|
return ret
|
144
135
|
|
145
136
|
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
|
@@ -147,12 +138,12 @@ class Linearizer(Kernel):
|
|
147
138
|
buf_uop = self.buf_uops[i]
|
148
139
|
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
149
140
|
|
150
|
-
|
151
|
-
_idxs = [
|
141
|
+
expand_vars = expand_idxs(idxs)
|
142
|
+
_idxs = zip(*[expand_node(idx, expand_vars) for idx in idxs]) if idxs else [tuple()] # transpose
|
152
143
|
store_offset = dict(zip(_idxs, store))
|
153
144
|
|
154
145
|
# float4 grouping
|
155
|
-
if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand :=
|
146
|
+
if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [2,4]:
|
156
147
|
grouped_store_offset = defaultdict(list)
|
157
148
|
for k in store_offset:
|
158
149
|
_idx = k[:upcast_dim[0]] + (float4_expand[0],) + k[upcast_dim[0]+1:]
|
@@ -162,28 +153,180 @@ class Linearizer(Kernel):
|
|
162
153
|
amt = len(grouped)
|
163
154
|
idx, valid = self.sts[i].expr_idxs(k)
|
164
155
|
assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
|
165
|
-
store_offset_new[k] = self.
|
156
|
+
store_offset_new[k] = self.uops.add(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
|
166
157
|
store_offset = store_offset_new
|
167
158
|
|
168
159
|
stores = []
|
169
|
-
for
|
170
|
-
idx, valid = self.sts[i].expr_idxs(
|
160
|
+
for _idx, var in store_offset.items():
|
161
|
+
idx, valid = self.sts[i].expr_idxs(_idx)
|
171
162
|
if isinstance(buf.dtype, ImageDType):
|
172
|
-
|
173
|
-
rendered_idx = self.
|
163
|
+
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
164
|
+
rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), \
|
165
|
+
tuple(x.render(self.render_ops, self) for x in image_idx))
|
174
166
|
else:
|
175
167
|
rendered_idx = idx.render(self.render_ops, self)
|
176
|
-
if valid.min == 1: stores.append(self.
|
177
|
-
else: stores.append(self.
|
168
|
+
if valid.min == 1: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var)))
|
169
|
+
else: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
|
178
170
|
return stores
|
179
171
|
|
172
|
+
# render loop
|
173
|
+
def render_loop(self, xx:List[Variable], depth:int) -> Tuple[UOp, ...]:
|
174
|
+
new_loops = {x.expr:self.uops.add(UOps.RANGE, dtypes.int32, (
|
175
|
+
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
|
176
|
+
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), arg=(depth,i)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
|
177
|
+
self.loop_uops.update(new_loops)
|
178
|
+
return tuple(new_loops.values())
|
179
|
+
|
180
|
+
def render_reduceop(self, reduceop:LazyOp, accs:Dict[LazyOp, List[UOp]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]],
|
181
|
+
global_idxs, local_idxs, upcast_idxs):
|
182
|
+
# define indicies
|
183
|
+
full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
|
184
|
+
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
|
185
|
+
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
186
|
+
|
187
|
+
def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
|
188
|
+
replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
|
189
|
+
for s in local_sizes:
|
190
|
+
thread_idxs.append(thread_idx % s)
|
191
|
+
thread_idx //= s
|
192
|
+
for alias in aliases:
|
193
|
+
full_var, full_var_sz = NumNode(0), 1
|
194
|
+
if alias[0] != 0:
|
195
|
+
for i in alias:
|
196
|
+
next_var = local_idxs[i-1] if i > 0 else thread_idxs[-i-1]
|
197
|
+
full_var += next_var * full_var_sz
|
198
|
+
full_var_sz *= next_var.max+1
|
199
|
+
replace_idxs.append(full_var)
|
200
|
+
return replace_idxs
|
201
|
+
|
202
|
+
# compute local aliases - modify idxs if necessary for TC
|
203
|
+
alias_buf_idxs = []
|
204
|
+
for i in self.local_alias:
|
205
|
+
localbuf_idx = self.bufs.index(self.local_alias[i])
|
206
|
+
buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
|
207
|
+
if (tc:=self.tensor_core):
|
208
|
+
min_alias_idx = min(self.local_alias.keys())
|
209
|
+
replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
|
210
|
+
for n in range(len(tc.threads)):
|
211
|
+
buf_idxs[self.global_dims+n] = replace_input_idxs[n] # replace locals
|
212
|
+
for n in range(tc.num_upcasts()):
|
213
|
+
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
|
214
|
+
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
|
215
|
+
alias_buf_idxs.append((i, localbuf_idx, buf_idxs,))
|
216
|
+
|
217
|
+
# reduce loop
|
218
|
+
loop_ctx = self.render_loop(reduce_idxs, 2)
|
219
|
+
|
220
|
+
# define accumulator - modify idxs if necessary for TC
|
221
|
+
out_buf = -1 if self.group_for_reduces else 0
|
222
|
+
if (tc:=self.tensor_core):
|
223
|
+
replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
|
224
|
+
for n in range(len(tc.threads)):
|
225
|
+
local_idxs[n] = replace_acc_idxs[n] # replace locals
|
226
|
+
for n in range(len(replace_acc_idxs)-len(tc.threads)):
|
227
|
+
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
|
228
|
+
if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs}")
|
229
|
+
accs[reduceop] = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
|
230
|
+
|
231
|
+
# store local aliases
|
232
|
+
locals_to_store = [(localbuf_idx, buf_idxs, self.global_load(i, buf_idxs)) for i, localbuf_idx, buf_idxs in alias_buf_idxs]
|
233
|
+
|
234
|
+
if (tc:=self.tensor_core):
|
235
|
+
# run tensor cores AST
|
236
|
+
wmma_sz = [prod(l) for l in tc.thread_local_sizes]
|
237
|
+
def upcast_strides(buf:int):
|
238
|
+
strides, next = [], 1
|
239
|
+
for (sz, stride, reduce) in self.upcasted_axis(buf)[tc.num_upcasts():]:
|
240
|
+
strides.append((0 if stride == 0 else next, sz))
|
241
|
+
next *= 1 if stride == 0 else sz
|
242
|
+
return strides
|
243
|
+
upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
|
244
|
+
# cast initial accs
|
245
|
+
wmmas = [self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
|
246
|
+
for x in range(0, len(accs[reduceop]), wmma_sz[2])]
|
247
|
+
for iter in [x[::-1] for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]:
|
248
|
+
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
|
249
|
+
ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
|
250
|
+
self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
|
251
|
+
wmmas[(wmma_idx:=offs[2]//wmma_sz[2])])
|
252
|
+
# TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid
|
253
|
+
wmmas[wmma_idx] = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev))
|
254
|
+
# phi the last wmmas back to accs
|
255
|
+
accs[reduceop] = [self.uops.add(UOps.PHI, tc.dtype_out, (acc, self.uops.add(UOps.GEP, tc.dtype_out, (wmmas[z//wmma_sz[2]],), z%wmma_sz[2])))
|
256
|
+
for z, acc in enumerate(accs[reduceop])]
|
257
|
+
else:
|
258
|
+
assert not locals_to_store, "storing locals isn't supported here"
|
259
|
+
|
260
|
+
# load earlybufs
|
261
|
+
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i,
|
262
|
+
global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs})
|
263
|
+
|
264
|
+
# run early AST (with reduce)
|
265
|
+
self.ast_parse(reduceop, accs, self.acc_offsets(self.full_buf_index), loaded_buffers, reduce_acc=accs[reduceop])
|
266
|
+
|
267
|
+
# end the reduce loop
|
268
|
+
self.load_cache.clear()
|
269
|
+
|
270
|
+
# end the local loop, do the local reduce
|
271
|
+
if self.group_for_reduces:
|
272
|
+
fake_global_idxs = [x*0 for x in global_idxs]
|
273
|
+
stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop]) # store accumulators
|
274
|
+
barrier = self.uops.add(UOps.BARRIER, None, tuple(stores))
|
275
|
+
if self.opts.has_local:
|
276
|
+
fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
|
277
|
+
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
278
|
+
if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(self.render_ops, self)
|
279
|
+
barrier = self.uops.add(UOps.IF, None, (if_cond, barrier))
|
280
|
+
|
281
|
+
# create new late reduce local loops and replace local_idxs that have been used
|
282
|
+
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501
|
283
|
+
local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
|
284
|
+
|
285
|
+
# if any group_for_reduce items aren't reduces, upcast them here
|
286
|
+
for j in self.upcast_in_mid_reduce_axes:
|
287
|
+
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
|
288
|
+
self.upcast()
|
289
|
+
self.group_for_reduces -= 1
|
290
|
+
local_idxs = local_idxs[:-1]
|
291
|
+
end_local_idxs = end_local_idxs[:-1]
|
292
|
+
# regenerate upcast_idxs
|
293
|
+
upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
|
294
|
+
|
295
|
+
# NOTE: this structure is the same as the reduce op above
|
296
|
+
|
297
|
+
# late reduce loop
|
298
|
+
loop_ctx = self.render_loop(end_local_idxs, 3)
|
299
|
+
|
300
|
+
# define late accumulator
|
301
|
+
accs[reduceop] = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
|
302
|
+
|
303
|
+
# load localbufs
|
304
|
+
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
|
305
|
+
|
306
|
+
# there's no AST here (and there's no shape for the reduce LazyOp)
|
307
|
+
self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)),\
|
308
|
+
accs, self.acc_offsets(-1), loaded_buffers, reduce_acc=accs[reduceop])
|
309
|
+
|
310
|
+
# end the late reduce loop
|
311
|
+
self.load_cache.clear()
|
312
|
+
|
313
|
+
# all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
|
314
|
+
# been rewritten with fake end_local_idxs.
|
315
|
+
return (accs, loaded_buffers, fake_reduce_idxs, local_idxs[:self.local_dims] + [NumNode(0) for i in range(self.group_for_reduces)], upcast_idxs)
|
316
|
+
|
180
317
|
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
181
318
|
def linearize(self):
|
182
319
|
# no new opts and we already ran? skip relinearizing
|
183
320
|
if self.applied_opts == self.applied_opts_cache: return self
|
184
321
|
|
322
|
+
# late alias the tensor core buffers
|
323
|
+
if (tc:=self.tensor_core) and (tc_opts:=self.tensor_core_opts):
|
324
|
+
alias_pattern = [0]*(self.global_dims) + [2]*(len(tc.threads)) + [0]*(self.local_dims-len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) # noqa: E501
|
325
|
+
for tc_buf in tc_opts.bufs:
|
326
|
+
self.alias_buffer(tc_buf, alias_pattern)
|
327
|
+
|
185
328
|
# save backups
|
186
|
-
sts_backup, gfr_backup, upc_backup = self.sts[:], self.
|
329
|
+
sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduces, self.upcasted
|
187
330
|
|
188
331
|
# global uop cache
|
189
332
|
self.saved_exprs: Dict[Tuple, UOp] = dict()
|
@@ -192,31 +335,36 @@ class Linearizer(Kernel):
|
|
192
335
|
if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max)
|
193
336
|
|
194
337
|
# uops
|
195
|
-
self.uops:
|
338
|
+
self.uops:UOpGraph = UOpGraph()
|
196
339
|
self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
|
197
340
|
self.loop_uops: Dict[str, UOp] = {}
|
198
341
|
|
199
342
|
# add global buffers
|
200
343
|
for i,buf in enumerate(self.bufs):
|
201
344
|
if isinstance(buf, MemBuffer):
|
202
|
-
self.buf_uops[i] = self.
|
345
|
+
self.buf_uops[i] = self.uops.add(UOps.DEFINE_GLOBAL,
|
346
|
+
buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
|
347
|
+
(buf.idx, any(buf.idx == x.idx for x in self.outbufs)))
|
203
348
|
# add var vals
|
204
|
-
for var in self.
|
349
|
+
for i,var in enumerate(self.vars):
|
205
350
|
assert var.expr is not None
|
206
|
-
self.loop_uops[var.expr] = self.
|
351
|
+
self.loop_uops[var.expr] = self.uops.add(UOps.DEFINE_VAR, dtypes.int32, (), var)
|
207
352
|
# define local buffers
|
208
353
|
for lb in self.local_alias.values():
|
209
|
-
self.buf_uops[self.bufs.index(lb)] = self.
|
354
|
+
self.buf_uops[self.bufs.index(lb)] = self.uops.add(UOps.DEFINE_LOCAL,
|
355
|
+
PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size))
|
210
356
|
# add a local buffer for multistage reduce. # TODO: use local alias
|
211
|
-
if self.
|
357
|
+
if self.group_for_reduces:
|
212
358
|
# TODO: the strides of this can be controlled
|
213
|
-
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+
|
214
|
-
temp_dtype = self.get_base_dtype(
|
359
|
+
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
|
360
|
+
temp_dtype = self.get_base_dtype(cast(LazyOp, self.reduceop).dtype)
|
215
361
|
self.bufs.append(LocalBuffer("temp", self.sts[-1].size, temp_dtype))
|
216
|
-
self.buf_uops.append(self.
|
362
|
+
self.buf_uops.append(self.uops.add(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size)))
|
217
363
|
|
218
364
|
# kernel name (before late upcast)
|
219
|
-
self.name = ("
|
365
|
+
self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
|
366
|
+
(f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \
|
367
|
+
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
220
368
|
|
221
369
|
# name the function something unique
|
222
370
|
Linearizer.kernel_cnt[(function_name := to_function_name(self.name))] += 1
|
@@ -225,342 +373,88 @@ class Linearizer(Kernel):
|
|
225
373
|
|
226
374
|
# define indexes
|
227
375
|
global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
|
228
|
-
local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+
|
229
|
-
|
230
|
-
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
|
231
|
-
|
232
|
-
# global and local loops
|
233
|
-
def render_loop(xx:List[Variable]) -> Tuple[UOp, ...]:
|
234
|
-
new_loops = {x.expr:self.uop(UOps.LOOP, dtypes.int32, (
|
235
|
-
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
|
236
|
-
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
|
237
|
-
self.loop_uops.update(new_loops)
|
238
|
-
return tuple(new_loops.values())
|
376
|
+
local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+self.group_for_reduces], 3 if self.opts.has_local else 0) # noqa: E501
|
377
|
+
upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
|
239
378
|
|
240
379
|
# set global/local size
|
241
380
|
self.global_size: Optional[List[int]] = None
|
242
381
|
self.local_size: Optional[List[int]] = None
|
243
382
|
if self.dont_use_locals:
|
244
383
|
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
|
245
|
-
self.loop_uops.update({x.expr:self.
|
384
|
+
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
|
246
385
|
elif self.opts.has_local:
|
247
|
-
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs]
|
248
|
-
self.loop_uops.update({x.expr:self.
|
249
|
-
self.loop_uops.update({x.expr:self.
|
386
|
+
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs]
|
387
|
+
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
|
388
|
+
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
|
250
389
|
else:
|
251
|
-
render_loop(loop_global_idxs+loop_local_idxs)
|
390
|
+
self.render_loop(loop_global_idxs+loop_local_idxs, 1)
|
391
|
+
if self.global_size is not None: self.global_size += [1]*(3-len(self.global_size))
|
392
|
+
if self.local_size is not None: self.local_size += [1]*(3-len(self.local_size))
|
252
393
|
|
253
394
|
# parse AST
|
254
|
-
loaded_buffers = {}
|
255
|
-
|
395
|
+
loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]] = {}
|
396
|
+
accs: Dict[LazyOp, List[UOp]] = {}
|
256
397
|
self.load_cache: Dict[str, UOp] = {}
|
257
398
|
|
258
399
|
# reduce op
|
259
400
|
fake_reduce_idxs: List[Variable] = []
|
260
|
-
if self.reduceop is not None:
|
261
|
-
|
262
|
-
|
263
|
-
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
264
|
-
|
265
|
-
# define accumulator
|
266
|
-
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))
|
267
|
-
|
268
|
-
if self.tensor_core:
|
269
|
-
def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
|
270
|
-
replace_idxs = []
|
271
|
-
for alias in aliases:
|
272
|
-
full_var, full_var_sz = NumNode(0), 1
|
273
|
-
if alias[0] != 0:
|
274
|
-
for i in alias:
|
275
|
-
next_var = local_idxs[-i] if i > 0 else Variable(None, 0, local_size-1)
|
276
|
-
full_var += next_var * full_var_sz
|
277
|
-
full_var_sz *= next_var.max+1
|
278
|
-
replace_idxs.append(full_var)
|
279
|
-
return replace_idxs
|
280
|
-
replace_acc_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[2], self.tensor_core.thread_local_aliases[2])
|
281
|
-
for n in range(len(self.tensor_core.threads)):
|
282
|
-
local_idxs[self.local_dims-len(self.tensor_core.threads)+n] = replace_acc_idxs[n] # replace locals
|
283
|
-
for n in range(len(replace_acc_idxs)-len(self.tensor_core.threads)):
|
284
|
-
upcast_idxs[n] = replace_acc_idxs[len(self.tensor_core.threads)+n] # replace upcasts
|
285
|
-
|
286
|
-
# reduce loop
|
287
|
-
loop_ctx = render_loop(reduce_idxs)
|
288
|
-
|
289
|
-
# barrier for fast GEMM
|
290
|
-
if self.tensor_core: self.uop(UOps.BARRIER, None, (), cachable=False)
|
291
|
-
|
292
|
-
# compute local aliases
|
293
|
-
locals_to_store = []
|
294
|
-
for i in self.local_alias:
|
295
|
-
localbuf_idx = self.bufs.index(self.local_alias[i])
|
296
|
-
buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
|
297
|
-
if self.tensor_core:
|
298
|
-
min_alias_idx = min(self.local_alias.keys())
|
299
|
-
replace_input_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[i-min_alias_idx], self.tensor_core.thread_local_aliases[i-min_alias_idx]) # noqa: E501
|
300
|
-
for n in range(len(self.tensor_core.threads)):
|
301
|
-
buf_idxs[self.first_reduce-len(self.tensor_core.threads)+n] = replace_input_idxs[n] # replace locals
|
302
|
-
for n in range(len(replace_input_idxs)-len(self.tensor_core.threads)):
|
303
|
-
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(self.tensor_core.threads)+n] # replace upcasts
|
304
|
-
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: idxs=", buf_idxs)
|
305
|
-
ll = self.global_load(i, buf_idxs)
|
306
|
-
locals_to_store.append((localbuf_idx, buf_idxs, ll))
|
307
|
-
|
308
|
-
# copy in any global buffers
|
309
|
-
if self.tensor_core:
|
310
|
-
wmma_sz = self.tensor_core.thread_local_sizes
|
311
|
-
# calculate the number of local accumulator reduces and render WMMAs: this is bad... this needs to come from someplace else
|
312
|
-
nx, ny, nacc = (len(locals_to_store[0][2])//wmma_sz[0]), (len(locals_to_store[1][2])//wmma_sz[1]), (len(acc)//wmma_sz[2])
|
313
|
-
acc_reds = math.isqrt((nx*ny)//nacc)
|
314
|
-
i, bx, by = 0, nx//acc_reds, ny//acc_reds
|
315
|
-
for y in range(by):
|
316
|
-
for x in range(bx):
|
317
|
-
for j in range(acc_reds):
|
318
|
-
op1, op2, op3 = locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]], locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]], acc[i:i+wmma_sz[2]] # noqa: E501
|
319
|
-
if self.opts.device != "HIP":
|
320
|
-
ops = tuple(op1+op2+op3)
|
321
|
-
else:
|
322
|
-
ops = (self.uop(UOps.CAST, dtypes.half.vec(16), tuple(op1)),
|
323
|
-
self.uop(UOps.CAST, dtypes.half.vec(16), tuple(op2)),
|
324
|
-
self.uop(UOps.CAST, dtypes.float.vec(8), tuple(op3)))
|
325
|
-
ret = self.uop(UOps.WMMA, dtypes.float.vec(2) if wmma_sz[2] == 2 else dtypes.float.vec(8), ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,)) # noqa: E501
|
326
|
-
for z in range(cast(DType, ret.dtype).sz):
|
327
|
-
acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + loop_ctx)
|
328
|
-
i += wmma_sz[2]
|
329
|
-
else:
|
330
|
-
if locals_to_store:
|
331
|
-
self.uop(UOps.BARRIER, None, (), cachable=False)
|
332
|
-
for i, idxs, ll in locals_to_store: self.global_store(i, idxs, ll)
|
333
|
-
self.uop(UOps.BARRIER, None, (), cachable=False)
|
334
|
-
|
335
|
-
# load earlybufs
|
336
|
-
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs}) # noqa: E501
|
337
|
-
|
338
|
-
# run early AST (with reduce)
|
339
|
-
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
|
340
|
-
|
341
|
-
# end the reduce loop
|
342
|
-
self.load_cache.clear()
|
343
|
-
|
344
|
-
# end the local loop, do the local reduce
|
345
|
-
if self.group_for_reduce:
|
346
|
-
fake_global_idxs = [x*0 for x in global_idxs]
|
347
|
-
stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
348
|
-
barrier = self.uop(UOps.BARRIER, None, tuple(stores), cachable=False)
|
349
|
-
if self.opts.has_local:
|
350
|
-
fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
|
351
|
-
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
352
|
-
if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self)
|
353
|
-
barrier = self.uop(UOps.IF, None, (if_cond, barrier), cachable=False)
|
354
|
-
|
355
|
-
# create new late reduce local loops and replace local_idxs that have been used
|
356
|
-
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] # noqa: E501
|
357
|
-
local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
|
358
|
-
|
359
|
-
# if any group_for_reduce items aren't reduces, upcast them here
|
360
|
-
for j in self.upcast_in_mid_reduce_axes:
|
361
|
-
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
|
362
|
-
self.upcast()
|
363
|
-
self.group_for_reduce.pop()
|
364
|
-
local_idxs = local_idxs[:-1]
|
365
|
-
end_local_idxs = end_local_idxs[:-1]
|
366
|
-
# regenerate upcast_idxs
|
367
|
-
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
|
368
|
-
|
369
|
-
# NOTE: this structure is the same as the reduce op above
|
370
|
-
|
371
|
-
# define late accumulator
|
372
|
-
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))
|
373
|
-
|
374
|
-
# late reduce loop
|
375
|
-
loop_ctx = render_loop(end_local_idxs)
|
376
|
-
|
377
|
-
# load localbufs
|
378
|
-
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
|
379
|
-
|
380
|
-
# there's no AST here (and there's no shape for the reduce LazyOp)
|
381
|
-
self.ast_parse(LazyOp(self.reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # noqa: E501
|
382
|
-
|
383
|
-
# end the late reduce loop
|
384
|
-
self.load_cache.clear()
|
401
|
+
for reduceop in [self.reduceop] if self.reduceop is not None else []:
|
402
|
+
accs,loaded_buffers,fake_reduce_idxs,local_idxs,upcast_idxs = \
|
403
|
+
self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs)
|
385
404
|
|
386
405
|
# load latebufs
|
387
|
-
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs)
|
406
|
+
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
|
407
|
+
for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})
|
388
408
|
|
389
409
|
# run late AST (without the store)
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
394
|
-
|
395
|
-
# get PHI node loop scope, link anything using a DEFINE_ACC to the loop as a "parent"
|
396
|
-
acc_scope: DefaultDict[UOp, List[UOp]] = defaultdict(list)
|
397
|
-
for u in self.uops:
|
398
|
-
if u.uop == UOps.PHI:
|
399
|
-
acc_scope[u.vin[0]] += u.vin[2:]
|
400
|
-
|
401
|
-
# graph helper functions
|
402
|
-
@functools.lru_cache(None)
|
403
|
-
def get_recursive_parents(x:UOp, with_phi=False) -> Set[UOp]:
|
404
|
-
return set.union(set(x.vin), *[get_recursive_parents(p, with_phi) for p in x.vin], set(acc_scope[x]) if with_phi else set())
|
405
|
-
|
406
|
-
def get_recursive_children(x:UOp) -> Set[UOp]:
|
407
|
-
deps = set([x])
|
408
|
-
ssize = 0
|
409
|
-
while ssize != len(deps):
|
410
|
-
ssize = len(deps)
|
411
|
-
for u in self.uops:
|
412
|
-
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
|
413
|
-
deps.add(u)
|
414
|
-
return deps
|
415
|
-
|
416
|
-
def replace_op(old:UOp, new:UOp):
|
417
|
-
for u in self.uops:
|
418
|
-
u.vin = tuple(new if x is old else x for x in u.vin)
|
419
|
-
self.uops.remove(old)
|
420
|
-
|
421
|
-
# fix loop scope, push uops upward out of loop if it does not depend on the loop
|
422
|
-
loop_stack: List[List[UOp]] = [[]]
|
423
|
-
for u in self.uops:
|
424
|
-
if not loop_stack[-1]: loop_stack[-1].append(u)
|
425
|
-
elif u.uop == UOps.LOOP: loop_stack.append([u])
|
426
|
-
elif u.uop not in [UOps.CONST, UOps.ALU, UOps.CAST, UOps.LOAD]: loop_stack[-1].append(u)
|
427
|
-
else:
|
428
|
-
parents = get_recursive_parents(u, with_phi=True)
|
429
|
-
# don't push any local buffer because there might have STORE and BARRIER (not considered as parent) between DEFINE_LOCAL and here
|
430
|
-
if any(u.uop == UOps.DEFINE_LOCAL for u in parents): loop_stack[-1].append(u)
|
431
|
-
else:
|
432
|
-
for i in reversed(range(len(loop_stack))):
|
433
|
-
# check backwards and put the uop in the first encounter with some dependency
|
434
|
-
if any(x in parents for x in loop_stack[i]) or i == 0:
|
435
|
-
loop_stack[i].append(u)
|
436
|
-
break
|
437
|
-
self.uops = flatten(loop_stack)
|
438
|
-
|
439
|
-
# uops optimization
|
440
|
-
changed_something = True
|
441
|
-
while changed_something:
|
442
|
-
changed_something = False
|
443
|
-
for u in self.uops:
|
444
|
-
if u.uop == UOps.PHI and len(u.vin) == 3:
|
445
|
-
# if the parents of the PHI node don't have the LOOP in their parents, it can be folded
|
446
|
-
# TODO: ADD becomes a MUL, MAX can just become nothing
|
447
|
-
if all(x.uop != UOps.LOOP for x in get_recursive_parents(UOp(u.uop, u.dtype, u.vin[0:2], u.arg))) and u.vin[1].arg == BinaryOps.ADD:
|
448
|
-
if DEBUG >= 4: print(f"removing PHI node {u}")
|
449
|
-
del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)]
|
450
|
-
# NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype
|
451
|
-
loop_len = self.uop(UOps.ALU, u.vin[2].vin[1].dtype, (u.vin[2].vin[1], u.vin[2].vin[0]), BinaryOps.SUB, insert_before=self.uops.index(u))
|
452
|
-
if loop_len.dtype != u.dtype: loop_len = self.uop(UOps.CAST, u.dtype, (loop_len,), insert_before=self.uops.index(u))
|
453
|
-
replace_op(u, self.uop(UOps.ALU, u.dtype, (u.vin[1], loop_len,), BinaryOps.MUL, insert_before=self.uops.index(u)))
|
454
|
-
changed_something = True
|
455
|
-
|
456
|
-
# (recursively) remove childless uops
|
457
|
-
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
|
458
|
-
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL}
|
459
|
-
while 1:
|
460
|
-
has_child: Set[UOp] = set()
|
461
|
-
for ru in self.uops:
|
462
|
-
for vu in ru.vin:
|
463
|
-
has_child.add(vu)
|
464
|
-
nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop in UOPS_W_SIDE_EFFECTS]
|
465
|
-
if len(nu) == len(self.uops): break
|
466
|
-
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
|
467
|
-
self.uops = nu
|
468
|
-
del nu
|
469
|
-
|
470
|
-
# add UOps.END
|
471
|
-
for u in self.uops:
|
472
|
-
if u.uop == UOps.LOOP:
|
473
|
-
# add END of loops after the last thing that (recursively) depends on them
|
474
|
-
self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(sorted(list(get_recursive_children(u)), key=self.uops.index)[-1])+1) # noqa: E501
|
475
|
-
elif u.uop == UOps.IF:
|
476
|
-
# END any if statements at the end of the uops
|
477
|
-
self.uop(UOps.END, None, (u,), cachable=False)
|
410
|
+
for op in self.ast:
|
411
|
+
val = self.ast_parse(op.src[0], accs, None, loaded_buffers)
|
412
|
+
self.global_store(op.arg.idx, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
478
413
|
|
479
414
|
# maybe graph the uops
|
480
|
-
if DEBUG >= 5:
|
481
|
-
|
482
|
-
print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") # noqa: E501
|
483
|
-
if getenv("GRAPHUOPS"):
|
484
|
-
from tinygrad.graph import graph_uops
|
485
|
-
graph_uops(self.uops)
|
415
|
+
if DEBUG >= 5: self.uops.print()
|
416
|
+
if getenv("GRAPHUOPS"): self.uops.graph()
|
486
417
|
|
487
418
|
# restore backups
|
488
|
-
self.sts, self.
|
419
|
+
self.sts, self.group_for_reduces, self.upcasted = sts_backup, gfr_backup, upc_backup
|
489
420
|
|
490
421
|
# set cache and return
|
491
422
|
self.applied_opts_cache = self.applied_opts[:]
|
492
423
|
return self
|
493
424
|
|
494
|
-
def
|
495
|
-
if uop == UOps.ALU:
|
496
|
-
if arg in UnaryOps:
|
497
|
-
assert dtype == vin[0].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=}"
|
498
|
-
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPEQ):
|
499
|
-
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
500
|
-
assert vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
|
501
|
-
elif arg in BinaryOps:
|
502
|
-
assert dtype == vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
|
503
|
-
elif arg == TernaryOps.WHERE:
|
504
|
-
assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
|
505
|
-
assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}"
|
506
|
-
|
507
|
-
if simplify:
|
508
|
-
if uop == UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop
|
509
|
-
if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype, insert_before)
|
510
|
-
if uop == UOps.CAST and all(x.uop == UOps.CONST for x in vin) and all_same([x.arg for x in vin]):
|
511
|
-
return self.const(vin[0].arg, dtype, insert_before)
|
512
|
-
if uop == UOps.ALU:
|
513
|
-
# rewrites. NOTE: the rewritten NEG op is still around...
|
514
|
-
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG:
|
515
|
-
return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable, insert_before)
|
516
|
-
# constant folding
|
517
|
-
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype, insert_before)
|
518
|
-
if arg == TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
|
519
|
-
# zero folding
|
520
|
-
for x in [0,1]:
|
521
|
-
if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
|
522
|
-
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
|
523
|
-
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x]
|
524
|
-
if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0]
|
525
|
-
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
|
526
|
-
|
527
|
-
key = (uop, dtype, vin, arg)
|
528
|
-
if insert_before is None: insert_before = len(self.uops)
|
529
|
-
# check if the cached expr is valid with the given insert place.
|
530
|
-
if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr
|
531
|
-
ret = UOp(uop, dtype, vin, arg)
|
532
|
-
self.uops.insert(insert_before, ret)
|
533
|
-
if cachable: self.saved_exprs[key] = ret
|
534
|
-
return ret
|
535
|
-
|
536
|
-
def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], do_reduce=False, loop_ctx=tuple(), cache=None) -> List[UOp]: # noqa: E501
|
425
|
+
def ast_parse(self, x:LazyOp, accs:Dict[LazyOp, List[UOp]], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], reduce_acc:Optional[List[UOp]]=None, cache=None) -> List[UOp]: # noqa: E501
|
537
426
|
if cache is None: cache = {}
|
538
427
|
if x in cache: return cache[x]
|
539
428
|
if x.op in BufferOps: return loaded_buffers[x.arg]
|
540
|
-
if x.op
|
541
|
-
return [self.
|
542
|
-
|
429
|
+
if x.op in [UnaryOps.CAST, UnaryOps.BITCAST]:
|
430
|
+
return [self.uops.add(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST,
|
431
|
+
self.get_base_dtype(x.arg), (u,)) for u in self.ast_parse(x.src[0], accs, offs, loaded_buffers)]
|
432
|
+
if x.op in ReduceOps and reduce_acc is None:
|
543
433
|
assert offs is None, "not available if we aren't doing reduce"
|
544
|
-
return
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
if (castop:=x.src[0]).op == UnaryOps.CAST and (mulop:=castop.src[0]).op == BinaryOps.MUL:
|
549
|
-
# MULACC with acc cast rewrite: MUL -> CAST -> SUM => CAST -> MULACC
|
550
|
-
x = LazyOp(TernaryOps.MULACC, tuple(LazyOp(UnaryOps.CAST, (s, ), castop.arg) for s in mulop.src), x.arg)
|
551
|
-
|
552
|
-
values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx, cache=cache) for v in x.src]
|
553
|
-
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
|
434
|
+
return accs[x]
|
435
|
+
|
436
|
+
values = [self.ast_parse(v, accs, offs, loaded_buffers, cache=cache) for v in x.src]
|
437
|
+
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}
|
554
438
|
if x.op in ops:
|
439
|
+
assert reduce_acc is not None
|
555
440
|
ret: List[UOp] = []
|
556
|
-
input_acc =
|
441
|
+
acc, input_acc = reduce_acc, reduce_acc[:]
|
557
442
|
for val, off in zip(zip(*values), cast(List[int], offs)):
|
558
|
-
acc[off] =
|
443
|
+
acc[off] = UOp.alu(ops[cast(ReduceOps, x.op)], *(val+(acc[off], )))
|
559
444
|
ret.append(acc[off])
|
560
445
|
for off in range(len(acc)):
|
561
446
|
if input_acc[off] != acc[off]:
|
562
|
-
acc[off] = self.
|
563
|
-
else:
|
564
|
-
ret = [self.uop(UOps.ALU, dtypes.bool if x.op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else val[-1].dtype, val, x.op) for val in zip(*values)]
|
447
|
+
acc[off] = self.uops.add(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]))
|
448
|
+
else: ret = [UOp.alu(x.op, *vin) for vin in zip(*values)]
|
565
449
|
cache[x] = ret
|
566
450
|
return ret
|
451
|
+
|
452
|
+
def to_program(self) -> Program:
|
453
|
+
self.linearize()
|
454
|
+
info = get_lazyop_info(self.ast[0])
|
455
|
+
src = self.opts.render(to_function_name(self.name), self.uops)
|
456
|
+
ops, mem = self.uops.flops_mem()
|
457
|
+
run_count = prod((self.global_size if self.global_size else []) + (self.local_size if self.local_size else []))
|
458
|
+
# NOTE: we use min here to ignore the indexing FLOPS
|
459
|
+
return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
|
460
|
+
self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|