tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/engine/schedule.py
CHANGED
@@ -1,245 +1,122 @@
|
|
1
|
-
import sys,
|
1
|
+
import sys, functools, atexit, pickle
|
2
2
|
from collections import defaultdict, deque
|
3
|
-
from dataclasses import dataclass
|
4
|
-
from
|
5
|
-
from tinygrad.ops import
|
6
|
-
from tinygrad.
|
7
|
-
from tinygrad.helpers import
|
8
|
-
from tinygrad.
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
|
5
|
+
from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views
|
6
|
+
from tinygrad.codegen.symbolic import symbolic_simple
|
7
|
+
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv
|
8
|
+
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND
|
9
|
+
from tinygrad.dtype import ImageDType
|
9
10
|
from tinygrad.shape.shapetracker import ShapeTracker
|
10
11
|
from tinygrad.shape.view import View, strides_for_shape
|
11
|
-
from tinygrad.engine.lazy import LazyBuffer
|
12
12
|
from tinygrad.device import Buffer
|
13
|
+
from tinygrad.spec import type_verify, kernel_spec
|
13
14
|
|
14
15
|
# creation can recurse a lot
|
15
16
|
sys.setrecursionlimit(10000)
|
16
17
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
18
|
+
# **** schedule simplifier
|
19
|
+
|
20
|
+
def simplify_stride0_reduce(reduce:UOp, x:UOp):
|
21
|
+
# must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis)
|
22
|
+
if any(v.mask is not None for v in unwrap(x.st).views): return None
|
23
|
+
# must have all stride 0 in the relevant axis (NOTE: can do partial)
|
24
|
+
if not all(unwrap(x.st).views[-1].strides[axis] == 0 for axis in reduce.arg[1]) or not all_int(x.shape): return None
|
25
|
+
prshape = prod(x.shape[i] for i in reduce.arg[1])
|
26
|
+
ret = x.shrink(tuple((0,s) if i not in reduce.arg[1] else (0,1) for i,s in enumerate(x.shape)))
|
27
|
+
match reduce.arg[0]:
|
28
|
+
case Ops.ADD: return ret*prshape
|
29
|
+
case Ops.MUL: return ret.pow(prshape)
|
30
|
+
case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough
|
31
|
+
|
32
|
+
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
33
|
+
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
|
34
|
+
def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
|
35
|
+
new_src = list(alu.src)
|
36
|
+
for i,s in enumerate(alu.src):
|
37
|
+
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
|
38
|
+
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
|
39
|
+
|
40
|
+
sym = symbolic_simple+PatternMatcher([
|
41
|
+
# UOp with size 0 is zero
|
42
|
+
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
|
43
|
+
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
|
44
|
+
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
|
45
|
+
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
46
|
+
# reduce of size 0 is the identity element
|
47
|
+
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
48
|
+
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
49
|
+
# reduce on stride 0 is collapsed
|
50
|
+
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce),
|
51
|
+
# COPY(CONST) creates a new CONST on the destination device
|
52
|
+
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.arg)),
|
53
|
+
# no COPY to same device, except clone (arg is True)
|
54
|
+
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
|
55
|
+
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
|
56
|
+
# remove cast to image when it's already a contiguous image
|
57
|
+
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"))),)),
|
58
|
+
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
57
59
|
# make things that can't be images not images
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
if buf.is_realized:
|
72
|
-
buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
|
73
|
-
op = None
|
74
|
-
elif buf.op is Ops.ASSIGN:
|
75
|
-
target, new_val = [to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs]
|
76
|
-
ctx.assigns.add(ubuf:=target.buf_uop)
|
77
|
-
op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg)
|
78
|
-
else:
|
79
|
-
buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
|
80
|
-
op = UOp(cast(Ops, buf.op), dtype, tuple(to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs),
|
81
|
-
None if buf.op in {Ops.CAST, Ops.BITCAST} else buf.arg)
|
82
|
-
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
|
83
|
-
if op is not None:
|
84
|
-
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
|
85
|
-
lazybufs[buf.buffer] = buf
|
86
|
-
for x in op.src:
|
87
|
-
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[ubuf] = None
|
88
|
-
ctx.allbufs[ubuf] = ret
|
89
|
-
return ret
|
90
|
-
|
91
|
-
# **** AST graph rewrite
|
92
|
-
|
93
|
-
# ** helpers for doing movementops on uops
|
94
|
-
|
95
|
-
def apply_swizzle(u:UOp, arg:ShapeTracker) -> UOp:
|
96
|
-
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u.view(arg), view_left)
|
97
|
-
|
98
|
-
def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
|
99
|
-
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
|
100
|
-
tmp = input_st.permute(permute_axis)
|
101
|
-
return tmp, tmp.shape[-len(axis):]
|
102
|
-
|
103
|
-
# ** movementops rewrite rules
|
104
|
-
|
105
|
-
def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
|
106
|
-
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(src.st).shape), r.axis_arg)
|
107
|
-
prshape = prod(rshape)
|
108
|
-
strides = strides_for_shape(rshape)
|
109
|
-
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
110
|
-
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
|
111
|
-
# update input_st and axis
|
112
|
-
new_input_st = tmp + ShapeTracker(tuple(nv))
|
113
|
-
_, new_rshape = permute_reduce(new_input_st, r.axis_arg)
|
114
|
-
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
|
115
|
-
return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
116
|
-
|
117
|
-
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp:
|
118
|
-
swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st)
|
119
|
-
assert swizzle_st.contiguous, "can't push a non contiguous VIEW down to STORE"
|
120
|
-
assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
|
121
|
-
output_shape = swizzle_st.reduce(root.axis_arg)
|
122
|
-
new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u)
|
123
|
-
return swizzle.src[0].r(root.arg[0], new_axis).view(ShapeTracker.from_shape(output_shape))
|
124
|
-
|
125
|
-
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
126
|
-
swizzles = [x for x in root.src if x.base is not x]
|
127
|
-
if len(swizzles) == 0: return None
|
128
|
-
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
|
129
|
-
assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
|
130
|
-
new_shape, new_input_shape = swizzle_shapes[0]
|
131
|
-
ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src))
|
132
|
-
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
|
133
|
-
|
134
|
-
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
135
|
-
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
136
|
-
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
|
137
|
-
return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
|
138
|
-
|
139
|
-
merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
|
140
|
-
|
141
|
-
# push VIEW to loads
|
142
|
-
view_left = merge_views+PatternMatcher([
|
143
|
-
# VIEW before elementwise ops
|
144
|
-
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
|
145
|
-
# early merge VIEW buffer ops
|
146
|
-
(UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.arg+v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
60
|
+
(UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
|
61
|
+
and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
|
62
|
+
# remove contiguous if we can just view the buffer
|
63
|
+
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
|
64
|
+
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
|
65
|
+
# contiguous/buffer/copy is already contiguous
|
66
|
+
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY)),)), lambda root: root.src[0]),
|
67
|
+
# support for using a contiguous permuted view instead of the parent view if one exists
|
68
|
+
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
|
69
|
+
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
|
70
|
+
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
71
|
+
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"),
|
72
|
+
lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None),
|
147
73
|
])
|
148
74
|
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
(UPat(Ops.REDUCE_AXIS, src=UPat.var("src"), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
|
156
|
-
# push a VIEW down to STORE, through a reduce (ONLY reshapes)
|
157
|
-
(UPat(Ops.REDUCE_AXIS, src=(UPat.var(name="src").view(name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
158
|
-
# push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes)
|
159
|
-
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
|
160
|
-
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
75
|
+
remove_movement_ops = merge_views+PatternMatcher([
|
76
|
+
# NOTE: movement ops are always applied to base
|
77
|
+
(UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
|
78
|
+
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
|
79
|
+
(UPat(Ops.VIEW, name="view"),
|
80
|
+
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
|
161
81
|
])
|
162
82
|
|
163
|
-
#
|
83
|
+
# **** UOp realization
|
164
84
|
|
165
85
|
@dataclass(frozen=True)
|
166
|
-
class
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
bufs: List[UOp] = field(default_factory=list)
|
171
|
-
assign_preloads: List[UOp] = field(default_factory=list)
|
172
|
-
|
173
|
-
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
|
174
|
-
if (st:=unwrap(x.st)) in ctx.sts: return None
|
175
|
-
st, var_vals = st.simplify().unbind()
|
176
|
-
ctx.var_vals.update(var_vals)
|
177
|
-
ctx.sts.add(st)
|
178
|
-
return st.to_uop() if st != x.st else None
|
179
|
-
|
180
|
-
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
181
|
-
ctx.bufs.append(x)
|
182
|
-
return UOp(Ops.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1)
|
183
|
-
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
|
86
|
+
class GrouperContext:
|
87
|
+
assigns: dict[UOp, UOp] # maps realized buffers to assigns
|
88
|
+
realizes: dict[UOp, None] # all the simplified tensor uops we realize
|
89
|
+
children: defaultdict[UOp, dict[UOp, None]] # children graph of tensor uops
|
184
90
|
|
185
|
-
def
|
186
|
-
if b in ctx.assigned: ctx.assign_preloads.append(b)
|
187
|
-
return x.replace(op=Ops.LOAD)
|
91
|
+
def realize(ctx:GrouperContext, tr:UOp) -> None: ctx.realizes[tr] = None
|
188
92
|
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
(
|
193
|
-
|
194
|
-
|
195
|
-
|
93
|
+
def realize_before_view(ctx:GrouperContext, view:UOp, src:UOp) -> None:
|
94
|
+
st = unwrap(view.st)
|
95
|
+
# fold simple pads
|
96
|
+
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
|
97
|
+
return None if can_pad(src, ctx.realizes, cache=dict()) else realize(ctx, src)
|
98
|
+
# early realize before expand
|
99
|
+
if resolve(prod(src.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, src)
|
100
|
+
# otherwise safety check pads
|
101
|
+
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, cache=dict())) else realize(ctx, src)
|
196
102
|
|
197
|
-
|
198
|
-
|
199
|
-
(UPat(Ops.
|
200
|
-
|
103
|
+
do_realize = PatternMatcher([
|
104
|
+
# always realize SINK parents
|
105
|
+
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x.base, None) for x in s.src if x.base.op not in {Ops.CONST, Ops.BIND, Ops.BUFFER})),
|
106
|
+
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
|
107
|
+
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
|
108
|
+
# realize before expand or unsafe pad ops
|
109
|
+
(UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}, name="src"),)), realize_before_view),
|
110
|
+
# realize before COPY
|
111
|
+
(UPat(Ops.COPY, src=(UPat(), UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW}, name="tr"))), realize),
|
201
112
|
])
|
202
113
|
|
203
|
-
|
204
|
-
|
205
|
-
def full_ast_rewrite(pre:UOp, var_vals:Dict[Variable, int], assigned:Set[UOp]) -> Tuple[UOp, ScheduleItemContext]:
|
206
|
-
# fuse and fold store -> loads
|
207
|
-
sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, {x.buf_uop:x.src[2] for x in pre.src})
|
208
|
-
# assert cyclic dependency
|
209
|
-
for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.buf_uop in assigned), key=lambda x:x.buf_uop):
|
210
|
-
if not all_same([x.op for x in ops]):
|
211
|
-
raise RuntimeError(f"cycle detected in kernel.\nhelp: use .contiguous() to break the part loading pre-assign {b} into a different kernel.")
|
212
|
-
# do movementops
|
213
|
-
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
214
|
-
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
215
|
-
if len(assign_targets:=[x.buf_uop for x in sink.sparents if x.op is Ops.ASSIGN]) != 0:
|
216
|
-
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
|
217
|
-
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.buf_uop in assign_targets):
|
218
|
-
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
219
|
-
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
220
|
-
# convert to AST
|
221
|
-
sink = graph_rewrite(graph_rewrite(sink, to_si, ctx:=ScheduleItemContext(var_vals, assigned)), append_bufs, ctx)
|
222
|
-
if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE.append(((pre, var_vals, assigned), sink))
|
223
|
-
return sink, ctx
|
224
|
-
|
225
|
-
PROCESS_REPLAY_CAPTURE: List[Tuple[Tuple, UOp]] = []
|
226
|
-
if getenv("RUN_PROCESS_REPLAY"):
|
227
|
-
@atexit.register
|
228
|
-
def save_process_replay():
|
229
|
-
for x,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(x[0].key), (x, {}, ret))
|
230
|
-
|
231
|
-
# **** Schedule grouping
|
232
|
-
|
233
|
-
def uval(u:UOp) -> UOp:
|
234
|
-
assert is_scheduled(u), f"must be a scheduled op {u}"
|
235
|
-
return to_store.src[0] if (to_store:=u.src[1]).is_contiguous_base else to_store
|
236
|
-
|
237
|
-
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp], realizes:Dict[UOp, UOp],
|
238
|
-
reduce_for_op:Dict[UOp, UOp], group:Dict[UOp, None], cache:Dict[Tuple[UOp, ShapeTracker], None]) -> None:
|
114
|
+
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
|
115
|
+
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
|
239
116
|
"""recursively search the uop for groupable children, realize the UOp if a child can't group"""
|
240
117
|
if (tr, st) in cache: return
|
241
118
|
cache.setdefault((tr, st))
|
242
|
-
rsize = unwrap(
|
119
|
+
rsize = unwrap(r.st).size
|
243
120
|
if tr in realizes and tr is not r:
|
244
121
|
# can only fuse contiguous
|
245
122
|
# max one reduceop per kernel
|
@@ -247,173 +124,335 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:DefaultDict[UOp, Di
|
|
247
124
|
return group.setdefault(tr)
|
248
125
|
for tr_next in children[tr]:
|
249
126
|
# max one reduceop per kernel
|
250
|
-
if
|
127
|
+
if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
|
251
128
|
# can only fuse contiguous
|
252
|
-
if len(st_childs:=dedup(unwrap(x.st) for x in
|
253
|
-
recursive_group(tr_next, st+st_childs[0], r, children,
|
254
|
-
|
255
|
-
def
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
rc_parents.extend(x.base.buf_uop for x in p.src if is_scheduled(x.base) and x.base.buf_uop is not r)
|
264
|
-
# search descendants of the reduceop that can cleanly group
|
265
|
-
descendants: Dict[UOp, None] = {}
|
266
|
-
for tr in group: recursive_group(tr, unwrap(allbufs[tr].st), tr, children, allbufs, realizes, reduce_for_op, descendants, cache={})
|
267
|
-
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
|
268
|
-
|
269
|
-
def group_realizes(ctx:ScheduleContext, realizes:Dict[UOp, UOp]) -> List[List[UOp]]:
|
270
|
-
"""search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
|
129
|
+
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next.src if x.base == tr)) > 1: return group.setdefault(r)
|
130
|
+
recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
|
131
|
+
|
132
|
+
def append_uop(ctx:GrouperContext, u:UOp) -> None:
|
133
|
+
if u.op is Ops.ASSIGN: ctx.assigns[u.buf_uop] = u
|
134
|
+
for s in u.src: ctx.children[s.base][u] = None
|
135
|
+
create_ctx = PatternMatcher([(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}, name="u"), append_uop)])
|
136
|
+
|
137
|
+
def group_realizes(sink:UOp) -> dict[UOp, None]:
|
138
|
+
# start by adding uops that always realize
|
139
|
+
sink = graph_rewrite(sink, do_realize+create_ctx, ctx:=GrouperContext({}, {}, defaultdict(dict)))
|
271
140
|
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
272
|
-
reduce_for_op:
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
if
|
277
|
-
if
|
278
|
-
|
279
|
-
|
280
|
-
recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, realizes, reduce_for_op, group, cache={})
|
141
|
+
reduce_for_op: dict[UOp, UOp] = {}
|
142
|
+
double_reduces: list[UOp] = []
|
143
|
+
for r in sink.toposort:
|
144
|
+
if r.op is not Ops.REDUCE_AXIS: continue
|
145
|
+
if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
|
146
|
+
if r in ctx.realizes: continue
|
147
|
+
group: dict[UOp, None] = {}
|
148
|
+
recursive_group(r, unwrap(r.st), r, ctx.children, ctx.realizes, reduce_for_op, group, cache={})
|
281
149
|
# max one reduceop per kernel
|
282
150
|
can_chase = all(tr not in reduce_for_op for tr in group)
|
283
151
|
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
284
152
|
forced_realize = r in group
|
285
|
-
|
286
|
-
|
153
|
+
# can only have one output
|
154
|
+
if not forced_realize and len(group) > 1: forced_realize = True
|
287
155
|
# can only fuse assign if no other assign_target is used in the kernel
|
288
|
-
if not forced_realize and any(x
|
156
|
+
if not forced_realize and any(x.op is Ops.ASSIGN for x in group):
|
289
157
|
parents = deque((r, *group))
|
290
158
|
while parents and not forced_realize:
|
291
|
-
|
292
|
-
if (
|
293
|
-
if p in realizes: continue
|
294
|
-
parents.extend(
|
159
|
+
p = parents.pop().base
|
160
|
+
if (assign:=ctx.assigns.get(p)) is not None and assign not in group: forced_realize, can_chase = True, False
|
161
|
+
if p in ctx.realizes: continue
|
162
|
+
parents.extend(p.src)
|
295
163
|
if forced_realize or not group:
|
296
164
|
tr = r
|
297
165
|
if can_chase:
|
298
166
|
# can chase this down to contiguous children
|
299
|
-
st = unwrap(
|
167
|
+
st = unwrap(tr.st)
|
300
168
|
while len(ctx.children[tr]) == 1:
|
301
|
-
|
302
|
-
st_childs = dedup(
|
169
|
+
tr_next = next(iter(ctx.children[tr]))
|
170
|
+
st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr)
|
303
171
|
if len(st_childs) > 1: break
|
304
172
|
if st.size != st_childs[0].size: break
|
305
173
|
st = st + st_childs[0]
|
306
|
-
if not st.contiguous or
|
174
|
+
if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
|
307
175
|
tr = tr_next
|
308
176
|
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
309
|
-
if
|
310
|
-
tr =
|
177
|
+
if tr.op is Ops.CAST and tr.dtype.itemsize > tr.src[0].dtype.itemsize:
|
178
|
+
tr = tr.src[0].base
|
311
179
|
group = {tr: None}
|
312
|
-
realizes[tr] =
|
180
|
+
ctx.realizes[tr] = None
|
313
181
|
reduce_for_op.update((tr, r) for tr in group)
|
314
|
-
if FUSE_ARANGE and
|
182
|
+
if FUSE_ARANGE and r.arg[0] is Ops.ADD and r.src[0].base.op is Ops.CONST:
|
183
|
+
# maybe fuse arange with its children
|
184
|
+
if len(flatten(ctx.children[tr] for tr in group)) != 0:
|
185
|
+
for tr in group: del ctx.realizes[tr]
|
315
186
|
# fuse double reduces with no other child
|
316
187
|
for reduceop in double_reduces:
|
317
|
-
top_reduce =
|
318
|
-
if len(ctx.children[top_reduce]) == 1: del realizes[top_reduce]
|
319
|
-
|
320
|
-
for rbuf in reduce_of_const:
|
321
|
-
group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
|
322
|
-
if any(ctx.allbufs[tr].src[1].is_contiguous_base for tr in group): continue
|
323
|
-
kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
|
324
|
-
if len(kernel_children) == 0: continue
|
325
|
-
for tr in group: del realizes[tr]
|
326
|
-
# group BUFFER uops into kernels
|
327
|
-
output_groups: DefaultDict[UOp, List[UOp]] = defaultdict(list)
|
328
|
-
for ubuf in realizes: output_groups[reduce_for_op.get(ubuf, ubuf)].append(ubuf)
|
329
|
-
return list(output_groups.values())
|
330
|
-
|
331
|
-
# **** Schedule creation and BFS toposort
|
332
|
-
|
333
|
-
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> UOp:
|
334
|
-
ctx[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), to_store)
|
335
|
-
return UOp(Ops.LOAD, base.dtype, (b, st.to_uop()))
|
336
|
-
|
337
|
-
def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> Optional[UOp]:
|
338
|
-
base_shape = unwrap(base.st).shape
|
339
|
-
st = unwrap(view.st)
|
340
|
-
# fold simple pads
|
341
|
-
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(base_shape) and resolve(prod(base_shape) >= prod([y-x for x,y in m])):
|
342
|
-
return None if can_pad(base) else realize(ctx, b, to_store, base).view(st)
|
343
|
-
# early realize before expand
|
344
|
-
if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base).view(st)
|
345
|
-
# otherwise safety check pads
|
346
|
-
return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, b, to_store, base).view(st)
|
188
|
+
top_reduce = reduceop.src[0].base
|
189
|
+
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
|
190
|
+
return ctx.realizes
|
347
191
|
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
192
|
+
# break the SINK into kernels
|
193
|
+
|
194
|
+
@dataclass(frozen=True)
|
195
|
+
class Kernel:
|
196
|
+
ast: UOp
|
197
|
+
metadata: tuple[Metadata, ...]
|
198
|
+
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
|
199
|
+
|
200
|
+
@dataclass(frozen=True)
|
201
|
+
class KernelContext:
|
202
|
+
realizes: dict[UOp, None]
|
203
|
+
ops_metadata: dict[UOp, Metadata]
|
204
|
+
|
205
|
+
def create_kernel(ctx:KernelContext, x:UOp):
|
206
|
+
if x not in ctx.realizes: return None
|
207
|
+
assert isinstance(x.device, str), f"buf device in kernel must be string {x.device}"
|
208
|
+
b = x.buf_uop if x.op is Ops.ASSIGN else UOp.new_buffer(x.device, x.size, x.dtype)
|
209
|
+
output_st = ShapeTracker.from_shape(x.shape)
|
210
|
+
# KERNEL nodes become: ASSIGN(VIEW(BUFFER), KERNEL)
|
211
|
+
# TODO: this should be ASSIGN(BUFFER, KERNEL) followed by the output ShapeTracker
|
212
|
+
return b.view(output_st).assign(UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x, (m,) if (m:=ctx.ops_metadata.get(x)) else ())))
|
213
|
+
|
214
|
+
def append_to_kernel(ctx:KernelContext, x:UOp):
|
215
|
+
new_srcs: list[UOp] = []
|
216
|
+
new_metadata: dict[Metadata, None] = dict.fromkeys(x.arg.metadata)
|
217
|
+
for s in x.src:
|
218
|
+
if s.op is Ops.BUFFER or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL) or s in ctx.realizes: new_srcs.append(s)
|
219
|
+
else:
|
220
|
+
new_srcs.extend(s.src)
|
221
|
+
if (m:=ctx.ops_metadata.get(s)) is not None: new_metadata[m] = None
|
222
|
+
return x.replace(src=n, arg=Kernel(x.arg.ast, tuple(new_metadata))) if (n:=tuple(dedup(new_srcs))) != x.src else None
|
223
|
+
|
224
|
+
create_kernels = merge_views+PatternMatcher([
|
225
|
+
(UPat(GroupOp.All-{Ops.KERNEL, Ops.BUFFER}, name="x"), create_kernel),
|
226
|
+
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
|
227
|
+
])
|
228
|
+
|
229
|
+
# **** convert Kernel to a ScheduleItem (for legacy reasons)
|
230
|
+
|
231
|
+
@dataclass(frozen=True)
|
232
|
+
class ScheduleItem:
|
233
|
+
ast: UOp
|
234
|
+
bufs: tuple[Buffer, ...]
|
235
|
+
metadata: tuple[Metadata, ...]
|
236
|
+
@property
|
237
|
+
def outputs(self) -> tuple[Buffer, ...]:
|
238
|
+
"""Read/write or write only buffers in the schedule."""
|
239
|
+
return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
|
240
|
+
@property
|
241
|
+
def inputs(self) -> tuple[Buffer, ...]:
|
242
|
+
"""Read only buffers in the schedule."""
|
243
|
+
return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
|
244
|
+
@functools.cached_property
|
245
|
+
def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
|
246
|
+
|
247
|
+
# **** Kernel creation
|
248
|
+
|
249
|
+
def apply_swizzle(u:UOp) -> UOp:
|
250
|
+
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
|
251
|
+
|
252
|
+
def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
|
253
|
+
input_st = ShapeTracker.from_shape(unwrap(src.st).shape)
|
254
|
+
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
|
255
|
+
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
|
256
|
+
strides = strides_for_shape(rshape)
|
257
|
+
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
258
|
+
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
|
259
|
+
# update input_st and axis
|
260
|
+
new_input_st = tmp + ShapeTracker(tuple(nv))
|
261
|
+
new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
|
262
|
+
return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
263
|
+
|
264
|
+
def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp:
|
265
|
+
if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}")
|
266
|
+
output_shape = swizzle_st.reduce(r.axis_arg)
|
267
|
+
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape))
|
268
|
+
|
269
|
+
def elementwise_view_right(root:UOp) -> UOp|None:
|
270
|
+
if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None
|
271
|
+
assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}"
|
272
|
+
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
|
273
|
+
# push the swizzle from src to root
|
274
|
+
output_swizzle = swizzles[0]
|
275
|
+
new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
|
276
|
+
ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
|
277
|
+
return ret.view(ShapeTracker.from_shape(output_swizzle.shape))
|
278
|
+
|
279
|
+
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
280
|
+
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
281
|
+
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
|
282
|
+
return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
|
283
|
+
|
284
|
+
# push VIEW to children
|
285
|
+
view_right = merge_views+PatternMatcher([
|
286
|
+
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
|
287
|
+
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
|
288
|
+
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
|
289
|
+
# STORE is the last child, so we just merge the ShapeTrackers and store the base
|
290
|
+
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
|
291
|
+
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
|
292
|
+
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
|
293
|
+
# REDUCE(src.view()) -> REDUCE(src).view()
|
294
|
+
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right),
|
295
|
+
# ALU(src.view()) -> ALU(src).view()
|
296
|
+
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right),
|
297
|
+
# double reduce op collapses to a single reduce op
|
298
|
+
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
359
299
|
])
|
360
|
-
|
300
|
+
|
301
|
+
def _append_st_vars(ctx:dict[Variable, int], x:UOp) -> UOp|None:
|
302
|
+
st = unwrap(x.st).simplify()
|
303
|
+
if any(x.op is Ops.BIND for x in st.vars()):
|
304
|
+
st, var_vals = st.unbind()
|
305
|
+
ctx.update(var_vals)
|
306
|
+
return st.to_uop() if st != x.st else None
|
307
|
+
|
308
|
+
def check_load_st(glbl:UOp, view:UOp):
|
309
|
+
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
310
|
+
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
311
|
+
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
|
312
|
+
# if it has a single view and it's equal when you shrink a contig, it's fine
|
313
|
+
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
|
314
|
+
# otherwise, it's not fine
|
315
|
+
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
316
|
+
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
317
|
+
|
318
|
+
fix_kernel_ops = PatternMatcher([
|
319
|
+
# BIND in shapetracker becomes DEFINE_VAR
|
320
|
+
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
321
|
+
# remove CONTIGUOUS/ASSIGN/DEVICE
|
322
|
+
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
323
|
+
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
|
324
|
+
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
|
325
|
+
# no ImageDType after load
|
326
|
+
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
327
|
+
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
328
|
+
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
|
329
|
+
])
|
330
|
+
|
331
|
+
def load_buf(ctx:list[UOp], x:UOp):
|
332
|
+
if x not in ctx: ctx.append(x)
|
333
|
+
return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop()))
|
334
|
+
|
335
|
+
add_buffer_ops = PatternMatcher([
|
336
|
+
# LOAD
|
337
|
+
(UPat(Ops.BUFFER, name="x"), load_buf),
|
338
|
+
# STORE (except for COPY/BUFFER_VIEW)
|
339
|
+
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
|
340
|
+
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
|
341
|
+
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
|
342
|
+
])
|
343
|
+
|
344
|
+
def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
|
345
|
+
ctx[var.replace(src=())] = val.arg
|
346
|
+
return var
|
347
|
+
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
|
348
|
+
|
349
|
+
def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
|
350
|
+
assert sink.op is Ops.ASSIGN and sink.src[1].op is Ops.KERNEL, f"{sink} must be ASSIGN"
|
351
|
+
# substitute kernel sources for the target buffer
|
352
|
+
ast = sink.src[1].arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in sink.src[1].src if s.op is Ops.ASSIGN}).sink()
|
353
|
+
# add buffer ops
|
354
|
+
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True)
|
355
|
+
# unbind_vars + push views to edges
|
356
|
+
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
|
357
|
+
# fix_kernel_ops
|
358
|
+
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
|
359
|
+
# create subbuffer
|
360
|
+
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = bufs[1].buffer.view(ast.size, ast.dtype, (x:=ast.src[0]).st_arg.views[0].offset*x.dtype.itemsize)
|
361
|
+
return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), sink.src[1].arg.metadata)
|
362
|
+
|
363
|
+
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
364
|
+
if CAPTURE_PROCESS_REPLAY:
|
365
|
+
@atexit.register
|
366
|
+
def save_process_replay():
|
367
|
+
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
|
368
|
+
|
369
|
+
# **** schedule creation and toposort
|
361
370
|
|
362
371
|
@track_rewrites(named=True)
|
363
|
-
def create_schedule_with_vars(
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
372
|
+
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
373
|
+
# remove_movement_ops + sym
|
374
|
+
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
|
375
|
+
|
376
|
+
# display the cleaned up tensor graph
|
377
|
+
if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph")
|
378
|
+
|
379
|
+
# do_realize + group_realizes
|
380
|
+
sink = tensor_map[big_sink]
|
381
|
+
realize_map = group_realizes(sink)
|
382
|
+
|
383
|
+
# map tensors to new uops
|
384
|
+
becomes_map: dict[UOp, UOp] = {}
|
385
|
+
rev_tensor_map: dict[UOp, list[UOp]] = {}
|
386
|
+
ops_metadata: dict[UOp, Metadata] = {}
|
387
|
+
for k,v in tensor_map.items():
|
388
|
+
rev_tensor_map.setdefault(v, []).append(k)
|
389
|
+
if k is v: continue
|
390
|
+
if v.base.op is Ops.BUFFER:
|
391
|
+
# VIEW isn't a valid tensor uop, we need to backtrack to the movement op that created it
|
392
|
+
if v.op is Ops.VIEW:
|
393
|
+
mop = [x for x in k.toposort if (xs:=tensor_map[x]).base is v.base and xs.st == v.st][0]
|
394
|
+
if k is not mop: becomes_map[k] = mop
|
395
|
+
else: becomes_map[k] = v
|
396
|
+
elif v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
|
397
|
+
# if we're not realizing this tensor, map its metadata to the simplified uop
|
398
|
+
elif isinstance(k.metadata, Metadata): ops_metadata[v] = k.metadata
|
399
|
+
|
400
|
+
# create kernels
|
401
|
+
if len(realize_map) == 0: return [], {}, becomes_map
|
402
|
+
kernel_map = graph_rewrite_map(sink, create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True)
|
403
|
+
sched_sink = kernel_map[sink]
|
404
|
+
type_verify(list(sched_sink.toposort), kernel_spec)
|
405
|
+
|
406
|
+
# map realized tensors to buffers
|
407
|
+
for k,v in kernel_map.items():
|
408
|
+
if k is v or v.op is not Ops.ASSIGN: continue
|
409
|
+
for t in rev_tensor_map[k]: becomes_map[t] = t.src[0] if t.op is Ops.ASSIGN else v.buf_uop.reshape(t.shape)
|
410
|
+
|
411
|
+
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
412
|
+
kernel_assign: dict[UOp, UOp] = {}
|
413
|
+
assign_rep: dict[UOp, UOp] = {}
|
414
|
+
for u in sched_sink.toposort:
|
415
|
+
if u.op is not Ops.ASSIGN: continue
|
416
|
+
kernel_assign[u.buf_uop] = u
|
417
|
+
for s in u.src[1].src:
|
418
|
+
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
|
419
|
+
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort):
|
420
|
+
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
|
421
|
+
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
|
422
|
+
if assign_rep:
|
423
|
+
sched_sink = sched_sink.substitute(assign_rep)
|
424
|
+
type_verify(list(sched_sink.toposort), kernel_spec)
|
425
|
+
|
426
|
+
# display the final graph
|
427
|
+
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
|
428
|
+
|
429
|
+
# final toposort (bfs)
|
430
|
+
children: dict[UOp, list[UOp]] = {}
|
431
|
+
in_degree: dict[UOp, int] = {}
|
432
|
+
for u in sched_sink.toposort:
|
433
|
+
if u.op is not Ops.ASSIGN: continue
|
434
|
+
in_degree[u] = 0
|
435
|
+
for s in u.src[1].src:
|
436
|
+
if s.op is not Ops.ASSIGN: continue
|
437
|
+
children.setdefault(s, []).append(u)
|
438
|
+
in_degree[u] += 1
|
439
|
+
|
440
|
+
queue = deque(k for k,v in in_degree.items() if v == 0)
|
441
|
+
schedule: list[ScheduleItem] = []
|
442
|
+
var_vals: dict[Variable, int] = {}
|
402
443
|
while queue:
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
for x in graph[si]:
|
444
|
+
u = queue.popleft()
|
445
|
+
schedule.append(schedule_uop(u, var_vals))
|
446
|
+
# increment the refcount of the target buf (this is required by the JIT and memory planner)
|
447
|
+
u.buf_uop.buffer.ref(1)
|
448
|
+
for x in children.get(u, []):
|
409
449
|
in_degree[x] -= 1
|
410
450
|
if in_degree[x] == 0: queue.append(x)
|
451
|
+
|
411
452
|
# confirm everything was scheduled correctly
|
412
|
-
if len(schedule) != (
|
453
|
+
if len(schedule) != (kc:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}")
|
413
454
|
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
schedule, var_vals
|
418
|
-
assert len(var_vals) == 0
|
419
|
-
return schedule
|
455
|
+
# capture process replay
|
456
|
+
if CAPTURE_PROCESS_REPLAY:
|
457
|
+
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, [x.ast for x in schedule]))
|
458
|
+
return schedule, var_vals, becomes_map
|