tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/codegen/kernel.py
CHANGED
@@ -1,21 +1,26 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
+
import itertools, functools
|
3
|
+
from dataclasses import dataclass
|
2
4
|
from collections import defaultdict
|
3
|
-
import
|
4
|
-
from
|
5
|
-
|
5
|
+
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence
|
6
|
+
from enum import Enum, auto
|
7
|
+
|
8
|
+
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \
|
9
|
+
graph_rewrite, track_rewrites, UPat
|
6
10
|
from tinygrad.device import Device
|
7
|
-
from tinygrad.renderer import Renderer, TensorCore
|
8
|
-
from tinygrad.dtype import
|
9
|
-
from tinygrad.helpers import all_same, colored, ansilen, dedup,
|
11
|
+
from tinygrad.renderer import Renderer, TensorCore, Program
|
12
|
+
from tinygrad.dtype import ImageDType
|
13
|
+
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap
|
14
|
+
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
|
10
15
|
from tinygrad.shape.shapetracker import ShapeTracker
|
11
|
-
from tinygrad.shape.
|
12
|
-
from tinygrad.
|
13
|
-
from
|
14
|
-
from
|
16
|
+
from tinygrad.shape.view import strides_for_shape
|
17
|
+
from tinygrad.codegen.linearize import linearize_uop
|
18
|
+
from tinygrad.codegen.uopgraph import full_graph_rewrite
|
19
|
+
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction
|
15
20
|
|
16
21
|
class OptOps(Enum):
|
17
22
|
TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
18
|
-
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702
|
23
|
+
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
19
24
|
def __lt__(self, x:OptOps): return self.value < x.value
|
20
25
|
|
21
26
|
class KernelOptError(Exception): pass
|
@@ -47,41 +52,41 @@ class TensorCoreOptions:
|
|
47
52
|
elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False
|
48
53
|
self.axes, self.axes_exist = tuple(axes), tuple(axes_exist)
|
49
54
|
|
50
|
-
@dataclass(frozen=True)
|
51
|
-
class LocalBuffer:
|
52
|
-
name: str
|
53
|
-
size: int
|
54
|
-
dtype: DType = dtypes.float32
|
55
|
-
realized: None = None
|
56
|
-
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
|
57
|
-
|
58
55
|
class Kernel:
|
59
|
-
def __init__(self,
|
56
|
+
def __init__(self, ast:UOp, opts:Optional[Renderer]=None):
|
57
|
+
if ast.op is Ops.SINK: self.ast = ast
|
58
|
+
|
60
59
|
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
61
|
-
|
62
|
-
|
63
|
-
|
60
|
+
try: uop_sts_map = verify_ast(self.ast)
|
61
|
+
except AssertionError as e:
|
62
|
+
print("INVALID AST")
|
63
|
+
print(self.ast)
|
64
|
+
raise e
|
64
65
|
|
65
|
-
|
66
|
-
def
|
67
|
-
|
68
|
-
return cached_ordered_lazyops[op]
|
69
|
-
self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps])
|
66
|
+
@functools.lru_cache(None)
|
67
|
+
def ordered_parents(op:UOp) -> List[UOp]: return dedup([item for x in op.src for item in ordered_parents(x)] + [op])
|
68
|
+
self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS])
|
70
69
|
|
71
|
-
self.
|
72
|
-
|
73
|
-
self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])
|
70
|
+
self.vars: List[Variable] = self.ast.variables()
|
71
|
+
self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in GroupOp.Buffer]
|
74
72
|
|
75
73
|
# get earlybufs, before any reduceops
|
76
|
-
|
77
|
-
self.full_buf_index: int = self.bufs.index(
|
74
|
+
earlybufs: List[UOp] = [x for reduceop in self.reduceops for x in reduceop.parents if x.op in GroupOp.Buffer]
|
75
|
+
self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
|
76
|
+
# NOTE: full_shape can be wrong if there's a tree of reduces
|
78
77
|
|
79
78
|
# create new shapetrackers inside this kernel, we will permute them
|
80
|
-
self.sts: List[ShapeTracker] = [x.
|
79
|
+
self.sts: List[ShapeTracker] = [x.st_arg for x in self.bufs]
|
80
|
+
|
81
|
+
# add the shapetrackers for each reduce
|
82
|
+
# we use this to track which axes are reduced in each reduce
|
83
|
+
for x in self.reduceops:
|
84
|
+
self.sts.append(uop_sts_map[x])
|
85
|
+
self.sts.append(uop_sts_map[x.src[0]])
|
81
86
|
|
82
87
|
# move all reduce axes to the end
|
83
88
|
reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
|
84
|
-
permute = tuple([i for i,(s,n) in reduce if s
|
89
|
+
permute = tuple([i for i,(s,n) in reduce if not resolve(s != n)] + [i for i,(s,n) in reduce if resolve(s != n)])
|
85
90
|
self.reshape_and_permute(None, permute)
|
86
91
|
|
87
92
|
# parameters for optimization
|
@@ -89,72 +94,57 @@ class Kernel:
|
|
89
94
|
self.group_for_reduces: int = 0
|
90
95
|
self.upcasted: int = 0
|
91
96
|
self.local_dims: int = 0
|
92
|
-
self.local_alias: DefaultDict[LazyOp, Dict[int, LocalBuffer]] = defaultdict(dict)
|
93
97
|
self.tensor_core: Optional[TensorCore] = None
|
94
98
|
self.tensor_core_opts: Optional[TensorCoreOptions] = None
|
99
|
+
self.use_tensor_cores: int = 0
|
95
100
|
# the local aliased buffers for A and B
|
96
|
-
self.bufs_for_tensor_core: Dict[
|
101
|
+
self.bufs_for_tensor_core: Dict[UOp, Tuple[int, int]] = {}
|
97
102
|
self.dont_use_locals: bool = False
|
98
103
|
|
99
104
|
# group simplifies
|
100
105
|
self.simplify_ones()
|
101
106
|
self.simplify_merge_adjacent()
|
102
107
|
|
103
|
-
# cache
|
104
|
-
self.applied_opts_cache: Optional[List[Opt]] = None
|
105
|
-
|
106
108
|
def copy(self):
|
107
109
|
ret = type(self).__new__(type(self))
|
108
110
|
|
109
111
|
# base linearizer params
|
110
|
-
ret.opts, ret.ast
|
112
|
+
ret.opts, ret.ast = self.opts, self.ast
|
111
113
|
|
112
114
|
# things downstream of the AST
|
113
|
-
ret.reduceops, ret.
|
114
|
-
self.reduceops, self.
|
115
|
-
ret.sts = self.sts[:len(ret.bufs)] # NOTE: must redo the local buffers with TC in beam
|
115
|
+
ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = \
|
116
|
+
self.reduceops, self.vars, self.bufs, self.full_buf_index
|
117
|
+
ret.sts = self.sts[:len(ret.bufs)+len(ret.reduceops)*2] # NOTE: must redo the local buffers with TC in beam
|
116
118
|
|
117
119
|
# parameters for optimizations
|
118
120
|
ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
|
119
121
|
self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
|
120
|
-
ret.tensor_core, ret.tensor_core_opts, ret.
|
121
|
-
|
122
|
-
|
123
|
-
# uncached since linearize didn't run
|
124
|
-
ret.applied_opts_cache = None
|
122
|
+
ret.tensor_core, ret.tensor_core_opts, ret.bufs_for_tensor_core, ret.use_tensor_cores = \
|
123
|
+
self.tensor_core, self.tensor_core_opts, self.bufs_for_tensor_core, self.use_tensor_cores
|
125
124
|
|
126
125
|
return ret
|
127
126
|
|
128
127
|
@property
|
129
|
-
def membufs(self) -> List[
|
128
|
+
def membufs(self) -> List[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
|
130
129
|
|
131
130
|
# TODO: these need more tests or it might silently be no-op
|
132
|
-
def
|
133
|
-
def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] # noqa: E501
|
131
|
+
def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501
|
134
132
|
|
135
133
|
def upcasted_axis(self, i:int) -> List[Tuple[int, Optional[sint], bool]]:
|
136
|
-
upcasted_shape, upcasted_stride = self.sts[i].shape[self.
|
134
|
+
upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:]
|
137
135
|
assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
|
138
136
|
return list(zip(upcasted_shape, upcasted_stride,
|
139
|
-
[x!=y for x,y in zip(self.sts[0].shape[self.
|
140
|
-
|
141
|
-
# TODO: is there a better way to write this?
|
142
|
-
def acc_offsets(self, i:int) -> List[int]:
|
143
|
-
if self.upcasted == 0: return [0]
|
144
|
-
upcasted_i = self.upcasted_axis(i)
|
145
|
-
acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
|
146
|
-
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]
|
147
|
-
|
148
|
-
def get_float4_upcast_dim(self, i:int) -> List[int]:
|
149
|
-
should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in (dtypes.float, dtypes.half) or isinstance(self.bufs[i].dtype, ImageDType))
|
150
|
-
return [x for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1] if should_upcast else []
|
137
|
+
[x!=y for x,y in zip(self.sts[0].shape[self.first_upcast:], self.full_shape[self.first_upcast:])]))
|
151
138
|
|
152
139
|
@property
|
153
140
|
def first_reduce(self) -> int:
|
154
|
-
return [x!=y for x,y in zip(self.sts[0].shape[:self.
|
141
|
+
return [resolve(x!=y) for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True)
|
155
142
|
|
156
143
|
@property
|
157
|
-
def
|
144
|
+
def first_upcast(self) -> int: return self.shape_len-self.upcasted
|
145
|
+
|
146
|
+
@property
|
147
|
+
def reduceop(self) -> Optional[UOp]: return self.reduceops[0] if len(self.reduceops) > 0 else None
|
158
148
|
|
159
149
|
@property
|
160
150
|
def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape
|
@@ -163,7 +153,7 @@ class Kernel:
|
|
163
153
|
def full_shape(self) -> Tuple[sint, ...]: return self.sts[self.full_buf_index].shape
|
164
154
|
|
165
155
|
@property
|
166
|
-
def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.
|
156
|
+
def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.first_upcast]
|
167
157
|
|
168
158
|
@property
|
169
159
|
def shape_len(self) -> int: return len(self.sts[0].shape)
|
@@ -193,27 +183,25 @@ class Kernel:
|
|
193
183
|
# between first_reduce and first_reduce + group_for_reduces, they are either upcast mid reduce (white), or late upcasted (green)
|
194
184
|
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + self.group_for_reduces)] # noqa: E501
|
195
185
|
# between first_reduce + group_for_reduces and upcasted, they are reduce (red)
|
196
|
-
colors += ["red"] * (
|
186
|
+
colors += ["red"] * (self.first_upcast - (self.first_reduce + self.group_for_reduces))
|
197
187
|
# upcasted dimensions are reduce (magenta) or normal (yellow)
|
198
|
-
colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.
|
188
|
+
colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.first_upcast, self.shape_len)]
|
199
189
|
assert len(colors) == self.shape_len, "colors size mismatch"
|
200
190
|
return colors
|
201
191
|
|
202
192
|
def colored_shape(self, pad:Optional[int]=None, dense=False) -> str:
|
203
|
-
|
193
|
+
shape_strs = [(s if dense else f"{s:4d}") if isinstance(s, int) else s.render() for s in self.full_shape]
|
194
|
+
ret = ' '.join(colored(s, color) for s,color in zip(shape_strs, self.colors()))
|
204
195
|
if pad: ret += ' '*(pad-ansilen(ret))
|
205
196
|
return ret
|
206
197
|
|
207
198
|
# ******************** base simplifiers ********************
|
208
199
|
|
209
200
|
# apply reshape and permute to all shapetrackers
|
210
|
-
def reshape_and_permute(self, new_shape_fxn, axis):
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
if axis is not None: st = st.permute(tuple(axis))
|
215
|
-
new_sts.append(st)
|
216
|
-
self.sts = new_sts
|
201
|
+
def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[Tuple[sint, ...]], Sequence[sint]]], axis:Optional[Sequence[int]]):
|
202
|
+
def reshape(st:ShapeTracker): return st.reshape(tuple(new_shape_fxn(st.shape))) if new_shape_fxn is not None else st
|
203
|
+
def permute(st:ShapeTracker): return st.permute(tuple(axis)) if axis is not None else st
|
204
|
+
self.sts = [permute(reshape(st)) for st in self.sts]
|
217
205
|
|
218
206
|
# drops the final dimension
|
219
207
|
def upcast(self):
|
@@ -229,7 +217,7 @@ class Kernel:
|
|
229
217
|
move_axis = axis if top else axis+1
|
230
218
|
if move_axis < insert_before: insert_before += 1
|
231
219
|
self.reshape_and_permute(
|
232
|
-
lambda x:
|
220
|
+
lambda x: x[0:axis] + (((amount, x[axis]//amount) if top else (x[axis]//amount, amount)) if x[axis] > 1 else (1,1)) + x[axis+1:],
|
233
221
|
[i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])
|
234
222
|
|
235
223
|
# ******************** complex simplifiers ********************
|
@@ -240,7 +228,7 @@ class Kernel:
|
|
240
228
|
if self.shape_len == 0: return False
|
241
229
|
all_ones = [s==1 for s in self.full_shape]
|
242
230
|
self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
|
243
|
-
self.upcasted -= sum(all_ones[self.
|
231
|
+
self.upcasted -= sum(all_ones[self.first_upcast:]) # TODO: no necessary since upcasted axis can't be un-upcasted
|
244
232
|
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
|
245
233
|
return any(all_ones)
|
246
234
|
|
@@ -249,8 +237,8 @@ class Kernel:
|
|
249
237
|
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
|
250
238
|
|
251
239
|
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
|
252
|
-
if isinstance(self.
|
253
|
-
base_shape = self.
|
240
|
+
if isinstance(self.membufs[0].dtype, ImageDType):
|
241
|
+
base_shape = self.membufs[0].dtype.shape
|
254
242
|
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
|
255
243
|
special_strides: Tuple[sint, ...] = tuple()
|
256
244
|
for i,g in enumerate(shape_idx_groups):
|
@@ -281,39 +269,20 @@ class Kernel:
|
|
281
269
|
# do the reshapes
|
282
270
|
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
283
271
|
|
284
|
-
# ******************** helpers ********************
|
285
|
-
|
286
|
-
def alias_buffer(self, op:LazyOp, i:int, pattern:List[int]) -> None:
|
287
|
-
assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
|
288
|
-
|
289
|
-
bst = 1
|
290
|
-
real_strides = self.sts[i].real_strides()
|
291
|
-
shp, stride = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern)
|
292
|
-
for priority in range(1, max(pattern)+1): # priority. 0 is non local and ignored
|
293
|
-
for j,p in enumerate(pattern):
|
294
|
-
if priority == p and real_strides[j] != 0:
|
295
|
-
stride[j] = bst
|
296
|
-
bst *= shp[j]
|
297
|
-
|
298
|
-
self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
|
299
|
-
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size))
|
300
|
-
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
|
301
|
-
self.local_alias[op][i] = cast(LocalBuffer, self.bufs[-1])
|
302
|
-
|
303
272
|
# ******************** high level optimizers ********************
|
304
273
|
|
305
|
-
def _create_tc_opts(self, reduceop:
|
274
|
+
def _create_tc_opts(self, reduceop:UOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
|
306
275
|
has_cast = tc.dtype_in != tc.dtype_out
|
307
|
-
if has_cast and not(reduceop.src[0].op is
|
276
|
+
if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
|
308
277
|
|
309
278
|
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
|
310
|
-
if mul_op.op is not
|
279
|
+
if mul_op.op is not Ops.MUL: return None
|
311
280
|
|
312
|
-
def buf_index(src:
|
281
|
+
def buf_index(src:UOp) -> Optional[int]:
|
313
282
|
# TODO: apply tc even if the sources are not from LOAD
|
314
|
-
if src.op is
|
283
|
+
if src.op is Ops.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(src)
|
315
284
|
try:
|
316
|
-
if opt_level >= 1 and src.op is
|
285
|
+
if opt_level >= 1 and src.op is Ops.CAST and src.dtype == tc.dtype_in: return self.bufs.index(src.src[0])
|
317
286
|
except ValueError: return None
|
318
287
|
return None
|
319
288
|
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
|
@@ -321,40 +290,40 @@ class Kernel:
|
|
321
290
|
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
322
291
|
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
|
323
292
|
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
|
324
|
-
if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
|
293
|
+
if not (axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
|
325
294
|
|
326
295
|
axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
|
327
|
-
if not(axis < len(axis_choices)): return None
|
296
|
+
if not (axis < len(axis_choices)): return None
|
328
297
|
|
329
298
|
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
|
330
|
-
axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0)
|
299
|
+
axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if resolve(self.full_shape[x]%tc.dims[i] != 0))
|
331
300
|
if axis_pads and (opt_level < 2): return None
|
332
301
|
self.bufs_for_tensor_core[reduceop] = (buf0, buf1)
|
333
302
|
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
334
303
|
return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
|
335
304
|
|
336
305
|
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
|
337
|
-
if use_tensor_cores and self.
|
306
|
+
if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD:
|
338
307
|
for tc in self.opts.tensor_cores:
|
339
308
|
tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
|
340
309
|
# can only fuse reduces with the same tc options
|
341
310
|
assert all_same(tensor_core_opts)
|
342
311
|
if tensor_core_opts[0] is None: continue
|
343
|
-
# tensor core -- unroll the reduce dim, upcast input
|
312
|
+
# tensor core -- unroll the reduce dim, upcast input and local the correct thread pattern
|
344
313
|
self.tensor_core_opts = tc_opts = tensor_core_opts[0]
|
345
314
|
|
346
315
|
# attempt to pad the tensor axes that require it
|
347
316
|
try:
|
348
317
|
for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
349
318
|
except KernelOptError: continue
|
350
|
-
self.apply_opt(Opt(OptOps.UNROLL,
|
351
|
-
for
|
352
|
-
if
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
319
|
+
for tc_dim, amt in tc.reduce_axes: self.apply_opt(Opt(OptOps.UNROLL,tc_opts.axes[2]-self.first_reduce,amt), append_opt=False)
|
320
|
+
for opt in tc.opts_seq:
|
321
|
+
if opt == "UP":
|
322
|
+
for tc_dim, amt in tc.early_upcast_axes: self.apply_opt(Opt(OptOps.UPCAST,tc_opts.axes[tc_dim],amt), append_opt=False)
|
323
|
+
elif opt == "LC":
|
324
|
+
for tc_dim, amt in tc.threads: self.apply_opt(Opt(OptOps.LOCAL,tc_opts.axes[tc_dim],amt), append_opt=False)
|
325
|
+
self.tensor_core = tc
|
326
|
+
self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
|
358
327
|
return True
|
359
328
|
return False
|
360
329
|
|
@@ -369,11 +338,11 @@ class Kernel:
|
|
369
338
|
2: apply tensor core shape but don't use UOp.WMMA
|
370
339
|
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
|
371
340
|
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
|
372
|
-
0: applies to only kernels with a single reduce axis and direct
|
373
|
-
1: allows kernels with multiple reduce axes and also multiplication of
|
341
|
+
0: applies to only kernels with a single reduce axis and direct UOps.LOAD into Ops.MUL
|
342
|
+
1: allows kernels with multiple reduce axes and also multiplication of UOps.CAST'd buffers
|
374
343
|
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
|
375
344
|
"""
|
376
|
-
if tc_opt is None: tc_opt =
|
345
|
+
if tc_opt is None: tc_opt = TC_OPT.value
|
377
346
|
if not self.opts.tensor_cores and use_tensor_cores != 2: return False
|
378
347
|
try: # check TC first and apply hand-coded opts if successful
|
379
348
|
self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
|
@@ -382,31 +351,25 @@ class Kernel:
|
|
382
351
|
if extra_opts is not None:
|
383
352
|
for opt in extra_opts: self.apply_opt(opt)
|
384
353
|
else:
|
354
|
+
if (self.opts.device == "CLANG" and AMX): return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
|
385
355
|
# hand-coded TC opts
|
386
|
-
|
387
|
-
if tc_opts.
|
388
|
-
|
389
|
-
if ax_div != 1: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], ax_div))
|
390
|
-
late_upcast_tc(1) # attempt to upcast M
|
391
|
-
late_upcast_tc(0) # attempt to upcast N
|
392
|
-
|
393
|
-
if self.tensor_core and tc_opts.axes_exist[0]: # attempt to local N
|
394
|
-
for upc in [4,2]:
|
395
|
-
if self.full_shape[tc_opts.axes[0]] % upc == 0:
|
396
|
-
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], upc))
|
397
|
-
break
|
356
|
+
for tc_dim in [tc_dim for tc_dim in [1,0] if tc_opts.axes_exist[tc_dim]]: # attempt to upcast M and N
|
357
|
+
szs = [sz for sz in [5,4,3,2] if self.full_shape[tc_opts.axes[tc_dim]] % sz == 0]
|
358
|
+
if szs: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], szs[0]))
|
398
359
|
|
360
|
+
if tc_opts.axes_exist[0] and (szs := [sz for sz in [4,2] if self.full_shape[tc_opts.axes[0]] % sz == 0]): # attempt to local N
|
361
|
+
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], szs[0]))
|
399
362
|
return True
|
400
363
|
except KernelOptError:
|
401
364
|
return False
|
402
365
|
|
403
366
|
def apply_opt(self, opt:Opt, append_opt:bool=True):
|
404
|
-
|
367
|
+
if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals")
|
405
368
|
|
406
369
|
if opt.op is OptOps.TC:
|
407
370
|
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
|
408
371
|
check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
|
409
|
-
check((use_tensor_cores:=
|
372
|
+
check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
|
410
373
|
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available")
|
411
374
|
self.applied_opts.append(opt)
|
412
375
|
return
|
@@ -414,15 +377,17 @@ class Kernel:
|
|
414
377
|
axis = opt.real_axis(self)
|
415
378
|
check(axis < len(self.full_shape), "invalid axis")
|
416
379
|
|
417
|
-
if opt.amt is
|
380
|
+
if opt.op is OptOps.SWAP: amt = cast(int, opt.amt) # amt is an axis in the SWAPs
|
381
|
+
elif opt.amt is not None:
|
418
382
|
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
|
419
383
|
check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless")
|
420
384
|
if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
|
421
385
|
else: amt = -1
|
422
386
|
|
423
|
-
if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or
|
424
|
-
|
425
|
-
|
387
|
+
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
|
388
|
+
(self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
|
389
|
+
acc_sz = self.reduceop.dtype.itemsize
|
390
|
+
upcast_sz = prod([a for a,b in zip(self.full_shape[self.first_upcast:], self.sts[0].shape[self.first_upcast:]) if a == b])
|
426
391
|
local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
|
427
392
|
smem_sz = amt*acc_sz*upcast_sz*local_sz
|
428
393
|
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
|
@@ -434,12 +399,13 @@ class Kernel:
|
|
434
399
|
self.local_dims += 1
|
435
400
|
elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
|
436
401
|
check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
|
437
|
-
check(
|
402
|
+
check(self.first_reduce + self.group_for_reduces <= axis < self.first_upcast, "must be reduce axis to group")
|
438
403
|
check(not self.tensor_core, "can't group with tensor cores")
|
404
|
+
check(len(reduce_axes:=[i for r in self.reduceops for i in r.axis_arg]) == len(set(reduce_axes)), "can't group with parallel reduces")
|
439
405
|
self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
|
440
406
|
self.group_for_reduces += 1
|
441
407
|
elif opt.op is OptOps.UNROLL: # purple
|
442
|
-
check(axis < self.
|
408
|
+
check(axis < self.first_upcast, "can't upcasted already upcasted")
|
443
409
|
check(amt <= 32, "don't unroll more than 32")
|
444
410
|
# TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
|
445
411
|
#upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
|
@@ -450,12 +416,12 @@ class Kernel:
|
|
450
416
|
self.upcast()
|
451
417
|
elif opt.op is OptOps.UPCAST: # yellow
|
452
418
|
check(axis < self.first_reduce, "upcast is for non-reduce")
|
453
|
-
check(not(self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
|
454
|
-
check(amt <=
|
419
|
+
check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
|
420
|
+
check(amt <= 16, "don't upcast more than 16")
|
455
421
|
self.shift_to(axis, amt, insert_before=None)
|
456
422
|
self.upcast()
|
457
423
|
elif opt.op is OptOps.UPCASTMID: # white
|
458
|
-
check(self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
|
424
|
+
check(self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
|
459
425
|
axes = self.sts[0].unit_stride_axes()
|
460
426
|
check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
|
461
427
|
check(axes[0] == axis, "wrong axis")
|
@@ -466,18 +432,21 @@ class Kernel:
|
|
466
432
|
check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
|
467
433
|
check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
|
468
434
|
self.dont_use_locals = True
|
435
|
+
elif opt.op is OptOps.SWAP:
|
436
|
+
check(axis < amt < self.global_dims, f"swap is only for globals with axis < amt, getting {amt=}, {axis=}, {self.global_dims=}")
|
437
|
+
permute = list(range(self.shape_len))
|
438
|
+
permute[axis], permute[amt] = permute[amt], permute[axis]
|
439
|
+
self.reshape_and_permute(None, tuple(permute))
|
469
440
|
elif opt.op is OptOps.PADTO:
|
470
441
|
check(not self.vars, "does not work with symbolic shape")
|
471
|
-
check(axis < self.
|
472
|
-
# ok to pad SUM if all parent ops have f(0) = 0
|
473
|
-
if self.first_reduce <= axis:
|
474
|
-
check((r:=cast(LazyOp, self.reduceop)).op is ReduceOps.SUM and \
|
475
|
-
all(op.op not in UNSAFE_PAD_OPS for ops in r.src for op in ops.lazyops), "cannot pad")
|
442
|
+
check(axis < self.first_upcast, "cannot pad upcasted")
|
443
|
+
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
444
|
+
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r), f"cannot pad {r}")
|
476
445
|
padded = False
|
477
446
|
for i,st in enumerate(self.sts):
|
478
|
-
if
|
479
|
-
check(
|
480
|
-
if (ru := round_up(cast(int,
|
447
|
+
if (s:=st.shape[axis]) == 1: continue # reduced
|
448
|
+
check(s > amt//4, f"pad adds more than quadruple the work {st.shape[axis]=} > {amt//4=}")
|
449
|
+
if (ru := round_up(cast(int, s), amt) - s):
|
481
450
|
# pad right seems to be faster
|
482
451
|
self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
|
483
452
|
padded = True
|
@@ -487,24 +456,25 @@ class Kernel:
|
|
487
456
|
if self.simplify_ones() and self.tensor_core_opts:
|
488
457
|
self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
|
489
458
|
|
490
|
-
def required_optimizations(self):
|
491
|
-
if self.
|
459
|
+
def required_optimizations(self) -> Kernel:
|
460
|
+
if isinstance(self.membufs[0].dtype, ImageDType):
|
492
461
|
unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
|
493
|
-
assert
|
494
|
-
if
|
462
|
+
assert unit_stride_axes_mul_4, f"needs a unit stride axis in {self.bufs[0]}"
|
463
|
+
if all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
|
495
464
|
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
465
|
+
return self
|
496
466
|
|
497
|
-
def hand_coded_optimizations(self):
|
467
|
+
def hand_coded_optimizations(self) -> Kernel:
|
498
468
|
self.required_optimizations()
|
499
469
|
|
500
470
|
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
501
471
|
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
|
502
472
|
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
503
|
-
self.reduceop is not None and self.reduceop.
|
504
|
-
(mulop:=self.reduceop.src[0]).op is
|
505
|
-
st0, st1 = self.sts[self.bufs.index(mulop.src[0]
|
473
|
+
self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \
|
474
|
+
(mulop:=self.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
|
475
|
+
st0, st1 = self.sts[self.bufs.index(mulop.src[0])], self.sts[self.bufs.index(mulop.src[1])]
|
506
476
|
strides0, strides1 = st0.real_strides(), st1.real_strides()
|
507
|
-
def has_expanded_axis(shape, strides): return any(s > 1 and st
|
477
|
+
def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
|
508
478
|
if strides0[self.first_reduce] == 1 and not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
|
509
479
|
for global_idx in range(self.global_dims):
|
510
480
|
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
@@ -513,13 +483,13 @@ class Kernel:
|
|
513
483
|
if MV_THREADS_PER_ROW > 1: self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
514
484
|
if MV_BLOCKSIZE > 1: self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
515
485
|
if MV_ROWS_PER_THREAD > 1: self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
|
516
|
-
return
|
486
|
+
return self
|
517
487
|
|
518
488
|
if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]):
|
519
489
|
# are we grouping? (requires local shape support)
|
520
490
|
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # noqa: E501
|
521
491
|
# TODO: use 1024 if it's allowed in a smarter way
|
522
|
-
for sz in (
|
492
|
+
for sz in ([256, 16] if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
523
493
|
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
|
524
494
|
try: # may fail due to excessive smem usage
|
525
495
|
self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
|
@@ -527,7 +497,7 @@ class Kernel:
|
|
527
497
|
except KernelOptError: pass
|
528
498
|
|
529
499
|
# are we upcasting in mid reduce? (only for images)
|
530
|
-
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
|
500
|
+
if self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
|
531
501
|
axes = self.sts[0].unit_stride_axes()
|
532
502
|
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
533
503
|
if self.sts[0].shape[axes[0]]%4 == 0:
|
@@ -536,21 +506,21 @@ class Kernel:
|
|
536
506
|
# upcast float4 images
|
537
507
|
for buf_index,buf in enumerate(self.bufs):
|
538
508
|
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
|
539
|
-
if buf.dtype.__class__ is ImageDType:
|
509
|
+
if buf.src[0].dtype.__class__ is ImageDType:
|
540
510
|
#assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
|
541
|
-
if len(unit_stride_axes_mul_4) and all(x <
|
511
|
+
if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
|
542
512
|
if unit_stride_axes_mul_4[0] < self.first_reduce:
|
543
513
|
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
544
514
|
else:
|
545
515
|
self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
|
546
516
|
|
547
517
|
# no more opt if we are grouping
|
548
|
-
if self.group_for_reduces: return
|
518
|
+
if self.group_for_reduces: return self
|
549
519
|
|
550
520
|
# **** below this line need to be optional and benchmarked ****
|
551
521
|
|
552
522
|
# TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
|
553
|
-
# to trigger the above bug, remove prod(self.full_shape[self.
|
523
|
+
# to trigger the above bug, remove prod(self.full_shape[self.first_upcast:]) from the below
|
554
524
|
# expression and run test/test_ops.py with IMAGE=2
|
555
525
|
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
556
526
|
# this can be made much smarter
|
@@ -560,14 +530,14 @@ class Kernel:
|
|
560
530
|
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
|
561
531
|
# for now skip upcasting here if there is a symbolic axis
|
562
532
|
if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \
|
563
|
-
prod(self.full_shape[self.
|
533
|
+
prod(self.full_shape[self.first_upcast:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
|
564
534
|
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
565
535
|
to_upcast.append(axis)
|
566
536
|
for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
567
537
|
|
568
538
|
# potentially do more upcasts of non reduce axes based on a heuristic
|
569
539
|
upcasted_axis = set()
|
570
|
-
while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
|
540
|
+
while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024):
|
571
541
|
xb_choices = []
|
572
542
|
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
573
543
|
# if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
@@ -581,11 +551,11 @@ class Kernel:
|
|
581
551
|
else: break
|
582
552
|
|
583
553
|
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
|
584
|
-
if self.first_reduce <
|
585
|
-
if (s:=self.full_unupcasted_shape[-1])
|
554
|
+
if self.first_reduce < self.first_upcast and (prod(self.full_shape[self.first_upcast:]) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): # noqa: E501
|
555
|
+
if isinstance(s:=self.full_unupcasted_shape[-1], int) and s <= 32: # NOTE: cannot loop unroll symbolic axis
|
586
556
|
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
587
557
|
# if it's small, upcast a second reduce dimension too
|
588
|
-
if self.first_reduce <
|
558
|
+
if self.first_reduce < self.first_upcast and s <= 3 and isinstance(s2:=self.full_unupcasted_shape[-1], int) and s2 <= 3:
|
589
559
|
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
590
560
|
else:
|
591
561
|
for splits in [4]:
|
@@ -618,3 +588,166 @@ class Kernel:
|
|
618
588
|
will_delete_shape = local_sz == self.full_shape[axis]
|
619
589
|
self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
|
620
590
|
if will_delete_shape: deleted_shape += 1
|
591
|
+
|
592
|
+
return self
|
593
|
+
|
594
|
+
# **** kernel outputs ****
|
595
|
+
|
596
|
+
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
597
|
+
@functools.cached_property
|
598
|
+
def name(self) -> str:
|
599
|
+
# kernel name (before late upcast)
|
600
|
+
kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op in GroupOp.Buffer for x in self.ast.parents) else "E")
|
601
|
+
suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
602
|
+
name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix
|
603
|
+
|
604
|
+
# name the function something unique
|
605
|
+
Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1
|
606
|
+
num = f"n{Kernel.kernel_cnt[function_name]-1}" if Kernel.kernel_cnt[function_name] > 1 else ""
|
607
|
+
return name + colored(num, 'BLACK')
|
608
|
+
|
609
|
+
def get_optimized_ast(self) -> UOp:
|
610
|
+
@functools.lru_cache(None)
|
611
|
+
def fixup_ast(op:UOp) -> UOp:
|
612
|
+
ret = op.replace(src=tuple(fixup_ast(x) for x in op.src))
|
613
|
+
if op.op in GroupOp.Buffer and op in self.bufs:
|
614
|
+
st_uop = self.sts[self.bufs.index(op)].to_uop()
|
615
|
+
return ret.replace(src=(st_uop,)) if op.op is Ops.VALID else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
|
616
|
+
if op.op is Ops.SINK: return ret.replace(arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals))
|
617
|
+
if op.op is Ops.REDUCE_AXIS:
|
618
|
+
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
|
619
|
+
|
620
|
+
def reduced_axes(start, stop):
|
621
|
+
return tuple(i for i in range(start, stop) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
|
622
|
+
axes = reduced_axes(self.first_reduce + self.group_for_reduces, self.shape_len)
|
623
|
+
grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces)
|
624
|
+
|
625
|
+
if (tc := self.tensor_core) and (self.use_tensor_cores == 1 or self.use_tensor_cores == 3):
|
626
|
+
def fix_st(st: ShapeTracker, wd_pattern, tcd_pattern):
|
627
|
+
wd, warp_dims = self.global_dims, tuple(sz for _, sz in tc.threads)
|
628
|
+
tcd, tcd_dims = self.first_upcast, tuple(sz for _, sz in tc.reduce_axes + tc.early_upcast_axes)
|
629
|
+
|
630
|
+
assert st.shape[wd:wd+len(warp_dims)] == warp_dims, f"warp dims wrong: {st.shape[wd:wd+len(warp_dims)]=} != {warp_dims=}"
|
631
|
+
assert st.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, f"tcd dims wrong: {st.shape[tcd:tcd+len(tcd_dims)]=} != {tcd_dims=}"
|
632
|
+
assert tc.expanded_shape is not None
|
633
|
+
|
634
|
+
new_shape = st.shape[:tcd] + tc.expanded_shape + st.shape[tcd+len(tcd_dims):] # expand the tcd
|
635
|
+
permaxis = list(range(wd)) + [y + (wd if x == 0 else tcd) for x,y in wd_pattern] + list(range(wd+len(warp_dims),tcd)) + \
|
636
|
+
[y + (wd if x == 0 else tcd) for x,y in tcd_pattern] + list(range(tcd+len(tc.expanded_shape),len(new_shape)))
|
637
|
+
return st.reshape(new_shape).permute(tuple(permaxis)).reshape(st.shape).simplify()
|
638
|
+
|
639
|
+
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
|
640
|
+
for i, tc_pattern in enumerate([tc.st1_pattern, tc.st2_pattern]):
|
641
|
+
if tc_pattern: srcs[i] = srcs[i].view(fix_st(unwrap(srcs[i].st), *tc_pattern))
|
642
|
+
|
643
|
+
if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
|
644
|
+
local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape))
|
645
|
+
st = store_st = ShapeTracker.from_shape(local_shape)
|
646
|
+
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{i + 1}", st.real_size()))
|
647
|
+
if tc_pattern: store_st = fix_st(store_st, *tc_pattern)
|
648
|
+
local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i])
|
649
|
+
srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store))
|
650
|
+
|
651
|
+
tc_reduce_axes = tuple(self.first_upcast + ax for ax, _ in tc.reduce_axes)
|
652
|
+
if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/EXPAND to get the vectorization right
|
653
|
+
upcast_axes = tuple(tuple((self.first_upcast + ax, sz) for ax, sz in up) for up in tc.upcast_axes)
|
654
|
+
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, prod(sz for _, sz in tc.threads), upcast_axes, tc_reduce_axes)
|
655
|
+
wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
|
656
|
+
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=(
|
657
|
+
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(wmma_sz[0]), src=(srcs[0],), arg=upcast_axes[0]),
|
658
|
+
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(wmma_sz[1]), src=(srcs[1],), arg=upcast_axes[1]),
|
659
|
+
UOp.const(tc.dtype_out.vec(wmma_sz[2]), 0.0)), arg=wmma_arg)
|
660
|
+
tc_uop = UOp(Ops.EXPAND, tc.dtype_out, (wmma,), arg=upcast_axes[2])
|
661
|
+
|
662
|
+
else: # for TC=3 MUL/SUM instead of WMMA
|
663
|
+
tc_uop = UOp(Ops.REDUCE_AXIS, tc.dtype_out, ((srcs[0] * srcs[1]).cast(tc.dtype_out),), (Ops.ADD, tc_reduce_axes))
|
664
|
+
|
665
|
+
new_reduce_axes = tuple(i for i in axes if i not in tc_reduce_axes)
|
666
|
+
return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_reduce_axes)) if new_reduce_axes else tc_uop
|
667
|
+
|
668
|
+
ret = ret.replace(arg = (op.arg[0], axes))
|
669
|
+
if self.group_for_reduces and grouped_axes:
|
670
|
+
local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims] + \
|
671
|
+
tuple([self.full_shape[i] if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i] else 1 \
|
672
|
+
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
|
673
|
+
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
|
674
|
+
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
|
675
|
+
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
|
676
|
+
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret)))
|
677
|
+
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
|
678
|
+
if op is self.reduceops[-1]: return grouped_reduce
|
679
|
+
st_uop = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else a for i,a in enumerate(local_shape)])).to_uop()
|
680
|
+
return UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, grouped_reduce)))
|
681
|
+
|
682
|
+
return ret
|
683
|
+
|
684
|
+
return graph_rewrite(fixup_ast(self.ast), PatternMatcher([
|
685
|
+
(UPat({*GroupOp.ALU,Ops.CAST,Ops.BITCAST,Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
|
686
|
+
(UPat(Ops.LOAD, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src)))]))
|
687
|
+
|
688
|
+
# **** this is the lowerer ****
|
689
|
+
|
690
|
+
@track_rewrites()
|
691
|
+
def linearize(self) -> Kernel:
|
692
|
+
modified_ast = self.get_optimized_ast()
|
693
|
+
|
694
|
+
if DEBUG >= 3:
|
695
|
+
print(self.name)
|
696
|
+
if getenv("RAWAST"): print(self.ast)
|
697
|
+
print(modified_ast)
|
698
|
+
print(self.applied_opts)
|
699
|
+
verify_ast(modified_ast)
|
700
|
+
|
701
|
+
self.uops:List[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
|
702
|
+
if DEBUG >= 5: print_uops(self.uops)
|
703
|
+
return self
|
704
|
+
|
705
|
+
def to_program(self, name_override:Optional[str]=None) -> Program:
|
706
|
+
self.linearize()
|
707
|
+
src = self.opts.render(name:=to_function_name(ansiname:=(name_override if name_override is not None else self.name)), self.uops)
|
708
|
+
|
709
|
+
if getenv("RUN_PROCESS_REPLAY"):
|
710
|
+
from test.external.process_replay.helpers import get_process_replay_ctx
|
711
|
+
diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, *get_process_replay_ctx(), src))
|
712
|
+
|
713
|
+
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
714
|
+
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
715
|
+
mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
|
716
|
+
for _, group in itertools.groupby([x for x in self.ast.parents if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
|
717
|
+
key=lambda x: (x.op, x.src[0].arg)))
|
718
|
+
return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
|
719
|
+
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|
720
|
+
|
721
|
+
# the living definition of intermediate UOps
|
722
|
+
|
723
|
+
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
|
724
|
+
if not uop.has_st or uop in sts: return
|
725
|
+
# restore globals from the two stage reduce
|
726
|
+
if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL:
|
727
|
+
_assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
|
728
|
+
sts[uop] = sts[local_reduce]
|
729
|
+
return
|
730
|
+
for x in uop.src: _assert_valid_uop(x, st, sts)
|
731
|
+
# only reduceuop is allowed to change shape, limited to turning n to 1
|
732
|
+
if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
|
733
|
+
# movementops are pushed to VIEW
|
734
|
+
elif uop.op is Ops.VIEW:
|
735
|
+
assert len(uop.src) == 0, f"can't swizzle in kernel yet {uop}"
|
736
|
+
st = uop.arg
|
737
|
+
# everything else inherits shape
|
738
|
+
else:
|
739
|
+
st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0]
|
740
|
+
if not all_same(shapes:=[x.shape for x in src_sts]):
|
741
|
+
if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
|
742
|
+
raise AssertionError(f"found implicit expand {sizes} {shapes}")
|
743
|
+
sts[uop] = st
|
744
|
+
|
745
|
+
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
|
746
|
+
assert ast.op is Ops.SINK and all(x.op is Ops.STORE for x in ast.src), "must be SINK"
|
747
|
+
assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size"
|
748
|
+
sts: Dict[UOp, ShapeTracker] = {}
|
749
|
+
for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
|
750
|
+
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
|
751
|
+
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
|
752
|
+
type_verify(list(sts))
|
753
|
+
return sts
|