tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,510 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import itertools, functools, math
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from collections import defaultdict
|
5
|
+
from typing import cast, Final, Callable, Sequence
|
6
|
+
from enum import Enum, auto
|
7
|
+
|
8
|
+
from tinygrad.uop.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, AxisType
|
9
|
+
from tinygrad.uop.spec import type_verify, ast_spec
|
10
|
+
from tinygrad.device import Device
|
11
|
+
from tinygrad.codegen.opt.tc import TensorCore
|
12
|
+
from tinygrad.renderer import Renderer
|
13
|
+
from tinygrad.dtype import ImageDType, AddrSpace
|
14
|
+
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
|
15
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
16
|
+
from tinygrad.shape.view import strides_for_shape, get_contraction
|
17
|
+
from tinygrad.codegen.opt.swizzler import view_left, view_left_through_load
|
18
|
+
|
19
|
+
class OptOps(Enum):
|
20
|
+
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
21
|
+
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
22
|
+
def __lt__(self, x:OptOps): return self.value < x.value
|
23
|
+
|
24
|
+
@dataclass(frozen=True, order=True)
|
25
|
+
class Opt:
|
26
|
+
op: OptOps
|
27
|
+
axis: int|None = None
|
28
|
+
arg: int|tuple|None = None
|
29
|
+
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
|
30
|
+
|
31
|
+
axis_letters = {AxisType.GLOBAL: "g", AxisType.LOCAL: "l", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
32
|
+
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
|
33
|
+
axis_colors = {AxisType.GLOBAL: "blue", AxisType.LOCAL: "cyan", AxisType.LOOP: "WHITE", AxisType.UPCAST: "yellow",
|
34
|
+
AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
|
35
|
+
|
36
|
+
class KernelOptError(Exception): pass
|
37
|
+
def check(cond:bool, msg:str=""):
|
38
|
+
if not cond: raise KernelOptError(msg)
|
39
|
+
|
40
|
+
@dataclass
|
41
|
+
class TensorCoreOptions:
|
42
|
+
axes: tuple[int, ...] # the location of the original N and M axes if still in the shape
|
43
|
+
axes_exist: tuple[bool, ...] # true if the original N and M axes are still in the shape
|
44
|
+
axis_pads: tuple[tuple[int, int], ...]
|
45
|
+
def fix_axes(self, removed_axis:int): # adjust the TC axes if necessary when a dimension is removed
|
46
|
+
axes, axes_exist = list(self.axes), list(self.axes_exist)
|
47
|
+
for tc_dim in [i for i in range(2) if axes_exist[i]]:
|
48
|
+
if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1
|
49
|
+
elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False
|
50
|
+
self.axes, self.axes_exist = tuple(axes), tuple(axes_exist)
|
51
|
+
|
52
|
+
class Kernel:
|
53
|
+
def __init__(self, ast:UOp, opts:Renderer|None=None):
|
54
|
+
assert ast.op is Ops.SINK, ast.op
|
55
|
+
self.ast = ast
|
56
|
+
|
57
|
+
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
58
|
+
# verify AST matches the spec
|
59
|
+
if __debug__: type_verify(list(self.ast.toposort()), ast_spec)
|
60
|
+
|
61
|
+
self.vars: list[Variable] = self.ast.variables()
|
62
|
+
# NOTE: this requires a specific order with the [::-1], this is likely a bug
|
63
|
+
self.bufs: list[UOp] = [x for x in self.ast.toposort() if x.op in GroupOp.Buffer][::-1]
|
64
|
+
|
65
|
+
# create new shapetrackers inside this kernel, we will permute them
|
66
|
+
self.sts: list[ShapeTracker] = [x.st_arg for x in self.bufs]
|
67
|
+
|
68
|
+
# add the shapetrackers for each reduce
|
69
|
+
# we use this to track which axes are reduced in each reduce
|
70
|
+
self.reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE_AXIS]
|
71
|
+
for x in self.reduceops:
|
72
|
+
self.sts.append(unwrap(x.st))
|
73
|
+
self.sts.append(unwrap(x.src[0].st))
|
74
|
+
|
75
|
+
# add a shapetracker to the end to track the full shape, with 0 strides so it can merge
|
76
|
+
full_shape = ast.full_shape
|
77
|
+
self.sts.append(ShapeTracker.from_shape(full_shape, (0,)*len(full_shape)))
|
78
|
+
|
79
|
+
# parameters for optimization
|
80
|
+
self.tensor_core: TensorCore|None = None
|
81
|
+
self.tensor_core_opts: TensorCoreOptions|None = None
|
82
|
+
self.use_tensor_cores: int = 0
|
83
|
+
self.applied_opts: list[Opt] = []
|
84
|
+
self.dont_use_locals = False
|
85
|
+
self.finalized: bool = False
|
86
|
+
|
87
|
+
# group simplifies
|
88
|
+
self.simplify_ones()
|
89
|
+
self.simplify_merge_adjacent()
|
90
|
+
|
91
|
+
# axis types
|
92
|
+
global_loops = AxisType.GLOBAL if self.opts.has_local else AxisType.LOOP
|
93
|
+
self.axis_types: list[AxisType] = [AxisType.REDUCE if resolve(x!=y) else global_loops for x,y in zip(self.output_shape, self.full_shape)]
|
94
|
+
|
95
|
+
# confirm all reduce axes are at the end
|
96
|
+
if (final_reduces := [x for x in self.axis_types if x == AxisType.REDUCE]) and final_reduces != self.axis_types[-len(final_reduces):]:
|
97
|
+
raise RuntimeError(f"reduces are not at the end of the shape {self.full_shape} -> {self.output_shape}")
|
98
|
+
|
99
|
+
def copy(self):
|
100
|
+
ret = type(self).__new__(type(self))
|
101
|
+
|
102
|
+
# base linearizer params
|
103
|
+
ret.opts, ret.ast = self.opts, self.ast
|
104
|
+
|
105
|
+
# things downstream of the AST
|
106
|
+
ret.reduceops, ret.vars, ret.bufs = self.reduceops, self.vars, self.bufs
|
107
|
+
ret.sts = self.sts[:]
|
108
|
+
ret.axis_types = self.axis_types[:]
|
109
|
+
|
110
|
+
# parameters for optimizations
|
111
|
+
ret.applied_opts, ret.dont_use_locals = self.applied_opts[:], self.dont_use_locals
|
112
|
+
ret.tensor_core, ret.tensor_core_opts, ret.use_tensor_cores = self.tensor_core, self.tensor_core_opts, self.use_tensor_cores
|
113
|
+
ret.finalized = self.finalized
|
114
|
+
|
115
|
+
return ret
|
116
|
+
|
117
|
+
@property
|
118
|
+
def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None
|
119
|
+
@property
|
120
|
+
def full_shape(self) -> tuple[sint, ...]: return self.sts[-1].shape
|
121
|
+
|
122
|
+
@property
|
123
|
+
def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape
|
124
|
+
@property
|
125
|
+
def shape_len(self) -> int: return len(self.sts[0].shape)
|
126
|
+
|
127
|
+
def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in argfix(axis_type)]
|
128
|
+
@property
|
129
|
+
def upcasted(self) -> int: return len(self.axes_of(AxisType.UPCAST, AxisType.UNROLL))
|
130
|
+
@property
|
131
|
+
def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE))
|
132
|
+
|
133
|
+
# heuristic helpers
|
134
|
+
@property
|
135
|
+
def upcastable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP) \
|
136
|
+
if isinstance(s:=self.full_shape[i], int) and s > 1]
|
137
|
+
@property
|
138
|
+
def unrollable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE) \
|
139
|
+
if isinstance(s:=self.full_shape[i], int) and s > 1]
|
140
|
+
|
141
|
+
# ******************** colors and names ********************
|
142
|
+
|
143
|
+
def colors(self) -> list[str]:
|
144
|
+
assert len(self.axis_types) == self.shape_len, "colors size mismatch"
|
145
|
+
return [axis_colors[x] if not self.dont_use_locals or not x == AxisType.GLOBAL else "BLUE" for x in self.axis_types]
|
146
|
+
|
147
|
+
def colored_shape(self, pad:int|None=None, dense=False) -> str:
|
148
|
+
shape_strs = [(s if dense else f"{s:4d}") if isinstance(s, int) else s.render() for s in self.full_shape]
|
149
|
+
ret = ' '.join(colored(s, color) for s,color in zip(shape_strs, self.colors()))
|
150
|
+
if pad: ret += ' '*(pad-ansilen(ret))
|
151
|
+
return ret
|
152
|
+
|
153
|
+
kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
|
154
|
+
@functools.cached_property
|
155
|
+
def name(self) -> str:
|
156
|
+
# kernel name (before late upcast)
|
157
|
+
kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op is Ops.SINK or x.op in GroupOp.Buffer for x in self.ast.toposort()) else "E")
|
158
|
+
suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
159
|
+
name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix
|
160
|
+
|
161
|
+
# name the function something unique
|
162
|
+
Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1
|
163
|
+
num = f"n{Kernel.kernel_cnt[function_name]-1}" if Kernel.kernel_cnt[function_name] > 1 else ""
|
164
|
+
return name + colored(num, 'BLACK')
|
165
|
+
|
166
|
+
# ******************** base simplifiers ********************
|
167
|
+
|
168
|
+
# apply reshape and permute to all shapetrackers
|
169
|
+
def reshape(self, new_shape_fxn:Callable[[tuple[sint, ...]], Sequence[sint]]):
|
170
|
+
self.sts = [st.reshape(tuple(new_shape_fxn(st.shape))) for st in self.sts]
|
171
|
+
def permute(self, new_axes:Sequence[int]): self.sts = [st.permute(tuple(new_axes)) for st in self.sts]
|
172
|
+
|
173
|
+
# axis : the axis to pull from
|
174
|
+
# amount : the amount to take
|
175
|
+
# top : if you want to pull that amount from the top
|
176
|
+
# insert_at : place to insert the new stuff
|
177
|
+
def shift_to(self, axis:int, amount:int, new_type:AxisType, top:bool=False, insert_at:int|None=None):
|
178
|
+
if insert_at is None: insert_at = self.shape_len
|
179
|
+
self.axis_types.insert(insert_at, new_type)
|
180
|
+
move_axis = axis if top else axis+1
|
181
|
+
if move_axis < insert_at: insert_at += 1
|
182
|
+
def new_shape_fxn(x): return x[0:axis] + (((amount,x[axis]//amount) if top else (x[axis]//amount,amount)) if x[axis] > 1 else (1,1)) + x[axis+1:]
|
183
|
+
new_axes = [i for i in range(insert_at) if i != move_axis]+[move_axis]+[i for i in range(insert_at, self.shape_len+1) if i != move_axis]
|
184
|
+
self.reshape(new_shape_fxn)
|
185
|
+
self.permute(new_axes)
|
186
|
+
|
187
|
+
# ******************** complex simplifiers ********************
|
188
|
+
|
189
|
+
def simplify_ones(self) -> bool:
|
190
|
+
# remove places where the shape is all ones
|
191
|
+
if any(all_ones:=[s==1 for s in self.full_shape]):
|
192
|
+
if hasattr(self, 'axis_types'):
|
193
|
+
self.axis_types = [x for i,x in enumerate(self.axis_types) if not all_ones[i]]
|
194
|
+
self.reshape(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]])
|
195
|
+
return True
|
196
|
+
return False
|
197
|
+
|
198
|
+
def simplify_merge_adjacent(self):
|
199
|
+
assert not hasattr(self, 'axis_types'), "don't call this after init"
|
200
|
+
if self.shape_len == 0: return
|
201
|
+
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
|
202
|
+
# NOTE: we can't use self.first_reduce yet
|
203
|
+
first_reduce = [resolve(x!=y) for x,y in zip(self.output_shape+(0,), self.full_shape+(1,))].index(True)
|
204
|
+
|
205
|
+
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
|
206
|
+
# TODO: remove membufs
|
207
|
+
membufs = dedup([x.src[0].base for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
|
208
|
+
if isinstance(membufs[0].base.dtype, ImageDType):
|
209
|
+
base_shape = membufs[0].base.dtype.shape
|
210
|
+
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
|
211
|
+
special_strides: tuple[sint, ...] = tuple()
|
212
|
+
for i,g in enumerate(shape_idx_groups):
|
213
|
+
shape_piece = tuple(self.output_shape[x] for x in g)
|
214
|
+
assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
|
215
|
+
special_strides += strides_for_shape(shape_piece)
|
216
|
+
# adding the fake image shape
|
217
|
+
shapes.append(self.output_shape)
|
218
|
+
strides.append(special_strides)
|
219
|
+
|
220
|
+
# merge dimensions if we can, multi _merge_dims
|
221
|
+
# NOTE: this does not always preserve the reduce dimension
|
222
|
+
# TODO: move this into shapetracker, with tests!
|
223
|
+
# TODO: how does this work with multi-reduce?
|
224
|
+
rets = [[(s[0], st[0])] for s,st in zip(shapes, strides)]
|
225
|
+
for i in range(1, len(shapes[0])):
|
226
|
+
can_merge = []
|
227
|
+
for s,st,ret in zip(shapes, strides, rets):
|
228
|
+
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
|
229
|
+
si, sti, last_st = s[i], st[i], ret[-1][1]
|
230
|
+
can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0)))
|
231
|
+
# more can merge than this
|
232
|
+
mergeable = all(can_merge) and i != first_reduce
|
233
|
+
for j,(s,st) in enumerate(zip(shapes, strides)):
|
234
|
+
if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i])
|
235
|
+
else: rets[j].append((s[i], st[i]))
|
236
|
+
|
237
|
+
# do the reshapes
|
238
|
+
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
239
|
+
|
240
|
+
# ******************** apply optimizations ********************
|
241
|
+
|
242
|
+
def real_axis(self, op:OptOps, axis:int|None):
|
243
|
+
try:
|
244
|
+
if axis is None: return -1
|
245
|
+
if op is OptOps.UNROLL: return self.unrollable_dims[axis]
|
246
|
+
if op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[axis]
|
247
|
+
check(axis < self.shape_len, "invalid axis")
|
248
|
+
return axis
|
249
|
+
except IndexError as e: raise KernelOptError from e
|
250
|
+
|
251
|
+
def apply_opt(self, opt:Opt, append_opt:bool=True):
|
252
|
+
if self.finalized: raise RuntimeError("can't optimize Kernel after it's finalized")
|
253
|
+
if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals")
|
254
|
+
|
255
|
+
if opt.op is OptOps.TC:
|
256
|
+
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
|
257
|
+
check(len(self.opts.tensor_cores) > 0, "must have tensor cores")
|
258
|
+
check(opt.axis is not None, "tensor core opts must have an axis")
|
259
|
+
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg")
|
260
|
+
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
|
261
|
+
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
|
262
|
+
check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 2, "use_tensor_cores value is not valid")
|
263
|
+
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available")
|
264
|
+
self.applied_opts.append(opt)
|
265
|
+
return
|
266
|
+
|
267
|
+
axis = self.real_axis(opt.op, opt.axis)
|
268
|
+
|
269
|
+
if opt.op is OptOps.SWAP: amt = self.real_axis(opt.op, cast(int, opt.arg)) # arg is an axis in the SWAPs
|
270
|
+
elif opt.arg is not None:
|
271
|
+
check(isinstance(opt.arg, int), "arg should be int")
|
272
|
+
amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis]
|
273
|
+
check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless")
|
274
|
+
if opt.op is not OptOps.PADTO:
|
275
|
+
# we check both the full_shape and each shape
|
276
|
+
check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}")
|
277
|
+
for st in self.sts: check(st.shape[axis] == 1 or st.shape[axis] % amt == 0, f"no longer valid shift {st.shape[axis]=}, {amt=}")
|
278
|
+
else: amt = -1
|
279
|
+
|
280
|
+
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
|
281
|
+
(self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
|
282
|
+
acc_sz = self.reduceop.dtype.itemsize
|
283
|
+
upcast_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.UPCAST)])
|
284
|
+
local_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.LOCAL)])
|
285
|
+
smem_sz = amt*acc_sz*upcast_sz*local_sz
|
286
|
+
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
|
287
|
+
|
288
|
+
if opt.op is OptOps.LOCAL: # cyan
|
289
|
+
# NOTE: LLVM/CPU can use locals too, but they are treated the same as globals (still helpful for L1 cache)
|
290
|
+
# it's disabled for now since it makes BEAM slow for little gain
|
291
|
+
check(self.opts.has_local, "target does not support local")
|
292
|
+
check(self.axis_types[axis] is AxisType.GLOBAL, "local is for globals")
|
293
|
+
self.shift_to(axis, amt, AxisType.LOCAL, insert_at=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL))+1)
|
294
|
+
elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
|
295
|
+
check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
|
296
|
+
check(self.axis_types[axis] is AxisType.REDUCE, "must be reduce axis to group")
|
297
|
+
check(not self.tensor_core, "can't group with tensor cores")
|
298
|
+
check(len(reduce_axes:=[i for r in self.reduceops for i in r.axis_arg]) == len(set(reduce_axes)), "can't group with parallel reduces")
|
299
|
+
self.shift_to(axis, amt, AxisType.GROUP_REDUCE, top=(opt.op is OptOps.GROUPTOP), insert_at=min(self.axes_of(AxisType.REDUCE)))
|
300
|
+
elif opt.op is OptOps.UNROLL: # purple
|
301
|
+
check(self.axis_types[axis] not in (AxisType.UPCAST, AxisType.UNROLL), "can't upcasted already upcasted")
|
302
|
+
check(amt <= 32, "don't unroll more than 32")
|
303
|
+
self.shift_to(axis, amt, AxisType.UNROLL, insert_at=None)
|
304
|
+
elif opt.op is OptOps.UPCAST: # yellow
|
305
|
+
check(axis in self.upcastable_dims, f"{axis=} not in {self.upcastable_dims=}")
|
306
|
+
# NOTE: assume the first get_local_axes() LOCAL are for TC
|
307
|
+
check(not (self.tensor_core and axis in self.axes_of(AxisType.LOCAL)[:len(self.tensor_core.get_local_axes())]), "can't upcast TC locals")
|
308
|
+
check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
|
309
|
+
self.shift_to(axis, amt, AxisType.UPCAST, insert_at=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP, AxisType.UPCAST))+1)
|
310
|
+
elif opt.op is OptOps.NOLOCALS:
|
311
|
+
check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
|
312
|
+
check(AxisType.LOCAL not in self.axis_types and self.group_for_reduces == 0, "can't have no locals with locals")
|
313
|
+
self.dont_use_locals = True
|
314
|
+
elif opt.op is OptOps.SWAP:
|
315
|
+
check(axis < amt, f"swap is only for axis < amt, getting {amt=}, {axis=}")
|
316
|
+
check(self.axis_types[axis]==self.axis_types[amt]==AxisType.GLOBAL, f"swap is for globals {self.axis_types[axis]=}, {self.axis_types[amt]=}")
|
317
|
+
permute = list(range(self.shape_len))
|
318
|
+
permute[axis], permute[amt] = permute[amt], permute[axis]
|
319
|
+
self.permute(tuple(permute))
|
320
|
+
elif opt.op is OptOps.PADTO:
|
321
|
+
check(not self.vars, "does not work with symbolic shape")
|
322
|
+
check(self.axis_types[axis] not in (AxisType.UPCAST, AxisType.UNROLL), "cannot pad upcasted")
|
323
|
+
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
324
|
+
if (r:=self.reduceop) is not None and self.axis_types[axis] in (AxisType.GROUP_REDUCE, AxisType.REDUCE):
|
325
|
+
check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}")
|
326
|
+
padded = False
|
327
|
+
for i,st in enumerate(self.sts):
|
328
|
+
if (s:=st.shape[axis]) == 1: continue # reduced
|
329
|
+
check(s > amt//4, f"pad adds more than quadruple the work {st.shape[axis]=} > {amt//4=}")
|
330
|
+
if (ru := round_up(cast(int, s), amt) - s):
|
331
|
+
# pad right seems to be faster
|
332
|
+
self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
|
333
|
+
padded = True
|
334
|
+
check(padded, "nothing was padded")
|
335
|
+
|
336
|
+
if append_opt: self.applied_opts.append(opt)
|
337
|
+
if self.simplify_ones() and self.tensor_core_opts:
|
338
|
+
self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
|
339
|
+
|
340
|
+
def apply_opts(self, opts:Sequence[Opt]) -> Kernel:
|
341
|
+
for opt in opts: self.apply_opt(opt)
|
342
|
+
return self
|
343
|
+
|
344
|
+
# **** kernel outputs, mostly tensor cores ****
|
345
|
+
|
346
|
+
def _create_tc_opts(self, reduceop:UOp, tc:TensorCore, axis:int, opt_level:int) -> TensorCoreOptions|None:
|
347
|
+
has_cast = tc.dtype_in != tc.dtype_out
|
348
|
+
if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
|
349
|
+
|
350
|
+
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
|
351
|
+
if mul_op.op is not Ops.MUL: return None
|
352
|
+
|
353
|
+
def buf_index(src:UOp) -> int|None:
|
354
|
+
# TODO: apply tc even if the sources are not from LOAD
|
355
|
+
if src.op is Ops.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(src)
|
356
|
+
try:
|
357
|
+
if opt_level >= 1 and src.op is Ops.CAST and src.dtype == tc.dtype_in: return self.bufs.index(src.src[0])
|
358
|
+
except ValueError: return None
|
359
|
+
return None
|
360
|
+
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
|
361
|
+
|
362
|
+
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
363
|
+
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i in self.upcastable_dims if buf0_strides[i] == 0]
|
364
|
+
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i in self.upcastable_dims if buf1_strides[i] == 0]
|
365
|
+
if not (axis_buf0 and axis_buf1 and (len(self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)) == 1 or (opt_level >= 1))): return None
|
366
|
+
|
367
|
+
axis_choices = list(itertools.product(axis_buf0, axis_buf1, self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)))
|
368
|
+
if not (axis < len(axis_choices)): return None
|
369
|
+
|
370
|
+
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
|
371
|
+
axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if resolve(self.full_shape[x]%tc.dims[i] != 0))
|
372
|
+
if axis_pads and (opt_level < 2): return None
|
373
|
+
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
374
|
+
return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
|
375
|
+
|
376
|
+
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> bool:
|
377
|
+
if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD:
|
378
|
+
tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
|
379
|
+
for tc in tensor_cores:
|
380
|
+
tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
|
381
|
+
if tensor_core_opts[0] is None: continue
|
382
|
+
# can only fuse reduces with the same tc options
|
383
|
+
assert all_same(tensor_core_opts)
|
384
|
+
self.tensor_core_opts = tc_opts = tensor_core_opts[0]
|
385
|
+
|
386
|
+
# attempt to pad the tensor axes that require it
|
387
|
+
try:
|
388
|
+
for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
389
|
+
except KernelOptError: continue
|
390
|
+
# tensor core -- unroll the reduce dim (K), upcast and local the inner and outer dims (N, M)
|
391
|
+
for opt in tc.opts: self.apply_opt(Opt({"u":OptOps.UPCAST, "l":OptOps.LOCAL}[opt[0]], tc_opts.axes[int(opt[1])], 2), append_opt=False)
|
392
|
+
for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, 0, amt), append_opt=False) # TODO: this should be the reduce, not 0
|
393
|
+
self.tensor_core = tc
|
394
|
+
self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
|
395
|
+
return True
|
396
|
+
return False
|
397
|
+
|
398
|
+
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:list[Opt]|None=None, axis:int=0, tc_select:int|None=None, tc_opt:int|None=None) -> bool:
|
399
|
+
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
|
400
|
+
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
|
401
|
+
|
402
|
+
Keyword arguments:
|
403
|
+
use_tensor_cores -- controls how tensor cores are applied (default 1)
|
404
|
+
0: will disable any tensor core matching
|
405
|
+
1: enable tensor cores
|
406
|
+
2: apply tensor core shape but don't use UOp.WMMA
|
407
|
+
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
|
408
|
+
tc_select -- specifies which tensor core(s) to use for optimization (default -1)
|
409
|
+
-1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
|
410
|
+
[0-N]: uses only the n'th tensor core available; useful for search
|
411
|
+
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
|
412
|
+
0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL
|
413
|
+
1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers
|
414
|
+
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
|
415
|
+
"""
|
416
|
+
if tc_select is None: tc_select = TC_SELECT.value
|
417
|
+
if tc_opt is None: tc_opt = TC_OPT.value
|
418
|
+
if not self.opts.tensor_cores: return False
|
419
|
+
try: # check TC first and apply hand-coded opts if successful
|
420
|
+
self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt, use_tensor_cores)))
|
421
|
+
|
422
|
+
if (tc_opts:=self.tensor_core_opts) is not None:
|
423
|
+
if extra_opts is not None: self.apply_opts(extra_opts)
|
424
|
+
else:
|
425
|
+
if AMX: return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
|
426
|
+
# hand-coded TC opts
|
427
|
+
for tc_dim in [tc_dim for tc_dim in [1,0] if tc_opts.axes_exist[tc_dim]]: # attempt to upcast M and N
|
428
|
+
szs = [sz for sz in [5,4,3,2] if self.full_shape[tc_opts.axes[tc_dim]] % sz == 0]
|
429
|
+
if szs: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], szs[0]))
|
430
|
+
|
431
|
+
if tc_opts.axes_exist[0] and (szs := [sz for sz in [4,2] if self.full_shape[tc_opts.axes[0]] % sz == 0]): # attempt to local N
|
432
|
+
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], szs[0]))
|
433
|
+
return True
|
434
|
+
except KernelOptError:
|
435
|
+
return False
|
436
|
+
|
437
|
+
# strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2']
|
438
|
+
def shape_str(self) -> list[str]:
|
439
|
+
ret: list[str] = []
|
440
|
+
cnt: dict[AxisType, int] = {}
|
441
|
+
for x in self.axis_types:
|
442
|
+
cnt[x] = (cnt[x] + 1) if x in cnt else 0
|
443
|
+
ret.append(f"{axis_letters[x]}{cnt[x]}")
|
444
|
+
return ret
|
445
|
+
def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms])
|
446
|
+
|
447
|
+
def get_optimized_ast(self, name_override:str|None=None) -> UOp:
|
448
|
+
@functools.cache
|
449
|
+
def fixup_ast(op:UOp) -> UOp:
|
450
|
+
ret = op.replace(src=tuple(fixup_ast(x) for x in op.src)) # noqa: F821
|
451
|
+
if op.op in GroupOp.Buffer and op in self.bufs:
|
452
|
+
st = self.sts[self.bufs.index(op)]
|
453
|
+
# replace the VIEW source
|
454
|
+
return ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:])
|
455
|
+
if op.op is Ops.SINK:
|
456
|
+
# NOTE: should group_for_reduces be added to the local_dims?
|
457
|
+
# TODO: arg.name should be able to be None
|
458
|
+
kernel_name = ret.arg.name if ret.arg is not None and ret.arg.name != "test" else self.name if name_override is None else name_override
|
459
|
+
return ret.replace(arg=KernelInfo(kernel_name, tuple(self.axis_types), self.dont_use_locals, tuple(self.applied_opts)))
|
460
|
+
if op.op is Ops.REDUCE_AXIS:
|
461
|
+
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
|
462
|
+
changed = tuple(i for i in range(self.shape_len) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
|
463
|
+
axes = tuple(i for i in self.axes_of(AxisType.REDUCE, AxisType.UNROLL) if i in changed)
|
464
|
+
grouped_axes = tuple(i for i in self.axes_of(AxisType.GROUP_REDUCE) if i in changed)
|
465
|
+
if (tc := self.tensor_core) and self.use_tensor_cores == 1:
|
466
|
+
# get reduce/upcast axes for the tensor cores
|
467
|
+
tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))])
|
468
|
+
base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis(tc.base_upcast_axes())])
|
469
|
+
tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)])
|
470
|
+
|
471
|
+
# permute the srcs
|
472
|
+
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
|
473
|
+
for i, (src, permaxis) in enumerate(zip(srcs, tc.permutes_for_shape_str(self.shape_str()))):
|
474
|
+
src_st = (src if src.op is Ops.LOAD else src.src[0]).st_arg
|
475
|
+
srcs[i] = src.view(ShapeTracker.from_shape(src_st.shape).permute(permaxis))
|
476
|
+
|
477
|
+
# construct the op
|
478
|
+
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes)
|
479
|
+
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
|
480
|
+
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
|
481
|
+
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]),
|
482
|
+
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
|
483
|
+
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])
|
484
|
+
|
485
|
+
# preserve any other reduce
|
486
|
+
return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop
|
487
|
+
|
488
|
+
ret = ret.replace(arg = (op.arg[0], axes))
|
489
|
+
if self.group_for_reduces and grouped_axes:
|
490
|
+
local_axes = tuple([i for i,t in enumerate(self.axis_types) if t in (AxisType.LOCAL, AxisType.UPCAST) or i in grouped_axes])
|
491
|
+
slocal, supcast, sgroup = sorted(self.axes_of(AxisType.LOCAL)), sorted(self.axes_of(AxisType.UPCAST)), sorted(grouped_axes)
|
492
|
+
# NOTE: start with UPCAST at the end so it has stride 1 and can merge
|
493
|
+
base_shape = tuple([self.full_shape[i] for i in slocal] + [self.full_shape[i] for i in sgroup] + [self.full_shape[i] for i in supcast])
|
494
|
+
permute_axes = tuple([local_axes.index(i) for i in slocal+sgroup+supcast])
|
495
|
+
local_shape = tuple([s if i in local_axes else 1 for i,s in enumerate(self.full_shape)])
|
496
|
+
local_src_shape = tuple([self.full_shape[i] if i in self.axes_of(AxisType.GLOBAL) else s for i,s in enumerate(local_shape)])
|
497
|
+
st = ShapeTracker.from_shape(base_shape).permute(permute_axes).reshape(local_shape).expand(local_src_shape)
|
498
|
+
local_size = st.real_size()
|
499
|
+
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, addrspace=AddrSpace.LOCAL), (), f"temp{self.reduceops.index(op)}")
|
500
|
+
local_load = local_buffer.view(st).load(local_buffer.view(st).store(ret))
|
501
|
+
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
|
502
|
+
if op is self.reduceops[-1]: return grouped_reduce
|
503
|
+
st = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else s for i,s in enumerate(local_shape)]))
|
504
|
+
return local_buffer.view(st).load(local_buffer.view(st).store(grouped_reduce))
|
505
|
+
|
506
|
+
return ret
|
507
|
+
self.finalized = True
|
508
|
+
fixed_ast = fixup_ast(self.ast)
|
509
|
+
del fixup_ast
|
510
|
+
return graph_rewrite(fixed_ast, view_left+view_left_through_load, name="fixup optimized AST")
|