tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. tinygrad/codegen/__init__.py +0 -0
  2. tinygrad/codegen/kernel.py +78 -90
  3. tinygrad/codegen/linearizer.py +237 -169
  4. tinygrad/codegen/uops.py +278 -242
  5. tinygrad/device.py +147 -10
  6. tinygrad/dtype.py +7 -7
  7. tinygrad/engine/graph.py +16 -16
  8. tinygrad/engine/jit.py +39 -36
  9. tinygrad/engine/realize.py +6 -5
  10. tinygrad/engine/schedule.py +15 -7
  11. tinygrad/engine/search.py +6 -3
  12. tinygrad/function.py +17 -23
  13. tinygrad/helpers.py +77 -8
  14. tinygrad/lazy.py +26 -26
  15. tinygrad/multi.py +13 -9
  16. tinygrad/nn/__init__.py +1 -1
  17. tinygrad/nn/datasets.py +2 -1
  18. tinygrad/nn/state.py +3 -4
  19. tinygrad/ops.py +49 -16
  20. tinygrad/renderer/__init__.py +8 -4
  21. tinygrad/renderer/assembly.py +93 -100
  22. tinygrad/renderer/cstyle.py +47 -42
  23. tinygrad/renderer/llvmir.py +30 -30
  24. tinygrad/runtime/__init__.py +0 -0
  25. tinygrad/runtime/autogen/amd_gpu.py +11504 -1
  26. tinygrad/runtime/autogen/comgr.py +36 -10
  27. tinygrad/runtime/autogen/hsa.py +146 -14
  28. tinygrad/runtime/autogen/io_uring.py +1486 -0
  29. tinygrad/runtime/autogen/nv_gpu.py +269 -0
  30. tinygrad/runtime/driver/__init__.py +0 -0
  31. tinygrad/runtime/driver/hip_comgr.py +20 -11
  32. tinygrad/runtime/graph/__init__.py +0 -0
  33. tinygrad/runtime/graph/clang.py +3 -2
  34. tinygrad/runtime/graph/cuda.py +2 -2
  35. tinygrad/runtime/graph/hcq.py +122 -78
  36. tinygrad/runtime/ops_amd.py +302 -316
  37. tinygrad/runtime/ops_cuda.py +3 -3
  38. tinygrad/runtime/ops_disk.py +70 -5
  39. tinygrad/runtime/ops_gpu.py +2 -2
  40. tinygrad/runtime/ops_metal.py +5 -6
  41. tinygrad/runtime/ops_npy.py +1 -1
  42. tinygrad/runtime/ops_nv.py +161 -166
  43. tinygrad/runtime/ops_python.py +20 -16
  44. tinygrad/shape/__init__.py +0 -0
  45. tinygrad/shape/shapetracker.py +5 -2
  46. tinygrad/shape/symbolic.py +1 -3
  47. tinygrad/shape/view.py +34 -19
  48. tinygrad/tensor.py +219 -135
  49. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
  50. tinygrad-0.9.1.dist-info/RECORD +63 -0
  51. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  52. tinygrad/runtime/driver/hsa.py +0 -143
  53. tinygrad/runtime/graph/hsa.py +0 -171
  54. tinygrad/runtime/ops_hsa.py +0 -278
  55. tinygrad-0.9.0.dist-info/RECORD +0 -60
  56. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
  57. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,35 +1,83 @@
1
1
  from __future__ import annotations
2
- from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence
2
+ from typing import List, Tuple, Optional, Type, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence, Callable
3
3
  import itertools, math, functools
4
4
  from collections import defaultdict
5
5
 
6
- from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
7
- from tinygrad.helpers import colored, DEBUG, prod, getenv, to_function_name
6
+ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
7
+ from tinygrad.helpers import colored, DEBUG, dedup, diskcache_put, prod, getenv, to_function_name, flatten
8
8
  from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
9
9
  from tinygrad.shape.shapetracker import ShapeTracker
10
- from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node
10
+ from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node, sint
11
11
  from tinygrad.codegen.kernel import LocalBuffer, Kernel
12
12
  from tinygrad.renderer import Program
13
13
 
14
14
  from tinygrad.codegen.uops import UOps, UOp, UOpGraph
15
15
 
16
- def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
17
- local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate((prod(local_dims[:-(maxdim-1)]),) + local_dims[-(maxdim-1):] if len(local_dims) > maxdim else local_dims)] # noqa: E501
18
- if maxdim != 0 and len(local_dims) > maxdim:
19
- dd = local_idxs[0]
20
- nli = []
21
- for s in local_dims[:-(maxdim-1)]:
22
- nli.append(dd % s)
23
- dd //= s
24
- local_idxs = nli + local_idxs[-(maxdim-1):]
25
- return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
16
+ def get_grouped_dims(prefix:str, off:int, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse_dims:bool=False):
17
+ """ Maps all global/local dims onto global/local sizes and returns the idxs, loop_idxs and sizes.
18
+
19
+ * If there are fewer dims than size, size will be padded with 1s to the length of max_sizes.
20
+ * If there are more dims than size, dims will be collapsed onto size starting from left-most (i.e. onto x, then y, then z).
21
+ * If the dim is too large for the size, the dim will be split between adjacent size axes space permitting, otherwise assert
22
+
23
+ Keyword arguments:
24
+ prefix -- the prefix to use for the size Variable names.
25
+ off -- the starting index for the size Variable names.
26
+ dims -- the global or local dims of the full shape.
27
+ max_sizes -- the maximum values for each size in (x, y, z) order.
28
+ reverse_dims -- reverse the order of the dims as they are mapped into size, i.e. if True, the right dim will go to the left size (.x).
29
+ """
30
+
31
+ # check the edge cases on max_sizes
32
+ if max_sizes is None: max_sizes = tuple([0xFFFFFFFFFFFFFFFF] * len(dims))
33
+ assert len(max_sizes) > 0 or len(dims) == 0, f"{prefix} dims should be empty because no size axes available"
34
+ if len(max_sizes) == 0: return [], [], None
35
+
36
+ # initialize the map of dims to size with a single dim in each size axis
37
+ # TODO: support sint properly
38
+ size_dims:List[List[Tuple[int, sint, sint]]] = [[(dim_idx, dim, dim if isinstance(dim, int) else dim.max+1)] for dim_idx, dim in enumerate(dims)]
39
+
40
+ # reverse the order of the dims to size map, if desired (currently for globals where smallest stride is on the right)
41
+ # TODO: remove reverse_dims, the mapping of dims to size for globals should be cosearched with memory layouts for optimal peformance
42
+ if reverse_dims: size_dims = size_dims[::-1]
43
+
44
+ # ensure that the initial dims initially fit the valid size axes
45
+ for size_idx in range(min(len(max_sizes), len(size_dims))):
46
+ # if the initial dim is too large, split the dim to separate size axes, if possible
47
+ dim_idx, dim, dim_max = size_dims[size_idx][0]
48
+ if dim_max <= (max_sz:=max_sizes[size_idx]): continue
49
+ assert isinstance(dim, int), "variable shape too large for size"
50
+ for factor in range(2, int(dim**0.5)+1):
51
+ if dim % factor == 0 and dim // factor <= max_sz:
52
+ size_dims = size_dims[:size_idx] + [[(dim_idx, dim//factor, dim//factor)], [(dim_idx, factor, factor)]] + size_dims[size_idx+1:]
53
+ break
54
+ assert size_dims[size_idx][0][2] <= max_sz, f"dim at {size_idx} too large and non-factorable: {dim} > {max_sz}"
55
+
56
+ # compress the extra dims, collapsing them onto the left-most valid size axis
57
+ cur_size_idx = 0
58
+ while len(size_dims) > len(max_sizes):
59
+ if prod([dim_max for (_, _, dim_max) in size_dims[cur_size_idx]])*size_dims[cur_size_idx+1][0][2] <= max_sizes[cur_size_idx]:
60
+ size_dims = size_dims[:cur_size_idx] + [size_dims[cur_size_idx] + size_dims[cur_size_idx+1]] + size_dims[cur_size_idx+2:]
61
+ elif cur_size_idx < len(max_sizes)-1: cur_size_idx += 1
62
+ else: raise AssertionError(f"cannot fit dims in size: {dims=} {max_sizes=}")
63
+
64
+ # construct the final dim idx variables from the the portions of the size variables
65
+ sizes, idxs = [prod([dim for (_, dim, _) in size_dim]) for size_dim in size_dims], [NumNode(0)] * len(dims)
66
+ size_vars = loop_idxs = [Variable(f"{prefix}{len(sizes)-1-(i+off) if reverse_dims else i+off}", 0, s-1) for i,s in enumerate(sizes)]
67
+ for size_idx, size_var in enumerate(size_vars):
68
+ for dim_idx, dim, _ in size_dims[size_idx]:
69
+ idxs[dim_idx] += (size_var % dim) * (idxs[dim_idx].max+1)
70
+ size_var //= dim
71
+
72
+ # pad the final sizes array to the proper length if necessary
73
+ return idxs, [x for x in loop_idxs if not isinstance(x, NumNode)], sizes + [1]*(len(max_sizes)-len(sizes))
26
74
 
27
75
  def expand_idx(node:Node) -> Union[Variable, NumNode]: return next((v for v in node.vars() if v.expr.startswith("_uidx")), NumNode(0))
28
76
  def expand_idxs(nodes:Sequence[Node]) -> Tuple[Union[Variable, NumNode], ...]:
29
77
  eidxs = [expand_idx(node) for node in nodes]
30
78
  return tuple([v if v not in eidxs[:j] else NumNode(0) for j, v in enumerate(eidxs)]) # take only first occurrence of expand variable
31
79
  def iter_idxs(idxs:Tuple[Union[Variable, NumNode], ...]) -> Iterator[Tuple[int,...]]:
32
- yield from (x[::-1] for x in itertools.product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]]))
80
+ yield from (x[::-1] for x in itertools.product(*[list(range(v.min, v.max + 1)) for v in idxs[::-1]]))
33
81
 
34
82
  def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
35
83
  idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
@@ -44,13 +92,21 @@ def expand_node(node:Node, idxs:Optional[Tuple[Union[Variable, NumNode], ...]]=N
44
92
  if idxs is None: idxs = (expand_idx(node),)
45
93
  return [node.substitute({k:v for k,v in zip(idxs, (NumNode(x) for x in rep)) if isinstance(k, Variable)}) for rep in iter_idxs(idxs)]
46
94
 
47
- class Linearizer(Kernel):
48
- def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op): return UOp.alu(op, a, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
95
+ def variable_to_uop(x, ctx=None) -> UOp:
96
+ if isinstance(x, int): return UOp.const(dtypes.int, x)
97
+ return x.render(render_ops, ctx)
49
98
 
50
- # NOTE: the consts have to be cached for deduping of downstream uops to work
51
- def const(self, b:ConstType, dtype:DType=dtypes.int32) -> UOp:
52
- return self.uops.add(UOps.DEFINE_VAR, dtype, (), b.unbind()[0]) if isinstance(b, Variable) else UOp.const(dtype, b)
99
+ render_ops: Dict[Type, Callable[..., UOp]] = {
100
+ NumNode: lambda self, ops, ctx: UOp.const(dtypes.int, self.b),
101
+ Variable: lambda self, ops, ctx: ctx[self.expr] if self.expr in ctx else UOp(UOps.DEFINE_VAR, dtypes.int, (), self),
102
+ MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
103
+ DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
104
+ ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
105
+ LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
106
+ SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+variable_to_uop(b, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
107
+ AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*variable_to_uop(b, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
53
108
 
109
+ class Linearizer(Kernel):
54
110
  def get_reduce_acc(self, reduceop:LazyOp):
55
111
  if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
56
112
  if reduceop.op is ReduceOps.MAX:
@@ -60,16 +116,6 @@ class Linearizer(Kernel):
60
116
  # NOTE: once images are loaded, we uop them as their base float
61
117
  def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt
62
118
 
63
- render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
64
- MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
65
- DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
66
- ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
67
- LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT),
68
- SumNode: lambda self,ops,ctx:
69
- functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
70
- AndNode: lambda self,ops,ctx:
71
- functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
72
-
73
119
  def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Optional[UOp]=None, loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]:
74
120
  buf = self.bufs[i]
75
121
  localtype = self.get_base_dtype(buf.dtype if acc is None else acc.dtype)
@@ -89,48 +135,47 @@ class Linearizer(Kernel):
89
135
  # todo: multioutput test with different output valids to add if acc is None: g_valid = NumNode(1)
90
136
 
91
137
  if amt > 1: localtype = localtype.vec(amt)
92
- e_idxs, e_valids = expand_node(g_idx, expand_vars), expand_node(g_valid, expand_vars)
138
+ e_idxs, e_valids = expand_node(g_idx, expand_vars), expand_node(g_valid, expand_vars) # pylint: disable=possibly-used-before-assignment
93
139
 
94
140
  ret = []
95
141
  invalid_value = 0
96
142
  acc_count = 0
97
143
  for idx, valid, rep_idx in zip(e_idxs, e_valids, iter_idxs(expand_vars)):
98
144
  this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid)
99
- # todo: when multiple reduceops are supported, clearly disambiguate and test acc load keys are unique for each reduceop
100
- key = f"{acc is not None}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
145
+ key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
101
146
  if key not in self.load_cache:
102
147
  if acc is not None:
103
- self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, loop_ctx, (self.get_reduce_acc(acc), i, acc_count))
148
+ self.load_cache[key] = UOp(UOps.DEFINE_ACC, localtype, (UOp.const(localtype.scalar(), self.get_reduce_acc(acc)), *loop_ctx), (i, acc_count))
104
149
  acc_count += 1
105
150
  elif this_const is not None:
106
- self.load_cache[key] = self.const(this_const, localtype)
151
+ self.load_cache[key] = UOp.const(localtype, this_const)
107
152
  if valid.min == 0 and valid.max == 1:
108
- valid_rendered = valid.render(self.render_ops, self)
109
- self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_rendered, self.load_cache[key], self.const(invalid_value, localtype))
153
+ valid_rendered = valid.render(render_ops, self.loop_uops)
154
+ self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_rendered, self.load_cache[key], UOp.const(localtype, invalid_value))
110
155
  elif isinstance(buf.dtype, ImageDType):
111
156
  buf_uop = self.buf_uops[i]
112
157
  assert buf_uop is not None, f"buffer {i} wasn't UOped"
113
158
  image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
114
- rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
115
- valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, buf.dtype.base.vec(4))) if valid.min == 0 else tuple()
116
- self.load_cache[key] = self.uops.add(UOps.LOAD, buf.dtype.base.vec(4),
159
+ rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
160
+ valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
161
+ self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4),
117
162
  (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
118
163
  if localtype == localtype.scalar():
119
164
  idx_small = idx%4
120
- res = idx_small.render(self.render_ops, self)
121
- out = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
165
+ res = idx_small.render(render_ops, self.loop_uops)
166
+ out = UOp(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
122
167
  for ix in range(idx_small.max, idx_small.min, -1):
123
- rvv = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
124
- sel = UOp.alu(BinaryOps.CMPLT, res, self.const(ix))
168
+ rvv = UOp(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
169
+ sel = UOp.alu(BinaryOps.CMPLT, res, UOp.const(dtypes.int, ix))
125
170
  out = UOp.alu(TernaryOps.WHERE, sel, rvv, out)
126
171
  self.load_cache[key] = out
127
172
  else:
128
173
  buf_uop = self.buf_uops[i]
129
174
  assert buf_uop is not None, f"buffer {i} wasn't UOped"
130
- rendered_idx = idx.render(self.render_ops, self)
131
- valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, localtype)) if valid.min == 0 else tuple()
132
- self.load_cache[key] = self.uops.add(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
133
- ret.append(self.uops.add(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
175
+ rendered_idx = idx.render(render_ops, self.loop_uops)
176
+ valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
177
+ self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
178
+ ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
134
179
  return ret
135
180
 
136
181
  def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
@@ -153,7 +198,7 @@ class Linearizer(Kernel):
153
198
  amt = len(grouped)
154
199
  idx, valid = self.sts[i].expr_idxs(k)
155
200
  assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
156
- store_offset_new[k] = self.uops.add(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
201
+ store_offset_new[k] = UOp(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
157
202
  store_offset = store_offset_new
158
203
 
159
204
  stores = []
@@ -161,29 +206,24 @@ class Linearizer(Kernel):
161
206
  idx, valid = self.sts[i].expr_idxs(_idx)
162
207
  if isinstance(buf.dtype, ImageDType):
163
208
  image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
164
- rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), \
165
- tuple(x.render(self.render_ops, self) for x in image_idx))
209
+ rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), \
210
+ tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
166
211
  else:
167
- rendered_idx = idx.render(self.render_ops, self)
168
- if valid.min == 1: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var)))
169
- else: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
212
+ rendered_idx = idx.render(render_ops, self.loop_uops)
213
+ # TODO: let UPat check this once it's fast
214
+ if valid.min == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var)))
215
+ else: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(render_ops, self.loop_uops))))
170
216
  return stores
171
217
 
172
218
  # render loop
173
- def render_loop(self, xx:List[Variable], depth:int) -> Tuple[UOp, ...]:
174
- new_loops = {x.expr:self.uops.add(UOps.RANGE, dtypes.int32, (
175
- self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
176
- self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), arg=(depth,i)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
219
+ def render_loop(self, xx:List[Variable], depth:int, reduce:bool) -> Tuple[UOp, ...]:
220
+ new_loops = {x.expr:UOp(UOps.RANGE, dtypes.int32, (
221
+ UOp.const(dtypes.int, x.min) if isinstance(x.min, int) else cast(Node, x.min).render(render_ops, self.loop_uops),
222
+ UOp.const(dtypes.int, x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(render_ops, self.loop_uops)), arg=(depth,i,reduce)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
177
223
  self.loop_uops.update(new_loops)
178
224
  return tuple(new_loops.values())
179
225
 
180
- def render_reduceop(self, reduceop:LazyOp, accs:Dict[LazyOp, List[UOp]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]],
181
- global_idxs, local_idxs, upcast_idxs):
182
- # define indicies
183
- full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
184
- reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
185
- fake_reduce_idxs = [x*0 for x in reduce_idxs]
186
-
226
+ def index_local_aliases(self, global_idxs, local_idxs, reduce_idxs, upcast_idxs, full_upcast_idxs):
187
227
  def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
188
228
  replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
189
229
  for s in local_sizes:
@@ -199,33 +239,39 @@ class Linearizer(Kernel):
199
239
  replace_idxs.append(full_var)
200
240
  return replace_idxs
201
241
 
202
- # compute local aliases - modify idxs if necessary for TC
203
- alias_buf_idxs = []
204
- for i in self.local_alias:
205
- localbuf_idx = self.bufs.index(self.local_alias[i])
206
- buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
207
- if (tc:=self.tensor_core):
208
- min_alias_idx = min(self.local_alias.keys())
209
- replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
210
- for n in range(len(tc.threads)):
211
- buf_idxs[self.global_dims+n] = replace_input_idxs[n] # replace locals
212
- for n in range(tc.num_upcasts()):
213
- buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
214
- if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
215
- alias_buf_idxs.append((i, localbuf_idx, buf_idxs,))
216
-
217
- # reduce loop
218
- loop_ctx = self.render_loop(reduce_idxs, 2)
219
-
220
- # define accumulator - modify idxs if necessary for TC
221
- out_buf = -1 if self.group_for_reduces else 0
242
+ # compute local aliases
243
+ alias_buf_idxs: DefaultDict[LazyOp, List[Tuple[int, int, List]]] = defaultdict(list)
244
+ for op, local_alias in self.local_alias.items():
245
+ for i in local_alias:
246
+ localbuf_idx = self.bufs.index(local_alias[i])
247
+ buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
248
+ if (tc:=self.tensor_core):
249
+ min_alias_idx = min(local_alias.keys())
250
+ replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
251
+ for n in range(len(tc.threads)):
252
+ buf_idxs[self.global_dims+n] = replace_input_idxs[n] # replace locals
253
+ for n in range(tc.num_upcasts()):
254
+ buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
255
+ if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
256
+ alias_buf_idxs[op].append((i, localbuf_idx, buf_idxs))
257
+ # modify idxs if necessary for TC
222
258
  if (tc:=self.tensor_core):
223
259
  replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
224
260
  for n in range(len(tc.threads)):
225
261
  local_idxs[n] = replace_acc_idxs[n] # replace locals
226
262
  for n in range(len(replace_acc_idxs)-len(tc.threads)):
227
263
  upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
228
- if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs}")
264
+ if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+upcast_idxs}")
265
+ return alias_buf_idxs
266
+
267
+ def render_reduceop(self, reduceop:LazyOp, accs:Dict[LazyOp, List[UOp]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]],
268
+ global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, reduce_idxs, fake_reduce_idxs,
269
+ alias_buf_idxs:List[Tuple[int, int, List]]) -> Tuple[List[NumNode|Variable], List[NumNode|Variable]]:
270
+ # reduce loop
271
+ loop_ctx = self.render_loop(reduce_idxs, (i:=self.reduceops.index(reduceop))*2+2, True)
272
+
273
+ # define accumulator - modify idxs if necessary for TC
274
+ out_buf = -len(self.reduceops)+i if self.group_for_reduces else 0
229
275
  accs[reduceop] = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
230
276
 
231
277
  # store local aliases
@@ -235,34 +281,39 @@ class Linearizer(Kernel):
235
281
  # run tensor cores AST
236
282
  wmma_sz = [prod(l) for l in tc.thread_local_sizes]
237
283
  def upcast_strides(buf:int):
238
- strides, next = [], 1
239
- for (sz, stride, reduce) in self.upcasted_axis(buf)[tc.num_upcasts():]:
240
- strides.append((0 if stride == 0 else next, sz))
241
- next *= 1 if stride == 0 else sz
284
+ strides, next_ = [], 1
285
+ for (sz, stride, _) in self.upcasted_axis(buf)[tc.num_upcasts():]:
286
+ strides.append((0 if stride == 0 else next_, sz))
287
+ next_ *= 1 if stride == 0 else sz
242
288
  return strides
243
289
  upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
244
290
  # cast initial accs
245
- wmmas = [self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
291
+ wmmas = [UOp(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
246
292
  for x in range(0, len(accs[reduceop]), wmma_sz[2])]
247
- for iter in [x[::-1] for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]:
248
- offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
249
- ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
250
- self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
293
+ for it in [x[::-1] for x in itertools.product(*list([range(sz) for _,sz in upcasts[0]][::-1]))]:
294
+ offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(it, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
295
+ ops = (UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
296
+ UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
251
297
  wmmas[(wmma_idx:=offs[2]//wmma_sz[2])])
252
298
  # TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid
253
- wmmas[wmma_idx] = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev))
299
+ wmmas[wmma_idx] = UOp(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev))
254
300
  # phi the last wmmas back to accs
255
- accs[reduceop] = [self.uops.add(UOps.PHI, tc.dtype_out, (acc, self.uops.add(UOps.GEP, tc.dtype_out, (wmmas[z//wmma_sz[2]],), z%wmma_sz[2])))
301
+ accs[reduceop] = [UOp(UOps.PHI, tc.dtype_out, (acc, UOp(UOps.GEP, tc.dtype_out, (wmmas[z//wmma_sz[2]],), z%wmma_sz[2])))
256
302
  for z, acc in enumerate(accs[reduceop])]
257
303
  else:
258
304
  assert not locals_to_store, "storing locals isn't supported here"
259
305
 
260
306
  # load earlybufs
261
- loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i,
307
+ loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[reduceop][i]) if i in self.local_alias else i,
262
308
  global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs})
263
309
 
310
+ def gate_acc(r, idxs): return [
311
+ UOp.alu(TernaryOps.WHERE, valid.render(render_ops, self.loop_uops), acc, UOp.const(r.dtype, 0)) if valid.min == 0 and valid.max == 1 else acc
312
+ for valid, acc in zip(expand_node(self.sts[self.full_buf_index].expr_idxs(idxs)[1], expand_idxs(idxs)), accs[r])]
313
+ local_accs = {r: gate_acc(r,global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for r in accs}
314
+
264
315
  # run early AST (with reduce)
265
- self.ast_parse(reduceop, accs, self.acc_offsets(self.full_buf_index), loaded_buffers, reduce_acc=accs[reduceop])
316
+ self.ast_parse(reduceop, local_accs, self.acc_offsets(self.full_buf_index), loaded_buffers, reduce_acc=accs[reduceop])
266
317
 
267
318
  # end the reduce loop
268
319
  self.load_cache.clear()
@@ -270,13 +321,13 @@ class Linearizer(Kernel):
270
321
  # end the local loop, do the local reduce
271
322
  if self.group_for_reduces:
272
323
  fake_global_idxs = [x*0 for x in global_idxs]
273
- stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop]) # store accumulators
274
- barrier = self.uops.add(UOps.BARRIER, None, tuple(stores))
324
+ stores = self.global_store(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop]) # store accumulators
325
+ barrier = UOp(UOps.BARRIER, None, tuple(stores))
275
326
  if self.opts.has_local:
276
327
  fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
277
328
  fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
278
- if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(self.render_ops, self)
279
- barrier = self.uops.add(UOps.IF, None, (if_cond, barrier))
329
+ if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(render_ops, self.loop_uops)
330
+ barrier = UOp(UOps.IF, None, (if_cond, barrier))
280
331
 
281
332
  # create new late reduce local loops and replace local_idxs that have been used
282
333
  end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501
@@ -295,71 +346,68 @@ class Linearizer(Kernel):
295
346
  # NOTE: this structure is the same as the reduce op above
296
347
 
297
348
  # late reduce loop
298
- loop_ctx = self.render_loop(end_local_idxs, 3)
349
+ loop_ctx = self.render_loop(end_local_idxs, i*2+3, True)
299
350
 
300
351
  # define late accumulator
301
352
  accs[reduceop] = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
302
353
 
303
354
  # load localbufs
304
- loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
355
+ loaded_buffers[self.bufs[out_buf]] = self.global_load(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
305
356
 
306
357
  # there's no AST here (and there's no shape for the reduce LazyOp)
307
- self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)),\
358
+ self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[out_buf]),)),\
308
359
  accs, self.acc_offsets(-1), loaded_buffers, reduce_acc=accs[reduceop])
309
360
 
310
361
  # end the late reduce loop
311
362
  self.load_cache.clear()
312
363
 
313
- # all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
314
- # been rewritten with fake end_local_idxs.
315
- return (accs, loaded_buffers, fake_reduce_idxs, local_idxs[:self.local_dims] + [NumNode(0) for i in range(self.group_for_reduces)], upcast_idxs)
364
+ if reduceop is not self.reduceops[-1]:
365
+ for j in self.upcast_in_mid_reduce_axes:
366
+ self.upcasted -= 1
367
+ self.group_for_reduces += 1
368
+ assert self.buf_uops[out_buf] is not None, "Local reduce buf must have been uoped at this point"
369
+ fake_local_idxs = local_idxs[:self.local_dims] + [x*0 for x in local_idxs[self.local_dims:]]
370
+ stores = self.global_store(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop])
371
+ barrier = UOp(UOps.BARRIER, None, tuple(stores))
372
+ accs[reduceop] = self.global_load(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
373
+ return local_idxs[:self.local_dims] + [NumNode(0) for _ in range(self.group_for_reduces)], upcast_idxs
316
374
 
317
375
  kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
318
- def linearize(self):
376
+ def linearize(self) -> Linearizer:
319
377
  # no new opts and we already ran? skip relinearizing
320
378
  if self.applied_opts == self.applied_opts_cache: return self
321
379
 
322
380
  # late alias the tensor core buffers
323
- if (tc:=self.tensor_core) and (tc_opts:=self.tensor_core_opts):
381
+ if (tc:=self.tensor_core) and self.tensor_core_opts is not None:
324
382
  alias_pattern = [0]*(self.global_dims) + [2]*(len(tc.threads)) + [0]*(self.local_dims-len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) # noqa: E501
325
- for tc_buf in tc_opts.bufs:
326
- self.alias_buffer(tc_buf, alias_pattern)
383
+ for op, tc_bufs in self.bufs_for_tensor_core.items():
384
+ for tc_buf in tc_bufs: self.alias_buffer(op, tc_buf, alias_pattern)
327
385
 
328
386
  # save backups
329
387
  sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduces, self.upcasted
330
388
 
331
- # global uop cache
332
- self.saved_exprs: Dict[Tuple, UOp] = dict()
333
-
334
- # limit dims if we need to
335
- if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max)
336
-
337
389
  # uops
338
- self.uops:UOpGraph = UOpGraph()
339
390
  self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
340
391
  self.loop_uops: Dict[str, UOp] = {}
341
392
 
342
393
  # add global buffers
343
394
  for i,buf in enumerate(self.bufs):
344
395
  if isinstance(buf, MemBuffer):
345
- self.buf_uops[i] = self.uops.add(UOps.DEFINE_GLOBAL,
396
+ self.buf_uops[i] = UOp(UOps.DEFINE_GLOBAL,
346
397
  buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
347
398
  (buf.idx, any(buf.idx == x.idx for x in self.outbufs)))
348
- # add var vals
349
- for i,var in enumerate(self.vars):
350
- assert var.expr is not None
351
- self.loop_uops[var.expr] = self.uops.add(UOps.DEFINE_VAR, dtypes.int32, (), var)
352
399
  # define local buffers
353
- for lb in self.local_alias.values():
354
- self.buf_uops[self.bufs.index(lb)] = self.uops.add(UOps.DEFINE_LOCAL,
355
- PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size))
400
+ for aliases in self.local_alias.values():
401
+ for lb in aliases.values(): self.buf_uops[self.bufs.index(lb)] = UOp(UOps.DEFINE_LOCAL, PtrDType(lb.dtype),
402
+ (), (lb.name, self.sts[self.bufs.index(lb)].size))
356
403
  # add a local buffer for multistage reduce. # TODO: use local alias
357
404
  if self.group_for_reduces:
358
- # TODO: the strides of this can be controlled
359
- self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
360
- temp_dtype = self.get_base_dtype(cast(LazyOp, self.reduceop).dtype)
361
- self.bufs.append(LocalBuffer("temp", self.sts[-1].size, temp_dtype))
362
- self.buf_uops.append(self.uops.add(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size)))
405
+ for i in range(len(self.reduceops)):
406
+ # TODO: the strides of this can be controlled
407
+ self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
408
+ temp_dtype = self.get_base_dtype(cast(LazyOp, self.reduceop).dtype)
409
+ self.bufs.append(LocalBuffer(name:=f"temp{i if len(self.reduceops) > 1 else ''}", buf_size:=self.sts[-1].size, temp_dtype))
410
+ self.buf_uops.append(UOp(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), (name, buf_size)))
363
411
 
364
412
  # kernel name (before late upcast)
365
413
  self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
@@ -372,44 +420,39 @@ class Linearizer(Kernel):
372
420
  self.name = self.name+colored(suffix, 'BLACK')
373
421
 
374
422
  # define indexes
375
- global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
376
- local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+self.group_for_reduces], 3 if self.opts.has_local else 0) # noqa: E501
423
+ gl_dims = self.full_shape[:self.first_reduce+self.group_for_reduces]
424
+ global_idxs, loop_global_idxs, self.global_size = get_grouped_dims("idx" if self.dont_use_locals else "gidx", 0, gl_dims[:self.global_dims],
425
+ self.opts.global_max, self.opts.has_local)
426
+ local_idxs, loop_local_idxs, self.local_size = get_grouped_dims("lidx", self.global_dims, gl_dims[self.global_dims:],
427
+ self.opts.local_max if self.opts.has_local else (), False)
377
428
  upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
429
+ full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
378
430
 
379
- # set global/local size
380
- self.global_size: Optional[List[int]] = None
381
- self.local_size: Optional[List[int]] = None
382
- if self.dont_use_locals:
383
- self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
384
- self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
385
- elif self.opts.has_local:
386
- self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs]
387
- self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
388
- self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
431
+ # render global and local as specials or a loop
432
+ if self.opts.has_local:
433
+ self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
434
+ if not self.dont_use_locals:
435
+ self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
389
436
  else:
390
- self.render_loop(loop_global_idxs+loop_local_idxs, 1)
391
- if self.global_size is not None: self.global_size += [1]*(3-len(self.global_size))
392
- if self.local_size is not None: self.local_size += [1]*(3-len(self.local_size))
437
+ self.global_size, self.local_size = None, None
438
+ self.render_loop(loop_global_idxs+loop_local_idxs, 1, False)
439
+
440
+ # define idxs for aliased buffers TODO: this doesn't belong in Kernel, but it can't exist in Block either (because of multireduce tensor cores)
441
+ reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
442
+ alias_buf_idxs = self.index_local_aliases(global_idxs,local_idxs,reduce_idxs,upcast_idxs,full_upcast_idxs)
393
443
 
394
444
  # parse AST
445
+ self.load_cache: Dict[str, UOp] = {}
395
446
  loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]] = {}
396
447
  accs: Dict[LazyOp, List[UOp]] = {}
397
- self.load_cache: Dict[str, UOp] = {}
398
448
 
399
- # reduce op
400
- fake_reduce_idxs: List[Variable] = []
401
- for reduceop in [self.reduceop] if self.reduceop is not None else []:
402
- accs,loaded_buffers,fake_reduce_idxs,local_idxs,upcast_idxs = \
403
- self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs)
449
+ # render reduceops by depth
450
+ for reduceop in self.reduceops:
451
+ self.render_block((reduceop, ), global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
452
+ stores = self.render_block(self.ast, global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
404
453
 
405
- # load latebufs
406
- loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
407
- for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})
408
-
409
- # run late AST (without the store)
410
- for op in self.ast:
411
- val = self.ast_parse(op.src[0], accs, None, loaded_buffers)
412
- self.global_store(op.arg.idx, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
454
+ # only the final stores are needed to define the full UOps graph
455
+ self.uops:UOpGraph = UOpGraph(flatten(stores))
413
456
 
414
457
  # maybe graph the uops
415
458
  if DEBUG >= 5: self.uops.print()
@@ -422,16 +465,40 @@ class Linearizer(Kernel):
422
465
  self.applied_opts_cache = self.applied_opts[:]
423
466
  return self
424
467
 
468
+ def render_block(self, outputs:Tuple[LazyOp, ...], global_idxs, local_idxs, upcast_idxs, full_upcast_idxs,
469
+ alias_buf_idxs:DefaultDict[LazyOp,List[Tuple[int,int,List[NumNode|Variable]]]],
470
+ loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], accs:Dict[LazyOp,List[UOp]]) -> List[List[UOp]]:
471
+ reduceops = dedup(x for x in outputs if x.op in ReduceOps)
472
+ assert len(reduceops) <= 1, "max one reduceop per block"
473
+ reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
474
+ fake_reduce_idxs = [x*0 for x in reduce_idxs]
475
+
476
+ if len(reduceops) != 0:
477
+ # TODO: delete render_reduceop and move the logic for group_for_reduces to Block
478
+ nlidx, nuidx = self.render_reduceop((r:=reduceops[0]),accs,loaded_buffers,\
479
+ global_idxs,local_idxs,upcast_idxs,full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[r])
480
+
481
+ # all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
482
+ # been rewritten with fake end_local_idxs.
483
+ if r is self.reduceops[-1]: local_idxs[:], upcast_idxs[:] = nlidx, nuidx
484
+ return [accs[r]]
485
+
486
+ # load latebufs
487
+ loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
488
+ for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})
489
+ # run late AST (without the store)
490
+ store_vals = {op.arg.idx:self.ast_parse(op.src[0], accs, None, loaded_buffers) for op in self.ast}
491
+ return [self.global_store(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) for i, val in store_vals.items()]
492
+
425
493
  def ast_parse(self, x:LazyOp, accs:Dict[LazyOp, List[UOp]], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], reduce_acc:Optional[List[UOp]]=None, cache=None) -> List[UOp]: # noqa: E501
426
494
  if cache is None: cache = {}
427
495
  if x in cache: return cache[x]
428
496
  if x.op in BufferOps: return loaded_buffers[x.arg]
429
497
  if x.op in [UnaryOps.CAST, UnaryOps.BITCAST]:
430
- return [self.uops.add(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST,
498
+ return [UOp(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST,
431
499
  self.get_base_dtype(x.arg), (u,)) for u in self.ast_parse(x.src[0], accs, offs, loaded_buffers)]
432
500
  if x.op in ReduceOps and reduce_acc is None:
433
- assert offs is None, "not available if we aren't doing reduce"
434
- return accs[x]
501
+ return [accs[x][i] for i in offs] if offs else accs[x]
435
502
 
436
503
  values = [self.ast_parse(v, accs, offs, loaded_buffers, cache=cache) for v in x.src]
437
504
  ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}
@@ -444,17 +511,18 @@ class Linearizer(Kernel):
444
511
  ret.append(acc[off])
445
512
  for off in range(len(acc)):
446
513
  if input_acc[off] != acc[off]:
447
- acc[off] = self.uops.add(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]))
448
- else: ret = [UOp.alu(x.op, *vin) for vin in zip(*values)]
514
+ acc[off] = UOp(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]))
515
+ else: ret = [UOp.alu(x.op, *src) for src in zip(*values)]
449
516
  cache[x] = ret
450
517
  return ret
451
518
 
452
519
  def to_program(self) -> Program:
453
520
  self.linearize()
454
521
  info = get_lazyop_info(self.ast[0])
455
- src = self.opts.render(to_function_name(self.name), self.uops)
522
+ src = self.opts.render(name:=to_function_name(self.name), self.uops)
523
+ if getenv("RUN_PROCESS_REPLAY"): diskcache_put("process_replay", id(self), (self.ast, self.opts, self.applied_opts, name, src))
456
524
  ops, mem = self.uops.flops_mem()
457
- run_count = prod((self.global_size if self.global_size else []) + (self.local_size if self.local_size else []))
525
+ run_count = prod((self.global_size or []) + (self.local_size or []))
458
526
  # NOTE: we use min here to ignore the indexing FLOPS
459
527
  return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
460
528
  self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))