tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +114 -172
- tinygrad/codegen/linearize.py +211 -81
- tinygrad/codegen/lowerer.py +30 -35
- tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
- tinygrad/codegen/transcendental.py +12 -13
- tinygrad/device.py +170 -47
- tinygrad/dtype.py +28 -26
- tinygrad/engine/jit.py +80 -63
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +162 -0
- tinygrad/engine/realize.py +58 -107
- tinygrad/engine/schedule.py +381 -314
- tinygrad/engine/search.py +40 -44
- tinygrad/gradient.py +70 -0
- tinygrad/helpers.py +77 -58
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +89 -64
- tinygrad/ops.py +562 -446
- tinygrad/renderer/__init__.py +79 -36
- tinygrad/renderer/cstyle.py +70 -84
- tinygrad/renderer/llvmir.py +32 -20
- tinygrad/renderer/ptx.py +79 -99
- tinygrad/renderer/wgsl.py +87 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libpciaccess.py +2023 -0
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +19 -21
- tinygrad/runtime/ops_amd.py +488 -327
- tinygrad/runtime/ops_clang.py +15 -28
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +129 -38
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +45 -40
- tinygrad/runtime/ops_metal.py +93 -73
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +232 -270
- tinygrad/runtime/ops_python.py +51 -46
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +63 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +384 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +26 -4
- tinygrad/runtime/support/hcq.py +254 -324
- tinygrad/runtime/support/llvm.py +32 -0
- tinygrad/shape/shapetracker.py +84 -53
- tinygrad/shape/view.py +103 -138
- tinygrad/spec.py +154 -0
- tinygrad/tensor.py +744 -496
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
- tinygrad-0.10.1.dist-info/RECORD +86 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/engine/schedule.py
CHANGED
@@ -1,241 +1,170 @@
|
|
1
|
-
import sys,
|
1
|
+
import sys, functools
|
2
2
|
from collections import defaultdict, deque
|
3
3
|
from dataclasses import dataclass, field
|
4
|
-
from
|
5
|
-
from tinygrad.ops import
|
6
|
-
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put,
|
7
|
-
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG
|
8
|
-
from tinygrad.dtype import ImageDType
|
4
|
+
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
|
5
|
+
from tinygrad.ops import can_pad, identity_element, resolve, symbolic_simple, view_left, merge_views
|
6
|
+
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap, flatten
|
7
|
+
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY
|
8
|
+
from tinygrad.dtype import ImageDType
|
9
9
|
from tinygrad.shape.shapetracker import ShapeTracker
|
10
10
|
from tinygrad.shape.view import View, strides_for_shape
|
11
|
-
from tinygrad.engine.lazy import LazyBuffer
|
12
11
|
from tinygrad.device import Buffer
|
12
|
+
from tinygrad.spec import type_verify, kernel_spec
|
13
13
|
|
14
14
|
# creation can recurse a lot
|
15
15
|
sys.setrecursionlimit(10000)
|
16
16
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
17
|
+
# **** schedule simplifier
|
18
|
+
|
19
|
+
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
|
20
|
+
if not all_int(x.shape): return None
|
21
|
+
# remove reduce on unmasked const
|
22
|
+
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
|
23
|
+
ret = x.const_arg
|
24
|
+
match reduce.arg[0]:
|
25
|
+
case Ops.ADD: ret *= prshape
|
26
|
+
case Ops.MUL: ret **= prshape
|
27
|
+
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
|
28
|
+
case _: return None
|
29
|
+
return reduce.const_like(ret)
|
30
|
+
|
31
|
+
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
32
|
+
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
|
33
|
+
def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
|
34
|
+
new_src = list(alu.src)
|
35
|
+
for i,s in enumerate(alu.src):
|
36
|
+
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
|
37
|
+
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
|
38
|
+
|
39
|
+
sym = symbolic_simple+PatternMatcher([
|
40
|
+
# UOp with size 0 is zero
|
41
|
+
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
|
42
|
+
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
|
43
|
+
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
|
44
|
+
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
45
|
+
# reduce of size 0 is the identity element
|
46
|
+
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
47
|
+
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
48
|
+
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
|
49
|
+
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
|
50
|
+
# COPY(CONST) creates a new CONST on the destination device
|
51
|
+
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)),
|
52
|
+
# no COPY to same device, except clone (arg is True)
|
53
|
+
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
|
54
|
+
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
|
55
|
+
# remove cast to image when it's already a contiguous image
|
56
|
+
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)),
|
57
|
+
lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
58
|
+
# remove contiguous if we can just view the buffer
|
59
|
+
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
|
60
|
+
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
|
61
|
+
# contiguous/buffer/copy is already contiguous
|
62
|
+
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY)),)), lambda root: root.src[0]),
|
63
|
+
# support for using a contiguous permuted view instead of the parent view if one exists
|
64
|
+
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
|
65
|
+
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
|
66
|
+
# remove CONST/BIND/BUFFER from SINK
|
67
|
+
(UPat(Ops.SINK, name="root"),
|
68
|
+
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
|
69
|
+
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
|
70
|
+
])
|
37
71
|
|
38
|
-
|
72
|
+
remove_movement_ops = merge_views+PatternMatcher([
|
73
|
+
# NOTE: movement ops are always applied to base
|
74
|
+
(UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
|
75
|
+
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
|
76
|
+
(UPat(Ops.VIEW, name="view"),
|
77
|
+
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
|
78
|
+
])
|
39
79
|
|
40
|
-
|
41
|
-
@functools.lru_cache(None)
|
42
|
-
def is_scheduled(u:UOp): return u.op is Ops.VIEW and len(u.src) == 2
|
80
|
+
# **** UOp realization
|
43
81
|
|
44
82
|
@dataclass(frozen=True)
|
45
83
|
class ScheduleContext:
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
allbufs:
|
50
|
-
children:
|
51
|
-
|
52
|
-
|
84
|
+
ops_metadata: dict[UOp, Metadata] # this maps uops in the schedule to the tensor metadata
|
85
|
+
assigns: dict[UOp, None] = field(default_factory=dict) # this holds all the BUFFER uops we ASSIGN to in this schedule
|
86
|
+
realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
|
87
|
+
allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
|
88
|
+
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
89
|
+
preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
90
|
+
|
91
|
+
# wrap tensor uops around a VIEW(BUFFER, <uop>)
|
92
|
+
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
|
93
|
+
def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp:
|
53
94
|
if (r:=cache.get(buf)) is not None: return r
|
95
|
+
# SINK is passthrough
|
96
|
+
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
|
97
|
+
# skip creating buffers for CONST/BIND/DEVICE/BUFFER
|
98
|
+
if buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf
|
99
|
+
if buf.base.op is Ops.BUFFER: return buf.view(unwrap(buf.st))
|
100
|
+
# VIEW is passthrough
|
54
101
|
if buf is not buf.base:
|
55
|
-
cache[buf] = ret =
|
102
|
+
cache[buf] = ret = add_buffers(buf.base, buffer_map, cache).view(unwrap(buf.st))
|
56
103
|
return ret
|
57
104
|
# make things that can't be images not images
|
58
|
-
|
59
|
-
|
60
|
-
if DEBUG >= 2: print(f"forcing image {
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
#
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
# everything else is a VIEW of BUFFER (with an optional op)
|
71
|
-
if buf.is_realized:
|
72
|
-
buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
|
73
|
-
op = None
|
74
|
-
elif buf.op is Ops.ASSIGN:
|
75
|
-
target, new_val = [to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs]
|
76
|
-
ctx.assigns.add(ubuf:=target.buf_uop)
|
77
|
-
op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg)
|
78
|
-
else:
|
79
|
-
buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
|
80
|
-
op = UOp(cast(Ops, buf.op), dtype, tuple(to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs),
|
81
|
-
None if buf.op in {Ops.CAST, Ops.BITCAST} else buf.arg)
|
82
|
-
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
|
83
|
-
if op is not None:
|
84
|
-
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
|
85
|
-
lazybufs[buf.buffer] = buf
|
86
|
-
for x in op.src:
|
87
|
-
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[ubuf] = None
|
88
|
-
ctx.allbufs[ubuf] = ret
|
105
|
+
dtype = buf.dtype
|
106
|
+
if isinstance(dtype, ImageDType) and (prod(buf.shape)!=prod(dtype.shape) or not any(buf.shape[x]%4==0 for x in unwrap(buf.st).unit_stride_axes())):
|
107
|
+
if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}")
|
108
|
+
dtype = buf.dtype.base
|
109
|
+
# ASSIGN already has a target buffer, otherwise we create a new one
|
110
|
+
assert isinstance(buf.device, str), f"buf device is str, not {buf.device}"
|
111
|
+
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
|
112
|
+
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
|
113
|
+
# track the buffer uop for the simplified uop
|
114
|
+
buffer_map[buf] = buf_uop
|
115
|
+
# (early) bufferize
|
116
|
+
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
|
89
117
|
return ret
|
90
118
|
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
def apply_swizzle(u:UOp, arg:ShapeTracker) -> UOp:
|
96
|
-
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u.view(arg), view_left)
|
97
|
-
|
98
|
-
def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
|
99
|
-
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
|
100
|
-
tmp = input_st.permute(permute_axis)
|
101
|
-
return tmp, tmp.shape[-len(axis):]
|
102
|
-
|
103
|
-
# ** movementops rewrite rules
|
104
|
-
|
105
|
-
def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
|
106
|
-
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(src.st).shape), r.axis_arg)
|
107
|
-
prshape = prod(rshape)
|
108
|
-
strides = strides_for_shape(rshape)
|
109
|
-
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
110
|
-
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
|
111
|
-
# update input_st and axis
|
112
|
-
new_input_st = tmp + ShapeTracker(tuple(nv))
|
113
|
-
_, new_rshape = permute_reduce(new_input_st, r.axis_arg)
|
114
|
-
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
|
115
|
-
return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
116
|
-
|
117
|
-
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp:
|
118
|
-
swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st)
|
119
|
-
assert swizzle_st.contiguous, "can't push a non contiguous VIEW down to STORE"
|
120
|
-
assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
|
121
|
-
output_shape = swizzle_st.reduce(root.axis_arg)
|
122
|
-
new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u)
|
123
|
-
return swizzle.src[0].r(root.arg[0], new_axis).view(ShapeTracker.from_shape(output_shape))
|
124
|
-
|
125
|
-
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
126
|
-
swizzles = [x for x in root.src if x.base is not x]
|
127
|
-
if len(swizzles) == 0: return None
|
128
|
-
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
|
129
|
-
assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
|
130
|
-
new_shape, new_input_shape = swizzle_shapes[0]
|
131
|
-
ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src))
|
132
|
-
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
|
119
|
+
class UPatScheduled(UPat):
|
120
|
+
def __init__(self, *args, **kwargs):
|
121
|
+
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
|
133
122
|
|
134
|
-
def
|
135
|
-
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
136
|
-
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
|
137
|
-
return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
|
123
|
+
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
|
138
124
|
|
139
|
-
|
140
|
-
|
141
|
-
#
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
# push VIEW to stores
|
150
|
-
view_right = merge_views+PatternMatcher([
|
151
|
-
# ASSIGN can override st
|
152
|
-
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))),
|
153
|
-
lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None),
|
154
|
-
# non contiguous VIEW on a reduce creates a new VIEW
|
155
|
-
(UPat(Ops.REDUCE_AXIS, src=UPat.var("src"), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
|
156
|
-
# push a VIEW down to STORE, through a reduce (ONLY reshapes)
|
157
|
-
(UPat(Ops.REDUCE_AXIS, src=(UPat.var(name="src").view(name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
158
|
-
# push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes)
|
159
|
-
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
|
160
|
-
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
161
|
-
])
|
162
|
-
|
163
|
-
# ** ScheduleItem context builder
|
164
|
-
|
165
|
-
@dataclass(frozen=True)
|
166
|
-
class ScheduleItemContext:
|
167
|
-
var_vals: Dict[Variable, int]
|
168
|
-
assigned: Set[UOp]
|
169
|
-
sts: Set[ShapeTracker] = field(default_factory=set)
|
170
|
-
bufs: List[UOp] = field(default_factory=list)
|
171
|
-
assign_preloads: List[UOp] = field(default_factory=list)
|
172
|
-
|
173
|
-
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
|
174
|
-
if (st:=unwrap(x.st)) in ctx.sts: return None
|
175
|
-
st, var_vals = st.simplify().unbind()
|
176
|
-
ctx.var_vals.update(var_vals)
|
177
|
-
ctx.sts.add(st)
|
178
|
-
return st.to_uop() if st != x.st else None
|
179
|
-
|
180
|
-
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
181
|
-
ctx.bufs.append(x)
|
182
|
-
return UOp(Ops.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1)
|
183
|
-
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
|
184
|
-
|
185
|
-
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
|
186
|
-
if b in ctx.assigned: ctx.assign_preloads.append(b)
|
187
|
-
return x.replace(op=Ops.LOAD)
|
188
|
-
|
189
|
-
to_si = PatternMatcher([
|
190
|
-
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
191
|
-
(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),
|
192
|
-
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))),
|
193
|
-
])
|
125
|
+
def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
|
126
|
+
st = unwrap(view.st)
|
127
|
+
# fold simple pads
|
128
|
+
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
|
129
|
+
return None if can_pad(src, ctx.realizes, dict()) else realize(ctx, b, src)
|
130
|
+
# early realize before expand
|
131
|
+
if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
|
132
|
+
# otherwise safety check pads
|
133
|
+
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, dict())) else realize(ctx, b, src)
|
194
134
|
|
195
|
-
|
135
|
+
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
|
136
|
+
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
|
137
|
+
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
|
138
|
+
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
|
196
139
|
|
197
|
-
|
198
|
-
|
199
|
-
(UPat(Ops.
|
200
|
-
|
140
|
+
do_realize = PatternMatcher([
|
141
|
+
# always realize SINK parents
|
142
|
+
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
|
143
|
+
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
|
144
|
+
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
|
145
|
+
# realize before expand or unsafe pad ops
|
146
|
+
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view),
|
147
|
+
# realize before COPY or BUFFER_VIEW
|
148
|
+
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
149
|
+
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
150
|
+
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
151
|
+
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
|
201
152
|
])
|
202
153
|
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.buf_uop in assigned), key=lambda x:x.buf_uop):
|
210
|
-
if not all_same([x.op for x in ops]):
|
211
|
-
raise RuntimeError(f"cycle detected in kernel.\nhelp: use .contiguous() to break the part loading pre-assign {b} into a different kernel.")
|
212
|
-
# do movementops
|
213
|
-
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
214
|
-
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
215
|
-
if len(assign_targets:=[x.buf_uop for x in sink.sparents if x.op is Ops.ASSIGN]) != 0:
|
216
|
-
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
|
217
|
-
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.buf_uop in assign_targets):
|
218
|
-
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
219
|
-
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
220
|
-
# convert to AST
|
221
|
-
sink = graph_rewrite(graph_rewrite(sink, to_si, ctx:=ScheduleItemContext(var_vals, assigned)), append_bufs, ctx)
|
222
|
-
if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE.append(((pre, var_vals, assigned), sink))
|
223
|
-
return sink, ctx
|
224
|
-
|
225
|
-
PROCESS_REPLAY_CAPTURE: List[Tuple[Tuple, UOp]] = []
|
226
|
-
if getenv("RUN_PROCESS_REPLAY"):
|
227
|
-
@atexit.register
|
228
|
-
def save_process_replay():
|
229
|
-
for x,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(x[0].key), (x, {}, ret))
|
230
|
-
|
231
|
-
# **** Schedule grouping
|
154
|
+
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
|
155
|
+
ctx.allbufs[buf_uop] = view
|
156
|
+
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns[buf_uop] = None
|
157
|
+
for x in op.base.src:
|
158
|
+
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
|
159
|
+
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
|
232
160
|
|
161
|
+
def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER
|
233
162
|
def uval(u:UOp) -> UOp:
|
234
163
|
assert is_scheduled(u), f"must be a scheduled op {u}"
|
235
|
-
return
|
164
|
+
return u.src[1]
|
236
165
|
|
237
|
-
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:
|
238
|
-
reduce_for_op:
|
166
|
+
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp],
|
167
|
+
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
|
239
168
|
"""recursively search the uop for groupable children, realize the UOp if a child can't group"""
|
240
169
|
if (tr, st) in cache: return
|
241
170
|
cache.setdefault((tr, st))
|
@@ -252,46 +181,32 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:DefaultDict[UOp, Di
|
|
252
181
|
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
|
253
182
|
recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
|
254
183
|
|
255
|
-
def
|
256
|
-
|
257
|
-
|
258
|
-
while rc_parents:
|
259
|
-
if (p:=uval(allbufs[rc_parents.pop()])) in cache: continue
|
260
|
-
cache.add(p)
|
261
|
-
# max one reduceop per kernel
|
262
|
-
if p.op is Ops.REDUCE_AXIS: return {}
|
263
|
-
rc_parents.extend(x.base.buf_uop for x in p.src if is_scheduled(x.base) and x.base.buf_uop is not r)
|
264
|
-
# search descendants of the reduceop that can cleanly group
|
265
|
-
descendants: Dict[UOp, None] = {}
|
266
|
-
for tr in group: recursive_group(tr, unwrap(allbufs[tr].st), tr, children, allbufs, realizes, reduce_for_op, descendants, cache={})
|
267
|
-
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
|
268
|
-
|
269
|
-
def group_realizes(ctx:ScheduleContext, realizes:Dict[UOp, UOp]) -> List[List[UOp]]:
|
270
|
-
"""search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
|
184
|
+
def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]:
|
185
|
+
# start by adding uops that always realize
|
186
|
+
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
|
271
187
|
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
272
|
-
reduce_for_op:
|
273
|
-
|
274
|
-
double_reduces: List[UOp] = []
|
188
|
+
reduce_for_op: dict[UOp, UOp] = {}
|
189
|
+
double_reduces: list[UOp] = []
|
275
190
|
for r, r_uop in ctx.allbufs.items():
|
276
191
|
if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue
|
277
|
-
if FUSE_CONV_BW and
|
278
|
-
if r in realizes: continue
|
279
|
-
group:
|
280
|
-
recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, realizes, reduce_for_op, group, cache={})
|
192
|
+
if FUSE_CONV_BW and is_scheduled((x:=r_uop.src[0]).base) and uval(x.base).op is r_uop.op and x.base is not x: double_reduces.append(r)
|
193
|
+
if r in ctx.realizes: continue
|
194
|
+
group: dict[UOp, None] = {}
|
195
|
+
recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, ctx.realizes, reduce_for_op, group, cache={})
|
281
196
|
# max one reduceop per kernel
|
282
197
|
can_chase = all(tr not in reduce_for_op for tr in group)
|
283
198
|
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
284
199
|
forced_realize = r in group
|
285
|
-
|
286
|
-
|
200
|
+
# can only have one output
|
201
|
+
if not forced_realize and len(group) > 1: forced_realize = True
|
287
202
|
# can only fuse assign if no other assign_target is used in the kernel
|
288
203
|
if not forced_realize and any(x in ctx.assigns for x in group):
|
289
204
|
parents = deque((r, *group))
|
290
205
|
while parents and not forced_realize:
|
291
206
|
if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue
|
292
207
|
if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False
|
293
|
-
if p in realizes: continue
|
294
|
-
parents.extend([x.base.
|
208
|
+
if p in ctx.realizes: continue
|
209
|
+
parents.extend([x.base.buf_uop for x in p_uop.src if x.base.is_realized or (x.base.op is Ops.VIEW and len(x.base.src) != 0)])
|
295
210
|
if forced_realize or not group:
|
296
211
|
tr = r
|
297
212
|
if can_chase:
|
@@ -309,86 +224,241 @@ def group_realizes(ctx:ScheduleContext, realizes:Dict[UOp, UOp]) -> List[List[UO
|
|
309
224
|
if (tr_uop:=uval(ctx.allbufs[tr])).op is Ops.CAST and tr_uop.dtype.base.itemsize > tr_uop.src[0].dtype.base.itemsize:
|
310
225
|
tr = tr_uop.src[0].base.buf_uop
|
311
226
|
group = {tr: None}
|
312
|
-
realizes[tr] = tr
|
227
|
+
ctx.realizes[tr] = tr
|
313
228
|
reduce_for_op.update((tr, r) for tr in group)
|
314
|
-
if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.
|
229
|
+
if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.CONST:
|
230
|
+
# maybe fuse arange with its children
|
231
|
+
if len(flatten(ctx.children[tr] for tr in group)) != 0:
|
232
|
+
for tr in group: del ctx.realizes[tr]
|
315
233
|
# fuse double reduces with no other child
|
316
234
|
for reduceop in double_reduces:
|
317
235
|
top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
|
318
|
-
if len(ctx.children[top_reduce]) == 1: del realizes[top_reduce]
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
st = unwrap(view.st)
|
340
|
-
# fold simple pads
|
341
|
-
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(base_shape) and resolve(prod(base_shape) >= prod([y-x for x,y in m])):
|
342
|
-
return None if can_pad(base) else realize(ctx, b, to_store, base).view(st)
|
343
|
-
# early realize before expand
|
344
|
-
if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base).view(st)
|
345
|
-
# otherwise safety check pads
|
346
|
-
return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, b, to_store, base).view(st)
|
236
|
+
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
|
237
|
+
graph_rewrite(sink, break_sched, ctx)
|
238
|
+
return ctx.realizes
|
239
|
+
|
240
|
+
# break the SINK into stores
|
241
|
+
|
242
|
+
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
|
243
|
+
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
|
244
|
+
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
|
245
|
+
|
246
|
+
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
|
247
|
+
if (m:=ctx.ops_metadata.get(b)) is not None: ctx.ops_metadata[x] = m
|
248
|
+
if b not in ctx.realizes: return x # collapse BUFFER
|
249
|
+
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
|
250
|
+
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
|
251
|
+
|
252
|
+
break_sched = PatternMatcher([
|
253
|
+
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
|
254
|
+
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
|
255
|
+
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
|
256
|
+
])
|
347
257
|
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
258
|
+
# **** convert Kernel to a ScheduleItem (for legacy reasons)
|
259
|
+
|
260
|
+
@dataclass(frozen=True)
|
261
|
+
class ScheduleItem:
|
262
|
+
ast: UOp
|
263
|
+
bufs: tuple[Buffer, ...]
|
264
|
+
metadata: tuple[Metadata, ...]
|
265
|
+
@property
|
266
|
+
def outputs(self) -> tuple[Buffer, ...]:
|
267
|
+
"""Read/write or write only buffers in the schedule."""
|
268
|
+
return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
|
269
|
+
@property
|
270
|
+
def inputs(self) -> tuple[Buffer, ...]:
|
271
|
+
"""Read only buffers in the schedule."""
|
272
|
+
return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
|
273
|
+
@functools.cached_property
|
274
|
+
def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
|
275
|
+
|
276
|
+
def kernel_to_si(k:UOp) -> ScheduleItem:
|
277
|
+
assert k.op is Ops.KERNEL, f"must be KERNEL {k}"
|
278
|
+
return ScheduleItem(k.arg.ast, tuple(u.buf_uop.buffer for u in k.src), k.arg.metadata)
|
279
|
+
|
280
|
+
# **** Kernel creation
|
281
|
+
|
282
|
+
@dataclass(frozen=True)
|
283
|
+
class Kernel:
|
284
|
+
ast: UOp
|
285
|
+
metadata: tuple[Metadata, ...]
|
286
|
+
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
|
287
|
+
|
288
|
+
@dataclass(frozen=True)
|
289
|
+
class ScheduleItemContext:
|
290
|
+
var_vals: dict[Variable, int]
|
291
|
+
bufs: list[UOp] = field(default_factory=list)
|
292
|
+
|
293
|
+
def apply_swizzle(u:UOp) -> UOp:
|
294
|
+
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
|
295
|
+
|
296
|
+
def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
|
297
|
+
input_st = ShapeTracker.from_shape(unwrap(src.st).shape)
|
298
|
+
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
|
299
|
+
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
|
300
|
+
strides = strides_for_shape(rshape)
|
301
|
+
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
302
|
+
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
|
303
|
+
# update input_st and axis
|
304
|
+
new_input_st = tmp + ShapeTracker(tuple(nv))
|
305
|
+
new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
|
306
|
+
return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
307
|
+
|
308
|
+
def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp:
|
309
|
+
if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}")
|
310
|
+
output_shape = swizzle_st.reduce(r.axis_arg)
|
311
|
+
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape))
|
312
|
+
|
313
|
+
def elementwise_view_right(root:UOp) -> UOp|None:
|
314
|
+
if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None
|
315
|
+
assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}"
|
316
|
+
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
|
317
|
+
# push the swizzle from src to root
|
318
|
+
output_swizzle = swizzles[0]
|
319
|
+
new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
|
320
|
+
ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
|
321
|
+
return ret.view(ShapeTracker.from_shape(output_swizzle.shape))
|
322
|
+
|
323
|
+
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
324
|
+
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
325
|
+
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
|
326
|
+
return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
|
327
|
+
|
328
|
+
# push VIEW to children
|
329
|
+
view_right = merge_views+PatternMatcher([
|
330
|
+
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
|
331
|
+
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
|
332
|
+
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
|
333
|
+
# STORE is the last child, so we just merge the ShapeTrackers and store the base
|
334
|
+
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
|
335
|
+
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
|
336
|
+
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
|
337
|
+
# REDUCE(src.view()) -> REDUCE(src).view()
|
338
|
+
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right),
|
339
|
+
# ALU(src.view()) -> ALU(src).view()
|
340
|
+
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right),
|
341
|
+
# double reduce op collapses to a single reduce op
|
342
|
+
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
359
343
|
])
|
360
|
-
|
344
|
+
|
345
|
+
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
|
346
|
+
st = unwrap(x.st).simplify()
|
347
|
+
if any(x.op is Ops.BIND for x in st.vars()):
|
348
|
+
st, var_vals = st.unbind()
|
349
|
+
ctx.var_vals.update(var_vals)
|
350
|
+
return st.to_uop() if st != x.st else None
|
351
|
+
|
352
|
+
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
353
|
+
ctx.bufs.append(x)
|
354
|
+
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
|
355
|
+
|
356
|
+
to_si = PatternMatcher([
|
357
|
+
# BUFFER -> DEFINE_GLOBAL
|
358
|
+
(UPat(Ops.BUFFER, name="x"), _append_buf),
|
359
|
+
# simplify and unbind the final VIEWs
|
360
|
+
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
361
|
+
# don't need SINK on COPY or BUFFER_VIEW
|
362
|
+
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
|
363
|
+
# don't need contiguous or assign anymore
|
364
|
+
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
365
|
+
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
|
366
|
+
# don't need DEVICE anymore
|
367
|
+
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
|
368
|
+
# PRELOAD becomes LOAD
|
369
|
+
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
|
370
|
+
# once images are loaded they become the base dtype
|
371
|
+
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
372
|
+
])
|
373
|
+
|
374
|
+
def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
|
375
|
+
ctx[var.replace(src=())] = val.arg
|
376
|
+
return var
|
377
|
+
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
|
378
|
+
|
379
|
+
def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> UOp:
|
380
|
+
# unbind_vars + push views to edges
|
381
|
+
sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=var_vals), view_right)
|
382
|
+
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
|
383
|
+
ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(var_vals))
|
384
|
+
# deal with ASSIGN
|
385
|
+
if len(ctx.assigns) != 0:
|
386
|
+
assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer]
|
387
|
+
for x in list(sink.toposort)[::-1]:
|
388
|
+
# we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
|
389
|
+
if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
|
390
|
+
# PRELOAD tells the toposort this kernel should run before ASSIGN
|
391
|
+
if x.op is Ops.PRELOAD:
|
392
|
+
assign_preloads[x.buf_uop] = None
|
393
|
+
# if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
|
394
|
+
if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous:
|
395
|
+
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
396
|
+
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass
|
397
|
+
# if it has a single view and it's equal when you shrink a contig, it's fine
|
398
|
+
elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass
|
399
|
+
# otherwise, it's not fine
|
400
|
+
else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
401
|
+
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
402
|
+
# NOTE: we only add the metadata for fused tensors
|
403
|
+
metadata = tuple(dedup(m for x in pre.toposort if x.op is not Ops.BUFFER and (m:=ctx.ops_metadata.get(x)) is not None))
|
404
|
+
return UOp(Ops.KERNEL, src=tuple(si_ctx.bufs), arg=Kernel(ast, metadata))
|
405
|
+
|
406
|
+
# **** schedule creation and toposort
|
361
407
|
|
362
408
|
@track_rewrites(named=True)
|
363
|
-
def create_schedule_with_vars(
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
409
|
+
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
410
|
+
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
|
411
|
+
# tensors can become an existing buffer or simplify to a const, no ScheduleItem needed
|
412
|
+
becomes_map: dict[UOp, UOp] = {}
|
413
|
+
for k,v in tensor_map.items():
|
414
|
+
# NOOP
|
415
|
+
if k.base is v.base: continue
|
416
|
+
# NOTE: only the base tensors get a BUFFER UOp
|
417
|
+
if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st))
|
418
|
+
# otherwise if it simplified to a CONST the UOp just becomes that CONST
|
419
|
+
elif v.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
|
420
|
+
|
421
|
+
# we group the rest of UOps into ScheduleItems
|
422
|
+
buffer_map: dict[UOp, UOp] = {}
|
423
|
+
sink = add_buffers(tensor_map[big_sink], buffer_map, cache={})
|
372
424
|
# get realizes
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
#
|
425
|
+
buf_tensors: dict[UOp, list[UOp]] = {}
|
426
|
+
ops_metadata: dict[UOp, Metadata] = {}
|
427
|
+
for k,v in tensor_map.items():
|
428
|
+
if (b:=buffer_map.get(v)) is not None:
|
429
|
+
buf_tensors.setdefault(b, []).append(k)
|
430
|
+
ops_metadata[b] = k.metadata
|
431
|
+
realize_map = group_realizes(sink, ctx:=ScheduleContext(ops_metadata))
|
432
|
+
|
433
|
+
# TODO: this should be the break between the "grouper" and the "linearizer"
|
434
|
+
# here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
|
435
|
+
# call into `def linearize_schedule(sched_sink:UOp) -> list[ScheduleItem]`
|
436
|
+
|
437
|
+
# create kernels + map buffers to realized tensors
|
438
|
+
sinks: list[UOp] = []
|
439
|
+
var_vals: dict[Variable, int] = {}
|
440
|
+
for buf_uop,store in realize_map.items():
|
441
|
+
assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}"
|
442
|
+
sinks.append(schedule_uop(store.sink(), ctx, var_vals))
|
443
|
+
# can only schedule once
|
444
|
+
for tensor_uop in buf_tensors[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
|
445
|
+
# increment refcount for this buffer
|
446
|
+
buf_uop.buffer.ref(1)
|
447
|
+
sched_sink = UOp(Ops.SINK, src=tuple(sinks))
|
448
|
+
# display, TODO: this isn't a complete sched_sink yet
|
449
|
+
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]))
|
450
|
+
type_verify(list(sched_sink.toposort), kernel_spec)
|
451
|
+
|
452
|
+
# convert kernels to ScheduleItem
|
453
|
+
prescheduled = [kernel_to_si(k) for k in sched_sink.src]
|
454
|
+
# add ScheduleItem children
|
455
|
+
# TODO: this should construct the graph directly from the sched_sink
|
386
456
|
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
387
|
-
graph:
|
388
|
-
in_degree:
|
457
|
+
graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)
|
458
|
+
in_degree: defaultdict[ScheduleItem, int] = defaultdict(int)
|
389
459
|
for si in prescheduled:
|
390
460
|
# realize outputs before a parent is assigned to
|
391
|
-
parents_assigns = dedup(xsi for x in si.
|
461
|
+
parents_assigns = dedup(xsi for x in ctx.preloads[si.bufs[0]] if (xsi:=schedule_targets.get(x.buffer)) and xsi is not si)
|
392
462
|
for assign in parents_assigns:
|
393
463
|
graph[si].append(assign)
|
394
464
|
in_degree[assign] += 1
|
@@ -397,23 +467,20 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
|
397
467
|
for x in scheduled_parents:
|
398
468
|
graph[x].append(si)
|
399
469
|
in_degree[si] += 1
|
470
|
+
|
471
|
+
# do BFS
|
400
472
|
queue = deque(si for si in prescheduled if in_degree[si] == 0)
|
401
|
-
schedule:
|
473
|
+
schedule: list[ScheduleItem] = []
|
402
474
|
while queue:
|
403
475
|
schedule.append(si:=queue.popleft())
|
404
|
-
for b in si.outputs: del lazybufs[b].srcs # can only schedule once
|
405
|
-
if (m:=BUF_LIMIT.get(device:=si.outputs[0].device)) and len(si.bufs) >= m:
|
406
|
-
if DEBUG >= 3: print(si)
|
407
|
-
raise RuntimeError(f"Kernel for {si.metadata} exceeded the {m} buffer count limit for {device} with {len(si.bufs)} buffers.")
|
408
476
|
for x in graph[si]:
|
409
477
|
in_degree[x] -= 1
|
410
478
|
if in_degree[x] == 0: queue.append(x)
|
411
479
|
# confirm everything was scheduled correctly
|
412
480
|
if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
|
413
481
|
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
return schedule
|
482
|
+
# capture process replay
|
483
|
+
if CAPTURE_PROCESS_REPLAY:
|
484
|
+
with Context(PICKLE_BUFFERS=0):
|
485
|
+
diskcache_put("schedule_process_replay", str(big_sink.key), (big_sink, ContextVar._cache, [x.ast for x in schedule]))
|
486
|
+
return schedule, var_vals, becomes_map
|