tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -1,51 +1,34 @@
1
1
  from __future__ import annotations
2
- import itertools, functools
2
+ import itertools, functools, math
3
3
  from dataclasses import dataclass
4
4
  from collections import defaultdict
5
- from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence
6
- from enum import Enum, auto
5
+ from typing import Optional, cast, Final, Callable, Sequence
7
6
 
8
- from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \
9
- graph_rewrite, track_rewrites, UPat
7
+ from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, view_left, print_uops
8
+ from tinygrad.ops import PatternMatcher
9
+ from tinygrad.spec import type_verify, shape_spec
10
10
  from tinygrad.device import Device
11
- from tinygrad.renderer import Renderer, TensorCore, Program
11
+ from tinygrad.renderer import Renderer, TensorCore, ProgramSpec, Opt, OptOps
12
12
  from tinygrad.dtype import ImageDType
13
- from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap
14
- from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
13
+ from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap, ContextVar
14
+ from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, USE_TC, AMX, CAPTURE_PROCESS_REPLAY
15
15
  from tinygrad.shape.shapetracker import ShapeTracker
16
16
  from tinygrad.shape.view import strides_for_shape
17
17
  from tinygrad.codegen.linearize import linearize_uop
18
- from tinygrad.codegen.uopgraph import full_graph_rewrite
18
+ from tinygrad.codegen.devectorizer import full_graph_rewrite
19
19
  from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction
20
20
 
21
- class OptOps(Enum):
22
- TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
23
- GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
24
- def __lt__(self, x:OptOps): return self.value < x.value
25
-
26
21
  class KernelOptError(Exception): pass
27
22
 
28
23
  def check(cond:bool, msg:str=""):
29
24
  if not cond: raise KernelOptError(msg)
30
25
 
31
- @dataclass(frozen=True, order=True)
32
- class Opt:
33
- op: OptOps
34
- axis: Optional[int] = None
35
- amt: Optional[int] = None
36
- def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
37
- def real_axis(self, k:Kernel):
38
- if self.axis is None: return -1
39
- if self.op is OptOps.UNROLL: return k.first_reduce+self.axis
40
- if self.op in {OptOps.GROUP, OptOps.GROUPTOP}: return k.first_reduce+k.group_for_reduces+self.axis
41
- return self.axis
42
-
43
26
  @dataclass
44
27
  class TensorCoreOptions:
45
- axes: Tuple[int, ...] # the location of the original N and M axes if still in the shape
46
- axes_exist: Tuple[bool, ...] # true if the original N and M axes are still in the shape
47
- axis_pads: Tuple[Tuple[int, int], ...]
48
- def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when a dimension is removed
28
+ axes: tuple[int, ...] # the location of the original N and M axes if still in the shape
29
+ axes_exist: tuple[bool, ...] # true if the original N and M axes are still in the shape
30
+ axis_pads: tuple[tuple[int, int], ...]
31
+ def fix_axes(self, removed_axis:int): # adjust the TC axes if necessary when a dimension is removed
49
32
  axes, axes_exist = list(self.axes), list(self.axes_exist)
50
33
  for tc_dim in [i for i in range(2) if axes_exist[i]]:
51
34
  if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1
@@ -57,32 +40,28 @@ class Kernel:
57
40
  if ast.op is Ops.SINK: self.ast = ast
58
41
 
59
42
  self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
60
- try: uop_sts_map = verify_ast(self.ast)
61
- except AssertionError as e:
62
- print("INVALID AST")
63
- print(self.ast)
64
- raise e
43
+ # verify AST matches the spec
44
+ if __debug__: type_verify(list(self.ast.toposort), shape_spec)
65
45
 
66
- @functools.lru_cache(None)
67
- def ordered_parents(op:UOp) -> List[UOp]: return dedup([item for x in op.src for item in ordered_parents(x)] + [op])
68
- self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS])
46
+ self.reduceops = [x for x in self.ast.toposort if x.op is Ops.REDUCE_AXIS]
69
47
 
70
- self.vars: List[Variable] = self.ast.variables()
71
- self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in GroupOp.Buffer]
48
+ self.vars: list[Variable] = self.ast.variables()
49
+ # NOTE: this requires a specific order with the [::-1], this is likely a bug
50
+ self.bufs: list[UOp] = [x for x in self.ast.toposort if x.op in GroupOp.Buffer][::-1]
72
51
 
73
52
  # get earlybufs, before any reduceops
74
- earlybufs: List[UOp] = [x for reduceop in self.reduceops for x in reduceop.parents if x.op in GroupOp.Buffer]
53
+ earlybufs: list[UOp] = [x for reduceop in self.reduceops for x in reduceop.src[0].toposort if x.op in GroupOp.Buffer]
75
54
  self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
76
55
  # NOTE: full_shape can be wrong if there's a tree of reduces
77
56
 
78
57
  # create new shapetrackers inside this kernel, we will permute them
79
- self.sts: List[ShapeTracker] = [x.st_arg for x in self.bufs]
58
+ self.sts: list[ShapeTracker] = [x.st_arg for x in self.bufs]
80
59
 
81
60
  # add the shapetrackers for each reduce
82
61
  # we use this to track which axes are reduced in each reduce
83
62
  for x in self.reduceops:
84
- self.sts.append(uop_sts_map[x])
85
- self.sts.append(uop_sts_map[x.src[0]])
63
+ self.sts.append(unwrap(x.st))
64
+ self.sts.append(unwrap(x.src[0].st))
86
65
 
87
66
  # move all reduce axes to the end
88
67
  reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
@@ -90,15 +69,13 @@ class Kernel:
90
69
  self.reshape_and_permute(None, permute)
91
70
 
92
71
  # parameters for optimization
93
- self.applied_opts: List[Opt] = []
72
+ self.applied_opts: list[Opt] = []
94
73
  self.group_for_reduces: int = 0
95
74
  self.upcasted: int = 0
96
75
  self.local_dims: int = 0
97
76
  self.tensor_core: Optional[TensorCore] = None
98
77
  self.tensor_core_opts: Optional[TensorCoreOptions] = None
99
78
  self.use_tensor_cores: int = 0
100
- # the local aliased buffers for A and B
101
- self.bufs_for_tensor_core: Dict[UOp, Tuple[int, int]] = {}
102
79
  self.dont_use_locals: bool = False
103
80
 
104
81
  # group simplifies
@@ -112,25 +89,23 @@ class Kernel:
112
89
  ret.opts, ret.ast = self.opts, self.ast
113
90
 
114
91
  # things downstream of the AST
115
- ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = \
116
- self.reduceops, self.vars, self.bufs, self.full_buf_index
92
+ ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = self.reduceops, self.vars, self.bufs, self.full_buf_index
117
93
  ret.sts = self.sts[:len(ret.bufs)+len(ret.reduceops)*2] # NOTE: must redo the local buffers with TC in beam
118
94
 
119
95
  # parameters for optimizations
120
96
  ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
121
97
  self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
122
- ret.tensor_core, ret.tensor_core_opts, ret.bufs_for_tensor_core, ret.use_tensor_cores = \
123
- self.tensor_core, self.tensor_core_opts, self.bufs_for_tensor_core, self.use_tensor_cores
98
+ ret.tensor_core, ret.tensor_core_opts, ret.use_tensor_cores = self.tensor_core, self.tensor_core_opts, self.use_tensor_cores
124
99
 
125
100
  return ret
126
101
 
127
102
  @property
128
- def membufs(self) -> List[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
103
+ def membufs(self) -> list[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
129
104
 
130
105
  # TODO: these need more tests or it might silently be no-op
131
106
  def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501
132
107
 
133
- def upcasted_axis(self, i:int) -> List[Tuple[int, Optional[sint], bool]]:
108
+ def upcasted_axis(self, i:int) -> list[tuple[int, Optional[sint], bool]]:
134
109
  upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:]
135
110
  assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
136
111
  return list(zip(upcasted_shape, upcasted_stride,
@@ -144,24 +119,20 @@ class Kernel:
144
119
  def first_upcast(self) -> int: return self.shape_len-self.upcasted
145
120
 
146
121
  @property
147
- def reduceop(self) -> Optional[UOp]: return self.reduceops[0] if len(self.reduceops) > 0 else None
122
+ def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None
148
123
 
149
124
  @property
150
- def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape
125
+ def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape
151
126
 
152
127
  @property
153
- def full_shape(self) -> Tuple[sint, ...]: return self.sts[self.full_buf_index].shape
128
+ def full_shape(self) -> tuple[sint, ...]: return self.sts[self.full_buf_index].shape
154
129
 
155
130
  @property
156
- def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.first_upcast]
131
+ def full_unupcasted_shape(self) -> tuple[sint, ...]: return self.full_shape[:self.first_upcast]
157
132
 
158
133
  @property
159
134
  def shape_len(self) -> int: return len(self.sts[0].shape)
160
135
 
161
- @property
162
- def upcast_in_mid_reduce_axes(self) -> List[int]:
163
- return [j for j in range(self.first_reduce, self.first_reduce+self.group_for_reduces) if self.full_shape[j] == self.sts[0].shape[j]]
164
-
165
136
  @property
166
137
  def global_dims(self) -> int: return self.first_reduce-self.local_dims
167
138
 
@@ -170,18 +141,17 @@ class Kernel:
170
141
  # cyan -- local dims (warp ones first)
171
142
  # *** self.first_reduce
172
143
  # green -- reduce-local dims
173
- # white -- reduce-late upcasted dim (self.upcast_in_mid_reduce_axes)
174
144
  # red -- reduce loops
175
145
  # *** self.upcasted
176
146
  # purple -- reduce upcasted
177
147
  # yellow -- normal upcasted dimensions
178
- def colors(self) -> List[str]:
148
+ def colors(self) -> list[str]:
179
149
  # first non local non reduce dims are global (blue)
180
150
  colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
181
151
  # after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
182
152
  colors += ["cyan"] * self.local_dims
183
- # between first_reduce and first_reduce + group_for_reduces, they are either upcast mid reduce (white), or late upcasted (green)
184
- colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + self.group_for_reduces)] # noqa: E501
153
+ # between first_reduce and first_reduce + group_for_reduces, they are late upcasted (green)
154
+ colors += ["green"] * self.group_for_reduces
185
155
  # between first_reduce + group_for_reduces and upcasted, they are reduce (red)
186
156
  colors += ["red"] * (self.first_upcast - (self.first_reduce + self.group_for_reduces))
187
157
  # upcasted dimensions are reduce (magenta) or normal (yellow)
@@ -198,7 +168,7 @@ class Kernel:
198
168
  # ******************** base simplifiers ********************
199
169
 
200
170
  # apply reshape and permute to all shapetrackers
201
- def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[Tuple[sint, ...]], Sequence[sint]]], axis:Optional[Sequence[int]]):
171
+ def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[tuple[sint, ...]], Sequence[sint]]], axis:Optional[Sequence[int]]):
202
172
  def reshape(st:ShapeTracker): return st.reshape(tuple(new_shape_fxn(st.shape))) if new_shape_fxn is not None else st
203
173
  def permute(st:ShapeTracker): return st.permute(tuple(axis)) if axis is not None else st
204
174
  self.sts = [permute(reshape(st)) for st in self.sts]
@@ -240,7 +210,7 @@ class Kernel:
240
210
  if isinstance(self.membufs[0].dtype, ImageDType):
241
211
  base_shape = self.membufs[0].dtype.shape
242
212
  if shape_idx_groups := get_contraction(self.output_shape, base_shape):
243
- special_strides: Tuple[sint, ...] = tuple()
213
+ special_strides: tuple[sint, ...] = tuple()
244
214
  for i,g in enumerate(shape_idx_groups):
245
215
  shape_piece = tuple(self.output_shape[x] for x in g)
246
216
  assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
@@ -298,37 +268,34 @@ class Kernel:
298
268
  s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
299
269
  axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if resolve(self.full_shape[x]%tc.dims[i] != 0))
300
270
  if axis_pads and (opt_level < 2): return None
301
- self.bufs_for_tensor_core[reduceop] = (buf0, buf1)
302
271
  if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
303
272
  return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
304
273
 
305
- def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
274
+ def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> bool:
306
275
  if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD:
307
- for tc in self.opts.tensor_cores:
276
+ tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
277
+ for tc in tensor_cores:
308
278
  tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
309
279
  # can only fuse reduces with the same tc options
310
280
  assert all_same(tensor_core_opts)
311
281
  if tensor_core_opts[0] is None: continue
312
- # tensor core -- unroll the reduce dim, upcast input and local the correct thread pattern
313
282
  self.tensor_core_opts = tc_opts = tensor_core_opts[0]
314
283
 
315
284
  # attempt to pad the tensor axes that require it
316
285
  try:
317
286
  for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
318
287
  except KernelOptError: continue
319
- for tc_dim, amt in tc.reduce_axes: self.apply_opt(Opt(OptOps.UNROLL,tc_opts.axes[2]-self.first_reduce,amt), append_opt=False)
320
- for opt in tc.opts_seq:
321
- if opt == "UP":
322
- for tc_dim, amt in tc.early_upcast_axes: self.apply_opt(Opt(OptOps.UPCAST,tc_opts.axes[tc_dim],amt), append_opt=False)
323
- elif opt == "LC":
324
- for tc_dim, amt in tc.threads: self.apply_opt(Opt(OptOps.LOCAL,tc_opts.axes[tc_dim],amt), append_opt=False)
288
+ # tensor core -- unroll the reduce dim (K), upcast and local the inner and outer dims (N, M)
289
+ for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, amt), append_opt=False)
290
+ for opt in tc.opts: self.apply_opt(Opt({"u":OptOps.UPCAST, "l":OptOps.LOCAL}[opt[0]], tc_opts.axes[int(opt[1])], 2), append_opt=False)
325
291
  self.tensor_core = tc
326
292
  self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
327
293
  return True
328
294
  return False
329
295
 
330
- def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, axis:int=0, tc_opt:Optional[int]=None) -> bool:
331
- """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
296
+ def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[list[Opt]]=None, axis:int=0, tc_select:Optional[int]=None,
297
+ tc_opt:Optional[int]=None) -> bool:
298
+ """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
332
299
  Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
333
300
 
334
301
  Keyword arguments:
@@ -337,21 +304,25 @@ class Kernel:
337
304
  1: enable tensor cores
338
305
  2: apply tensor core shape but don't use UOp.WMMA
339
306
  extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
307
+ tc_select -- specifies which tensor core(s) to use for optimization (default -1)
308
+ -1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
309
+ [0-N]: uses only the n'th tensor core available; useful for search
340
310
  tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
341
- 0: applies to only kernels with a single reduce axis and direct UOps.LOAD into Ops.MUL
342
- 1: allows kernels with multiple reduce axes and also multiplication of UOps.CAST'd buffers
311
+ 0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL
312
+ 1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers
343
313
  2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
344
314
  """
315
+ if tc_select is None: tc_select = TC_SELECT.value
345
316
  if tc_opt is None: tc_opt = TC_OPT.value
346
317
  if not self.opts.tensor_cores and use_tensor_cores != 2: return False
347
318
  try: # check TC first and apply hand-coded opts if successful
348
- self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
319
+ self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt)))
349
320
 
350
321
  if (tc_opts:=self.tensor_core_opts) is not None:
351
322
  if extra_opts is not None:
352
323
  for opt in extra_opts: self.apply_opt(opt)
353
324
  else:
354
- if (self.opts.device == "CLANG" and AMX): return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
325
+ if AMX: return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
355
326
  # hand-coded TC opts
356
327
  for tc_dim in [tc_dim for tc_dim in [1,0] if tc_opts.axes_exist[tc_dim]]: # attempt to upcast M and N
357
328
  szs = [sz for sz in [5,4,3,2] if self.full_shape[tc_opts.axes[tc_dim]] % sz == 0]
@@ -363,25 +334,35 @@ class Kernel:
363
334
  except KernelOptError:
364
335
  return False
365
336
 
337
+ def real_axis(self, opt:Opt):
338
+ if opt.axis is None: return -1
339
+ if opt.op is OptOps.UNROLL: return self.first_reduce+opt.axis
340
+ if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.first_reduce+self.group_for_reduces+opt.axis
341
+ return opt.axis
342
+
366
343
  def apply_opt(self, opt:Opt, append_opt:bool=True):
367
- if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals")
344
+ if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals")
368
345
 
369
346
  if opt.op is OptOps.TC:
370
347
  check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
371
- check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
372
348
  check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
373
- check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available")
349
+ check(opt.axis is not None, "tensor core opts must have an axis")
350
+ check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 2, "tensor core opts must have tc_select and tc_opt")
351
+ check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
352
+ check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
353
+ check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available")
374
354
  self.applied_opts.append(opt)
375
355
  return
376
356
 
377
- axis = opt.real_axis(self)
357
+ axis = self.real_axis(opt)
378
358
  check(axis < len(self.full_shape), "invalid axis")
379
359
 
380
- if opt.op is OptOps.SWAP: amt = cast(int, opt.amt) # amt is an axis in the SWAPs
381
- elif opt.amt is not None:
382
- amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
383
- check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless")
384
- if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
360
+ if opt.op is OptOps.SWAP: amt = cast(int, opt.arg) # arg is an axis in the SWAPs
361
+ elif opt.arg is not None:
362
+ check(isinstance(opt.arg, int), "arg should be int")
363
+ amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis]
364
+ check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless")
365
+ if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}")
385
366
  else: amt = -1
386
367
 
387
368
  if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
@@ -393,6 +374,8 @@ class Kernel:
393
374
  check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
394
375
 
395
376
  if opt.op is OptOps.LOCAL: # cyan
377
+ # NOTE: LLVM/CPU can use locals too, but they are treated the same as globals (still helpful for L1 cache)
378
+ # it's disabled for now since it makes BEAM slow for little gain
396
379
  check(self.opts.has_local, "target does not support local")
397
380
  check(axis < self.global_dims, "local is for globals")
398
381
  self.shift_to(axis, amt, insert_before=self.first_reduce)
@@ -416,18 +399,10 @@ class Kernel:
416
399
  self.upcast()
417
400
  elif opt.op is OptOps.UPCAST: # yellow
418
401
  check(axis < self.first_reduce, "upcast is for non-reduce")
419
- check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
420
- check(amt <= 16, "don't upcast more than 16")
402
+ check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.get_local_axes())), "can't upcast TC locals")
403
+ check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
421
404
  self.shift_to(axis, amt, insert_before=None)
422
405
  self.upcast()
423
- elif opt.op is OptOps.UPCASTMID: # white
424
- check(self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
425
- axes = self.sts[0].unit_stride_axes()
426
- check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
427
- check(axes[0] == axis, "wrong axis")
428
- check(amt == 4, "don't upcast mid anything but 4")
429
- self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces)
430
- self.group_for_reduces += 1
431
406
  elif opt.op is OptOps.NOLOCALS:
432
407
  check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
433
408
  check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
@@ -441,7 +416,7 @@ class Kernel:
441
416
  check(not self.vars, "does not work with symbolic shape")
442
417
  check(axis < self.first_upcast, "cannot pad upcasted")
443
418
  # ok to pad SUM if all parent ALU ops have f(0) = 0
444
- if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r), f"cannot pad {r}")
419
+ if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}, cache={}), f"cannot pad {r}")
445
420
  padded = False
446
421
  for i,st in enumerate(self.sts):
447
422
  if (s:=st.shape[axis]) == 1: continue # reduced
@@ -460,8 +435,7 @@ class Kernel:
460
435
  if isinstance(self.membufs[0].dtype, ImageDType):
461
436
  unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
462
437
  assert unit_stride_axes_mul_4, f"needs a unit stride axis in {self.bufs[0]}"
463
- if all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
464
- self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
438
+ if all(x < self.first_upcast for x in unit_stride_axes_mul_4): self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
465
439
  return self
466
440
 
467
441
  def hand_coded_optimizations(self) -> Kernel:
@@ -496,19 +470,12 @@ class Kernel:
496
470
  break
497
471
  except KernelOptError: pass
498
472
 
499
- # are we upcasting in mid reduce? (only for images)
500
- if self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
501
- axes = self.sts[0].unit_stride_axes()
502
- assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
503
- if self.sts[0].shape[axes[0]]%4 == 0:
504
- self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4))
505
-
506
473
  # upcast float4 images
507
474
  for buf_index,buf in enumerate(self.bufs):
508
475
  unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
509
476
  if buf.src[0].dtype.__class__ is ImageDType:
510
477
  #assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
511
- if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
478
+ if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4):
512
479
  if unit_stride_axes_mul_4[0] < self.first_reduce:
513
480
  self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
514
481
  else:
@@ -524,7 +491,7 @@ class Kernel:
524
491
  # expression and run test/test_ops.py with IMAGE=2
525
492
  # if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
526
493
  # this can be made much smarter
527
- to_upcast: List[int] = []
494
+ to_upcast: list[int] = []
528
495
  # upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
529
496
  for axis in range(self.first_reduce):
530
497
  # we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
@@ -536,7 +503,7 @@ class Kernel:
536
503
  for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
537
504
 
538
505
  # potentially do more upcasts of non reduce axes based on a heuristic
539
- upcasted_axis = set()
506
+ upcasted_axis: set[int] = set()
540
507
  while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024):
541
508
  xb_choices = []
542
509
  for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
@@ -577,7 +544,7 @@ class Kernel:
577
544
  else:
578
545
  # prioritize making expand axes local
579
546
  local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] # noqa: E501
580
- to_local: List[Tuple[int, int]] = []
547
+ to_local: list[tuple[int, int]] = []
581
548
  for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
582
549
  local_size = prod(sz for _, sz in to_local)
583
550
  local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None) # noqa: E501
@@ -593,11 +560,11 @@ class Kernel:
593
560
 
594
561
  # **** kernel outputs ****
595
562
 
596
- kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
563
+ kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
597
564
  @functools.cached_property
598
565
  def name(self) -> str:
599
566
  # kernel name (before late upcast)
600
- kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op in GroupOp.Buffer for x in self.ast.parents) else "E")
567
+ kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op is Ops.SINK or x.op in GroupOp.Buffer for x in self.ast.toposort) else "E")
601
568
  suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())])
602
569
  name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix
603
570
 
@@ -606,14 +573,19 @@ class Kernel:
606
573
  num = f"n{Kernel.kernel_cnt[function_name]-1}" if Kernel.kernel_cnt[function_name] > 1 else ""
607
574
  return name + colored(num, 'BLACK')
608
575
 
609
- def get_optimized_ast(self) -> UOp:
576
+ def get_optimized_ast(self, name_override:Optional[str]=None) -> UOp:
610
577
  @functools.lru_cache(None)
611
578
  def fixup_ast(op:UOp) -> UOp:
612
579
  ret = op.replace(src=tuple(fixup_ast(x) for x in op.src))
613
580
  if op.op in GroupOp.Buffer and op in self.bufs:
614
581
  st_uop = self.sts[self.bufs.index(op)].to_uop()
615
- return ret.replace(src=(st_uop,)) if op.op is Ops.VALID else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
616
- if op.op is Ops.SINK: return ret.replace(arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals))
582
+ # NOTE: if CONST got masked after applying opts, we create a new VALID
583
+ if op.op is Ops.CONST and any(v.mask is not None for v in unwrap(st_uop.st).views): return op.valid(unwrap(st_uop.st))
584
+ # otherwise we just replace the VIEW source
585
+ return ret.replace(src=(st_uop,)) if len(op.src) == 1 else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
586
+ if op.op is Ops.SINK:
587
+ return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override,
588
+ self.local_dims, self.upcasted, self.dont_use_locals))
617
589
  if op.op is Ops.REDUCE_AXIS:
618
590
  reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
619
591
 
@@ -623,47 +595,43 @@ class Kernel:
623
595
  grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces)
624
596
 
625
597
  if (tc := self.tensor_core) and (self.use_tensor_cores == 1 or self.use_tensor_cores == 3):
626
- def fix_st(st: ShapeTracker, wd_pattern, tcd_pattern):
627
- wd, warp_dims = self.global_dims, tuple(sz for _, sz in tc.threads)
628
- tcd, tcd_dims = self.first_upcast, tuple(sz for _, sz in tc.reduce_axes + tc.early_upcast_axes)
629
-
630
- assert st.shape[wd:wd+len(warp_dims)] == warp_dims, f"warp dims wrong: {st.shape[wd:wd+len(warp_dims)]=} != {warp_dims=}"
631
- assert st.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, f"tcd dims wrong: {st.shape[tcd:tcd+len(tcd_dims)]=} != {tcd_dims=}"
632
- assert tc.expanded_shape is not None
633
-
634
- new_shape = st.shape[:tcd] + tc.expanded_shape + st.shape[tcd+len(tcd_dims):] # expand the tcd
635
- permaxis = list(range(wd)) + [y + (wd if x == 0 else tcd) for x,y in wd_pattern] + list(range(wd+len(warp_dims),tcd)) + \
636
- [y + (wd if x == 0 else tcd) for x,y in tcd_pattern] + list(range(tcd+len(tc.expanded_shape),len(new_shape)))
637
- return st.reshape(new_shape).permute(tuple(permaxis)).reshape(st.shape).simplify()
598
+ wd, tcd = self.global_dims, self.first_upcast
599
+ def get_upcast_axes(buf): # upcast along non-zero dimensions of (tc_reduce + tc_upcast)
600
+ upcast_axes = int(math.log2(tc.elements_per_thread[buf]))
601
+ return tuple((tcd + len(tc.get_reduce_axes()) + len(tc.get_upcast_axes()) - (i+1), 2) for i in range(upcast_axes))
602
+ def get_tc_swizzle_st(shape, local_perm, upcast_perm):
603
+ offset = (tcd - (wd + len(local_perm)))
604
+ permaxis = list(range(wd)) \
605
+ + [wd + x + (offset if x >= len(local_perm) else 0) for x in local_perm] + list(range(wd + len(local_perm), tcd)) \
606
+ + [wd + x + (offset if x >= len(local_perm) else 0) for x in upcast_perm] + list(range(tcd + len(upcast_perm), len(shape)))
607
+ return ShapeTracker.from_shape(shape).permute(tuple(permaxis))
638
608
 
639
609
  srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
640
- for i, tc_pattern in enumerate([tc.st1_pattern, tc.st2_pattern]):
641
- if tc_pattern: srcs[i] = srcs[i].view(fix_st(unwrap(srcs[i].st), *tc_pattern))
610
+ for i, (src, swizzle) in enumerate(zip(srcs, tc.swizzle)):
611
+ if swizzle: srcs[i] = src.view(get_tc_swizzle_st((src if src.op is Ops.LOAD else src.src[0]).st_arg.shape, *swizzle))
642
612
 
643
613
  if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
644
614
  local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape))
645
615
  st = store_st = ShapeTracker.from_shape(local_shape)
646
- local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{i + 1}", st.real_size()))
647
- if tc_pattern: store_st = fix_st(store_st, *tc_pattern)
616
+ local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i}")
617
+ if swizzle: store_st = get_tc_swizzle_st(store_st.shape, *swizzle)
648
618
  local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i])
649
619
  srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store))
650
620
 
651
- tc_reduce_axes = tuple(self.first_upcast + ax for ax, _ in tc.reduce_axes)
652
- if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/EXPAND to get the vectorization right
653
- upcast_axes = tuple(tuple((self.first_upcast + ax, sz) for ax, sz in up) for up in tc.upcast_axes)
654
- wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, prod(sz for _, sz in tc.threads), upcast_axes, tc_reduce_axes)
655
- wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
656
- wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=(
657
- UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(wmma_sz[0]), src=(srcs[0],), arg=upcast_axes[0]),
658
- UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(wmma_sz[1]), src=(srcs[1],), arg=upcast_axes[1]),
659
- UOp.const(tc.dtype_out.vec(wmma_sz[2]), 0.0)), arg=wmma_arg)
660
- tc_uop = UOp(Ops.EXPAND, tc.dtype_out, (wmma,), arg=upcast_axes[2])
621
+ tc_reduce_axes = tuple(tcd + ax for ax, _ in tc.get_reduce_axes())
622
+ if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/UNROLL to get the vectorization right
623
+ tc_upcast_axes = (get_upcast_axes(0), get_upcast_axes(1), get_upcast_axes(2))
624
+ wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes)
625
+ wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
626
+ UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
627
+ UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]),
628
+ UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
629
+ tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])
661
630
 
662
631
  else: # for TC=3 MUL/SUM instead of WMMA
663
632
  tc_uop = UOp(Ops.REDUCE_AXIS, tc.dtype_out, ((srcs[0] * srcs[1]).cast(tc.dtype_out),), (Ops.ADD, tc_reduce_axes))
664
633
 
665
- new_reduce_axes = tuple(i for i in axes if i not in tc_reduce_axes)
666
- return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_reduce_axes)) if new_reduce_axes else tc_uop
634
+ return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop
667
635
 
668
636
  ret = ret.replace(arg = (op.arg[0], axes))
669
637
  if self.group_for_reduces and grouped_axes:
@@ -672,7 +640,8 @@ class Kernel:
672
640
  for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
673
641
  (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
674
642
  st_uop = ShapeTracker.from_shape(local_shape).to_uop()
675
- local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
643
+ local_size = st_uop.arg.real_size()
644
+ local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
676
645
  local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret)))
677
646
  grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
678
647
  if op is self.reduceops[-1]: return grouped_reduce
@@ -681,73 +650,44 @@ class Kernel:
681
650
 
682
651
  return ret
683
652
 
684
- return graph_rewrite(fixup_ast(self.ast), PatternMatcher([
685
- (UPat({*GroupOp.ALU,Ops.CAST,Ops.BITCAST,Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
686
- (UPat(Ops.LOAD, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src)))]))
653
+ return graph_rewrite(fixup_ast(self.ast), view_left)
687
654
 
688
655
  # **** this is the lowerer ****
689
656
 
690
657
  @track_rewrites()
691
- def linearize(self) -> Kernel:
692
- modified_ast = self.get_optimized_ast()
658
+ def linearize(self, name_override:Optional[str]=None) -> Kernel:
659
+ # display the AST
660
+ if getenv("VIZ"): graph_rewrite(self.ast, PatternMatcher([]), name="View Base AST")
661
+
662
+ modified_ast = self.get_optimized_ast(name_override)
693
663
 
694
664
  if DEBUG >= 3:
695
665
  print(self.name)
696
666
  if getenv("RAWAST"): print(self.ast)
697
- print(modified_ast)
667
+ for i,(buf,st) in enumerate([(buf,st) for buf,st in zip(self.bufs, self.sts) if buf.op not in {Ops.CONST, Ops.VALID}]):
668
+ print(f"{i:2d}: {str(st.shape):25s} {str(buf.src[0].dtype).replace('dtypes.',''):20s}", st.real_strides())
698
669
  print(self.applied_opts)
699
- verify_ast(modified_ast)
670
+ # verify AST matches the spec after applying opts
671
+ if __debug__: type_verify(list(modified_ast.toposort))
672
+ # TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
673
+ #if __debug__: type_verify(list(modified_ast.toposort), shape_spec)
700
674
 
701
- self.uops:List[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
675
+ self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
702
676
  if DEBUG >= 5: print_uops(self.uops)
703
677
  return self
704
678
 
705
- def to_program(self, name_override:Optional[str]=None) -> Program:
706
- self.linearize()
707
- src = self.opts.render(name:=to_function_name(ansiname:=(name_override if name_override is not None else self.name)), self.uops)
679
+ def to_program(self, name_override:Optional[str]=None) -> ProgramSpec:
680
+ self.linearize(name_override)
681
+ assert self.uops[0].op is Ops.NAME, "first uop must be name"
682
+ src = self.opts.render(self.uops)
708
683
 
709
- if getenv("RUN_PROCESS_REPLAY"):
710
- from test.external.process_replay.helpers import get_process_replay_ctx
711
- diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, *get_process_replay_ctx(), src))
684
+ if CAPTURE_PROCESS_REPLAY:
685
+ diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, self.uops[0].arg, ContextVar._cache, src))
712
686
 
713
687
  # group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
714
688
  # TODO: these max and min don't work on symbolic, and results are very wrong.
715
689
  mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
716
- for _, group in itertools.groupby([x for x in self.ast.parents if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
690
+ for _, group in itertools.groupby([x for x in self.ast.toposort if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
717
691
  key=lambda x: (x.op, x.src[0].arg)))
718
- return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
719
- global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
720
-
721
- # the living definition of intermediate UOps
722
-
723
- def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
724
- if not uop.has_st or uop in sts: return
725
- # restore globals from the two stage reduce
726
- if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL:
727
- _assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
728
- sts[uop] = sts[local_reduce]
729
- return
730
- for x in uop.src: _assert_valid_uop(x, st, sts)
731
- # only reduceuop is allowed to change shape, limited to turning n to 1
732
- if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
733
- # movementops are pushed to VIEW
734
- elif uop.op is Ops.VIEW:
735
- assert len(uop.src) == 0, f"can't swizzle in kernel yet {uop}"
736
- st = uop.arg
737
- # everything else inherits shape
738
- else:
739
- st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0]
740
- if not all_same(shapes:=[x.shape for x in src_sts]):
741
- if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
742
- raise AssertionError(f"found implicit expand {sizes} {shapes}")
743
- sts[uop] = st
744
-
745
- def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
746
- assert ast.op is Ops.SINK and all(x.op is Ops.STORE for x in ast.src), "must be SINK"
747
- assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size"
748
- sts: Dict[UOp, ShapeTracker] = {}
749
- for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
750
- shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
751
- assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
752
- type_verify(list(sts))
753
- return sts
692
+ return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, self.applied_opts, mem_bytes,
693
+ global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)