tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/__init__.py +0 -0
  3. tinygrad/codegen/kernel.py +253 -225
  4. tinygrad/codegen/linearizer.py +398 -436
  5. tinygrad/codegen/uops.py +451 -0
  6. tinygrad/device.py +268 -274
  7. tinygrad/dtype.py +56 -40
  8. tinygrad/engine/__init__.py +0 -0
  9. tinygrad/engine/graph.py +100 -0
  10. tinygrad/engine/jit.py +198 -0
  11. tinygrad/engine/realize.py +192 -0
  12. tinygrad/engine/schedule.py +370 -0
  13. tinygrad/engine/search.py +199 -0
  14. tinygrad/{mlops.py → function.py} +40 -32
  15. tinygrad/helpers.py +144 -46
  16. tinygrad/lazy.py +143 -242
  17. tinygrad/multi.py +173 -0
  18. tinygrad/nn/__init__.py +180 -9
  19. tinygrad/nn/datasets.py +8 -0
  20. tinygrad/nn/optim.py +106 -28
  21. tinygrad/nn/state.py +87 -19
  22. tinygrad/ops.py +104 -45
  23. tinygrad/renderer/__init__.py +65 -0
  24. tinygrad/renderer/assembly.py +269 -0
  25. tinygrad/renderer/cstyle.py +308 -210
  26. tinygrad/renderer/llvmir.py +119 -124
  27. tinygrad/runtime/__init__.py +0 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +13403 -0
  29. tinygrad/runtime/autogen/comgr.py +891 -0
  30. tinygrad/runtime/autogen/cuda.py +5923 -0
  31. tinygrad/runtime/autogen/hip.py +5909 -0
  32. tinygrad/runtime/autogen/hsa.py +5893 -0
  33. tinygrad/runtime/autogen/io_uring.py +1486 -0
  34. tinygrad/runtime/autogen/kfd.py +812 -0
  35. tinygrad/runtime/autogen/nv_gpu.py +33597 -0
  36. tinygrad/runtime/autogen/opencl.py +1795 -0
  37. tinygrad/runtime/driver/__init__.py +0 -0
  38. tinygrad/runtime/driver/hip_comgr.py +56 -0
  39. tinygrad/runtime/graph/__init__.py +0 -0
  40. tinygrad/runtime/graph/clang.py +39 -0
  41. tinygrad/runtime/graph/cuda.py +59 -54
  42. tinygrad/runtime/graph/hcq.py +187 -0
  43. tinygrad/runtime/graph/metal.py +37 -41
  44. tinygrad/runtime/ops_amd.py +550 -0
  45. tinygrad/runtime/ops_clang.py +16 -14
  46. tinygrad/runtime/ops_cuda.py +129 -37
  47. tinygrad/runtime/ops_disk.py +111 -43
  48. tinygrad/runtime/ops_gpu.py +52 -50
  49. tinygrad/runtime/ops_llvm.py +36 -56
  50. tinygrad/runtime/ops_metal.py +41 -24
  51. tinygrad/runtime/ops_npy.py +9 -0
  52. tinygrad/runtime/ops_nv.py +625 -0
  53. tinygrad/runtime/ops_python.py +208 -0
  54. tinygrad/shape/__init__.py +0 -0
  55. tinygrad/shape/shapetracker.py +46 -107
  56. tinygrad/shape/symbolic.py +99 -98
  57. tinygrad/shape/view.py +162 -45
  58. tinygrad/tensor.py +2492 -483
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
  60. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
  61. tinygrad-0.9.1.dist-info/RECORD +63 -0
  62. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  63. tinygrad/features/image.py +0 -93
  64. tinygrad/features/multi.py +0 -103
  65. tinygrad/features/search.py +0 -160
  66. tinygrad/graph.py +0 -106
  67. tinygrad/jit.py +0 -152
  68. tinygrad/realize.py +0 -50
  69. tinygrad/runtime/graph/hip.py +0 -24
  70. tinygrad/runtime/ops_cpu.py +0 -45
  71. tinygrad/runtime/ops_hip.py +0 -97
  72. tinygrad/runtime/ops_torch.py +0 -49
  73. tinygrad-0.8.0.dist-info/RECORD +0 -41
  74. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,48 +1,89 @@
1
1
  from __future__ import annotations
2
- from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Sequence, Final, Set, Iterator
2
+ from typing import List, Tuple, Optional, Type, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence, Callable
3
3
  import itertools, math, functools
4
4
  from collections import defaultdict
5
- from enum import Enum, auto
6
- from dataclasses import dataclass
7
5
 
8
6
  from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
9
- from tinygrad.helpers import colored, DEBUG, prod, getenv, all_same, to_function_name, flatten
7
+ from tinygrad.helpers import colored, DEBUG, dedup, diskcache_put, prod, getenv, to_function_name, flatten
10
8
  from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
11
9
  from tinygrad.shape.shapetracker import ShapeTracker
12
- from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
10
+ from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node, sint
13
11
  from tinygrad.codegen.kernel import LocalBuffer, Kernel
14
- from tinygrad.features.image import to_image_idx
15
-
16
- # bottom ones are asm only
17
- class UOps(Enum):
18
- LOOP = auto(); IF = auto(); END = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
19
- DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
20
- LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702
21
- ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
22
-
23
- @dataclass(eq=False)
24
- class UOp:
25
- uop: UOps
26
- dtype: Optional[DType]
27
- vin: Tuple[UOp, ...]
28
- arg: Any
29
- def __repr__(self):
30
- return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
31
-
32
- def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
33
- local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[0:maxdim-1] + (prod(local_dims[maxdim-1:]),) if len(local_dims) > maxdim else local_dims)] # noqa: E501
34
- if maxdim != 0 and len(local_dims) > maxdim:
35
- dd = local_idxs[maxdim-1]
36
- nli = []
37
- for s in local_dims[maxdim-1:][::-1]:
38
- nli.append(dd % s)
39
- dd //= s
40
- local_idxs = local_idxs[0:maxdim-1] + nli[::-1]
41
- return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
42
-
43
- def expand_idx(node:Node) -> Union[Variable, NumNode]: return next((v for v in node.vars() if v.expr is None), NumNode(0))
12
+ from tinygrad.renderer import Program
13
+
14
+ from tinygrad.codegen.uops import UOps, UOp, UOpGraph
15
+
16
+ def get_grouped_dims(prefix:str, off:int, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse_dims:bool=False):
17
+ """ Maps all global/local dims onto global/local sizes and returns the idxs, loop_idxs and sizes.
18
+
19
+ * If there are fewer dims than size, size will be padded with 1s to the length of max_sizes.
20
+ * If there are more dims than size, dims will be collapsed onto size starting from left-most (i.e. onto x, then y, then z).
21
+ * If the dim is too large for the size, the dim will be split between adjacent size axes space permitting, otherwise assert
22
+
23
+ Keyword arguments:
24
+ prefix -- the prefix to use for the size Variable names.
25
+ off -- the starting index for the size Variable names.
26
+ dims -- the global or local dims of the full shape.
27
+ max_sizes -- the maximum values for each size in (x, y, z) order.
28
+ reverse_dims -- reverse the order of the dims as they are mapped into size, i.e. if True, the right dim will go to the left size (.x).
29
+ """
30
+
31
+ # check the edge cases on max_sizes
32
+ if max_sizes is None: max_sizes = tuple([0xFFFFFFFFFFFFFFFF] * len(dims))
33
+ assert len(max_sizes) > 0 or len(dims) == 0, f"{prefix} dims should be empty because no size axes available"
34
+ if len(max_sizes) == 0: return [], [], None
35
+
36
+ # initialize the map of dims to size with a single dim in each size axis
37
+ # TODO: support sint properly
38
+ size_dims:List[List[Tuple[int, sint, sint]]] = [[(dim_idx, dim, dim if isinstance(dim, int) else dim.max+1)] for dim_idx, dim in enumerate(dims)]
39
+
40
+ # reverse the order of the dims to size map, if desired (currently for globals where smallest stride is on the right)
41
+ # TODO: remove reverse_dims, the mapping of dims to size for globals should be cosearched with memory layouts for optimal peformance
42
+ if reverse_dims: size_dims = size_dims[::-1]
43
+
44
+ # ensure that the initial dims initially fit the valid size axes
45
+ for size_idx in range(min(len(max_sizes), len(size_dims))):
46
+ # if the initial dim is too large, split the dim to separate size axes, if possible
47
+ dim_idx, dim, dim_max = size_dims[size_idx][0]
48
+ if dim_max <= (max_sz:=max_sizes[size_idx]): continue
49
+ assert isinstance(dim, int), "variable shape too large for size"
50
+ for factor in range(2, int(dim**0.5)+1):
51
+ if dim % factor == 0 and dim // factor <= max_sz:
52
+ size_dims = size_dims[:size_idx] + [[(dim_idx, dim//factor, dim//factor)], [(dim_idx, factor, factor)]] + size_dims[size_idx+1:]
53
+ break
54
+ assert size_dims[size_idx][0][2] <= max_sz, f"dim at {size_idx} too large and non-factorable: {dim} > {max_sz}"
55
+
56
+ # compress the extra dims, collapsing them onto the left-most valid size axis
57
+ cur_size_idx = 0
58
+ while len(size_dims) > len(max_sizes):
59
+ if prod([dim_max for (_, _, dim_max) in size_dims[cur_size_idx]])*size_dims[cur_size_idx+1][0][2] <= max_sizes[cur_size_idx]:
60
+ size_dims = size_dims[:cur_size_idx] + [size_dims[cur_size_idx] + size_dims[cur_size_idx+1]] + size_dims[cur_size_idx+2:]
61
+ elif cur_size_idx < len(max_sizes)-1: cur_size_idx += 1
62
+ else: raise AssertionError(f"cannot fit dims in size: {dims=} {max_sizes=}")
63
+
64
+ # construct the final dim idx variables from the the portions of the size variables
65
+ sizes, idxs = [prod([dim for (_, dim, _) in size_dim]) for size_dim in size_dims], [NumNode(0)] * len(dims)
66
+ size_vars = loop_idxs = [Variable(f"{prefix}{len(sizes)-1-(i+off) if reverse_dims else i+off}", 0, s-1) for i,s in enumerate(sizes)]
67
+ for size_idx, size_var in enumerate(size_vars):
68
+ for dim_idx, dim, _ in size_dims[size_idx]:
69
+ idxs[dim_idx] += (size_var % dim) * (idxs[dim_idx].max+1)
70
+ size_var //= dim
71
+
72
+ # pad the final sizes array to the proper length if necessary
73
+ return idxs, [x for x in loop_idxs if not isinstance(x, NumNode)], sizes + [1]*(len(max_sizes)-len(sizes))
74
+
75
+ def expand_idx(node:Node) -> Union[Variable, NumNode]: return next((v for v in node.vars() if v.expr.startswith("_uidx")), NumNode(0))
76
+ def expand_idxs(nodes:Sequence[Node]) -> Tuple[Union[Variable, NumNode], ...]:
77
+ eidxs = [expand_idx(node) for node in nodes]
78
+ return tuple([v if v not in eidxs[:j] else NumNode(0) for j, v in enumerate(eidxs)]) # take only first occurrence of expand variable
44
79
  def iter_idxs(idxs:Tuple[Union[Variable, NumNode], ...]) -> Iterator[Tuple[int,...]]:
45
- yield from (x[::-1] for x in itertools.product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]]))
80
+ yield from (x[::-1] for x in itertools.product(*[list(range(v.min, v.max + 1)) for v in idxs[::-1]]))
81
+
82
+ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
83
+ idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
84
+ # TODO: bring back the valid removal logic (correct!)
85
+ if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid)
86
+ return (idx, idy), valid
46
87
 
47
88
  # expand a Node into List[Node] that enumerates the underlying Variables from min to max
48
89
  # expand increments earlier variables faster than later variables (as specified in the argument)
@@ -51,95 +92,90 @@ def expand_node(node:Node, idxs:Optional[Tuple[Union[Variable, NumNode], ...]]=N
51
92
  if idxs is None: idxs = (expand_idx(node),)
52
93
  return [node.substitute({k:v for k,v in zip(idxs, (NumNode(x) for x in rep)) if isinstance(k, Variable)}) for rep in iter_idxs(idxs)]
53
94
 
54
- class Linearizer(Kernel):
55
- def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32):
56
- render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
57
- return self.uop(UOps.ALU, dtype, (a, render_b), op)
58
-
59
- # NOTE: the consts have to be cached for deduping of downstream uops to work
60
- def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp:
61
- return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)
95
+ def variable_to_uop(x, ctx=None) -> UOp:
96
+ if isinstance(x, int): return UOp.const(dtypes.int, x)
97
+ return x.render(render_ops, ctx)
62
98
 
63
- def cast(self, val: UOp, dtype) -> UOp: return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
99
+ render_ops: Dict[Type, Callable[..., UOp]] = {
100
+ NumNode: lambda self, ops, ctx: UOp.const(dtypes.int, self.b),
101
+ Variable: lambda self, ops, ctx: ctx[self.expr] if self.expr in ctx else UOp(UOps.DEFINE_VAR, dtypes.int, (), self),
102
+ MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
103
+ DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
104
+ ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
105
+ LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
106
+ SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+variable_to_uop(b, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
107
+ AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*variable_to_uop(b, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
64
108
 
109
+ class Linearizer(Kernel):
65
110
  def get_reduce_acc(self, reduceop:LazyOp):
66
- dtype = get_lazyop_info(reduceop).dtype
67
- if reduceop.op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0
68
- elif reduceop.op == ReduceOps.MAX:
69
- if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
70
- return -math.inf if dtypes.is_float(dtype) else False
111
+ if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
112
+ if reduceop.op is ReduceOps.MAX:
113
+ if dtypes.is_int(reduceop.dtype): return 0 if dtypes.is_unsigned(reduceop.dtype) else -2**(reduceop.dtype.itemsize*8-1)
114
+ return -math.inf if dtypes.is_float(reduceop.dtype) else False
71
115
 
72
116
  # NOTE: once images are loaded, we uop them as their base float
73
- def get_base_dtype(self, dt:DType): return dt.base if isinstance(dt, ImageDType) else dt
74
-
75
- render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
76
- MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
77
- DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
78
- ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
79
- LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT, dtype=dtypes.bool),
80
- SumNode: lambda self,ops,ctx:
81
- functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
82
- AndNode: lambda self,ops,ctx:
83
- functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
84
-
85
- def global_load(self, i:int, idxs:Sequence[Node], acc=None, barrier:Optional[UOp]=None) -> List[UOp]:
117
+ def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt
118
+
119
+ def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Optional[UOp]=None, loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]:
86
120
  buf = self.bufs[i]
87
- localtype = self.get_base_dtype(buf.dtype if acc is None else get_lazyop_info(self.reduceop).dtype)
88
- const = buf.val if isinstance(buf, ConstBuffer) else acc
121
+ localtype = self.get_base_dtype(buf.dtype if acc is None else acc.dtype)
122
+ const = buf.val if isinstance(buf, ConstBuffer) else None
89
123
 
90
- def rename_var(v: Union[Variable, NumNode], expr: str): return v if isinstance(v, NumNode) else Variable(expr, v.min, v.max)
91
- expand_vars = tuple([rename_var(expand_idx(idx), f"_uidx{j}") for j, idx in enumerate(idxs)])
92
- fake_idxs = [idx.substitute({eidx: ev}) if isinstance(eidx:=expand_idx(idx), Variable) else idx for idx, ev in zip(idxs, expand_vars)]
124
+ expand_vars = expand_idxs(idxs)
93
125
 
94
126
  dim, amt = None, 1
95
127
  # float 4 grouping
96
128
  if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [4,2]:
97
129
  dim, amt = upcast_dim[0], len(float4_expand)
98
- g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs[:dim] + [float4_expand[0]] + fake_idxs[dim+1:])
130
+ g_idx, g_valid = self.sts[i].expr_idxs(idxs[:dim] + [float4_expand[0]] + idxs[dim+1:])
99
131
  # do not use float4 if idx is not aligned
100
132
  if g_idx != (g_idx//amt*amt): dim, amt = None, 1
101
133
  if dim is None:
102
- g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs)
134
+ g_idx, g_valid = self.sts[i].expr_idxs(idxs)
135
+ # todo: multioutput test with different output valids to add if acc is None: g_valid = NumNode(1)
103
136
 
104
137
  if amt > 1: localtype = localtype.vec(amt)
105
- e_idxs, e_valids = expand_node(g_idx, expand_vars), expand_node(g_valid, expand_vars)
138
+ e_idxs, e_valids = expand_node(g_idx, expand_vars), expand_node(g_valid, expand_vars) # pylint: disable=possibly-used-before-assignment
106
139
 
107
140
  ret = []
108
- invalid_value = 0 if dtypes.is_int(buf.dtype) else 0.0
141
+ invalid_value = 0
142
+ acc_count = 0
109
143
  for idx, valid, rep_idx in zip(e_idxs, e_valids, iter_idxs(expand_vars)):
110
144
  this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid)
111
- key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
145
+ key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
112
146
  if key not in self.load_cache:
113
147
  if acc is not None:
114
- self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, (), this_const, cachable=False)
148
+ self.load_cache[key] = UOp(UOps.DEFINE_ACC, localtype, (UOp.const(localtype.scalar(), self.get_reduce_acc(acc)), *loop_ctx), (i, acc_count))
149
+ acc_count += 1
115
150
  elif this_const is not None:
116
- self.load_cache[key] = self.const(this_const, localtype)
151
+ self.load_cache[key] = UOp.const(localtype, this_const)
117
152
  if valid.min == 0 and valid.max == 1:
118
- valid_rendered = valid.render(self.render_ops, self)
119
- self.load_cache[key] = self.uop(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE) # noqa: E501
153
+ valid_rendered = valid.render(render_ops, self.loop_uops)
154
+ self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_rendered, self.load_cache[key], UOp.const(localtype, invalid_value))
120
155
  elif isinstance(buf.dtype, ImageDType):
121
156
  buf_uop = self.buf_uops[i]
122
157
  assert buf_uop is not None, f"buffer {i} wasn't UOped"
123
158
  image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
124
- rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
125
- valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, buf.dtype.base.vec(4))) if valid.min == 0 else tuple()
126
- self.load_cache[key] = self.uop(UOps.LOAD, buf.dtype.base.vec(4), (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
159
+ rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
160
+ valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
161
+ self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4),
162
+ (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
127
163
  if localtype == localtype.scalar():
128
164
  idx_small = idx%4
129
- res = idx_small.render(self.render_ops, self)
130
- out = self.uop(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
165
+ res = idx_small.render(render_ops, self.loop_uops)
166
+ out = UOp(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
131
167
  for ix in range(idx_small.max, idx_small.min, -1):
132
- rvv = self.uop(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
133
- sel = self.uop(UOps.ALU, dtypes.bool, (res, self.const(ix)), BinaryOps.CMPLT)
134
- out = self.uop(UOps.ALU, localtype, (sel, rvv, out), TernaryOps.WHERE)
168
+ rvv = UOp(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
169
+ sel = UOp.alu(BinaryOps.CMPLT, res, UOp.const(dtypes.int, ix))
170
+ out = UOp.alu(TernaryOps.WHERE, sel, rvv, out)
135
171
  self.load_cache[key] = out
136
172
  else:
137
173
  buf_uop = self.buf_uops[i]
138
174
  assert buf_uop is not None, f"buffer {i} wasn't UOped"
139
- rendered_idx = idx.render(self.render_ops, self)
140
- valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, localtype)) if valid.min == 0 else tuple()
141
- self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
142
- ret.append(self.uop(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
175
+ rendered_idx = idx.render(render_ops, self.loop_uops)
176
+ valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
177
+ self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
178
+ ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
143
179
  return ret
144
180
 
145
181
  def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
@@ -147,12 +183,12 @@ class Linearizer(Kernel):
147
183
  buf_uop = self.buf_uops[i]
148
184
  assert buf_uop is not None, f"buffer {i} wasn't UOped"
149
185
 
150
- expanded_nodes = [expand_node(idx) for idx in idxs]
151
- _idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
186
+ expand_vars = expand_idxs(idxs)
187
+ _idxs = zip(*[expand_node(idx, expand_vars) for idx in idxs]) if idxs else [tuple()] # transpose
152
188
  store_offset = dict(zip(_idxs, store))
153
189
 
154
190
  # float4 grouping
155
- if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expanded_nodes[upcast_dim[0]]) in [2,4]:
191
+ if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [2,4]:
156
192
  grouped_store_offset = defaultdict(list)
157
193
  for k in store_offset:
158
194
  _idx = k[:upcast_dim[0]] + (float4_expand[0],) + k[upcast_dim[0]+1:]
@@ -162,61 +198,221 @@ class Linearizer(Kernel):
162
198
  amt = len(grouped)
163
199
  idx, valid = self.sts[i].expr_idxs(k)
164
200
  assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
165
- store_offset_new[k] = self.uop(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
201
+ store_offset_new[k] = UOp(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
166
202
  store_offset = store_offset_new
167
203
 
168
204
  stores = []
169
- for idx, var in store_offset.items():
170
- idx, valid = self.sts[i].expr_idxs(idx)
205
+ for _idx, var in store_offset.items():
206
+ idx, valid = self.sts[i].expr_idxs(_idx)
171
207
  if isinstance(buf.dtype, ImageDType):
172
- idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
173
- rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in idx))
208
+ image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
209
+ rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), \
210
+ tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
174
211
  else:
175
- rendered_idx = idx.render(self.render_ops, self)
176
- if valid.min == 1: stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var)))
177
- else: stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
212
+ rendered_idx = idx.render(render_ops, self.loop_uops)
213
+ # TODO: let UPat check this once it's fast
214
+ if valid.min == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var)))
215
+ else: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(render_ops, self.loop_uops))))
178
216
  return stores
179
217
 
218
+ # render loop
219
+ def render_loop(self, xx:List[Variable], depth:int, reduce:bool) -> Tuple[UOp, ...]:
220
+ new_loops = {x.expr:UOp(UOps.RANGE, dtypes.int32, (
221
+ UOp.const(dtypes.int, x.min) if isinstance(x.min, int) else cast(Node, x.min).render(render_ops, self.loop_uops),
222
+ UOp.const(dtypes.int, x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(render_ops, self.loop_uops)), arg=(depth,i,reduce)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
223
+ self.loop_uops.update(new_loops)
224
+ return tuple(new_loops.values())
225
+
226
+ def index_local_aliases(self, global_idxs, local_idxs, reduce_idxs, upcast_idxs, full_upcast_idxs):
227
+ def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
228
+ replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
229
+ for s in local_sizes:
230
+ thread_idxs.append(thread_idx % s)
231
+ thread_idx //= s
232
+ for alias in aliases:
233
+ full_var, full_var_sz = NumNode(0), 1
234
+ if alias[0] != 0:
235
+ for i in alias:
236
+ next_var = local_idxs[i-1] if i > 0 else thread_idxs[-i-1]
237
+ full_var += next_var * full_var_sz
238
+ full_var_sz *= next_var.max+1
239
+ replace_idxs.append(full_var)
240
+ return replace_idxs
241
+
242
+ # compute local aliases
243
+ alias_buf_idxs: DefaultDict[LazyOp, List[Tuple[int, int, List]]] = defaultdict(list)
244
+ for op, local_alias in self.local_alias.items():
245
+ for i in local_alias:
246
+ localbuf_idx = self.bufs.index(local_alias[i])
247
+ buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
248
+ if (tc:=self.tensor_core):
249
+ min_alias_idx = min(local_alias.keys())
250
+ replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
251
+ for n in range(len(tc.threads)):
252
+ buf_idxs[self.global_dims+n] = replace_input_idxs[n] # replace locals
253
+ for n in range(tc.num_upcasts()):
254
+ buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
255
+ if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
256
+ alias_buf_idxs[op].append((i, localbuf_idx, buf_idxs))
257
+ # modify idxs if necessary for TC
258
+ if (tc:=self.tensor_core):
259
+ replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
260
+ for n in range(len(tc.threads)):
261
+ local_idxs[n] = replace_acc_idxs[n] # replace locals
262
+ for n in range(len(replace_acc_idxs)-len(tc.threads)):
263
+ upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
264
+ if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+upcast_idxs}")
265
+ return alias_buf_idxs
266
+
267
+ def render_reduceop(self, reduceop:LazyOp, accs:Dict[LazyOp, List[UOp]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]],
268
+ global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, reduce_idxs, fake_reduce_idxs,
269
+ alias_buf_idxs:List[Tuple[int, int, List]]) -> Tuple[List[NumNode|Variable], List[NumNode|Variable]]:
270
+ # reduce loop
271
+ loop_ctx = self.render_loop(reduce_idxs, (i:=self.reduceops.index(reduceop))*2+2, True)
272
+
273
+ # define accumulator - modify idxs if necessary for TC
274
+ out_buf = -len(self.reduceops)+i if self.group_for_reduces else 0
275
+ accs[reduceop] = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
276
+
277
+ # store local aliases
278
+ locals_to_store = [(localbuf_idx, buf_idxs, self.global_load(i, buf_idxs)) for i, localbuf_idx, buf_idxs in alias_buf_idxs]
279
+
280
+ if (tc:=self.tensor_core):
281
+ # run tensor cores AST
282
+ wmma_sz = [prod(l) for l in tc.thread_local_sizes]
283
+ def upcast_strides(buf:int):
284
+ strides, next_ = [], 1
285
+ for (sz, stride, _) in self.upcasted_axis(buf)[tc.num_upcasts():]:
286
+ strides.append((0 if stride == 0 else next_, sz))
287
+ next_ *= 1 if stride == 0 else sz
288
+ return strides
289
+ upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
290
+ # cast initial accs
291
+ wmmas = [UOp(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
292
+ for x in range(0, len(accs[reduceop]), wmma_sz[2])]
293
+ for it in [x[::-1] for x in itertools.product(*list([range(sz) for _,sz in upcasts[0]][::-1]))]:
294
+ offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(it, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
295
+ ops = (UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
296
+ UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
297
+ wmmas[(wmma_idx:=offs[2]//wmma_sz[2])])
298
+ # TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid
299
+ wmmas[wmma_idx] = UOp(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev))
300
+ # phi the last wmmas back to accs
301
+ accs[reduceop] = [UOp(UOps.PHI, tc.dtype_out, (acc, UOp(UOps.GEP, tc.dtype_out, (wmmas[z//wmma_sz[2]],), z%wmma_sz[2])))
302
+ for z, acc in enumerate(accs[reduceop])]
303
+ else:
304
+ assert not locals_to_store, "storing locals isn't supported here"
305
+
306
+ # load earlybufs
307
+ loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[reduceop][i]) if i in self.local_alias else i,
308
+ global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs})
309
+
310
+ def gate_acc(r, idxs): return [
311
+ UOp.alu(TernaryOps.WHERE, valid.render(render_ops, self.loop_uops), acc, UOp.const(r.dtype, 0)) if valid.min == 0 and valid.max == 1 else acc
312
+ for valid, acc in zip(expand_node(self.sts[self.full_buf_index].expr_idxs(idxs)[1], expand_idxs(idxs)), accs[r])]
313
+ local_accs = {r: gate_acc(r,global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for r in accs}
314
+
315
+ # run early AST (with reduce)
316
+ self.ast_parse(reduceop, local_accs, self.acc_offsets(self.full_buf_index), loaded_buffers, reduce_acc=accs[reduceop])
317
+
318
+ # end the reduce loop
319
+ self.load_cache.clear()
320
+
321
+ # end the local loop, do the local reduce
322
+ if self.group_for_reduces:
323
+ fake_global_idxs = [x*0 for x in global_idxs]
324
+ stores = self.global_store(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop]) # store accumulators
325
+ barrier = UOp(UOps.BARRIER, None, tuple(stores))
326
+ if self.opts.has_local:
327
+ fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
328
+ fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
329
+ if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(render_ops, self.loop_uops)
330
+ barrier = UOp(UOps.IF, None, (if_cond, barrier))
331
+
332
+ # create new late reduce local loops and replace local_idxs that have been used
333
+ end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501
334
+ local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
335
+
336
+ # if any group_for_reduce items aren't reduces, upcast them here
337
+ for j in self.upcast_in_mid_reduce_axes:
338
+ self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
339
+ self.upcast()
340
+ self.group_for_reduces -= 1
341
+ local_idxs = local_idxs[:-1]
342
+ end_local_idxs = end_local_idxs[:-1]
343
+ # regenerate upcast_idxs
344
+ upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
345
+
346
+ # NOTE: this structure is the same as the reduce op above
347
+
348
+ # late reduce loop
349
+ loop_ctx = self.render_loop(end_local_idxs, i*2+3, True)
350
+
351
+ # define late accumulator
352
+ accs[reduceop] = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
353
+
354
+ # load localbufs
355
+ loaded_buffers[self.bufs[out_buf]] = self.global_load(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
356
+
357
+ # there's no AST here (and there's no shape for the reduce LazyOp)
358
+ self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[out_buf]),)),\
359
+ accs, self.acc_offsets(-1), loaded_buffers, reduce_acc=accs[reduceop])
360
+
361
+ # end the late reduce loop
362
+ self.load_cache.clear()
363
+
364
+ if reduceop is not self.reduceops[-1]:
365
+ for j in self.upcast_in_mid_reduce_axes:
366
+ self.upcasted -= 1
367
+ self.group_for_reduces += 1
368
+ assert self.buf_uops[out_buf] is not None, "Local reduce buf must have been uoped at this point"
369
+ fake_local_idxs = local_idxs[:self.local_dims] + [x*0 for x in local_idxs[self.local_dims:]]
370
+ stores = self.global_store(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop])
371
+ barrier = UOp(UOps.BARRIER, None, tuple(stores))
372
+ accs[reduceop] = self.global_load(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
373
+ return local_idxs[:self.local_dims] + [NumNode(0) for _ in range(self.group_for_reduces)], upcast_idxs
374
+
180
375
  kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
181
- def linearize(self):
376
+ def linearize(self) -> Linearizer:
182
377
  # no new opts and we already ran? skip relinearizing
183
378
  if self.applied_opts == self.applied_opts_cache: return self
184
379
 
185
- # save backups
186
- sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduce[:], self.upcasted
187
-
188
- # global uop cache
189
- self.saved_exprs: Dict[Tuple, UOp] = dict()
380
+ # late alias the tensor core buffers
381
+ if (tc:=self.tensor_core) and self.tensor_core_opts is not None:
382
+ alias_pattern = [0]*(self.global_dims) + [2]*(len(tc.threads)) + [0]*(self.local_dims-len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) # noqa: E501
383
+ for op, tc_bufs in self.bufs_for_tensor_core.items():
384
+ for tc_buf in tc_bufs: self.alias_buffer(op, tc_buf, alias_pattern)
190
385
 
191
- # limit dims if we need to
192
- if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max)
386
+ # save backups
387
+ sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduces, self.upcasted
193
388
 
194
389
  # uops
195
- self.uops: List[UOp] = []
196
390
  self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
197
391
  self.loop_uops: Dict[str, UOp] = {}
198
392
 
199
393
  # add global buffers
200
394
  for i,buf in enumerate(self.bufs):
201
395
  if isinstance(buf, MemBuffer):
202
- self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), f"data{buf.idx}")
203
- # add var vals
204
- for var in self.ast.vars():
205
- assert var.expr is not None
206
- self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), var.expr)
396
+ self.buf_uops[i] = UOp(UOps.DEFINE_GLOBAL,
397
+ buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
398
+ (buf.idx, any(buf.idx == x.idx for x in self.outbufs)))
207
399
  # define local buffers
208
- for lb in self.local_alias.values():
209
- self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size))
400
+ for aliases in self.local_alias.values():
401
+ for lb in aliases.values(): self.buf_uops[self.bufs.index(lb)] = UOp(UOps.DEFINE_LOCAL, PtrDType(lb.dtype),
402
+ (), (lb.name, self.sts[self.bufs.index(lb)].size))
210
403
  # add a local buffer for multistage reduce. # TODO: use local alias
211
- if self.group_for_reduce:
212
- # TODO: the strides of this can be controlled
213
- self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
214
- temp_dtype = self.get_base_dtype(get_lazyop_info(self.reduceop).dtype)
215
- self.bufs.append(LocalBuffer("temp", self.sts[-1].size, temp_dtype))
216
- self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size)))
404
+ if self.group_for_reduces:
405
+ for i in range(len(self.reduceops)):
406
+ # TODO: the strides of this can be controlled
407
+ self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
408
+ temp_dtype = self.get_base_dtype(cast(LazyOp, self.reduceop).dtype)
409
+ self.bufs.append(LocalBuffer(name:=f"temp{i if len(self.reduceops) > 1 else ''}", buf_size:=self.sts[-1].size, temp_dtype))
410
+ self.buf_uops.append(UOp(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), (name, buf_size)))
217
411
 
218
412
  # kernel name (before late upcast)
219
- self.name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
413
+ self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
414
+ (f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \
415
+ colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
220
416
 
221
417
  # name the function something unique
222
418
  Linearizer.kernel_cnt[(function_name := to_function_name(self.name))] += 1
@@ -224,343 +420,109 @@ class Linearizer(Kernel):
224
420
  self.name = self.name+colored(suffix, 'BLACK')
225
421
 
226
422
  # define indexes
227
- global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
228
- local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+len(self.group_for_reduce)], 3 if self.opts.has_local else 0) # noqa: E501
229
- full_upcast_idxs = [Variable(None, 0, s-1) for s in self.full_shape[self.shape_len-self.upcasted:]]
230
- upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
231
-
232
- # global and local loops
233
- def render_loop(xx:List[Variable]) -> Tuple[UOp, ...]:
234
- new_loops = {x.expr:self.uop(UOps.LOOP, dtypes.int32, (
235
- self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
236
- self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
237
- self.loop_uops.update(new_loops)
238
- return tuple(new_loops.values())
239
-
240
- # set global/local size
241
- self.global_size: Optional[List[int]] = None
242
- self.local_size: Optional[List[int]] = None
243
- if self.dont_use_locals:
244
- self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
245
- self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
246
- elif self.opts.has_local:
247
- self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1]
248
- self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
249
- self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) # noqa: E501
423
+ gl_dims = self.full_shape[:self.first_reduce+self.group_for_reduces]
424
+ global_idxs, loop_global_idxs, self.global_size = get_grouped_dims("idx" if self.dont_use_locals else "gidx", 0, gl_dims[:self.global_dims],
425
+ self.opts.global_max, self.opts.has_local)
426
+ local_idxs, loop_local_idxs, self.local_size = get_grouped_dims("lidx", self.global_dims, gl_dims[self.global_dims:],
427
+ self.opts.local_max if self.opts.has_local else (), False)
428
+ upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
429
+ full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
430
+
431
+ # render global and local as specials or a loop
432
+ if self.opts.has_local:
433
+ self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
434
+ if not self.dont_use_locals:
435
+ self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
250
436
  else:
251
- render_loop(loop_global_idxs+loop_local_idxs)
437
+ self.global_size, self.local_size = None, None
438
+ self.render_loop(loop_global_idxs+loop_local_idxs, 1, False)
439
+
440
+ # define idxs for aliased buffers TODO: this doesn't belong in Kernel, but it can't exist in Block either (because of multireduce tensor cores)
441
+ reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
442
+ alias_buf_idxs = self.index_local_aliases(global_idxs,local_idxs,reduce_idxs,upcast_idxs,full_upcast_idxs)
252
443
 
253
444
  # parse AST
254
- loaded_buffers = {}
255
- acc: List[UOp] = []
256
445
  self.load_cache: Dict[str, UOp] = {}
446
+ loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]] = {}
447
+ accs: Dict[LazyOp, List[UOp]] = {}
257
448
 
258
- # reduce op
259
- fake_reduce_idxs: List[Variable] = []
260
- if self.reduceop is not None:
261
- # define indexes
262
- reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)] # noqa: E501
263
- fake_reduce_idxs = [x*0 for x in reduce_idxs]
264
-
265
- # define accumulator
266
- acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))
267
-
268
- if self.tensor_core:
269
- def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
270
- replace_idxs = []
271
- for alias in aliases:
272
- full_var, full_var_sz = NumNode(0), 1
273
- if alias[0] != 0:
274
- for i in alias:
275
- next_var = local_idxs[-i] if i > 0 else Variable(None, 0, local_size-1)
276
- full_var += next_var * full_var_sz
277
- full_var_sz *= next_var.max+1
278
- replace_idxs.append(full_var)
279
- return replace_idxs
280
- replace_acc_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[2], self.tensor_core.thread_local_aliases[2])
281
- for n in range(len(self.tensor_core.threads)):
282
- local_idxs[self.local_dims-len(self.tensor_core.threads)+n] = replace_acc_idxs[n] # replace locals
283
- for n in range(len(replace_acc_idxs)-len(self.tensor_core.threads)):
284
- upcast_idxs[n] = replace_acc_idxs[len(self.tensor_core.threads)+n] # replace upcasts
285
-
286
- # reduce loop
287
- loop_ctx = render_loop(reduce_idxs)
288
-
289
- # barrier for fast GEMM
290
- if self.tensor_core: self.uop(UOps.BARRIER, None, (), cachable=False)
291
-
292
- # compute local aliases
293
- locals_to_store = []
294
- for i in self.local_alias:
295
- localbuf_idx = self.bufs.index(self.local_alias[i])
296
- buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
297
- if self.tensor_core:
298
- min_alias_idx = min(self.local_alias.keys())
299
- replace_input_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[i-min_alias_idx], self.tensor_core.thread_local_aliases[i-min_alias_idx]) # noqa: E501
300
- for n in range(len(self.tensor_core.threads)):
301
- buf_idxs[self.first_reduce-len(self.tensor_core.threads)+n] = replace_input_idxs[n] # replace locals
302
- for n in range(len(replace_input_idxs)-len(self.tensor_core.threads)):
303
- buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(self.tensor_core.threads)+n] # replace upcasts
304
- if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: idxs=", buf_idxs)
305
- ll = self.global_load(i, buf_idxs)
306
- locals_to_store.append((localbuf_idx, buf_idxs, ll))
307
-
308
- # copy in any global buffers
309
- if self.tensor_core:
310
- wmma_sz = self.tensor_core.thread_local_sizes
311
- # calculate the number of local accumulator reduces and render WMMAs: this is bad... this needs to come from someplace else
312
- nx, ny, nacc = (len(locals_to_store[0][2])//wmma_sz[0]), (len(locals_to_store[1][2])//wmma_sz[1]), (len(acc)//wmma_sz[2])
313
- acc_reds = math.isqrt((nx*ny)//nacc)
314
- i, bx, by = 0, nx//acc_reds, ny//acc_reds
315
- for y in range(by):
316
- for x in range(bx):
317
- for j in range(acc_reds):
318
- op1, op2, op3 = locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]], locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]], acc[i:i+wmma_sz[2]] # noqa: E501
319
- if self.opts.device != "HIP":
320
- ops = tuple(op1+op2+op3)
321
- else:
322
- ops = (self.uop(UOps.CAST, dtypes.half.vec(16), tuple(op1)),
323
- self.uop(UOps.CAST, dtypes.half.vec(16), tuple(op2)),
324
- self.uop(UOps.CAST, dtypes.float.vec(8), tuple(op3)))
325
- ret = self.uop(UOps.WMMA, dtypes.float.vec(2) if wmma_sz[2] == 2 else dtypes.float.vec(8), ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,)) # noqa: E501
326
- for z in range(cast(DType, ret.dtype).sz):
327
- acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + loop_ctx)
328
- i += wmma_sz[2]
329
- else:
330
- if locals_to_store:
331
- self.uop(UOps.BARRIER, None, (), cachable=False)
332
- for i, idxs, ll in locals_to_store: self.global_store(i, idxs, ll)
333
- self.uop(UOps.BARRIER, None, (), cachable=False)
334
-
335
- # load earlybufs
336
- loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs}) # noqa: E501
337
-
338
- # run early AST (with reduce)
339
- self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
340
-
341
- # end the reduce loop
342
- self.load_cache.clear()
343
-
344
- # end the local loop, do the local reduce
345
- if self.group_for_reduce:
346
- fake_global_idxs = [x*0 for x in global_idxs]
347
- stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
348
- barrier = self.uop(UOps.BARRIER, None, tuple(stores), cachable=False)
349
- if self.opts.has_local:
350
- fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
351
- fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
352
- if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self)
353
- barrier = self.uop(UOps.IF, None, (if_cond, barrier), cachable=False)
354
-
355
- # create new late reduce local loops and replace local_idxs that have been used
356
- end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] # noqa: E501
357
- local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
358
-
359
- # if any group_for_reduce items aren't reduces, upcast them here
360
- for j in self.upcast_in_mid_reduce_axes:
361
- self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
362
- self.upcast()
363
- self.group_for_reduce.pop()
364
- local_idxs = local_idxs[:-1]
365
- end_local_idxs = end_local_idxs[:-1]
366
- # regenerate upcast_idxs
367
- upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
368
-
369
- # NOTE: this structure is the same as the reduce op above
370
-
371
- # define late accumulator
372
- acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))
373
-
374
- # late reduce loop
375
- loop_ctx = render_loop(end_local_idxs)
376
-
377
- # load localbufs
378
- loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
449
+ # render reduceops by depth
450
+ for reduceop in self.reduceops:
451
+ self.render_block((reduceop, ), global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
452
+ stores = self.render_block(self.ast, global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
379
453
 
380
- # there's no AST here (and there's no shape for the reduce LazyOp)
381
- self.ast_parse(LazyOp(self.reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # noqa: E501
382
-
383
- # end the late reduce loop
384
- self.load_cache.clear()
385
-
386
- # load latebufs
387
- loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer}) # noqa: E501
388
-
389
- # run late AST (without the store)
390
- val = self.ast_parse(self.ast.src[0], acc, None, loaded_buffers)
391
-
392
- # store
393
- self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
394
-
395
- # get PHI node loop scope, link anything using a DEFINE_ACC to the loop as a "parent"
396
- acc_scope: DefaultDict[UOp, List[UOp]] = defaultdict(list)
397
- for u in self.uops:
398
- if u.uop == UOps.PHI:
399
- acc_scope[u.vin[0]] += u.vin[2:]
400
-
401
- # graph helper functions
402
- @functools.lru_cache(None)
403
- def get_recursive_parents(x:UOp, with_phi=False) -> Set[UOp]:
404
- return set.union(set(x.vin), *[get_recursive_parents(p, with_phi) for p in x.vin], set(acc_scope[x]) if with_phi else set())
405
-
406
- def get_recursive_children(x:UOp) -> Set[UOp]:
407
- deps = set([x])
408
- ssize = 0
409
- while ssize != len(deps):
410
- ssize = len(deps)
411
- for u in self.uops:
412
- if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
413
- deps.add(u)
414
- return deps
415
-
416
- def replace_op(old:UOp, new:UOp):
417
- for u in self.uops:
418
- u.vin = tuple(new if x is old else x for x in u.vin)
419
- self.uops.remove(old)
420
-
421
- # fix loop scope, push uops upward out of loop if it does not depend on the loop
422
- loop_stack: List[List[UOp]] = [[]]
423
- for u in self.uops:
424
- if not loop_stack[-1]: loop_stack[-1].append(u)
425
- elif u.uop == UOps.LOOP: loop_stack.append([u])
426
- elif u.uop not in [UOps.CONST, UOps.ALU, UOps.CAST, UOps.LOAD]: loop_stack[-1].append(u)
427
- else:
428
- parents = get_recursive_parents(u, with_phi=True)
429
- # don't push any local buffer because there might have STORE and BARRIER (not considered as parent) between DEFINE_LOCAL and here
430
- if any(u.uop == UOps.DEFINE_LOCAL for u in parents): loop_stack[-1].append(u)
431
- else:
432
- for i in reversed(range(len(loop_stack))):
433
- # check backwards and put the uop in the first encounter with some dependency
434
- if any(x in parents for x in loop_stack[i]) or i == 0:
435
- loop_stack[i].append(u)
436
- break
437
- self.uops = flatten(loop_stack)
438
-
439
- # uops optimization
440
- changed_something = True
441
- while changed_something:
442
- changed_something = False
443
- for u in self.uops:
444
- if u.uop == UOps.PHI and len(u.vin) == 3:
445
- # if the parents of the PHI node don't have the LOOP in their parents, it can be folded
446
- # TODO: ADD becomes a MUL, MAX can just become nothing
447
- if all(x.uop != UOps.LOOP for x in get_recursive_parents(UOp(u.uop, u.dtype, u.vin[0:2], u.arg))) and u.vin[1].arg == BinaryOps.ADD:
448
- if DEBUG >= 4: print(f"removing PHI node {u}")
449
- del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)]
450
- # NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype
451
- loop_len = self.uop(UOps.ALU, u.vin[2].vin[1].dtype, (u.vin[2].vin[1], u.vin[2].vin[0]), BinaryOps.SUB, insert_before=self.uops.index(u))
452
- if loop_len.dtype != u.dtype: loop_len = self.uop(UOps.CAST, u.dtype, (loop_len,), insert_before=self.uops.index(u))
453
- replace_op(u, self.uop(UOps.ALU, u.dtype, (u.vin[1], loop_len,), BinaryOps.MUL, insert_before=self.uops.index(u)))
454
- changed_something = True
455
-
456
- # (recursively) remove childless uops
457
- # NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
458
- UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL}
459
- while 1:
460
- has_child: Set[UOp] = set()
461
- for ru in self.uops:
462
- for vu in ru.vin:
463
- has_child.add(vu)
464
- nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop in UOPS_W_SIDE_EFFECTS]
465
- if len(nu) == len(self.uops): break
466
- if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
467
- self.uops = nu
468
- del nu
469
-
470
- # add UOps.END
471
- for u in self.uops:
472
- if u.uop == UOps.LOOP:
473
- # add END of loops after the last thing that (recursively) depends on them
474
- self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(sorted(list(get_recursive_children(u)), key=self.uops.index)[-1])+1) # noqa: E501
475
- elif u.uop == UOps.IF:
476
- # END any if statements at the end of the uops
477
- self.uop(UOps.END, None, (u,), cachable=False)
454
+ # only the final stores are needed to define the full UOps graph
455
+ self.uops:UOpGraph = UOpGraph(flatten(stores))
478
456
 
479
457
  # maybe graph the uops
480
- if DEBUG >= 5:
481
- for u in self.uops:
482
- print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") # noqa: E501
483
- if getenv("GRAPHUOPS"):
484
- from tinygrad.graph import graph_uops
485
- graph_uops(self.uops)
458
+ if DEBUG >= 5: self.uops.print()
459
+ if getenv("GRAPHUOPS"): self.uops.graph()
486
460
 
487
461
  # restore backups
488
- self.sts, self.group_for_reduce, self.upcasted = sts_backup, gfr_backup, upc_backup
462
+ self.sts, self.group_for_reduces, self.upcasted = sts_backup, gfr_backup, upc_backup
489
463
 
490
464
  # set cache and return
491
465
  self.applied_opts_cache = self.applied_opts[:]
492
466
  return self
493
467
 
494
- def uop(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp: # noqa: E501
495
- if uop == UOps.ALU:
496
- if arg in UnaryOps:
497
- assert dtype == vin[0].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=}"
498
- elif arg in (BinaryOps.CMPLT, BinaryOps.CMPEQ):
499
- assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
500
- assert vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
501
- elif arg in BinaryOps:
502
- assert dtype == vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
503
- elif arg == TernaryOps.WHERE:
504
- assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
505
- assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}"
506
-
507
- if simplify:
508
- if uop == UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop
509
- if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype, insert_before)
510
- if uop == UOps.CAST and all(x.uop == UOps.CONST for x in vin) and all_same([x.arg for x in vin]):
511
- return self.const(vin[0].arg, dtype, insert_before)
512
- if uop == UOps.ALU:
513
- # rewrites. NOTE: the rewritten NEG op is still around...
514
- if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG:
515
- return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable, insert_before)
516
- # constant folding
517
- if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype, insert_before)
518
- if arg == TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
519
- # zero folding
520
- for x in [0,1]:
521
- if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
522
- if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
523
- if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x]
524
- if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0]
525
- if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
526
-
527
- key = (uop, dtype, vin, arg)
528
- if insert_before is None: insert_before = len(self.uops)
529
- # check if the cached expr is valid with the given insert place.
530
- if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr
531
- ret = UOp(uop, dtype, vin, arg)
532
- self.uops.insert(insert_before, ret)
533
- if cachable: self.saved_exprs[key] = ret
534
- return ret
468
+ def render_block(self, outputs:Tuple[LazyOp, ...], global_idxs, local_idxs, upcast_idxs, full_upcast_idxs,
469
+ alias_buf_idxs:DefaultDict[LazyOp,List[Tuple[int,int,List[NumNode|Variable]]]],
470
+ loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], accs:Dict[LazyOp,List[UOp]]) -> List[List[UOp]]:
471
+ reduceops = dedup(x for x in outputs if x.op in ReduceOps)
472
+ assert len(reduceops) <= 1, "max one reduceop per block"
473
+ reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
474
+ fake_reduce_idxs = [x*0 for x in reduce_idxs]
475
+
476
+ if len(reduceops) != 0:
477
+ # TODO: delete render_reduceop and move the logic for group_for_reduces to Block
478
+ nlidx, nuidx = self.render_reduceop((r:=reduceops[0]),accs,loaded_buffers,\
479
+ global_idxs,local_idxs,upcast_idxs,full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[r])
480
+
481
+ # all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
482
+ # been rewritten with fake end_local_idxs.
483
+ if r is self.reduceops[-1]: local_idxs[:], upcast_idxs[:] = nlidx, nuidx
484
+ return [accs[r]]
535
485
 
536
- def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], do_reduce=False, loop_ctx=tuple(), cache=None) -> List[UOp]: # noqa: E501
486
+ # load latebufs
487
+ loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
488
+ for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})
489
+ # run late AST (without the store)
490
+ store_vals = {op.arg.idx:self.ast_parse(op.src[0], accs, None, loaded_buffers) for op in self.ast}
491
+ return [self.global_store(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) for i, val in store_vals.items()]
492
+
493
+ def ast_parse(self, x:LazyOp, accs:Dict[LazyOp, List[UOp]], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], reduce_acc:Optional[List[UOp]]=None, cache=None) -> List[UOp]: # noqa: E501
537
494
  if cache is None: cache = {}
538
495
  if x in cache: return cache[x]
539
496
  if x.op in BufferOps: return loaded_buffers[x.arg]
540
- if x.op == UnaryOps.CAST:
541
- return [self.uop(UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg) for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
542
- if x.op in ReduceOps and not do_reduce:
543
- assert offs is None, "not available if we aren't doing reduce"
544
- return acc
545
- # MULACC fusion. TODO: this is copied from Interpreted
546
- if x.op == ReduceOps.SUM:
547
- if x.src[0].op == BinaryOps.MUL: x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
548
- if (castop:=x.src[0]).op == UnaryOps.CAST and (mulop:=castop.src[0]).op == BinaryOps.MUL:
549
- # MULACC with acc cast rewrite: MUL -> CAST -> SUM => CAST -> MULACC
550
- x = LazyOp(TernaryOps.MULACC, tuple(LazyOp(UnaryOps.CAST, (s, ), castop.arg) for s in mulop.src), x.arg)
551
-
552
- values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx, cache=cache) for v in x.src]
553
- ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
497
+ if x.op in [UnaryOps.CAST, UnaryOps.BITCAST]:
498
+ return [UOp(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST,
499
+ self.get_base_dtype(x.arg), (u,)) for u in self.ast_parse(x.src[0], accs, offs, loaded_buffers)]
500
+ if x.op in ReduceOps and reduce_acc is None:
501
+ return [accs[x][i] for i in offs] if offs else accs[x]
502
+
503
+ values = [self.ast_parse(v, accs, offs, loaded_buffers, cache=cache) for v in x.src]
504
+ ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}
554
505
  if x.op in ops:
506
+ assert reduce_acc is not None
555
507
  ret: List[UOp] = []
556
- input_acc = acc[:]
508
+ acc, input_acc = reduce_acc, reduce_acc[:]
557
509
  for val, off in zip(zip(*values), cast(List[int], offs)):
558
- acc[off] = self.uop(UOps.ALU, acc[off].dtype, vin=val+(acc[off],), arg=ops[x.op])
510
+ acc[off] = UOp.alu(ops[cast(ReduceOps, x.op)], *(val+(acc[off], )))
559
511
  ret.append(acc[off])
560
512
  for off in range(len(acc)):
561
513
  if input_acc[off] != acc[off]:
562
- acc[off] = self.uop(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx))
563
- else:
564
- ret = [self.uop(UOps.ALU, dtypes.bool if x.op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else val[-1].dtype, val, x.op) for val in zip(*values)]
514
+ acc[off] = UOp(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]))
515
+ else: ret = [UOp.alu(x.op, *src) for src in zip(*values)]
565
516
  cache[x] = ret
566
517
  return ret
518
+
519
+ def to_program(self) -> Program:
520
+ self.linearize()
521
+ info = get_lazyop_info(self.ast[0])
522
+ src = self.opts.render(name:=to_function_name(self.name), self.uops)
523
+ if getenv("RUN_PROCESS_REPLAY"): diskcache_put("process_replay", id(self), (self.ast, self.opts, self.applied_opts, name, src))
524
+ ops, mem = self.uops.flops_mem()
525
+ run_count = prod((self.global_size or []) + (self.local_size or []))
526
+ # NOTE: we use min here to ignore the indexing FLOPS
527
+ return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
528
+ self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))