tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/codegen/kernel.py
CHANGED
@@ -1,51 +1,34 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import itertools, functools
|
2
|
+
import itertools, functools, math
|
3
3
|
from dataclasses import dataclass
|
4
4
|
from collections import defaultdict
|
5
|
-
from typing import Optional,
|
6
|
-
from enum import Enum, auto
|
5
|
+
from typing import Optional, cast, Final, Callable, Sequence
|
7
6
|
|
8
|
-
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops,
|
9
|
-
|
7
|
+
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, view_left, print_uops
|
8
|
+
from tinygrad.ops import PatternMatcher
|
9
|
+
from tinygrad.spec import type_verify, shape_spec
|
10
10
|
from tinygrad.device import Device
|
11
|
-
from tinygrad.renderer import Renderer, TensorCore,
|
11
|
+
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec, Opt, OptOps
|
12
12
|
from tinygrad.dtype import ImageDType
|
13
|
-
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap
|
14
|
-
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
|
13
|
+
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap, ContextVar
|
14
|
+
from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, USE_TC, AMX, CAPTURE_PROCESS_REPLAY
|
15
15
|
from tinygrad.shape.shapetracker import ShapeTracker
|
16
16
|
from tinygrad.shape.view import strides_for_shape
|
17
17
|
from tinygrad.codegen.linearize import linearize_uop
|
18
|
-
from tinygrad.codegen.
|
18
|
+
from tinygrad.codegen.devectorizer import full_graph_rewrite
|
19
19
|
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction
|
20
20
|
|
21
|
-
class OptOps(Enum):
|
22
|
-
TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
23
|
-
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
24
|
-
def __lt__(self, x:OptOps): return self.value < x.value
|
25
|
-
|
26
21
|
class KernelOptError(Exception): pass
|
27
22
|
|
28
23
|
def check(cond:bool, msg:str=""):
|
29
24
|
if not cond: raise KernelOptError(msg)
|
30
25
|
|
31
|
-
@dataclass(frozen=True, order=True)
|
32
|
-
class Opt:
|
33
|
-
op: OptOps
|
34
|
-
axis: Optional[int] = None
|
35
|
-
amt: Optional[int] = None
|
36
|
-
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
|
37
|
-
def real_axis(self, k:Kernel):
|
38
|
-
if self.axis is None: return -1
|
39
|
-
if self.op is OptOps.UNROLL: return k.first_reduce+self.axis
|
40
|
-
if self.op in {OptOps.GROUP, OptOps.GROUPTOP}: return k.first_reduce+k.group_for_reduces+self.axis
|
41
|
-
return self.axis
|
42
|
-
|
43
26
|
@dataclass
|
44
27
|
class TensorCoreOptions:
|
45
|
-
axes:
|
46
|
-
axes_exist:
|
47
|
-
axis_pads:
|
48
|
-
def fix_axes(self, removed_axis:int): # adjust the TC axes if
|
28
|
+
axes: tuple[int, ...] # the location of the original N and M axes if still in the shape
|
29
|
+
axes_exist: tuple[bool, ...] # true if the original N and M axes are still in the shape
|
30
|
+
axis_pads: tuple[tuple[int, int], ...]
|
31
|
+
def fix_axes(self, removed_axis:int): # adjust the TC axes if necessary when a dimension is removed
|
49
32
|
axes, axes_exist = list(self.axes), list(self.axes_exist)
|
50
33
|
for tc_dim in [i for i in range(2) if axes_exist[i]]:
|
51
34
|
if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1
|
@@ -57,32 +40,28 @@ class Kernel:
|
|
57
40
|
if ast.op is Ops.SINK: self.ast = ast
|
58
41
|
|
59
42
|
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
60
|
-
|
61
|
-
|
62
|
-
print("INVALID AST")
|
63
|
-
print(self.ast)
|
64
|
-
raise e
|
43
|
+
# verify AST matches the spec
|
44
|
+
if __debug__: type_verify(list(self.ast.toposort), shape_spec)
|
65
45
|
|
66
|
-
|
67
|
-
def ordered_parents(op:UOp) -> List[UOp]: return dedup([item for x in op.src for item in ordered_parents(x)] + [op])
|
68
|
-
self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS])
|
46
|
+
self.reduceops = [x for x in self.ast.toposort if x.op is Ops.REDUCE_AXIS]
|
69
47
|
|
70
|
-
self.vars:
|
71
|
-
|
48
|
+
self.vars: list[Variable] = self.ast.variables()
|
49
|
+
# NOTE: this requires a specific order with the [::-1], this is likely a bug
|
50
|
+
self.bufs: list[UOp] = [x for x in self.ast.toposort if x.op in GroupOp.Buffer][::-1]
|
72
51
|
|
73
52
|
# get earlybufs, before any reduceops
|
74
|
-
earlybufs:
|
53
|
+
earlybufs: list[UOp] = [x for reduceop in self.reduceops for x in reduceop.src[0].toposort if x.op in GroupOp.Buffer]
|
75
54
|
self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
|
76
55
|
# NOTE: full_shape can be wrong if there's a tree of reduces
|
77
56
|
|
78
57
|
# create new shapetrackers inside this kernel, we will permute them
|
79
|
-
self.sts:
|
58
|
+
self.sts: list[ShapeTracker] = [x.st_arg for x in self.bufs]
|
80
59
|
|
81
60
|
# add the shapetrackers for each reduce
|
82
61
|
# we use this to track which axes are reduced in each reduce
|
83
62
|
for x in self.reduceops:
|
84
|
-
self.sts.append(
|
85
|
-
self.sts.append(
|
63
|
+
self.sts.append(unwrap(x.st))
|
64
|
+
self.sts.append(unwrap(x.src[0].st))
|
86
65
|
|
87
66
|
# move all reduce axes to the end
|
88
67
|
reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
|
@@ -90,15 +69,13 @@ class Kernel:
|
|
90
69
|
self.reshape_and_permute(None, permute)
|
91
70
|
|
92
71
|
# parameters for optimization
|
93
|
-
self.applied_opts:
|
72
|
+
self.applied_opts: list[Opt] = []
|
94
73
|
self.group_for_reduces: int = 0
|
95
74
|
self.upcasted: int = 0
|
96
75
|
self.local_dims: int = 0
|
97
76
|
self.tensor_core: Optional[TensorCore] = None
|
98
77
|
self.tensor_core_opts: Optional[TensorCoreOptions] = None
|
99
78
|
self.use_tensor_cores: int = 0
|
100
|
-
# the local aliased buffers for A and B
|
101
|
-
self.bufs_for_tensor_core: Dict[UOp, Tuple[int, int]] = {}
|
102
79
|
self.dont_use_locals: bool = False
|
103
80
|
|
104
81
|
# group simplifies
|
@@ -112,25 +89,23 @@ class Kernel:
|
|
112
89
|
ret.opts, ret.ast = self.opts, self.ast
|
113
90
|
|
114
91
|
# things downstream of the AST
|
115
|
-
ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index =
|
116
|
-
self.reduceops, self.vars, self.bufs, self.full_buf_index
|
92
|
+
ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = self.reduceops, self.vars, self.bufs, self.full_buf_index
|
117
93
|
ret.sts = self.sts[:len(ret.bufs)+len(ret.reduceops)*2] # NOTE: must redo the local buffers with TC in beam
|
118
94
|
|
119
95
|
# parameters for optimizations
|
120
96
|
ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
|
121
97
|
self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
|
122
|
-
ret.tensor_core, ret.tensor_core_opts, ret.
|
123
|
-
self.tensor_core, self.tensor_core_opts, self.bufs_for_tensor_core, self.use_tensor_cores
|
98
|
+
ret.tensor_core, ret.tensor_core_opts, ret.use_tensor_cores = self.tensor_core, self.tensor_core_opts, self.use_tensor_cores
|
124
99
|
|
125
100
|
return ret
|
126
101
|
|
127
102
|
@property
|
128
|
-
def membufs(self) ->
|
103
|
+
def membufs(self) -> list[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
|
129
104
|
|
130
105
|
# TODO: these need more tests or it might silently be no-op
|
131
106
|
def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501
|
132
107
|
|
133
|
-
def upcasted_axis(self, i:int) ->
|
108
|
+
def upcasted_axis(self, i:int) -> list[tuple[int, Optional[sint], bool]]:
|
134
109
|
upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:]
|
135
110
|
assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
|
136
111
|
return list(zip(upcasted_shape, upcasted_stride,
|
@@ -144,24 +119,20 @@ class Kernel:
|
|
144
119
|
def first_upcast(self) -> int: return self.shape_len-self.upcasted
|
145
120
|
|
146
121
|
@property
|
147
|
-
def reduceop(self) ->
|
122
|
+
def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None
|
148
123
|
|
149
124
|
@property
|
150
|
-
def output_shape(self) ->
|
125
|
+
def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape
|
151
126
|
|
152
127
|
@property
|
153
|
-
def full_shape(self) ->
|
128
|
+
def full_shape(self) -> tuple[sint, ...]: return self.sts[self.full_buf_index].shape
|
154
129
|
|
155
130
|
@property
|
156
|
-
def full_unupcasted_shape(self) ->
|
131
|
+
def full_unupcasted_shape(self) -> tuple[sint, ...]: return self.full_shape[:self.first_upcast]
|
157
132
|
|
158
133
|
@property
|
159
134
|
def shape_len(self) -> int: return len(self.sts[0].shape)
|
160
135
|
|
161
|
-
@property
|
162
|
-
def upcast_in_mid_reduce_axes(self) -> List[int]:
|
163
|
-
return [j for j in range(self.first_reduce, self.first_reduce+self.group_for_reduces) if self.full_shape[j] == self.sts[0].shape[j]]
|
164
|
-
|
165
136
|
@property
|
166
137
|
def global_dims(self) -> int: return self.first_reduce-self.local_dims
|
167
138
|
|
@@ -170,18 +141,17 @@ class Kernel:
|
|
170
141
|
# cyan -- local dims (warp ones first)
|
171
142
|
# *** self.first_reduce
|
172
143
|
# green -- reduce-local dims
|
173
|
-
# white -- reduce-late upcasted dim (self.upcast_in_mid_reduce_axes)
|
174
144
|
# red -- reduce loops
|
175
145
|
# *** self.upcasted
|
176
146
|
# purple -- reduce upcasted
|
177
147
|
# yellow -- normal upcasted dimensions
|
178
|
-
def colors(self) ->
|
148
|
+
def colors(self) -> list[str]:
|
179
149
|
# first non local non reduce dims are global (blue)
|
180
150
|
colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
|
181
151
|
# after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
|
182
152
|
colors += ["cyan"] * self.local_dims
|
183
|
-
# between first_reduce and first_reduce + group_for_reduces, they are
|
184
|
-
colors += ["
|
153
|
+
# between first_reduce and first_reduce + group_for_reduces, they are late upcasted (green)
|
154
|
+
colors += ["green"] * self.group_for_reduces
|
185
155
|
# between first_reduce + group_for_reduces and upcasted, they are reduce (red)
|
186
156
|
colors += ["red"] * (self.first_upcast - (self.first_reduce + self.group_for_reduces))
|
187
157
|
# upcasted dimensions are reduce (magenta) or normal (yellow)
|
@@ -198,7 +168,7 @@ class Kernel:
|
|
198
168
|
# ******************** base simplifiers ********************
|
199
169
|
|
200
170
|
# apply reshape and permute to all shapetrackers
|
201
|
-
def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[
|
171
|
+
def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[tuple[sint, ...]], Sequence[sint]]], axis:Optional[Sequence[int]]):
|
202
172
|
def reshape(st:ShapeTracker): return st.reshape(tuple(new_shape_fxn(st.shape))) if new_shape_fxn is not None else st
|
203
173
|
def permute(st:ShapeTracker): return st.permute(tuple(axis)) if axis is not None else st
|
204
174
|
self.sts = [permute(reshape(st)) for st in self.sts]
|
@@ -240,7 +210,7 @@ class Kernel:
|
|
240
210
|
if isinstance(self.membufs[0].dtype, ImageDType):
|
241
211
|
base_shape = self.membufs[0].dtype.shape
|
242
212
|
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
|
243
|
-
special_strides:
|
213
|
+
special_strides: tuple[sint, ...] = tuple()
|
244
214
|
for i,g in enumerate(shape_idx_groups):
|
245
215
|
shape_piece = tuple(self.output_shape[x] for x in g)
|
246
216
|
assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
|
@@ -298,37 +268,34 @@ class Kernel:
|
|
298
268
|
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
|
299
269
|
axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if resolve(self.full_shape[x]%tc.dims[i] != 0))
|
300
270
|
if axis_pads and (opt_level < 2): return None
|
301
|
-
self.bufs_for_tensor_core[reduceop] = (buf0, buf1)
|
302
271
|
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
303
272
|
return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
|
304
273
|
|
305
|
-
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
|
274
|
+
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> bool:
|
306
275
|
if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD:
|
307
|
-
|
276
|
+
tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
|
277
|
+
for tc in tensor_cores:
|
308
278
|
tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
|
309
279
|
# can only fuse reduces with the same tc options
|
310
280
|
assert all_same(tensor_core_opts)
|
311
281
|
if tensor_core_opts[0] is None: continue
|
312
|
-
# tensor core -- unroll the reduce dim, upcast input and local the correct thread pattern
|
313
282
|
self.tensor_core_opts = tc_opts = tensor_core_opts[0]
|
314
283
|
|
315
284
|
# attempt to pad the tensor axes that require it
|
316
285
|
try:
|
317
286
|
for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
318
287
|
except KernelOptError: continue
|
319
|
-
|
320
|
-
for
|
321
|
-
|
322
|
-
for tc_dim, amt in tc.early_upcast_axes: self.apply_opt(Opt(OptOps.UPCAST,tc_opts.axes[tc_dim],amt), append_opt=False)
|
323
|
-
elif opt == "LC":
|
324
|
-
for tc_dim, amt in tc.threads: self.apply_opt(Opt(OptOps.LOCAL,tc_opts.axes[tc_dim],amt), append_opt=False)
|
288
|
+
# tensor core -- unroll the reduce dim (K), upcast and local the inner and outer dims (N, M)
|
289
|
+
for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, amt), append_opt=False)
|
290
|
+
for opt in tc.opts: self.apply_opt(Opt({"u":OptOps.UPCAST, "l":OptOps.LOCAL}[opt[0]], tc_opts.axes[int(opt[1])], 2), append_opt=False)
|
325
291
|
self.tensor_core = tc
|
326
292
|
self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
|
327
293
|
return True
|
328
294
|
return False
|
329
295
|
|
330
|
-
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[
|
331
|
-
|
296
|
+
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[list[Opt]]=None, axis:int=0, tc_select:Optional[int]=None,
|
297
|
+
tc_opt:Optional[int]=None) -> bool:
|
298
|
+
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
|
332
299
|
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
|
333
300
|
|
334
301
|
Keyword arguments:
|
@@ -337,21 +304,25 @@ class Kernel:
|
|
337
304
|
1: enable tensor cores
|
338
305
|
2: apply tensor core shape but don't use UOp.WMMA
|
339
306
|
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
|
307
|
+
tc_select -- specifies which tensor core(s) to use for optimization (default -1)
|
308
|
+
-1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
|
309
|
+
[0-N]: uses only the n'th tensor core available; useful for search
|
340
310
|
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
|
341
|
-
0: applies to only kernels with a single reduce axis and direct
|
342
|
-
1: allows kernels with multiple reduce axes and also multiplication of
|
311
|
+
0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL
|
312
|
+
1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers
|
343
313
|
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
|
344
314
|
"""
|
315
|
+
if tc_select is None: tc_select = TC_SELECT.value
|
345
316
|
if tc_opt is None: tc_opt = TC_OPT.value
|
346
317
|
if not self.opts.tensor_cores and use_tensor_cores != 2: return False
|
347
318
|
try: # check TC first and apply hand-coded opts if successful
|
348
|
-
self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
|
319
|
+
self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt)))
|
349
320
|
|
350
321
|
if (tc_opts:=self.tensor_core_opts) is not None:
|
351
322
|
if extra_opts is not None:
|
352
323
|
for opt in extra_opts: self.apply_opt(opt)
|
353
324
|
else:
|
354
|
-
if
|
325
|
+
if AMX: return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
|
355
326
|
# hand-coded TC opts
|
356
327
|
for tc_dim in [tc_dim for tc_dim in [1,0] if tc_opts.axes_exist[tc_dim]]: # attempt to upcast M and N
|
357
328
|
szs = [sz for sz in [5,4,3,2] if self.full_shape[tc_opts.axes[tc_dim]] % sz == 0]
|
@@ -363,25 +334,35 @@ class Kernel:
|
|
363
334
|
except KernelOptError:
|
364
335
|
return False
|
365
336
|
|
337
|
+
def real_axis(self, opt:Opt):
|
338
|
+
if opt.axis is None: return -1
|
339
|
+
if opt.op is OptOps.UNROLL: return self.first_reduce+opt.axis
|
340
|
+
if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.first_reduce+self.group_for_reduces+opt.axis
|
341
|
+
return opt.axis
|
342
|
+
|
366
343
|
def apply_opt(self, opt:Opt, append_opt:bool=True):
|
367
|
-
if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP
|
344
|
+
if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals")
|
368
345
|
|
369
346
|
if opt.op is OptOps.TC:
|
370
347
|
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
|
371
|
-
check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
|
372
348
|
check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
|
373
|
-
check(
|
349
|
+
check(opt.axis is not None, "tensor core opts must have an axis")
|
350
|
+
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 2, "tensor core opts must have tc_select and tc_opt")
|
351
|
+
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
|
352
|
+
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
|
353
|
+
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available")
|
374
354
|
self.applied_opts.append(opt)
|
375
355
|
return
|
376
356
|
|
377
|
-
axis =
|
357
|
+
axis = self.real_axis(opt)
|
378
358
|
check(axis < len(self.full_shape), "invalid axis")
|
379
359
|
|
380
|
-
if opt.op is OptOps.SWAP: amt = cast(int, opt.
|
381
|
-
elif opt.
|
382
|
-
|
383
|
-
|
384
|
-
|
360
|
+
if opt.op is OptOps.SWAP: amt = cast(int, opt.arg) # arg is an axis in the SWAPs
|
361
|
+
elif opt.arg is not None:
|
362
|
+
check(isinstance(opt.arg, int), "arg should be int")
|
363
|
+
amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis]
|
364
|
+
check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless")
|
365
|
+
if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}")
|
385
366
|
else: amt = -1
|
386
367
|
|
387
368
|
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
|
@@ -393,6 +374,8 @@ class Kernel:
|
|
393
374
|
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
|
394
375
|
|
395
376
|
if opt.op is OptOps.LOCAL: # cyan
|
377
|
+
# NOTE: LLVM/CPU can use locals too, but they are treated the same as globals (still helpful for L1 cache)
|
378
|
+
# it's disabled for now since it makes BEAM slow for little gain
|
396
379
|
check(self.opts.has_local, "target does not support local")
|
397
380
|
check(axis < self.global_dims, "local is for globals")
|
398
381
|
self.shift_to(axis, amt, insert_before=self.first_reduce)
|
@@ -416,18 +399,10 @@ class Kernel:
|
|
416
399
|
self.upcast()
|
417
400
|
elif opt.op is OptOps.UPCAST: # yellow
|
418
401
|
check(axis < self.first_reduce, "upcast is for non-reduce")
|
419
|
-
check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.
|
420
|
-
check(amt <= 16, "don't upcast more than 16")
|
402
|
+
check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.get_local_axes())), "can't upcast TC locals")
|
403
|
+
check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
|
421
404
|
self.shift_to(axis, amt, insert_before=None)
|
422
405
|
self.upcast()
|
423
|
-
elif opt.op is OptOps.UPCASTMID: # white
|
424
|
-
check(self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
|
425
|
-
axes = self.sts[0].unit_stride_axes()
|
426
|
-
check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
|
427
|
-
check(axes[0] == axis, "wrong axis")
|
428
|
-
check(amt == 4, "don't upcast mid anything but 4")
|
429
|
-
self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces)
|
430
|
-
self.group_for_reduces += 1
|
431
406
|
elif opt.op is OptOps.NOLOCALS:
|
432
407
|
check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
|
433
408
|
check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
|
@@ -441,7 +416,7 @@ class Kernel:
|
|
441
416
|
check(not self.vars, "does not work with symbolic shape")
|
442
417
|
check(axis < self.first_upcast, "cannot pad upcasted")
|
443
418
|
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
444
|
-
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r), f"cannot pad {r}")
|
419
|
+
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}, cache={}), f"cannot pad {r}")
|
445
420
|
padded = False
|
446
421
|
for i,st in enumerate(self.sts):
|
447
422
|
if (s:=st.shape[axis]) == 1: continue # reduced
|
@@ -460,8 +435,7 @@ class Kernel:
|
|
460
435
|
if isinstance(self.membufs[0].dtype, ImageDType):
|
461
436
|
unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
|
462
437
|
assert unit_stride_axes_mul_4, f"needs a unit stride axis in {self.bufs[0]}"
|
463
|
-
if all(x < self.first_upcast for x in unit_stride_axes_mul_4)
|
464
|
-
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
438
|
+
if all(x < self.first_upcast for x in unit_stride_axes_mul_4): self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
465
439
|
return self
|
466
440
|
|
467
441
|
def hand_coded_optimizations(self) -> Kernel:
|
@@ -496,19 +470,12 @@ class Kernel:
|
|
496
470
|
break
|
497
471
|
except KernelOptError: pass
|
498
472
|
|
499
|
-
# are we upcasting in mid reduce? (only for images)
|
500
|
-
if self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
|
501
|
-
axes = self.sts[0].unit_stride_axes()
|
502
|
-
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
503
|
-
if self.sts[0].shape[axes[0]]%4 == 0:
|
504
|
-
self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4))
|
505
|
-
|
506
473
|
# upcast float4 images
|
507
474
|
for buf_index,buf in enumerate(self.bufs):
|
508
475
|
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
|
509
476
|
if buf.src[0].dtype.__class__ is ImageDType:
|
510
477
|
#assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
|
511
|
-
if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4)
|
478
|
+
if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4):
|
512
479
|
if unit_stride_axes_mul_4[0] < self.first_reduce:
|
513
480
|
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
514
481
|
else:
|
@@ -524,7 +491,7 @@ class Kernel:
|
|
524
491
|
# expression and run test/test_ops.py with IMAGE=2
|
525
492
|
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
526
493
|
# this can be made much smarter
|
527
|
-
to_upcast:
|
494
|
+
to_upcast: list[int] = []
|
528
495
|
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
|
529
496
|
for axis in range(self.first_reduce):
|
530
497
|
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
|
@@ -536,7 +503,7 @@ class Kernel:
|
|
536
503
|
for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
537
504
|
|
538
505
|
# potentially do more upcasts of non reduce axes based on a heuristic
|
539
|
-
upcasted_axis = set()
|
506
|
+
upcasted_axis: set[int] = set()
|
540
507
|
while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024):
|
541
508
|
xb_choices = []
|
542
509
|
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
@@ -577,7 +544,7 @@ class Kernel:
|
|
577
544
|
else:
|
578
545
|
# prioritize making expand axes local
|
579
546
|
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] # noqa: E501
|
580
|
-
to_local:
|
547
|
+
to_local: list[tuple[int, int]] = []
|
581
548
|
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
582
549
|
local_size = prod(sz for _, sz in to_local)
|
583
550
|
local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None) # noqa: E501
|
@@ -593,11 +560,11 @@ class Kernel:
|
|
593
560
|
|
594
561
|
# **** kernel outputs ****
|
595
562
|
|
596
|
-
kernel_cnt: Final[
|
563
|
+
kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
|
597
564
|
@functools.cached_property
|
598
565
|
def name(self) -> str:
|
599
566
|
# kernel name (before late upcast)
|
600
|
-
kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op in GroupOp.Buffer for x in self.ast.
|
567
|
+
kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op is Ops.SINK or x.op in GroupOp.Buffer for x in self.ast.toposort) else "E")
|
601
568
|
suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
602
569
|
name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix
|
603
570
|
|
@@ -606,14 +573,19 @@ class Kernel:
|
|
606
573
|
num = f"n{Kernel.kernel_cnt[function_name]-1}" if Kernel.kernel_cnt[function_name] > 1 else ""
|
607
574
|
return name + colored(num, 'BLACK')
|
608
575
|
|
609
|
-
def get_optimized_ast(self) -> UOp:
|
576
|
+
def get_optimized_ast(self, name_override:Optional[str]=None) -> UOp:
|
610
577
|
@functools.lru_cache(None)
|
611
578
|
def fixup_ast(op:UOp) -> UOp:
|
612
579
|
ret = op.replace(src=tuple(fixup_ast(x) for x in op.src))
|
613
580
|
if op.op in GroupOp.Buffer and op in self.bufs:
|
614
581
|
st_uop = self.sts[self.bufs.index(op)].to_uop()
|
615
|
-
|
616
|
-
|
582
|
+
# NOTE: if CONST got masked after applying opts, we create a new VALID
|
583
|
+
if op.op is Ops.CONST and any(v.mask is not None for v in unwrap(st_uop.st).views): return op.valid(unwrap(st_uop.st))
|
584
|
+
# otherwise we just replace the VIEW source
|
585
|
+
return ret.replace(src=(st_uop,)) if len(op.src) == 1 else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
|
586
|
+
if op.op is Ops.SINK:
|
587
|
+
return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override,
|
588
|
+
self.local_dims, self.upcasted, self.dont_use_locals))
|
617
589
|
if op.op is Ops.REDUCE_AXIS:
|
618
590
|
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
|
619
591
|
|
@@ -623,47 +595,43 @@ class Kernel:
|
|
623
595
|
grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces)
|
624
596
|
|
625
597
|
if (tc := self.tensor_core) and (self.use_tensor_cores == 1 or self.use_tensor_cores == 3):
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
[y + (wd if x == 0 else tcd) for x,y in tcd_pattern] + list(range(tcd+len(tc.expanded_shape),len(new_shape)))
|
637
|
-
return st.reshape(new_shape).permute(tuple(permaxis)).reshape(st.shape).simplify()
|
598
|
+
wd, tcd = self.global_dims, self.first_upcast
|
599
|
+
def get_upcast_axes(buf): # upcast along non-zero dimensions of (tc_reduce + tc_upcast)
|
600
|
+
upcast_axes = int(math.log2(tc.elements_per_thread[buf]))
|
601
|
+
return tuple((tcd + len(tc.get_reduce_axes()) + len(tc.get_upcast_axes()) - (i+1), 2) for i in range(upcast_axes))
|
602
|
+
def get_tc_swizzle_st(shape, local_perm, upcast_perm):
|
603
|
+
offset = (tcd - (wd + len(local_perm)))
|
604
|
+
permaxis = list(range(wd)) \
|
605
|
+
+ [wd + x + (offset if x >= len(local_perm) else 0) for x in local_perm] + list(range(wd + len(local_perm), tcd)) \
|
606
|
+
+ [wd + x + (offset if x >= len(local_perm) else 0) for x in upcast_perm] + list(range(tcd + len(upcast_perm), len(shape)))
|
607
|
+
return ShapeTracker.from_shape(shape).permute(tuple(permaxis))
|
638
608
|
|
639
609
|
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
|
640
|
-
for i,
|
641
|
-
if
|
610
|
+
for i, (src, swizzle) in enumerate(zip(srcs, tc.swizzle)):
|
611
|
+
if swizzle: srcs[i] = src.view(get_tc_swizzle_st((src if src.op is Ops.LOAD else src.src[0]).st_arg.shape, *swizzle))
|
642
612
|
|
643
613
|
if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
|
644
614
|
local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape))
|
645
615
|
st = store_st = ShapeTracker.from_shape(local_shape)
|
646
|
-
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (),
|
647
|
-
if
|
616
|
+
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i}")
|
617
|
+
if swizzle: store_st = get_tc_swizzle_st(store_st.shape, *swizzle)
|
648
618
|
local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i])
|
649
619
|
srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store))
|
650
620
|
|
651
|
-
tc_reduce_axes = tuple(
|
652
|
-
if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/
|
653
|
-
|
654
|
-
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device,
|
655
|
-
|
656
|
-
|
657
|
-
UOp(Ops.CONTRACT, dtype=srcs[
|
658
|
-
UOp(
|
659
|
-
|
660
|
-
tc_uop = UOp(Ops.EXPAND, tc.dtype_out, (wmma,), arg=upcast_axes[2])
|
621
|
+
tc_reduce_axes = tuple(tcd + ax for ax, _ in tc.get_reduce_axes())
|
622
|
+
if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/UNROLL to get the vectorization right
|
623
|
+
tc_upcast_axes = (get_upcast_axes(0), get_upcast_axes(1), get_upcast_axes(2))
|
624
|
+
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes)
|
625
|
+
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
|
626
|
+
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
|
627
|
+
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]),
|
628
|
+
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
|
629
|
+
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])
|
661
630
|
|
662
631
|
else: # for TC=3 MUL/SUM instead of WMMA
|
663
632
|
tc_uop = UOp(Ops.REDUCE_AXIS, tc.dtype_out, ((srcs[0] * srcs[1]).cast(tc.dtype_out),), (Ops.ADD, tc_reduce_axes))
|
664
633
|
|
665
|
-
|
666
|
-
return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_reduce_axes)) if new_reduce_axes else tc_uop
|
634
|
+
return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop
|
667
635
|
|
668
636
|
ret = ret.replace(arg = (op.arg[0], axes))
|
669
637
|
if self.group_for_reduces and grouped_axes:
|
@@ -672,7 +640,8 @@ class Kernel:
|
|
672
640
|
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
|
673
641
|
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
|
674
642
|
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
|
675
|
-
|
643
|
+
local_size = st_uop.arg.real_size()
|
644
|
+
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
|
676
645
|
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret)))
|
677
646
|
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
|
678
647
|
if op is self.reduceops[-1]: return grouped_reduce
|
@@ -681,73 +650,44 @@ class Kernel:
|
|
681
650
|
|
682
651
|
return ret
|
683
652
|
|
684
|
-
return graph_rewrite(fixup_ast(self.ast),
|
685
|
-
(UPat({*GroupOp.ALU,Ops.CAST,Ops.BITCAST,Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
|
686
|
-
(UPat(Ops.LOAD, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src)))]))
|
653
|
+
return graph_rewrite(fixup_ast(self.ast), view_left)
|
687
654
|
|
688
655
|
# **** this is the lowerer ****
|
689
656
|
|
690
657
|
@track_rewrites()
|
691
|
-
def linearize(self) -> Kernel:
|
692
|
-
|
658
|
+
def linearize(self, name_override:Optional[str]=None) -> Kernel:
|
659
|
+
# display the AST
|
660
|
+
if getenv("VIZ"): graph_rewrite(self.ast, PatternMatcher([]), name="View Base AST")
|
661
|
+
|
662
|
+
modified_ast = self.get_optimized_ast(name_override)
|
693
663
|
|
694
664
|
if DEBUG >= 3:
|
695
665
|
print(self.name)
|
696
666
|
if getenv("RAWAST"): print(self.ast)
|
697
|
-
|
667
|
+
for i,(buf,st) in enumerate([(buf,st) for buf,st in zip(self.bufs, self.sts) if buf.op not in {Ops.CONST, Ops.VALID}]):
|
668
|
+
print(f"{i:2d}: {str(st.shape):25s} {str(buf.src[0].dtype).replace('dtypes.',''):20s}", st.real_strides())
|
698
669
|
print(self.applied_opts)
|
699
|
-
|
670
|
+
# verify AST matches the spec after applying opts
|
671
|
+
if __debug__: type_verify(list(modified_ast.toposort))
|
672
|
+
# TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
|
673
|
+
#if __debug__: type_verify(list(modified_ast.toposort), shape_spec)
|
700
674
|
|
701
|
-
self.uops:
|
675
|
+
self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
|
702
676
|
if DEBUG >= 5: print_uops(self.uops)
|
703
677
|
return self
|
704
678
|
|
705
|
-
def to_program(self, name_override:Optional[str]=None) ->
|
706
|
-
self.linearize()
|
707
|
-
|
679
|
+
def to_program(self, name_override:Optional[str]=None) -> ProgramSpec:
|
680
|
+
self.linearize(name_override)
|
681
|
+
assert self.uops[0].op is Ops.NAME, "first uop must be name"
|
682
|
+
src = self.opts.render(self.uops)
|
708
683
|
|
709
|
-
if
|
710
|
-
|
711
|
-
diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, *get_process_replay_ctx(), src))
|
684
|
+
if CAPTURE_PROCESS_REPLAY:
|
685
|
+
diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, self.uops[0].arg, ContextVar._cache, src))
|
712
686
|
|
713
687
|
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
714
688
|
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
715
689
|
mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
|
716
|
-
for _, group in itertools.groupby([x for x in self.ast.
|
690
|
+
for _, group in itertools.groupby([x for x in self.ast.toposort if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
|
717
691
|
key=lambda x: (x.op, x.src[0].arg)))
|
718
|
-
return
|
719
|
-
|
720
|
-
|
721
|
-
# the living definition of intermediate UOps
|
722
|
-
|
723
|
-
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
|
724
|
-
if not uop.has_st or uop in sts: return
|
725
|
-
# restore globals from the two stage reduce
|
726
|
-
if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL:
|
727
|
-
_assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
|
728
|
-
sts[uop] = sts[local_reduce]
|
729
|
-
return
|
730
|
-
for x in uop.src: _assert_valid_uop(x, st, sts)
|
731
|
-
# only reduceuop is allowed to change shape, limited to turning n to 1
|
732
|
-
if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
|
733
|
-
# movementops are pushed to VIEW
|
734
|
-
elif uop.op is Ops.VIEW:
|
735
|
-
assert len(uop.src) == 0, f"can't swizzle in kernel yet {uop}"
|
736
|
-
st = uop.arg
|
737
|
-
# everything else inherits shape
|
738
|
-
else:
|
739
|
-
st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0]
|
740
|
-
if not all_same(shapes:=[x.shape for x in src_sts]):
|
741
|
-
if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
|
742
|
-
raise AssertionError(f"found implicit expand {sizes} {shapes}")
|
743
|
-
sts[uop] = st
|
744
|
-
|
745
|
-
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
|
746
|
-
assert ast.op is Ops.SINK and all(x.op is Ops.STORE for x in ast.src), "must be SINK"
|
747
|
-
assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size"
|
748
|
-
sts: Dict[UOp, ShapeTracker] = {}
|
749
|
-
for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
|
750
|
-
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
|
751
|
-
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
|
752
|
-
type_verify(list(sts))
|
753
|
-
return sts
|
692
|
+
return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, self.applied_opts, mem_bytes,
|
693
|
+
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|