tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/kernel.py +230 -190
  3. tinygrad/codegen/linearizer.py +278 -384
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +132 -275
  6. tinygrad/dtype.py +53 -37
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +28 -14
  14. tinygrad/helpers.py +72 -43
  15. tinygrad/lazy.py +141 -240
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +179 -8
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +106 -28
  20. tinygrad/nn/state.py +86 -17
  21. tinygrad/ops.py +70 -44
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +299 -206
  25. tinygrad/renderer/llvmir.py +118 -123
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +59 -54
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +37 -41
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +16 -14
  43. tinygrad/runtime/ops_cuda.py +130 -38
  44. tinygrad/runtime/ops_disk.py +45 -42
  45. tinygrad/runtime/ops_gpu.py +52 -50
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +36 -56
  48. tinygrad/runtime/ops_metal.py +42 -24
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +41 -105
  53. tinygrad/shape/symbolic.py +98 -95
  54. tinygrad/shape/view.py +137 -35
  55. tinygrad/tensor.py +2367 -442
  56. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/features/image.py +0 -93
  61. tinygrad/features/multi.py +0 -103
  62. tinygrad/features/search.py +0 -160
  63. tinygrad/graph.py +0 -106
  64. tinygrad/jit.py +0 -152
  65. tinygrad/realize.py +0 -50
  66. tinygrad/runtime/graph/hip.py +0 -24
  67. tinygrad/runtime/ops_cpu.py +0 -45
  68. tinygrad/runtime/ops_hip.py +0 -97
  69. tinygrad/runtime/ops_torch.py +0 -49
  70. tinygrad-0.8.0.dist-info/RECORD +0 -41
  71. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,52 +1,47 @@
1
1
  from __future__ import annotations
2
- import os, math, itertools
2
+ import math, itertools
3
3
  from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
4
- from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps
5
- from tinygrad.device import Device, Compiled
4
+ from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS
5
+ from tinygrad.device import Device
6
+ from tinygrad.renderer import Renderer, TensorCore
6
7
  from tinygrad.dtype import dtypes, ImageDType, DType
7
- from tinygrad.helpers import dedup, colored, ansilen, getenv, prod, DEBUG, round_up, all_int
8
- from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
8
+ from tinygrad.helpers import colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
9
+ from tinygrad.shape.shapetracker import ShapeTracker
9
10
  from tinygrad.shape.symbolic import sint
10
11
  from tinygrad.shape.view import View, strides_for_shape
11
12
  from dataclasses import dataclass
12
13
  from enum import Enum, auto
13
14
 
14
15
  class OptOps(Enum):
15
- UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto() # noqa: E702
16
+ TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
16
17
  GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702
17
18
  def __lt__(self, x:OptOps): return self.value < x.value
18
19
 
20
+ class KernelOptError(Exception): pass
21
+
22
+ def check(cond:bool, msg:str=""):
23
+ if not cond: raise KernelOptError(msg)
24
+
19
25
  @dataclass(frozen=True, order=True)
20
26
  class Opt:
21
27
  op: OptOps
22
28
  axis: Optional[int] = None
23
29
  amt: Optional[int] = None
24
30
  def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
25
-
26
- @dataclass(frozen=True)
27
- class TensorCore:
28
- device: str
29
- dims: List[int]
30
- dtype_in: DType
31
- dtype_out: DType
32
- threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
33
- upcast_dim: int # which TC dim to upcast
34
- thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
35
- thread_local_sizes: List[int] # in each thread, the number of elements stored in registers for each TC dim
36
- arch: Optional[str] = None
37
- def __str__(self): return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>"
38
-
39
- tensor_cores: Dict[str, List[TensorCore]] = {
40
- "METAL": [
41
- TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
42
- # TODO: enable half @ half -> half tensor core with correct dtypes in uop
43
- # TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
44
- ],
45
- "HIP": [
46
- TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
47
- TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
48
- ]
49
- }
31
+ def real_axis(self, k:Kernel):
32
+ if self.axis is None: return -1
33
+ if self.op is OptOps.UNROLL: return k.first_reduce+self.axis
34
+ if self.op in {OptOps.GROUP, OptOps.GROUPTOP}: return k.first_reduce+k.group_for_reduces+self.axis
35
+ return self.axis
36
+
37
+ class TensorCoreOptions(NamedTuple):
38
+ bufs: Tuple[int, int] # the local aliased buffers for A and B
39
+ axes: List[int] # the location of the original N and M axes if still in the shape
40
+ axes_exist: List[bool] # true if the original N and M axes are still in the shape
41
+ def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when an dimension is removed
42
+ for tc_dim in [i for i in range(2) if self.axes_exist[i]]:
43
+ if removed_axis < self.axes[tc_dim]: self.axes[tc_dim] -= 1
44
+ elif removed_axis == self.axes[tc_dim]: self.axes_exist[tc_dim] = False
50
45
 
51
46
  class LocalBuffer(NamedTuple):
52
47
  name: str
@@ -55,53 +50,46 @@ class LocalBuffer(NamedTuple):
55
50
  realized: None = None
56
51
  def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
57
52
 
58
- class LinearizerOptions(NamedTuple):
59
- device: str = ""
60
- # TODO: make this generic with a list of supported types
61
- supports_float4: bool = True
62
- supports_float4_alu: bool = True
63
- has_local: bool = True
64
- has_shared: bool = True
65
- # NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
66
- global_max: Optional[List[int]] = None
67
- local_max: Optional[List[int]] = None
68
-
69
53
  class Kernel:
70
- def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None):
71
- self.opts = opts or (device.linearizer_opts if isinstance(device:=Device[Device.DEFAULT], Compiled) else LinearizerOptions())
54
+ def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
55
+ self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
56
+ assert all(op.op is BufferOps.STORE for op in ast), f"kernels must have stores as the output, got {ast}"
57
+ assert len(set(op.arg.st.size for op in ast)) == 1, f"all outbufs should have the same size, got {[op.arg.st for op in ast]}"
72
58
  self.ast = ast
73
- assert ast.op == BufferOps.STORE, f"kernels must have a store as the output, got {ast.op}"
74
-
75
- # fetch lazyop info
76
- self.info: FlopCounter = get_lazyop_info(self.ast)
59
+ self.lazyops = flatten([op.lazyops for op in self.ast])
77
60
 
78
61
  # there's only allowed to be one reduceop
79
- reduceops = [x for x in self.ast.lazyops if x.op in ReduceOps]
80
- assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast"
81
- self.reduceop = reduceops[0] if reduceops else None
62
+ cached_ordered_lazyops: Dict[LazyOp, List[LazyOp]] = {}
63
+ def ordered_lazyops(op):
64
+ if op not in cached_ordered_lazyops: cached_ordered_lazyops[op] = dedup([item for x in op.src for item in ordered_lazyops(x)] + [op])
65
+ return cached_ordered_lazyops[op]
66
+ self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps])
67
+ assert len(self.reduceops) < 2, "Only one reduceop allowed"
82
68
 
83
- self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = dedup([x.arg for x in self.ast.lazyops if x.op in BufferOps])
84
- assert isinstance(self.bufs[0], MemBuffer) and self.bufs[0].idx == 0, f"buffer 0 is not the store buffer {self.bufs[0]}"
69
+ self.outbufs, self.vars = [x.arg for x in self.ast], flatten([x.vars() for x in self.ast])
70
+ loadops = [BufferOps.LOAD, BufferOps.CONST]
71
+ self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])
85
72
 
86
73
  # get earlybufs, before the one reduce op
87
- self.earlybufs = [x.arg for x in self.reduceop.lazyops if x.op in BufferOps] if self.reduceop else []
74
+ self.earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
88
75
  self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
89
76
 
90
77
  # create new shapetrackers inside this kernel, we will permute them
91
78
  self.sts: List[ShapeTracker] = [x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)]
92
79
 
93
80
  # move all reduce axes to the end
94
- reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape)))
81
+ reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
95
82
  permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
96
83
  self.reshape_and_permute(None, permute)
97
84
 
98
85
  # parameters for optimization
99
86
  self.applied_opts: List[Opt] = []
100
- self.group_for_reduce: List[int] = []
87
+ self.group_for_reduces: int = 0
101
88
  self.upcasted: int = 0
102
89
  self.local_dims: int = 0
103
90
  self.local_alias: Dict[int, LocalBuffer] = {}
104
91
  self.tensor_core: Optional[TensorCore] = None
92
+ self.tensor_core_opts: Optional[TensorCoreOptions] = None
105
93
  self.dont_use_locals: bool = False
106
94
 
107
95
  # group simplifies
@@ -115,16 +103,17 @@ class Kernel:
115
103
  ret = type(self).__new__(type(self))
116
104
 
117
105
  # base linearizer params
118
- ret.opts, ret.ast = self.opts, self.ast
106
+ ret.opts, ret.ast, ret.lazyops = self.opts, self.ast, self.lazyops
119
107
 
120
108
  # things downstream of the AST
121
- # NOTE: we copy bufs for local buffers and sts for optimizations
122
- ret.info, ret.reduceop, ret.bufs, ret.earlybufs, ret.full_buf_index, ret.sts = \
123
- self.info, self.reduceop, self.bufs[:], self.earlybufs, self.full_buf_index, self.sts[:]
109
+ ret.reduceops, ret.outbufs, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \
110
+ self.reduceops, self.outbufs, self.vars, [x for x in self.bufs if not isinstance(x, LocalBuffer)], self.earlybufs, self.full_buf_index
111
+ ret.sts = self.sts[:len(ret.bufs)] # NOTE: must redo the local buffers with TC in beam
124
112
 
125
113
  # parameters for optimizations
126
- ret.applied_opts, ret.group_for_reduce, ret.upcasted, ret.local_dims, ret.local_alias, ret.tensor_core, ret.dont_use_locals = \
127
- self.applied_opts[:], self.group_for_reduce[:], self.upcasted, self.local_dims, self.local_alias.copy(), self.tensor_core, self.dont_use_locals
114
+ ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
115
+ self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
116
+ ret.tensor_core, ret.tensor_core_opts, ret.local_alias = self.tensor_core, self.tensor_core_opts, {}
128
117
 
129
118
  # uncached since linearize didn't run
130
119
  ret.applied_opts_cache = None
@@ -138,9 +127,10 @@ class Kernel:
138
127
  def shape_offsets(self, i:int): return itertools.product(*[list(range(cast(int, s))) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] # noqa: E501
139
128
  def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] # noqa: E501
140
129
 
141
- def upcasted_axis(self, i:int):
142
- return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:],
143
- self.sts[i].real_strides()[self.shape_len-self.upcasted:],
130
+ def upcasted_axis(self, i:int) -> List[Tuple[int, Optional[sint], bool]]:
131
+ upcasted_shape, upcasted_stride = self.sts[i].shape[self.shape_len-self.upcasted:], self.sts[i].real_strides()[self.shape_len-self.upcasted:]
132
+ assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
133
+ return list(zip(upcasted_shape, upcasted_stride,
144
134
  [x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
145
135
 
146
136
  # TODO: is there a better way to write this?
@@ -158,6 +148,9 @@ class Kernel:
158
148
  def first_reduce(self) -> int:
159
149
  return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True) # noqa: E501
160
150
 
151
+ @property
152
+ def reduceop(self) -> Optional[LazyOp]: return self.reduceops[0] if len(self.reduceops) > 0 else None
153
+
161
154
  @property
162
155
  def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape
163
156
 
@@ -172,7 +165,7 @@ class Kernel:
172
165
 
173
166
  @property
174
167
  def upcast_in_mid_reduce_axes(self) -> List[int]:
175
- return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]]
168
+ return [j for j in range(self.first_reduce, self.first_reduce+self.group_for_reduces) if self.full_shape[j] == self.sts[0].shape[j]]
176
169
 
177
170
  @property
178
171
  def global_dims(self) -> int: return self.first_reduce-self.local_dims
@@ -192,10 +185,10 @@ class Kernel:
192
185
  colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
193
186
  # after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
194
187
  colors += ["cyan"] * self.local_dims
195
- # between first_reduce and first_reduce + group_for_reduce, they are either upcast mid reduce (white), or late upcasted (green)
196
- colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))] # noqa: E501
197
- # between first_reduce + group_for_reduce and upcasted, they are reduce (red)
198
- colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + len(self.group_for_reduce)))
188
+ # between first_reduce and first_reduce + group_for_reduces, they are either upcast mid reduce (white), or late upcasted (green)
189
+ colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + self.group_for_reduces)] # noqa: E501
190
+ # between first_reduce + group_for_reduces and upcasted, they are reduce (red)
191
+ colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + self.group_for_reduces))
199
192
  # upcasted dimensions are reduce (magenta) or normal (yellow)
200
193
  colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.shape_len-self.upcasted, self.shape_len)]
201
194
  assert len(colors) == self.shape_len, "colors size mismatch"
@@ -219,7 +212,7 @@ class Kernel:
219
212
 
220
213
  # drops the final dimension
221
214
  def upcast(self):
222
- assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1"
215
+ check(self.full_shape[-1] != 1, "can't upcast a dimension with size 1")
223
216
  self.upcasted += 1
224
217
 
225
218
  # axis : the axis to pull from
@@ -242,7 +235,7 @@ class Kernel:
242
235
  if self.shape_len == 0: return False
243
236
  all_ones = [s==1 for s in self.full_shape]
244
237
  self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
245
- self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:])
238
+ self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:]) # TODO: no necessary since upcasted axis can't be un-upcasted
246
239
  self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
247
240
  return any(all_ones)
248
241
 
@@ -254,7 +247,7 @@ class Kernel:
254
247
  if isinstance(self.bufs[0].dtype, ImageDType):
255
248
  base_shape = self.bufs[0].dtype.shape
256
249
  if shape_idx_groups := get_contraction(self.output_shape, base_shape):
257
- special_strides: Tuple[int, ...] = tuple()
250
+ special_strides: Tuple[sint, ...] = tuple()
258
251
  for i,g in enumerate(shape_idx_groups):
259
252
  shape_piece = tuple(self.output_shape[x] for x in g)
260
253
  assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
@@ -281,35 +274,32 @@ class Kernel:
281
274
  # do the reshapes
282
275
  for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
283
276
 
284
- # ******************** GPU simplifiers ********************
277
+ # ******************** helpers ********************
285
278
 
286
- def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
287
- new_shape,dims = list(x), len(x)
288
- for i in range(dims):
289
- next_idx = (i + 1) % dims
279
+ def _limit_size(self, x: Tuple[int], max_size: List[Union[int,float]]) -> Tuple[int, ...]:
280
+ new_shape = list(x)
281
+ for i in range(len(new_shape)):
282
+ next_idx = (i + 1) % len(new_shape)
290
283
  while new_shape[i] > max_size[i]:
284
+ # TODO: what if new_shape[i] is not a multiple of 2??
291
285
  new_shape[i] = new_shape[i] // 2
292
- if (new_shape[next_idx] <= max_size[next_idx]):
293
- new_shape[next_idx] = new_shape[next_idx] * 2
294
- else:
295
- next_idx = (next_idx + 1) % dims
296
- new_shape[next_idx] = new_shape[next_idx] * 2
286
+ next_idx = next_idx if new_shape[next_idx] <= max_size[next_idx] else (next_idx + 1) % len(new_shape)
287
+ new_shape[next_idx] = new_shape[next_idx] * 2
297
288
  return tuple(new_shape)
298
289
 
299
290
  def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
300
291
  # Check the global allocation limit, current the global_size will be flipped during codegen
301
292
  # and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
302
- global_dims = self.first_reduce-self.local_dims
303
- if global_dims > 0:
293
+ if self.global_dims > 0:
304
294
  if global_max:
305
- tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else [])
306
- if max(global_max) < max(self.full_shape[:global_dims]):
295
+ tmp = global_max[:self.global_dims] + (local_max[:self.local_dims] if local_max else [])
296
+ if max(global_max) < max(self.full_shape[:self.global_dims]):
307
297
  self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
308
- assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}" # noqa: E501
309
- for i in range(global_dims-1):
298
+ assert max(global_max) >= max(self.full_shape[:self.global_dims]), f"device max allocation {max(self.full_shape[:self.global_dims])} exceeds global dim maximum {max(global_max)}" # noqa: E501
299
+ for i in range(self.global_dims-1):
310
300
  if i < len(global_max) and self.full_shape[i] > global_max[i]:
311
301
  order = list(range(len(self.full_shape)))
312
- order[i], order[global_dims-1] = order[global_dims-1], order[i]
302
+ order[i], order[self.global_dims-1] = order[self.global_dims-1], order[i]
313
303
  self.reshape_and_permute(None, order)
314
304
  if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit")
315
305
 
@@ -332,134 +322,182 @@ class Kernel:
332
322
 
333
323
  # ******************** high level optimizers ********************
334
324
 
335
- def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None) -> bool:
336
- if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op == ReduceOps.SUM and self.opts.device in tensor_cores:
337
- for tc in tensor_cores[self.opts.device]:
338
- if not (use_tensor_cores==2 or (tc.arch is None or tc.arch == os.uname().machine)): continue
325
+ def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
326
+ if use_tensor_cores and self.opts.has_local and self.reduceop is not None and self.reduceop.op is ReduceOps.SUM:
327
+ for tc in self.opts.tensor_cores:
339
328
  has_cast = tc.dtype_in != tc.dtype_out
329
+ if has_cast and not(self.reduceop.src[0].op is UnaryOps.CAST and self.reduceop.src[0].arg == tc.dtype_out): continue
340
330
 
341
- if has_cast and not(self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue
342
331
  mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
332
+ if mul_op.op is not BinaryOps.MUL: continue
343
333
 
344
- if mul_op.op != BinaryOps.MUL: continue
345
- if not (mul_op.src[0].op == BufferOps.LOAD and mul_op.src[0].arg.dtype == tc.dtype_in): continue
346
- if not (mul_op.src[1].op == BufferOps.LOAD and mul_op.src[1].arg.dtype == tc.dtype_in): continue
347
- buf0, buf1 = self.bufs.index(cast(MemBuffer, mul_op.src[0].arg)), self.bufs.index(cast(MemBuffer, mul_op.src[1].arg))
348
- buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
349
- axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0] # noqa: E501
350
- axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0] # noqa: E501
334
+ def buf_index(src: LazyOp) -> Optional[int]:
335
+ # TODO: apply tc even if the sources are not from LOAD
336
+ if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
337
+ try:
338
+ if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
339
+ except ValueError: return None
340
+ return None
341
+ if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue
351
342
 
352
- if not(axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%tc.dims[2] == 0 and self.full_shape[self.first_reduce] >= tc.dims[2] and (self.shape_len-self.first_reduce) == 1): continue # noqa: E501
343
+ buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
344
+ axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
345
+ axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
346
+ if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): continue
353
347
 
354
- if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
348
+ axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
349
+ if not(axis < len(axis_choices)): continue
355
350
 
356
- s0, s1 = axis_buf0[-1][0], axis_buf1[-1][0] # TODO: select axis in smart way
357
- s0_exists, s1_exists = True, True
358
- assert s0 != s1 and self.full_shape[s0]%tc.dims[0] == 0 and self.full_shape[s1]%tc.dims[1] == 0
359
- def fix(needed, ax):
360
- nonlocal s0, s1, s0_exists, s1_exists
361
- if not needed: return
362
- if s0_exists and ax == s0:
363
- if s1_exists and s0 < s1: s1 -= 1
364
- s0_exists = False
365
- elif s1_exists and ax == s1:
366
- if s0_exists and s1 < s0: s0 -= 1
367
- s1_exists = False
351
+ s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
352
+ axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0]
353
+ if axis_pads and (opt_level < 2): continue
368
354
 
369
355
  # tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
370
- self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]))
371
- self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads])))
356
+ self.tensor_core_opts = (tc_opts:=TensorCoreOptions(bufs=(buf0, buf1), axes=[s0, s1], axes_exist=[True, True]))
357
+
358
+ # attempt to pad the tensor axes that require it
359
+ try:
360
+ for axis, dim in axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
361
+ except KernelOptError: continue
362
+ self.apply_opt(Opt(OptOps.UNROLL, s2-self.first_reduce, tc.dims[2]), append_opt=False)
363
+ for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
364
+ if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
372
365
  for (tc_dim, tc_amt) in tc.threads:
373
- fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
366
+ self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
374
367
 
375
- # assert tensor core and prevent extra_opts from altering the key shape structure
368
+ # assert tensor core
369
+ if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
376
370
  if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
371
+ return True
372
+ return False
377
373
 
374
+ def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, axis:int=0, tc_opt:int=getenv("TC_OPT")) -> bool:
375
+ """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
376
+ Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
377
+
378
+ Keyword arguments:
379
+ use_tensor_cores -- controls how tensor cores are applied (default 1)
380
+ 0: will disable any tensor core matching
381
+ 1: enable tensor cores
382
+ 2: apply tensor core shape but don't use UOp.WMMA
383
+ extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
384
+ tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
385
+ 0: applies to only kernels with a single reduce axis and direct BufferOps.LOAD into BinaryOps.MUL
386
+ 1: allows kernels with multiple reduce axes and also multiplication of UnaryOps.CAST'd buffers
387
+ 2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
388
+ """
389
+ if not self.opts.tensor_cores and use_tensor_cores != 2: return False
390
+ try: # check TC first and apply hand-coded opts if successful
391
+ self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
392
+
393
+ if (tc_opts:=self.tensor_core_opts) is not None:
378
394
  if extra_opts is not None:
379
395
  for opt in extra_opts: self.apply_opt(opt)
380
396
  else:
381
397
  # hand-coded TC opts
382
- if s1_exists:
383
- s1_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s1]%upc == 0][0]
384
- if s1_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s1, s1_div)), s1)
385
- if s0_exists:
386
- s0_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s0]%upc == 0][0]
387
- if s0_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s0, s0_div)), s0)
388
- if self.tensor_core and s0_exists:
398
+ def late_upcast_tc(tc_dim: int):
399
+ if tc_opts.axes_exist[tc_dim]:
400
+ ax_div = [upc for upc in [5,4,3,2,1] if self.full_shape[tc_opts.axes[tc_dim]]%upc == 0][0]
401
+ if ax_div != 1: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], ax_div))
402
+ late_upcast_tc(1) # attempt to upcast M
403
+ late_upcast_tc(0) # attempt to upcast N
404
+
405
+ if self.tensor_core and tc_opts.axes_exist[0]: # attempt to local N
389
406
  for upc in [4,2]:
390
- if self.full_shape[s0] % upc == 0:
391
- self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc))
407
+ if self.full_shape[tc_opts.axes[0]] % upc == 0:
408
+ self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], upc))
392
409
  break
393
410
 
394
- # alias buffer
395
- alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) # noqa: E501
396
- self.alias_buffer(buf0, alias_pattern)
397
- self.alias_buffer(buf1, alias_pattern)
398
- return True
399
- return False
411
+ return True
412
+ except KernelOptError:
413
+ return False
414
+
415
+ def apply_opt(self, opt:Opt, append_opt:bool=True):
416
+ check(not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals")
417
+
418
+ if opt.op is OptOps.TC:
419
+ check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
420
+ check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
421
+ check((use_tensor_cores:=getenv("TC", 1)) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
422
+ check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available")
423
+ self.applied_opts.append(opt)
424
+ return
425
+
426
+ axis = opt.real_axis(self)
427
+ check(axis < len(self.full_shape), "invalid axis")
400
428
 
401
- def apply_opt(self, opt:Opt):
402
- assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals" # noqa: E501
403
- self.applied_opts.append(opt)
404
- if opt.axis is not None:
405
- axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op in [OptOps.GROUP, OptOps.GROUPTOP] else 0)) # noqa: E501
406
- else:
407
- axis = -1
408
429
  if opt.amt is not None:
409
430
  amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
410
- assert isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless"
411
- if opt.op != OptOps.PADTO: assert self.full_shape[axis] % amt == 0, "no longer valid shift"
412
- else:
413
- amt = -1
414
- if opt.op in [OptOps.LOCAL, OptOps.LASTLOCAL]: # cyan
415
- assert self.opts.has_local, "target does not support local"
416
- assert axis < self.first_reduce, "can't local a reduce"
417
- if opt.op == OptOps.LOCAL:
418
- assert not self.tensor_core, "can't local with tensor cores"
419
- self.shift_to(axis, amt, insert_before=self.first_reduce)
420
- else:
421
- self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims)
431
+ check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless")
432
+ if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
433
+ else: amt = -1
434
+
435
+ if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
436
+ acc_sz, upcast_idx = dt.base.itemsize if isinstance((dt:=self.reduceop.dtype), ImageDType) else dt.itemsize, self.shape_len-self.upcasted
437
+ upcast_sz = prod([a for a,b in zip(self.full_shape[upcast_idx:], self.sts[0].shape[upcast_idx:]) if a == b])
438
+ local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
439
+ smem_sz = amt*acc_sz*upcast_sz*local_sz
440
+ check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
441
+
442
+ if opt.op is OptOps.LOCAL: # cyan
443
+ check(self.opts.has_local, "target does not support local")
444
+ check(axis < self.global_dims, "local is for globals")
445
+ self.shift_to(axis, amt, insert_before=self.first_reduce)
422
446
  self.local_dims += 1
423
- elif opt.op in [OptOps.GROUP, OptOps.GROUPTOP]: # green
424
- assert self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem"
425
- assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
426
- assert not self.tensor_core, "can't group with tensor cores"
427
- self.shift_to(axis, amt, top=(opt.op==OptOps.GROUPTOP), insert_before=self.first_reduce + len(self.group_for_reduce))
428
- self.group_for_reduce.append(amt)
429
- elif opt.op == OptOps.UNROLL: # purple
430
- assert axis < self.shape_len-self.upcasted, "can't upcasted already upcasted"
431
- assert amt <= 32, "don't unroll more than 32"
447
+ elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
448
+ check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
449
+ check(axis >= self.first_reduce + self.group_for_reduces and axis < self.shape_len-self.upcasted, "must be reduce axis to group")
450
+ check(not self.tensor_core, "can't group with tensor cores")
451
+ self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
452
+ self.group_for_reduces += 1
453
+ elif opt.op is OptOps.UNROLL: # purple
454
+ check(axis < self.shape_len-self.upcasted, "can't upcasted already upcasted")
455
+ check(amt <= 32, "don't unroll more than 32")
456
+ # TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
457
+ #upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
458
+ #self.shift_to(axis, amt, insert_before=None if upcast_count == 0 else self.shape_len-upcast_count)
459
+ if self.full_shape[axis] == amt and axis == self.first_reduce: self.local_dims += 1 # first_reduce will ++, so offset loss in simplify_ones
460
+ if self.full_shape[axis] == amt and axis < self.first_reduce+self.group_for_reduces: self.group_for_reduces -= 1 # fully unrolling a GROUP
432
461
  self.shift_to(axis, amt, insert_before=None)
433
462
  self.upcast()
434
- elif opt.op == OptOps.UPCAST: # yellow
435
- assert axis < self.first_reduce, "upcast is for non-reduce"
436
- assert amt <= 8, "don't upcast more than 8"
463
+ elif opt.op is OptOps.UPCAST: # yellow
464
+ check(axis < self.first_reduce, "upcast is for non-reduce")
465
+ check(not(self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
466
+ check(amt <= 8, "don't upcast more than 8")
437
467
  self.shift_to(axis, amt, insert_before=None)
438
468
  self.upcast()
439
- elif opt.op == OptOps.UPCASTMID: # white
440
- assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce" # noqa: E501
469
+ elif opt.op is OptOps.UPCASTMID: # white
470
+ check(self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
441
471
  axes = self.sts[0].unit_stride_axes()
442
- assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
443
- assert axes[0] == axis, "wrong axis"
444
- assert amt == 4, "don't upcast mid anything but 4"
445
- self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
446
- self.group_for_reduce.append(amt)
447
- elif opt.op == OptOps.NOLOCALS:
448
- assert self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals"
449
- assert self.local_dims == 0 and len(self.group_for_reduce) == 0, "can't have no locals with locals"
472
+ check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
473
+ check(axes[0] == axis, "wrong axis")
474
+ check(amt == 4, "don't upcast mid anything but 4")
475
+ self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces)
476
+ self.group_for_reduces += 1
477
+ elif opt.op is OptOps.NOLOCALS:
478
+ check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
479
+ check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
450
480
  self.dont_use_locals = True
451
- elif opt.op == OptOps.PADTO:
452
- assert not self.ast.vars(), "does not work with symbolic shape"
453
- assert axis < self.first_reduce, "cannot pad a reduce axis"
481
+ elif opt.op is OptOps.PADTO:
482
+ check(not self.vars, "does not work with symbolic shape")
483
+ check(axis < self.shape_len - self.upcasted, "cannot pad upcasted")
484
+ # ok to pad SUM if all parent ops have f(0) = 0
485
+ if self.first_reduce <= axis:
486
+ check((r:=cast(LazyOp, self.reduceop)).op is ReduceOps.SUM and \
487
+ all(op.op not in UNSAFE_PAD_OPS for ops in r.src for op in ops.lazyops), "cannot pad")
454
488
  padded = False
455
489
  for i,st in enumerate(self.sts):
456
- assert self.sts[i].shape[axis] > amt//2, "pad adds more than double the work"
457
- if (ru := round_up(self.sts[i].shape[axis], amt) - self.sts[i].shape[axis]):
490
+ if self.sts[i].shape[axis] == 1: continue # reduced
491
+ check(self.sts[i].shape[axis] > amt//4, f"pad adds more than quadruple the work {self.sts[i].shape[axis]=} > {amt//4=}")
492
+ if (ru := round_up(cast(int, self.sts[i].shape[axis]), cast(int, amt)) - self.sts[i].shape[axis]):
458
493
  # pad right seems to be faster
459
494
  self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
460
495
  padded = True
461
- assert padded, "nothing was padded"
462
- return self.simplify_ones()
496
+ check(padded, "nothing was padded")
497
+
498
+ if append_opt: self.applied_opts.append(opt)
499
+ if self.simplify_ones() and self.tensor_core_opts:
500
+ self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
463
501
 
464
502
  def required_optimizations(self):
465
503
  if self.bufs[0].dtype.__class__ is ImageDType:
@@ -474,8 +512,8 @@ class Kernel:
474
512
  # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
475
513
  MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
476
514
  if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
477
- self.reduceop and self.reduceop.op == ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
478
- (mulop:=self.reduceop.src[0]).op == BinaryOps.MUL and mulop.src[0].op == BufferOps.LOAD and mulop.src[1].op == BufferOps.LOAD:
515
+ self.reduceop is not None and self.reduceop.op is ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
516
+ (mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is BufferOps.LOAD and mulop.src[1].op is BufferOps.LOAD:
479
517
  st0, st1 = self.sts[self.bufs.index(mulop.src[0].arg)], self.sts[self.bufs.index(mulop.src[1].arg)]
480
518
  strides0, strides1 = st0.real_strides(), st1.real_strides()
481
519
  def has_expanded_axis(shape, strides): return any(s > 1 and st == 0 for s,st in zip(shape,strides))
@@ -495,11 +533,13 @@ class Kernel:
495
533
  # TODO: use 1024 if it's allowed in a smarter way
496
534
  for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
497
535
  if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
498
- self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
499
- break
536
+ try: # may fail due to excessive smem usage
537
+ self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
538
+ break
539
+ except KernelOptError: pass
500
540
 
501
541
  # are we upcasting in mid reduce? (only for images)
502
- if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
542
+ if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
503
543
  axes = self.sts[0].unit_stride_axes()
504
544
  assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
505
545
  if self.sts[0].shape[axes[0]]%4 == 0:
@@ -517,7 +557,7 @@ class Kernel:
517
557
  self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
518
558
 
519
559
  # no more opt if we are grouping
520
- if self.group_for_reduce: return
560
+ if self.group_for_reduces: return
521
561
 
522
562
  # **** below this line need to be optional and benchmarked ****
523
563
 
@@ -574,7 +614,7 @@ class Kernel:
574
614
  # **** local groups ****
575
615
 
576
616
  if self.opts.has_local:
577
- if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduce:
617
+ if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduces:
578
618
  self.apply_opt(Opt(OptOps.NOLOCALS))
579
619
  else:
580
620
  # prioritize making expand axes local