tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. tinygrad/__init__.py +6 -0
  2. tinygrad/codegen/kernel.py +572 -83
  3. tinygrad/codegen/linearizer.py +415 -395
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +183 -0
  6. tinygrad/dtype.py +113 -0
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +76 -55
  14. tinygrad/helpers.py +196 -89
  15. tinygrad/lazy.py +210 -371
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +202 -22
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +112 -32
  20. tinygrad/nn/state.py +136 -39
  21. tinygrad/ops.py +119 -202
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +353 -166
  25. tinygrad/renderer/llvmir.py +150 -138
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +81 -0
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +75 -0
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +24 -77
  43. tinygrad/runtime/ops_cuda.py +175 -89
  44. tinygrad/runtime/ops_disk.py +56 -33
  45. tinygrad/runtime/ops_gpu.py +92 -95
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +39 -60
  48. tinygrad/runtime/ops_metal.py +92 -74
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +86 -254
  53. tinygrad/shape/symbolic.py +166 -141
  54. tinygrad/shape/view.py +296 -0
  55. tinygrad/tensor.py +2619 -448
  56. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. tinygrad-0.9.0.dist-info/METADATA +227 -0
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/codegen/assembly.py +0 -190
  61. tinygrad/codegen/optimizer.py +0 -379
  62. tinygrad/codegen/search.py +0 -72
  63. tinygrad/graph.py +0 -83
  64. tinygrad/jit.py +0 -57
  65. tinygrad/nn/image.py +0 -100
  66. tinygrad/renderer/assembly_arm64.py +0 -169
  67. tinygrad/renderer/assembly_ptx.py +0 -98
  68. tinygrad/renderer/wgsl.py +0 -53
  69. tinygrad/runtime/lib.py +0 -113
  70. tinygrad/runtime/ops_cpu.py +0 -51
  71. tinygrad/runtime/ops_hip.py +0 -82
  72. tinygrad/runtime/ops_shm.py +0 -29
  73. tinygrad/runtime/ops_torch.py +0 -30
  74. tinygrad/runtime/ops_webgpu.py +0 -45
  75. tinygrad-0.7.0.dist-info/METADATA +0 -212
  76. tinygrad-0.7.0.dist-info/RECORD +0 -40
  77. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,440 +1,460 @@
1
- from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, TypeVar, Dict, Iterator, Union, Sequence, Final
2
- import itertools, math
1
+ from __future__ import annotations
2
+ from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence
3
+ import itertools, math, functools
3
4
  from collections import defaultdict
4
- from enum import Enum, auto
5
5
 
6
- from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, mnum, DType, all_same, partition
7
- from tinygrad.ops import LazyOp, UnaryOps, Op
8
- from tinygrad.lazy import LazyBuffer
9
- from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
10
- from tinygrad.runtime.lib import RawConst
6
+ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
7
+ from tinygrad.helpers import colored, DEBUG, prod, getenv, to_function_name
8
+ from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
11
9
  from tinygrad.shape.shapetracker import ShapeTracker
12
- from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, sym_rename
13
- from tinygrad.codegen.optimizer import OptimizedKernel
14
- from tinygrad.codegen.kernel import LocalBuffer, LinearizerOptions # noqa: F401 # pylint:disable=unused-import
15
- VariableOrNum = Union[Variable, NumNode, Node]
16
-
17
- # bottom ones are asm only
18
- class UOps(Enum):
19
- LOOP = auto(); ENDLOOP = auto() # loops can be global, local, or other # noqa: E702
20
- DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto() # this defines buffers # noqa: E702
21
- LOAD = auto(); STORE = auto(); BARRIER = auto() # noqa: E702
22
- ALU = auto(); WMMA = auto(); CAST = auto() # noqa: E702
23
- # TODO: add CONST. use ALU WHERE for gated load
24
- # *** assembly only UOps ***
25
- SPECIAL = auto(); LABEL = auto(); COND_BRANCH = auto() # TODO: replace these with LOOP and ENDLOOP # noqa: E702
26
-
27
- def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]:
28
- idy = (idxy//(4*base_shape[1]))
29
- if validhacks and valid.min == 0:
30
- idx = (idxy//4) + (idy*-base_shape[1])
31
- # find the ones in idx that didn't factorize and remove them (TODO: this is not universal)
32
- if isinstance(idx, SumNode):
33
- unfactored, idx_nodes = partition(idx.nodes, lambda x: isinstance(x, MulNode) and x.b == -base_shape[1])
34
- assert len(unfactored) <= 1
35
- idx = Variable.sum(idx_nodes)
36
- unfactored = (Variable.sum(unfactored) // base_shape[1])
37
- idy += unfactored
38
- # ugh really...handtuned garbage
39
- if idx.min >= (base_shape[1]*3)//4:
40
- idx -= base_shape[1]
41
- idy += 1
42
- else:
43
- idx = (idxy//4)%base_shape[1]
44
- if DEBUG >= 5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
45
- return idx, idy
46
-
47
- class Token(NamedTuple):
48
- name: str
49
- dtype: DType
50
- offset: Optional[int] = None
51
- def render(self, with_type=False):
52
- if with_type:
53
- assert self.offset is None
54
- return f"{self.dtype.name} {self.name}"
55
- if self.offset is None: return self.name
56
- assert self.dtype in [dtypes._float4, dtypes._float2], f"{self.dtype} isn't okay with offset {self.offset}"
57
- return self.name+"."+"xyzw"[int(self.offset)]
58
- def __repr__(self): return f"<{self.name}>" if self.offset is None and self.dtype == dtypes.float32 else f"<{self.name}:{self.dtype.name}:{self.offset}>"
59
-
60
- # TODO: the next three functions are poorly written
61
- def get_grouped_float4_idxs(acc:List[Token]) -> Optional[List[int]]:
62
- idxs: Optional[List[int]] = []
63
- for i,a in enumerate(acc):
64
- if idxs is None: break
65
- if i in idxs: continue
66
- if a.dtype.sz > 1 and a.offset == 0:
67
- idxs.append(i)
68
- friends: List[int] = []
69
- for j,b in enumerate(acc):
70
- if len(friends) == 3: break
71
- if j in idxs: continue
72
- if a.name == b.name and b.dtype.sz > 1 and b.offset == len(friends)+1:
73
- friends.append(j)
74
- if len(friends) == 3: idxs += friends
75
- else: idxs = None
76
- else:
77
- idxs = None
78
- return idxs
79
-
80
- def to_float4(x:List[Token]) -> Optional[Token]:
81
- if all_same(x): return x[0]
82
- if all_same([y.name for y in x]) and all(y.dtype == dtypes._float4 and y.offset == i for i,y in enumerate(x)):
83
- return Token(x[0].name, dtypes._float4)
84
- return None
85
-
86
- def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True):
87
- assert all_same([len(x) for x in values]), f"all values are not the same length {values}"
88
- # these use accumulators, we can only fold if the acc is a float4
89
- idxs = get_grouped_float4_idxs(values[-1]) if grouping_allowed else None
90
- if idxs is not None:
91
- new_idxs = []
92
- new_values = []
93
- for i in range(0, len(idxs), 4):
94
- nv = [to_float4([v[j] for j in idxs[i:i+4]]) for v in values]
95
- if any(x is None for x in nv): break
96
- new_idxs.append(idxs[i:i+4])
97
- new_values.append(nv)
98
- if len(new_values) == len(idxs)//4:
99
- return zip(new_idxs, new_values)
100
- return zip([[i] for i in range(len(values[0]))], zip(*values))
101
-
102
- # TODO: generic visitor pattern?
103
- def expand_node(idx:Node) -> List[Node]:
104
- if isinstance(idx, Variable): return [idx] if idx.expr is not None else [Variable.num(j) for j in range(idx.min, idx.max+1)]
105
- if isinstance(idx, NumNode): return [idx]
106
- if isinstance(idx, MulNode): return [x*idx.b for x in expand_node(idx.a)]
107
- if isinstance(idx, SumNode): return [Variable.sum(list(it)) for it in itertools.product(*[expand_node(x) for x in idx.nodes])]
108
- raise NotImplementedError(idx)
109
-
110
- def expand_idxs(idxs:Sequence[Node]) -> Iterator[Tuple[Node, ...]]:
111
- for x in itertools.product(*[expand_node(idx) for idx in idxs[::-1]]):
112
- yield x[::-1]
113
-
114
- class MemOp(NamedTuple):
115
- name: str
116
- idx: Node
117
- local: bool
118
- memory_dtype: DType
119
-
120
- # shared
121
- valid: Node
122
- invalid_value: Union[float, int] = 0.0
123
-
124
- class ConstOp(NamedTuple):
125
- value: Union[float, int]
126
-
127
- # shared
128
- valid: Node
129
- invalid_value: Union[float, int] = 0.0
130
-
131
- class UOp(NamedTuple):
132
- uop: UOps
133
- out: Optional[Token]
134
- vin: List[Token]
135
- arg: Any
136
- def __repr__(self): return f"{str(self.uop):20s}: {str(self.out) if self.out is not None else '':25s} {str(self.vin):32s} {self.arg}"
137
-
138
- class Linearizer(OptimizedKernel):
139
- def get_buffer_name(self, i):
140
- if self.bufs[i].__class__ == LocalBuffer: return self.bufs[i].name
141
- assert self.bufs[i].realized.__class__ is not RawConst # constants shouldn't be loaded with memops
142
- return self.arg_bufs[self.bufs[i].realized]
143
-
144
- def global_load(self, i:int, idxs:Sequence[VariableOrNum], acc=None) -> List[Token]:
145
- const = self.bufs[i].realized._buf if isinstance(self.bufs[i].realized, RawConst) else acc
146
-
147
- expanded_nodes = [expand_node(idx) for idx in idxs]
148
- _idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
149
- upcast_dim = self.get_upcast_dim(i)
150
-
151
- amt = 1
152
- if len(upcast_dim) == 1 and len(expanded_nodes[upcast_dim[0]]) in [4,2]:
153
- dim, amt = upcast_dim[0], len(expanded_nodes[upcast_dim[0]])
10
+ from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node
11
+ from tinygrad.codegen.kernel import LocalBuffer, Kernel
12
+ from tinygrad.renderer import Program
13
+
14
+ from tinygrad.codegen.uops import UOps, UOp, UOpGraph
15
+
16
+ def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
17
+ local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate((prod(local_dims[:-(maxdim-1)]),) + local_dims[-(maxdim-1):] if len(local_dims) > maxdim else local_dims)] # noqa: E501
18
+ if maxdim != 0 and len(local_dims) > maxdim:
19
+ dd = local_idxs[0]
20
+ nli = []
21
+ for s in local_dims[:-(maxdim-1)]:
22
+ nli.append(dd % s)
23
+ dd //= s
24
+ local_idxs = nli + local_idxs[-(maxdim-1):]
25
+ return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
26
+
27
+ def expand_idx(node:Node) -> Union[Variable, NumNode]: return next((v for v in node.vars() if v.expr.startswith("_uidx")), NumNode(0))
28
+ def expand_idxs(nodes:Sequence[Node]) -> Tuple[Union[Variable, NumNode], ...]:
29
+ eidxs = [expand_idx(node) for node in nodes]
30
+ return tuple([v if v not in eidxs[:j] else NumNode(0) for j, v in enumerate(eidxs)]) # take only first occurrence of expand variable
31
+ def iter_idxs(idxs:Tuple[Union[Variable, NumNode], ...]) -> Iterator[Tuple[int,...]]:
32
+ yield from (x[::-1] for x in itertools.product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]]))
33
+
34
+ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
35
+ idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
36
+ # TODO: bring back the valid removal logic (correct!)
37
+ if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid)
38
+ return (idx, idy), valid
39
+
40
+ # expand a Node into List[Node] that enumerates the underlying Variables from min to max
41
+ # expand increments earlier variables faster than later variables (as specified in the argument)
42
+ @functools.lru_cache(maxsize=None)
43
+ def expand_node(node:Node, idxs:Optional[Tuple[Union[Variable, NumNode], ...]]=None) -> List[Node]:
44
+ if idxs is None: idxs = (expand_idx(node),)
45
+ return [node.substitute({k:v for k,v in zip(idxs, (NumNode(x) for x in rep)) if isinstance(k, Variable)}) for rep in iter_idxs(idxs)]
46
+
47
+ class Linearizer(Kernel):
48
+ def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op): return UOp.alu(op, a, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
49
+
50
+ # NOTE: the consts have to be cached for deduping of downstream uops to work
51
+ def const(self, b:ConstType, dtype:DType=dtypes.int32) -> UOp:
52
+ return self.uops.add(UOps.DEFINE_VAR, dtype, (), b.unbind()[0]) if isinstance(b, Variable) else UOp.const(dtype, b)
53
+
54
+ def get_reduce_acc(self, reduceop:LazyOp):
55
+ if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
56
+ if reduceop.op is ReduceOps.MAX:
57
+ if dtypes.is_int(reduceop.dtype): return 0 if dtypes.is_unsigned(reduceop.dtype) else -2**(reduceop.dtype.itemsize*8-1)
58
+ return -math.inf if dtypes.is_float(reduceop.dtype) else False
59
+
60
+ # NOTE: once images are loaded, we uop them as their base float
61
+ def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt
62
+
63
+ render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
64
+ MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
65
+ DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
66
+ ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
67
+ LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT),
68
+ SumNode: lambda self,ops,ctx:
69
+ functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
70
+ AndNode: lambda self,ops,ctx:
71
+ functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
72
+
73
+ def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Optional[UOp]=None, loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]:
74
+ buf = self.bufs[i]
75
+ localtype = self.get_base_dtype(buf.dtype if acc is None else acc.dtype)
76
+ const = buf.val if isinstance(buf, ConstBuffer) else None
77
+
78
+ expand_vars = expand_idxs(idxs)
79
+
80
+ dim, amt = None, 1
81
+ # float 4 grouping
82
+ if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [4,2]:
83
+ dim, amt = upcast_dim[0], len(float4_expand)
84
+ g_idx, g_valid = self.sts[i].expr_idxs(idxs[:dim] + [float4_expand[0]] + idxs[dim+1:])
85
+ # do not use float4 if idx is not aligned
86
+ if g_idx != (g_idx//amt*amt): dim, amt = None, 1
87
+ if dim is None:
88
+ g_idx, g_valid = self.sts[i].expr_idxs(idxs)
89
+ # todo: multioutput test with different output valids to add if acc is None: g_valid = NumNode(1)
90
+
91
+ if amt > 1: localtype = localtype.vec(amt)
92
+ e_idxs, e_valids = expand_node(g_idx, expand_vars), expand_node(g_valid, expand_vars)
154
93
 
155
94
  ret = []
156
- invalid_value = 0 if dtypes.is_int(self.bufs[i].dtype) else 0.0
157
- for load_i, _idx in enumerate(_idxs):
158
- if amt > 1:
159
- idx, valid = self.sts[i].expr_idxs((_idx[:dim] + (expanded_nodes[dim][0],) + _idx[dim+1:]))
160
- localtype = dtypes._float4 if amt == 4 else dtypes._float2
161
- if idx.render() != ((idx//amt)*amt).render():
162
- idx, valid = self.sts[i].expr_idxs(_idx)
163
- localtype = dtypes.float32
164
- else:
165
- idx, valid = self.sts[i].expr_idxs(_idx)
166
- localtype = dtypes.float32
167
- this_const, idx, valid = (invalid_value, Variable.num(0), Variable.num(1)) if valid.max == 0 else (const, idx, valid)
168
- key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else self.get_buffer_name(i)}{idx.render()}{valid.render()}"
95
+ invalid_value = 0
96
+ acc_count = 0
97
+ for idx, valid, rep_idx in zip(e_idxs, e_valids, iter_idxs(expand_vars)):
98
+ this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid)
99
+ # todo: when multiple reduceops are supported, clearly disambiguate and test acc load keys are unique for each reduceop
100
+ key = f"{acc is not None}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
169
101
  if key not in self.load_cache:
170
- if isinstance(self.bufs[i].dtype, ImageDType): idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
171
- self.load_cache[key] = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{load_i}", localtype), [], MemOp(self.get_buffer_name(i), idx, self.bufs[i].__class__ is LocalBuffer, self.bufs[i].dtype, valid, invalid_value)) if this_const is None else \
172
- self.uop(UOps.LOAD, Token(f"{'const' if acc is None else 'acc'}{mnum(i)}_{load_i}", localtype), [], ConstOp(this_const, valid))
173
- ret.append(Token(self.load_cache[key].name, self.load_cache[key].dtype, expanded_nodes[dim].index(_idx[dim])) if localtype != dtypes.float else self.load_cache[key])
102
+ if acc is not None:
103
+ self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, loop_ctx, (self.get_reduce_acc(acc), i, acc_count))
104
+ acc_count += 1
105
+ elif this_const is not None:
106
+ self.load_cache[key] = self.const(this_const, localtype)
107
+ if valid.min == 0 and valid.max == 1:
108
+ valid_rendered = valid.render(self.render_ops, self)
109
+ self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_rendered, self.load_cache[key], self.const(invalid_value, localtype))
110
+ elif isinstance(buf.dtype, ImageDType):
111
+ buf_uop = self.buf_uops[i]
112
+ assert buf_uop is not None, f"buffer {i} wasn't UOped"
113
+ image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
114
+ rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
115
+ valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, buf.dtype.base.vec(4))) if valid.min == 0 else tuple()
116
+ self.load_cache[key] = self.uops.add(UOps.LOAD, buf.dtype.base.vec(4),
117
+ (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
118
+ if localtype == localtype.scalar():
119
+ idx_small = idx%4
120
+ res = idx_small.render(self.render_ops, self)
121
+ out = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
122
+ for ix in range(idx_small.max, idx_small.min, -1):
123
+ rvv = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
124
+ sel = UOp.alu(BinaryOps.CMPLT, res, self.const(ix))
125
+ out = UOp.alu(TernaryOps.WHERE, sel, rvv, out)
126
+ self.load_cache[key] = out
127
+ else:
128
+ buf_uop = self.buf_uops[i]
129
+ assert buf_uop is not None, f"buffer {i} wasn't UOped"
130
+ rendered_idx = idx.render(self.render_ops, self)
131
+ valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, localtype)) if valid.min == 0 else tuple()
132
+ self.load_cache[key] = self.uops.add(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
133
+ ret.append(self.uops.add(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
174
134
  return ret
175
135
 
176
- def global_store(self, i, idxs:List[VariableOrNum], store:List[Token], ssa) -> None:
177
- expanded_nodes = [expand_node(idx) for idx in idxs]
178
- _idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
179
- upcast_dim = self.get_upcast_dim(i)
136
+ def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
137
+ buf = self.bufs[i]
138
+ buf_uop = self.buf_uops[i]
139
+ assert buf_uop is not None, f"buffer {i} wasn't UOped"
180
140
 
141
+ expand_vars = expand_idxs(idxs)
142
+ _idxs = zip(*[expand_node(idx, expand_vars) for idx in idxs]) if idxs else [tuple()] # transpose
181
143
  store_offset = dict(zip(_idxs, store))
182
144
 
183
145
  # float4 grouping
184
- if len(upcast_dim) == 1 and len(expanded_nodes[upcast_dim[0]]) in [2,4]:
146
+ if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [2,4]:
185
147
  grouped_store_offset = defaultdict(list)
186
148
  for k in store_offset:
187
- _idx = k[:upcast_dim[0]] + (expanded_nodes[upcast_dim[0]][0],) + k[upcast_dim[0]+1:]
149
+ _idx = k[:upcast_dim[0]] + (float4_expand[0],) + k[upcast_dim[0]+1:]
188
150
  grouped_store_offset[_idx].append(store_offset[k])
189
151
  store_offset_new = {}
190
- for k,out_tokens in grouped_store_offset.items():
191
- amt = len(out_tokens)
152
+ for k,grouped in grouped_store_offset.items():
153
+ amt = len(grouped)
192
154
  idx, valid = self.sts[i].expr_idxs(k)
193
- assert idx.render() == ((idx//amt)*amt).render(), "float4 stores are always aligned"
194
- assert valid.min == 1, "stores are always valid"
195
- if all_same([x.name for x in out_tokens]) and tuple(range(amt)) == tuple(x.offset for x in out_tokens):
196
- store_offset_new[k] = Token(out_tokens[0].name, dtypes._float4 if amt == 4 else dtypes._float2)
197
- else:
198
- store_offset_new[k] = self.uop(UOps.CAST, ssa("alu", dtypes._float4 if amt == 4 else dtypes._float2), out_tokens)
155
+ assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
156
+ store_offset_new[k] = self.uops.add(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
199
157
  store_offset = store_offset_new
200
158
 
201
- for idx, var in store_offset.items():
202
- idx, valid = self.sts[i].expr_idxs(idx)
203
- if isinstance(self.bufs[i].dtype, ImageDType): idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
204
- self.uop(UOps.STORE, None, [var], MemOp(self.get_buffer_name(i), idx, self.bufs[i].__class__ is LocalBuffer, self.bufs[i].dtype, valid))
159
+ stores = []
160
+ for _idx, var in store_offset.items():
161
+ idx, valid = self.sts[i].expr_idxs(_idx)
162
+ if isinstance(buf.dtype, ImageDType):
163
+ image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
164
+ rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), \
165
+ tuple(x.render(self.render_ops, self) for x in image_idx))
166
+ else:
167
+ rendered_idx = idx.render(self.render_ops, self)
168
+ if valid.min == 1: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var)))
169
+ else: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
170
+ return stores
171
+
172
+ # render loop
173
+ def render_loop(self, xx:List[Variable], depth:int) -> Tuple[UOp, ...]:
174
+ new_loops = {x.expr:self.uops.add(UOps.RANGE, dtypes.int32, (
175
+ self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
176
+ self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), arg=(depth,i)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
177
+ self.loop_uops.update(new_loops)
178
+ return tuple(new_loops.values())
179
+
180
+ def render_reduceop(self, reduceop:LazyOp, accs:Dict[LazyOp, List[UOp]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]],
181
+ global_idxs, local_idxs, upcast_idxs):
182
+ # define indicies
183
+ full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
184
+ reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
185
+ fake_reduce_idxs = [x*0 for x in reduce_idxs]
186
+
187
+ def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
188
+ replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
189
+ for s in local_sizes:
190
+ thread_idxs.append(thread_idx % s)
191
+ thread_idx //= s
192
+ for alias in aliases:
193
+ full_var, full_var_sz = NumNode(0), 1
194
+ if alias[0] != 0:
195
+ for i in alias:
196
+ next_var = local_idxs[i-1] if i > 0 else thread_idxs[-i-1]
197
+ full_var += next_var * full_var_sz
198
+ full_var_sz *= next_var.max+1
199
+ replace_idxs.append(full_var)
200
+ return replace_idxs
201
+
202
+ # compute local aliases - modify idxs if necessary for TC
203
+ alias_buf_idxs = []
204
+ for i in self.local_alias:
205
+ localbuf_idx = self.bufs.index(self.local_alias[i])
206
+ buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
207
+ if (tc:=self.tensor_core):
208
+ min_alias_idx = min(self.local_alias.keys())
209
+ replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
210
+ for n in range(len(tc.threads)):
211
+ buf_idxs[self.global_dims+n] = replace_input_idxs[n] # replace locals
212
+ for n in range(tc.num_upcasts()):
213
+ buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
214
+ if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
215
+ alias_buf_idxs.append((i, localbuf_idx, buf_idxs,))
216
+
217
+ # reduce loop
218
+ loop_ctx = self.render_loop(reduce_idxs, 2)
219
+
220
+ # define accumulator - modify idxs if necessary for TC
221
+ out_buf = -1 if self.group_for_reduces else 0
222
+ if (tc:=self.tensor_core):
223
+ replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
224
+ for n in range(len(tc.threads)):
225
+ local_idxs[n] = replace_acc_idxs[n] # replace locals
226
+ for n in range(len(replace_acc_idxs)-len(tc.threads)):
227
+ upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
228
+ if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs}")
229
+ accs[reduceop] = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
230
+
231
+ # store local aliases
232
+ locals_to_store = [(localbuf_idx, buf_idxs, self.global_load(i, buf_idxs)) for i, localbuf_idx, buf_idxs in alias_buf_idxs]
233
+
234
+ if (tc:=self.tensor_core):
235
+ # run tensor cores AST
236
+ wmma_sz = [prod(l) for l in tc.thread_local_sizes]
237
+ def upcast_strides(buf:int):
238
+ strides, next = [], 1
239
+ for (sz, stride, reduce) in self.upcasted_axis(buf)[tc.num_upcasts():]:
240
+ strides.append((0 if stride == 0 else next, sz))
241
+ next *= 1 if stride == 0 else sz
242
+ return strides
243
+ upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
244
+ # cast initial accs
245
+ wmmas = [self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
246
+ for x in range(0, len(accs[reduceop]), wmma_sz[2])]
247
+ for iter in [x[::-1] for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]:
248
+ offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
249
+ ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
250
+ self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
251
+ wmmas[(wmma_idx:=offs[2]//wmma_sz[2])])
252
+ # TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid
253
+ wmmas[wmma_idx] = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev))
254
+ # phi the last wmmas back to accs
255
+ accs[reduceop] = [self.uops.add(UOps.PHI, tc.dtype_out, (acc, self.uops.add(UOps.GEP, tc.dtype_out, (wmmas[z//wmma_sz[2]],), z%wmma_sz[2])))
256
+ for z, acc in enumerate(accs[reduceop])]
257
+ else:
258
+ assert not locals_to_store, "storing locals isn't supported here"
259
+
260
+ # load earlybufs
261
+ loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i,
262
+ global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs})
263
+
264
+ # run early AST (with reduce)
265
+ self.ast_parse(reduceop, accs, self.acc_offsets(self.full_buf_index), loaded_buffers, reduce_acc=accs[reduceop])
266
+
267
+ # end the reduce loop
268
+ self.load_cache.clear()
269
+
270
+ # end the local loop, do the local reduce
271
+ if self.group_for_reduces:
272
+ fake_global_idxs = [x*0 for x in global_idxs]
273
+ stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop]) # store accumulators
274
+ barrier = self.uops.add(UOps.BARRIER, None, tuple(stores))
275
+ if self.opts.has_local:
276
+ fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
277
+ fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
278
+ if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(self.render_ops, self)
279
+ barrier = self.uops.add(UOps.IF, None, (if_cond, barrier))
280
+
281
+ # create new late reduce local loops and replace local_idxs that have been used
282
+ end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501
283
+ local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
284
+
285
+ # if any group_for_reduce items aren't reduces, upcast them here
286
+ for j in self.upcast_in_mid_reduce_axes:
287
+ self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
288
+ self.upcast()
289
+ self.group_for_reduces -= 1
290
+ local_idxs = local_idxs[:-1]
291
+ end_local_idxs = end_local_idxs[:-1]
292
+ # regenerate upcast_idxs
293
+ upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
294
+
295
+ # NOTE: this structure is the same as the reduce op above
296
+
297
+ # late reduce loop
298
+ loop_ctx = self.render_loop(end_local_idxs, 3)
299
+
300
+ # define late accumulator
301
+ accs[reduceop] = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
302
+
303
+ # load localbufs
304
+ loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
305
+
306
+ # there's no AST here (and there's no shape for the reduce LazyOp)
307
+ self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)),\
308
+ accs, self.acc_offsets(-1), loaded_buffers, reduce_acc=accs[reduceop])
309
+
310
+ # end the late reduce loop
311
+ self.load_cache.clear()
312
+
313
+ # all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
314
+ # been rewritten with fake end_local_idxs.
315
+ return (accs, loaded_buffers, fake_reduce_idxs, local_idxs[:self.local_dims] + [NumNode(0) for i in range(self.group_for_reduces)], upcast_idxs)
205
316
 
206
317
  kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
207
318
  def linearize(self):
208
- self.process()
319
+ # no new opts and we already ran? skip relinearizing
320
+ if self.applied_opts == self.applied_opts_cache: return self
321
+
322
+ # late alias the tensor core buffers
323
+ if (tc:=self.tensor_core) and (tc_opts:=self.tensor_core_opts):
324
+ alias_pattern = [0]*(self.global_dims) + [2]*(len(tc.threads)) + [0]*(self.local_dims-len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) # noqa: E501
325
+ for tc_buf in tc_opts.bufs:
326
+ self.alias_buffer(tc_buf, alias_pattern)
327
+
328
+ # save backups
329
+ sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduces, self.upcasted
330
+
331
+ # global uop cache
332
+ self.saved_exprs: Dict[Tuple, UOp] = dict()
209
333
 
210
334
  # limit dims if we need to
211
- if self.opts.global_max and self.opts.local_max: self.limit_global_dims(3, self.opts.global_max, self.opts.local_max)
335
+ if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max)
212
336
 
213
337
  # uops
214
- self.uops: List[UOp] = []
215
- self.load_cache: Dict[str, Token] = {}
216
- self.saved_exprs: Dict[Tuple[Op, Tuple[Token, ...]], Token] = dict()
338
+ self.uops:UOpGraph = UOpGraph()
339
+ self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
340
+ self.loop_uops: Dict[str, UOp] = {}
217
341
 
218
342
  # add global buffers
219
- for buf,name in self.arg_bufs.items():
220
- self.uop(UOps.DEFINE_GLOBAL, None, [], (name, buf.dtype))
221
- # add variables from symbolic shapes
222
- for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key):
223
- self.uop(UOps.DEFINE_GLOBAL, None, [], (var.expr, dtypes._arg_int32))
224
-
225
- # add a local buffer for multistage reduce
226
- if self.group_for_reduce:
227
- # TODO: the strides of this can be controlled
228
- self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
229
- self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
230
- self.uop(UOps.DEFINE_LOCAL, None, [], ("temp", self.sts[-1].size()))
231
-
343
+ for i,buf in enumerate(self.bufs):
344
+ if isinstance(buf, MemBuffer):
345
+ self.buf_uops[i] = self.uops.add(UOps.DEFINE_GLOBAL,
346
+ buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
347
+ (buf.idx, any(buf.idx == x.idx for x in self.outbufs)))
348
+ # add var vals
349
+ for i,var in enumerate(self.vars):
350
+ assert var.expr is not None
351
+ self.loop_uops[var.expr] = self.uops.add(UOps.DEFINE_VAR, dtypes.int32, (), var)
232
352
  # define local buffers
233
353
  for lb in self.local_alias.values():
234
- self.uop(UOps.DEFINE_LOCAL, None, [], (lb.name, self.sts[self.bufs.index(lb)].size()))
235
-
236
- # print
237
- if DEBUG >= 3: self.printbufs()
354
+ self.buf_uops[self.bufs.index(lb)] = self.uops.add(UOps.DEFINE_LOCAL,
355
+ PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size))
356
+ # add a local buffer for multistage reduce. # TODO: use local alias
357
+ if self.group_for_reduces:
358
+ # TODO: the strides of this can be controlled
359
+ self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
360
+ temp_dtype = self.get_base_dtype(cast(LazyOp, self.reduceop).dtype)
361
+ self.bufs.append(LocalBuffer("temp", self.sts[-1].size, temp_dtype))
362
+ self.buf_uops.append(self.uops.add(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size)))
238
363
 
239
364
  # kernel name (before late upcast)
240
- self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) if isinstance(x, int) else sym_rename(x) for x in self.full_shape])
241
- self.display_name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
242
-
243
- # parse AST
244
- loaded_buffers = {}
245
- acc = []
246
-
247
- # ssa
248
- _ssa:DefaultDict[str,int] = defaultdict(int)
249
- def ssa(name, ltype=dtypes.float) -> Token:
250
- _ssa[name] += 1
251
- return Token(f"{name}{_ssa[name]-1}", ltype)
252
-
253
- # global loop
254
- global_idxs = [Variable(f"gidx{i}", 0, self.full_shape[i]-1) for i in range(0, self.first_reduce-self.local_dims)]
255
- self.uop(UOps.LOOP, None, [], (global_idxs, "global"))
365
+ self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
366
+ (f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \
367
+ colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
256
368
 
257
- # local loop
258
- local_idxs = [Variable(f"lidx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce-self.local_dims, self.first_reduce+len(self.group_for_reduce))]
259
- self.uop(UOps.LOOP, None, [], (local_idxs, "local"))
369
+ # name the function something unique
370
+ Linearizer.kernel_cnt[(function_name := to_function_name(self.name))] += 1
371
+ suffix = f"{'n'+str(Linearizer.kernel_cnt[function_name]-1)}" if Linearizer.kernel_cnt[function_name] > 1 else ""
372
+ self.name = self.name+colored(suffix, 'BLACK')
373
+
374
+ # define indexes
375
+ global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
376
+ local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+self.group_for_reduces], 3 if self.opts.has_local else 0) # noqa: E501
377
+ upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
378
+
379
+ # set global/local size
380
+ self.global_size: Optional[List[int]] = None
381
+ self.local_size: Optional[List[int]] = None
382
+ if self.dont_use_locals:
383
+ self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
384
+ self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
385
+ elif self.opts.has_local:
386
+ self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs]
387
+ self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
388
+ self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
389
+ else:
390
+ self.render_loop(loop_global_idxs+loop_local_idxs, 1)
391
+ if self.global_size is not None: self.global_size += [1]*(3-len(self.global_size))
392
+ if self.local_size is not None: self.local_size += [1]*(3-len(self.local_size))
260
393
 
261
- # upcast indexes
262
- full_upcast_idxs = [Variable(None, 0, s-1) for s in self.full_shape[self.shape_len-self.upcasted:]]
263
- upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
394
+ # parse AST
395
+ loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]] = {}
396
+ accs: Dict[LazyOp, List[UOp]] = {}
397
+ self.load_cache: Dict[str, UOp] = {}
264
398
 
265
399
  # reduce op
266
- fake_reduce_idxs = []
267
- if self.reduceop is not None:
268
- # define indexes
269
- reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)]
270
- fake_reduce_idxs = [x*0 for x in reduce_idxs]
271
-
272
- # define accumulator
273
- acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
274
-
275
- # reduce loop
276
- self.uop(UOps.LOOP, None, [], (reduce_idxs, "reduce"))
277
-
278
- # barrier for fast GEMM
279
- if self.use_tensor_cores: self.uop(UOps.BARRIER, None, [], ())
280
-
281
- # compute local aliases
282
- locals_to_store = []
283
- for i in self.local_alias:
284
- strides = self.sts[i].real_strides()
285
- extra_locals = [lidx for lidx,st in zip(local_idxs[self.exclude_local_upcast:], strides[len(global_idxs)+self.exclude_local_upcast:self.first_reduce]) if st == 0]
286
- this_upcast_idxs: List[Node] = []
287
- # TODO: just flipping the order here is likely not generic at all
288
- for j,v in list(enumerate(full_upcast_idxs))[::-1] if self.reverse_upcast_dir else list(enumerate(full_upcast_idxs)):
289
- if strides[len(global_idxs)+len(local_idxs)+len(reduce_idxs)+j] == 0:
290
- if DEBUG >= 4: print(f"upcasting@{j} stride 0")
291
- this_upcast_idxs.append(Variable.num(0))
292
- elif (elc:=[el for el in extra_locals if v.min == el.min and v.max == el.max]):
293
- if DEBUG >= 4: print(f"upcasting@{j} matched stride {elc[0]}")
294
- this_upcast_idxs.append(elc[0])
295
- extra_locals.remove(elc[0])
296
- elif (elc:=[el for el in extra_locals if v.min == el.min and (v.max+1)%(el.max+1) == 0]):
297
- tacc = Variable.num(0)
298
- rem = v.max+1
299
- while len(elc) and rem%(elc[0].max+1) == 0:
300
- if DEBUG >= 4: print(f"upcasting@{j} partial stride {rem} {elc[0]} left: {elc[1:]}")
301
- rem = rem//(elc[0].max+1)
302
- tacc += (elc[0] * rem)
303
- extra_locals.remove(elc[0])
304
- elc = [el for el in extra_locals if v.min == el.min and rem%(el.max+1) == 0]
305
- if DEBUG >= 4 and rem > 1: print(f"failed upcasting@{j} partial stride {rem} extra locals {extra_locals}")
306
- this_upcast_idxs.append(tacc + Variable(None, 0, rem-1))
307
- else:
308
- if DEBUG >= 4: print(f"failed upcasting@{j} stride {v} extra locals {extra_locals}")
309
- this_upcast_idxs.append(v)
310
- idxs = global_idxs+local_idxs+reduce_idxs+(this_upcast_idxs[::-1] if self.reverse_upcast_dir else this_upcast_idxs)
311
- ll = self.global_load(i, idxs)
312
- locals_to_store.append((self.bufs.index(self.local_alias[i]), idxs, ll))
313
-
314
- # copy in any global buffers
315
- if self.use_tensor_cores:
316
- if self.bufs[0].device == "METAL":
317
- i = 0
318
- for y0,y1 in zip(locals_to_store[1][2][::2], locals_to_store[1][2][1::2]):
319
- for x0,x1 in zip(locals_to_store[0][2][::2], locals_to_store[0][2][1::2]):
320
- self.uop(UOps.WMMA, None, [x0, x1, y0, y1, acc[i], acc[i+1]], "METAL")
321
- i += 2
322
- elif self.bufs[0].device == "HIP":
323
- i = 0
324
- for y in range(0, len(locals_to_store[1][2]), 0x10):
325
- for x in range(0, len(locals_to_store[0][2]), 0x10):
326
- self.uop(UOps.WMMA, None, acc[i:i+8]+locals_to_store[0][2][x:x+0x10]+locals_to_store[1][2][y:y+0x10], "HIP")
327
- i += 8
328
- else:
329
- if locals_to_store:
330
- self.uop(UOps.BARRIER, None, [], ())
331
- for i, idxs, ll in locals_to_store: self.global_store(i, idxs, ll, ssa)
332
- self.uop(UOps.BARRIER, None, [], ())
333
-
334
- # load earlybufs
335
- loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
336
-
337
- # run early AST (with reduce)
338
- self.ast_parse(self.reduceop, [acc[off] for off in self.acc_offsets(self.full_buf_index)], loaded_buffers, ssa, do_reduce=True)
339
-
340
- # end the reduce loop
341
- self.uop(UOps.ENDLOOP, None, [], (reduce_idxs, "reduce"))
342
- self.load_cache.clear()
343
-
344
- # end the local loop, do the local reduce
345
- if self.group_for_reduce:
346
- fake_global_idxs = [x*0 for x in global_idxs]
347
- self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc, ssa) # store accumulators
348
- self.uop(UOps.BARRIER, None, [], ())
349
- self.uop(UOps.ENDLOOP, None, [], (local_idxs, "local"))
350
-
351
- # local indexs are over, 0 them out
352
- local_idxs = [x*0 for x in local_idxs]
353
-
354
- # if any group_for_reduce items aren't reduces, upcast them here
355
- for j in self.upcast_in_mid_reduce_axes:
356
- self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
357
- self.upcast()
358
- self.group_for_reduce.pop()
359
- local_idxs = local_idxs[:-1]
360
- # regenerate upcast_idxs
361
- upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
362
-
363
- # NOTE: this structure is the same as the reduce op above
364
-
365
- # define late accumulator
366
- acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
367
-
368
- # late reduce loop
369
- end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
370
- self.uop(UOps.LOOP, None, [], (end_local_idxs, "late_reduce"))
371
-
372
- # load localbufs
373
- loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs+upcast_idxs)
374
-
375
- # there's no AST here (and there's no shape for the reduce LazyOp)
376
- self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True) # type: ignore
377
-
378
- # end the late reduce loop
379
- self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce"))
380
- self.load_cache.clear()
400
+ fake_reduce_idxs: List[Variable] = []
401
+ for reduceop in [self.reduceop] if self.reduceop is not None else []:
402
+ accs,loaded_buffers,fake_reduce_idxs,local_idxs,upcast_idxs = \
403
+ self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs)
381
404
 
382
405
  # load latebufs
383
- loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
406
+ loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
407
+ for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})
384
408
 
385
- # run late AST
386
- val = self.ast_parse(self.ast, acc, loaded_buffers, ssa)
409
+ # run late AST (without the store)
410
+ for op in self.ast:
411
+ val = self.ast_parse(op.src[0], accs, None, loaded_buffers)
412
+ self.global_store(op.arg.idx, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
387
413
 
388
- # store
389
- self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val, ssa)
414
+ # maybe graph the uops
415
+ if DEBUG >= 5: self.uops.print()
416
+ if getenv("GRAPHUOPS"): self.uops.graph()
390
417
 
391
- if not self.group_for_reduce:
392
- # end the global+local loop
393
- self.uop(UOps.ENDLOOP, None, [], (global_idxs+local_idxs, "global+local"))
394
- else:
395
- # end the global loop
396
- self.uop(UOps.ENDLOOP, None, [], (global_idxs, "global"))
418
+ # restore backups
419
+ self.sts, self.group_for_reduces, self.upcasted = sts_backup, gfr_backup, upc_backup
397
420
 
398
- # name the function something unique
399
- Linearizer.kernel_cnt[self.function_name] += 1
400
- suffix = f"{'n'+str(Linearizer.kernel_cnt[self.function_name]-1)}" if Linearizer.kernel_cnt[self.function_name] > 1 else ""
401
- self.function_name, self.display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK')
421
+ # set cache and return
422
+ self.applied_opts_cache = self.applied_opts[:]
402
423
  return self
403
424
 
404
- _OT = TypeVar("_OT")
405
- def uop(self, uop:UOps, out:_OT, vin:List[Token], arg:Any=None) -> _OT:
406
- self.uops.append(UOp(uop, cast(Optional[Token], out), vin, arg))
407
- if DEBUG >= 4: print(self.uops[-1])
408
- return out
409
-
410
- def uop_alu(self, out: Token, vin: List[Token], op: Op) -> Token:
411
- key = (op, tuple(vin))
412
- if key not in self.saved_exprs: self.saved_exprs[key] = self.uop(UOps.ALU, out, vin, op)
413
- return self.saved_exprs[key]
414
-
415
- def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[Token]:
416
- if x.__class__ is not LazyOp: return loaded_buffers[x]
417
- if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers, ssa) # cast isn't an ALU op
418
- if x.op in ReduceOps and not do_reduce: return acc
419
- # MULACC fusion. TODO: this is copied from Interpreted
420
- if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == BinaryOps.MUL:
421
- x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
422
- if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL:
423
- x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg)
424
- if x.op in {BinaryOps.ADD, BinaryOps.MUL}:
425
- # Reorder sources to put constants first so get_grouped_maybe_float4 can fold the op
426
- srcs = sorted(x.src, key=lambda x: (x.realized.__class__ != RawConst) if x.__class__ == LazyBuffer else 0)
427
- x.src = tuple(srcs)
428
- values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
429
- ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
425
+ def ast_parse(self, x:LazyOp, accs:Dict[LazyOp, List[UOp]], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], reduce_acc:Optional[List[UOp]]=None, cache=None) -> List[UOp]: # noqa: E501
426
+ if cache is None: cache = {}
427
+ if x in cache: return cache[x]
428
+ if x.op in BufferOps: return loaded_buffers[x.arg]
429
+ if x.op in [UnaryOps.CAST, UnaryOps.BITCAST]:
430
+ return [self.uops.add(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST,
431
+ self.get_base_dtype(x.arg), (u,)) for u in self.ast_parse(x.src[0], accs, offs, loaded_buffers)]
432
+ if x.op in ReduceOps and reduce_acc is None:
433
+ assert offs is None, "not available if we aren't doing reduce"
434
+ return accs[x]
435
+
436
+ values = [self.ast_parse(v, accs, offs, loaded_buffers, cache=cache) for v in x.src]
437
+ ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}
430
438
  if x.op in ops:
431
- ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), ops[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.opts.supports_float4_alu)]
432
- else:
433
- ret = [(idx, self.uop_alu(ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.opts.supports_float4_alu and x.op not in {BinaryOps.CMPLT, TernaryOps.WHERE})]
434
- ordered_ret: List[Optional[Token]] = [None]*len(values[0])
435
- # scatter
436
- for i,j in ret:
437
- for o,k in enumerate(i):
438
- ordered_ret[k] = Token(j.name, j.dtype, o) if j.dtype == dtypes._float4 else j
439
- assert all(isinstance(x, Token) for x in ordered_ret), "some tokens didn't get scattered?"
440
- return cast(List[Token], ordered_ret)
439
+ assert reduce_acc is not None
440
+ ret: List[UOp] = []
441
+ acc, input_acc = reduce_acc, reduce_acc[:]
442
+ for val, off in zip(zip(*values), cast(List[int], offs)):
443
+ acc[off] = UOp.alu(ops[cast(ReduceOps, x.op)], *(val+(acc[off], )))
444
+ ret.append(acc[off])
445
+ for off in range(len(acc)):
446
+ if input_acc[off] != acc[off]:
447
+ acc[off] = self.uops.add(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]))
448
+ else: ret = [UOp.alu(x.op, *vin) for vin in zip(*values)]
449
+ cache[x] = ret
450
+ return ret
451
+
452
+ def to_program(self) -> Program:
453
+ self.linearize()
454
+ info = get_lazyop_info(self.ast[0])
455
+ src = self.opts.render(to_function_name(self.name), self.uops)
456
+ ops, mem = self.uops.flops_mem()
457
+ run_count = prod((self.global_size if self.global_size else []) + (self.local_size if self.local_size else []))
458
+ # NOTE: we use min here to ignore the indexing FLOPS
459
+ return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
460
+ self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))