tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,528 +0,0 @@
1
- from __future__ import annotations
2
- from typing import List, Tuple, Optional, Type, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence, Callable
3
- import itertools, math, functools
4
- from collections import defaultdict
5
-
6
- from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
7
- from tinygrad.helpers import colored, DEBUG, dedup, diskcache_put, prod, getenv, to_function_name, flatten
8
- from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
9
- from tinygrad.shape.shapetracker import ShapeTracker
10
- from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node, sint
11
- from tinygrad.codegen.kernel import LocalBuffer, Kernel
12
- from tinygrad.renderer import Program
13
-
14
- from tinygrad.codegen.uops import UOps, UOp, UOpGraph
15
-
16
- def get_grouped_dims(prefix:str, off:int, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse_dims:bool=False):
17
- """ Maps all global/local dims onto global/local sizes and returns the idxs, loop_idxs and sizes.
18
-
19
- * If there are fewer dims than size, size will be padded with 1s to the length of max_sizes.
20
- * If there are more dims than size, dims will be collapsed onto size starting from left-most (i.e. onto x, then y, then z).
21
- * If the dim is too large for the size, the dim will be split between adjacent size axes space permitting, otherwise assert
22
-
23
- Keyword arguments:
24
- prefix -- the prefix to use for the size Variable names.
25
- off -- the starting index for the size Variable names.
26
- dims -- the global or local dims of the full shape.
27
- max_sizes -- the maximum values for each size in (x, y, z) order.
28
- reverse_dims -- reverse the order of the dims as they are mapped into size, i.e. if True, the right dim will go to the left size (.x).
29
- """
30
-
31
- # check the edge cases on max_sizes
32
- if max_sizes is None: max_sizes = tuple([0xFFFFFFFFFFFFFFFF] * len(dims))
33
- assert len(max_sizes) > 0 or len(dims) == 0, f"{prefix} dims should be empty because no size axes available"
34
- if len(max_sizes) == 0: return [], [], None
35
-
36
- # initialize the map of dims to size with a single dim in each size axis
37
- # TODO: support sint properly
38
- size_dims:List[List[Tuple[int, sint, sint]]] = [[(dim_idx, dim, dim if isinstance(dim, int) else dim.max+1)] for dim_idx, dim in enumerate(dims)]
39
-
40
- # reverse the order of the dims to size map, if desired (currently for globals where smallest stride is on the right)
41
- # TODO: remove reverse_dims, the mapping of dims to size for globals should be cosearched with memory layouts for optimal peformance
42
- if reverse_dims: size_dims = size_dims[::-1]
43
-
44
- # ensure that the initial dims initially fit the valid size axes
45
- for size_idx in range(min(len(max_sizes), len(size_dims))):
46
- # if the initial dim is too large, split the dim to separate size axes, if possible
47
- dim_idx, dim, dim_max = size_dims[size_idx][0]
48
- if dim_max <= (max_sz:=max_sizes[size_idx]): continue
49
- assert isinstance(dim, int), "variable shape too large for size"
50
- for factor in range(2, int(dim**0.5)+1):
51
- if dim % factor == 0 and dim // factor <= max_sz:
52
- size_dims = size_dims[:size_idx] + [[(dim_idx, dim//factor, dim//factor)], [(dim_idx, factor, factor)]] + size_dims[size_idx+1:]
53
- break
54
- assert size_dims[size_idx][0][2] <= max_sz, f"dim at {size_idx} too large and non-factorable: {dim} > {max_sz}"
55
-
56
- # compress the extra dims, collapsing them onto the left-most valid size axis
57
- cur_size_idx = 0
58
- while len(size_dims) > len(max_sizes):
59
- if prod([dim_max for (_, _, dim_max) in size_dims[cur_size_idx]])*size_dims[cur_size_idx+1][0][2] <= max_sizes[cur_size_idx]:
60
- size_dims = size_dims[:cur_size_idx] + [size_dims[cur_size_idx] + size_dims[cur_size_idx+1]] + size_dims[cur_size_idx+2:]
61
- elif cur_size_idx < len(max_sizes)-1: cur_size_idx += 1
62
- else: raise AssertionError(f"cannot fit dims in size: {dims=} {max_sizes=}")
63
-
64
- # construct the final dim idx variables from the the portions of the size variables
65
- sizes, idxs = [prod([dim for (_, dim, _) in size_dim]) for size_dim in size_dims], [NumNode(0)] * len(dims)
66
- size_vars = loop_idxs = [Variable(f"{prefix}{len(sizes)-1-(i+off) if reverse_dims else i+off}", 0, s-1) for i,s in enumerate(sizes)]
67
- for size_idx, size_var in enumerate(size_vars):
68
- for dim_idx, dim, _ in size_dims[size_idx]:
69
- idxs[dim_idx] += (size_var % dim) * (idxs[dim_idx].max+1)
70
- size_var //= dim
71
-
72
- # pad the final sizes array to the proper length if necessary
73
- return idxs, [x for x in loop_idxs if not isinstance(x, NumNode)], sizes + [1]*(len(max_sizes)-len(sizes))
74
-
75
- def expand_idx(node:Node) -> Union[Variable, NumNode]: return next((v for v in node.vars() if v.expr.startswith("_uidx")), NumNode(0))
76
- def expand_idxs(nodes:Sequence[Node]) -> Tuple[Union[Variable, NumNode], ...]:
77
- eidxs = [expand_idx(node) for node in nodes]
78
- return tuple([v if v not in eidxs[:j] else NumNode(0) for j, v in enumerate(eidxs)]) # take only first occurrence of expand variable
79
- def iter_idxs(idxs:Tuple[Union[Variable, NumNode], ...]) -> Iterator[Tuple[int,...]]:
80
- yield from (x[::-1] for x in itertools.product(*[list(range(v.min, v.max + 1)) for v in idxs[::-1]]))
81
-
82
- def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
83
- idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
84
- # TODO: bring back the valid removal logic (correct!)
85
- if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid)
86
- return (idx, idy), valid
87
-
88
- # expand a Node into List[Node] that enumerates the underlying Variables from min to max
89
- # expand increments earlier variables faster than later variables (as specified in the argument)
90
- @functools.lru_cache(maxsize=None)
91
- def expand_node(node:Node, idxs:Optional[Tuple[Union[Variable, NumNode], ...]]=None) -> List[Node]:
92
- if idxs is None: idxs = (expand_idx(node),)
93
- return [node.substitute({k:v for k,v in zip(idxs, (NumNode(x) for x in rep)) if isinstance(k, Variable)}) for rep in iter_idxs(idxs)]
94
-
95
- def variable_to_uop(x, ctx=None) -> UOp:
96
- if isinstance(x, int): return UOp.const(dtypes.int, x)
97
- return x.render(render_ops, ctx)
98
-
99
- render_ops: Dict[Type, Callable[..., UOp]] = {
100
- NumNode: lambda self, ops, ctx: UOp.const(dtypes.int, self.b),
101
- Variable: lambda self, ops, ctx: ctx[self.expr] if self.expr in ctx else UOp(UOps.DEFINE_VAR, dtypes.int, (), self),
102
- MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
103
- DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
104
- ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
105
- LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
106
- SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+variable_to_uop(b, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
107
- AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*variable_to_uop(b, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
108
-
109
- class Linearizer(Kernel):
110
- def get_reduce_acc(self, reduceop:LazyOp):
111
- if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
112
- if reduceop.op is ReduceOps.MAX:
113
- if dtypes.is_int(reduceop.dtype): return 0 if dtypes.is_unsigned(reduceop.dtype) else -2**(reduceop.dtype.itemsize*8-1)
114
- return -math.inf if dtypes.is_float(reduceop.dtype) else False
115
-
116
- # NOTE: once images are loaded, we uop them as their base float
117
- def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt
118
-
119
- def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Optional[UOp]=None, loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]:
120
- buf = self.bufs[i]
121
- localtype = self.get_base_dtype(buf.dtype if acc is None else acc.dtype)
122
- const = buf.val if isinstance(buf, ConstBuffer) else None
123
-
124
- expand_vars = expand_idxs(idxs)
125
-
126
- dim, amt = None, 1
127
- # float 4 grouping
128
- if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [4,2]:
129
- dim, amt = upcast_dim[0], len(float4_expand)
130
- g_idx, g_valid = self.sts[i].expr_idxs(idxs[:dim] + [float4_expand[0]] + idxs[dim+1:])
131
- # do not use float4 if idx is not aligned
132
- if g_idx != (g_idx//amt*amt): dim, amt = None, 1
133
- if dim is None:
134
- g_idx, g_valid = self.sts[i].expr_idxs(idxs)
135
- # todo: multioutput test with different output valids to add if acc is None: g_valid = NumNode(1)
136
-
137
- if amt > 1: localtype = localtype.vec(amt)
138
- e_idxs, e_valids = expand_node(g_idx, expand_vars), expand_node(g_valid, expand_vars) # pylint: disable=possibly-used-before-assignment
139
-
140
- ret = []
141
- invalid_value = 0
142
- acc_count = 0
143
- for idx, valid, rep_idx in zip(e_idxs, e_valids, iter_idxs(expand_vars)):
144
- this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid)
145
- key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
146
- if key not in self.load_cache:
147
- if acc is not None:
148
- self.load_cache[key] = UOp(UOps.DEFINE_ACC, localtype, (UOp.const(localtype.scalar(), self.get_reduce_acc(acc)), *loop_ctx), (i, acc_count))
149
- acc_count += 1
150
- elif this_const is not None:
151
- self.load_cache[key] = UOp.const(localtype, this_const)
152
- if valid.min == 0 and valid.max == 1:
153
- valid_rendered = valid.render(render_ops, self.loop_uops)
154
- self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_rendered, self.load_cache[key], UOp.const(localtype, invalid_value))
155
- elif isinstance(buf.dtype, ImageDType):
156
- buf_uop = self.buf_uops[i]
157
- assert buf_uop is not None, f"buffer {i} wasn't UOped"
158
- image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
159
- rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
160
- valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
161
- self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4),
162
- (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
163
- if localtype == localtype.scalar():
164
- idx_small = idx%4
165
- res = idx_small.render(render_ops, self.loop_uops)
166
- out = UOp(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
167
- for ix in range(idx_small.max, idx_small.min, -1):
168
- rvv = UOp(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
169
- sel = UOp.alu(BinaryOps.CMPLT, res, UOp.const(dtypes.int, ix))
170
- out = UOp.alu(TernaryOps.WHERE, sel, rvv, out)
171
- self.load_cache[key] = out
172
- else:
173
- buf_uop = self.buf_uops[i]
174
- assert buf_uop is not None, f"buffer {i} wasn't UOped"
175
- rendered_idx = idx.render(render_ops, self.loop_uops)
176
- valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
177
- self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
178
- ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
179
- return ret
180
-
181
- def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
182
- buf = self.bufs[i]
183
- buf_uop = self.buf_uops[i]
184
- assert buf_uop is not None, f"buffer {i} wasn't UOped"
185
-
186
- expand_vars = expand_idxs(idxs)
187
- _idxs = zip(*[expand_node(idx, expand_vars) for idx in idxs]) if idxs else [tuple()] # transpose
188
- store_offset = dict(zip(_idxs, store))
189
-
190
- # float4 grouping
191
- if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [2,4]:
192
- grouped_store_offset = defaultdict(list)
193
- for k in store_offset:
194
- _idx = k[:upcast_dim[0]] + (float4_expand[0],) + k[upcast_dim[0]+1:]
195
- grouped_store_offset[_idx].append(store_offset[k])
196
- store_offset_new = {}
197
- for k,grouped in grouped_store_offset.items():
198
- amt = len(grouped)
199
- idx, valid = self.sts[i].expr_idxs(k)
200
- assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
201
- store_offset_new[k] = UOp(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
202
- store_offset = store_offset_new
203
-
204
- stores = []
205
- for _idx, var in store_offset.items():
206
- idx, valid = self.sts[i].expr_idxs(_idx)
207
- if isinstance(buf.dtype, ImageDType):
208
- image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
209
- rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), \
210
- tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
211
- else:
212
- rendered_idx = idx.render(render_ops, self.loop_uops)
213
- # TODO: let UPat check this once it's fast
214
- if valid.min == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var)))
215
- else: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(render_ops, self.loop_uops))))
216
- return stores
217
-
218
- # render loop
219
- def render_loop(self, xx:List[Variable], depth:int, reduce:bool) -> Tuple[UOp, ...]:
220
- new_loops = {x.expr:UOp(UOps.RANGE, dtypes.int32, (
221
- UOp.const(dtypes.int, x.min) if isinstance(x.min, int) else cast(Node, x.min).render(render_ops, self.loop_uops),
222
- UOp.const(dtypes.int, x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(render_ops, self.loop_uops)), arg=(depth,i,reduce)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
223
- self.loop_uops.update(new_loops)
224
- return tuple(new_loops.values())
225
-
226
- def index_local_aliases(self, global_idxs, local_idxs, reduce_idxs, upcast_idxs, full_upcast_idxs):
227
- def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
228
- replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
229
- for s in local_sizes:
230
- thread_idxs.append(thread_idx % s)
231
- thread_idx //= s
232
- for alias in aliases:
233
- full_var, full_var_sz = NumNode(0), 1
234
- if alias[0] != 0:
235
- for i in alias:
236
- next_var = local_idxs[i-1] if i > 0 else thread_idxs[-i-1]
237
- full_var += next_var * full_var_sz
238
- full_var_sz *= next_var.max+1
239
- replace_idxs.append(full_var)
240
- return replace_idxs
241
-
242
- # compute local aliases
243
- alias_buf_idxs: DefaultDict[LazyOp, List[Tuple[int, int, List]]] = defaultdict(list)
244
- for op, local_alias in self.local_alias.items():
245
- for i in local_alias:
246
- localbuf_idx = self.bufs.index(local_alias[i])
247
- buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
248
- if (tc:=self.tensor_core):
249
- min_alias_idx = min(local_alias.keys())
250
- replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
251
- for n in range(len(tc.threads)):
252
- buf_idxs[self.global_dims+n] = replace_input_idxs[n] # replace locals
253
- for n in range(tc.num_upcasts()):
254
- buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
255
- if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
256
- alias_buf_idxs[op].append((i, localbuf_idx, buf_idxs))
257
- # modify idxs if necessary for TC
258
- if (tc:=self.tensor_core):
259
- replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
260
- for n in range(len(tc.threads)):
261
- local_idxs[n] = replace_acc_idxs[n] # replace locals
262
- for n in range(len(replace_acc_idxs)-len(tc.threads)):
263
- upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
264
- if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+upcast_idxs}")
265
- return alias_buf_idxs
266
-
267
- def render_reduceop(self, reduceop:LazyOp, accs:Dict[LazyOp, List[UOp]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]],
268
- global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, reduce_idxs, fake_reduce_idxs,
269
- alias_buf_idxs:List[Tuple[int, int, List]]) -> Tuple[List[NumNode|Variable], List[NumNode|Variable]]:
270
- # reduce loop
271
- loop_ctx = self.render_loop(reduce_idxs, (i:=self.reduceops.index(reduceop))*2+2, True)
272
-
273
- # define accumulator - modify idxs if necessary for TC
274
- out_buf = -len(self.reduceops)+i if self.group_for_reduces else 0
275
- accs[reduceop] = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
276
-
277
- # store local aliases
278
- locals_to_store = [(localbuf_idx, buf_idxs, self.global_load(i, buf_idxs)) for i, localbuf_idx, buf_idxs in alias_buf_idxs]
279
-
280
- if (tc:=self.tensor_core):
281
- # run tensor cores AST
282
- wmma_sz = [prod(l) for l in tc.thread_local_sizes]
283
- def upcast_strides(buf:int):
284
- strides, next_ = [], 1
285
- for (sz, stride, _) in self.upcasted_axis(buf)[tc.num_upcasts():]:
286
- strides.append((0 if stride == 0 else next_, sz))
287
- next_ *= 1 if stride == 0 else sz
288
- return strides
289
- upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
290
- # cast initial accs
291
- wmmas = [UOp(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
292
- for x in range(0, len(accs[reduceop]), wmma_sz[2])]
293
- for it in [x[::-1] for x in itertools.product(*list([range(sz) for _,sz in upcasts[0]][::-1]))]:
294
- offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(it, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
295
- ops = (UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
296
- UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
297
- wmmas[(wmma_idx:=offs[2]//wmma_sz[2])])
298
- # TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid
299
- wmmas[wmma_idx] = UOp(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev))
300
- # phi the last wmmas back to accs
301
- accs[reduceop] = [UOp(UOps.PHI, tc.dtype_out, (acc, UOp(UOps.GEP, tc.dtype_out, (wmmas[z//wmma_sz[2]],), z%wmma_sz[2])))
302
- for z, acc in enumerate(accs[reduceop])]
303
- else:
304
- assert not locals_to_store, "storing locals isn't supported here"
305
-
306
- # load earlybufs
307
- loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[reduceop][i]) if i in self.local_alias else i,
308
- global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs})
309
-
310
- def gate_acc(r, idxs): return [
311
- UOp.alu(TernaryOps.WHERE, valid.render(render_ops, self.loop_uops), acc, UOp.const(r.dtype, 0)) if valid.min == 0 and valid.max == 1 else acc
312
- for valid, acc in zip(expand_node(self.sts[self.full_buf_index].expr_idxs(idxs)[1], expand_idxs(idxs)), accs[r])]
313
- local_accs = {r: gate_acc(r,global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for r in accs}
314
-
315
- # run early AST (with reduce)
316
- self.ast_parse(reduceop, local_accs, self.acc_offsets(self.full_buf_index), loaded_buffers, reduce_acc=accs[reduceop])
317
-
318
- # end the reduce loop
319
- self.load_cache.clear()
320
-
321
- # end the local loop, do the local reduce
322
- if self.group_for_reduces:
323
- fake_global_idxs = [x*0 for x in global_idxs]
324
- stores = self.global_store(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop]) # store accumulators
325
- barrier = UOp(UOps.BARRIER, None, tuple(stores))
326
- if self.opts.has_local:
327
- fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
328
- fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
329
- if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(render_ops, self.loop_uops)
330
- barrier = UOp(UOps.IF, None, (if_cond, barrier))
331
-
332
- # create new late reduce local loops and replace local_idxs that have been used
333
- end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501
334
- local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
335
-
336
- # if any group_for_reduce items aren't reduces, upcast them here
337
- for j in self.upcast_in_mid_reduce_axes:
338
- self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
339
- self.upcast()
340
- self.group_for_reduces -= 1
341
- local_idxs = local_idxs[:-1]
342
- end_local_idxs = end_local_idxs[:-1]
343
- # regenerate upcast_idxs
344
- upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
345
-
346
- # NOTE: this structure is the same as the reduce op above
347
-
348
- # late reduce loop
349
- loop_ctx = self.render_loop(end_local_idxs, i*2+3, True)
350
-
351
- # define late accumulator
352
- accs[reduceop] = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
353
-
354
- # load localbufs
355
- loaded_buffers[self.bufs[out_buf]] = self.global_load(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
356
-
357
- # there's no AST here (and there's no shape for the reduce LazyOp)
358
- self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[out_buf]),)),\
359
- accs, self.acc_offsets(-1), loaded_buffers, reduce_acc=accs[reduceop])
360
-
361
- # end the late reduce loop
362
- self.load_cache.clear()
363
-
364
- if reduceop is not self.reduceops[-1]:
365
- for j in self.upcast_in_mid_reduce_axes:
366
- self.upcasted -= 1
367
- self.group_for_reduces += 1
368
- assert self.buf_uops[out_buf] is not None, "Local reduce buf must have been uoped at this point"
369
- fake_local_idxs = local_idxs[:self.local_dims] + [x*0 for x in local_idxs[self.local_dims:]]
370
- stores = self.global_store(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop])
371
- barrier = UOp(UOps.BARRIER, None, tuple(stores))
372
- accs[reduceop] = self.global_load(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
373
- return local_idxs[:self.local_dims] + [NumNode(0) for _ in range(self.group_for_reduces)], upcast_idxs
374
-
375
- kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
376
- def linearize(self) -> Linearizer:
377
- # no new opts and we already ran? skip relinearizing
378
- if self.applied_opts == self.applied_opts_cache: return self
379
-
380
- # late alias the tensor core buffers
381
- if (tc:=self.tensor_core) and self.tensor_core_opts is not None:
382
- alias_pattern = [0]*(self.global_dims) + [2]*(len(tc.threads)) + [0]*(self.local_dims-len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) # noqa: E501
383
- for op, tc_bufs in self.bufs_for_tensor_core.items():
384
- for tc_buf in tc_bufs: self.alias_buffer(op, tc_buf, alias_pattern)
385
-
386
- # save backups
387
- sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduces, self.upcasted
388
-
389
- # uops
390
- self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
391
- self.loop_uops: Dict[str, UOp] = {}
392
-
393
- # add global buffers
394
- for i,buf in enumerate(self.bufs):
395
- if isinstance(buf, MemBuffer):
396
- self.buf_uops[i] = UOp(UOps.DEFINE_GLOBAL,
397
- buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
398
- (buf.idx, any(buf.idx == x.idx for x in self.outbufs)))
399
- # define local buffers
400
- for aliases in self.local_alias.values():
401
- for lb in aliases.values(): self.buf_uops[self.bufs.index(lb)] = UOp(UOps.DEFINE_LOCAL, PtrDType(lb.dtype),
402
- (), (lb.name, self.sts[self.bufs.index(lb)].size))
403
- # add a local buffer for multistage reduce. # TODO: use local alias
404
- if self.group_for_reduces:
405
- for i in range(len(self.reduceops)):
406
- # TODO: the strides of this can be controlled
407
- self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
408
- temp_dtype = self.get_base_dtype(cast(LazyOp, self.reduceop).dtype)
409
- self.bufs.append(LocalBuffer(name:=f"temp{i if len(self.reduceops) > 1 else ''}", buf_size:=self.sts[-1].size, temp_dtype))
410
- self.buf_uops.append(UOp(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), (name, buf_size)))
411
-
412
- # kernel name (before late upcast)
413
- self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
414
- (f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \
415
- colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
416
-
417
- # name the function something unique
418
- Linearizer.kernel_cnt[(function_name := to_function_name(self.name))] += 1
419
- suffix = f"{'n'+str(Linearizer.kernel_cnt[function_name]-1)}" if Linearizer.kernel_cnt[function_name] > 1 else ""
420
- self.name = self.name+colored(suffix, 'BLACK')
421
-
422
- # define indexes
423
- gl_dims = self.full_shape[:self.first_reduce+self.group_for_reduces]
424
- global_idxs, loop_global_idxs, self.global_size = get_grouped_dims("idx" if self.dont_use_locals else "gidx", 0, gl_dims[:self.global_dims],
425
- self.opts.global_max, self.opts.has_local)
426
- local_idxs, loop_local_idxs, self.local_size = get_grouped_dims("lidx", self.global_dims, gl_dims[self.global_dims:],
427
- self.opts.local_max if self.opts.has_local else (), False)
428
- upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
429
- full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
430
-
431
- # render global and local as specials or a loop
432
- if self.opts.has_local:
433
- self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
434
- if not self.dont_use_locals:
435
- self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
436
- else:
437
- self.global_size, self.local_size = None, None
438
- self.render_loop(loop_global_idxs+loop_local_idxs, 1, False)
439
-
440
- # define idxs for aliased buffers TODO: this doesn't belong in Kernel, but it can't exist in Block either (because of multireduce tensor cores)
441
- reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
442
- alias_buf_idxs = self.index_local_aliases(global_idxs,local_idxs,reduce_idxs,upcast_idxs,full_upcast_idxs)
443
-
444
- # parse AST
445
- self.load_cache: Dict[str, UOp] = {}
446
- loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]] = {}
447
- accs: Dict[LazyOp, List[UOp]] = {}
448
-
449
- # render reduceops by depth
450
- for reduceop in self.reduceops:
451
- self.render_block((reduceop, ), global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
452
- stores = self.render_block(self.ast, global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
453
-
454
- # only the final stores are needed to define the full UOps graph
455
- self.uops:UOpGraph = UOpGraph(flatten(stores))
456
-
457
- # maybe graph the uops
458
- if DEBUG >= 5: self.uops.print()
459
- if getenv("GRAPHUOPS"): self.uops.graph()
460
-
461
- # restore backups
462
- self.sts, self.group_for_reduces, self.upcasted = sts_backup, gfr_backup, upc_backup
463
-
464
- # set cache and return
465
- self.applied_opts_cache = self.applied_opts[:]
466
- return self
467
-
468
- def render_block(self, outputs:Tuple[LazyOp, ...], global_idxs, local_idxs, upcast_idxs, full_upcast_idxs,
469
- alias_buf_idxs:DefaultDict[LazyOp,List[Tuple[int,int,List[NumNode|Variable]]]],
470
- loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], accs:Dict[LazyOp,List[UOp]]) -> List[List[UOp]]:
471
- reduceops = dedup(x for x in outputs if x.op in ReduceOps)
472
- assert len(reduceops) <= 1, "max one reduceop per block"
473
- reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
474
- fake_reduce_idxs = [x*0 for x in reduce_idxs]
475
-
476
- if len(reduceops) != 0:
477
- # TODO: delete render_reduceop and move the logic for group_for_reduces to Block
478
- nlidx, nuidx = self.render_reduceop((r:=reduceops[0]),accs,loaded_buffers,\
479
- global_idxs,local_idxs,upcast_idxs,full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[r])
480
-
481
- # all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
482
- # been rewritten with fake end_local_idxs.
483
- if r is self.reduceops[-1]: local_idxs[:], upcast_idxs[:] = nlidx, nuidx
484
- return [accs[r]]
485
-
486
- # load latebufs
487
- loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
488
- for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})
489
- # run late AST (without the store)
490
- store_vals = {op.arg.idx:self.ast_parse(op.src[0], accs, None, loaded_buffers) for op in self.ast}
491
- return [self.global_store(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) for i, val in store_vals.items()]
492
-
493
- def ast_parse(self, x:LazyOp, accs:Dict[LazyOp, List[UOp]], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], reduce_acc:Optional[List[UOp]]=None, cache=None) -> List[UOp]: # noqa: E501
494
- if cache is None: cache = {}
495
- if x in cache: return cache[x]
496
- if x.op in BufferOps: return loaded_buffers[x.arg]
497
- if x.op in [UnaryOps.CAST, UnaryOps.BITCAST]:
498
- return [UOp(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST,
499
- self.get_base_dtype(x.arg), (u,)) for u in self.ast_parse(x.src[0], accs, offs, loaded_buffers)]
500
- if x.op in ReduceOps and reduce_acc is None:
501
- return [accs[x][i] for i in offs] if offs else accs[x]
502
-
503
- values = [self.ast_parse(v, accs, offs, loaded_buffers, cache=cache) for v in x.src]
504
- ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}
505
- if x.op in ops:
506
- assert reduce_acc is not None
507
- ret: List[UOp] = []
508
- acc, input_acc = reduce_acc, reduce_acc[:]
509
- for val, off in zip(zip(*values), cast(List[int], offs)):
510
- acc[off] = UOp.alu(ops[cast(ReduceOps, x.op)], *(val+(acc[off], )))
511
- ret.append(acc[off])
512
- for off in range(len(acc)):
513
- if input_acc[off] != acc[off]:
514
- acc[off] = UOp(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]))
515
- else: ret = [UOp.alu(x.op, *src) for src in zip(*values)]
516
- cache[x] = ret
517
- return ret
518
-
519
- def to_program(self) -> Program:
520
- self.linearize()
521
- info = get_lazyop_info(self.ast[0])
522
- src = self.opts.render(name:=to_function_name(self.name), self.uops)
523
- if getenv("RUN_PROCESS_REPLAY"): diskcache_put("process_replay", id(self), (self.ast, self.opts, self.applied_opts, name, src))
524
- ops, mem = self.uops.flops_mem()
525
- run_count = prod((self.global_size or []) + (self.local_size or []))
526
- # NOTE: we use min here to ignore the indexing FLOPS
527
- return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
528
- self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))