tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/engine/schedule.py
CHANGED
@@ -1,458 +1,83 @@
|
|
1
|
-
|
2
|
-
from
|
3
|
-
from
|
4
|
-
from tinygrad.ops import UOp, Variable, Ops,
|
5
|
-
from tinygrad.
|
6
|
-
from tinygrad.
|
7
|
-
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv
|
8
|
-
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND
|
9
|
-
from tinygrad.dtype import ImageDType
|
10
|
-
from tinygrad.shape.shapetracker import ShapeTracker
|
11
|
-
from tinygrad.shape.view import View, strides_for_shape
|
12
|
-
from tinygrad.device import Buffer
|
13
|
-
from tinygrad.spec import type_verify, kernel_spec
|
1
|
+
from typing import cast
|
2
|
+
from dataclasses import dataclass, field
|
3
|
+
from collections import deque, defaultdict
|
4
|
+
from tinygrad.uop.ops import UOp, Variable, Ops, buffers
|
5
|
+
from tinygrad.device import Device, Buffer, MultiBuffer
|
6
|
+
from tinygrad.helpers import Metadata, all_same
|
14
7
|
|
15
|
-
#
|
16
|
-
sys.setrecursionlimit(10000)
|
17
|
-
|
18
|
-
# **** schedule simplifier
|
19
|
-
|
20
|
-
def simplify_stride0_reduce(reduce:UOp, x:UOp):
|
21
|
-
# must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis)
|
22
|
-
if any(v.mask is not None for v in unwrap(x.st).views): return None
|
23
|
-
# must have all stride 0 in the relevant axis (NOTE: can do partial)
|
24
|
-
if not all(unwrap(x.st).views[-1].strides[axis] == 0 for axis in reduce.arg[1]) or not all_int(x.shape): return None
|
25
|
-
prshape = prod(x.shape[i] for i in reduce.arg[1])
|
26
|
-
ret = x.shrink(tuple((0,s) if i not in reduce.arg[1] else (0,1) for i,s in enumerate(x.shape)))
|
27
|
-
match reduce.arg[0]:
|
28
|
-
case Ops.ADD: return ret*prshape
|
29
|
-
case Ops.MUL: return ret.pow(prshape)
|
30
|
-
case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough
|
31
|
-
|
32
|
-
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
33
|
-
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
|
34
|
-
def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
|
35
|
-
new_src = list(alu.src)
|
36
|
-
for i,s in enumerate(alu.src):
|
37
|
-
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
|
38
|
-
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
|
39
|
-
|
40
|
-
sym = symbolic_simple+PatternMatcher([
|
41
|
-
# UOp with size 0 is zero
|
42
|
-
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
|
43
|
-
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
|
44
|
-
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
|
45
|
-
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
46
|
-
# reduce of size 0 is the identity element
|
47
|
-
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
48
|
-
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
49
|
-
# reduce on stride 0 is collapsed
|
50
|
-
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce),
|
51
|
-
# COPY(CONST) creates a new CONST on the destination device
|
52
|
-
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.arg)),
|
53
|
-
# no COPY to same device, except clone (arg is True)
|
54
|
-
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
|
55
|
-
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
|
56
|
-
# remove cast to image when it's already a contiguous image
|
57
|
-
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"))),)),
|
58
|
-
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
59
|
-
# make things that can't be images not images
|
60
|
-
(UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
|
61
|
-
and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
|
62
|
-
# remove contiguous if we can just view the buffer
|
63
|
-
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
|
64
|
-
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
|
65
|
-
# contiguous/buffer/copy is already contiguous
|
66
|
-
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY)),)), lambda root: root.src[0]),
|
67
|
-
# support for using a contiguous permuted view instead of the parent view if one exists
|
68
|
-
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
|
69
|
-
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
|
70
|
-
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
71
|
-
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"),
|
72
|
-
lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None),
|
73
|
-
])
|
74
|
-
|
75
|
-
remove_movement_ops = merge_views+PatternMatcher([
|
76
|
-
# NOTE: movement ops are always applied to base
|
77
|
-
(UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
|
78
|
-
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
|
79
|
-
(UPat(Ops.VIEW, name="view"),
|
80
|
-
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
|
81
|
-
])
|
82
|
-
|
83
|
-
# **** UOp realization
|
84
|
-
|
85
|
-
@dataclass(frozen=True)
|
86
|
-
class GrouperContext:
|
87
|
-
assigns: dict[UOp, UOp] # maps realized buffers to assigns
|
88
|
-
realizes: dict[UOp, None] # all the simplified tensor uops we realize
|
89
|
-
children: defaultdict[UOp, dict[UOp, None]] # children graph of tensor uops
|
90
|
-
|
91
|
-
def realize(ctx:GrouperContext, tr:UOp) -> None: ctx.realizes[tr] = None
|
92
|
-
|
93
|
-
def realize_before_view(ctx:GrouperContext, view:UOp, src:UOp) -> None:
|
94
|
-
st = unwrap(view.st)
|
95
|
-
# fold simple pads
|
96
|
-
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
|
97
|
-
return None if can_pad(src, ctx.realizes, cache=dict()) else realize(ctx, src)
|
98
|
-
# early realize before expand
|
99
|
-
if resolve(prod(src.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, src)
|
100
|
-
# otherwise safety check pads
|
101
|
-
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, cache=dict())) else realize(ctx, src)
|
102
|
-
|
103
|
-
do_realize = PatternMatcher([
|
104
|
-
# always realize SINK parents
|
105
|
-
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x.base, None) for x in s.src if x.base.op not in {Ops.CONST, Ops.BIND, Ops.BUFFER})),
|
106
|
-
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
|
107
|
-
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
|
108
|
-
# realize before expand or unsafe pad ops
|
109
|
-
(UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}, name="src"),)), realize_before_view),
|
110
|
-
# realize before COPY
|
111
|
-
(UPat(Ops.COPY, src=(UPat(), UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW}, name="tr"))), realize),
|
112
|
-
])
|
113
|
-
|
114
|
-
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
|
115
|
-
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
|
116
|
-
"""recursively search the uop for groupable children, realize the UOp if a child can't group"""
|
117
|
-
if (tr, st) in cache: return
|
118
|
-
cache.setdefault((tr, st))
|
119
|
-
rsize = unwrap(r.st).size
|
120
|
-
if tr in realizes and tr is not r:
|
121
|
-
# can only fuse contiguous
|
122
|
-
# max one reduceop per kernel
|
123
|
-
if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
|
124
|
-
return group.setdefault(tr)
|
125
|
-
for tr_next in children[tr]:
|
126
|
-
# max one reduceop per kernel
|
127
|
-
if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
|
128
|
-
# can only fuse contiguous
|
129
|
-
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next.src if x.base == tr)) > 1: return group.setdefault(r)
|
130
|
-
recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
|
131
|
-
|
132
|
-
def append_uop(ctx:GrouperContext, u:UOp) -> None:
|
133
|
-
if u.op is Ops.ASSIGN: ctx.assigns[u.buf_uop] = u
|
134
|
-
for s in u.src: ctx.children[s.base][u] = None
|
135
|
-
create_ctx = PatternMatcher([(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}, name="u"), append_uop)])
|
136
|
-
|
137
|
-
def group_realizes(sink:UOp) -> dict[UOp, None]:
|
138
|
-
# start by adding uops that always realize
|
139
|
-
sink = graph_rewrite(sink, do_realize+create_ctx, ctx:=GrouperContext({}, {}, defaultdict(dict)))
|
140
|
-
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
141
|
-
reduce_for_op: dict[UOp, UOp] = {}
|
142
|
-
double_reduces: list[UOp] = []
|
143
|
-
for r in sink.toposort:
|
144
|
-
if r.op is not Ops.REDUCE_AXIS: continue
|
145
|
-
if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
|
146
|
-
if r in ctx.realizes: continue
|
147
|
-
group: dict[UOp, None] = {}
|
148
|
-
recursive_group(r, unwrap(r.st), r, ctx.children, ctx.realizes, reduce_for_op, group, cache={})
|
149
|
-
# max one reduceop per kernel
|
150
|
-
can_chase = all(tr not in reduce_for_op for tr in group)
|
151
|
-
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
152
|
-
forced_realize = r in group
|
153
|
-
# can only have one output
|
154
|
-
if not forced_realize and len(group) > 1: forced_realize = True
|
155
|
-
# can only fuse assign if no other assign_target is used in the kernel
|
156
|
-
if not forced_realize and any(x.op is Ops.ASSIGN for x in group):
|
157
|
-
parents = deque((r, *group))
|
158
|
-
while parents and not forced_realize:
|
159
|
-
p = parents.pop().base
|
160
|
-
if (assign:=ctx.assigns.get(p)) is not None and assign not in group: forced_realize, can_chase = True, False
|
161
|
-
if p in ctx.realizes: continue
|
162
|
-
parents.extend(p.src)
|
163
|
-
if forced_realize or not group:
|
164
|
-
tr = r
|
165
|
-
if can_chase:
|
166
|
-
# can chase this down to contiguous children
|
167
|
-
st = unwrap(tr.st)
|
168
|
-
while len(ctx.children[tr]) == 1:
|
169
|
-
tr_next = next(iter(ctx.children[tr]))
|
170
|
-
st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr)
|
171
|
-
if len(st_childs) > 1: break
|
172
|
-
if st.size != st_childs[0].size: break
|
173
|
-
st = st + st_childs[0]
|
174
|
-
if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
|
175
|
-
tr = tr_next
|
176
|
-
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
177
|
-
if tr.op is Ops.CAST and tr.dtype.itemsize > tr.src[0].dtype.itemsize:
|
178
|
-
tr = tr.src[0].base
|
179
|
-
group = {tr: None}
|
180
|
-
ctx.realizes[tr] = None
|
181
|
-
reduce_for_op.update((tr, r) for tr in group)
|
182
|
-
if FUSE_ARANGE and r.arg[0] is Ops.ADD and r.src[0].base.op is Ops.CONST:
|
183
|
-
# maybe fuse arange with its children
|
184
|
-
if len(flatten(ctx.children[tr] for tr in group)) != 0:
|
185
|
-
for tr in group: del ctx.realizes[tr]
|
186
|
-
# fuse double reduces with no other child
|
187
|
-
for reduceop in double_reduces:
|
188
|
-
top_reduce = reduceop.src[0].base
|
189
|
-
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
|
190
|
-
return ctx.realizes
|
191
|
-
|
192
|
-
# break the SINK into kernels
|
193
|
-
|
194
|
-
@dataclass(frozen=True)
|
195
|
-
class Kernel:
|
196
|
-
ast: UOp
|
197
|
-
metadata: tuple[Metadata, ...]
|
198
|
-
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
|
199
|
-
|
200
|
-
@dataclass(frozen=True)
|
201
|
-
class KernelContext:
|
202
|
-
realizes: dict[UOp, None]
|
203
|
-
ops_metadata: dict[UOp, Metadata]
|
204
|
-
|
205
|
-
def create_kernel(ctx:KernelContext, x:UOp):
|
206
|
-
if x not in ctx.realizes: return None
|
207
|
-
assert isinstance(x.device, str), f"buf device in kernel must be string {x.device}"
|
208
|
-
b = x.buf_uop if x.op is Ops.ASSIGN else UOp.new_buffer(x.device, x.size, x.dtype)
|
209
|
-
output_st = ShapeTracker.from_shape(x.shape)
|
210
|
-
# KERNEL nodes become: ASSIGN(VIEW(BUFFER), KERNEL)
|
211
|
-
# TODO: this should be ASSIGN(BUFFER, KERNEL) followed by the output ShapeTracker
|
212
|
-
return b.view(output_st).assign(UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x, (m,) if (m:=ctx.ops_metadata.get(x)) else ())))
|
213
|
-
|
214
|
-
def append_to_kernel(ctx:KernelContext, x:UOp):
|
215
|
-
new_srcs: list[UOp] = []
|
216
|
-
new_metadata: dict[Metadata, None] = dict.fromkeys(x.arg.metadata)
|
217
|
-
for s in x.src:
|
218
|
-
if s.op is Ops.BUFFER or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL) or s in ctx.realizes: new_srcs.append(s)
|
219
|
-
else:
|
220
|
-
new_srcs.extend(s.src)
|
221
|
-
if (m:=ctx.ops_metadata.get(s)) is not None: new_metadata[m] = None
|
222
|
-
return x.replace(src=n, arg=Kernel(x.arg.ast, tuple(new_metadata))) if (n:=tuple(dedup(new_srcs))) != x.src else None
|
223
|
-
|
224
|
-
create_kernels = merge_views+PatternMatcher([
|
225
|
-
(UPat(GroupOp.All-{Ops.KERNEL, Ops.BUFFER}, name="x"), create_kernel),
|
226
|
-
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
|
227
|
-
])
|
228
|
-
|
229
|
-
# **** convert Kernel to a ScheduleItem (for legacy reasons)
|
8
|
+
# **** ScheduleItem return type
|
230
9
|
|
231
10
|
@dataclass(frozen=True)
|
232
11
|
class ScheduleItem:
|
233
12
|
ast: UOp
|
234
13
|
bufs: tuple[Buffer, ...]
|
235
|
-
metadata: tuple[Metadata, ...]
|
236
|
-
|
237
|
-
def outputs(self) -> tuple[Buffer, ...]:
|
238
|
-
"""Read/write or write only buffers in the schedule."""
|
239
|
-
return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
|
240
|
-
@property
|
241
|
-
def inputs(self) -> tuple[Buffer, ...]:
|
242
|
-
"""Read only buffers in the schedule."""
|
243
|
-
return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
|
244
|
-
@functools.cached_property
|
245
|
-
def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
|
246
|
-
|
247
|
-
# **** Kernel creation
|
14
|
+
metadata: tuple[Metadata, ...] = ()
|
15
|
+
fixedvars: dict[Variable, int] = field(default_factory=dict)
|
248
16
|
|
249
|
-
|
250
|
-
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
|
17
|
+
# **** schedule linearizer
|
251
18
|
|
252
|
-
def
|
253
|
-
|
254
|
-
|
255
|
-
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
|
256
|
-
strides = strides_for_shape(rshape)
|
257
|
-
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
258
|
-
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
|
259
|
-
# update input_st and axis
|
260
|
-
new_input_st = tmp + ShapeTracker(tuple(nv))
|
261
|
-
new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
|
262
|
-
return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
263
|
-
|
264
|
-
def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp:
|
265
|
-
if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}")
|
266
|
-
output_shape = swizzle_st.reduce(r.axis_arg)
|
267
|
-
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape))
|
268
|
-
|
269
|
-
def elementwise_view_right(root:UOp) -> UOp|None:
|
270
|
-
if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None
|
271
|
-
assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}"
|
272
|
-
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
|
273
|
-
# push the swizzle from src to root
|
274
|
-
output_swizzle = swizzles[0]
|
275
|
-
new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
|
276
|
-
ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
|
277
|
-
return ret.view(ShapeTracker.from_shape(output_swizzle.shape))
|
278
|
-
|
279
|
-
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
280
|
-
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
281
|
-
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
|
282
|
-
return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
|
283
|
-
|
284
|
-
# push VIEW to children
|
285
|
-
view_right = merge_views+PatternMatcher([
|
286
|
-
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
|
287
|
-
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
|
288
|
-
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
|
289
|
-
# STORE is the last child, so we just merge the ShapeTrackers and store the base
|
290
|
-
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
|
291
|
-
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
|
292
|
-
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
|
293
|
-
# REDUCE(src.view()) -> REDUCE(src).view()
|
294
|
-
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right),
|
295
|
-
# ALU(src.view()) -> ALU(src).view()
|
296
|
-
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right),
|
297
|
-
# double reduce op collapses to a single reduce op
|
298
|
-
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
299
|
-
])
|
300
|
-
|
301
|
-
def _append_st_vars(ctx:dict[Variable, int], x:UOp) -> UOp|None:
|
302
|
-
st = unwrap(x.st).simplify()
|
303
|
-
if any(x.op is Ops.BIND for x in st.vars()):
|
304
|
-
st, var_vals = st.unbind()
|
305
|
-
ctx.update(var_vals)
|
306
|
-
return st.to_uop() if st != x.st else None
|
307
|
-
|
308
|
-
def check_load_st(glbl:UOp, view:UOp):
|
309
|
-
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
310
|
-
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
311
|
-
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
|
312
|
-
# if it has a single view and it's equal when you shrink a contig, it's fine
|
313
|
-
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
|
314
|
-
# otherwise, it's not fine
|
315
|
-
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
316
|
-
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
317
|
-
|
318
|
-
fix_kernel_ops = PatternMatcher([
|
319
|
-
# BIND in shapetracker becomes DEFINE_VAR
|
320
|
-
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
321
|
-
# remove CONTIGUOUS/ASSIGN/DEVICE
|
322
|
-
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
323
|
-
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
|
324
|
-
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
|
325
|
-
# no ImageDType after load
|
326
|
-
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
327
|
-
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
328
|
-
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
|
329
|
-
])
|
330
|
-
|
331
|
-
def load_buf(ctx:list[UOp], x:UOp):
|
332
|
-
if x not in ctx: ctx.append(x)
|
333
|
-
return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop()))
|
334
|
-
|
335
|
-
add_buffer_ops = PatternMatcher([
|
336
|
-
# LOAD
|
337
|
-
(UPat(Ops.BUFFER, name="x"), load_buf),
|
338
|
-
# STORE (except for COPY/BUFFER_VIEW)
|
339
|
-
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
|
340
|
-
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
|
341
|
-
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
|
342
|
-
])
|
343
|
-
|
344
|
-
def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
|
345
|
-
ctx[var.replace(src=())] = val.arg
|
346
|
-
return var
|
347
|
-
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
|
348
|
-
|
349
|
-
def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
|
350
|
-
assert sink.op is Ops.ASSIGN and sink.src[1].op is Ops.KERNEL, f"{sink} must be ASSIGN"
|
351
|
-
# substitute kernel sources for the target buffer
|
352
|
-
ast = sink.src[1].arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in sink.src[1].src if s.op is Ops.ASSIGN}).sink()
|
353
|
-
# add buffer ops
|
354
|
-
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True)
|
355
|
-
# unbind_vars + push views to edges
|
356
|
-
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
|
357
|
-
# fix_kernel_ops
|
358
|
-
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
|
359
|
-
# create subbuffer
|
360
|
-
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = bufs[1].buffer.view(ast.size, ast.dtype, (x:=ast.src[0]).st_arg.views[0].offset*x.dtype.itemsize)
|
361
|
-
return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), sink.src[1].arg.metadata)
|
362
|
-
|
363
|
-
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
364
|
-
if CAPTURE_PROCESS_REPLAY:
|
365
|
-
@atexit.register
|
366
|
-
def save_process_replay():
|
367
|
-
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
|
368
|
-
|
369
|
-
# **** schedule creation and toposort
|
370
|
-
|
371
|
-
@track_rewrites(named=True)
|
372
|
-
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
373
|
-
# remove_movement_ops + sym
|
374
|
-
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
|
375
|
-
|
376
|
-
# display the cleaned up tensor graph
|
377
|
-
if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph")
|
378
|
-
|
379
|
-
# do_realize + group_realizes
|
380
|
-
sink = tensor_map[big_sink]
|
381
|
-
realize_map = group_realizes(sink)
|
382
|
-
|
383
|
-
# map tensors to new uops
|
384
|
-
becomes_map: dict[UOp, UOp] = {}
|
385
|
-
rev_tensor_map: dict[UOp, list[UOp]] = {}
|
386
|
-
ops_metadata: dict[UOp, Metadata] = {}
|
387
|
-
for k,v in tensor_map.items():
|
388
|
-
rev_tensor_map.setdefault(v, []).append(k)
|
389
|
-
if k is v: continue
|
390
|
-
if v.base.op is Ops.BUFFER:
|
391
|
-
# VIEW isn't a valid tensor uop, we need to backtrack to the movement op that created it
|
392
|
-
if v.op is Ops.VIEW:
|
393
|
-
mop = [x for x in k.toposort if (xs:=tensor_map[x]).base is v.base and xs.st == v.st][0]
|
394
|
-
if k is not mop: becomes_map[k] = mop
|
395
|
-
else: becomes_map[k] = v
|
396
|
-
elif v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
|
397
|
-
# if we're not realizing this tensor, map its metadata to the simplified uop
|
398
|
-
elif isinstance(k.metadata, Metadata): ops_metadata[v] = k.metadata
|
399
|
-
|
400
|
-
# create kernels
|
401
|
-
if len(realize_map) == 0: return [], {}, becomes_map
|
402
|
-
kernel_map = graph_rewrite_map(sink, create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True)
|
403
|
-
sched_sink = kernel_map[sink]
|
404
|
-
type_verify(list(sched_sink.toposort), kernel_spec)
|
405
|
-
|
406
|
-
# map realized tensors to buffers
|
407
|
-
for k,v in kernel_map.items():
|
408
|
-
if k is v or v.op is not Ops.ASSIGN: continue
|
409
|
-
for t in rev_tensor_map[k]: becomes_map[t] = t.src[0] if t.op is Ops.ASSIGN else v.buf_uop.reshape(t.shape)
|
410
|
-
|
411
|
-
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
412
|
-
kernel_assign: dict[UOp, UOp] = {}
|
413
|
-
assign_rep: dict[UOp, UOp] = {}
|
414
|
-
for u in sched_sink.toposort:
|
415
|
-
if u.op is not Ops.ASSIGN: continue
|
416
|
-
kernel_assign[u.buf_uop] = u
|
417
|
-
for s in u.src[1].src:
|
418
|
-
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
|
419
|
-
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort):
|
420
|
-
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
|
421
|
-
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
|
422
|
-
if assign_rep:
|
423
|
-
sched_sink = sched_sink.substitute(assign_rep)
|
424
|
-
type_verify(list(sched_sink.toposort), kernel_spec)
|
425
|
-
|
426
|
-
# display the final graph
|
427
|
-
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
|
428
|
-
|
429
|
-
# final toposort (bfs)
|
430
|
-
children: dict[UOp, list[UOp]] = {}
|
19
|
+
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int]]:
|
20
|
+
# construct the KERNEL children graph based on assigns
|
21
|
+
children: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
431
22
|
in_degree: dict[UOp, int] = {}
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
23
|
+
var_vals: dict[Variable, int] = {}
|
24
|
+
for u in sched_sink.toposort():
|
25
|
+
if u.op is not Ops.ASSIGN: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
|
26
|
+
k = u.src[1]
|
27
|
+
in_degree.setdefault(k, 0)
|
28
|
+
for s in k.src:
|
29
|
+
if s.op is Ops.ASSIGN:
|
30
|
+
children[s.src[1]].append(k)
|
31
|
+
in_degree[k] += 1
|
32
|
+
elif s.op in {Ops.MSELECT, Ops.MSTACK}:
|
33
|
+
for ss in s.src:
|
34
|
+
if ss.op is Ops.MSELECT: ss = ss.src[0]
|
35
|
+
if ss.op is not Ops.BUFFER:
|
36
|
+
assert ss.op is Ops.ASSIGN
|
37
|
+
children[ss.src[1]].append(k)
|
38
|
+
in_degree[k] += 1
|
39
|
+
elif s.op is Ops.BUFFER:
|
40
|
+
pass # a BUFFER is already realized, nothing to do here
|
41
|
+
elif s.op is Ops.BIND:
|
42
|
+
var, val = s.unbind()
|
43
|
+
assert var not in var_vals or var_vals[var] == val, f"bind mismatch on {var}, {var_vals[var]} != {val}"
|
44
|
+
var_vals[var] = val
|
45
|
+
else:
|
46
|
+
raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}")
|
47
|
+
|
48
|
+
# linearize KERNEL UOps into ScheduleItems in BFS order
|
49
|
+
|
50
|
+
def _heuristic(k: UOp):
|
51
|
+
if k.arg.ast.op is Ops.COPY and not all_same([Device[cast(Buffer, s.buf_uop.buffer).device].group_id for s in k.src]): return 1000
|
52
|
+
return 0
|
53
|
+
|
54
|
+
last_heuristic: int = 0
|
55
|
+
queues: defaultdict[int, deque[UOp]] = defaultdict(deque)
|
56
|
+
last_queue: deque[UOp] = deque()
|
57
|
+
for k,v in in_degree.items():
|
58
|
+
if v == 0: queues[_heuristic(k)].append(k)
|
439
59
|
|
440
|
-
queue = deque(k for k,v in in_degree.items() if v == 0)
|
441
60
|
schedule: list[ScheduleItem] = []
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
#
|
447
|
-
|
448
|
-
|
61
|
+
while last_queue or any(queues.values()):
|
62
|
+
if not last_queue: last_heuristic, last_queue = min((it for it in queues.items() if it[1]), key=lambda x: abs(x[0]-last_heuristic))
|
63
|
+
k = last_queue.popleft()
|
64
|
+
ast = k.arg.ast
|
65
|
+
# create subbuffers if needed
|
66
|
+
if ast.op is Ops.BUFFER_VIEW:
|
67
|
+
base = k.src[1].buf_uop.buffer
|
68
|
+
assert isinstance(base, Buffer), "base can't be MultiBuffer"
|
69
|
+
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
70
|
+
ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND)
|
71
|
+
if any(isinstance(x, MultiBuffer) for x in ubufs):
|
72
|
+
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
73
|
+
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
|
74
|
+
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
75
|
+
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0]:i} if len(dnums) else {}))
|
76
|
+
else:
|
77
|
+
# ONE -> ONE
|
78
|
+
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))
|
79
|
+
for x in children[k]:
|
449
80
|
in_degree[x] -= 1
|
450
|
-
if in_degree[x] == 0:
|
81
|
+
if in_degree[x] == 0: queues[_heuristic(x)].append(x)
|
451
82
|
|
452
|
-
|
453
|
-
if len(schedule) != (kc:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}")
|
454
|
-
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
455
|
-
# capture process replay
|
456
|
-
if CAPTURE_PROCESS_REPLAY:
|
457
|
-
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, [x.ast for x in schedule]))
|
458
|
-
return schedule, var_vals, becomes_map
|
83
|
+
return schedule, var_vals
|
File without changes
|