tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -1,693 +0,0 @@
1
- from __future__ import annotations
2
- import itertools, functools, math
3
- from dataclasses import dataclass
4
- from collections import defaultdict
5
- from typing import Optional, cast, Final, Callable, Sequence
6
-
7
- from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, view_left, print_uops
8
- from tinygrad.ops import PatternMatcher
9
- from tinygrad.spec import type_verify, shape_spec
10
- from tinygrad.device import Device
11
- from tinygrad.renderer import Renderer, TensorCore, ProgramSpec, Opt, OptOps
12
- from tinygrad.dtype import ImageDType
13
- from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap, ContextVar
14
- from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, USE_TC, AMX, CAPTURE_PROCESS_REPLAY
15
- from tinygrad.shape.shapetracker import ShapeTracker
16
- from tinygrad.shape.view import strides_for_shape
17
- from tinygrad.codegen.linearize import linearize_uop
18
- from tinygrad.codegen.devectorizer import full_graph_rewrite
19
- from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction
20
-
21
- class KernelOptError(Exception): pass
22
-
23
- def check(cond:bool, msg:str=""):
24
- if not cond: raise KernelOptError(msg)
25
-
26
- @dataclass
27
- class TensorCoreOptions:
28
- axes: tuple[int, ...] # the location of the original N and M axes if still in the shape
29
- axes_exist: tuple[bool, ...] # true if the original N and M axes are still in the shape
30
- axis_pads: tuple[tuple[int, int], ...]
31
- def fix_axes(self, removed_axis:int): # adjust the TC axes if necessary when a dimension is removed
32
- axes, axes_exist = list(self.axes), list(self.axes_exist)
33
- for tc_dim in [i for i in range(2) if axes_exist[i]]:
34
- if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1
35
- elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False
36
- self.axes, self.axes_exist = tuple(axes), tuple(axes_exist)
37
-
38
- class Kernel:
39
- def __init__(self, ast:UOp, opts:Optional[Renderer]=None):
40
- if ast.op is Ops.SINK: self.ast = ast
41
-
42
- self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
43
- # verify AST matches the spec
44
- if __debug__: type_verify(list(self.ast.toposort), shape_spec)
45
-
46
- self.reduceops = [x for x in self.ast.toposort if x.op is Ops.REDUCE_AXIS]
47
-
48
- self.vars: list[Variable] = self.ast.variables()
49
- # NOTE: this requires a specific order with the [::-1], this is likely a bug
50
- self.bufs: list[UOp] = [x for x in self.ast.toposort if x.op in GroupOp.Buffer][::-1]
51
-
52
- # get earlybufs, before any reduceops
53
- earlybufs: list[UOp] = [x for reduceop in self.reduceops for x in reduceop.src[0].toposort if x.op in GroupOp.Buffer]
54
- self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
55
- # NOTE: full_shape can be wrong if there's a tree of reduces
56
-
57
- # create new shapetrackers inside this kernel, we will permute them
58
- self.sts: list[ShapeTracker] = [x.st_arg for x in self.bufs]
59
-
60
- # add the shapetrackers for each reduce
61
- # we use this to track which axes are reduced in each reduce
62
- for x in self.reduceops:
63
- self.sts.append(unwrap(x.st))
64
- self.sts.append(unwrap(x.src[0].st))
65
-
66
- # move all reduce axes to the end
67
- reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
68
- permute = tuple([i for i,(s,n) in reduce if not resolve(s != n)] + [i for i,(s,n) in reduce if resolve(s != n)])
69
- self.reshape_and_permute(None, permute)
70
-
71
- # parameters for optimization
72
- self.applied_opts: list[Opt] = []
73
- self.group_for_reduces: int = 0
74
- self.upcasted: int = 0
75
- self.local_dims: int = 0
76
- self.tensor_core: Optional[TensorCore] = None
77
- self.tensor_core_opts: Optional[TensorCoreOptions] = None
78
- self.use_tensor_cores: int = 0
79
- self.dont_use_locals: bool = False
80
-
81
- # group simplifies
82
- self.simplify_ones()
83
- self.simplify_merge_adjacent()
84
-
85
- def copy(self):
86
- ret = type(self).__new__(type(self))
87
-
88
- # base linearizer params
89
- ret.opts, ret.ast = self.opts, self.ast
90
-
91
- # things downstream of the AST
92
- ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = self.reduceops, self.vars, self.bufs, self.full_buf_index
93
- ret.sts = self.sts[:len(ret.bufs)+len(ret.reduceops)*2] # NOTE: must redo the local buffers with TC in beam
94
-
95
- # parameters for optimizations
96
- ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
97
- self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
98
- ret.tensor_core, ret.tensor_core_opts, ret.use_tensor_cores = self.tensor_core, self.tensor_core_opts, self.use_tensor_cores
99
-
100
- return ret
101
-
102
- @property
103
- def membufs(self) -> list[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
104
-
105
- # TODO: these need more tests or it might silently be no-op
106
- def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501
107
-
108
- def upcasted_axis(self, i:int) -> list[tuple[int, Optional[sint], bool]]:
109
- upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:]
110
- assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
111
- return list(zip(upcasted_shape, upcasted_stride,
112
- [x!=y for x,y in zip(self.sts[0].shape[self.first_upcast:], self.full_shape[self.first_upcast:])]))
113
-
114
- @property
115
- def first_reduce(self) -> int:
116
- return [resolve(x!=y) for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True)
117
-
118
- @property
119
- def first_upcast(self) -> int: return self.shape_len-self.upcasted
120
-
121
- @property
122
- def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None
123
-
124
- @property
125
- def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape
126
-
127
- @property
128
- def full_shape(self) -> tuple[sint, ...]: return self.sts[self.full_buf_index].shape
129
-
130
- @property
131
- def full_unupcasted_shape(self) -> tuple[sint, ...]: return self.full_shape[:self.first_upcast]
132
-
133
- @property
134
- def shape_len(self) -> int: return len(self.sts[0].shape)
135
-
136
- @property
137
- def global_dims(self) -> int: return self.first_reduce-self.local_dims
138
-
139
- # there's eight chunks of the shape
140
- # blue -- global dims
141
- # cyan -- local dims (warp ones first)
142
- # *** self.first_reduce
143
- # green -- reduce-local dims
144
- # red -- reduce loops
145
- # *** self.upcasted
146
- # purple -- reduce upcasted
147
- # yellow -- normal upcasted dimensions
148
- def colors(self) -> list[str]:
149
- # first non local non reduce dims are global (blue)
150
- colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
151
- # after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
152
- colors += ["cyan"] * self.local_dims
153
- # between first_reduce and first_reduce + group_for_reduces, they are late upcasted (green)
154
- colors += ["green"] * self.group_for_reduces
155
- # between first_reduce + group_for_reduces and upcasted, they are reduce (red)
156
- colors += ["red"] * (self.first_upcast - (self.first_reduce + self.group_for_reduces))
157
- # upcasted dimensions are reduce (magenta) or normal (yellow)
158
- colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.first_upcast, self.shape_len)]
159
- assert len(colors) == self.shape_len, "colors size mismatch"
160
- return colors
161
-
162
- def colored_shape(self, pad:Optional[int]=None, dense=False) -> str:
163
- shape_strs = [(s if dense else f"{s:4d}") if isinstance(s, int) else s.render() for s in self.full_shape]
164
- ret = ' '.join(colored(s, color) for s,color in zip(shape_strs, self.colors()))
165
- if pad: ret += ' '*(pad-ansilen(ret))
166
- return ret
167
-
168
- # ******************** base simplifiers ********************
169
-
170
- # apply reshape and permute to all shapetrackers
171
- def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[tuple[sint, ...]], Sequence[sint]]], axis:Optional[Sequence[int]]):
172
- def reshape(st:ShapeTracker): return st.reshape(tuple(new_shape_fxn(st.shape))) if new_shape_fxn is not None else st
173
- def permute(st:ShapeTracker): return st.permute(tuple(axis)) if axis is not None else st
174
- self.sts = [permute(reshape(st)) for st in self.sts]
175
-
176
- # drops the final dimension
177
- def upcast(self):
178
- check(self.full_shape[-1] != 1, "can't upcast a dimension with size 1")
179
- self.upcasted += 1
180
-
181
- # axis : the axis to pull from
182
- # amount : the amount to take
183
- # top : if you want to pull that amount from the top
184
- # insert_before : place to insert the new stuff
185
- def shift_to(self, axis, amount, top=False, insert_before=None):
186
- if insert_before is None: insert_before = self.shape_len
187
- move_axis = axis if top else axis+1
188
- if move_axis < insert_before: insert_before += 1
189
- self.reshape_and_permute(
190
- lambda x: x[0:axis] + (((amount, x[axis]//amount) if top else (x[axis]//amount, amount)) if x[axis] > 1 else (1,1)) + x[axis+1:],
191
- [i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])
192
-
193
- # ******************** complex simplifiers ********************
194
-
195
- def simplify_ones(self) -> bool:
196
- # remove places where the shape is all ones
197
- # TODO: this should be factored in to multi shape stride
198
- if self.shape_len == 0: return False
199
- all_ones = [s==1 for s in self.full_shape]
200
- self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
201
- self.upcasted -= sum(all_ones[self.first_upcast:]) # TODO: no necessary since upcasted axis can't be un-upcasted
202
- self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
203
- return any(all_ones)
204
-
205
- def simplify_merge_adjacent(self):
206
- if self.shape_len == 0: return
207
- shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
208
-
209
- # if it's an image, insert fake strides such that this fusion doesn't happen across image axes
210
- if isinstance(self.membufs[0].dtype, ImageDType):
211
- base_shape = self.membufs[0].dtype.shape
212
- if shape_idx_groups := get_contraction(self.output_shape, base_shape):
213
- special_strides: tuple[sint, ...] = tuple()
214
- for i,g in enumerate(shape_idx_groups):
215
- shape_piece = tuple(self.output_shape[x] for x in g)
216
- assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
217
- special_strides += strides_for_shape(shape_piece)
218
- # adding the fake image shape
219
- shapes.append(self.output_shape)
220
- strides.append(special_strides)
221
-
222
- # merge dimensions if we can, multi _merge_dims
223
- # NOTE: this does not always preserve the reduce dimension
224
- # TODO: move this into shapetracker, with tests!
225
- # TODO: how does this work with multi-reduce?
226
- rets = [[(s[0], st[0])] for s,st in zip(shapes, strides)]
227
- for i in range(1, len(shapes[0])):
228
- can_merge = []
229
- for s,st,ret in zip(shapes, strides, rets):
230
- # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
231
- si, sti, last_st = s[i], st[i], ret[-1][1]
232
- can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0)))
233
- # more can merge than this
234
- mergeable = all(can_merge) and i != self.first_reduce
235
- for j,(s,st) in enumerate(zip(shapes, strides)):
236
- if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i])
237
- else: rets[j].append((s[i], st[i]))
238
-
239
- # do the reshapes
240
- for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
241
-
242
- # ******************** high level optimizers ********************
243
-
244
- def _create_tc_opts(self, reduceop:UOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
245
- has_cast = tc.dtype_in != tc.dtype_out
246
- if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
247
-
248
- mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
249
- if mul_op.op is not Ops.MUL: return None
250
-
251
- def buf_index(src:UOp) -> Optional[int]:
252
- # TODO: apply tc even if the sources are not from LOAD
253
- if src.op is Ops.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(src)
254
- try:
255
- if opt_level >= 1 and src.op is Ops.CAST and src.dtype == tc.dtype_in: return self.bufs.index(src.src[0])
256
- except ValueError: return None
257
- return None
258
- if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
259
-
260
- buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
261
- axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
262
- axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
263
- if not (axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
264
-
265
- axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
266
- if not (axis < len(axis_choices)): return None
267
-
268
- s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
269
- axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if resolve(self.full_shape[x]%tc.dims[i] != 0))
270
- if axis_pads and (opt_level < 2): return None
271
- if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
272
- return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
273
-
274
- def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> bool:
275
- if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD:
276
- tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
277
- for tc in tensor_cores:
278
- tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
279
- # can only fuse reduces with the same tc options
280
- assert all_same(tensor_core_opts)
281
- if tensor_core_opts[0] is None: continue
282
- self.tensor_core_opts = tc_opts = tensor_core_opts[0]
283
-
284
- # attempt to pad the tensor axes that require it
285
- try:
286
- for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
287
- except KernelOptError: continue
288
- # tensor core -- unroll the reduce dim (K), upcast and local the inner and outer dims (N, M)
289
- for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, amt), append_opt=False)
290
- for opt in tc.opts: self.apply_opt(Opt({"u":OptOps.UPCAST, "l":OptOps.LOCAL}[opt[0]], tc_opts.axes[int(opt[1])], 2), append_opt=False)
291
- self.tensor_core = tc
292
- self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
293
- return True
294
- return False
295
-
296
- def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[list[Opt]]=None, axis:int=0, tc_select:Optional[int]=None,
297
- tc_opt:Optional[int]=None) -> bool:
298
- """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
299
- Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
300
-
301
- Keyword arguments:
302
- use_tensor_cores -- controls how tensor cores are applied (default 1)
303
- 0: will disable any tensor core matching
304
- 1: enable tensor cores
305
- 2: apply tensor core shape but don't use UOp.WMMA
306
- extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
307
- tc_select -- specifies which tensor core(s) to use for optimization (default -1)
308
- -1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
309
- [0-N]: uses only the n'th tensor core available; useful for search
310
- tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
311
- 0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL
312
- 1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers
313
- 2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
314
- """
315
- if tc_select is None: tc_select = TC_SELECT.value
316
- if tc_opt is None: tc_opt = TC_OPT.value
317
- if not self.opts.tensor_cores and use_tensor_cores != 2: return False
318
- try: # check TC first and apply hand-coded opts if successful
319
- self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt)))
320
-
321
- if (tc_opts:=self.tensor_core_opts) is not None:
322
- if extra_opts is not None:
323
- for opt in extra_opts: self.apply_opt(opt)
324
- else:
325
- if AMX: return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
326
- # hand-coded TC opts
327
- for tc_dim in [tc_dim for tc_dim in [1,0] if tc_opts.axes_exist[tc_dim]]: # attempt to upcast M and N
328
- szs = [sz for sz in [5,4,3,2] if self.full_shape[tc_opts.axes[tc_dim]] % sz == 0]
329
- if szs: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], szs[0]))
330
-
331
- if tc_opts.axes_exist[0] and (szs := [sz for sz in [4,2] if self.full_shape[tc_opts.axes[0]] % sz == 0]): # attempt to local N
332
- self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], szs[0]))
333
- return True
334
- except KernelOptError:
335
- return False
336
-
337
- def real_axis(self, opt:Opt):
338
- if opt.axis is None: return -1
339
- if opt.op is OptOps.UNROLL: return self.first_reduce+opt.axis
340
- if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.first_reduce+self.group_for_reduces+opt.axis
341
- return opt.axis
342
-
343
- def apply_opt(self, opt:Opt, append_opt:bool=True):
344
- if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals")
345
-
346
- if opt.op is OptOps.TC:
347
- check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
348
- check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
349
- check(opt.axis is not None, "tensor core opts must have an axis")
350
- check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 2, "tensor core opts must have tc_select and tc_opt")
351
- check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
352
- check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
353
- check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available")
354
- self.applied_opts.append(opt)
355
- return
356
-
357
- axis = self.real_axis(opt)
358
- check(axis < len(self.full_shape), "invalid axis")
359
-
360
- if opt.op is OptOps.SWAP: amt = cast(int, opt.arg) # arg is an axis in the SWAPs
361
- elif opt.arg is not None:
362
- check(isinstance(opt.arg, int), "arg should be int")
363
- amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis]
364
- check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless")
365
- if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}")
366
- else: amt = -1
367
-
368
- if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
369
- (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
370
- acc_sz = self.reduceop.dtype.itemsize
371
- upcast_sz = prod([a for a,b in zip(self.full_shape[self.first_upcast:], self.sts[0].shape[self.first_upcast:]) if a == b])
372
- local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
373
- smem_sz = amt*acc_sz*upcast_sz*local_sz
374
- check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
375
-
376
- if opt.op is OptOps.LOCAL: # cyan
377
- # NOTE: LLVM/CPU can use locals too, but they are treated the same as globals (still helpful for L1 cache)
378
- # it's disabled for now since it makes BEAM slow for little gain
379
- check(self.opts.has_local, "target does not support local")
380
- check(axis < self.global_dims, "local is for globals")
381
- self.shift_to(axis, amt, insert_before=self.first_reduce)
382
- self.local_dims += 1
383
- elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
384
- check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
385
- check(self.first_reduce + self.group_for_reduces <= axis < self.first_upcast, "must be reduce axis to group")
386
- check(not self.tensor_core, "can't group with tensor cores")
387
- check(len(reduce_axes:=[i for r in self.reduceops for i in r.axis_arg]) == len(set(reduce_axes)), "can't group with parallel reduces")
388
- self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
389
- self.group_for_reduces += 1
390
- elif opt.op is OptOps.UNROLL: # purple
391
- check(axis < self.first_upcast, "can't upcasted already upcasted")
392
- check(amt <= 32, "don't unroll more than 32")
393
- # TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
394
- #upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
395
- #self.shift_to(axis, amt, insert_before=None if upcast_count == 0 else self.shape_len-upcast_count)
396
- if self.full_shape[axis] == amt and axis == self.first_reduce: self.local_dims += 1 # first_reduce will ++, so offset loss in simplify_ones
397
- if self.full_shape[axis] == amt and axis < self.first_reduce+self.group_for_reduces: self.group_for_reduces -= 1 # fully unrolling a GROUP
398
- self.shift_to(axis, amt, insert_before=None)
399
- self.upcast()
400
- elif opt.op is OptOps.UPCAST: # yellow
401
- check(axis < self.first_reduce, "upcast is for non-reduce")
402
- check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.get_local_axes())), "can't upcast TC locals")
403
- check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
404
- self.shift_to(axis, amt, insert_before=None)
405
- self.upcast()
406
- elif opt.op is OptOps.NOLOCALS:
407
- check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
408
- check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
409
- self.dont_use_locals = True
410
- elif opt.op is OptOps.SWAP:
411
- check(axis < amt < self.global_dims, f"swap is only for globals with axis < amt, getting {amt=}, {axis=}, {self.global_dims=}")
412
- permute = list(range(self.shape_len))
413
- permute[axis], permute[amt] = permute[amt], permute[axis]
414
- self.reshape_and_permute(None, tuple(permute))
415
- elif opt.op is OptOps.PADTO:
416
- check(not self.vars, "does not work with symbolic shape")
417
- check(axis < self.first_upcast, "cannot pad upcasted")
418
- # ok to pad SUM if all parent ALU ops have f(0) = 0
419
- if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}, cache={}), f"cannot pad {r}")
420
- padded = False
421
- for i,st in enumerate(self.sts):
422
- if (s:=st.shape[axis]) == 1: continue # reduced
423
- check(s > amt//4, f"pad adds more than quadruple the work {st.shape[axis]=} > {amt//4=}")
424
- if (ru := round_up(cast(int, s), amt) - s):
425
- # pad right seems to be faster
426
- self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
427
- padded = True
428
- check(padded, "nothing was padded")
429
-
430
- if append_opt: self.applied_opts.append(opt)
431
- if self.simplify_ones() and self.tensor_core_opts:
432
- self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
433
-
434
- def required_optimizations(self) -> Kernel:
435
- if isinstance(self.membufs[0].dtype, ImageDType):
436
- unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
437
- assert unit_stride_axes_mul_4, f"needs a unit stride axis in {self.bufs[0]}"
438
- if all(x < self.first_upcast for x in unit_stride_axes_mul_4): self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
439
- return self
440
-
441
- def hand_coded_optimizations(self) -> Kernel:
442
- self.required_optimizations()
443
-
444
- # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
445
- MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
446
- if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
447
- self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \
448
- (mulop:=self.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
449
- st0, st1 = self.sts[self.bufs.index(mulop.src[0])], self.sts[self.bufs.index(mulop.src[1])]
450
- strides0, strides1 = st0.real_strides(), st1.real_strides()
451
- def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
452
- if strides0[self.first_reduce] == 1 and not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
453
- for global_idx in range(self.global_dims):
454
- if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
455
- if DEBUG >= 3:
456
- print(f"MATVEC: {self.full_shape=} {self.first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
457
- if MV_THREADS_PER_ROW > 1: self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
458
- if MV_BLOCKSIZE > 1: self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
459
- if MV_ROWS_PER_THREAD > 1: self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
460
- return self
461
-
462
- if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]):
463
- # are we grouping? (requires local shape support)
464
- if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # noqa: E501
465
- # TODO: use 1024 if it's allowed in a smarter way
466
- for sz in ([256, 16] if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
467
- if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
468
- try: # may fail due to excessive smem usage
469
- self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
470
- break
471
- except KernelOptError: pass
472
-
473
- # upcast float4 images
474
- for buf_index,buf in enumerate(self.bufs):
475
- unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
476
- if buf.src[0].dtype.__class__ is ImageDType:
477
- #assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
478
- if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4):
479
- if unit_stride_axes_mul_4[0] < self.first_reduce:
480
- self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
481
- else:
482
- self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
483
-
484
- # no more opt if we are grouping
485
- if self.group_for_reduces: return self
486
-
487
- # **** below this line need to be optional and benchmarked ****
488
-
489
- # TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
490
- # to trigger the above bug, remove prod(self.full_shape[self.first_upcast:]) from the below
491
- # expression and run test/test_ops.py with IMAGE=2
492
- # if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
493
- # this can be made much smarter
494
- to_upcast: list[int] = []
495
- # upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
496
- for axis in range(self.first_reduce):
497
- # we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
498
- # for now skip upcasting here if there is a symbolic axis
499
- if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \
500
- prod(self.full_shape[self.first_upcast:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
501
- if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
502
- to_upcast.append(axis)
503
- for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
504
-
505
- # potentially do more upcasts of non reduce axes based on a heuristic
506
- upcasted_axis: set[int] = set()
507
- while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024):
508
- xb_choices = []
509
- for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
510
- # if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
511
- if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): # noqa: E501
512
- xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) # noqa: E501
513
- if xb_choices:
514
- xb_choices = sorted(xb_choices)
515
- if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
516
- self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
517
- upcasted_axis.add(xb_choices[0][2])
518
- else: break
519
-
520
- # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
521
- if self.first_reduce < self.first_upcast and (prod(self.full_shape[self.first_upcast:]) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): # noqa: E501
522
- if isinstance(s:=self.full_unupcasted_shape[-1], int) and s <= 32: # NOTE: cannot loop unroll symbolic axis
523
- self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
524
- # if it's small, upcast a second reduce dimension too
525
- if self.first_reduce < self.first_upcast and s <= 3 and isinstance(s2:=self.full_unupcasted_shape[-1], int) and s2 <= 3:
526
- self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
527
- else:
528
- for splits in [4]:
529
- if self.full_unupcasted_shape[-1]%splits == 0:
530
- self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits))
531
- break
532
-
533
- # if nothing at all is upcasted and it's easy to, do an upcast
534
- # TODO: this is breaking the tests
535
- for splits in [4]:
536
- if self.upcasted == 0 and self.full_unupcasted_shape and self.full_unupcasted_shape[-1] % splits == 0:
537
- self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits))
538
-
539
- # **** local groups ****
540
-
541
- if self.opts.has_local:
542
- if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduces:
543
- self.apply_opt(Opt(OptOps.NOLOCALS))
544
- else:
545
- # prioritize making expand axes local
546
- local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] # noqa: E501
547
- to_local: list[tuple[int, int]] = []
548
- for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
549
- local_size = prod(sz for _, sz in to_local)
550
- local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None) # noqa: E501
551
- if local_sz is not None: to_local.append((axis, local_sz))
552
- deleted_shape = 0
553
- for axis, local_sz in sorted(to_local[:3]):
554
- axis = axis - deleted_shape
555
- will_delete_shape = local_sz == self.full_shape[axis]
556
- self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
557
- if will_delete_shape: deleted_shape += 1
558
-
559
- return self
560
-
561
- # **** kernel outputs ****
562
-
563
- kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
564
- @functools.cached_property
565
- def name(self) -> str:
566
- # kernel name (before late upcast)
567
- kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op is Ops.SINK or x.op in GroupOp.Buffer for x in self.ast.toposort) else "E")
568
- suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())])
569
- name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix
570
-
571
- # name the function something unique
572
- Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1
573
- num = f"n{Kernel.kernel_cnt[function_name]-1}" if Kernel.kernel_cnt[function_name] > 1 else ""
574
- return name + colored(num, 'BLACK')
575
-
576
- def get_optimized_ast(self, name_override:Optional[str]=None) -> UOp:
577
- @functools.lru_cache(None)
578
- def fixup_ast(op:UOp) -> UOp:
579
- ret = op.replace(src=tuple(fixup_ast(x) for x in op.src))
580
- if op.op in GroupOp.Buffer and op in self.bufs:
581
- st_uop = self.sts[self.bufs.index(op)].to_uop()
582
- # NOTE: if CONST got masked after applying opts, we create a new VALID
583
- if op.op is Ops.CONST and any(v.mask is not None for v in unwrap(st_uop.st).views): return op.valid(unwrap(st_uop.st))
584
- # otherwise we just replace the VIEW source
585
- return ret.replace(src=(st_uop,)) if len(op.src) == 1 else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
586
- if op.op is Ops.SINK:
587
- return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override,
588
- self.local_dims, self.upcasted, self.dont_use_locals))
589
- if op.op is Ops.REDUCE_AXIS:
590
- reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
591
-
592
- def reduced_axes(start, stop):
593
- return tuple(i for i in range(start, stop) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
594
- axes = reduced_axes(self.first_reduce + self.group_for_reduces, self.shape_len)
595
- grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces)
596
-
597
- if (tc := self.tensor_core) and (self.use_tensor_cores == 1 or self.use_tensor_cores == 3):
598
- wd, tcd = self.global_dims, self.first_upcast
599
- def get_upcast_axes(buf): # upcast along non-zero dimensions of (tc_reduce + tc_upcast)
600
- upcast_axes = int(math.log2(tc.elements_per_thread[buf]))
601
- return tuple((tcd + len(tc.get_reduce_axes()) + len(tc.get_upcast_axes()) - (i+1), 2) for i in range(upcast_axes))
602
- def get_tc_swizzle_st(shape, local_perm, upcast_perm):
603
- offset = (tcd - (wd + len(local_perm)))
604
- permaxis = list(range(wd)) \
605
- + [wd + x + (offset if x >= len(local_perm) else 0) for x in local_perm] + list(range(wd + len(local_perm), tcd)) \
606
- + [wd + x + (offset if x >= len(local_perm) else 0) for x in upcast_perm] + list(range(tcd + len(upcast_perm), len(shape)))
607
- return ShapeTracker.from_shape(shape).permute(tuple(permaxis))
608
-
609
- srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
610
- for i, (src, swizzle) in enumerate(zip(srcs, tc.swizzle)):
611
- if swizzle: srcs[i] = src.view(get_tc_swizzle_st((src if src.op is Ops.LOAD else src.src[0]).st_arg.shape, *swizzle))
612
-
613
- if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
614
- local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape))
615
- st = store_st = ShapeTracker.from_shape(local_shape)
616
- local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i}")
617
- if swizzle: store_st = get_tc_swizzle_st(store_st.shape, *swizzle)
618
- local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i])
619
- srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store))
620
-
621
- tc_reduce_axes = tuple(tcd + ax for ax, _ in tc.get_reduce_axes())
622
- if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/UNROLL to get the vectorization right
623
- tc_upcast_axes = (get_upcast_axes(0), get_upcast_axes(1), get_upcast_axes(2))
624
- wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes)
625
- wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
626
- UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
627
- UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]),
628
- UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
629
- tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])
630
-
631
- else: # for TC=3 MUL/SUM instead of WMMA
632
- tc_uop = UOp(Ops.REDUCE_AXIS, tc.dtype_out, ((srcs[0] * srcs[1]).cast(tc.dtype_out),), (Ops.ADD, tc_reduce_axes))
633
-
634
- return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop
635
-
636
- ret = ret.replace(arg = (op.arg[0], axes))
637
- if self.group_for_reduces and grouped_axes:
638
- local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims] + \
639
- tuple([self.full_shape[i] if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i] else 1 \
640
- for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
641
- (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
642
- st_uop = ShapeTracker.from_shape(local_shape).to_uop()
643
- local_size = st_uop.arg.real_size()
644
- local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
645
- local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret)))
646
- grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
647
- if op is self.reduceops[-1]: return grouped_reduce
648
- st_uop = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else a for i,a in enumerate(local_shape)])).to_uop()
649
- return UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, grouped_reduce)))
650
-
651
- return ret
652
-
653
- return graph_rewrite(fixup_ast(self.ast), view_left)
654
-
655
- # **** this is the lowerer ****
656
-
657
- @track_rewrites()
658
- def linearize(self, name_override:Optional[str]=None) -> Kernel:
659
- # display the AST
660
- if getenv("VIZ"): graph_rewrite(self.ast, PatternMatcher([]), name="View Base AST")
661
-
662
- modified_ast = self.get_optimized_ast(name_override)
663
-
664
- if DEBUG >= 3:
665
- print(self.name)
666
- if getenv("RAWAST"): print(self.ast)
667
- for i,(buf,st) in enumerate([(buf,st) for buf,st in zip(self.bufs, self.sts) if buf.op not in {Ops.CONST, Ops.VALID}]):
668
- print(f"{i:2d}: {str(st.shape):25s} {str(buf.src[0].dtype).replace('dtypes.',''):20s}", st.real_strides())
669
- print(self.applied_opts)
670
- # verify AST matches the spec after applying opts
671
- if __debug__: type_verify(list(modified_ast.toposort))
672
- # TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
673
- #if __debug__: type_verify(list(modified_ast.toposort), shape_spec)
674
-
675
- self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
676
- if DEBUG >= 5: print_uops(self.uops)
677
- return self
678
-
679
- def to_program(self, name_override:Optional[str]=None) -> ProgramSpec:
680
- self.linearize(name_override)
681
- assert self.uops[0].op is Ops.NAME, "first uop must be name"
682
- src = self.opts.render(self.uops)
683
-
684
- if CAPTURE_PROCESS_REPLAY:
685
- diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, self.uops[0].arg, ContextVar._cache, src))
686
-
687
- # group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
688
- # TODO: these max and min don't work on symbolic, and results are very wrong.
689
- mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
690
- for _, group in itertools.groupby([x for x in self.ast.toposort if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
691
- key=lambda x: (x.op, x.src[0].arg)))
692
- return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, self.applied_opts, mem_bytes,
693
- global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)