tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +114 -172
- tinygrad/codegen/linearize.py +211 -81
- tinygrad/codegen/lowerer.py +30 -35
- tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
- tinygrad/codegen/transcendental.py +12 -13
- tinygrad/device.py +170 -47
- tinygrad/dtype.py +28 -26
- tinygrad/engine/jit.py +80 -63
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +162 -0
- tinygrad/engine/realize.py +58 -107
- tinygrad/engine/schedule.py +381 -314
- tinygrad/engine/search.py +40 -44
- tinygrad/gradient.py +70 -0
- tinygrad/helpers.py +77 -58
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +89 -64
- tinygrad/ops.py +562 -446
- tinygrad/renderer/__init__.py +79 -36
- tinygrad/renderer/cstyle.py +70 -84
- tinygrad/renderer/llvmir.py +32 -20
- tinygrad/renderer/ptx.py +79 -99
- tinygrad/renderer/wgsl.py +87 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libpciaccess.py +2023 -0
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +19 -21
- tinygrad/runtime/ops_amd.py +488 -327
- tinygrad/runtime/ops_clang.py +15 -28
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +129 -38
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +45 -40
- tinygrad/runtime/ops_metal.py +93 -73
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +232 -270
- tinygrad/runtime/ops_python.py +51 -46
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +63 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +384 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +26 -4
- tinygrad/runtime/support/hcq.py +254 -324
- tinygrad/runtime/support/llvm.py +32 -0
- tinygrad/shape/shapetracker.py +84 -53
- tinygrad/shape/view.py +103 -138
- tinygrad/spec.py +154 -0
- tinygrad/tensor.py +744 -496
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
- tinygrad-0.10.1.dist-info/RECORD +86 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/codegen/kernel.py
CHANGED
@@ -1,25 +1,25 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import itertools, functools
|
2
|
+
import itertools, functools, math
|
3
3
|
from dataclasses import dataclass
|
4
4
|
from collections import defaultdict
|
5
|
-
from typing import Optional,
|
5
|
+
from typing import Optional, cast, Final, Callable, Sequence
|
6
6
|
from enum import Enum, auto
|
7
7
|
|
8
|
-
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops,
|
9
|
-
|
8
|
+
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, view_left, print_uops
|
9
|
+
from tinygrad.spec import type_verify, shape_spec
|
10
10
|
from tinygrad.device import Device
|
11
|
-
from tinygrad.renderer import Renderer, TensorCore,
|
11
|
+
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec
|
12
12
|
from tinygrad.dtype import ImageDType
|
13
|
-
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap
|
14
|
-
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
|
13
|
+
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap, ContextVar
|
14
|
+
from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, USE_TC, AMX, CAPTURE_PROCESS_REPLAY
|
15
15
|
from tinygrad.shape.shapetracker import ShapeTracker
|
16
16
|
from tinygrad.shape.view import strides_for_shape
|
17
17
|
from tinygrad.codegen.linearize import linearize_uop
|
18
|
-
from tinygrad.codegen.
|
18
|
+
from tinygrad.codegen.rewriter import full_graph_rewrite
|
19
19
|
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction
|
20
20
|
|
21
21
|
class OptOps(Enum):
|
22
|
-
TC = auto(); UPCAST = auto();
|
22
|
+
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
23
23
|
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
24
24
|
def __lt__(self, x:OptOps): return self.value < x.value
|
25
25
|
|
@@ -32,8 +32,8 @@ def check(cond:bool, msg:str=""):
|
|
32
32
|
class Opt:
|
33
33
|
op: OptOps
|
34
34
|
axis: Optional[int] = None
|
35
|
-
|
36
|
-
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis},
|
35
|
+
arg: Optional[int | tuple] = None
|
36
|
+
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
|
37
37
|
def real_axis(self, k:Kernel):
|
38
38
|
if self.axis is None: return -1
|
39
39
|
if self.op is OptOps.UNROLL: return k.first_reduce+self.axis
|
@@ -42,10 +42,10 @@ class Opt:
|
|
42
42
|
|
43
43
|
@dataclass
|
44
44
|
class TensorCoreOptions:
|
45
|
-
axes:
|
46
|
-
axes_exist:
|
47
|
-
axis_pads:
|
48
|
-
def fix_axes(self, removed_axis:int): # adjust the TC axes if
|
45
|
+
axes: tuple[int, ...] # the location of the original N and M axes if still in the shape
|
46
|
+
axes_exist: tuple[bool, ...] # true if the original N and M axes are still in the shape
|
47
|
+
axis_pads: tuple[tuple[int, int], ...]
|
48
|
+
def fix_axes(self, removed_axis:int): # adjust the TC axes if necessary when a dimension is removed
|
49
49
|
axes, axes_exist = list(self.axes), list(self.axes_exist)
|
50
50
|
for tc_dim in [i for i in range(2) if axes_exist[i]]:
|
51
51
|
if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1
|
@@ -57,32 +57,28 @@ class Kernel:
|
|
57
57
|
if ast.op is Ops.SINK: self.ast = ast
|
58
58
|
|
59
59
|
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
60
|
-
|
61
|
-
|
62
|
-
print("INVALID AST")
|
63
|
-
print(self.ast)
|
64
|
-
raise e
|
60
|
+
# verify AST matches the spec
|
61
|
+
if __debug__: type_verify(list(self.ast.toposort), shape_spec)
|
65
62
|
|
66
|
-
|
67
|
-
def ordered_parents(op:UOp) -> List[UOp]: return dedup([item for x in op.src for item in ordered_parents(x)] + [op])
|
68
|
-
self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS])
|
63
|
+
self.reduceops = [x for x in self.ast.toposort if x.op is Ops.REDUCE_AXIS]
|
69
64
|
|
70
|
-
self.vars:
|
71
|
-
|
65
|
+
self.vars: list[Variable] = self.ast.variables()
|
66
|
+
# NOTE: this requires a specific order with the [::-1], this is likely a bug
|
67
|
+
self.bufs: list[UOp] = [x for x in self.ast.toposort if x.op in GroupOp.Buffer][::-1]
|
72
68
|
|
73
69
|
# get earlybufs, before any reduceops
|
74
|
-
earlybufs:
|
70
|
+
earlybufs: list[UOp] = [x for reduceop in self.reduceops for x in reduceop.src[0].toposort if x.op in GroupOp.Buffer]
|
75
71
|
self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
|
76
72
|
# NOTE: full_shape can be wrong if there's a tree of reduces
|
77
73
|
|
78
74
|
# create new shapetrackers inside this kernel, we will permute them
|
79
|
-
self.sts:
|
75
|
+
self.sts: list[ShapeTracker] = [x.st_arg for x in self.bufs]
|
80
76
|
|
81
77
|
# add the shapetrackers for each reduce
|
82
78
|
# we use this to track which axes are reduced in each reduce
|
83
79
|
for x in self.reduceops:
|
84
|
-
self.sts.append(
|
85
|
-
self.sts.append(
|
80
|
+
self.sts.append(unwrap(x.st))
|
81
|
+
self.sts.append(unwrap(x.src[0].st))
|
86
82
|
|
87
83
|
# move all reduce axes to the end
|
88
84
|
reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
|
@@ -90,15 +86,13 @@ class Kernel:
|
|
90
86
|
self.reshape_and_permute(None, permute)
|
91
87
|
|
92
88
|
# parameters for optimization
|
93
|
-
self.applied_opts:
|
89
|
+
self.applied_opts: list[Opt] = []
|
94
90
|
self.group_for_reduces: int = 0
|
95
91
|
self.upcasted: int = 0
|
96
92
|
self.local_dims: int = 0
|
97
93
|
self.tensor_core: Optional[TensorCore] = None
|
98
94
|
self.tensor_core_opts: Optional[TensorCoreOptions] = None
|
99
95
|
self.use_tensor_cores: int = 0
|
100
|
-
# the local aliased buffers for A and B
|
101
|
-
self.bufs_for_tensor_core: Dict[UOp, Tuple[int, int]] = {}
|
102
96
|
self.dont_use_locals: bool = False
|
103
97
|
|
104
98
|
# group simplifies
|
@@ -112,25 +106,23 @@ class Kernel:
|
|
112
106
|
ret.opts, ret.ast = self.opts, self.ast
|
113
107
|
|
114
108
|
# things downstream of the AST
|
115
|
-
ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index =
|
116
|
-
self.reduceops, self.vars, self.bufs, self.full_buf_index
|
109
|
+
ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = self.reduceops, self.vars, self.bufs, self.full_buf_index
|
117
110
|
ret.sts = self.sts[:len(ret.bufs)+len(ret.reduceops)*2] # NOTE: must redo the local buffers with TC in beam
|
118
111
|
|
119
112
|
# parameters for optimizations
|
120
113
|
ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
|
121
114
|
self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
|
122
|
-
ret.tensor_core, ret.tensor_core_opts, ret.
|
123
|
-
self.tensor_core, self.tensor_core_opts, self.bufs_for_tensor_core, self.use_tensor_cores
|
115
|
+
ret.tensor_core, ret.tensor_core_opts, ret.use_tensor_cores = self.tensor_core, self.tensor_core_opts, self.use_tensor_cores
|
124
116
|
|
125
117
|
return ret
|
126
118
|
|
127
119
|
@property
|
128
|
-
def membufs(self) ->
|
120
|
+
def membufs(self) -> list[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
|
129
121
|
|
130
122
|
# TODO: these need more tests or it might silently be no-op
|
131
123
|
def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501
|
132
124
|
|
133
|
-
def upcasted_axis(self, i:int) ->
|
125
|
+
def upcasted_axis(self, i:int) -> list[tuple[int, Optional[sint], bool]]:
|
134
126
|
upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:]
|
135
127
|
assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
|
136
128
|
return list(zip(upcasted_shape, upcasted_stride,
|
@@ -144,24 +136,20 @@ class Kernel:
|
|
144
136
|
def first_upcast(self) -> int: return self.shape_len-self.upcasted
|
145
137
|
|
146
138
|
@property
|
147
|
-
def reduceop(self) ->
|
139
|
+
def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None
|
148
140
|
|
149
141
|
@property
|
150
|
-
def output_shape(self) ->
|
142
|
+
def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape
|
151
143
|
|
152
144
|
@property
|
153
|
-
def full_shape(self) ->
|
145
|
+
def full_shape(self) -> tuple[sint, ...]: return self.sts[self.full_buf_index].shape
|
154
146
|
|
155
147
|
@property
|
156
|
-
def full_unupcasted_shape(self) ->
|
148
|
+
def full_unupcasted_shape(self) -> tuple[sint, ...]: return self.full_shape[:self.first_upcast]
|
157
149
|
|
158
150
|
@property
|
159
151
|
def shape_len(self) -> int: return len(self.sts[0].shape)
|
160
152
|
|
161
|
-
@property
|
162
|
-
def upcast_in_mid_reduce_axes(self) -> List[int]:
|
163
|
-
return [j for j in range(self.first_reduce, self.first_reduce+self.group_for_reduces) if self.full_shape[j] == self.sts[0].shape[j]]
|
164
|
-
|
165
153
|
@property
|
166
154
|
def global_dims(self) -> int: return self.first_reduce-self.local_dims
|
167
155
|
|
@@ -170,18 +158,17 @@ class Kernel:
|
|
170
158
|
# cyan -- local dims (warp ones first)
|
171
159
|
# *** self.first_reduce
|
172
160
|
# green -- reduce-local dims
|
173
|
-
# white -- reduce-late upcasted dim (self.upcast_in_mid_reduce_axes)
|
174
161
|
# red -- reduce loops
|
175
162
|
# *** self.upcasted
|
176
163
|
# purple -- reduce upcasted
|
177
164
|
# yellow -- normal upcasted dimensions
|
178
|
-
def colors(self) ->
|
165
|
+
def colors(self) -> list[str]:
|
179
166
|
# first non local non reduce dims are global (blue)
|
180
167
|
colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
|
181
168
|
# after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
|
182
169
|
colors += ["cyan"] * self.local_dims
|
183
|
-
# between first_reduce and first_reduce + group_for_reduces, they are
|
184
|
-
colors += ["
|
170
|
+
# between first_reduce and first_reduce + group_for_reduces, they are late upcasted (green)
|
171
|
+
colors += ["green"] * self.group_for_reduces
|
185
172
|
# between first_reduce + group_for_reduces and upcasted, they are reduce (red)
|
186
173
|
colors += ["red"] * (self.first_upcast - (self.first_reduce + self.group_for_reduces))
|
187
174
|
# upcasted dimensions are reduce (magenta) or normal (yellow)
|
@@ -198,7 +185,7 @@ class Kernel:
|
|
198
185
|
# ******************** base simplifiers ********************
|
199
186
|
|
200
187
|
# apply reshape and permute to all shapetrackers
|
201
|
-
def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[
|
188
|
+
def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[tuple[sint, ...]], Sequence[sint]]], axis:Optional[Sequence[int]]):
|
202
189
|
def reshape(st:ShapeTracker): return st.reshape(tuple(new_shape_fxn(st.shape))) if new_shape_fxn is not None else st
|
203
190
|
def permute(st:ShapeTracker): return st.permute(tuple(axis)) if axis is not None else st
|
204
191
|
self.sts = [permute(reshape(st)) for st in self.sts]
|
@@ -240,7 +227,7 @@ class Kernel:
|
|
240
227
|
if isinstance(self.membufs[0].dtype, ImageDType):
|
241
228
|
base_shape = self.membufs[0].dtype.shape
|
242
229
|
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
|
243
|
-
special_strides:
|
230
|
+
special_strides: tuple[sint, ...] = tuple()
|
244
231
|
for i,g in enumerate(shape_idx_groups):
|
245
232
|
shape_piece = tuple(self.output_shape[x] for x in g)
|
246
233
|
assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
|
@@ -298,37 +285,34 @@ class Kernel:
|
|
298
285
|
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
|
299
286
|
axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if resolve(self.full_shape[x]%tc.dims[i] != 0))
|
300
287
|
if axis_pads and (opt_level < 2): return None
|
301
|
-
self.bufs_for_tensor_core[reduceop] = (buf0, buf1)
|
302
288
|
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
303
289
|
return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
|
304
290
|
|
305
|
-
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
|
291
|
+
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> bool:
|
306
292
|
if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD:
|
307
|
-
|
293
|
+
tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
|
294
|
+
for tc in tensor_cores:
|
308
295
|
tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
|
309
296
|
# can only fuse reduces with the same tc options
|
310
297
|
assert all_same(tensor_core_opts)
|
311
298
|
if tensor_core_opts[0] is None: continue
|
312
|
-
# tensor core -- unroll the reduce dim, upcast input and local the correct thread pattern
|
313
299
|
self.tensor_core_opts = tc_opts = tensor_core_opts[0]
|
314
300
|
|
315
301
|
# attempt to pad the tensor axes that require it
|
316
302
|
try:
|
317
303
|
for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
318
304
|
except KernelOptError: continue
|
319
|
-
|
320
|
-
for
|
321
|
-
|
322
|
-
for tc_dim, amt in tc.early_upcast_axes: self.apply_opt(Opt(OptOps.UPCAST,tc_opts.axes[tc_dim],amt), append_opt=False)
|
323
|
-
elif opt == "LC":
|
324
|
-
for tc_dim, amt in tc.threads: self.apply_opt(Opt(OptOps.LOCAL,tc_opts.axes[tc_dim],amt), append_opt=False)
|
305
|
+
# tensor core -- unroll the reduce dim (K), upcast and local the inner and outer dims (N, M)
|
306
|
+
for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, amt), append_opt=False)
|
307
|
+
for opt in tc.opts: self.apply_opt(Opt({"u":OptOps.UPCAST, "l":OptOps.LOCAL}[opt[0]], tc_opts.axes[int(opt[1])], 2), append_opt=False)
|
325
308
|
self.tensor_core = tc
|
326
309
|
self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
|
327
310
|
return True
|
328
311
|
return False
|
329
312
|
|
330
|
-
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[
|
331
|
-
|
313
|
+
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[list[Opt]]=None, axis:int=0, tc_select:Optional[int]=None,
|
314
|
+
tc_opt:Optional[int]=None) -> bool:
|
315
|
+
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
|
332
316
|
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
|
333
317
|
|
334
318
|
Keyword arguments:
|
@@ -337,15 +321,19 @@ class Kernel:
|
|
337
321
|
1: enable tensor cores
|
338
322
|
2: apply tensor core shape but don't use UOp.WMMA
|
339
323
|
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
|
324
|
+
tc_select -- specifies which tensor core(s) to use for optimization (default -1)
|
325
|
+
-1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
|
326
|
+
[0-N]: uses only the n'th tensor core available; useful for search
|
340
327
|
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
|
341
328
|
0: applies to only kernels with a single reduce axis and direct UOps.LOAD into Ops.MUL
|
342
329
|
1: allows kernels with multiple reduce axes and also multiplication of UOps.CAST'd buffers
|
343
330
|
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
|
344
331
|
"""
|
332
|
+
if tc_select is None: tc_select = TC_SELECT.value
|
345
333
|
if tc_opt is None: tc_opt = TC_OPT.value
|
346
334
|
if not self.opts.tensor_cores and use_tensor_cores != 2: return False
|
347
335
|
try: # check TC first and apply hand-coded opts if successful
|
348
|
-
self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
|
336
|
+
self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt)))
|
349
337
|
|
350
338
|
if (tc_opts:=self.tensor_core_opts) is not None:
|
351
339
|
if extra_opts is not None:
|
@@ -364,24 +352,28 @@ class Kernel:
|
|
364
352
|
return False
|
365
353
|
|
366
354
|
def apply_opt(self, opt:Opt, append_opt:bool=True):
|
367
|
-
if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP
|
355
|
+
if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals")
|
368
356
|
|
369
357
|
if opt.op is OptOps.TC:
|
370
358
|
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
|
371
|
-
check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
|
372
359
|
check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
|
373
|
-
check(
|
360
|
+
check(opt.axis is not None, "tensor core opts must have an axis")
|
361
|
+
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 2, "tensor core opts must have tc_select and tc_opt")
|
362
|
+
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
|
363
|
+
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
|
364
|
+
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available")
|
374
365
|
self.applied_opts.append(opt)
|
375
366
|
return
|
376
367
|
|
377
368
|
axis = opt.real_axis(self)
|
378
369
|
check(axis < len(self.full_shape), "invalid axis")
|
379
370
|
|
380
|
-
if opt.op is OptOps.SWAP: amt = cast(int, opt.
|
381
|
-
elif opt.
|
382
|
-
|
383
|
-
|
384
|
-
|
371
|
+
if opt.op is OptOps.SWAP: amt = cast(int, opt.arg) # arg is an axis in the SWAPs
|
372
|
+
elif opt.arg is not None:
|
373
|
+
check(isinstance(opt.arg, int), "arg should be int")
|
374
|
+
amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis]
|
375
|
+
check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless")
|
376
|
+
if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}")
|
385
377
|
else: amt = -1
|
386
378
|
|
387
379
|
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
|
@@ -416,18 +408,10 @@ class Kernel:
|
|
416
408
|
self.upcast()
|
417
409
|
elif opt.op is OptOps.UPCAST: # yellow
|
418
410
|
check(axis < self.first_reduce, "upcast is for non-reduce")
|
419
|
-
check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.
|
411
|
+
check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.get_local_axes())), "can't upcast TC locals")
|
420
412
|
check(amt <= 16, "don't upcast more than 16")
|
421
413
|
self.shift_to(axis, amt, insert_before=None)
|
422
414
|
self.upcast()
|
423
|
-
elif opt.op is OptOps.UPCASTMID: # white
|
424
|
-
check(self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
|
425
|
-
axes = self.sts[0].unit_stride_axes()
|
426
|
-
check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
|
427
|
-
check(axes[0] == axis, "wrong axis")
|
428
|
-
check(amt == 4, "don't upcast mid anything but 4")
|
429
|
-
self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces)
|
430
|
-
self.group_for_reduces += 1
|
431
415
|
elif opt.op is OptOps.NOLOCALS:
|
432
416
|
check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
|
433
417
|
check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
|
@@ -441,7 +425,7 @@ class Kernel:
|
|
441
425
|
check(not self.vars, "does not work with symbolic shape")
|
442
426
|
check(axis < self.first_upcast, "cannot pad upcasted")
|
443
427
|
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
444
|
-
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r), f"cannot pad {r}")
|
428
|
+
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}, {}), f"cannot pad {r}")
|
445
429
|
padded = False
|
446
430
|
for i,st in enumerate(self.sts):
|
447
431
|
if (s:=st.shape[axis]) == 1: continue # reduced
|
@@ -460,8 +444,7 @@ class Kernel:
|
|
460
444
|
if isinstance(self.membufs[0].dtype, ImageDType):
|
461
445
|
unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
|
462
446
|
assert unit_stride_axes_mul_4, f"needs a unit stride axis in {self.bufs[0]}"
|
463
|
-
if all(x < self.first_upcast for x in unit_stride_axes_mul_4)
|
464
|
-
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
447
|
+
if all(x < self.first_upcast for x in unit_stride_axes_mul_4): self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
465
448
|
return self
|
466
449
|
|
467
450
|
def hand_coded_optimizations(self) -> Kernel:
|
@@ -496,19 +479,12 @@ class Kernel:
|
|
496
479
|
break
|
497
480
|
except KernelOptError: pass
|
498
481
|
|
499
|
-
# are we upcasting in mid reduce? (only for images)
|
500
|
-
if self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
|
501
|
-
axes = self.sts[0].unit_stride_axes()
|
502
|
-
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
503
|
-
if self.sts[0].shape[axes[0]]%4 == 0:
|
504
|
-
self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4))
|
505
|
-
|
506
482
|
# upcast float4 images
|
507
483
|
for buf_index,buf in enumerate(self.bufs):
|
508
484
|
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
|
509
485
|
if buf.src[0].dtype.__class__ is ImageDType:
|
510
486
|
#assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
|
511
|
-
if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4)
|
487
|
+
if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4):
|
512
488
|
if unit_stride_axes_mul_4[0] < self.first_reduce:
|
513
489
|
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
514
490
|
else:
|
@@ -524,7 +500,7 @@ class Kernel:
|
|
524
500
|
# expression and run test/test_ops.py with IMAGE=2
|
525
501
|
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
526
502
|
# this can be made much smarter
|
527
|
-
to_upcast:
|
503
|
+
to_upcast: list[int] = []
|
528
504
|
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
|
529
505
|
for axis in range(self.first_reduce):
|
530
506
|
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
|
@@ -577,7 +553,7 @@ class Kernel:
|
|
577
553
|
else:
|
578
554
|
# prioritize making expand axes local
|
579
555
|
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] # noqa: E501
|
580
|
-
to_local:
|
556
|
+
to_local: list[tuple[int, int]] = []
|
581
557
|
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
582
558
|
local_size = prod(sz for _, sz in to_local)
|
583
559
|
local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None) # noqa: E501
|
@@ -593,11 +569,11 @@ class Kernel:
|
|
593
569
|
|
594
570
|
# **** kernel outputs ****
|
595
571
|
|
596
|
-
kernel_cnt: Final[
|
572
|
+
kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
|
597
573
|
@functools.cached_property
|
598
574
|
def name(self) -> str:
|
599
575
|
# kernel name (before late upcast)
|
600
|
-
kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op in GroupOp.Buffer for x in self.ast.
|
576
|
+
kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op is Ops.SINK or x.op in GroupOp.Buffer for x in self.ast.toposort) else "E")
|
601
577
|
suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
602
578
|
name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix
|
603
579
|
|
@@ -612,7 +588,10 @@ class Kernel:
|
|
612
588
|
ret = op.replace(src=tuple(fixup_ast(x) for x in op.src))
|
613
589
|
if op.op in GroupOp.Buffer and op in self.bufs:
|
614
590
|
st_uop = self.sts[self.bufs.index(op)].to_uop()
|
615
|
-
|
591
|
+
# NOTE: if CONST got masked after applying opts, we create a new VALID
|
592
|
+
if op.op is Ops.CONST and any(v.mask is not None for v in unwrap(st_uop.st).views): return op.valid(unwrap(st_uop.st))
|
593
|
+
# otherwise we just replace the VIEW source
|
594
|
+
return ret.replace(src=(st_uop,)) if len(op.src) == 1 else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
|
616
595
|
if op.op is Ops.SINK: return ret.replace(arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals))
|
617
596
|
if op.op is Ops.REDUCE_AXIS:
|
618
597
|
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
|
@@ -623,47 +602,43 @@ class Kernel:
|
|
623
602
|
grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces)
|
624
603
|
|
625
604
|
if (tc := self.tensor_core) and (self.use_tensor_cores == 1 or self.use_tensor_cores == 3):
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
[y + (wd if x == 0 else tcd) for x,y in tcd_pattern] + list(range(tcd+len(tc.expanded_shape),len(new_shape)))
|
637
|
-
return st.reshape(new_shape).permute(tuple(permaxis)).reshape(st.shape).simplify()
|
605
|
+
wd, tcd = self.global_dims, self.first_upcast
|
606
|
+
def get_upcast_axes(buf): # upcast along non-zero dimensions of (tc_reduce + tc_upcast)
|
607
|
+
upcast_axes = int(math.log2(tc.elements_per_thread[buf]))
|
608
|
+
return tuple((tcd + len(tc.get_reduce_axes()) + len(tc.get_upcast_axes()) - (i+1), 2) for i in range(upcast_axes))
|
609
|
+
def get_tc_swizzle_st(shape, local_perm, upcast_perm):
|
610
|
+
offset = (tcd - (wd + len(local_perm)))
|
611
|
+
permaxis = list(range(wd)) \
|
612
|
+
+ [wd + x + (offset if x >= len(local_perm) else 0) for x in local_perm] + list(range(wd + len(local_perm), tcd)) \
|
613
|
+
+ [wd + x + (offset if x >= len(local_perm) else 0) for x in upcast_perm] + list(range(tcd + len(upcast_perm), len(shape)))
|
614
|
+
return ShapeTracker.from_shape(shape).permute(tuple(permaxis))
|
638
615
|
|
639
616
|
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
|
640
|
-
for i,
|
641
|
-
if
|
617
|
+
for i, (src, swizzle) in enumerate(zip(srcs, tc.swizzle)):
|
618
|
+
if swizzle: srcs[i] = src.view(get_tc_swizzle_st((src if src.op is Ops.LOAD else src.src[0]).st_arg.shape, *swizzle))
|
642
619
|
|
643
620
|
if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
|
644
621
|
local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape))
|
645
622
|
st = store_st = ShapeTracker.from_shape(local_shape)
|
646
|
-
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (),
|
647
|
-
if
|
623
|
+
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i}")
|
624
|
+
if swizzle: store_st = get_tc_swizzle_st(store_st.shape, *swizzle)
|
648
625
|
local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i])
|
649
626
|
srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store))
|
650
627
|
|
651
|
-
tc_reduce_axes = tuple(
|
652
|
-
if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/
|
653
|
-
|
654
|
-
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device,
|
655
|
-
|
656
|
-
|
657
|
-
UOp(Ops.CONTRACT, dtype=srcs[
|
658
|
-
UOp(
|
659
|
-
|
660
|
-
tc_uop = UOp(Ops.EXPAND, tc.dtype_out, (wmma,), arg=upcast_axes[2])
|
628
|
+
tc_reduce_axes = tuple(tcd + ax for ax, _ in tc.get_reduce_axes())
|
629
|
+
if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/UNROLL to get the vectorization right
|
630
|
+
tc_upcast_axes = (get_upcast_axes(0), get_upcast_axes(1), get_upcast_axes(2))
|
631
|
+
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes)
|
632
|
+
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
|
633
|
+
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
|
634
|
+
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]),
|
635
|
+
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
|
636
|
+
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])
|
661
637
|
|
662
638
|
else: # for TC=3 MUL/SUM instead of WMMA
|
663
639
|
tc_uop = UOp(Ops.REDUCE_AXIS, tc.dtype_out, ((srcs[0] * srcs[1]).cast(tc.dtype_out),), (Ops.ADD, tc_reduce_axes))
|
664
640
|
|
665
|
-
|
666
|
-
return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_reduce_axes)) if new_reduce_axes else tc_uop
|
641
|
+
return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop
|
667
642
|
|
668
643
|
ret = ret.replace(arg = (op.arg[0], axes))
|
669
644
|
if self.group_for_reduces and grouped_axes:
|
@@ -672,7 +647,8 @@ class Kernel:
|
|
672
647
|
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
|
673
648
|
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
|
674
649
|
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
|
675
|
-
|
650
|
+
local_size = st_uop.arg.real_size()
|
651
|
+
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
|
676
652
|
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret)))
|
677
653
|
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
|
678
654
|
if op is self.reduceops[-1]: return grouped_reduce
|
@@ -681,9 +657,7 @@ class Kernel:
|
|
681
657
|
|
682
658
|
return ret
|
683
659
|
|
684
|
-
return graph_rewrite(fixup_ast(self.ast),
|
685
|
-
(UPat({*GroupOp.ALU,Ops.CAST,Ops.BITCAST,Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
|
686
|
-
(UPat(Ops.LOAD, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src)))]))
|
660
|
+
return graph_rewrite(fixup_ast(self.ast), view_left)
|
687
661
|
|
688
662
|
# **** this is the lowerer ****
|
689
663
|
|
@@ -696,58 +670,26 @@ class Kernel:
|
|
696
670
|
if getenv("RAWAST"): print(self.ast)
|
697
671
|
print(modified_ast)
|
698
672
|
print(self.applied_opts)
|
699
|
-
|
673
|
+
# verify AST matches the spec after applying opts
|
674
|
+
if __debug__: type_verify(list(modified_ast.toposort))
|
675
|
+
# TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
|
676
|
+
#if __debug__: type_verify(list(modified_ast.toposort), shape_spec)
|
700
677
|
|
701
|
-
self.uops:
|
678
|
+
self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
|
702
679
|
if DEBUG >= 5: print_uops(self.uops)
|
703
680
|
return self
|
704
681
|
|
705
|
-
def to_program(self, name_override:Optional[str]=None) ->
|
682
|
+
def to_program(self, name_override:Optional[str]=None) -> ProgramSpec:
|
706
683
|
self.linearize()
|
707
684
|
src = self.opts.render(name:=to_function_name(ansiname:=(name_override if name_override is not None else self.name)), self.uops)
|
708
685
|
|
709
|
-
if
|
710
|
-
|
711
|
-
diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, *get_process_replay_ctx(), src))
|
686
|
+
if CAPTURE_PROCESS_REPLAY:
|
687
|
+
diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, ContextVar._cache, src))
|
712
688
|
|
713
689
|
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
714
690
|
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
715
691
|
mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
|
716
|
-
for _, group in itertools.groupby([x for x in self.ast.
|
692
|
+
for _, group in itertools.groupby([x for x in self.ast.toposort if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
|
717
693
|
key=lambda x: (x.op, x.src[0].arg)))
|
718
|
-
return
|
694
|
+
return ProgramSpec(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
|
719
695
|
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|
720
|
-
|
721
|
-
# the living definition of intermediate UOps
|
722
|
-
|
723
|
-
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
|
724
|
-
if not uop.has_st or uop in sts: return
|
725
|
-
# restore globals from the two stage reduce
|
726
|
-
if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL:
|
727
|
-
_assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
|
728
|
-
sts[uop] = sts[local_reduce]
|
729
|
-
return
|
730
|
-
for x in uop.src: _assert_valid_uop(x, st, sts)
|
731
|
-
# only reduceuop is allowed to change shape, limited to turning n to 1
|
732
|
-
if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
|
733
|
-
# movementops are pushed to VIEW
|
734
|
-
elif uop.op is Ops.VIEW:
|
735
|
-
assert len(uop.src) == 0, f"can't swizzle in kernel yet {uop}"
|
736
|
-
st = uop.arg
|
737
|
-
# everything else inherits shape
|
738
|
-
else:
|
739
|
-
st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0]
|
740
|
-
if not all_same(shapes:=[x.shape for x in src_sts]):
|
741
|
-
if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
|
742
|
-
raise AssertionError(f"found implicit expand {sizes} {shapes}")
|
743
|
-
sts[uop] = st
|
744
|
-
|
745
|
-
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
|
746
|
-
assert ast.op is Ops.SINK and all(x.op is Ops.STORE for x in ast.src), "must be SINK"
|
747
|
-
assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size"
|
748
|
-
sts: Dict[UOp, ShapeTracker] = {}
|
749
|
-
for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
|
750
|
-
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
|
751
|
-
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
|
752
|
-
type_verify(list(sts))
|
753
|
-
return sts
|