tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/__init__.py +0 -0
  3. tinygrad/codegen/kernel.py +253 -225
  4. tinygrad/codegen/linearizer.py +398 -436
  5. tinygrad/codegen/uops.py +451 -0
  6. tinygrad/device.py +268 -274
  7. tinygrad/dtype.py +56 -40
  8. tinygrad/engine/__init__.py +0 -0
  9. tinygrad/engine/graph.py +100 -0
  10. tinygrad/engine/jit.py +198 -0
  11. tinygrad/engine/realize.py +192 -0
  12. tinygrad/engine/schedule.py +370 -0
  13. tinygrad/engine/search.py +199 -0
  14. tinygrad/{mlops.py → function.py} +40 -32
  15. tinygrad/helpers.py +144 -46
  16. tinygrad/lazy.py +143 -242
  17. tinygrad/multi.py +173 -0
  18. tinygrad/nn/__init__.py +180 -9
  19. tinygrad/nn/datasets.py +8 -0
  20. tinygrad/nn/optim.py +106 -28
  21. tinygrad/nn/state.py +87 -19
  22. tinygrad/ops.py +104 -45
  23. tinygrad/renderer/__init__.py +65 -0
  24. tinygrad/renderer/assembly.py +269 -0
  25. tinygrad/renderer/cstyle.py +308 -210
  26. tinygrad/renderer/llvmir.py +119 -124
  27. tinygrad/runtime/__init__.py +0 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +13403 -0
  29. tinygrad/runtime/autogen/comgr.py +891 -0
  30. tinygrad/runtime/autogen/cuda.py +5923 -0
  31. tinygrad/runtime/autogen/hip.py +5909 -0
  32. tinygrad/runtime/autogen/hsa.py +5893 -0
  33. tinygrad/runtime/autogen/io_uring.py +1486 -0
  34. tinygrad/runtime/autogen/kfd.py +812 -0
  35. tinygrad/runtime/autogen/nv_gpu.py +33597 -0
  36. tinygrad/runtime/autogen/opencl.py +1795 -0
  37. tinygrad/runtime/driver/__init__.py +0 -0
  38. tinygrad/runtime/driver/hip_comgr.py +56 -0
  39. tinygrad/runtime/graph/__init__.py +0 -0
  40. tinygrad/runtime/graph/clang.py +39 -0
  41. tinygrad/runtime/graph/cuda.py +59 -54
  42. tinygrad/runtime/graph/hcq.py +187 -0
  43. tinygrad/runtime/graph/metal.py +37 -41
  44. tinygrad/runtime/ops_amd.py +550 -0
  45. tinygrad/runtime/ops_clang.py +16 -14
  46. tinygrad/runtime/ops_cuda.py +129 -37
  47. tinygrad/runtime/ops_disk.py +111 -43
  48. tinygrad/runtime/ops_gpu.py +52 -50
  49. tinygrad/runtime/ops_llvm.py +36 -56
  50. tinygrad/runtime/ops_metal.py +41 -24
  51. tinygrad/runtime/ops_npy.py +9 -0
  52. tinygrad/runtime/ops_nv.py +625 -0
  53. tinygrad/runtime/ops_python.py +208 -0
  54. tinygrad/shape/__init__.py +0 -0
  55. tinygrad/shape/shapetracker.py +46 -107
  56. tinygrad/shape/symbolic.py +99 -98
  57. tinygrad/shape/view.py +162 -45
  58. tinygrad/tensor.py +2492 -483
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
  60. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
  61. tinygrad-0.9.1.dist-info/RECORD +63 -0
  62. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  63. tinygrad/features/image.py +0 -93
  64. tinygrad/features/multi.py +0 -103
  65. tinygrad/features/search.py +0 -160
  66. tinygrad/graph.py +0 -106
  67. tinygrad/jit.py +0 -152
  68. tinygrad/realize.py +0 -50
  69. tinygrad/runtime/graph/hip.py +0 -24
  70. tinygrad/runtime/ops_cpu.py +0 -45
  71. tinygrad/runtime/ops_hip.py +0 -97
  72. tinygrad/runtime/ops_torch.py +0 -49
  73. tinygrad-0.8.0.dist-info/RECORD +0 -41
  74. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,107 +1,99 @@
1
1
  from __future__ import annotations
2
- import os, math, itertools
3
- from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
4
- from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps
5
- from tinygrad.device import Device, Compiled
2
+ from collections import defaultdict
3
+ import itertools
4
+ from typing import DefaultDict, Optional, List, Tuple, cast, Dict, Union
5
+ from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS, verify_lazyop
6
+ from tinygrad.device import Device
7
+ from tinygrad.renderer import Renderer, TensorCore
6
8
  from tinygrad.dtype import dtypes, ImageDType, DType
7
- from tinygrad.helpers import dedup, colored, ansilen, getenv, prod, DEBUG, round_up, all_int
8
- from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
9
+ from tinygrad.helpers import all_same, colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
10
+ from tinygrad.shape.shapetracker import ShapeTracker
9
11
  from tinygrad.shape.symbolic import sint
10
12
  from tinygrad.shape.view import View, strides_for_shape
11
13
  from dataclasses import dataclass
12
14
  from enum import Enum, auto
13
15
 
14
16
  class OptOps(Enum):
15
- UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto() # noqa: E702
17
+ TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
16
18
  GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702
17
19
  def __lt__(self, x:OptOps): return self.value < x.value
18
20
 
21
+ class KernelOptError(Exception): pass
22
+
23
+ def check(cond:bool, msg:str=""):
24
+ if not cond: raise KernelOptError(msg)
25
+
19
26
  @dataclass(frozen=True, order=True)
20
27
  class Opt:
21
28
  op: OptOps
22
29
  axis: Optional[int] = None
23
30
  amt: Optional[int] = None
24
31
  def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
32
+ def real_axis(self, k:Kernel):
33
+ if self.axis is None: return -1
34
+ if self.op is OptOps.UNROLL: return k.first_reduce+self.axis
35
+ if self.op in {OptOps.GROUP, OptOps.GROUPTOP}: return k.first_reduce+k.group_for_reduces+self.axis
36
+ return self.axis
37
+
38
+ @dataclass
39
+ class TensorCoreOptions:
40
+ axes: Tuple[int, ...] # the location of the original N and M axes if still in the shape
41
+ axes_exist: Tuple[bool, ...] # true if the original N and M axes are still in the shape
42
+ axis_pads: Tuple[Tuple[int, int], ...]
43
+ def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when a dimension is removed
44
+ axes, axes_exist = list(self.axes), list(self.axes_exist)
45
+ for tc_dim in [i for i in range(2) if axes_exist[i]]:
46
+ if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1
47
+ elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False
48
+ self.axes, self.axes_exist = tuple(axes), tuple(axes_exist)
25
49
 
26
50
  @dataclass(frozen=True)
27
- class TensorCore:
28
- device: str
29
- dims: List[int]
30
- dtype_in: DType
31
- dtype_out: DType
32
- threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
33
- upcast_dim: int # which TC dim to upcast
34
- thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
35
- thread_local_sizes: List[int] # in each thread, the number of elements stored in registers for each TC dim
36
- arch: Optional[str] = None
37
- def __str__(self): return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>"
38
-
39
- tensor_cores: Dict[str, List[TensorCore]] = {
40
- "METAL": [
41
- TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
42
- # TODO: enable half @ half -> half tensor core with correct dtypes in uop
43
- # TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
44
- ],
45
- "HIP": [
46
- TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
47
- TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
48
- ]
49
- }
50
-
51
- class LocalBuffer(NamedTuple):
51
+ class LocalBuffer:
52
52
  name: str
53
53
  size: int
54
54
  dtype: DType = dtypes.float32
55
55
  realized: None = None
56
56
  def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
57
57
 
58
- class LinearizerOptions(NamedTuple):
59
- device: str = ""
60
- # TODO: make this generic with a list of supported types
61
- supports_float4: bool = True
62
- supports_float4_alu: bool = True
63
- has_local: bool = True
64
- has_shared: bool = True
65
- # NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
66
- global_max: Optional[List[int]] = None
67
- local_max: Optional[List[int]] = None
68
-
69
58
  class Kernel:
70
- def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None):
71
- self.opts = opts or (device.linearizer_opts if isinstance(device:=Device[Device.DEFAULT], Compiled) else LinearizerOptions())
59
+ def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
60
+ self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
61
+ verify_lazyop(*ast)
72
62
  self.ast = ast
73
- assert ast.op == BufferOps.STORE, f"kernels must have a store as the output, got {ast.op}"
63
+ self.lazyops = flatten([op.lazyops for op in self.ast])
74
64
 
75
- # fetch lazyop info
76
- self.info: FlopCounter = get_lazyop_info(self.ast)
65
+ cached_ordered_lazyops: Dict[LazyOp, List[LazyOp]] = {}
66
+ def ordered_lazyops(op):
67
+ if op not in cached_ordered_lazyops: cached_ordered_lazyops[op] = dedup([item for x in op.src for item in ordered_lazyops(x)] + [op])
68
+ return cached_ordered_lazyops[op]
69
+ self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps])
77
70
 
78
- # there's only allowed to be one reduceop
79
- reduceops = [x for x in self.ast.lazyops if x.op in ReduceOps]
80
- assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast"
81
- self.reduceop = reduceops[0] if reduceops else None
71
+ self.outbufs, self.vars = [x.arg for x in self.ast], flatten([x.vars() for x in self.ast])
72
+ loadops = [BufferOps.LOAD, BufferOps.CONST]
73
+ self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])
82
74
 
83
- self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = dedup([x.arg for x in self.ast.lazyops if x.op in BufferOps])
84
- assert isinstance(self.bufs[0], MemBuffer) and self.bufs[0].idx == 0, f"buffer 0 is not the store buffer {self.bufs[0]}"
85
-
86
- # get earlybufs, before the one reduce op
87
- self.earlybufs = [x.arg for x in self.reduceop.lazyops if x.op in BufferOps] if self.reduceop else []
75
+ # get earlybufs, before any reduceops
76
+ self.earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
88
77
  self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
89
78
 
90
79
  # create new shapetrackers inside this kernel, we will permute them
91
80
  self.sts: List[ShapeTracker] = [x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)]
92
81
 
93
82
  # move all reduce axes to the end
94
- reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape)))
83
+ reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
95
84
  permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
96
85
  self.reshape_and_permute(None, permute)
97
86
 
98
87
  # parameters for optimization
99
88
  self.applied_opts: List[Opt] = []
100
- self.group_for_reduce: List[int] = []
89
+ self.group_for_reduces: int = 0
101
90
  self.upcasted: int = 0
102
91
  self.local_dims: int = 0
103
- self.local_alias: Dict[int, LocalBuffer] = {}
92
+ self.local_alias: DefaultDict[LazyOp, Dict[int, LocalBuffer]] = defaultdict(dict)
104
93
  self.tensor_core: Optional[TensorCore] = None
94
+ self.tensor_core_opts: Optional[TensorCoreOptions] = None
95
+ # the local aliased buffers for A and B
96
+ self.bufs_for_tensor_core: Dict[LazyOp, Tuple[int, int]] = {}
105
97
  self.dont_use_locals: bool = False
106
98
 
107
99
  # group simplifies
@@ -115,16 +107,18 @@ class Kernel:
115
107
  ret = type(self).__new__(type(self))
116
108
 
117
109
  # base linearizer params
118
- ret.opts, ret.ast = self.opts, self.ast
110
+ ret.opts, ret.ast, ret.lazyops = self.opts, self.ast, self.lazyops
119
111
 
120
112
  # things downstream of the AST
121
- # NOTE: we copy bufs for local buffers and sts for optimizations
122
- ret.info, ret.reduceop, ret.bufs, ret.earlybufs, ret.full_buf_index, ret.sts = \
123
- self.info, self.reduceop, self.bufs[:], self.earlybufs, self.full_buf_index, self.sts[:]
113
+ ret.reduceops, ret.outbufs, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \
114
+ self.reduceops, self.outbufs, self.vars, [x for x in self.bufs if not isinstance(x, LocalBuffer)], self.earlybufs, self.full_buf_index
115
+ ret.sts = self.sts[:len(ret.bufs)] # NOTE: must redo the local buffers with TC in beam
124
116
 
125
117
  # parameters for optimizations
126
- ret.applied_opts, ret.group_for_reduce, ret.upcasted, ret.local_dims, ret.local_alias, ret.tensor_core, ret.dont_use_locals = \
127
- self.applied_opts[:], self.group_for_reduce[:], self.upcasted, self.local_dims, self.local_alias.copy(), self.tensor_core, self.dont_use_locals
118
+ ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
119
+ self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
120
+ ret.tensor_core, ret.tensor_core_opts, ret.local_alias, ret.bufs_for_tensor_core = self.tensor_core, self.tensor_core_opts, defaultdict(dict), \
121
+ self.bufs_for_tensor_core
128
122
 
129
123
  # uncached since linearize didn't run
130
124
  ret.applied_opts_cache = None
@@ -138,9 +132,10 @@ class Kernel:
138
132
  def shape_offsets(self, i:int): return itertools.product(*[list(range(cast(int, s))) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] # noqa: E501
139
133
  def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] # noqa: E501
140
134
 
141
- def upcasted_axis(self, i:int):
142
- return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:],
143
- self.sts[i].real_strides()[self.shape_len-self.upcasted:],
135
+ def upcasted_axis(self, i:int) -> List[Tuple[int, Optional[sint], bool]]:
136
+ upcasted_shape, upcasted_stride = self.sts[i].shape[self.shape_len-self.upcasted:], self.sts[i].real_strides()[self.shape_len-self.upcasted:]
137
+ assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
138
+ return list(zip(upcasted_shape, upcasted_stride,
144
139
  [x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
145
140
 
146
141
  # TODO: is there a better way to write this?
@@ -158,6 +153,9 @@ class Kernel:
158
153
  def first_reduce(self) -> int:
159
154
  return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True) # noqa: E501
160
155
 
156
+ @property
157
+ def reduceop(self) -> Optional[LazyOp]: return self.reduceops[0] if len(self.reduceops) > 0 else None
158
+
161
159
  @property
162
160
  def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape
163
161
 
@@ -172,7 +170,7 @@ class Kernel:
172
170
 
173
171
  @property
174
172
  def upcast_in_mid_reduce_axes(self) -> List[int]:
175
- return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]]
173
+ return [j for j in range(self.first_reduce, self.first_reduce+self.group_for_reduces) if self.full_shape[j] == self.sts[0].shape[j]]
176
174
 
177
175
  @property
178
176
  def global_dims(self) -> int: return self.first_reduce-self.local_dims
@@ -192,10 +190,10 @@ class Kernel:
192
190
  colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
193
191
  # after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
194
192
  colors += ["cyan"] * self.local_dims
195
- # between first_reduce and first_reduce + group_for_reduce, they are either upcast mid reduce (white), or late upcasted (green)
196
- colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))] # noqa: E501
197
- # between first_reduce + group_for_reduce and upcasted, they are reduce (red)
198
- colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + len(self.group_for_reduce)))
193
+ # between first_reduce and first_reduce + group_for_reduces, they are either upcast mid reduce (white), or late upcasted (green)
194
+ colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + self.group_for_reduces)] # noqa: E501
195
+ # between first_reduce + group_for_reduces and upcasted, they are reduce (red)
196
+ colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + self.group_for_reduces))
199
197
  # upcasted dimensions are reduce (magenta) or normal (yellow)
200
198
  colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.shape_len-self.upcasted, self.shape_len)]
201
199
  assert len(colors) == self.shape_len, "colors size mismatch"
@@ -219,7 +217,7 @@ class Kernel:
219
217
 
220
218
  # drops the final dimension
221
219
  def upcast(self):
222
- assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1"
220
+ check(self.full_shape[-1] != 1, "can't upcast a dimension with size 1")
223
221
  self.upcasted += 1
224
222
 
225
223
  # axis : the axis to pull from
@@ -242,7 +240,7 @@ class Kernel:
242
240
  if self.shape_len == 0: return False
243
241
  all_ones = [s==1 for s in self.full_shape]
244
242
  self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
245
- self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:])
243
+ self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:]) # TODO: no necessary since upcasted axis can't be un-upcasted
246
244
  self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
247
245
  return any(all_ones)
248
246
 
@@ -254,7 +252,7 @@ class Kernel:
254
252
  if isinstance(self.bufs[0].dtype, ImageDType):
255
253
  base_shape = self.bufs[0].dtype.shape
256
254
  if shape_idx_groups := get_contraction(self.output_shape, base_shape):
257
- special_strides: Tuple[int, ...] = tuple()
255
+ special_strides: Tuple[sint, ...] = tuple()
258
256
  for i,g in enumerate(shape_idx_groups):
259
257
  shape_piece = tuple(self.output_shape[x] for x in g)
260
258
  assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
@@ -263,57 +261,29 @@ class Kernel:
263
261
  shapes.append(self.output_shape)
264
262
  strides.append(special_strides)
265
263
 
266
- # merge dimensions if we can, multi get_shape_strides
264
+ # merge dimensions if we can, multi _merge_dims
267
265
  # NOTE: this does not always preserve the reduce dimension
268
266
  # TODO: move this into shapetracker, with tests!
269
- rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))]
267
+ # TODO: how does this work with multi-reduce?
268
+ rets = [[(s[0], st[0])] for s,st in zip(shapes, strides)]
270
269
  for i in range(1, len(shapes[0])):
271
270
  can_merge = []
272
- for j in range(len(shapes)):
271
+ for s,st,ret in zip(shapes, strides, rets):
273
272
  # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
274
- can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0))) # noqa: E501
273
+ si, sti, last_st = s[i], st[i], ret[-1][1]
274
+ can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0)))
275
275
  # more can merge than this
276
276
  mergeable = all(can_merge) and i != self.first_reduce
277
- for j in range(len(shapes)):
278
- if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
279
- else: rets[j].append((shapes[j][i], strides[j][i]))
277
+ for j,(s,st) in enumerate(zip(shapes, strides)):
278
+ if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i])
279
+ else: rets[j].append((s[i], st[i]))
280
280
 
281
281
  # do the reshapes
282
282
  for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
283
283
 
284
- # ******************** GPU simplifiers ********************
284
+ # ******************** helpers ********************
285
285
 
286
- def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
287
- new_shape,dims = list(x), len(x)
288
- for i in range(dims):
289
- next_idx = (i + 1) % dims
290
- while new_shape[i] > max_size[i]:
291
- new_shape[i] = new_shape[i] // 2
292
- if (new_shape[next_idx] <= max_size[next_idx]):
293
- new_shape[next_idx] = new_shape[next_idx] * 2
294
- else:
295
- next_idx = (next_idx + 1) % dims
296
- new_shape[next_idx] = new_shape[next_idx] * 2
297
- return tuple(new_shape)
298
-
299
- def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
300
- # Check the global allocation limit, current the global_size will be flipped during codegen
301
- # and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
302
- global_dims = self.first_reduce-self.local_dims
303
- if global_dims > 0:
304
- if global_max:
305
- tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else [])
306
- if max(global_max) < max(self.full_shape[:global_dims]):
307
- self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
308
- assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}" # noqa: E501
309
- for i in range(global_dims-1):
310
- if i < len(global_max) and self.full_shape[i] > global_max[i]:
311
- order = list(range(len(self.full_shape)))
312
- order[i], order[global_dims-1] = order[global_dims-1], order[i]
313
- self.reshape_and_permute(None, order)
314
- if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit")
315
-
316
- def alias_buffer(self, i, pattern):
286
+ def alias_buffer(self, op:LazyOp, i:int, pattern:List[int]) -> None:
317
287
  assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
318
288
 
319
289
  bst = 1
@@ -328,138 +298,194 @@ class Kernel:
328
298
  self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
329
299
  self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size))
330
300
  if DEBUG >= 4: print("aliasing buffer", self.sts[i])
331
- self.local_alias[i] = cast(LocalBuffer, self.bufs[-1])
301
+ self.local_alias[op][i] = cast(LocalBuffer, self.bufs[-1])
332
302
 
333
303
  # ******************** high level optimizers ********************
334
304
 
335
- def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None) -> bool:
336
- if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op == ReduceOps.SUM and self.opts.device in tensor_cores:
337
- for tc in tensor_cores[self.opts.device]:
338
- if not (use_tensor_cores==2 or (tc.arch is None or tc.arch == os.uname().machine)): continue
339
- has_cast = tc.dtype_in != tc.dtype_out
340
-
341
- if has_cast and not(self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue
342
- mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
343
-
344
- if mul_op.op != BinaryOps.MUL: continue
345
- if not (mul_op.src[0].op == BufferOps.LOAD and mul_op.src[0].arg.dtype == tc.dtype_in): continue
346
- if not (mul_op.src[1].op == BufferOps.LOAD and mul_op.src[1].arg.dtype == tc.dtype_in): continue
347
- buf0, buf1 = self.bufs.index(cast(MemBuffer, mul_op.src[0].arg)), self.bufs.index(cast(MemBuffer, mul_op.src[1].arg))
348
- buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
349
- axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0] # noqa: E501
350
- axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0] # noqa: E501
351
-
352
- if not(axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%tc.dims[2] == 0 and self.full_shape[self.first_reduce] >= tc.dims[2] and (self.shape_len-self.first_reduce) == 1): continue # noqa: E501
353
-
354
- if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
355
-
356
- s0, s1 = axis_buf0[-1][0], axis_buf1[-1][0] # TODO: select axis in smart way
357
- s0_exists, s1_exists = True, True
358
- assert s0 != s1 and self.full_shape[s0]%tc.dims[0] == 0 and self.full_shape[s1]%tc.dims[1] == 0
359
- def fix(needed, ax):
360
- nonlocal s0, s1, s0_exists, s1_exists
361
- if not needed: return
362
- if s0_exists and ax == s0:
363
- if s1_exists and s0 < s1: s1 -= 1
364
- s0_exists = False
365
- elif s1_exists and ax == s1:
366
- if s0_exists and s1 < s0: s0 -= 1
367
- s1_exists = False
368
-
305
+ def _create_tc_opts(self, reduceop:LazyOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
306
+ has_cast = tc.dtype_in != tc.dtype_out
307
+ if has_cast and not(reduceop.src[0].op is UnaryOps.CAST and reduceop.src[0].arg == tc.dtype_out): return None
308
+
309
+ mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
310
+ if mul_op.op is not BinaryOps.MUL: return None
311
+
312
+ def buf_index(src: LazyOp) -> Optional[int]:
313
+ # TODO: apply tc even if the sources are not from LOAD
314
+ if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
315
+ try:
316
+ if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
317
+ except ValueError: return None
318
+ return None
319
+ if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
320
+
321
+ buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
322
+ axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
323
+ axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
324
+ if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
325
+
326
+ axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
327
+ if not(axis < len(axis_choices)): return None
328
+
329
+ s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
330
+ axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0)
331
+ if axis_pads and (opt_level < 2): return None
332
+ self.bufs_for_tensor_core[reduceop] = (buf0, buf1)
333
+ if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
334
+ return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
335
+
336
+ def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
337
+ if use_tensor_cores and self.opts.has_local and self.reduceop is not None and self.reduceop.op is ReduceOps.SUM:
338
+ for tc in self.opts.tensor_cores:
339
+ tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
340
+ # can only fuse reduces with the same tc options
341
+ assert all_same(tensor_core_opts)
342
+ if tensor_core_opts[0] is None: continue
369
343
  # tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
370
- self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]))
371
- self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads])))
344
+ self.tensor_core_opts = tc_opts = tensor_core_opts[0]
345
+
346
+ # attempt to pad the tensor axes that require it
347
+ try:
348
+ for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
349
+ except KernelOptError: continue
350
+ self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, tc.dims[2]), append_opt=False)
351
+ for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
352
+ if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
372
353
  for (tc_dim, tc_amt) in tc.threads:
373
- fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
354
+ self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
374
355
 
375
- # assert tensor core and prevent extra_opts from altering the key shape structure
356
+ # assert tensor core
376
357
  if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
358
+ return True
359
+ return False
377
360
 
361
+ def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, axis:int=0, tc_opt:Optional[int]=None) -> bool:
362
+ """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
363
+ Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
364
+
365
+ Keyword arguments:
366
+ use_tensor_cores -- controls how tensor cores are applied (default 1)
367
+ 0: will disable any tensor core matching
368
+ 1: enable tensor cores
369
+ 2: apply tensor core shape but don't use UOp.WMMA
370
+ extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
371
+ tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
372
+ 0: applies to only kernels with a single reduce axis and direct BufferOps.LOAD into BinaryOps.MUL
373
+ 1: allows kernels with multiple reduce axes and also multiplication of UnaryOps.CAST'd buffers
374
+ 2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
375
+ """
376
+ if tc_opt is None: tc_opt = self.opts.tc_opt
377
+ if not self.opts.tensor_cores and use_tensor_cores != 2: return False
378
+ try: # check TC first and apply hand-coded opts if successful
379
+ self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
380
+
381
+ if (tc_opts:=self.tensor_core_opts) is not None:
378
382
  if extra_opts is not None:
379
383
  for opt in extra_opts: self.apply_opt(opt)
380
384
  else:
381
385
  # hand-coded TC opts
382
- if s1_exists:
383
- s1_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s1]%upc == 0][0]
384
- if s1_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s1, s1_div)), s1)
385
- if s0_exists:
386
- s0_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s0]%upc == 0][0]
387
- if s0_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s0, s0_div)), s0)
388
- if self.tensor_core and s0_exists:
386
+ def late_upcast_tc(tc_dim: int):
387
+ if tc_opts.axes_exist[tc_dim]:
388
+ ax_div = [upc for upc in [5,4,3,2,1] if self.full_shape[tc_opts.axes[tc_dim]]%upc == 0][0]
389
+ if ax_div != 1: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], ax_div))
390
+ late_upcast_tc(1) # attempt to upcast M
391
+ late_upcast_tc(0) # attempt to upcast N
392
+
393
+ if self.tensor_core and tc_opts.axes_exist[0]: # attempt to local N
389
394
  for upc in [4,2]:
390
- if self.full_shape[s0] % upc == 0:
391
- self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc))
395
+ if self.full_shape[tc_opts.axes[0]] % upc == 0:
396
+ self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], upc))
392
397
  break
393
398
 
394
- # alias buffer
395
- alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) # noqa: E501
396
- self.alias_buffer(buf0, alias_pattern)
397
- self.alias_buffer(buf1, alias_pattern)
398
- return True
399
- return False
399
+ return True
400
+ except KernelOptError:
401
+ return False
402
+
403
+ def apply_opt(self, opt:Opt, append_opt:bool=True):
404
+ check(not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals")
405
+
406
+ if opt.op is OptOps.TC:
407
+ check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
408
+ check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
409
+ check((use_tensor_cores:=self.opts.tc) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
410
+ check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available")
411
+ self.applied_opts.append(opt)
412
+ return
413
+
414
+ axis = opt.real_axis(self)
415
+ check(axis < len(self.full_shape), "invalid axis")
400
416
 
401
- def apply_opt(self, opt:Opt):
402
- assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals" # noqa: E501
403
- self.applied_opts.append(opt)
404
- if opt.axis is not None:
405
- axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op in [OptOps.GROUP, OptOps.GROUPTOP] else 0)) # noqa: E501
406
- else:
407
- axis = -1
408
417
  if opt.amt is not None:
409
418
  amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
410
- assert isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless"
411
- if opt.op != OptOps.PADTO: assert self.full_shape[axis] % amt == 0, "no longer valid shift"
412
- else:
413
- amt = -1
414
- if opt.op in [OptOps.LOCAL, OptOps.LASTLOCAL]: # cyan
415
- assert self.opts.has_local, "target does not support local"
416
- assert axis < self.first_reduce, "can't local a reduce"
417
- if opt.op == OptOps.LOCAL:
418
- assert not self.tensor_core, "can't local with tensor cores"
419
- self.shift_to(axis, amt, insert_before=self.first_reduce)
420
- else:
421
- self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims)
419
+ check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless")
420
+ if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
421
+ else: amt = -1
422
+
423
+ if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
424
+ acc_sz, upcast_idx = dt.base.itemsize if isinstance((dt:=self.reduceop.dtype), ImageDType) else dt.itemsize, self.shape_len-self.upcasted
425
+ upcast_sz = prod([a for a,b in zip(self.full_shape[upcast_idx:], self.sts[0].shape[upcast_idx:]) if a == b])
426
+ local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
427
+ smem_sz = amt*acc_sz*upcast_sz*local_sz
428
+ check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
429
+
430
+ if opt.op is OptOps.LOCAL: # cyan
431
+ check(self.opts.has_local, "target does not support local")
432
+ check(axis < self.global_dims, "local is for globals")
433
+ self.shift_to(axis, amt, insert_before=self.first_reduce)
422
434
  self.local_dims += 1
423
- elif opt.op in [OptOps.GROUP, OptOps.GROUPTOP]: # green
424
- assert self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem"
425
- assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
426
- assert not self.tensor_core, "can't group with tensor cores"
427
- self.shift_to(axis, amt, top=(opt.op==OptOps.GROUPTOP), insert_before=self.first_reduce + len(self.group_for_reduce))
428
- self.group_for_reduce.append(amt)
429
- elif opt.op == OptOps.UNROLL: # purple
430
- assert axis < self.shape_len-self.upcasted, "can't upcasted already upcasted"
431
- assert amt <= 32, "don't unroll more than 32"
435
+ elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
436
+ check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
437
+ check(axis >= self.first_reduce + self.group_for_reduces and axis < self.shape_len-self.upcasted, "must be reduce axis to group")
438
+ check(not self.tensor_core, "can't group with tensor cores")
439
+ self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
440
+ self.group_for_reduces += 1
441
+ elif opt.op is OptOps.UNROLL: # purple
442
+ check(axis < self.shape_len-self.upcasted, "can't upcasted already upcasted")
443
+ check(amt <= 32, "don't unroll more than 32")
444
+ # TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
445
+ #upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
446
+ #self.shift_to(axis, amt, insert_before=None if upcast_count == 0 else self.shape_len-upcast_count)
447
+ if self.full_shape[axis] == amt and axis == self.first_reduce: self.local_dims += 1 # first_reduce will ++, so offset loss in simplify_ones
448
+ if self.full_shape[axis] == amt and axis < self.first_reduce+self.group_for_reduces: self.group_for_reduces -= 1 # fully unrolling a GROUP
432
449
  self.shift_to(axis, amt, insert_before=None)
433
450
  self.upcast()
434
- elif opt.op == OptOps.UPCAST: # yellow
435
- assert axis < self.first_reduce, "upcast is for non-reduce"
436
- assert amt <= 8, "don't upcast more than 8"
451
+ elif opt.op is OptOps.UPCAST: # yellow
452
+ check(axis < self.first_reduce, "upcast is for non-reduce")
453
+ check(not(self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
454
+ check(amt <= 8, "don't upcast more than 8")
437
455
  self.shift_to(axis, amt, insert_before=None)
438
456
  self.upcast()
439
- elif opt.op == OptOps.UPCASTMID: # white
440
- assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce" # noqa: E501
457
+ elif opt.op is OptOps.UPCASTMID: # white
458
+ check(self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
441
459
  axes = self.sts[0].unit_stride_axes()
442
- assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
443
- assert axes[0] == axis, "wrong axis"
444
- assert amt == 4, "don't upcast mid anything but 4"
445
- self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
446
- self.group_for_reduce.append(amt)
447
- elif opt.op == OptOps.NOLOCALS:
448
- assert self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals"
449
- assert self.local_dims == 0 and len(self.group_for_reduce) == 0, "can't have no locals with locals"
460
+ check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
461
+ check(axes[0] == axis, "wrong axis")
462
+ check(amt == 4, "don't upcast mid anything but 4")
463
+ self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces)
464
+ self.group_for_reduces += 1
465
+ elif opt.op is OptOps.NOLOCALS:
466
+ check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
467
+ check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
450
468
  self.dont_use_locals = True
451
- elif opt.op == OptOps.PADTO:
452
- assert not self.ast.vars(), "does not work with symbolic shape"
453
- assert axis < self.first_reduce, "cannot pad a reduce axis"
469
+ elif opt.op is OptOps.PADTO:
470
+ check(not self.vars, "does not work with symbolic shape")
471
+ check(axis < self.shape_len - self.upcasted, "cannot pad upcasted")
472
+ # ok to pad SUM if all parent ops have f(0) = 0
473
+ if self.first_reduce <= axis:
474
+ check((r:=cast(LazyOp, self.reduceop)).op is ReduceOps.SUM and \
475
+ all(op.op not in UNSAFE_PAD_OPS for ops in r.src for op in ops.lazyops), "cannot pad")
454
476
  padded = False
455
477
  for i,st in enumerate(self.sts):
456
- assert self.sts[i].shape[axis] > amt//2, "pad adds more than double the work"
457
- if (ru := round_up(self.sts[i].shape[axis], amt) - self.sts[i].shape[axis]):
478
+ if self.sts[i].shape[axis] == 1: continue # reduced
479
+ check(self.sts[i].shape[axis] > amt//4, f"pad adds more than quadruple the work {self.sts[i].shape[axis]=} > {amt//4=}")
480
+ if (ru := round_up(cast(int, self.sts[i].shape[axis]), cast(int, amt)) - self.sts[i].shape[axis]):
458
481
  # pad right seems to be faster
459
482
  self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
460
483
  padded = True
461
- assert padded, "nothing was padded"
462
- return self.simplify_ones()
484
+ check(padded, "nothing was padded")
485
+
486
+ if append_opt: self.applied_opts.append(opt)
487
+ if self.simplify_ones() and self.tensor_core_opts:
488
+ self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
463
489
 
464
490
  def required_optimizations(self):
465
491
  if self.bufs[0].dtype.__class__ is ImageDType:
@@ -474,8 +500,8 @@ class Kernel:
474
500
  # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
475
501
  MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
476
502
  if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
477
- self.reduceop and self.reduceop.op == ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
478
- (mulop:=self.reduceop.src[0]).op == BinaryOps.MUL and mulop.src[0].op == BufferOps.LOAD and mulop.src[1].op == BufferOps.LOAD:
503
+ self.reduceop is not None and self.reduceop.op is ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
504
+ (mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is BufferOps.LOAD and mulop.src[1].op is BufferOps.LOAD:
479
505
  st0, st1 = self.sts[self.bufs.index(mulop.src[0].arg)], self.sts[self.bufs.index(mulop.src[1].arg)]
480
506
  strides0, strides1 = st0.real_strides(), st1.real_strides()
481
507
  def has_expanded_axis(shape, strides): return any(s > 1 and st == 0 for s,st in zip(shape,strides))
@@ -495,11 +521,13 @@ class Kernel:
495
521
  # TODO: use 1024 if it's allowed in a smarter way
496
522
  for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
497
523
  if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
498
- self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
499
- break
524
+ try: # may fail due to excessive smem usage
525
+ self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
526
+ break
527
+ except KernelOptError: pass
500
528
 
501
529
  # are we upcasting in mid reduce? (only for images)
502
- if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
530
+ if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
503
531
  axes = self.sts[0].unit_stride_axes()
504
532
  assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
505
533
  if self.sts[0].shape[axes[0]]%4 == 0:
@@ -517,7 +545,7 @@ class Kernel:
517
545
  self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
518
546
 
519
547
  # no more opt if we are grouping
520
- if self.group_for_reduce: return
548
+ if self.group_for_reduces: return
521
549
 
522
550
  # **** below this line need to be optional and benchmarked ****
523
551
 
@@ -574,7 +602,7 @@ class Kernel:
574
602
  # **** local groups ****
575
603
 
576
604
  if self.opts.has_local:
577
- if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduce:
605
+ if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduces:
578
606
  self.apply_opt(Opt(OptOps.NOLOCALS))
579
607
  else:
580
608
  # prioritize making expand axes local