tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +35 -37
- tinygrad/codegen/linearize.py +19 -10
- tinygrad/codegen/lowerer.py +31 -8
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +10 -0
- tinygrad/device.py +28 -11
- tinygrad/dtype.py +12 -3
- tinygrad/engine/jit.py +3 -2
- tinygrad/engine/multi.py +0 -1
- tinygrad/engine/realize.py +7 -4
- tinygrad/engine/schedule.py +227 -255
- tinygrad/engine/search.py +20 -27
- tinygrad/gradient.py +3 -0
- tinygrad/helpers.py +7 -4
- tinygrad/nn/state.py +2 -2
- tinygrad/ops.py +64 -329
- tinygrad/renderer/__init__.py +19 -3
- tinygrad/renderer/cstyle.py +39 -18
- tinygrad/renderer/llvmir.py +55 -18
- tinygrad/renderer/ptx.py +6 -2
- tinygrad/renderer/wgsl.py +20 -12
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/metal.py +28 -29
- tinygrad/runtime/ops_amd.py +37 -34
- tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
- tinygrad/runtime/ops_disk.py +1 -1
- tinygrad/runtime/ops_dsp.py +59 -33
- tinygrad/runtime/ops_llvm.py +14 -12
- tinygrad/runtime/ops_metal.py +78 -62
- tinygrad/runtime/ops_nv.py +9 -6
- tinygrad/runtime/ops_python.py +5 -5
- tinygrad/runtime/ops_webgpu.py +200 -38
- tinygrad/runtime/support/am/amdev.py +23 -11
- tinygrad/runtime/support/am/ip.py +10 -10
- tinygrad/runtime/support/elf.py +2 -0
- tinygrad/runtime/support/hcq.py +7 -5
- tinygrad/runtime/support/llvm.py +8 -14
- tinygrad/shape/shapetracker.py +3 -2
- tinygrad/shape/view.py +2 -3
- tinygrad/spec.py +21 -20
- tinygrad/tensor.py +150 -90
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- tinygrad/codegen/rewriter.py +0 -516
- tinygrad-0.10.1.dist-info/RECORD +0 -86
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/engine/schedule.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
|
-
import sys, functools
|
1
|
+
import sys, functools, atexit, pickle
|
2
2
|
from collections import defaultdict, deque
|
3
|
-
from dataclasses import dataclass
|
3
|
+
from dataclasses import dataclass
|
4
4
|
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
|
5
|
-
from tinygrad.ops import can_pad, identity_element, resolve,
|
6
|
-
from tinygrad.
|
7
|
-
from tinygrad.helpers import
|
5
|
+
from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views
|
6
|
+
from tinygrad.codegen.symbolic import symbolic_simple
|
7
|
+
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv
|
8
|
+
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND
|
8
9
|
from tinygrad.dtype import ImageDType
|
9
10
|
from tinygrad.shape.shapetracker import ShapeTracker
|
10
11
|
from tinygrad.shape.view import View, strides_for_shape
|
@@ -16,17 +17,17 @@ sys.setrecursionlimit(10000)
|
|
16
17
|
|
17
18
|
# **** schedule simplifier
|
18
19
|
|
19
|
-
def
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
20
|
+
def simplify_stride0_reduce(reduce:UOp, x:UOp):
|
21
|
+
# must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis)
|
22
|
+
if any(v.mask is not None for v in unwrap(x.st).views): return None
|
23
|
+
# must have all stride 0 in the relevant axis (NOTE: can do partial)
|
24
|
+
if not all(unwrap(x.st).views[-1].strides[axis] == 0 for axis in reduce.arg[1]) or not all_int(x.shape): return None
|
25
|
+
prshape = prod(x.shape[i] for i in reduce.arg[1])
|
26
|
+
ret = x.shrink(tuple((0,s) if i not in reduce.arg[1] else (0,1) for i,s in enumerate(x.shape)))
|
24
27
|
match reduce.arg[0]:
|
25
|
-
case Ops.ADD: ret
|
26
|
-
case Ops.MUL: ret
|
27
|
-
case Ops.MAX:
|
28
|
-
case _: return None
|
29
|
-
return reduce.const_like(ret)
|
28
|
+
case Ops.ADD: return ret*prshape
|
29
|
+
case Ops.MUL: return ret.pow(prshape)
|
30
|
+
case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough
|
30
31
|
|
31
32
|
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
32
33
|
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
|
@@ -39,22 +40,25 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
|
|
39
40
|
sym = symbolic_simple+PatternMatcher([
|
40
41
|
# UOp with size 0 is zero
|
41
42
|
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
|
42
|
-
|
43
|
+
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
|
43
44
|
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
|
44
45
|
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
45
46
|
# reduce of size 0 is the identity element
|
46
47
|
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
47
48
|
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
48
|
-
# reduce
|
49
|
-
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.
|
49
|
+
# reduce on stride 0 is collapsed
|
50
|
+
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce),
|
50
51
|
# COPY(CONST) creates a new CONST on the destination device
|
51
|
-
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.
|
52
|
+
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.arg)),
|
52
53
|
# no COPY to same device, except clone (arg is True)
|
53
54
|
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
|
54
55
|
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
|
55
56
|
# remove cast to image when it's already a contiguous image
|
56
|
-
(UPat(Ops.
|
57
|
-
lambda cast,base,
|
57
|
+
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"))),)),
|
58
|
+
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
59
|
+
# make things that can't be images not images
|
60
|
+
(UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
|
61
|
+
and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
|
58
62
|
# remove contiguous if we can just view the buffer
|
59
63
|
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
|
60
64
|
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
|
@@ -63,10 +67,9 @@ sym = symbolic_simple+PatternMatcher([
|
|
63
67
|
# support for using a contiguous permuted view instead of the parent view if one exists
|
64
68
|
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
|
65
69
|
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
|
66
|
-
#
|
67
|
-
(UPat(Ops.
|
68
|
-
|
69
|
-
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
|
70
|
+
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
71
|
+
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"),
|
72
|
+
lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None),
|
70
73
|
])
|
71
74
|
|
72
75
|
remove_movement_ops = merge_views+PatternMatcher([
|
@@ -80,95 +83,40 @@ remove_movement_ops = merge_views+PatternMatcher([
|
|
80
83
|
# **** UOp realization
|
81
84
|
|
82
85
|
@dataclass(frozen=True)
|
83
|
-
class
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
# wrap tensor uops around a VIEW(BUFFER, <uop>)
|
92
|
-
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
|
93
|
-
def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp:
|
94
|
-
if (r:=cache.get(buf)) is not None: return r
|
95
|
-
# SINK is passthrough
|
96
|
-
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
|
97
|
-
# skip creating buffers for CONST/BIND/DEVICE/BUFFER
|
98
|
-
if buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf
|
99
|
-
if buf.base.op is Ops.BUFFER: return buf.view(unwrap(buf.st))
|
100
|
-
# VIEW is passthrough
|
101
|
-
if buf is not buf.base:
|
102
|
-
cache[buf] = ret = add_buffers(buf.base, buffer_map, cache).view(unwrap(buf.st))
|
103
|
-
return ret
|
104
|
-
# make things that can't be images not images
|
105
|
-
dtype = buf.dtype
|
106
|
-
if isinstance(dtype, ImageDType) and (prod(buf.shape)!=prod(dtype.shape) or not any(buf.shape[x]%4==0 for x in unwrap(buf.st).unit_stride_axes())):
|
107
|
-
if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}")
|
108
|
-
dtype = buf.dtype.base
|
109
|
-
# ASSIGN already has a target buffer, otherwise we create a new one
|
110
|
-
assert isinstance(buf.device, str), f"buf device is str, not {buf.device}"
|
111
|
-
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
|
112
|
-
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
|
113
|
-
# track the buffer uop for the simplified uop
|
114
|
-
buffer_map[buf] = buf_uop
|
115
|
-
# (early) bufferize
|
116
|
-
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
|
117
|
-
return ret
|
118
|
-
|
119
|
-
class UPatScheduled(UPat):
|
120
|
-
def __init__(self, *args, **kwargs):
|
121
|
-
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
|
122
|
-
|
123
|
-
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
|
124
|
-
|
125
|
-
def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
|
86
|
+
class GrouperContext:
|
87
|
+
assigns: dict[UOp, UOp] # maps realized buffers to assigns
|
88
|
+
realizes: dict[UOp, None] # all the simplified tensor uops we realize
|
89
|
+
children: defaultdict[UOp, dict[UOp, None]] # children graph of tensor uops
|
90
|
+
|
91
|
+
def realize(ctx:GrouperContext, tr:UOp) -> None: ctx.realizes[tr] = None
|
92
|
+
|
93
|
+
def realize_before_view(ctx:GrouperContext, view:UOp, src:UOp) -> None:
|
126
94
|
st = unwrap(view.st)
|
127
95
|
# fold simple pads
|
128
96
|
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
|
129
|
-
return None if can_pad(src, ctx.realizes, dict()) else realize(ctx,
|
97
|
+
return None if can_pad(src, ctx.realizes, cache=dict()) else realize(ctx, src)
|
130
98
|
# early realize before expand
|
131
|
-
if resolve(prod(src.shape) < prod(st.shape)) and not
|
99
|
+
if resolve(prod(src.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, src)
|
132
100
|
# otherwise safety check pads
|
133
|
-
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, dict())) else realize(ctx,
|
134
|
-
|
135
|
-
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
|
136
|
-
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
|
137
|
-
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
|
138
|
-
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
|
101
|
+
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, cache=dict())) else realize(ctx, src)
|
139
102
|
|
140
103
|
do_realize = PatternMatcher([
|
141
104
|
# always realize SINK parents
|
142
|
-
(UPat(Ops.SINK, name="
|
105
|
+
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x.base, None) for x in s.src if x.base.op not in {Ops.CONST, Ops.BIND, Ops.BUFFER})),
|
143
106
|
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
|
144
|
-
(
|
107
|
+
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
|
145
108
|
# realize before expand or unsafe pad ops
|
146
|
-
(UPat(Ops.VIEW, name="view", src=(
|
147
|
-
# realize before COPY
|
148
|
-
(UPat(Ops.COPY, src=(UPat(), UPat.
|
149
|
-
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
150
|
-
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
151
|
-
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
|
109
|
+
(UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}, name="src"),)), realize_before_view),
|
110
|
+
# realize before COPY
|
111
|
+
(UPat(Ops.COPY, src=(UPat(), UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW}, name="tr"))), realize),
|
152
112
|
])
|
153
113
|
|
154
|
-
def
|
155
|
-
|
156
|
-
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns[buf_uop] = None
|
157
|
-
for x in op.base.src:
|
158
|
-
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
|
159
|
-
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
|
160
|
-
|
161
|
-
def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER
|
162
|
-
def uval(u:UOp) -> UOp:
|
163
|
-
assert is_scheduled(u), f"must be a scheduled op {u}"
|
164
|
-
return u.src[1]
|
165
|
-
|
166
|
-
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp],
|
167
|
-
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
|
114
|
+
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
|
115
|
+
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
|
168
116
|
"""recursively search the uop for groupable children, realize the UOp if a child can't group"""
|
169
117
|
if (tr, st) in cache: return
|
170
118
|
cache.setdefault((tr, st))
|
171
|
-
rsize = unwrap(
|
119
|
+
rsize = unwrap(r.st).size
|
172
120
|
if tr in realizes and tr is not r:
|
173
121
|
# can only fuse contiguous
|
174
122
|
# max one reduceop per kernel
|
@@ -176,23 +124,28 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, di
|
|
176
124
|
return group.setdefault(tr)
|
177
125
|
for tr_next in children[tr]:
|
178
126
|
# max one reduceop per kernel
|
179
|
-
if
|
127
|
+
if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
|
180
128
|
# can only fuse contiguous
|
181
|
-
if len(st_childs:=dedup(unwrap(x.st) for x in
|
182
|
-
recursive_group(tr_next, st+st_childs[0], r, children,
|
129
|
+
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next.src if x.base == tr)) > 1: return group.setdefault(r)
|
130
|
+
recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
|
131
|
+
|
132
|
+
def append_uop(ctx:GrouperContext, u:UOp) -> None:
|
133
|
+
if u.op is Ops.ASSIGN: ctx.assigns[u.buf_uop] = u
|
134
|
+
for s in u.src: ctx.children[s.base][u] = None
|
135
|
+
create_ctx = PatternMatcher([(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}, name="u"), append_uop)])
|
183
136
|
|
184
|
-
def group_realizes(sink:UOp
|
137
|
+
def group_realizes(sink:UOp) -> dict[UOp, None]:
|
185
138
|
# start by adding uops that always realize
|
186
|
-
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
|
139
|
+
sink = graph_rewrite(sink, do_realize+create_ctx, ctx:=GrouperContext({}, {}, defaultdict(dict)))
|
187
140
|
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
188
141
|
reduce_for_op: dict[UOp, UOp] = {}
|
189
142
|
double_reduces: list[UOp] = []
|
190
|
-
for r
|
191
|
-
if
|
192
|
-
if FUSE_CONV_BW and
|
143
|
+
for r in sink.toposort:
|
144
|
+
if r.op is not Ops.REDUCE_AXIS: continue
|
145
|
+
if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
|
193
146
|
if r in ctx.realizes: continue
|
194
147
|
group: dict[UOp, None] = {}
|
195
|
-
recursive_group(r, unwrap(
|
148
|
+
recursive_group(r, unwrap(r.st), r, ctx.children, ctx.realizes, reduce_for_op, group, cache={})
|
196
149
|
# max one reduceop per kernel
|
197
150
|
can_chase = all(tr not in reduce_for_op for tr in group)
|
198
151
|
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
@@ -200,59 +153,77 @@ def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]:
|
|
200
153
|
# can only have one output
|
201
154
|
if not forced_realize and len(group) > 1: forced_realize = True
|
202
155
|
# can only fuse assign if no other assign_target is used in the kernel
|
203
|
-
if not forced_realize and any(x
|
156
|
+
if not forced_realize and any(x.op is Ops.ASSIGN for x in group):
|
204
157
|
parents = deque((r, *group))
|
205
158
|
while parents and not forced_realize:
|
206
|
-
|
207
|
-
if (
|
159
|
+
p = parents.pop().base
|
160
|
+
if (assign:=ctx.assigns.get(p)) is not None and assign not in group: forced_realize, can_chase = True, False
|
208
161
|
if p in ctx.realizes: continue
|
209
|
-
parents.extend(
|
162
|
+
parents.extend(p.src)
|
210
163
|
if forced_realize or not group:
|
211
164
|
tr = r
|
212
165
|
if can_chase:
|
213
166
|
# can chase this down to contiguous children
|
214
|
-
st = unwrap(
|
167
|
+
st = unwrap(tr.st)
|
215
168
|
while len(ctx.children[tr]) == 1:
|
216
|
-
|
217
|
-
st_childs = dedup(
|
169
|
+
tr_next = next(iter(ctx.children[tr]))
|
170
|
+
st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr)
|
218
171
|
if len(st_childs) > 1: break
|
219
172
|
if st.size != st_childs[0].size: break
|
220
173
|
st = st + st_childs[0]
|
221
|
-
if not st.contiguous or
|
174
|
+
if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
|
222
175
|
tr = tr_next
|
223
176
|
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
224
|
-
if
|
225
|
-
tr =
|
177
|
+
if tr.op is Ops.CAST and tr.dtype.itemsize > tr.src[0].dtype.itemsize:
|
178
|
+
tr = tr.src[0].base
|
226
179
|
group = {tr: None}
|
227
|
-
ctx.realizes[tr] =
|
180
|
+
ctx.realizes[tr] = None
|
228
181
|
reduce_for_op.update((tr, r) for tr in group)
|
229
|
-
if FUSE_ARANGE and
|
182
|
+
if FUSE_ARANGE and r.arg[0] is Ops.ADD and r.src[0].base.op is Ops.CONST:
|
230
183
|
# maybe fuse arange with its children
|
231
184
|
if len(flatten(ctx.children[tr] for tr in group)) != 0:
|
232
185
|
for tr in group: del ctx.realizes[tr]
|
233
186
|
# fuse double reduces with no other child
|
234
187
|
for reduceop in double_reduces:
|
235
|
-
top_reduce =
|
188
|
+
top_reduce = reduceop.src[0].base
|
236
189
|
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
|
237
|
-
graph_rewrite(sink, break_sched, ctx)
|
238
190
|
return ctx.realizes
|
239
191
|
|
240
|
-
# break the SINK into
|
192
|
+
# break the SINK into kernels
|
241
193
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
def
|
247
|
-
if (m:=ctx.ops_metadata.get(b)) is not None: ctx.ops_metadata[x] = m
|
248
|
-
if b not in ctx.realizes: return x # collapse BUFFER
|
249
|
-
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
|
250
|
-
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
|
194
|
+
@dataclass(frozen=True)
|
195
|
+
class Kernel:
|
196
|
+
ast: UOp
|
197
|
+
metadata: tuple[Metadata, ...]
|
198
|
+
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
|
251
199
|
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
200
|
+
@dataclass(frozen=True)
|
201
|
+
class KernelContext:
|
202
|
+
realizes: dict[UOp, None]
|
203
|
+
ops_metadata: dict[UOp, Metadata]
|
204
|
+
|
205
|
+
def create_kernel(ctx:KernelContext, x:UOp):
|
206
|
+
if x not in ctx.realizes: return None
|
207
|
+
assert isinstance(x.device, str), f"buf device in kernel must be string {x.device}"
|
208
|
+
b = x.buf_uop if x.op is Ops.ASSIGN else UOp.new_buffer(x.device, x.size, x.dtype)
|
209
|
+
output_st = ShapeTracker.from_shape(x.shape)
|
210
|
+
# KERNEL nodes become: ASSIGN(VIEW(BUFFER), KERNEL)
|
211
|
+
# TODO: this should be ASSIGN(BUFFER, KERNEL) followed by the output ShapeTracker
|
212
|
+
return b.view(output_st).assign(UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x, (m,) if (m:=ctx.ops_metadata.get(x)) else ())))
|
213
|
+
|
214
|
+
def append_to_kernel(ctx:KernelContext, x:UOp):
|
215
|
+
new_srcs: list[UOp] = []
|
216
|
+
new_metadata: dict[Metadata, None] = dict.fromkeys(x.arg.metadata)
|
217
|
+
for s in x.src:
|
218
|
+
if s.op is Ops.BUFFER or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL) or s in ctx.realizes: new_srcs.append(s)
|
219
|
+
else:
|
220
|
+
new_srcs.extend(s.src)
|
221
|
+
if (m:=ctx.ops_metadata.get(s)) is not None: new_metadata[m] = None
|
222
|
+
return x.replace(src=n, arg=Kernel(x.arg.ast, tuple(new_metadata))) if (n:=tuple(dedup(new_srcs))) != x.src else None
|
223
|
+
|
224
|
+
create_kernels = merge_views+PatternMatcher([
|
225
|
+
(UPat(GroupOp.All-{Ops.KERNEL, Ops.BUFFER}, name="x"), create_kernel),
|
226
|
+
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
|
256
227
|
])
|
257
228
|
|
258
229
|
# **** convert Kernel to a ScheduleItem (for legacy reasons)
|
@@ -273,23 +244,8 @@ class ScheduleItem:
|
|
273
244
|
@functools.cached_property
|
274
245
|
def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
|
275
246
|
|
276
|
-
def kernel_to_si(k:UOp) -> ScheduleItem:
|
277
|
-
assert k.op is Ops.KERNEL, f"must be KERNEL {k}"
|
278
|
-
return ScheduleItem(k.arg.ast, tuple(u.buf_uop.buffer for u in k.src), k.arg.metadata)
|
279
|
-
|
280
247
|
# **** Kernel creation
|
281
248
|
|
282
|
-
@dataclass(frozen=True)
|
283
|
-
class Kernel:
|
284
|
-
ast: UOp
|
285
|
-
metadata: tuple[Metadata, ...]
|
286
|
-
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
|
287
|
-
|
288
|
-
@dataclass(frozen=True)
|
289
|
-
class ScheduleItemContext:
|
290
|
-
var_vals: dict[Variable, int]
|
291
|
-
bufs: list[UOp] = field(default_factory=list)
|
292
|
-
|
293
249
|
def apply_swizzle(u:UOp) -> UOp:
|
294
250
|
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
|
295
251
|
|
@@ -342,33 +298,47 @@ view_right = merge_views+PatternMatcher([
|
|
342
298
|
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
343
299
|
])
|
344
300
|
|
345
|
-
def _append_st_vars(ctx:
|
301
|
+
def _append_st_vars(ctx:dict[Variable, int], x:UOp) -> UOp|None:
|
346
302
|
st = unwrap(x.st).simplify()
|
347
303
|
if any(x.op is Ops.BIND for x in st.vars()):
|
348
304
|
st, var_vals = st.unbind()
|
349
|
-
ctx.
|
305
|
+
ctx.update(var_vals)
|
350
306
|
return st.to_uop() if st != x.st else None
|
351
307
|
|
352
|
-
def
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
308
|
+
def check_load_st(glbl:UOp, view:UOp):
|
309
|
+
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
310
|
+
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
311
|
+
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
|
312
|
+
# if it has a single view and it's equal when you shrink a contig, it's fine
|
313
|
+
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
|
314
|
+
# otherwise, it's not fine
|
315
|
+
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
316
|
+
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
317
|
+
|
318
|
+
fix_kernel_ops = PatternMatcher([
|
319
|
+
# BIND in shapetracker becomes DEFINE_VAR
|
360
320
|
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
361
|
-
#
|
362
|
-
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
|
363
|
-
# don't need contiguous or assign anymore
|
321
|
+
# remove CONTIGUOUS/ASSIGN/DEVICE
|
364
322
|
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
365
323
|
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
|
366
|
-
# don't need DEVICE anymore
|
367
324
|
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
|
368
|
-
#
|
369
|
-
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
|
370
|
-
# once images are loaded they become the base dtype
|
325
|
+
# no ImageDType after load
|
371
326
|
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
327
|
+
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
328
|
+
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
|
329
|
+
])
|
330
|
+
|
331
|
+
def load_buf(ctx:list[UOp], x:UOp):
|
332
|
+
if x not in ctx: ctx.append(x)
|
333
|
+
return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop()))
|
334
|
+
|
335
|
+
add_buffer_ops = PatternMatcher([
|
336
|
+
# LOAD
|
337
|
+
(UPat(Ops.BUFFER, name="x"), load_buf),
|
338
|
+
# STORE (except for COPY/BUFFER_VIEW)
|
339
|
+
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
|
340
|
+
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
|
341
|
+
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
|
372
342
|
])
|
373
343
|
|
374
344
|
def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
|
@@ -376,111 +346,113 @@ def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
|
|
376
346
|
return var
|
377
347
|
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
|
378
348
|
|
379
|
-
def schedule_uop(
|
349
|
+
def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
|
350
|
+
assert sink.op is Ops.ASSIGN and sink.src[1].op is Ops.KERNEL, f"{sink} must be ASSIGN"
|
351
|
+
# substitute kernel sources for the target buffer
|
352
|
+
ast = sink.src[1].arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in sink.src[1].src if s.op is Ops.ASSIGN}).sink()
|
353
|
+
# add buffer ops
|
354
|
+
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True)
|
380
355
|
# unbind_vars + push views to edges
|
381
|
-
|
382
|
-
#
|
383
|
-
ast = graph_rewrite(
|
384
|
-
#
|
385
|
-
if
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
# if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
|
394
|
-
if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous:
|
395
|
-
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
396
|
-
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass
|
397
|
-
# if it has a single view and it's equal when you shrink a contig, it's fine
|
398
|
-
elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass
|
399
|
-
# otherwise, it's not fine
|
400
|
-
else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
401
|
-
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
402
|
-
# NOTE: we only add the metadata for fused tensors
|
403
|
-
metadata = tuple(dedup(m for x in pre.toposort if x.op is not Ops.BUFFER and (m:=ctx.ops_metadata.get(x)) is not None))
|
404
|
-
return UOp(Ops.KERNEL, src=tuple(si_ctx.bufs), arg=Kernel(ast, metadata))
|
356
|
+
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
|
357
|
+
# fix_kernel_ops
|
358
|
+
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
|
359
|
+
# create subbuffer
|
360
|
+
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = bufs[1].buffer.view(ast.size, ast.dtype, (x:=ast.src[0]).st_arg.views[0].offset*x.dtype.itemsize)
|
361
|
+
return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), sink.src[1].arg.metadata)
|
362
|
+
|
363
|
+
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
364
|
+
if CAPTURE_PROCESS_REPLAY:
|
365
|
+
@atexit.register
|
366
|
+
def save_process_replay():
|
367
|
+
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
|
405
368
|
|
406
369
|
# **** schedule creation and toposort
|
407
370
|
|
408
371
|
@track_rewrites(named=True)
|
409
372
|
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
373
|
+
# remove_movement_ops + sym
|
410
374
|
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
|
411
|
-
|
375
|
+
|
376
|
+
# display the cleaned up tensor graph
|
377
|
+
if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph")
|
378
|
+
|
379
|
+
# do_realize + group_realizes
|
380
|
+
sink = tensor_map[big_sink]
|
381
|
+
realize_map = group_realizes(sink)
|
382
|
+
|
383
|
+
# map tensors to new uops
|
412
384
|
becomes_map: dict[UOp, UOp] = {}
|
413
|
-
|
414
|
-
# NOOP
|
415
|
-
if k.base is v.base: continue
|
416
|
-
# NOTE: only the base tensors get a BUFFER UOp
|
417
|
-
if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st))
|
418
|
-
# otherwise if it simplified to a CONST the UOp just becomes that CONST
|
419
|
-
elif v.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
|
420
|
-
|
421
|
-
# we group the rest of UOps into ScheduleItems
|
422
|
-
buffer_map: dict[UOp, UOp] = {}
|
423
|
-
sink = add_buffers(tensor_map[big_sink], buffer_map, cache={})
|
424
|
-
# get realizes
|
425
|
-
buf_tensors: dict[UOp, list[UOp]] = {}
|
385
|
+
rev_tensor_map: dict[UOp, list[UOp]] = {}
|
426
386
|
ops_metadata: dict[UOp, Metadata] = {}
|
427
387
|
for k,v in tensor_map.items():
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
for tensor_uop in buf_tensors[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
|
445
|
-
# increment refcount for this buffer
|
446
|
-
buf_uop.buffer.ref(1)
|
447
|
-
sched_sink = UOp(Ops.SINK, src=tuple(sinks))
|
448
|
-
# display, TODO: this isn't a complete sched_sink yet
|
449
|
-
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]))
|
388
|
+
rev_tensor_map.setdefault(v, []).append(k)
|
389
|
+
if k is v: continue
|
390
|
+
if v.base.op is Ops.BUFFER:
|
391
|
+
# VIEW isn't a valid tensor uop, we need to backtrack to the movement op that created it
|
392
|
+
if v.op is Ops.VIEW:
|
393
|
+
mop = [x for x in k.toposort if (xs:=tensor_map[x]).base is v.base and xs.st == v.st][0]
|
394
|
+
if k is not mop: becomes_map[k] = mop
|
395
|
+
else: becomes_map[k] = v
|
396
|
+
elif v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
|
397
|
+
# if we're not realizing this tensor, map its metadata to the simplified uop
|
398
|
+
elif isinstance(k.metadata, Metadata): ops_metadata[v] = k.metadata
|
399
|
+
|
400
|
+
# create kernels
|
401
|
+
if len(realize_map) == 0: return [], {}, becomes_map
|
402
|
+
kernel_map = graph_rewrite_map(sink, create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True)
|
403
|
+
sched_sink = kernel_map[sink]
|
450
404
|
type_verify(list(sched_sink.toposort), kernel_spec)
|
451
405
|
|
452
|
-
#
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
406
|
+
# map realized tensors to buffers
|
407
|
+
for k,v in kernel_map.items():
|
408
|
+
if k is v or v.op is not Ops.ASSIGN: continue
|
409
|
+
for t in rev_tensor_map[k]: becomes_map[t] = t.src[0] if t.op is Ops.ASSIGN else v.buf_uop.reshape(t.shape)
|
410
|
+
|
411
|
+
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
412
|
+
kernel_assign: dict[UOp, UOp] = {}
|
413
|
+
assign_rep: dict[UOp, UOp] = {}
|
414
|
+
for u in sched_sink.toposort:
|
415
|
+
if u.op is not Ops.ASSIGN: continue
|
416
|
+
kernel_assign[u.buf_uop] = u
|
417
|
+
for s in u.src[1].src:
|
418
|
+
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
|
419
|
+
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort):
|
420
|
+
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
|
421
|
+
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
|
422
|
+
if assign_rep:
|
423
|
+
sched_sink = sched_sink.substitute(assign_rep)
|
424
|
+
type_verify(list(sched_sink.toposort), kernel_spec)
|
425
|
+
|
426
|
+
# display the final graph
|
427
|
+
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
|
428
|
+
|
429
|
+
# final toposort (bfs)
|
430
|
+
children: dict[UOp, list[UOp]] = {}
|
431
|
+
in_degree: dict[UOp, int] = {}
|
432
|
+
for u in sched_sink.toposort:
|
433
|
+
if u.op is not Ops.ASSIGN: continue
|
434
|
+
in_degree[u] = 0
|
435
|
+
for s in u.src[1].src:
|
436
|
+
if s.op is not Ops.ASSIGN: continue
|
437
|
+
children.setdefault(s, []).append(u)
|
438
|
+
in_degree[u] += 1
|
439
|
+
|
440
|
+
queue = deque(k for k,v in in_degree.items() if v == 0)
|
473
441
|
schedule: list[ScheduleItem] = []
|
442
|
+
var_vals: dict[Variable, int] = {}
|
474
443
|
while queue:
|
475
|
-
|
476
|
-
|
444
|
+
u = queue.popleft()
|
445
|
+
schedule.append(schedule_uop(u, var_vals))
|
446
|
+
# increment the refcount of the target buf (this is required by the JIT and memory planner)
|
447
|
+
u.buf_uop.buffer.ref(1)
|
448
|
+
for x in children.get(u, []):
|
477
449
|
in_degree[x] -= 1
|
478
450
|
if in_degree[x] == 0: queue.append(x)
|
451
|
+
|
479
452
|
# confirm everything was scheduled correctly
|
480
|
-
if len(schedule) != (
|
453
|
+
if len(schedule) != (kc:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}")
|
481
454
|
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
482
455
|
# capture process replay
|
483
456
|
if CAPTURE_PROCESS_REPLAY:
|
484
|
-
with Context(PICKLE_BUFFERS=0):
|
485
|
-
diskcache_put("schedule_process_replay", str(big_sink.key), (big_sink, ContextVar._cache, [x.ast for x in schedule]))
|
457
|
+
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, [x.ast for x in schedule]))
|
486
458
|
return schedule, var_vals, becomes_map
|