tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,26 @@
1
1
  from __future__ import annotations
2
+ import itertools, functools
3
+ from dataclasses import dataclass
2
4
  from collections import defaultdict
3
- import itertools
4
- from typing import DefaultDict, Optional, List, Tuple, cast, Dict, Union
5
- from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS, verify_lazyop
5
+ from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence
6
+ from enum import Enum, auto
7
+
8
+ from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \
9
+ graph_rewrite, track_rewrites, UPat
6
10
  from tinygrad.device import Device
7
- from tinygrad.renderer import Renderer, TensorCore
8
- from tinygrad.dtype import dtypes, ImageDType, DType
9
- from tinygrad.helpers import all_same, colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
11
+ from tinygrad.renderer import Renderer, TensorCore, Program
12
+ from tinygrad.dtype import ImageDType
13
+ from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap
14
+ from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
10
15
  from tinygrad.shape.shapetracker import ShapeTracker
11
- from tinygrad.shape.symbolic import sint
12
- from tinygrad.shape.view import View, strides_for_shape
13
- from dataclasses import dataclass
14
- from enum import Enum, auto
16
+ from tinygrad.shape.view import strides_for_shape
17
+ from tinygrad.codegen.linearize import linearize_uop
18
+ from tinygrad.codegen.uopgraph import full_graph_rewrite
19
+ from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction
15
20
 
16
21
  class OptOps(Enum):
17
22
  TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
18
- GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702
23
+ GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
19
24
  def __lt__(self, x:OptOps): return self.value < x.value
20
25
 
21
26
  class KernelOptError(Exception): pass
@@ -47,41 +52,41 @@ class TensorCoreOptions:
47
52
  elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False
48
53
  self.axes, self.axes_exist = tuple(axes), tuple(axes_exist)
49
54
 
50
- @dataclass(frozen=True)
51
- class LocalBuffer:
52
- name: str
53
- size: int
54
- dtype: DType = dtypes.float32
55
- realized: None = None
56
- def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
57
-
58
55
  class Kernel:
59
- def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
56
+ def __init__(self, ast:UOp, opts:Optional[Renderer]=None):
57
+ if ast.op is Ops.SINK: self.ast = ast
58
+
60
59
  self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
61
- verify_lazyop(*ast)
62
- self.ast = ast
63
- self.lazyops = flatten([op.lazyops for op in self.ast])
60
+ try: uop_sts_map = verify_ast(self.ast)
61
+ except AssertionError as e:
62
+ print("INVALID AST")
63
+ print(self.ast)
64
+ raise e
64
65
 
65
- cached_ordered_lazyops: Dict[LazyOp, List[LazyOp]] = {}
66
- def ordered_lazyops(op):
67
- if op not in cached_ordered_lazyops: cached_ordered_lazyops[op] = dedup([item for x in op.src for item in ordered_lazyops(x)] + [op])
68
- return cached_ordered_lazyops[op]
69
- self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps])
66
+ @functools.lru_cache(None)
67
+ def ordered_parents(op:UOp) -> List[UOp]: return dedup([item for x in op.src for item in ordered_parents(x)] + [op])
68
+ self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS])
70
69
 
71
- self.outbufs, self.vars = [x.arg for x in self.ast], flatten([x.vars() for x in self.ast])
72
- loadops = [BufferOps.LOAD, BufferOps.CONST]
73
- self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])
70
+ self.vars: List[Variable] = self.ast.variables()
71
+ self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in GroupOp.Buffer]
74
72
 
75
73
  # get earlybufs, before any reduceops
76
- self.earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
77
- self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
74
+ earlybufs: List[UOp] = [x for reduceop in self.reduceops for x in reduceop.parents if x.op in GroupOp.Buffer]
75
+ self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
76
+ # NOTE: full_shape can be wrong if there's a tree of reduces
78
77
 
79
78
  # create new shapetrackers inside this kernel, we will permute them
80
- self.sts: List[ShapeTracker] = [x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)]
79
+ self.sts: List[ShapeTracker] = [x.st_arg for x in self.bufs]
80
+
81
+ # add the shapetrackers for each reduce
82
+ # we use this to track which axes are reduced in each reduce
83
+ for x in self.reduceops:
84
+ self.sts.append(uop_sts_map[x])
85
+ self.sts.append(uop_sts_map[x.src[0]])
81
86
 
82
87
  # move all reduce axes to the end
83
88
  reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
84
- permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
89
+ permute = tuple([i for i,(s,n) in reduce if not resolve(s != n)] + [i for i,(s,n) in reduce if resolve(s != n)])
85
90
  self.reshape_and_permute(None, permute)
86
91
 
87
92
  # parameters for optimization
@@ -89,72 +94,57 @@ class Kernel:
89
94
  self.group_for_reduces: int = 0
90
95
  self.upcasted: int = 0
91
96
  self.local_dims: int = 0
92
- self.local_alias: DefaultDict[LazyOp, Dict[int, LocalBuffer]] = defaultdict(dict)
93
97
  self.tensor_core: Optional[TensorCore] = None
94
98
  self.tensor_core_opts: Optional[TensorCoreOptions] = None
99
+ self.use_tensor_cores: int = 0
95
100
  # the local aliased buffers for A and B
96
- self.bufs_for_tensor_core: Dict[LazyOp, Tuple[int, int]] = {}
101
+ self.bufs_for_tensor_core: Dict[UOp, Tuple[int, int]] = {}
97
102
  self.dont_use_locals: bool = False
98
103
 
99
104
  # group simplifies
100
105
  self.simplify_ones()
101
106
  self.simplify_merge_adjacent()
102
107
 
103
- # cache
104
- self.applied_opts_cache: Optional[List[Opt]] = None
105
-
106
108
  def copy(self):
107
109
  ret = type(self).__new__(type(self))
108
110
 
109
111
  # base linearizer params
110
- ret.opts, ret.ast, ret.lazyops = self.opts, self.ast, self.lazyops
112
+ ret.opts, ret.ast = self.opts, self.ast
111
113
 
112
114
  # things downstream of the AST
113
- ret.reduceops, ret.outbufs, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \
114
- self.reduceops, self.outbufs, self.vars, [x for x in self.bufs if not isinstance(x, LocalBuffer)], self.earlybufs, self.full_buf_index
115
- ret.sts = self.sts[:len(ret.bufs)] # NOTE: must redo the local buffers with TC in beam
115
+ ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = \
116
+ self.reduceops, self.vars, self.bufs, self.full_buf_index
117
+ ret.sts = self.sts[:len(ret.bufs)+len(ret.reduceops)*2] # NOTE: must redo the local buffers with TC in beam
116
118
 
117
119
  # parameters for optimizations
118
120
  ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
119
121
  self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
120
- ret.tensor_core, ret.tensor_core_opts, ret.local_alias, ret.bufs_for_tensor_core = self.tensor_core, self.tensor_core_opts, defaultdict(dict), \
121
- self.bufs_for_tensor_core
122
-
123
- # uncached since linearize didn't run
124
- ret.applied_opts_cache = None
122
+ ret.tensor_core, ret.tensor_core_opts, ret.bufs_for_tensor_core, ret.use_tensor_cores = \
123
+ self.tensor_core, self.tensor_core_opts, self.bufs_for_tensor_core, self.use_tensor_cores
125
124
 
126
125
  return ret
127
126
 
128
127
  @property
129
- def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)]
128
+ def membufs(self) -> List[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
130
129
 
131
130
  # TODO: these need more tests or it might silently be no-op
132
- def shape_offsets(self, i:int): return itertools.product(*[list(range(cast(int, s))) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] # noqa: E501
133
- def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] # noqa: E501
131
+ def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501
134
132
 
135
133
  def upcasted_axis(self, i:int) -> List[Tuple[int, Optional[sint], bool]]:
136
- upcasted_shape, upcasted_stride = self.sts[i].shape[self.shape_len-self.upcasted:], self.sts[i].real_strides()[self.shape_len-self.upcasted:]
134
+ upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:]
137
135
  assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
138
136
  return list(zip(upcasted_shape, upcasted_stride,
139
- [x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
140
-
141
- # TODO: is there a better way to write this?
142
- def acc_offsets(self, i:int) -> List[int]:
143
- if self.upcasted == 0: return [0]
144
- upcasted_i = self.upcasted_axis(i)
145
- acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
146
- return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]
147
-
148
- def get_float4_upcast_dim(self, i:int) -> List[int]:
149
- should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in (dtypes.float, dtypes.half) or isinstance(self.bufs[i].dtype, ImageDType))
150
- return [x for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1] if should_upcast else []
137
+ [x!=y for x,y in zip(self.sts[0].shape[self.first_upcast:], self.full_shape[self.first_upcast:])]))
151
138
 
152
139
  @property
153
140
  def first_reduce(self) -> int:
154
- return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True) # noqa: E501
141
+ return [resolve(x!=y) for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True)
155
142
 
156
143
  @property
157
- def reduceop(self) -> Optional[LazyOp]: return self.reduceops[0] if len(self.reduceops) > 0 else None
144
+ def first_upcast(self) -> int: return self.shape_len-self.upcasted
145
+
146
+ @property
147
+ def reduceop(self) -> Optional[UOp]: return self.reduceops[0] if len(self.reduceops) > 0 else None
158
148
 
159
149
  @property
160
150
  def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape
@@ -163,7 +153,7 @@ class Kernel:
163
153
  def full_shape(self) -> Tuple[sint, ...]: return self.sts[self.full_buf_index].shape
164
154
 
165
155
  @property
166
- def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.shape_len-self.upcasted]
156
+ def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.first_upcast]
167
157
 
168
158
  @property
169
159
  def shape_len(self) -> int: return len(self.sts[0].shape)
@@ -193,27 +183,25 @@ class Kernel:
193
183
  # between first_reduce and first_reduce + group_for_reduces, they are either upcast mid reduce (white), or late upcasted (green)
194
184
  colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + self.group_for_reduces)] # noqa: E501
195
185
  # between first_reduce + group_for_reduces and upcasted, they are reduce (red)
196
- colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + self.group_for_reduces))
186
+ colors += ["red"] * (self.first_upcast - (self.first_reduce + self.group_for_reduces))
197
187
  # upcasted dimensions are reduce (magenta) or normal (yellow)
198
- colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.shape_len-self.upcasted, self.shape_len)]
188
+ colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.first_upcast, self.shape_len)]
199
189
  assert len(colors) == self.shape_len, "colors size mismatch"
200
190
  return colors
201
191
 
202
192
  def colored_shape(self, pad:Optional[int]=None, dense=False) -> str:
203
- ret = ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) and not dense else s for s in self.full_shape], self.colors())) # noqa: E501
193
+ shape_strs = [(s if dense else f"{s:4d}") if isinstance(s, int) else s.render() for s in self.full_shape]
194
+ ret = ' '.join(colored(s, color) for s,color in zip(shape_strs, self.colors()))
204
195
  if pad: ret += ' '*(pad-ansilen(ret))
205
196
  return ret
206
197
 
207
198
  # ******************** base simplifiers ********************
208
199
 
209
200
  # apply reshape and permute to all shapetrackers
210
- def reshape_and_permute(self, new_shape_fxn, axis):
211
- new_sts = []
212
- for st in self.sts:
213
- if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape)))
214
- if axis is not None: st = st.permute(tuple(axis))
215
- new_sts.append(st)
216
- self.sts = new_sts
201
+ def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[Tuple[sint, ...]], Sequence[sint]]], axis:Optional[Sequence[int]]):
202
+ def reshape(st:ShapeTracker): return st.reshape(tuple(new_shape_fxn(st.shape))) if new_shape_fxn is not None else st
203
+ def permute(st:ShapeTracker): return st.permute(tuple(axis)) if axis is not None else st
204
+ self.sts = [permute(reshape(st)) for st in self.sts]
217
205
 
218
206
  # drops the final dimension
219
207
  def upcast(self):
@@ -229,7 +217,7 @@ class Kernel:
229
217
  move_axis = axis if top else axis+1
230
218
  if move_axis < insert_before: insert_before += 1
231
219
  self.reshape_and_permute(
232
- lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]),
220
+ lambda x: x[0:axis] + (((amount, x[axis]//amount) if top else (x[axis]//amount, amount)) if x[axis] > 1 else (1,1)) + x[axis+1:],
233
221
  [i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])
234
222
 
235
223
  # ******************** complex simplifiers ********************
@@ -240,7 +228,7 @@ class Kernel:
240
228
  if self.shape_len == 0: return False
241
229
  all_ones = [s==1 for s in self.full_shape]
242
230
  self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
243
- self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:]) # TODO: no necessary since upcasted axis can't be un-upcasted
231
+ self.upcasted -= sum(all_ones[self.first_upcast:]) # TODO: no necessary since upcasted axis can't be un-upcasted
244
232
  self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
245
233
  return any(all_ones)
246
234
 
@@ -249,8 +237,8 @@ class Kernel:
249
237
  shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
250
238
 
251
239
  # if it's an image, insert fake strides such that this fusion doesn't happen across image axes
252
- if isinstance(self.bufs[0].dtype, ImageDType):
253
- base_shape = self.bufs[0].dtype.shape
240
+ if isinstance(self.membufs[0].dtype, ImageDType):
241
+ base_shape = self.membufs[0].dtype.shape
254
242
  if shape_idx_groups := get_contraction(self.output_shape, base_shape):
255
243
  special_strides: Tuple[sint, ...] = tuple()
256
244
  for i,g in enumerate(shape_idx_groups):
@@ -281,39 +269,20 @@ class Kernel:
281
269
  # do the reshapes
282
270
  for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
283
271
 
284
- # ******************** helpers ********************
285
-
286
- def alias_buffer(self, op:LazyOp, i:int, pattern:List[int]) -> None:
287
- assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
288
-
289
- bst = 1
290
- real_strides = self.sts[i].real_strides()
291
- shp, stride = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern)
292
- for priority in range(1, max(pattern)+1): # priority. 0 is non local and ignored
293
- for j,p in enumerate(pattern):
294
- if priority == p and real_strides[j] != 0:
295
- stride[j] = bst
296
- bst *= shp[j]
297
-
298
- self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
299
- self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size))
300
- if DEBUG >= 4: print("aliasing buffer", self.sts[i])
301
- self.local_alias[op][i] = cast(LocalBuffer, self.bufs[-1])
302
-
303
272
  # ******************** high level optimizers ********************
304
273
 
305
- def _create_tc_opts(self, reduceop:LazyOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
274
+ def _create_tc_opts(self, reduceop:UOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
306
275
  has_cast = tc.dtype_in != tc.dtype_out
307
- if has_cast and not(reduceop.src[0].op is UnaryOps.CAST and reduceop.src[0].arg == tc.dtype_out): return None
276
+ if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
308
277
 
309
278
  mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
310
- if mul_op.op is not BinaryOps.MUL: return None
279
+ if mul_op.op is not Ops.MUL: return None
311
280
 
312
- def buf_index(src: LazyOp) -> Optional[int]:
281
+ def buf_index(src:UOp) -> Optional[int]:
313
282
  # TODO: apply tc even if the sources are not from LOAD
314
- if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
283
+ if src.op is Ops.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(src)
315
284
  try:
316
- if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
285
+ if opt_level >= 1 and src.op is Ops.CAST and src.dtype == tc.dtype_in: return self.bufs.index(src.src[0])
317
286
  except ValueError: return None
318
287
  return None
319
288
  if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
@@ -321,40 +290,40 @@ class Kernel:
321
290
  buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
322
291
  axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
323
292
  axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
324
- if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
293
+ if not (axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
325
294
 
326
295
  axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
327
- if not(axis < len(axis_choices)): return None
296
+ if not (axis < len(axis_choices)): return None
328
297
 
329
298
  s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
330
- axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0)
299
+ axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if resolve(self.full_shape[x]%tc.dims[i] != 0))
331
300
  if axis_pads and (opt_level < 2): return None
332
301
  self.bufs_for_tensor_core[reduceop] = (buf0, buf1)
333
302
  if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
334
303
  return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
335
304
 
336
305
  def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
337
- if use_tensor_cores and self.opts.has_local and self.reduceop is not None and self.reduceop.op is ReduceOps.SUM:
306
+ if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD:
338
307
  for tc in self.opts.tensor_cores:
339
308
  tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
340
309
  # can only fuse reduces with the same tc options
341
310
  assert all_same(tensor_core_opts)
342
311
  if tensor_core_opts[0] is None: continue
343
- # tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
312
+ # tensor core -- unroll the reduce dim, upcast input and local the correct thread pattern
344
313
  self.tensor_core_opts = tc_opts = tensor_core_opts[0]
345
314
 
346
315
  # attempt to pad the tensor axes that require it
347
316
  try:
348
317
  for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
349
318
  except KernelOptError: continue
350
- self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, tc.dims[2]), append_opt=False)
351
- for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
352
- if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
353
- for (tc_dim, tc_amt) in tc.threads:
354
- self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
355
-
356
- # assert tensor core
357
- if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
319
+ for tc_dim, amt in tc.reduce_axes: self.apply_opt(Opt(OptOps.UNROLL,tc_opts.axes[2]-self.first_reduce,amt), append_opt=False)
320
+ for opt in tc.opts_seq:
321
+ if opt == "UP":
322
+ for tc_dim, amt in tc.early_upcast_axes: self.apply_opt(Opt(OptOps.UPCAST,tc_opts.axes[tc_dim],amt), append_opt=False)
323
+ elif opt == "LC":
324
+ for tc_dim, amt in tc.threads: self.apply_opt(Opt(OptOps.LOCAL,tc_opts.axes[tc_dim],amt), append_opt=False)
325
+ self.tensor_core = tc
326
+ self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
358
327
  return True
359
328
  return False
360
329
 
@@ -369,11 +338,11 @@ class Kernel:
369
338
  2: apply tensor core shape but don't use UOp.WMMA
370
339
  extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
371
340
  tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
372
- 0: applies to only kernels with a single reduce axis and direct BufferOps.LOAD into BinaryOps.MUL
373
- 1: allows kernels with multiple reduce axes and also multiplication of UnaryOps.CAST'd buffers
341
+ 0: applies to only kernels with a single reduce axis and direct UOps.LOAD into Ops.MUL
342
+ 1: allows kernels with multiple reduce axes and also multiplication of UOps.CAST'd buffers
374
343
  2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
375
344
  """
376
- if tc_opt is None: tc_opt = self.opts.tc_opt
345
+ if tc_opt is None: tc_opt = TC_OPT.value
377
346
  if not self.opts.tensor_cores and use_tensor_cores != 2: return False
378
347
  try: # check TC first and apply hand-coded opts if successful
379
348
  self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
@@ -382,31 +351,25 @@ class Kernel:
382
351
  if extra_opts is not None:
383
352
  for opt in extra_opts: self.apply_opt(opt)
384
353
  else:
354
+ if (self.opts.device == "CLANG" and AMX): return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
385
355
  # hand-coded TC opts
386
- def late_upcast_tc(tc_dim: int):
387
- if tc_opts.axes_exist[tc_dim]:
388
- ax_div = [upc for upc in [5,4,3,2,1] if self.full_shape[tc_opts.axes[tc_dim]]%upc == 0][0]
389
- if ax_div != 1: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], ax_div))
390
- late_upcast_tc(1) # attempt to upcast M
391
- late_upcast_tc(0) # attempt to upcast N
392
-
393
- if self.tensor_core and tc_opts.axes_exist[0]: # attempt to local N
394
- for upc in [4,2]:
395
- if self.full_shape[tc_opts.axes[0]] % upc == 0:
396
- self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], upc))
397
- break
356
+ for tc_dim in [tc_dim for tc_dim in [1,0] if tc_opts.axes_exist[tc_dim]]: # attempt to upcast M and N
357
+ szs = [sz for sz in [5,4,3,2] if self.full_shape[tc_opts.axes[tc_dim]] % sz == 0]
358
+ if szs: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], szs[0]))
398
359
 
360
+ if tc_opts.axes_exist[0] and (szs := [sz for sz in [4,2] if self.full_shape[tc_opts.axes[0]] % sz == 0]): # attempt to local N
361
+ self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], szs[0]))
399
362
  return True
400
363
  except KernelOptError:
401
364
  return False
402
365
 
403
366
  def apply_opt(self, opt:Opt, append_opt:bool=True):
404
- check(not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals")
367
+ if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals")
405
368
 
406
369
  if opt.op is OptOps.TC:
407
370
  check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
408
371
  check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
409
- check((use_tensor_cores:=self.opts.tc) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
372
+ check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
410
373
  check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available")
411
374
  self.applied_opts.append(opt)
412
375
  return
@@ -414,15 +377,17 @@ class Kernel:
414
377
  axis = opt.real_axis(self)
415
378
  check(axis < len(self.full_shape), "invalid axis")
416
379
 
417
- if opt.amt is not None:
380
+ if opt.op is OptOps.SWAP: amt = cast(int, opt.amt) # amt is an axis in the SWAPs
381
+ elif opt.amt is not None:
418
382
  amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
419
383
  check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless")
420
384
  if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
421
385
  else: amt = -1
422
386
 
423
- if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
424
- acc_sz, upcast_idx = dt.base.itemsize if isinstance((dt:=self.reduceop.dtype), ImageDType) else dt.itemsize, self.shape_len-self.upcasted
425
- upcast_sz = prod([a for a,b in zip(self.full_shape[upcast_idx:], self.sts[0].shape[upcast_idx:]) if a == b])
387
+ if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
388
+ (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
389
+ acc_sz = self.reduceop.dtype.itemsize
390
+ upcast_sz = prod([a for a,b in zip(self.full_shape[self.first_upcast:], self.sts[0].shape[self.first_upcast:]) if a == b])
426
391
  local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
427
392
  smem_sz = amt*acc_sz*upcast_sz*local_sz
428
393
  check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
@@ -434,12 +399,13 @@ class Kernel:
434
399
  self.local_dims += 1
435
400
  elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
436
401
  check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
437
- check(axis >= self.first_reduce + self.group_for_reduces and axis < self.shape_len-self.upcasted, "must be reduce axis to group")
402
+ check(self.first_reduce + self.group_for_reduces <= axis < self.first_upcast, "must be reduce axis to group")
438
403
  check(not self.tensor_core, "can't group with tensor cores")
404
+ check(len(reduce_axes:=[i for r in self.reduceops for i in r.axis_arg]) == len(set(reduce_axes)), "can't group with parallel reduces")
439
405
  self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
440
406
  self.group_for_reduces += 1
441
407
  elif opt.op is OptOps.UNROLL: # purple
442
- check(axis < self.shape_len-self.upcasted, "can't upcasted already upcasted")
408
+ check(axis < self.first_upcast, "can't upcasted already upcasted")
443
409
  check(amt <= 32, "don't unroll more than 32")
444
410
  # TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
445
411
  #upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
@@ -450,12 +416,12 @@ class Kernel:
450
416
  self.upcast()
451
417
  elif opt.op is OptOps.UPCAST: # yellow
452
418
  check(axis < self.first_reduce, "upcast is for non-reduce")
453
- check(not(self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
454
- check(amt <= 8, "don't upcast more than 8")
419
+ check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
420
+ check(amt <= 16, "don't upcast more than 16")
455
421
  self.shift_to(axis, amt, insert_before=None)
456
422
  self.upcast()
457
423
  elif opt.op is OptOps.UPCASTMID: # white
458
- check(self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
424
+ check(self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
459
425
  axes = self.sts[0].unit_stride_axes()
460
426
  check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
461
427
  check(axes[0] == axis, "wrong axis")
@@ -466,18 +432,21 @@ class Kernel:
466
432
  check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
467
433
  check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
468
434
  self.dont_use_locals = True
435
+ elif opt.op is OptOps.SWAP:
436
+ check(axis < amt < self.global_dims, f"swap is only for globals with axis < amt, getting {amt=}, {axis=}, {self.global_dims=}")
437
+ permute = list(range(self.shape_len))
438
+ permute[axis], permute[amt] = permute[amt], permute[axis]
439
+ self.reshape_and_permute(None, tuple(permute))
469
440
  elif opt.op is OptOps.PADTO:
470
441
  check(not self.vars, "does not work with symbolic shape")
471
- check(axis < self.shape_len - self.upcasted, "cannot pad upcasted")
472
- # ok to pad SUM if all parent ops have f(0) = 0
473
- if self.first_reduce <= axis:
474
- check((r:=cast(LazyOp, self.reduceop)).op is ReduceOps.SUM and \
475
- all(op.op not in UNSAFE_PAD_OPS for ops in r.src for op in ops.lazyops), "cannot pad")
442
+ check(axis < self.first_upcast, "cannot pad upcasted")
443
+ # ok to pad SUM if all parent ALU ops have f(0) = 0
444
+ if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r), f"cannot pad {r}")
476
445
  padded = False
477
446
  for i,st in enumerate(self.sts):
478
- if self.sts[i].shape[axis] == 1: continue # reduced
479
- check(self.sts[i].shape[axis] > amt//4, f"pad adds more than quadruple the work {self.sts[i].shape[axis]=} > {amt//4=}")
480
- if (ru := round_up(cast(int, self.sts[i].shape[axis]), cast(int, amt)) - self.sts[i].shape[axis]):
447
+ if (s:=st.shape[axis]) == 1: continue # reduced
448
+ check(s > amt//4, f"pad adds more than quadruple the work {st.shape[axis]=} > {amt//4=}")
449
+ if (ru := round_up(cast(int, s), amt) - s):
481
450
  # pad right seems to be faster
482
451
  self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
483
452
  padded = True
@@ -487,24 +456,25 @@ class Kernel:
487
456
  if self.simplify_ones() and self.tensor_core_opts:
488
457
  self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
489
458
 
490
- def required_optimizations(self):
491
- if self.bufs[0].dtype.__class__ is ImageDType:
459
+ def required_optimizations(self) -> Kernel:
460
+ if isinstance(self.membufs[0].dtype, ImageDType):
492
461
  unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
493
- assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[0]}"
494
- if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
462
+ assert unit_stride_axes_mul_4, f"needs a unit stride axis in {self.bufs[0]}"
463
+ if all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
495
464
  self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
465
+ return self
496
466
 
497
- def hand_coded_optimizations(self):
467
+ def hand_coded_optimizations(self) -> Kernel:
498
468
  self.required_optimizations()
499
469
 
500
470
  # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
501
471
  MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
502
472
  if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
503
- self.reduceop is not None and self.reduceop.op is ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
504
- (mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is BufferOps.LOAD and mulop.src[1].op is BufferOps.LOAD:
505
- st0, st1 = self.sts[self.bufs.index(mulop.src[0].arg)], self.sts[self.bufs.index(mulop.src[1].arg)]
473
+ self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \
474
+ (mulop:=self.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
475
+ st0, st1 = self.sts[self.bufs.index(mulop.src[0])], self.sts[self.bufs.index(mulop.src[1])]
506
476
  strides0, strides1 = st0.real_strides(), st1.real_strides()
507
- def has_expanded_axis(shape, strides): return any(s > 1 and st == 0 for s,st in zip(shape,strides))
477
+ def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
508
478
  if strides0[self.first_reduce] == 1 and not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
509
479
  for global_idx in range(self.global_dims):
510
480
  if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
@@ -513,13 +483,13 @@ class Kernel:
513
483
  if MV_THREADS_PER_ROW > 1: self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
514
484
  if MV_BLOCKSIZE > 1: self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
515
485
  if MV_ROWS_PER_THREAD > 1: self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
516
- return
486
+ return self
517
487
 
518
488
  if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]):
519
489
  # are we grouping? (requires local shape support)
520
490
  if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # noqa: E501
521
491
  # TODO: use 1024 if it's allowed in a smarter way
522
- for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
492
+ for sz in ([256, 16] if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
523
493
  if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
524
494
  try: # may fail due to excessive smem usage
525
495
  self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
@@ -527,7 +497,7 @@ class Kernel:
527
497
  except KernelOptError: pass
528
498
 
529
499
  # are we upcasting in mid reduce? (only for images)
530
- if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
500
+ if self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
531
501
  axes = self.sts[0].unit_stride_axes()
532
502
  assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
533
503
  if self.sts[0].shape[axes[0]]%4 == 0:
@@ -536,21 +506,21 @@ class Kernel:
536
506
  # upcast float4 images
537
507
  for buf_index,buf in enumerate(self.bufs):
538
508
  unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
539
- if buf.dtype.__class__ is ImageDType:
509
+ if buf.src[0].dtype.__class__ is ImageDType:
540
510
  #assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
541
- if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
511
+ if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
542
512
  if unit_stride_axes_mul_4[0] < self.first_reduce:
543
513
  self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
544
514
  else:
545
515
  self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
546
516
 
547
517
  # no more opt if we are grouping
548
- if self.group_for_reduces: return
518
+ if self.group_for_reduces: return self
549
519
 
550
520
  # **** below this line need to be optional and benchmarked ****
551
521
 
552
522
  # TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
553
- # to trigger the above bug, remove prod(self.full_shape[self.shape_len - self.upcasted:]) from the below
523
+ # to trigger the above bug, remove prod(self.full_shape[self.first_upcast:]) from the below
554
524
  # expression and run test/test_ops.py with IMAGE=2
555
525
  # if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
556
526
  # this can be made much smarter
@@ -560,14 +530,14 @@ class Kernel:
560
530
  # we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
561
531
  # for now skip upcasting here if there is a symbolic axis
562
532
  if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \
563
- prod(self.full_shape[self.shape_len - self.upcasted:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
533
+ prod(self.full_shape[self.first_upcast:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
564
534
  if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
565
535
  to_upcast.append(axis)
566
536
  for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
567
537
 
568
538
  # potentially do more upcasts of non reduce axes based on a heuristic
569
539
  upcasted_axis = set()
570
- while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
540
+ while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024):
571
541
  xb_choices = []
572
542
  for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
573
543
  # if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
@@ -581,11 +551,11 @@ class Kernel:
581
551
  else: break
582
552
 
583
553
  # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
584
- if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): # noqa: E501
585
- if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
554
+ if self.first_reduce < self.first_upcast and (prod(self.full_shape[self.first_upcast:]) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): # noqa: E501
555
+ if isinstance(s:=self.full_unupcasted_shape[-1], int) and s <= 32: # NOTE: cannot loop unroll symbolic axis
586
556
  self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
587
557
  # if it's small, upcast a second reduce dimension too
588
- if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int):
558
+ if self.first_reduce < self.first_upcast and s <= 3 and isinstance(s2:=self.full_unupcasted_shape[-1], int) and s2 <= 3:
589
559
  self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
590
560
  else:
591
561
  for splits in [4]:
@@ -618,3 +588,166 @@ class Kernel:
618
588
  will_delete_shape = local_sz == self.full_shape[axis]
619
589
  self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
620
590
  if will_delete_shape: deleted_shape += 1
591
+
592
+ return self
593
+
594
+ # **** kernel outputs ****
595
+
596
+ kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
597
+ @functools.cached_property
598
+ def name(self) -> str:
599
+ # kernel name (before late upcast)
600
+ kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op in GroupOp.Buffer for x in self.ast.parents) else "E")
601
+ suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())])
602
+ name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix
603
+
604
+ # name the function something unique
605
+ Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1
606
+ num = f"n{Kernel.kernel_cnt[function_name]-1}" if Kernel.kernel_cnt[function_name] > 1 else ""
607
+ return name + colored(num, 'BLACK')
608
+
609
+ def get_optimized_ast(self) -> UOp:
610
+ @functools.lru_cache(None)
611
+ def fixup_ast(op:UOp) -> UOp:
612
+ ret = op.replace(src=tuple(fixup_ast(x) for x in op.src))
613
+ if op.op in GroupOp.Buffer and op in self.bufs:
614
+ st_uop = self.sts[self.bufs.index(op)].to_uop()
615
+ return ret.replace(src=(st_uop,)) if op.op is Ops.VALID else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
616
+ if op.op is Ops.SINK: return ret.replace(arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals))
617
+ if op.op is Ops.REDUCE_AXIS:
618
+ reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
619
+
620
+ def reduced_axes(start, stop):
621
+ return tuple(i for i in range(start, stop) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
622
+ axes = reduced_axes(self.first_reduce + self.group_for_reduces, self.shape_len)
623
+ grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces)
624
+
625
+ if (tc := self.tensor_core) and (self.use_tensor_cores == 1 or self.use_tensor_cores == 3):
626
+ def fix_st(st: ShapeTracker, wd_pattern, tcd_pattern):
627
+ wd, warp_dims = self.global_dims, tuple(sz for _, sz in tc.threads)
628
+ tcd, tcd_dims = self.first_upcast, tuple(sz for _, sz in tc.reduce_axes + tc.early_upcast_axes)
629
+
630
+ assert st.shape[wd:wd+len(warp_dims)] == warp_dims, f"warp dims wrong: {st.shape[wd:wd+len(warp_dims)]=} != {warp_dims=}"
631
+ assert st.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, f"tcd dims wrong: {st.shape[tcd:tcd+len(tcd_dims)]=} != {tcd_dims=}"
632
+ assert tc.expanded_shape is not None
633
+
634
+ new_shape = st.shape[:tcd] + tc.expanded_shape + st.shape[tcd+len(tcd_dims):] # expand the tcd
635
+ permaxis = list(range(wd)) + [y + (wd if x == 0 else tcd) for x,y in wd_pattern] + list(range(wd+len(warp_dims),tcd)) + \
636
+ [y + (wd if x == 0 else tcd) for x,y in tcd_pattern] + list(range(tcd+len(tc.expanded_shape),len(new_shape)))
637
+ return st.reshape(new_shape).permute(tuple(permaxis)).reshape(st.shape).simplify()
638
+
639
+ srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
640
+ for i, tc_pattern in enumerate([tc.st1_pattern, tc.st2_pattern]):
641
+ if tc_pattern: srcs[i] = srcs[i].view(fix_st(unwrap(srcs[i].st), *tc_pattern))
642
+
643
+ if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
644
+ local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape))
645
+ st = store_st = ShapeTracker.from_shape(local_shape)
646
+ local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{i + 1}", st.real_size()))
647
+ if tc_pattern: store_st = fix_st(store_st, *tc_pattern)
648
+ local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i])
649
+ srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store))
650
+
651
+ tc_reduce_axes = tuple(self.first_upcast + ax for ax, _ in tc.reduce_axes)
652
+ if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/EXPAND to get the vectorization right
653
+ upcast_axes = tuple(tuple((self.first_upcast + ax, sz) for ax, sz in up) for up in tc.upcast_axes)
654
+ wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, prod(sz for _, sz in tc.threads), upcast_axes, tc_reduce_axes)
655
+ wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
656
+ wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=(
657
+ UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(wmma_sz[0]), src=(srcs[0],), arg=upcast_axes[0]),
658
+ UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(wmma_sz[1]), src=(srcs[1],), arg=upcast_axes[1]),
659
+ UOp.const(tc.dtype_out.vec(wmma_sz[2]), 0.0)), arg=wmma_arg)
660
+ tc_uop = UOp(Ops.EXPAND, tc.dtype_out, (wmma,), arg=upcast_axes[2])
661
+
662
+ else: # for TC=3 MUL/SUM instead of WMMA
663
+ tc_uop = UOp(Ops.REDUCE_AXIS, tc.dtype_out, ((srcs[0] * srcs[1]).cast(tc.dtype_out),), (Ops.ADD, tc_reduce_axes))
664
+
665
+ new_reduce_axes = tuple(i for i in axes if i not in tc_reduce_axes)
666
+ return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_reduce_axes)) if new_reduce_axes else tc_uop
667
+
668
+ ret = ret.replace(arg = (op.arg[0], axes))
669
+ if self.group_for_reduces and grouped_axes:
670
+ local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims] + \
671
+ tuple([self.full_shape[i] if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i] else 1 \
672
+ for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
673
+ (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
674
+ st_uop = ShapeTracker.from_shape(local_shape).to_uop()
675
+ local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
676
+ local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret)))
677
+ grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
678
+ if op is self.reduceops[-1]: return grouped_reduce
679
+ st_uop = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else a for i,a in enumerate(local_shape)])).to_uop()
680
+ return UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, grouped_reduce)))
681
+
682
+ return ret
683
+
684
+ return graph_rewrite(fixup_ast(self.ast), PatternMatcher([
685
+ (UPat({*GroupOp.ALU,Ops.CAST,Ops.BITCAST,Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
686
+ (UPat(Ops.LOAD, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src)))]))
687
+
688
+ # **** this is the lowerer ****
689
+
690
+ @track_rewrites()
691
+ def linearize(self) -> Kernel:
692
+ modified_ast = self.get_optimized_ast()
693
+
694
+ if DEBUG >= 3:
695
+ print(self.name)
696
+ if getenv("RAWAST"): print(self.ast)
697
+ print(modified_ast)
698
+ print(self.applied_opts)
699
+ verify_ast(modified_ast)
700
+
701
+ self.uops:List[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
702
+ if DEBUG >= 5: print_uops(self.uops)
703
+ return self
704
+
705
+ def to_program(self, name_override:Optional[str]=None) -> Program:
706
+ self.linearize()
707
+ src = self.opts.render(name:=to_function_name(ansiname:=(name_override if name_override is not None else self.name)), self.uops)
708
+
709
+ if getenv("RUN_PROCESS_REPLAY"):
710
+ from test.external.process_replay.helpers import get_process_replay_ctx
711
+ diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, *get_process_replay_ctx(), src))
712
+
713
+ # group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
714
+ # TODO: these max and min don't work on symbolic, and results are very wrong.
715
+ mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
716
+ for _, group in itertools.groupby([x for x in self.ast.parents if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
717
+ key=lambda x: (x.op, x.src[0].arg)))
718
+ return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
719
+ global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
720
+
721
+ # the living definition of intermediate UOps
722
+
723
+ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
724
+ if not uop.has_st or uop in sts: return
725
+ # restore globals from the two stage reduce
726
+ if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL:
727
+ _assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
728
+ sts[uop] = sts[local_reduce]
729
+ return
730
+ for x in uop.src: _assert_valid_uop(x, st, sts)
731
+ # only reduceuop is allowed to change shape, limited to turning n to 1
732
+ if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
733
+ # movementops are pushed to VIEW
734
+ elif uop.op is Ops.VIEW:
735
+ assert len(uop.src) == 0, f"can't swizzle in kernel yet {uop}"
736
+ st = uop.arg
737
+ # everything else inherits shape
738
+ else:
739
+ st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0]
740
+ if not all_same(shapes:=[x.shape for x in src_sts]):
741
+ if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
742
+ raise AssertionError(f"found implicit expand {sizes} {shapes}")
743
+ sts[uop] = st
744
+
745
+ def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
746
+ assert ast.op is Ops.SINK and all(x.op is Ops.STORE for x in ast.src), "must be SINK"
747
+ assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size"
748
+ sts: Dict[UOp, ShapeTracker] = {}
749
+ for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
750
+ shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
751
+ assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
752
+ type_verify(list(sts))
753
+ return sts