tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,510 @@
1
+ from __future__ import annotations
2
+ import itertools, functools, math
3
+ from dataclasses import dataclass
4
+ from collections import defaultdict
5
+ from typing import cast, Final, Callable, Sequence
6
+ from enum import Enum, auto
7
+
8
+ from tinygrad.uop.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, AxisType
9
+ from tinygrad.uop.spec import type_verify, ast_spec
10
+ from tinygrad.device import Device
11
+ from tinygrad.codegen.opt.tc import TensorCore
12
+ from tinygrad.renderer import Renderer
13
+ from tinygrad.dtype import ImageDType, AddrSpace
14
+ from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
15
+ from tinygrad.shape.shapetracker import ShapeTracker
16
+ from tinygrad.shape.view import strides_for_shape, get_contraction
17
+ from tinygrad.codegen.opt.swizzler import view_left, view_left_through_load
18
+
19
+ class OptOps(Enum):
20
+ TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
21
+ GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
22
+ def __lt__(self, x:OptOps): return self.value < x.value
23
+
24
+ @dataclass(frozen=True, order=True)
25
+ class Opt:
26
+ op: OptOps
27
+ axis: int|None = None
28
+ arg: int|tuple|None = None
29
+ def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
30
+
31
+ axis_letters = {AxisType.GLOBAL: "g", AxisType.LOCAL: "l", AxisType.LOOP: "L", AxisType.UPCAST: "u",
32
+ AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
33
+ axis_colors = {AxisType.GLOBAL: "blue", AxisType.LOCAL: "cyan", AxisType.LOOP: "WHITE", AxisType.UPCAST: "yellow",
34
+ AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
35
+
36
+ class KernelOptError(Exception): pass
37
+ def check(cond:bool, msg:str=""):
38
+ if not cond: raise KernelOptError(msg)
39
+
40
+ @dataclass
41
+ class TensorCoreOptions:
42
+ axes: tuple[int, ...] # the location of the original N and M axes if still in the shape
43
+ axes_exist: tuple[bool, ...] # true if the original N and M axes are still in the shape
44
+ axis_pads: tuple[tuple[int, int], ...]
45
+ def fix_axes(self, removed_axis:int): # adjust the TC axes if necessary when a dimension is removed
46
+ axes, axes_exist = list(self.axes), list(self.axes_exist)
47
+ for tc_dim in [i for i in range(2) if axes_exist[i]]:
48
+ if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1
49
+ elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False
50
+ self.axes, self.axes_exist = tuple(axes), tuple(axes_exist)
51
+
52
+ class Kernel:
53
+ def __init__(self, ast:UOp, opts:Renderer|None=None):
54
+ assert ast.op is Ops.SINK, ast.op
55
+ self.ast = ast
56
+
57
+ self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
58
+ # verify AST matches the spec
59
+ if __debug__: type_verify(list(self.ast.toposort()), ast_spec)
60
+
61
+ self.vars: list[Variable] = self.ast.variables()
62
+ # NOTE: this requires a specific order with the [::-1], this is likely a bug
63
+ self.bufs: list[UOp] = [x for x in self.ast.toposort() if x.op in GroupOp.Buffer][::-1]
64
+
65
+ # create new shapetrackers inside this kernel, we will permute them
66
+ self.sts: list[ShapeTracker] = [x.st_arg for x in self.bufs]
67
+
68
+ # add the shapetrackers for each reduce
69
+ # we use this to track which axes are reduced in each reduce
70
+ self.reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE_AXIS]
71
+ for x in self.reduceops:
72
+ self.sts.append(unwrap(x.st))
73
+ self.sts.append(unwrap(x.src[0].st))
74
+
75
+ # add a shapetracker to the end to track the full shape, with 0 strides so it can merge
76
+ full_shape = ast.full_shape
77
+ self.sts.append(ShapeTracker.from_shape(full_shape, (0,)*len(full_shape)))
78
+
79
+ # parameters for optimization
80
+ self.tensor_core: TensorCore|None = None
81
+ self.tensor_core_opts: TensorCoreOptions|None = None
82
+ self.use_tensor_cores: int = 0
83
+ self.applied_opts: list[Opt] = []
84
+ self.dont_use_locals = False
85
+ self.finalized: bool = False
86
+
87
+ # group simplifies
88
+ self.simplify_ones()
89
+ self.simplify_merge_adjacent()
90
+
91
+ # axis types
92
+ global_loops = AxisType.GLOBAL if self.opts.has_local else AxisType.LOOP
93
+ self.axis_types: list[AxisType] = [AxisType.REDUCE if resolve(x!=y) else global_loops for x,y in zip(self.output_shape, self.full_shape)]
94
+
95
+ # confirm all reduce axes are at the end
96
+ if (final_reduces := [x for x in self.axis_types if x == AxisType.REDUCE]) and final_reduces != self.axis_types[-len(final_reduces):]:
97
+ raise RuntimeError(f"reduces are not at the end of the shape {self.full_shape} -> {self.output_shape}")
98
+
99
+ def copy(self):
100
+ ret = type(self).__new__(type(self))
101
+
102
+ # base linearizer params
103
+ ret.opts, ret.ast = self.opts, self.ast
104
+
105
+ # things downstream of the AST
106
+ ret.reduceops, ret.vars, ret.bufs = self.reduceops, self.vars, self.bufs
107
+ ret.sts = self.sts[:]
108
+ ret.axis_types = self.axis_types[:]
109
+
110
+ # parameters for optimizations
111
+ ret.applied_opts, ret.dont_use_locals = self.applied_opts[:], self.dont_use_locals
112
+ ret.tensor_core, ret.tensor_core_opts, ret.use_tensor_cores = self.tensor_core, self.tensor_core_opts, self.use_tensor_cores
113
+ ret.finalized = self.finalized
114
+
115
+ return ret
116
+
117
+ @property
118
+ def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None
119
+ @property
120
+ def full_shape(self) -> tuple[sint, ...]: return self.sts[-1].shape
121
+
122
+ @property
123
+ def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape
124
+ @property
125
+ def shape_len(self) -> int: return len(self.sts[0].shape)
126
+
127
+ def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in argfix(axis_type)]
128
+ @property
129
+ def upcasted(self) -> int: return len(self.axes_of(AxisType.UPCAST, AxisType.UNROLL))
130
+ @property
131
+ def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE))
132
+
133
+ # heuristic helpers
134
+ @property
135
+ def upcastable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP) \
136
+ if isinstance(s:=self.full_shape[i], int) and s > 1]
137
+ @property
138
+ def unrollable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE) \
139
+ if isinstance(s:=self.full_shape[i], int) and s > 1]
140
+
141
+ # ******************** colors and names ********************
142
+
143
+ def colors(self) -> list[str]:
144
+ assert len(self.axis_types) == self.shape_len, "colors size mismatch"
145
+ return [axis_colors[x] if not self.dont_use_locals or not x == AxisType.GLOBAL else "BLUE" for x in self.axis_types]
146
+
147
+ def colored_shape(self, pad:int|None=None, dense=False) -> str:
148
+ shape_strs = [(s if dense else f"{s:4d}") if isinstance(s, int) else s.render() for s in self.full_shape]
149
+ ret = ' '.join(colored(s, color) for s,color in zip(shape_strs, self.colors()))
150
+ if pad: ret += ' '*(pad-ansilen(ret))
151
+ return ret
152
+
153
+ kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
154
+ @functools.cached_property
155
+ def name(self) -> str:
156
+ # kernel name (before late upcast)
157
+ kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op is Ops.SINK or x.op in GroupOp.Buffer for x in self.ast.toposort()) else "E")
158
+ suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())])
159
+ name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix
160
+
161
+ # name the function something unique
162
+ Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1
163
+ num = f"n{Kernel.kernel_cnt[function_name]-1}" if Kernel.kernel_cnt[function_name] > 1 else ""
164
+ return name + colored(num, 'BLACK')
165
+
166
+ # ******************** base simplifiers ********************
167
+
168
+ # apply reshape and permute to all shapetrackers
169
+ def reshape(self, new_shape_fxn:Callable[[tuple[sint, ...]], Sequence[sint]]):
170
+ self.sts = [st.reshape(tuple(new_shape_fxn(st.shape))) for st in self.sts]
171
+ def permute(self, new_axes:Sequence[int]): self.sts = [st.permute(tuple(new_axes)) for st in self.sts]
172
+
173
+ # axis : the axis to pull from
174
+ # amount : the amount to take
175
+ # top : if you want to pull that amount from the top
176
+ # insert_at : place to insert the new stuff
177
+ def shift_to(self, axis:int, amount:int, new_type:AxisType, top:bool=False, insert_at:int|None=None):
178
+ if insert_at is None: insert_at = self.shape_len
179
+ self.axis_types.insert(insert_at, new_type)
180
+ move_axis = axis if top else axis+1
181
+ if move_axis < insert_at: insert_at += 1
182
+ def new_shape_fxn(x): return x[0:axis] + (((amount,x[axis]//amount) if top else (x[axis]//amount,amount)) if x[axis] > 1 else (1,1)) + x[axis+1:]
183
+ new_axes = [i for i in range(insert_at) if i != move_axis]+[move_axis]+[i for i in range(insert_at, self.shape_len+1) if i != move_axis]
184
+ self.reshape(new_shape_fxn)
185
+ self.permute(new_axes)
186
+
187
+ # ******************** complex simplifiers ********************
188
+
189
+ def simplify_ones(self) -> bool:
190
+ # remove places where the shape is all ones
191
+ if any(all_ones:=[s==1 for s in self.full_shape]):
192
+ if hasattr(self, 'axis_types'):
193
+ self.axis_types = [x for i,x in enumerate(self.axis_types) if not all_ones[i]]
194
+ self.reshape(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]])
195
+ return True
196
+ return False
197
+
198
+ def simplify_merge_adjacent(self):
199
+ assert not hasattr(self, 'axis_types'), "don't call this after init"
200
+ if self.shape_len == 0: return
201
+ shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
202
+ # NOTE: we can't use self.first_reduce yet
203
+ first_reduce = [resolve(x!=y) for x,y in zip(self.output_shape+(0,), self.full_shape+(1,))].index(True)
204
+
205
+ # if it's an image, insert fake strides such that this fusion doesn't happen across image axes
206
+ # TODO: remove membufs
207
+ membufs = dedup([x.src[0].base for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
208
+ if isinstance(membufs[0].base.dtype, ImageDType):
209
+ base_shape = membufs[0].base.dtype.shape
210
+ if shape_idx_groups := get_contraction(self.output_shape, base_shape):
211
+ special_strides: tuple[sint, ...] = tuple()
212
+ for i,g in enumerate(shape_idx_groups):
213
+ shape_piece = tuple(self.output_shape[x] for x in g)
214
+ assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
215
+ special_strides += strides_for_shape(shape_piece)
216
+ # adding the fake image shape
217
+ shapes.append(self.output_shape)
218
+ strides.append(special_strides)
219
+
220
+ # merge dimensions if we can, multi _merge_dims
221
+ # NOTE: this does not always preserve the reduce dimension
222
+ # TODO: move this into shapetracker, with tests!
223
+ # TODO: how does this work with multi-reduce?
224
+ rets = [[(s[0], st[0])] for s,st in zip(shapes, strides)]
225
+ for i in range(1, len(shapes[0])):
226
+ can_merge = []
227
+ for s,st,ret in zip(shapes, strides, rets):
228
+ # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
229
+ si, sti, last_st = s[i], st[i], ret[-1][1]
230
+ can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0)))
231
+ # more can merge than this
232
+ mergeable = all(can_merge) and i != first_reduce
233
+ for j,(s,st) in enumerate(zip(shapes, strides)):
234
+ if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i])
235
+ else: rets[j].append((s[i], st[i]))
236
+
237
+ # do the reshapes
238
+ for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
239
+
240
+ # ******************** apply optimizations ********************
241
+
242
+ def real_axis(self, op:OptOps, axis:int|None):
243
+ try:
244
+ if axis is None: return -1
245
+ if op is OptOps.UNROLL: return self.unrollable_dims[axis]
246
+ if op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[axis]
247
+ check(axis < self.shape_len, "invalid axis")
248
+ return axis
249
+ except IndexError as e: raise KernelOptError from e
250
+
251
+ def apply_opt(self, opt:Opt, append_opt:bool=True):
252
+ if self.finalized: raise RuntimeError("can't optimize Kernel after it's finalized")
253
+ if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals")
254
+
255
+ if opt.op is OptOps.TC:
256
+ check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
257
+ check(len(self.opts.tensor_cores) > 0, "must have tensor cores")
258
+ check(opt.axis is not None, "tensor core opts must have an axis")
259
+ check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg")
260
+ check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
261
+ check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
262
+ check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 2, "use_tensor_cores value is not valid")
263
+ check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available")
264
+ self.applied_opts.append(opt)
265
+ return
266
+
267
+ axis = self.real_axis(opt.op, opt.axis)
268
+
269
+ if opt.op is OptOps.SWAP: amt = self.real_axis(opt.op, cast(int, opt.arg)) # arg is an axis in the SWAPs
270
+ elif opt.arg is not None:
271
+ check(isinstance(opt.arg, int), "arg should be int")
272
+ amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis]
273
+ check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless")
274
+ if opt.op is not OptOps.PADTO:
275
+ # we check both the full_shape and each shape
276
+ check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}")
277
+ for st in self.sts: check(st.shape[axis] == 1 or st.shape[axis] % amt == 0, f"no longer valid shift {st.shape[axis]=}, {amt=}")
278
+ else: amt = -1
279
+
280
+ if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
281
+ (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
282
+ acc_sz = self.reduceop.dtype.itemsize
283
+ upcast_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.UPCAST)])
284
+ local_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.LOCAL)])
285
+ smem_sz = amt*acc_sz*upcast_sz*local_sz
286
+ check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
287
+
288
+ if opt.op is OptOps.LOCAL: # cyan
289
+ # NOTE: LLVM/CPU can use locals too, but they are treated the same as globals (still helpful for L1 cache)
290
+ # it's disabled for now since it makes BEAM slow for little gain
291
+ check(self.opts.has_local, "target does not support local")
292
+ check(self.axis_types[axis] is AxisType.GLOBAL, "local is for globals")
293
+ self.shift_to(axis, amt, AxisType.LOCAL, insert_at=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL))+1)
294
+ elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
295
+ check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
296
+ check(self.axis_types[axis] is AxisType.REDUCE, "must be reduce axis to group")
297
+ check(not self.tensor_core, "can't group with tensor cores")
298
+ check(len(reduce_axes:=[i for r in self.reduceops for i in r.axis_arg]) == len(set(reduce_axes)), "can't group with parallel reduces")
299
+ self.shift_to(axis, amt, AxisType.GROUP_REDUCE, top=(opt.op is OptOps.GROUPTOP), insert_at=min(self.axes_of(AxisType.REDUCE)))
300
+ elif opt.op is OptOps.UNROLL: # purple
301
+ check(self.axis_types[axis] not in (AxisType.UPCAST, AxisType.UNROLL), "can't upcasted already upcasted")
302
+ check(amt <= 32, "don't unroll more than 32")
303
+ self.shift_to(axis, amt, AxisType.UNROLL, insert_at=None)
304
+ elif opt.op is OptOps.UPCAST: # yellow
305
+ check(axis in self.upcastable_dims, f"{axis=} not in {self.upcastable_dims=}")
306
+ # NOTE: assume the first get_local_axes() LOCAL are for TC
307
+ check(not (self.tensor_core and axis in self.axes_of(AxisType.LOCAL)[:len(self.tensor_core.get_local_axes())]), "can't upcast TC locals")
308
+ check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
309
+ self.shift_to(axis, amt, AxisType.UPCAST, insert_at=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP, AxisType.UPCAST))+1)
310
+ elif opt.op is OptOps.NOLOCALS:
311
+ check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
312
+ check(AxisType.LOCAL not in self.axis_types and self.group_for_reduces == 0, "can't have no locals with locals")
313
+ self.dont_use_locals = True
314
+ elif opt.op is OptOps.SWAP:
315
+ check(axis < amt, f"swap is only for axis < amt, getting {amt=}, {axis=}")
316
+ check(self.axis_types[axis]==self.axis_types[amt]==AxisType.GLOBAL, f"swap is for globals {self.axis_types[axis]=}, {self.axis_types[amt]=}")
317
+ permute = list(range(self.shape_len))
318
+ permute[axis], permute[amt] = permute[amt], permute[axis]
319
+ self.permute(tuple(permute))
320
+ elif opt.op is OptOps.PADTO:
321
+ check(not self.vars, "does not work with symbolic shape")
322
+ check(self.axis_types[axis] not in (AxisType.UPCAST, AxisType.UNROLL), "cannot pad upcasted")
323
+ # ok to pad SUM if all parent ALU ops have f(0) = 0
324
+ if (r:=self.reduceop) is not None and self.axis_types[axis] in (AxisType.GROUP_REDUCE, AxisType.REDUCE):
325
+ check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}")
326
+ padded = False
327
+ for i,st in enumerate(self.sts):
328
+ if (s:=st.shape[axis]) == 1: continue # reduced
329
+ check(s > amt//4, f"pad adds more than quadruple the work {st.shape[axis]=} > {amt//4=}")
330
+ if (ru := round_up(cast(int, s), amt) - s):
331
+ # pad right seems to be faster
332
+ self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
333
+ padded = True
334
+ check(padded, "nothing was padded")
335
+
336
+ if append_opt: self.applied_opts.append(opt)
337
+ if self.simplify_ones() and self.tensor_core_opts:
338
+ self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
339
+
340
+ def apply_opts(self, opts:Sequence[Opt]) -> Kernel:
341
+ for opt in opts: self.apply_opt(opt)
342
+ return self
343
+
344
+ # **** kernel outputs, mostly tensor cores ****
345
+
346
+ def _create_tc_opts(self, reduceop:UOp, tc:TensorCore, axis:int, opt_level:int) -> TensorCoreOptions|None:
347
+ has_cast = tc.dtype_in != tc.dtype_out
348
+ if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
349
+
350
+ mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
351
+ if mul_op.op is not Ops.MUL: return None
352
+
353
+ def buf_index(src:UOp) -> int|None:
354
+ # TODO: apply tc even if the sources are not from LOAD
355
+ if src.op is Ops.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(src)
356
+ try:
357
+ if opt_level >= 1 and src.op is Ops.CAST and src.dtype == tc.dtype_in: return self.bufs.index(src.src[0])
358
+ except ValueError: return None
359
+ return None
360
+ if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
361
+
362
+ buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
363
+ axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i in self.upcastable_dims if buf0_strides[i] == 0]
364
+ axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i in self.upcastable_dims if buf1_strides[i] == 0]
365
+ if not (axis_buf0 and axis_buf1 and (len(self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)) == 1 or (opt_level >= 1))): return None
366
+
367
+ axis_choices = list(itertools.product(axis_buf0, axis_buf1, self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)))
368
+ if not (axis < len(axis_choices)): return None
369
+
370
+ s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
371
+ axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if resolve(self.full_shape[x]%tc.dims[i] != 0))
372
+ if axis_pads and (opt_level < 2): return None
373
+ if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
374
+ return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
375
+
376
+ def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> bool:
377
+ if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD:
378
+ tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
379
+ for tc in tensor_cores:
380
+ tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
381
+ if tensor_core_opts[0] is None: continue
382
+ # can only fuse reduces with the same tc options
383
+ assert all_same(tensor_core_opts)
384
+ self.tensor_core_opts = tc_opts = tensor_core_opts[0]
385
+
386
+ # attempt to pad the tensor axes that require it
387
+ try:
388
+ for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
389
+ except KernelOptError: continue
390
+ # tensor core -- unroll the reduce dim (K), upcast and local the inner and outer dims (N, M)
391
+ for opt in tc.opts: self.apply_opt(Opt({"u":OptOps.UPCAST, "l":OptOps.LOCAL}[opt[0]], tc_opts.axes[int(opt[1])], 2), append_opt=False)
392
+ for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, 0, amt), append_opt=False) # TODO: this should be the reduce, not 0
393
+ self.tensor_core = tc
394
+ self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
395
+ return True
396
+ return False
397
+
398
+ def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:list[Opt]|None=None, axis:int=0, tc_select:int|None=None, tc_opt:int|None=None) -> bool:
399
+ """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
400
+ Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
401
+
402
+ Keyword arguments:
403
+ use_tensor_cores -- controls how tensor cores are applied (default 1)
404
+ 0: will disable any tensor core matching
405
+ 1: enable tensor cores
406
+ 2: apply tensor core shape but don't use UOp.WMMA
407
+ extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
408
+ tc_select -- specifies which tensor core(s) to use for optimization (default -1)
409
+ -1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
410
+ [0-N]: uses only the n'th tensor core available; useful for search
411
+ tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
412
+ 0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL
413
+ 1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers
414
+ 2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
415
+ """
416
+ if tc_select is None: tc_select = TC_SELECT.value
417
+ if tc_opt is None: tc_opt = TC_OPT.value
418
+ if not self.opts.tensor_cores: return False
419
+ try: # check TC first and apply hand-coded opts if successful
420
+ self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt, use_tensor_cores)))
421
+
422
+ if (tc_opts:=self.tensor_core_opts) is not None:
423
+ if extra_opts is not None: self.apply_opts(extra_opts)
424
+ else:
425
+ if AMX: return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
426
+ # hand-coded TC opts
427
+ for tc_dim in [tc_dim for tc_dim in [1,0] if tc_opts.axes_exist[tc_dim]]: # attempt to upcast M and N
428
+ szs = [sz for sz in [5,4,3,2] if self.full_shape[tc_opts.axes[tc_dim]] % sz == 0]
429
+ if szs: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], szs[0]))
430
+
431
+ if tc_opts.axes_exist[0] and (szs := [sz for sz in [4,2] if self.full_shape[tc_opts.axes[0]] % sz == 0]): # attempt to local N
432
+ self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], szs[0]))
433
+ return True
434
+ except KernelOptError:
435
+ return False
436
+
437
+ # strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2']
438
+ def shape_str(self) -> list[str]:
439
+ ret: list[str] = []
440
+ cnt: dict[AxisType, int] = {}
441
+ for x in self.axis_types:
442
+ cnt[x] = (cnt[x] + 1) if x in cnt else 0
443
+ ret.append(f"{axis_letters[x]}{cnt[x]}")
444
+ return ret
445
+ def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms])
446
+
447
+ def get_optimized_ast(self, name_override:str|None=None) -> UOp:
448
+ @functools.cache
449
+ def fixup_ast(op:UOp) -> UOp:
450
+ ret = op.replace(src=tuple(fixup_ast(x) for x in op.src)) # noqa: F821
451
+ if op.op in GroupOp.Buffer and op in self.bufs:
452
+ st = self.sts[self.bufs.index(op)]
453
+ # replace the VIEW source
454
+ return ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:])
455
+ if op.op is Ops.SINK:
456
+ # NOTE: should group_for_reduces be added to the local_dims?
457
+ # TODO: arg.name should be able to be None
458
+ kernel_name = ret.arg.name if ret.arg is not None and ret.arg.name != "test" else self.name if name_override is None else name_override
459
+ return ret.replace(arg=KernelInfo(kernel_name, tuple(self.axis_types), self.dont_use_locals, tuple(self.applied_opts)))
460
+ if op.op is Ops.REDUCE_AXIS:
461
+ reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
462
+ changed = tuple(i for i in range(self.shape_len) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
463
+ axes = tuple(i for i in self.axes_of(AxisType.REDUCE, AxisType.UNROLL) if i in changed)
464
+ grouped_axes = tuple(i for i in self.axes_of(AxisType.GROUP_REDUCE) if i in changed)
465
+ if (tc := self.tensor_core) and self.use_tensor_cores == 1:
466
+ # get reduce/upcast axes for the tensor cores
467
+ tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))])
468
+ base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis(tc.base_upcast_axes())])
469
+ tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)])
470
+
471
+ # permute the srcs
472
+ srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
473
+ for i, (src, permaxis) in enumerate(zip(srcs, tc.permutes_for_shape_str(self.shape_str()))):
474
+ src_st = (src if src.op is Ops.LOAD else src.src[0]).st_arg
475
+ srcs[i] = src.view(ShapeTracker.from_shape(src_st.shape).permute(permaxis))
476
+
477
+ # construct the op
478
+ wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes)
479
+ wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
480
+ UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
481
+ UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]),
482
+ UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
483
+ tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])
484
+
485
+ # preserve any other reduce
486
+ return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop
487
+
488
+ ret = ret.replace(arg = (op.arg[0], axes))
489
+ if self.group_for_reduces and grouped_axes:
490
+ local_axes = tuple([i for i,t in enumerate(self.axis_types) if t in (AxisType.LOCAL, AxisType.UPCAST) or i in grouped_axes])
491
+ slocal, supcast, sgroup = sorted(self.axes_of(AxisType.LOCAL)), sorted(self.axes_of(AxisType.UPCAST)), sorted(grouped_axes)
492
+ # NOTE: start with UPCAST at the end so it has stride 1 and can merge
493
+ base_shape = tuple([self.full_shape[i] for i in slocal] + [self.full_shape[i] for i in sgroup] + [self.full_shape[i] for i in supcast])
494
+ permute_axes = tuple([local_axes.index(i) for i in slocal+sgroup+supcast])
495
+ local_shape = tuple([s if i in local_axes else 1 for i,s in enumerate(self.full_shape)])
496
+ local_src_shape = tuple([self.full_shape[i] if i in self.axes_of(AxisType.GLOBAL) else s for i,s in enumerate(local_shape)])
497
+ st = ShapeTracker.from_shape(base_shape).permute(permute_axes).reshape(local_shape).expand(local_src_shape)
498
+ local_size = st.real_size()
499
+ local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, addrspace=AddrSpace.LOCAL), (), f"temp{self.reduceops.index(op)}")
500
+ local_load = local_buffer.view(st).load(local_buffer.view(st).store(ret))
501
+ grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
502
+ if op is self.reduceops[-1]: return grouped_reduce
503
+ st = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else s for i,s in enumerate(local_shape)]))
504
+ return local_buffer.view(st).load(local_buffer.view(st).store(grouped_reduce))
505
+
506
+ return ret
507
+ self.finalized = True
508
+ fixed_ast = fixup_ast(self.ast)
509
+ del fixup_ast
510
+ return graph_rewrite(fixed_ast, view_left+view_left_through_load, name="fixup optimized AST")