tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +253 -225
- tinygrad/codegen/linearizer.py +398 -436
- tinygrad/codegen/uops.py +451 -0
- tinygrad/device.py +268 -274
- tinygrad/dtype.py +56 -40
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +198 -0
- tinygrad/engine/realize.py +192 -0
- tinygrad/engine/schedule.py +370 -0
- tinygrad/engine/search.py +199 -0
- tinygrad/{mlops.py → function.py} +40 -32
- tinygrad/helpers.py +144 -46
- tinygrad/lazy.py +143 -242
- tinygrad/multi.py +173 -0
- tinygrad/nn/__init__.py +180 -9
- tinygrad/nn/datasets.py +8 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +87 -19
- tinygrad/ops.py +104 -45
- tinygrad/renderer/__init__.py +65 -0
- tinygrad/renderer/assembly.py +269 -0
- tinygrad/renderer/cstyle.py +308 -210
- tinygrad/renderer/llvmir.py +119 -124
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +13403 -0
- tinygrad/runtime/autogen/comgr.py +891 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5893 -0
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33597 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +56 -0
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +39 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +187 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +550 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +129 -37
- tinygrad/runtime/ops_disk.py +111 -43
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +41 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +625 -0
- tinygrad/runtime/ops_python.py +208 -0
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +46 -107
- tinygrad/shape/symbolic.py +99 -98
- tinygrad/shape/view.py +162 -45
- tinygrad/tensor.py +2492 -483
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/codegen/kernel.py
CHANGED
@@ -1,107 +1,99 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
|
3
|
-
|
4
|
-
from
|
5
|
-
from tinygrad.
|
2
|
+
from collections import defaultdict
|
3
|
+
import itertools
|
4
|
+
from typing import DefaultDict, Optional, List, Tuple, cast, Dict, Union
|
5
|
+
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS, verify_lazyop
|
6
|
+
from tinygrad.device import Device
|
7
|
+
from tinygrad.renderer import Renderer, TensorCore
|
6
8
|
from tinygrad.dtype import dtypes, ImageDType, DType
|
7
|
-
from tinygrad.helpers import
|
8
|
-
from tinygrad.shape.shapetracker import ShapeTracker
|
9
|
+
from tinygrad.helpers import all_same, colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
|
10
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
9
11
|
from tinygrad.shape.symbolic import sint
|
10
12
|
from tinygrad.shape.view import View, strides_for_shape
|
11
13
|
from dataclasses import dataclass
|
12
14
|
from enum import Enum, auto
|
13
15
|
|
14
16
|
class OptOps(Enum):
|
15
|
-
|
17
|
+
TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
16
18
|
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702
|
17
19
|
def __lt__(self, x:OptOps): return self.value < x.value
|
18
20
|
|
21
|
+
class KernelOptError(Exception): pass
|
22
|
+
|
23
|
+
def check(cond:bool, msg:str=""):
|
24
|
+
if not cond: raise KernelOptError(msg)
|
25
|
+
|
19
26
|
@dataclass(frozen=True, order=True)
|
20
27
|
class Opt:
|
21
28
|
op: OptOps
|
22
29
|
axis: Optional[int] = None
|
23
30
|
amt: Optional[int] = None
|
24
31
|
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
|
32
|
+
def real_axis(self, k:Kernel):
|
33
|
+
if self.axis is None: return -1
|
34
|
+
if self.op is OptOps.UNROLL: return k.first_reduce+self.axis
|
35
|
+
if self.op in {OptOps.GROUP, OptOps.GROUPTOP}: return k.first_reduce+k.group_for_reduces+self.axis
|
36
|
+
return self.axis
|
37
|
+
|
38
|
+
@dataclass
|
39
|
+
class TensorCoreOptions:
|
40
|
+
axes: Tuple[int, ...] # the location of the original N and M axes if still in the shape
|
41
|
+
axes_exist: Tuple[bool, ...] # true if the original N and M axes are still in the shape
|
42
|
+
axis_pads: Tuple[Tuple[int, int], ...]
|
43
|
+
def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when a dimension is removed
|
44
|
+
axes, axes_exist = list(self.axes), list(self.axes_exist)
|
45
|
+
for tc_dim in [i for i in range(2) if axes_exist[i]]:
|
46
|
+
if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1
|
47
|
+
elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False
|
48
|
+
self.axes, self.axes_exist = tuple(axes), tuple(axes_exist)
|
25
49
|
|
26
50
|
@dataclass(frozen=True)
|
27
|
-
class
|
28
|
-
device: str
|
29
|
-
dims: List[int]
|
30
|
-
dtype_in: DType
|
31
|
-
dtype_out: DType
|
32
|
-
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
|
33
|
-
upcast_dim: int # which TC dim to upcast
|
34
|
-
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
|
35
|
-
thread_local_sizes: List[int] # in each thread, the number of elements stored in registers for each TC dim
|
36
|
-
arch: Optional[str] = None
|
37
|
-
def __str__(self): return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>"
|
38
|
-
|
39
|
-
tensor_cores: Dict[str, List[TensorCore]] = {
|
40
|
-
"METAL": [
|
41
|
-
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
|
42
|
-
# TODO: enable half @ half -> half tensor core with correct dtypes in uop
|
43
|
-
# TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
|
44
|
-
],
|
45
|
-
"HIP": [
|
46
|
-
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
47
|
-
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
48
|
-
]
|
49
|
-
}
|
50
|
-
|
51
|
-
class LocalBuffer(NamedTuple):
|
51
|
+
class LocalBuffer:
|
52
52
|
name: str
|
53
53
|
size: int
|
54
54
|
dtype: DType = dtypes.float32
|
55
55
|
realized: None = None
|
56
56
|
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
|
57
57
|
|
58
|
-
class LinearizerOptions(NamedTuple):
|
59
|
-
device: str = ""
|
60
|
-
# TODO: make this generic with a list of supported types
|
61
|
-
supports_float4: bool = True
|
62
|
-
supports_float4_alu: bool = True
|
63
|
-
has_local: bool = True
|
64
|
-
has_shared: bool = True
|
65
|
-
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
|
66
|
-
global_max: Optional[List[int]] = None
|
67
|
-
local_max: Optional[List[int]] = None
|
68
|
-
|
69
58
|
class Kernel:
|
70
|
-
def __init__(self, ast:LazyOp, opts:Optional[
|
71
|
-
self.opts = opts
|
59
|
+
def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
|
60
|
+
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
61
|
+
verify_lazyop(*ast)
|
72
62
|
self.ast = ast
|
73
|
-
|
63
|
+
self.lazyops = flatten([op.lazyops for op in self.ast])
|
74
64
|
|
75
|
-
|
76
|
-
|
65
|
+
cached_ordered_lazyops: Dict[LazyOp, List[LazyOp]] = {}
|
66
|
+
def ordered_lazyops(op):
|
67
|
+
if op not in cached_ordered_lazyops: cached_ordered_lazyops[op] = dedup([item for x in op.src for item in ordered_lazyops(x)] + [op])
|
68
|
+
return cached_ordered_lazyops[op]
|
69
|
+
self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps])
|
77
70
|
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
self.reduceop = reduceops[0] if reduceops else None
|
71
|
+
self.outbufs, self.vars = [x.arg for x in self.ast], flatten([x.vars() for x in self.ast])
|
72
|
+
loadops = [BufferOps.LOAD, BufferOps.CONST]
|
73
|
+
self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])
|
82
74
|
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
# get earlybufs, before the one reduce op
|
87
|
-
self.earlybufs = [x.arg for x in self.reduceop.lazyops if x.op in BufferOps] if self.reduceop else []
|
75
|
+
# get earlybufs, before any reduceops
|
76
|
+
self.earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
|
88
77
|
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
|
89
78
|
|
90
79
|
# create new shapetrackers inside this kernel, we will permute them
|
91
80
|
self.sts: List[ShapeTracker] = [x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)]
|
92
81
|
|
93
82
|
# move all reduce axes to the end
|
94
|
-
reduce = list(enumerate(zip(self.full_shape, self.
|
83
|
+
reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
|
95
84
|
permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
|
96
85
|
self.reshape_and_permute(None, permute)
|
97
86
|
|
98
87
|
# parameters for optimization
|
99
88
|
self.applied_opts: List[Opt] = []
|
100
|
-
self.
|
89
|
+
self.group_for_reduces: int = 0
|
101
90
|
self.upcasted: int = 0
|
102
91
|
self.local_dims: int = 0
|
103
|
-
self.local_alias: Dict[int, LocalBuffer] =
|
92
|
+
self.local_alias: DefaultDict[LazyOp, Dict[int, LocalBuffer]] = defaultdict(dict)
|
104
93
|
self.tensor_core: Optional[TensorCore] = None
|
94
|
+
self.tensor_core_opts: Optional[TensorCoreOptions] = None
|
95
|
+
# the local aliased buffers for A and B
|
96
|
+
self.bufs_for_tensor_core: Dict[LazyOp, Tuple[int, int]] = {}
|
105
97
|
self.dont_use_locals: bool = False
|
106
98
|
|
107
99
|
# group simplifies
|
@@ -115,16 +107,18 @@ class Kernel:
|
|
115
107
|
ret = type(self).__new__(type(self))
|
116
108
|
|
117
109
|
# base linearizer params
|
118
|
-
ret.opts, ret.ast = self.opts, self.ast
|
110
|
+
ret.opts, ret.ast, ret.lazyops = self.opts, self.ast, self.lazyops
|
119
111
|
|
120
112
|
# things downstream of the AST
|
121
|
-
|
122
|
-
|
123
|
-
|
113
|
+
ret.reduceops, ret.outbufs, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \
|
114
|
+
self.reduceops, self.outbufs, self.vars, [x for x in self.bufs if not isinstance(x, LocalBuffer)], self.earlybufs, self.full_buf_index
|
115
|
+
ret.sts = self.sts[:len(ret.bufs)] # NOTE: must redo the local buffers with TC in beam
|
124
116
|
|
125
117
|
# parameters for optimizations
|
126
|
-
ret.applied_opts, ret.
|
127
|
-
self.applied_opts[:], self.
|
118
|
+
ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
|
119
|
+
self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
|
120
|
+
ret.tensor_core, ret.tensor_core_opts, ret.local_alias, ret.bufs_for_tensor_core = self.tensor_core, self.tensor_core_opts, defaultdict(dict), \
|
121
|
+
self.bufs_for_tensor_core
|
128
122
|
|
129
123
|
# uncached since linearize didn't run
|
130
124
|
ret.applied_opts_cache = None
|
@@ -138,9 +132,10 @@ class Kernel:
|
|
138
132
|
def shape_offsets(self, i:int): return itertools.product(*[list(range(cast(int, s))) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] # noqa: E501
|
139
133
|
def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] # noqa: E501
|
140
134
|
|
141
|
-
def upcasted_axis(self, i:int):
|
142
|
-
|
143
|
-
|
135
|
+
def upcasted_axis(self, i:int) -> List[Tuple[int, Optional[sint], bool]]:
|
136
|
+
upcasted_shape, upcasted_stride = self.sts[i].shape[self.shape_len-self.upcasted:], self.sts[i].real_strides()[self.shape_len-self.upcasted:]
|
137
|
+
assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
|
138
|
+
return list(zip(upcasted_shape, upcasted_stride,
|
144
139
|
[x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
|
145
140
|
|
146
141
|
# TODO: is there a better way to write this?
|
@@ -158,6 +153,9 @@ class Kernel:
|
|
158
153
|
def first_reduce(self) -> int:
|
159
154
|
return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True) # noqa: E501
|
160
155
|
|
156
|
+
@property
|
157
|
+
def reduceop(self) -> Optional[LazyOp]: return self.reduceops[0] if len(self.reduceops) > 0 else None
|
158
|
+
|
161
159
|
@property
|
162
160
|
def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape
|
163
161
|
|
@@ -172,7 +170,7 @@ class Kernel:
|
|
172
170
|
|
173
171
|
@property
|
174
172
|
def upcast_in_mid_reduce_axes(self) -> List[int]:
|
175
|
-
return [j for j in range(self.first_reduce, self.first_reduce+
|
173
|
+
return [j for j in range(self.first_reduce, self.first_reduce+self.group_for_reduces) if self.full_shape[j] == self.sts[0].shape[j]]
|
176
174
|
|
177
175
|
@property
|
178
176
|
def global_dims(self) -> int: return self.first_reduce-self.local_dims
|
@@ -192,10 +190,10 @@ class Kernel:
|
|
192
190
|
colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
|
193
191
|
# after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
|
194
192
|
colors += ["cyan"] * self.local_dims
|
195
|
-
# between first_reduce and first_reduce +
|
196
|
-
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce +
|
197
|
-
# between first_reduce +
|
198
|
-
colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce +
|
193
|
+
# between first_reduce and first_reduce + group_for_reduces, they are either upcast mid reduce (white), or late upcasted (green)
|
194
|
+
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + self.group_for_reduces)] # noqa: E501
|
195
|
+
# between first_reduce + group_for_reduces and upcasted, they are reduce (red)
|
196
|
+
colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + self.group_for_reduces))
|
199
197
|
# upcasted dimensions are reduce (magenta) or normal (yellow)
|
200
198
|
colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.shape_len-self.upcasted, self.shape_len)]
|
201
199
|
assert len(colors) == self.shape_len, "colors size mismatch"
|
@@ -219,7 +217,7 @@ class Kernel:
|
|
219
217
|
|
220
218
|
# drops the final dimension
|
221
219
|
def upcast(self):
|
222
|
-
|
220
|
+
check(self.full_shape[-1] != 1, "can't upcast a dimension with size 1")
|
223
221
|
self.upcasted += 1
|
224
222
|
|
225
223
|
# axis : the axis to pull from
|
@@ -242,7 +240,7 @@ class Kernel:
|
|
242
240
|
if self.shape_len == 0: return False
|
243
241
|
all_ones = [s==1 for s in self.full_shape]
|
244
242
|
self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
|
245
|
-
self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:])
|
243
|
+
self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:]) # TODO: no necessary since upcasted axis can't be un-upcasted
|
246
244
|
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
|
247
245
|
return any(all_ones)
|
248
246
|
|
@@ -254,7 +252,7 @@ class Kernel:
|
|
254
252
|
if isinstance(self.bufs[0].dtype, ImageDType):
|
255
253
|
base_shape = self.bufs[0].dtype.shape
|
256
254
|
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
|
257
|
-
special_strides: Tuple[
|
255
|
+
special_strides: Tuple[sint, ...] = tuple()
|
258
256
|
for i,g in enumerate(shape_idx_groups):
|
259
257
|
shape_piece = tuple(self.output_shape[x] for x in g)
|
260
258
|
assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
|
@@ -263,57 +261,29 @@ class Kernel:
|
|
263
261
|
shapes.append(self.output_shape)
|
264
262
|
strides.append(special_strides)
|
265
263
|
|
266
|
-
# merge dimensions if we can, multi
|
264
|
+
# merge dimensions if we can, multi _merge_dims
|
267
265
|
# NOTE: this does not always preserve the reduce dimension
|
268
266
|
# TODO: move this into shapetracker, with tests!
|
269
|
-
|
267
|
+
# TODO: how does this work with multi-reduce?
|
268
|
+
rets = [[(s[0], st[0])] for s,st in zip(shapes, strides)]
|
270
269
|
for i in range(1, len(shapes[0])):
|
271
270
|
can_merge = []
|
272
|
-
for
|
271
|
+
for s,st,ret in zip(shapes, strides, rets):
|
273
272
|
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
|
274
|
-
|
273
|
+
si, sti, last_st = s[i], st[i], ret[-1][1]
|
274
|
+
can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0)))
|
275
275
|
# more can merge than this
|
276
276
|
mergeable = all(can_merge) and i != self.first_reduce
|
277
|
-
for j in
|
278
|
-
if mergeable: rets[j][-1] = (rets[j][-1][0] *
|
279
|
-
else: rets[j].append((
|
277
|
+
for j,(s,st) in enumerate(zip(shapes, strides)):
|
278
|
+
if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i])
|
279
|
+
else: rets[j].append((s[i], st[i]))
|
280
280
|
|
281
281
|
# do the reshapes
|
282
282
|
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
283
283
|
|
284
|
-
# ********************
|
284
|
+
# ******************** helpers ********************
|
285
285
|
|
286
|
-
def
|
287
|
-
new_shape,dims = list(x), len(x)
|
288
|
-
for i in range(dims):
|
289
|
-
next_idx = (i + 1) % dims
|
290
|
-
while new_shape[i] > max_size[i]:
|
291
|
-
new_shape[i] = new_shape[i] // 2
|
292
|
-
if (new_shape[next_idx] <= max_size[next_idx]):
|
293
|
-
new_shape[next_idx] = new_shape[next_idx] * 2
|
294
|
-
else:
|
295
|
-
next_idx = (next_idx + 1) % dims
|
296
|
-
new_shape[next_idx] = new_shape[next_idx] * 2
|
297
|
-
return tuple(new_shape)
|
298
|
-
|
299
|
-
def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
|
300
|
-
# Check the global allocation limit, current the global_size will be flipped during codegen
|
301
|
-
# and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
|
302
|
-
global_dims = self.first_reduce-self.local_dims
|
303
|
-
if global_dims > 0:
|
304
|
-
if global_max:
|
305
|
-
tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else [])
|
306
|
-
if max(global_max) < max(self.full_shape[:global_dims]):
|
307
|
-
self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
|
308
|
-
assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}" # noqa: E501
|
309
|
-
for i in range(global_dims-1):
|
310
|
-
if i < len(global_max) and self.full_shape[i] > global_max[i]:
|
311
|
-
order = list(range(len(self.full_shape)))
|
312
|
-
order[i], order[global_dims-1] = order[global_dims-1], order[i]
|
313
|
-
self.reshape_and_permute(None, order)
|
314
|
-
if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit")
|
315
|
-
|
316
|
-
def alias_buffer(self, i, pattern):
|
286
|
+
def alias_buffer(self, op:LazyOp, i:int, pattern:List[int]) -> None:
|
317
287
|
assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
|
318
288
|
|
319
289
|
bst = 1
|
@@ -328,138 +298,194 @@ class Kernel:
|
|
328
298
|
self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
|
329
299
|
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size))
|
330
300
|
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
|
331
|
-
self.local_alias[i] = cast(LocalBuffer, self.bufs[-1])
|
301
|
+
self.local_alias[op][i] = cast(LocalBuffer, self.bufs[-1])
|
332
302
|
|
333
303
|
# ******************** high level optimizers ********************
|
334
304
|
|
335
|
-
def
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
if
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
305
|
+
def _create_tc_opts(self, reduceop:LazyOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
|
306
|
+
has_cast = tc.dtype_in != tc.dtype_out
|
307
|
+
if has_cast and not(reduceop.src[0].op is UnaryOps.CAST and reduceop.src[0].arg == tc.dtype_out): return None
|
308
|
+
|
309
|
+
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
|
310
|
+
if mul_op.op is not BinaryOps.MUL: return None
|
311
|
+
|
312
|
+
def buf_index(src: LazyOp) -> Optional[int]:
|
313
|
+
# TODO: apply tc even if the sources are not from LOAD
|
314
|
+
if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
|
315
|
+
try:
|
316
|
+
if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
|
317
|
+
except ValueError: return None
|
318
|
+
return None
|
319
|
+
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
|
320
|
+
|
321
|
+
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
322
|
+
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
|
323
|
+
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
|
324
|
+
if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
|
325
|
+
|
326
|
+
axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
|
327
|
+
if not(axis < len(axis_choices)): return None
|
328
|
+
|
329
|
+
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
|
330
|
+
axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0)
|
331
|
+
if axis_pads and (opt_level < 2): return None
|
332
|
+
self.bufs_for_tensor_core[reduceop] = (buf0, buf1)
|
333
|
+
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
334
|
+
return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
|
335
|
+
|
336
|
+
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
|
337
|
+
if use_tensor_cores and self.opts.has_local and self.reduceop is not None and self.reduceop.op is ReduceOps.SUM:
|
338
|
+
for tc in self.opts.tensor_cores:
|
339
|
+
tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
|
340
|
+
# can only fuse reduces with the same tc options
|
341
|
+
assert all_same(tensor_core_opts)
|
342
|
+
if tensor_core_opts[0] is None: continue
|
369
343
|
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
|
370
|
-
self.
|
371
|
-
|
344
|
+
self.tensor_core_opts = tc_opts = tensor_core_opts[0]
|
345
|
+
|
346
|
+
# attempt to pad the tensor axes that require it
|
347
|
+
try:
|
348
|
+
for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
349
|
+
except KernelOptError: continue
|
350
|
+
self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, tc.dims[2]), append_opt=False)
|
351
|
+
for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
|
352
|
+
if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
|
372
353
|
for (tc_dim, tc_amt) in tc.threads:
|
373
|
-
|
354
|
+
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
|
374
355
|
|
375
|
-
# assert tensor core
|
356
|
+
# assert tensor core
|
376
357
|
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
|
358
|
+
return True
|
359
|
+
return False
|
377
360
|
|
361
|
+
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, axis:int=0, tc_opt:Optional[int]=None) -> bool:
|
362
|
+
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
|
363
|
+
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
|
364
|
+
|
365
|
+
Keyword arguments:
|
366
|
+
use_tensor_cores -- controls how tensor cores are applied (default 1)
|
367
|
+
0: will disable any tensor core matching
|
368
|
+
1: enable tensor cores
|
369
|
+
2: apply tensor core shape but don't use UOp.WMMA
|
370
|
+
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
|
371
|
+
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
|
372
|
+
0: applies to only kernels with a single reduce axis and direct BufferOps.LOAD into BinaryOps.MUL
|
373
|
+
1: allows kernels with multiple reduce axes and also multiplication of UnaryOps.CAST'd buffers
|
374
|
+
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
|
375
|
+
"""
|
376
|
+
if tc_opt is None: tc_opt = self.opts.tc_opt
|
377
|
+
if not self.opts.tensor_cores and use_tensor_cores != 2: return False
|
378
|
+
try: # check TC first and apply hand-coded opts if successful
|
379
|
+
self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
|
380
|
+
|
381
|
+
if (tc_opts:=self.tensor_core_opts) is not None:
|
378
382
|
if extra_opts is not None:
|
379
383
|
for opt in extra_opts: self.apply_opt(opt)
|
380
384
|
else:
|
381
385
|
# hand-coded TC opts
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
386
|
+
def late_upcast_tc(tc_dim: int):
|
387
|
+
if tc_opts.axes_exist[tc_dim]:
|
388
|
+
ax_div = [upc for upc in [5,4,3,2,1] if self.full_shape[tc_opts.axes[tc_dim]]%upc == 0][0]
|
389
|
+
if ax_div != 1: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], ax_div))
|
390
|
+
late_upcast_tc(1) # attempt to upcast M
|
391
|
+
late_upcast_tc(0) # attempt to upcast N
|
392
|
+
|
393
|
+
if self.tensor_core and tc_opts.axes_exist[0]: # attempt to local N
|
389
394
|
for upc in [4,2]:
|
390
|
-
if self.full_shape[
|
391
|
-
self.apply_opt(Opt(OptOps.
|
395
|
+
if self.full_shape[tc_opts.axes[0]] % upc == 0:
|
396
|
+
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], upc))
|
392
397
|
break
|
393
398
|
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
399
|
+
return True
|
400
|
+
except KernelOptError:
|
401
|
+
return False
|
402
|
+
|
403
|
+
def apply_opt(self, opt:Opt, append_opt:bool=True):
|
404
|
+
check(not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals")
|
405
|
+
|
406
|
+
if opt.op is OptOps.TC:
|
407
|
+
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
|
408
|
+
check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
|
409
|
+
check((use_tensor_cores:=self.opts.tc) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
|
410
|
+
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available")
|
411
|
+
self.applied_opts.append(opt)
|
412
|
+
return
|
413
|
+
|
414
|
+
axis = opt.real_axis(self)
|
415
|
+
check(axis < len(self.full_shape), "invalid axis")
|
400
416
|
|
401
|
-
def apply_opt(self, opt:Opt):
|
402
|
-
assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals" # noqa: E501
|
403
|
-
self.applied_opts.append(opt)
|
404
|
-
if opt.axis is not None:
|
405
|
-
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op in [OptOps.GROUP, OptOps.GROUPTOP] else 0)) # noqa: E501
|
406
|
-
else:
|
407
|
-
axis = -1
|
408
417
|
if opt.amt is not None:
|
409
418
|
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
|
410
|
-
|
411
|
-
if opt.op
|
412
|
-
else:
|
413
|
-
|
414
|
-
if opt.op in
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
419
|
+
check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless")
|
420
|
+
if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
|
421
|
+
else: amt = -1
|
422
|
+
|
423
|
+
if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
|
424
|
+
acc_sz, upcast_idx = dt.base.itemsize if isinstance((dt:=self.reduceop.dtype), ImageDType) else dt.itemsize, self.shape_len-self.upcasted
|
425
|
+
upcast_sz = prod([a for a,b in zip(self.full_shape[upcast_idx:], self.sts[0].shape[upcast_idx:]) if a == b])
|
426
|
+
local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
|
427
|
+
smem_sz = amt*acc_sz*upcast_sz*local_sz
|
428
|
+
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
|
429
|
+
|
430
|
+
if opt.op is OptOps.LOCAL: # cyan
|
431
|
+
check(self.opts.has_local, "target does not support local")
|
432
|
+
check(axis < self.global_dims, "local is for globals")
|
433
|
+
self.shift_to(axis, amt, insert_before=self.first_reduce)
|
422
434
|
self.local_dims += 1
|
423
|
-
elif opt.op in
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
self.shift_to(axis, amt, top=(opt.op
|
428
|
-
self.
|
429
|
-
elif opt.op
|
430
|
-
|
431
|
-
|
435
|
+
elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
|
436
|
+
check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
|
437
|
+
check(axis >= self.first_reduce + self.group_for_reduces and axis < self.shape_len-self.upcasted, "must be reduce axis to group")
|
438
|
+
check(not self.tensor_core, "can't group with tensor cores")
|
439
|
+
self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
|
440
|
+
self.group_for_reduces += 1
|
441
|
+
elif opt.op is OptOps.UNROLL: # purple
|
442
|
+
check(axis < self.shape_len-self.upcasted, "can't upcasted already upcasted")
|
443
|
+
check(amt <= 32, "don't unroll more than 32")
|
444
|
+
# TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
|
445
|
+
#upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
|
446
|
+
#self.shift_to(axis, amt, insert_before=None if upcast_count == 0 else self.shape_len-upcast_count)
|
447
|
+
if self.full_shape[axis] == amt and axis == self.first_reduce: self.local_dims += 1 # first_reduce will ++, so offset loss in simplify_ones
|
448
|
+
if self.full_shape[axis] == amt and axis < self.first_reduce+self.group_for_reduces: self.group_for_reduces -= 1 # fully unrolling a GROUP
|
432
449
|
self.shift_to(axis, amt, insert_before=None)
|
433
450
|
self.upcast()
|
434
|
-
elif opt.op
|
435
|
-
|
436
|
-
|
451
|
+
elif opt.op is OptOps.UPCAST: # yellow
|
452
|
+
check(axis < self.first_reduce, "upcast is for non-reduce")
|
453
|
+
check(not(self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
|
454
|
+
check(amt <= 8, "don't upcast more than 8")
|
437
455
|
self.shift_to(axis, amt, insert_before=None)
|
438
456
|
self.upcast()
|
439
|
-
elif opt.op
|
440
|
-
|
457
|
+
elif opt.op is OptOps.UPCASTMID: # white
|
458
|
+
check(self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
|
441
459
|
axes = self.sts[0].unit_stride_axes()
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
self.shift_to(axis, amt, insert_before=self.first_reduce +
|
446
|
-
self.
|
447
|
-
elif opt.op
|
448
|
-
|
449
|
-
|
460
|
+
check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
|
461
|
+
check(axes[0] == axis, "wrong axis")
|
462
|
+
check(amt == 4, "don't upcast mid anything but 4")
|
463
|
+
self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces)
|
464
|
+
self.group_for_reduces += 1
|
465
|
+
elif opt.op is OptOps.NOLOCALS:
|
466
|
+
check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
|
467
|
+
check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
|
450
468
|
self.dont_use_locals = True
|
451
|
-
elif opt.op
|
452
|
-
|
453
|
-
|
469
|
+
elif opt.op is OptOps.PADTO:
|
470
|
+
check(not self.vars, "does not work with symbolic shape")
|
471
|
+
check(axis < self.shape_len - self.upcasted, "cannot pad upcasted")
|
472
|
+
# ok to pad SUM if all parent ops have f(0) = 0
|
473
|
+
if self.first_reduce <= axis:
|
474
|
+
check((r:=cast(LazyOp, self.reduceop)).op is ReduceOps.SUM and \
|
475
|
+
all(op.op not in UNSAFE_PAD_OPS for ops in r.src for op in ops.lazyops), "cannot pad")
|
454
476
|
padded = False
|
455
477
|
for i,st in enumerate(self.sts):
|
456
|
-
|
457
|
-
|
478
|
+
if self.sts[i].shape[axis] == 1: continue # reduced
|
479
|
+
check(self.sts[i].shape[axis] > amt//4, f"pad adds more than quadruple the work {self.sts[i].shape[axis]=} > {amt//4=}")
|
480
|
+
if (ru := round_up(cast(int, self.sts[i].shape[axis]), cast(int, amt)) - self.sts[i].shape[axis]):
|
458
481
|
# pad right seems to be faster
|
459
482
|
self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
|
460
483
|
padded = True
|
461
|
-
|
462
|
-
|
484
|
+
check(padded, "nothing was padded")
|
485
|
+
|
486
|
+
if append_opt: self.applied_opts.append(opt)
|
487
|
+
if self.simplify_ones() and self.tensor_core_opts:
|
488
|
+
self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
|
463
489
|
|
464
490
|
def required_optimizations(self):
|
465
491
|
if self.bufs[0].dtype.__class__ is ImageDType:
|
@@ -474,8 +500,8 @@ class Kernel:
|
|
474
500
|
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
475
501
|
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
|
476
502
|
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
477
|
-
self.reduceop and self.reduceop.op
|
478
|
-
(mulop:=self.reduceop.src[0]).op
|
503
|
+
self.reduceop is not None and self.reduceop.op is ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
|
504
|
+
(mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is BufferOps.LOAD and mulop.src[1].op is BufferOps.LOAD:
|
479
505
|
st0, st1 = self.sts[self.bufs.index(mulop.src[0].arg)], self.sts[self.bufs.index(mulop.src[1].arg)]
|
480
506
|
strides0, strides1 = st0.real_strides(), st1.real_strides()
|
481
507
|
def has_expanded_axis(shape, strides): return any(s > 1 and st == 0 for s,st in zip(shape,strides))
|
@@ -495,11 +521,13 @@ class Kernel:
|
|
495
521
|
# TODO: use 1024 if it's allowed in a smarter way
|
496
522
|
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
497
523
|
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
|
498
|
-
|
499
|
-
|
524
|
+
try: # may fail due to excessive smem usage
|
525
|
+
self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
|
526
|
+
break
|
527
|
+
except KernelOptError: pass
|
500
528
|
|
501
529
|
# are we upcasting in mid reduce? (only for images)
|
502
|
-
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.
|
530
|
+
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
|
503
531
|
axes = self.sts[0].unit_stride_axes()
|
504
532
|
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
505
533
|
if self.sts[0].shape[axes[0]]%4 == 0:
|
@@ -517,7 +545,7 @@ class Kernel:
|
|
517
545
|
self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
|
518
546
|
|
519
547
|
# no more opt if we are grouping
|
520
|
-
if self.
|
548
|
+
if self.group_for_reduces: return
|
521
549
|
|
522
550
|
# **** below this line need to be optional and benchmarked ****
|
523
551
|
|
@@ -574,7 +602,7 @@ class Kernel:
|
|
574
602
|
# **** local groups ****
|
575
603
|
|
576
604
|
if self.opts.has_local:
|
577
|
-
if getenv("NOLOCALS") and self.local_dims == 0 and not self.
|
605
|
+
if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduces:
|
578
606
|
self.apply_opt(Opt(OptOps.NOLOCALS))
|
579
607
|
else:
|
580
608
|
# prioritize making expand axes local
|