tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +248 -115
- tinygrad/codegen/lowerer.py +215 -0
- tinygrad/codegen/transcendental.py +310 -0
- tinygrad/codegen/uopgraph.py +622 -0
- tinygrad/codegen/uops.py +235 -393
- tinygrad/device.py +428 -69
- tinygrad/dtype.py +18 -4
- tinygrad/engine/graph.py +19 -32
- tinygrad/engine/jit.py +148 -70
- tinygrad/engine/realize.py +127 -51
- tinygrad/engine/schedule.py +259 -216
- tinygrad/engine/search.py +29 -22
- tinygrad/function.py +9 -0
- tinygrad/helpers.py +87 -49
- tinygrad/lazy.py +34 -35
- tinygrad/multi.py +41 -36
- tinygrad/nn/__init__.py +39 -22
- tinygrad/nn/state.py +3 -3
- tinygrad/ops.py +63 -62
- tinygrad/renderer/__init__.py +43 -21
- tinygrad/renderer/assembly.py +104 -106
- tinygrad/renderer/cstyle.py +87 -60
- tinygrad/renderer/llvmir.py +21 -30
- tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/kfd.py +32 -0
- tinygrad/runtime/autogen/libc.py +4260 -0
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/graph/clang.py +2 -2
- tinygrad/runtime/graph/cuda.py +8 -11
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +18 -15
- tinygrad/runtime/ops_amd.py +197 -305
- tinygrad/runtime/ops_clang.py +2 -2
- tinygrad/runtime/ops_cuda.py +36 -94
- tinygrad/runtime/ops_disk.py +3 -7
- tinygrad/runtime/ops_gpu.py +4 -2
- tinygrad/runtime/ops_hip.py +70 -0
- tinygrad/runtime/ops_metal.py +38 -27
- tinygrad/runtime/ops_nv.py +283 -363
- tinygrad/runtime/ops_python.py +26 -30
- tinygrad/runtime/support/compiler_cuda.py +78 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/shape/shapetracker.py +5 -14
- tinygrad/shape/symbolic.py +4 -8
- tinygrad/shape/view.py +34 -22
- tinygrad/tensor.py +399 -97
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
- tinygrad-0.9.2.dist-info/RECORD +70 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/runtime/{driver → support}/__init__.py +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/codegen/kernel.py
CHANGED
@@ -1,21 +1,25 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
+
import itertools, functools
|
3
|
+
from dataclasses import dataclass, replace
|
2
4
|
from collections import defaultdict
|
3
|
-
import
|
4
|
-
|
5
|
-
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS, verify_lazyop
|
5
|
+
from typing import Optional, List, Tuple, cast, Dict, Union, Final, DefaultDict, Any
|
6
|
+
|
7
|
+
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, MetaOps, UNSAFE_PAD_OPS, verify_lazyop, KernelInfo
|
6
8
|
from tinygrad.device import Device
|
7
|
-
from tinygrad.renderer import Renderer, TensorCore
|
8
|
-
from tinygrad.dtype import
|
9
|
-
from tinygrad.helpers import all_same, colored, ansilen, dedup,
|
9
|
+
from tinygrad.renderer import Renderer, TensorCore, Program
|
10
|
+
from tinygrad.dtype import ImageDType
|
11
|
+
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, DEBUG, TC_OPT, USE_TC, round_up, all_int, \
|
12
|
+
get_contraction, to_function_name, diskcache_put, ContextVar
|
10
13
|
from tinygrad.shape.shapetracker import ShapeTracker
|
11
14
|
from tinygrad.shape.symbolic import sint
|
12
|
-
from tinygrad.shape.view import
|
13
|
-
from
|
15
|
+
from tinygrad.shape.view import strides_for_shape
|
16
|
+
from tinygrad.codegen.uopgraph import UOpGraph
|
17
|
+
from tinygrad.codegen.lowerer import lazyop_to_uop
|
14
18
|
from enum import Enum, auto
|
15
19
|
|
16
20
|
class OptOps(Enum):
|
17
21
|
TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
18
|
-
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702
|
22
|
+
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
19
23
|
def __lt__(self, x:OptOps): return self.value < x.value
|
20
24
|
|
21
25
|
class KernelOptError(Exception): pass
|
@@ -47,37 +51,42 @@ class TensorCoreOptions:
|
|
47
51
|
elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False
|
48
52
|
self.axes, self.axes_exist = tuple(axes), tuple(axes_exist)
|
49
53
|
|
50
|
-
@dataclass(frozen=True)
|
51
|
-
class LocalBuffer:
|
52
|
-
name: str
|
53
|
-
size: int
|
54
|
-
dtype: DType = dtypes.float32
|
55
|
-
realized: None = None
|
56
|
-
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
|
57
|
-
|
58
54
|
class Kernel:
|
59
55
|
def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
|
56
|
+
if len(ast) > 1 or ast[0].op is BufferOps.STORE:
|
57
|
+
assert all(x.op is BufferOps.STORE for x in ast)
|
58
|
+
self.ast = LazyOp(MetaOps.KERNEL, ast)
|
59
|
+
else:
|
60
|
+
assert len(ast) == 1 and ast[0].op is MetaOps.KERNEL
|
61
|
+
self.ast = ast[0]
|
62
|
+
|
60
63
|
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
61
|
-
verify_lazyop(
|
62
|
-
|
63
|
-
|
64
|
+
try: lazyop_sts_map = verify_lazyop(self.ast)
|
65
|
+
except AssertionError as e:
|
66
|
+
print("INVALID AST")
|
67
|
+
for op in ast: print(op)
|
68
|
+
raise e
|
64
69
|
|
65
|
-
|
66
|
-
def ordered_lazyops(op):
|
67
|
-
|
68
|
-
return cached_ordered_lazyops[op]
|
69
|
-
self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps])
|
70
|
+
@functools.lru_cache(None)
|
71
|
+
def ordered_lazyops(op): return dedup([item for x in op.src for item in ordered_lazyops(x)] + [op])
|
72
|
+
self.reduceops = dedup([x for x in ordered_lazyops(self.ast) if x.op in ReduceOps])
|
70
73
|
|
71
|
-
self.
|
72
|
-
|
73
|
-
self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])
|
74
|
+
self.vars = self.ast.vars()
|
75
|
+
self.bufs: List[Union[MemBuffer, ConstBuffer]] = dedup([x.arg for x in self.ast.lazyops if x.op in BufferOps])
|
74
76
|
|
75
77
|
# get earlybufs, before any reduceops
|
76
|
-
|
77
|
-
self.full_buf_index: int = self.bufs.index(
|
78
|
+
earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
|
79
|
+
self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
|
80
|
+
# NOTE: full_shape can be wrong if there's a tree of reduces
|
78
81
|
|
79
82
|
# create new shapetrackers inside this kernel, we will permute them
|
80
|
-
self.sts: List[ShapeTracker] = [x.st for x in
|
83
|
+
self.sts: List[ShapeTracker] = [x.st for x in self.bufs]
|
84
|
+
|
85
|
+
# add the shapetrackers for each reduce
|
86
|
+
# we use this to track which axes are reduced in each reduce
|
87
|
+
for x in self.reduceops:
|
88
|
+
self.sts.append(lazyop_sts_map[x])
|
89
|
+
self.sts.append(lazyop_sts_map[x.src[0]])
|
81
90
|
|
82
91
|
# move all reduce axes to the end
|
83
92
|
reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
|
@@ -89,9 +98,9 @@ class Kernel:
|
|
89
98
|
self.group_for_reduces: int = 0
|
90
99
|
self.upcasted: int = 0
|
91
100
|
self.local_dims: int = 0
|
92
|
-
self.local_alias: DefaultDict[LazyOp, Dict[int, LocalBuffer]] = defaultdict(dict)
|
93
101
|
self.tensor_core: Optional[TensorCore] = None
|
94
102
|
self.tensor_core_opts: Optional[TensorCoreOptions] = None
|
103
|
+
self.use_tensor_cores: int = 0
|
95
104
|
# the local aliased buffers for A and B
|
96
105
|
self.bufs_for_tensor_core: Dict[LazyOp, Tuple[int, int]] = {}
|
97
106
|
self.dont_use_locals: bool = False
|
@@ -100,28 +109,22 @@ class Kernel:
|
|
100
109
|
self.simplify_ones()
|
101
110
|
self.simplify_merge_adjacent()
|
102
111
|
|
103
|
-
# cache
|
104
|
-
self.applied_opts_cache: Optional[List[Opt]] = None
|
105
|
-
|
106
112
|
def copy(self):
|
107
113
|
ret = type(self).__new__(type(self))
|
108
114
|
|
109
115
|
# base linearizer params
|
110
|
-
ret.opts, ret.ast
|
116
|
+
ret.opts, ret.ast = self.opts, self.ast
|
111
117
|
|
112
118
|
# things downstream of the AST
|
113
|
-
ret.reduceops, ret.
|
114
|
-
self.reduceops, self.
|
115
|
-
ret.sts = self.sts[:len(ret.bufs)] # NOTE: must redo the local buffers with TC in beam
|
119
|
+
ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = \
|
120
|
+
self.reduceops, self.vars, self.bufs, self.full_buf_index
|
121
|
+
ret.sts = self.sts[:len(ret.bufs)+len(ret.reduceops)*2] # NOTE: must redo the local buffers with TC in beam
|
116
122
|
|
117
123
|
# parameters for optimizations
|
118
124
|
ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
|
119
125
|
self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
|
120
|
-
ret.tensor_core, ret.tensor_core_opts, ret.
|
121
|
-
|
122
|
-
|
123
|
-
# uncached since linearize didn't run
|
124
|
-
ret.applied_opts_cache = None
|
126
|
+
ret.tensor_core, ret.tensor_core_opts, ret.bufs_for_tensor_core, ret.use_tensor_cores = \
|
127
|
+
self.tensor_core, self.tensor_core_opts, self.bufs_for_tensor_core, self.use_tensor_cores
|
125
128
|
|
126
129
|
return ret
|
127
130
|
|
@@ -129,29 +132,20 @@ class Kernel:
|
|
129
132
|
def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)]
|
130
133
|
|
131
134
|
# TODO: these need more tests or it might silently be no-op
|
132
|
-
def
|
133
|
-
def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] # noqa: E501
|
135
|
+
def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501
|
134
136
|
|
135
137
|
def upcasted_axis(self, i:int) -> List[Tuple[int, Optional[sint], bool]]:
|
136
|
-
upcasted_shape, upcasted_stride = self.sts[i].shape[self.
|
138
|
+
upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:]
|
137
139
|
assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
|
138
140
|
return list(zip(upcasted_shape, upcasted_stride,
|
139
|
-
[x!=y for x,y in zip(self.sts[0].shape[self.
|
140
|
-
|
141
|
-
# TODO: is there a better way to write this?
|
142
|
-
def acc_offsets(self, i:int) -> List[int]:
|
143
|
-
if self.upcasted == 0: return [0]
|
144
|
-
upcasted_i = self.upcasted_axis(i)
|
145
|
-
acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
|
146
|
-
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]
|
147
|
-
|
148
|
-
def get_float4_upcast_dim(self, i:int) -> List[int]:
|
149
|
-
should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in (dtypes.float, dtypes.half) or isinstance(self.bufs[i].dtype, ImageDType))
|
150
|
-
return [x for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1] if should_upcast else []
|
141
|
+
[x!=y for x,y in zip(self.sts[0].shape[self.first_upcast:], self.full_shape[self.first_upcast:])]))
|
151
142
|
|
152
143
|
@property
|
153
144
|
def first_reduce(self) -> int:
|
154
|
-
return [x!=y for x,y in zip(self.sts[0].shape[:self.
|
145
|
+
return [x!=y for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True)
|
146
|
+
|
147
|
+
@property
|
148
|
+
def first_upcast(self) -> int: return self.shape_len-self.upcasted
|
155
149
|
|
156
150
|
@property
|
157
151
|
def reduceop(self) -> Optional[LazyOp]: return self.reduceops[0] if len(self.reduceops) > 0 else None
|
@@ -163,7 +157,7 @@ class Kernel:
|
|
163
157
|
def full_shape(self) -> Tuple[sint, ...]: return self.sts[self.full_buf_index].shape
|
164
158
|
|
165
159
|
@property
|
166
|
-
def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.
|
160
|
+
def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.first_upcast]
|
167
161
|
|
168
162
|
@property
|
169
163
|
def shape_len(self) -> int: return len(self.sts[0].shape)
|
@@ -193,9 +187,9 @@ class Kernel:
|
|
193
187
|
# between first_reduce and first_reduce + group_for_reduces, they are either upcast mid reduce (white), or late upcasted (green)
|
194
188
|
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + self.group_for_reduces)] # noqa: E501
|
195
189
|
# between first_reduce + group_for_reduces and upcasted, they are reduce (red)
|
196
|
-
colors += ["red"] * (
|
190
|
+
colors += ["red"] * (self.first_upcast - (self.first_reduce + self.group_for_reduces))
|
197
191
|
# upcasted dimensions are reduce (magenta) or normal (yellow)
|
198
|
-
colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.
|
192
|
+
colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.first_upcast, self.shape_len)]
|
199
193
|
assert len(colors) == self.shape_len, "colors size mismatch"
|
200
194
|
return colors
|
201
195
|
|
@@ -229,7 +223,7 @@ class Kernel:
|
|
229
223
|
move_axis = axis if top else axis+1
|
230
224
|
if move_axis < insert_before: insert_before += 1
|
231
225
|
self.reshape_and_permute(
|
232
|
-
lambda x:
|
226
|
+
lambda x: x[0:axis] + (((amount, x[axis]//amount) if top else (x[axis]//amount, amount)) if x[axis] > 1 else (1,1)) + x[axis+1:],
|
233
227
|
[i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])
|
234
228
|
|
235
229
|
# ******************** complex simplifiers ********************
|
@@ -240,7 +234,7 @@ class Kernel:
|
|
240
234
|
if self.shape_len == 0: return False
|
241
235
|
all_ones = [s==1 for s in self.full_shape]
|
242
236
|
self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
|
243
|
-
self.upcasted -= sum(all_ones[self.
|
237
|
+
self.upcasted -= sum(all_ones[self.first_upcast:]) # TODO: no necessary since upcasted axis can't be un-upcasted
|
244
238
|
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
|
245
239
|
return any(all_ones)
|
246
240
|
|
@@ -281,25 +275,6 @@ class Kernel:
|
|
281
275
|
# do the reshapes
|
282
276
|
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
283
277
|
|
284
|
-
# ******************** helpers ********************
|
285
|
-
|
286
|
-
def alias_buffer(self, op:LazyOp, i:int, pattern:List[int]) -> None:
|
287
|
-
assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
|
288
|
-
|
289
|
-
bst = 1
|
290
|
-
real_strides = self.sts[i].real_strides()
|
291
|
-
shp, stride = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern)
|
292
|
-
for priority in range(1, max(pattern)+1): # priority. 0 is non local and ignored
|
293
|
-
for j,p in enumerate(pattern):
|
294
|
-
if priority == p and real_strides[j] != 0:
|
295
|
-
stride[j] = bst
|
296
|
-
bst *= shp[j]
|
297
|
-
|
298
|
-
self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
|
299
|
-
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size))
|
300
|
-
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
|
301
|
-
self.local_alias[op][i] = cast(LocalBuffer, self.bufs[-1])
|
302
|
-
|
303
278
|
# ******************** high level optimizers ********************
|
304
279
|
|
305
280
|
def _create_tc_opts(self, reduceop:LazyOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
|
@@ -347,14 +322,26 @@ class Kernel:
|
|
347
322
|
try:
|
348
323
|
for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
349
324
|
except KernelOptError: continue
|
350
|
-
self.
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
325
|
+
if self.opts.device in {"AMD", "HIP"}:
|
326
|
+
# NOTE: AMD requires locals first
|
327
|
+
self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, tc.dims[2]), append_opt=False)
|
328
|
+
for (tc_dim, tc_amt) in tc.threads: self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
|
329
|
+
for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
|
330
|
+
if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
|
331
|
+
elif self.opts.device == "METAL":
|
332
|
+
self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, tc.dims[2]), append_opt=False)
|
333
|
+
for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
|
334
|
+
if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
|
335
|
+
for (tc_dim, tc_amt) in tc.threads: self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
|
336
|
+
elif self.opts.device in {"CUDA", "NV"}:
|
337
|
+
self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, 8), append_opt=False)
|
338
|
+
self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, 2), append_opt=False)
|
339
|
+
# NOTE: LOCALS and UPCAST can be swapped here. it doesn't seem faster
|
340
|
+
self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[1], 2), append_opt=False)
|
341
|
+
self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[0], 2), append_opt=False)
|
342
|
+
for (tc_dim, tc_amt) in tc.threads: self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
|
343
|
+
self.tensor_core = tc
|
344
|
+
self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
|
358
345
|
return True
|
359
346
|
return False
|
360
347
|
|
@@ -373,7 +360,7 @@ class Kernel:
|
|
373
360
|
1: allows kernels with multiple reduce axes and also multiplication of UnaryOps.CAST'd buffers
|
374
361
|
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
|
375
362
|
"""
|
376
|
-
if tc_opt is None: tc_opt =
|
363
|
+
if tc_opt is None: tc_opt = TC_OPT.value
|
377
364
|
if not self.opts.tensor_cores and use_tensor_cores != 2: return False
|
378
365
|
try: # check TC first and apply hand-coded opts if successful
|
379
366
|
self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
|
@@ -395,7 +382,6 @@ class Kernel:
|
|
395
382
|
if self.full_shape[tc_opts.axes[0]] % upc == 0:
|
396
383
|
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], upc))
|
397
384
|
break
|
398
|
-
|
399
385
|
return True
|
400
386
|
except KernelOptError:
|
401
387
|
return False
|
@@ -406,7 +392,7 @@ class Kernel:
|
|
406
392
|
if opt.op is OptOps.TC:
|
407
393
|
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
|
408
394
|
check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
|
409
|
-
check((use_tensor_cores:=
|
395
|
+
check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
|
410
396
|
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available")
|
411
397
|
self.applied_opts.append(opt)
|
412
398
|
return
|
@@ -414,15 +400,16 @@ class Kernel:
|
|
414
400
|
axis = opt.real_axis(self)
|
415
401
|
check(axis < len(self.full_shape), "invalid axis")
|
416
402
|
|
417
|
-
if opt.amt is
|
403
|
+
if opt.op is OptOps.SWAP: amt = cast(int, opt.amt) # amt is an axis in the SWAPs
|
404
|
+
elif opt.amt is not None:
|
418
405
|
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
|
419
406
|
check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless")
|
420
407
|
if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
|
421
408
|
else: amt = -1
|
422
409
|
|
423
410
|
if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
|
424
|
-
acc_sz
|
425
|
-
upcast_sz = prod([a for a,b in zip(self.full_shape[
|
411
|
+
acc_sz = dt.base.itemsize if isinstance((dt:=self.reduceop.dtype), ImageDType) else dt.itemsize
|
412
|
+
upcast_sz = prod([a for a,b in zip(self.full_shape[self.first_upcast:], self.sts[0].shape[self.first_upcast:]) if a == b])
|
426
413
|
local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
|
427
414
|
smem_sz = amt*acc_sz*upcast_sz*local_sz
|
428
415
|
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
|
@@ -434,12 +421,13 @@ class Kernel:
|
|
434
421
|
self.local_dims += 1
|
435
422
|
elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
|
436
423
|
check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
|
437
|
-
check(
|
424
|
+
check(self.first_reduce + self.group_for_reduces <= axis < self.first_upcast, "must be reduce axis to group")
|
438
425
|
check(not self.tensor_core, "can't group with tensor cores")
|
426
|
+
check(len(self.reduceops) == 1, "can't group with multiple reduces")
|
439
427
|
self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
|
440
428
|
self.group_for_reduces += 1
|
441
429
|
elif opt.op is OptOps.UNROLL: # purple
|
442
|
-
check(axis < self.
|
430
|
+
check(axis < self.first_upcast, "can't upcasted already upcasted")
|
443
431
|
check(amt <= 32, "don't unroll more than 32")
|
444
432
|
# TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
|
445
433
|
#upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
|
@@ -451,7 +439,7 @@ class Kernel:
|
|
451
439
|
elif opt.op is OptOps.UPCAST: # yellow
|
452
440
|
check(axis < self.first_reduce, "upcast is for non-reduce")
|
453
441
|
check(not(self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
|
454
|
-
check(amt <=
|
442
|
+
check(amt <= 16, "don't upcast more than 16")
|
455
443
|
self.shift_to(axis, amt, insert_before=None)
|
456
444
|
self.upcast()
|
457
445
|
elif opt.op is OptOps.UPCASTMID: # white
|
@@ -466,18 +454,23 @@ class Kernel:
|
|
466
454
|
check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
|
467
455
|
check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
|
468
456
|
self.dont_use_locals = True
|
457
|
+
elif opt.op is OptOps.SWAP:
|
458
|
+
check(axis < amt and amt < self.global_dims, f"swap is only for globals with axis < amt, getting {amt=}, {axis=}, {self.global_dims=}")
|
459
|
+
permute = list(range(self.shape_len))
|
460
|
+
permute[axis], permute[amt] = permute[amt], permute[axis]
|
461
|
+
self.reshape_and_permute(None, tuple(permute))
|
469
462
|
elif opt.op is OptOps.PADTO:
|
470
463
|
check(not self.vars, "does not work with symbolic shape")
|
471
|
-
check(axis < self.
|
464
|
+
check(axis < self.first_upcast, "cannot pad upcasted")
|
472
465
|
# ok to pad SUM if all parent ops have f(0) = 0
|
473
466
|
if self.first_reduce <= axis:
|
474
467
|
check((r:=cast(LazyOp, self.reduceop)).op is ReduceOps.SUM and \
|
475
|
-
all(op.op not in UNSAFE_PAD_OPS for
|
468
|
+
all(op.op not in UNSAFE_PAD_OPS for sop in r.src for op in sop.lazyops), "cannot pad")
|
476
469
|
padded = False
|
477
470
|
for i,st in enumerate(self.sts):
|
478
471
|
if self.sts[i].shape[axis] == 1: continue # reduced
|
479
472
|
check(self.sts[i].shape[axis] > amt//4, f"pad adds more than quadruple the work {self.sts[i].shape[axis]=} > {amt//4=}")
|
480
|
-
if (ru := round_up(cast(int, self.sts[i].shape[axis]),
|
473
|
+
if (ru := round_up(cast(int, self.sts[i].shape[axis]), amt) - self.sts[i].shape[axis]):
|
481
474
|
# pad right seems to be faster
|
482
475
|
self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
|
483
476
|
padded = True
|
@@ -487,14 +480,15 @@ class Kernel:
|
|
487
480
|
if self.simplify_ones() and self.tensor_core_opts:
|
488
481
|
self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
|
489
482
|
|
490
|
-
def required_optimizations(self):
|
483
|
+
def required_optimizations(self) -> Kernel:
|
491
484
|
if self.bufs[0].dtype.__class__ is ImageDType:
|
492
485
|
unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
|
493
486
|
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[0]}"
|
494
|
-
if len(unit_stride_axes_mul_4) and all(x <
|
487
|
+
if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
|
495
488
|
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
489
|
+
return self
|
496
490
|
|
497
|
-
def hand_coded_optimizations(self):
|
491
|
+
def hand_coded_optimizations(self) -> Kernel:
|
498
492
|
self.required_optimizations()
|
499
493
|
|
500
494
|
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
@@ -513,13 +507,13 @@ class Kernel:
|
|
513
507
|
if MV_THREADS_PER_ROW > 1: self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
514
508
|
if MV_BLOCKSIZE > 1: self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
515
509
|
if MV_ROWS_PER_THREAD > 1: self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
|
516
|
-
return
|
510
|
+
return self
|
517
511
|
|
518
512
|
if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]):
|
519
513
|
# are we grouping? (requires local shape support)
|
520
514
|
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # noqa: E501
|
521
515
|
# TODO: use 1024 if it's allowed in a smarter way
|
522
|
-
for sz in (
|
516
|
+
for sz in ([256, 16] if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
523
517
|
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
|
524
518
|
try: # may fail due to excessive smem usage
|
525
519
|
self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
|
@@ -538,19 +532,19 @@ class Kernel:
|
|
538
532
|
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
|
539
533
|
if buf.dtype.__class__ is ImageDType:
|
540
534
|
#assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
|
541
|
-
if len(unit_stride_axes_mul_4) and all(x <
|
535
|
+
if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
|
542
536
|
if unit_stride_axes_mul_4[0] < self.first_reduce:
|
543
537
|
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
544
538
|
else:
|
545
539
|
self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
|
546
540
|
|
547
541
|
# no more opt if we are grouping
|
548
|
-
if self.group_for_reduces: return
|
542
|
+
if self.group_for_reduces: return self
|
549
543
|
|
550
544
|
# **** below this line need to be optional and benchmarked ****
|
551
545
|
|
552
546
|
# TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
|
553
|
-
# to trigger the above bug, remove prod(self.full_shape[self.
|
547
|
+
# to trigger the above bug, remove prod(self.full_shape[self.first_upcast:]) from the below
|
554
548
|
# expression and run test/test_ops.py with IMAGE=2
|
555
549
|
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
556
550
|
# this can be made much smarter
|
@@ -560,7 +554,7 @@ class Kernel:
|
|
560
554
|
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
|
561
555
|
# for now skip upcasting here if there is a symbolic axis
|
562
556
|
if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \
|
563
|
-
prod(self.full_shape[self.
|
557
|
+
prod(self.full_shape[self.first_upcast:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
|
564
558
|
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
565
559
|
to_upcast.append(axis)
|
566
560
|
for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
@@ -581,11 +575,11 @@ class Kernel:
|
|
581
575
|
else: break
|
582
576
|
|
583
577
|
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
|
584
|
-
if self.first_reduce <
|
578
|
+
if self.first_reduce < self.first_upcast and (prod(self.full_shape[self.first_upcast:]) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): # noqa: E501
|
585
579
|
if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
|
586
580
|
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
587
581
|
# if it's small, upcast a second reduce dimension too
|
588
|
-
if self.first_reduce <
|
582
|
+
if self.first_reduce < self.first_upcast and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int):
|
589
583
|
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
590
584
|
else:
|
591
585
|
for splits in [4]:
|
@@ -618,3 +612,142 @@ class Kernel:
|
|
618
612
|
will_delete_shape = local_sz == self.full_shape[axis]
|
619
613
|
self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
|
620
614
|
if will_delete_shape: deleted_shape += 1
|
615
|
+
|
616
|
+
return self
|
617
|
+
|
618
|
+
# **** kernel outputs ****
|
619
|
+
|
620
|
+
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
621
|
+
@functools.cached_property
|
622
|
+
def name(self) -> str:
|
623
|
+
# kernel name (before late upcast)
|
624
|
+
name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.ast.lazyops) else "E")) + \
|
625
|
+
(f"{len(self.ast.src)}_" if len(self.ast.src) > 1 else "_") + \
|
626
|
+
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
627
|
+
|
628
|
+
# name the function something unique
|
629
|
+
Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1
|
630
|
+
suffix = f"{'n'+str(Kernel.kernel_cnt[function_name]-1)}" if Kernel.kernel_cnt[function_name] > 1 else ""
|
631
|
+
return name+colored(suffix, 'BLACK')
|
632
|
+
|
633
|
+
def get_optimized_ast(self) -> LazyOp:
|
634
|
+
# set the shapetrackers to the optimized ones, fixup reduceop
|
635
|
+
# transformed to the final LazyOp
|
636
|
+
@functools.lru_cache(None)
|
637
|
+
def fixup_ast(op:LazyOp, apply_to_st=None) -> LazyOp:
|
638
|
+
if op.op in BufferOps:
|
639
|
+
if isinstance(op.arg, MemBuffer) and op.arg.idx < 0:
|
640
|
+
# for locals, we use the ShapeTracker that's in the MemBuffer
|
641
|
+
arg:Any = replace(op.arg, st=apply_to_st(op.arg.st)) if apply_to_st is not None else op.arg
|
642
|
+
else:
|
643
|
+
idx = self.bufs.index(op.arg)
|
644
|
+
arg = replace(op.arg, st=self.sts[idx] if apply_to_st is None else apply_to_st(self.sts[idx]))
|
645
|
+
elif op.op in ReduceOps:
|
646
|
+
reduce_idx = len(self.bufs) + self.reduceops.index(op)*2
|
647
|
+
arg = tuple(i for i in range(self.first_reduce+self.group_for_reduces, self.shape_len)
|
648
|
+
if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i])
|
649
|
+
if op in self.bufs_for_tensor_core and (tc := self.tensor_core):
|
650
|
+
rsrc = op.src[0]
|
651
|
+
if rsrc.op is UnaryOps.CAST: rsrc = rsrc.src[0]
|
652
|
+
assert rsrc.op is BinaryOps.MUL
|
653
|
+
|
654
|
+
def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1):
|
655
|
+
wd, tcd = self.global_dims, self.first_upcast
|
656
|
+
assert st1.shape[wd:wd+len(warp_dims)] == warp_dims, f"warp dims wrong: {st1.shape[wd:wd+len(warp_dims)]=} != {warp_dims=}"
|
657
|
+
assert st1.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, f"tcd dims wrong: {st1.shape[tcd:tcd+len(tcd_dims)]=} != {tcd_dims=}"
|
658
|
+
new_shape = st1.shape[:tcd] + tcd_expand + st1.shape[tcd+len(tcd_dims):] # expand the tcd
|
659
|
+
permaxis = list(range(wd)) + [y + (wd if x == 0 else tcd) for x,y in pattern_1] + list(range(wd+len(warp_dims), tcd)) + \
|
660
|
+
[y + (wd if x == 0 else tcd) for x,y in pattern_2] + list(range(tcd+len(tcd_expand), len(new_shape)))
|
661
|
+
return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape).simplify()
|
662
|
+
|
663
|
+
if self.opts.device in {"AMD", "HIP"}:
|
664
|
+
reduce_axes, upcast_axes = [0], [[(0, 16)], [(0, 16)], [(1, 8)]]
|
665
|
+
# https://gpuopen.com/learn/wmma_on_rdna3/
|
666
|
+
fix_st1 = functools.partial(fix_st, (8,2,2), (16,8), (16,2,4), ((1,2), (0,2), (1,1), (0,1)), ((1,0), (0,0)))
|
667
|
+
fix_st2 = None
|
668
|
+
elif self.opts.device == "METAL":
|
669
|
+
reduce_axes, upcast_axes = [0], [[(1, 2)], [(1, 2)], [(1, 2)]]
|
670
|
+
fix_st1 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((1,1), (0,1), (1,0), (0,3)), ((0,0), (0,2), (1,3), (1,2)))
|
671
|
+
fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3)))
|
672
|
+
elif self.opts.device in {"CUDA", "NV"}:
|
673
|
+
reduce_axes, upcast_axes = [0, 1], [[(0, 8)], [(2, 2), (3, 2)], [(2, 2), (3, 2)]]
|
674
|
+
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
|
675
|
+
fix_st1 = functools.partial(fix_st, (2,2,2,2,2), (8,2,2,2), (2,2,2,2,2,2),
|
676
|
+
((1,1), (1,0), (0,2), (0,3), (0,4)), ((1,3), (1,4), (1,2), (0,0), (0,1), (1,5)))
|
677
|
+
fix_st2 = functools.partial(fix_st, (2,2,2,2,2), (8,2,2,2), (2,2,2,2,2,2),
|
678
|
+
((1,1), (1,0), (1,5), (0,0), (0,1)), ((0,4), (0,2), (1,4), (0,3), (1,3), (1,2)))
|
679
|
+
else:
|
680
|
+
raise RuntimeError("unsupported device for tensor cores")
|
681
|
+
|
682
|
+
assert apply_to_st is None, "double tensor core? not supported"
|
683
|
+
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device,
|
684
|
+
tuple(tuple((self.first_upcast+ax, sz) for ax, sz in up) for up in upcast_axes),
|
685
|
+
tuple(self.first_upcast+ax for ax in reduce_axes))
|
686
|
+
if self.use_tensor_cores >= 2:
|
687
|
+
if self.use_tensor_cores == 3:
|
688
|
+
# TC=3, emulate the warp addressing with locals
|
689
|
+
ex_shape = tuple(1 if i < self.global_dims or (i >= self.first_reduce and i < self.first_upcast) else s \
|
690
|
+
for i,s in enumerate(self.full_shape))
|
691
|
+
srcs = []
|
692
|
+
for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])):
|
693
|
+
st_load = [self.sts[self.bufs.index(op.arg)].real_strides() for op in src.lazyops if op.op is BufferOps.LOAD]
|
694
|
+
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
|
695
|
+
membuf = MemBuffer(-1-i, tc.dtype_in, ShapeTracker.from_shape(local_shape).expand(ex_shape))
|
696
|
+
srcs.append(LazyOp(BufferOps.LOAD, (fixup_ast(LazyOp(BufferOps.STORE, (src,), membuf), fix_st_fxn),), membuf))
|
697
|
+
else:
|
698
|
+
# for TC=2, we can't do the shapetracker fixup
|
699
|
+
srcs = [fixup_ast(rsrc.src[0]), fixup_ast(rsrc.src[1])]
|
700
|
+
# MUL/SUM instead of WMMA
|
701
|
+
ret = LazyOp(ReduceOps.SUM, (LazyOp(UnaryOps.CAST, (LazyOp(BinaryOps.MUL, tuple(srcs)),), tc.dtype_out),), wmma_arg[-1])
|
702
|
+
else:
|
703
|
+
ret = LazyOp(ReduceOps.WMMA, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
|
704
|
+
return LazyOp(op.op, (ret,), new_reduce_axes) if (new_reduce_axes:=tuple(i for i in arg if i-self.first_upcast not in reduce_axes)) else ret
|
705
|
+
if self.group_for_reduces:
|
706
|
+
start = LazyOp(op.op, tuple(fixup_ast(x, apply_to_st) for x in op.src), arg)
|
707
|
+
local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces] + \
|
708
|
+
(1,) * (self.first_upcast - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
|
709
|
+
local_buffer = MemBuffer(-1, start.dtype, ShapeTracker.from_shape(local_shape))
|
710
|
+
local_store = LazyOp(BufferOps.STORE, (start,), local_buffer)
|
711
|
+
local_load = LazyOp(BufferOps.LOAD, (local_store,), local_buffer)
|
712
|
+
return LazyOp(op.op, (local_load,), tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces)))
|
713
|
+
elif op.op is MetaOps.KERNEL:
|
714
|
+
arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals)
|
715
|
+
else:
|
716
|
+
arg = op.arg
|
717
|
+
return LazyOp(op.op, tuple(fixup_ast(x, apply_to_st) for x in op.src), arg)
|
718
|
+
return fixup_ast(self.ast)
|
719
|
+
|
720
|
+
# **** this is the lowerer ****
|
721
|
+
|
722
|
+
def linearize(self) -> Kernel:
|
723
|
+
modified_ast = self.get_optimized_ast()
|
724
|
+
|
725
|
+
if DEBUG >= 3:
|
726
|
+
print(self.name)
|
727
|
+
if getenv("RAWAST"): print(self.ast)
|
728
|
+
print(modified_ast)
|
729
|
+
print(self.applied_opts)
|
730
|
+
verify_lazyop(modified_ast)
|
731
|
+
|
732
|
+
# generate the UOpGraph
|
733
|
+
self.uops:UOpGraph = UOpGraph(lazyop_to_uop(modified_ast, self.opts), self.opts)
|
734
|
+
if DEBUG >= 5: self.uops.print()
|
735
|
+
if getenv("GRAPHUOPS"): self.uops.graph()
|
736
|
+
return self
|
737
|
+
|
738
|
+
def to_program(self, name_override:Optional[str]=None) -> Program:
|
739
|
+
self.linearize()
|
740
|
+
self.uops.linearize(self.opts.extra_matcher)
|
741
|
+
src = self.opts.render(name:=to_function_name(ansiname:=(name_override if name_override is not None else self.name)), self.uops.uops)
|
742
|
+
|
743
|
+
if getenv("RUN_PROCESS_REPLAY"):
|
744
|
+
table_name = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}"
|
745
|
+
diskcache_put(table_name, id(self), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()}))
|
746
|
+
|
747
|
+
# group non-local MemBuffers by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
748
|
+
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
749
|
+
mem_bytes = sum(max(x.arg.dtype.itemsize * x.arg.st.real_size() for x in group) for _, group in
|
750
|
+
itertools.groupby([x for x in self.ast.lazyops if x.op in BufferOps and isinstance(x.arg, MemBuffer) and x.arg.idx >= 0],
|
751
|
+
key=lambda x: (x.op, x.arg.idx)))
|
752
|
+
return Program(ansiname, src, self.opts.device, self.uops.uops, mem_estimate=mem_bytes,
|
753
|
+
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|