tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/engine/schedule.py
CHANGED
@@ -1,370 +1,419 @@
|
|
1
|
-
import sys,
|
1
|
+
import sys, atexit, functools, itertools
|
2
2
|
from collections import defaultdict, deque
|
3
|
-
from dataclasses import dataclass
|
4
|
-
from typing import Tuple, List, Dict, Optional,
|
5
|
-
from tinygrad.ops import
|
6
|
-
from tinygrad.
|
7
|
-
from tinygrad.helpers import
|
8
|
-
from tinygrad.
|
9
|
-
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType
|
10
|
-
from tinygrad.lazy import LazyBuffer
|
3
|
+
from dataclasses import dataclass, field
|
4
|
+
from typing import Set, Tuple, List, Dict, Optional, DefaultDict, cast
|
5
|
+
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint
|
6
|
+
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
|
7
|
+
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG
|
8
|
+
from tinygrad.dtype import ImageDType, dtypes
|
11
9
|
from tinygrad.shape.shapetracker import ShapeTracker
|
10
|
+
from tinygrad.shape.view import View, strides_for_shape
|
11
|
+
from tinygrad.engine.lazy import LazyBuffer
|
12
12
|
from tinygrad.device import Buffer
|
13
13
|
|
14
14
|
# creation can recurse a lot
|
15
15
|
sys.setrecursionlimit(10000)
|
16
16
|
|
17
|
-
|
18
|
-
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
|
17
|
+
BUF_LIMIT = {"METAL":32}
|
19
18
|
|
20
|
-
#
|
19
|
+
# **** ScheduleItem return type
|
21
20
|
|
22
21
|
@dataclass(frozen=True)
|
23
22
|
class ScheduleItem:
|
24
|
-
ast:
|
23
|
+
ast: UOp
|
25
24
|
bufs: Tuple[Buffer, ...]
|
25
|
+
metadata: Tuple[Metadata, ...]
|
26
|
+
assign_preloads: Tuple[UOp, ...]
|
26
27
|
@property
|
27
28
|
def outputs(self) -> Tuple[Buffer, ...]:
|
28
29
|
"""Read/write or write only buffers in the schedule."""
|
29
|
-
return self.bufs
|
30
|
+
return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
|
30
31
|
@property
|
31
32
|
def inputs(self) -> Tuple[Buffer, ...]:
|
32
33
|
"""Read only buffers in the schedule."""
|
33
|
-
return self.bufs
|
34
|
+
return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
|
35
|
+
@functools.cached_property
|
36
|
+
def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
|
34
37
|
|
35
|
-
#
|
38
|
+
# **** small wrapper for LazyBuffer -> UOp
|
36
39
|
|
37
|
-
|
38
|
-
@
|
39
|
-
|
40
|
-
ast: Tuple[LazyOp, ...]
|
41
|
-
outputs: Tuple[LazyBuffer, ...]
|
42
|
-
inputs: Tuple[LazyBuffer, ...]
|
43
|
-
var_vals: Dict[Variable, int]
|
44
|
-
|
45
|
-
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
|
46
|
-
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], cache) -> LazyOp:
|
47
|
-
"""recursively create a lazyop"""
|
48
|
-
if (buf, st) in cache: return cache[(buf, st)]
|
49
|
-
if buf != buf.base:
|
50
|
-
st = buf.st + st
|
51
|
-
buf = buf.base
|
52
|
-
# all buffers here are base now
|
53
|
-
assert buf.op is not None
|
40
|
+
def UPatSrc(*args, **kwargs): return UPat(Ops.VIEW, src=(UPat.var("b"), UPat(*args, **{**kwargs, "name":"to_store"})), name="base")
|
41
|
+
@functools.lru_cache(None)
|
42
|
+
def is_scheduled(u:UOp): return u.op is Ops.VIEW and len(u.src) == 2
|
54
43
|
|
44
|
+
@dataclass(frozen=True)
|
45
|
+
class ScheduleContext:
|
46
|
+
ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata
|
47
|
+
var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value
|
48
|
+
assigns: Set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
|
49
|
+
allbufs: Dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
|
50
|
+
children: DefaultDict[UOp, Dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
51
|
+
|
52
|
+
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], lazybufs:Dict[Buffer, LazyBuffer], cache:Dict[LazyBuffer, UOp]) -> UOp:
|
53
|
+
if (r:=cache.get(buf)) is not None: return r
|
54
|
+
if buf is not buf.base:
|
55
|
+
cache[buf] = ret = to_uop(buf.base, ctx, buffers, lazybufs, cache).view(buf.st)
|
56
|
+
return ret
|
57
|
+
# make things that can't be images not images
|
58
|
+
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
|
59
|
+
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
|
60
|
+
if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to {buf.dtype.base}")
|
61
|
+
# hack the underlying buffer too
|
62
|
+
buf.dtype = buf.buffer.dtype = buf.dtype.base
|
63
|
+
assert not buf.is_realized, "can't fixup allocated buffer"
|
64
|
+
buf.buffer.options = None
|
65
|
+
dtype = buf.dtype if buf.op in GroupOp.Meta else buf.dtype.base
|
55
66
|
# consts are always fused and generated
|
56
|
-
if buf.op is
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
79
|
-
return LazyOp(BufferOps.LOAD, (), MemBuffer(outputs.index(assign_targets[buf]), buf.dtype, unbound_st))
|
80
|
-
if buf not in inputs: inputs.append(buf)
|
81
|
-
return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.index(buf), buf.dtype, unbound_st))
|
82
|
-
|
83
|
-
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
|
84
|
-
if buf.op is LoadOps.CONTIGUOUS:
|
85
|
-
assert buf in outputs
|
86
|
-
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
|
87
|
-
if buf.op is LoadOps.ASSIGN:
|
88
|
-
assert buf in outputs
|
89
|
-
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
|
90
|
-
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
|
91
|
-
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
|
92
|
-
|
93
|
-
# if it's a reduce, we have to change the shapetracker
|
94
|
-
if buf.op in ReduceOps:
|
95
|
-
assert st.contiguous, "ReduceOps late fusion must be contiguous"
|
96
|
-
st = ShapeTracker.from_shape(buf.srcs[0].shape)
|
97
|
-
|
98
|
-
# otherwise we fuse it like normal
|
99
|
-
cache[(buf, st)] = ret = \
|
100
|
-
LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outputs, var_vals, st, realizes, assign_targets, cache) for x in buf.srcs), buf.arg)
|
67
|
+
if buf.op is Ops.CONST:
|
68
|
+
if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()])
|
69
|
+
return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(UOp.const(dtype, val), 0)
|
70
|
+
# everything else is a VIEW of BUFFER (with an optional op)
|
71
|
+
if buf.is_realized:
|
72
|
+
buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
|
73
|
+
op = None
|
74
|
+
elif buf.op is Ops.ASSIGN:
|
75
|
+
target, new_val = [to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs]
|
76
|
+
ctx.assigns.add(ubuf:=target.buf_uop)
|
77
|
+
op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg)
|
78
|
+
else:
|
79
|
+
buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
|
80
|
+
op = UOp(cast(Ops, buf.op), dtype, tuple(to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs),
|
81
|
+
None if buf.op in {Ops.CAST, Ops.BITCAST} else buf.arg)
|
82
|
+
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
|
83
|
+
if op is not None:
|
84
|
+
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
|
85
|
+
lazybufs[buf.buffer] = buf
|
86
|
+
for x in op.src:
|
87
|
+
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[ubuf] = None
|
88
|
+
ctx.allbufs[ubuf] = ret
|
101
89
|
return ret
|
102
90
|
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
for
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
91
|
+
# **** AST graph rewrite
|
92
|
+
|
93
|
+
# ** helpers for doing movementops on uops
|
94
|
+
|
95
|
+
def apply_swizzle(u:UOp, arg:ShapeTracker) -> UOp:
|
96
|
+
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u.view(arg), view_left)
|
97
|
+
|
98
|
+
def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
|
99
|
+
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
|
100
|
+
tmp = input_st.permute(permute_axis)
|
101
|
+
return tmp, tmp.shape[-len(axis):]
|
102
|
+
|
103
|
+
# ** movementops rewrite rules
|
104
|
+
|
105
|
+
def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
|
106
|
+
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(src.st).shape), r.axis_arg)
|
107
|
+
prshape = prod(rshape)
|
108
|
+
strides = strides_for_shape(rshape)
|
109
|
+
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
110
|
+
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
|
111
|
+
# update input_st and axis
|
112
|
+
new_input_st = tmp + ShapeTracker(tuple(nv))
|
113
|
+
_, new_rshape = permute_reduce(new_input_st, r.axis_arg)
|
114
|
+
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
|
115
|
+
return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
116
|
+
|
117
|
+
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp:
|
118
|
+
swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st)
|
119
|
+
assert swizzle_st.contiguous, "can't push a non contiguous VIEW down to STORE"
|
120
|
+
assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
|
121
|
+
output_shape = swizzle_st.reduce(root.axis_arg)
|
122
|
+
new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u)
|
123
|
+
return swizzle.src[0].r(root.arg[0], new_axis).view(ShapeTracker.from_shape(output_shape))
|
124
|
+
|
125
|
+
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
126
|
+
swizzles = [x for x in root.src if x.base is not x]
|
127
|
+
if len(swizzles) == 0: return None
|
128
|
+
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
|
129
|
+
assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
|
130
|
+
new_shape, new_input_shape = swizzle_shapes[0]
|
131
|
+
ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src))
|
132
|
+
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
|
133
|
+
|
134
|
+
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
135
|
+
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
136
|
+
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
|
137
|
+
return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
|
138
|
+
|
139
|
+
merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
|
140
|
+
|
141
|
+
# push VIEW to loads
|
142
|
+
view_left = merge_views+PatternMatcher([
|
143
|
+
# VIEW before elementwise ops
|
144
|
+
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
|
145
|
+
# early merge VIEW buffer ops
|
146
|
+
(UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.arg+v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
147
|
+
])
|
148
|
+
|
149
|
+
# push VIEW to stores
|
150
|
+
view_right = merge_views+PatternMatcher([
|
151
|
+
# ASSIGN can override st
|
152
|
+
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))),
|
153
|
+
lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None),
|
154
|
+
# non contiguous VIEW on a reduce creates a new VIEW
|
155
|
+
(UPat(Ops.REDUCE_AXIS, src=UPat.var("src"), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
|
156
|
+
# push a VIEW down to STORE, through a reduce (ONLY reshapes)
|
157
|
+
(UPat(Ops.REDUCE_AXIS, src=(UPat.var(name="src").view(name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
158
|
+
# push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes)
|
159
|
+
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
|
160
|
+
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
161
|
+
])
|
162
|
+
|
163
|
+
# ** ScheduleItem context builder
|
164
|
+
|
165
|
+
@dataclass(frozen=True)
|
166
|
+
class ScheduleItemContext:
|
167
|
+
var_vals: Dict[Variable, int]
|
168
|
+
assigned: Set[UOp]
|
169
|
+
sts: Set[ShapeTracker] = field(default_factory=set)
|
170
|
+
bufs: List[UOp] = field(default_factory=list)
|
171
|
+
assign_preloads: List[UOp] = field(default_factory=list)
|
172
|
+
|
173
|
+
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
|
174
|
+
if (st:=unwrap(x.st)) in ctx.sts: return None
|
175
|
+
st, var_vals = st.simplify().unbind()
|
176
|
+
ctx.var_vals.update(var_vals)
|
177
|
+
ctx.sts.add(st)
|
178
|
+
return st.to_uop() if st != x.st else None
|
179
|
+
|
180
|
+
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
181
|
+
ctx.bufs.append(x)
|
182
|
+
return UOp(Ops.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1)
|
183
|
+
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
|
184
|
+
|
185
|
+
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
|
186
|
+
if b in ctx.assigned: ctx.assign_preloads.append(b)
|
187
|
+
return x.replace(op=Ops.LOAD)
|
188
|
+
|
189
|
+
to_si = PatternMatcher([
|
190
|
+
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
191
|
+
(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),
|
192
|
+
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))),
|
193
|
+
])
|
194
|
+
|
195
|
+
# ** fusion
|
196
|
+
|
197
|
+
lazy = PatternMatcher([
|
198
|
+
(UPatSrc(), lambda ctx,to_store,**kwargs: to_store),
|
199
|
+
(UPat(Ops.BUFFER, name="b").view(name="view"), lambda ctx,b,view: UOp(Ops.PRELOAD, view.dtype, (b, view.st.to_uop()))),
|
200
|
+
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
|
201
|
+
])
|
202
|
+
|
203
|
+
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])
|
204
|
+
|
205
|
+
def full_ast_rewrite(pre:UOp, var_vals:Dict[Variable, int], assigned:Set[UOp]) -> Tuple[UOp, ScheduleItemContext]:
|
206
|
+
# fuse and fold store -> loads
|
207
|
+
sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, {x.buf_uop:x.src[2] for x in pre.src})
|
208
|
+
# assert cyclic dependency
|
209
|
+
for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.buf_uop in assigned), key=lambda x:x.buf_uop):
|
210
|
+
if not all_same([x.op for x in ops]):
|
211
|
+
raise RuntimeError(f"cycle detected in kernel.\nhelp: use .contiguous() to break the part loading pre-assign {b} into a different kernel.")
|
212
|
+
# do movementops
|
213
|
+
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
214
|
+
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
215
|
+
if len(assign_targets:=[x.buf_uop for x in sink.sparents if x.op is Ops.ASSIGN]) != 0:
|
216
|
+
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
|
217
|
+
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.buf_uop in assign_targets):
|
218
|
+
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
219
|
+
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
220
|
+
# convert to AST
|
221
|
+
sink = graph_rewrite(graph_rewrite(sink, to_si, ctx:=ScheduleItemContext(var_vals, assigned)), append_bufs, ctx)
|
222
|
+
if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE.append(((pre, var_vals, assigned), sink))
|
223
|
+
return sink, ctx
|
224
|
+
|
225
|
+
PROCESS_REPLAY_CAPTURE: List[Tuple[Tuple, UOp]] = []
|
226
|
+
if getenv("RUN_PROCESS_REPLAY"):
|
227
|
+
@atexit.register
|
228
|
+
def save_process_replay():
|
229
|
+
for x,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(x[0].key), (x, {}, ret))
|
230
|
+
|
231
|
+
# **** Schedule grouping
|
232
|
+
|
233
|
+
def uval(u:UOp) -> UOp:
|
234
|
+
assert is_scheduled(u), f"must be a scheduled op {u}"
|
235
|
+
return to_store.src[0] if (to_store:=u.src[1]).is_contiguous_base else to_store
|
236
|
+
|
237
|
+
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp], realizes:Dict[UOp, UOp],
|
238
|
+
reduce_for_op:Dict[UOp, UOp], group:Dict[UOp, None], cache:Dict[Tuple[UOp, ShapeTracker], None]) -> None:
|
239
|
+
"""recursively search the uop for groupable children, realize the UOp if a child can't group"""
|
240
|
+
if (tr, st) in cache: return
|
241
|
+
cache.setdefault((tr, st))
|
242
|
+
rsize = unwrap(allbufs[r].st).size
|
243
|
+
if tr in realizes and tr is not r:
|
172
244
|
# can only fuse contiguous
|
173
245
|
# max one reduceop per kernel
|
174
|
-
if not st.contiguous or st.size !=
|
175
|
-
return group.
|
246
|
+
if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
|
247
|
+
return group.setdefault(tr)
|
176
248
|
for tr_next in children[tr]:
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
249
|
+
# max one reduceop per kernel
|
250
|
+
if (tr_next_uop:=uval(allbufs[tr_next]).base).op is Ops.REDUCE_AXIS: return group.setdefault(r)
|
251
|
+
# can only fuse contiguous
|
252
|
+
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
|
253
|
+
recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
|
254
|
+
|
255
|
+
def get_isolated_children(r:UOp, reduce_for_op:Dict[UOp, UOp], children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp],
|
256
|
+
realizes:Dict[UOp, UOp], group:Dict[UOp, None]) -> Dict[UOp, None]:
|
257
|
+
rc_parents, cache = deque(group), set()
|
258
|
+
while rc_parents:
|
259
|
+
if (p:=uval(allbufs[rc_parents.pop()])) in cache: continue
|
260
|
+
cache.add(p)
|
261
|
+
# max one reduceop per kernel
|
262
|
+
if p.op is Ops.REDUCE_AXIS: return {}
|
263
|
+
rc_parents.extend(x.base.buf_uop for x in p.src if is_scheduled(x.base) and x.base.buf_uop is not r)
|
264
|
+
# search descendants of the reduceop that can cleanly group
|
265
|
+
descendants: Dict[UOp, None] = {}
|
266
|
+
for tr in group: recursive_group(tr, unwrap(allbufs[tr].st), tr, children, allbufs, realizes, reduce_for_op, descendants, cache={})
|
267
|
+
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
|
268
|
+
|
269
|
+
def group_realizes(ctx:ScheduleContext, realizes:Dict[UOp, UOp]) -> List[List[UOp]]:
|
270
|
+
"""search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
|
200
271
|
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
201
|
-
reduce_for_op: Dict[
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
272
|
+
reduce_for_op: Dict[UOp, UOp] = {}
|
273
|
+
reduce_of_const: List[UOp] = []
|
274
|
+
double_reduces: List[UOp] = []
|
275
|
+
for r, r_uop in ctx.allbufs.items():
|
276
|
+
if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue
|
277
|
+
if FUSE_CONV_BW and r_uop.op is Ops.REDUCE_AXIS and uval((x:=r_uop.src[0]).base).op is r_uop.op and x.base is not x: double_reduces.append(r)
|
278
|
+
if r in realizes: continue
|
279
|
+
group: Dict[UOp, None] = {}
|
280
|
+
recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, realizes, reduce_for_op, group, cache={})
|
207
281
|
# max one reduceop per kernel
|
208
282
|
can_chase = all(tr not in reduce_for_op for tr in group)
|
209
283
|
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
210
284
|
forced_realize = r in group
|
211
285
|
if not forced_realize and len(group) > 1:
|
212
|
-
|
213
|
-
rc_parents, rc_children = deque(group), deque(group)
|
214
|
-
while rc_parents and not forced_realize:
|
215
|
-
# max one reduceop per kernel
|
216
|
-
if (p:=rc_parents.pop()).op in ReduceOps: forced_realize = True
|
217
|
-
else: rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
|
218
|
-
# search descendants of the reduceop that can cleanly group
|
219
|
-
realized_descendants: Set[LazyBuffer] = set()
|
220
|
-
while rc_children and not forced_realize:
|
221
|
-
if (c:=rc_children.pop()).op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op:
|
222
|
-
realized_descendants.clear()
|
223
|
-
break
|
224
|
-
if c in realizes and c not in group: realized_descendants.add(c)
|
225
|
-
rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device)
|
226
|
-
group.update(realized_descendants)
|
286
|
+
group = get_isolated_children(r, reduce_for_op, ctx.children, ctx.allbufs, realizes, group)
|
227
287
|
# can only fuse assign if no other assign_target is used in the kernel
|
228
|
-
if not forced_realize and any(x
|
288
|
+
if not forced_realize and any(x in ctx.assigns for x in group):
|
229
289
|
parents = deque((r, *group))
|
230
290
|
while parents and not forced_realize:
|
231
|
-
if (p:=parents.pop()
|
232
|
-
|
233
|
-
|
234
|
-
parents.extend(
|
235
|
-
if forced_realize:
|
291
|
+
if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue
|
292
|
+
if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False
|
293
|
+
if p in realizes: continue
|
294
|
+
parents.extend([x.base.src[0] for x in p_uop.src if x.base.op is Ops.VIEW and len(x.base.src) != 0])
|
295
|
+
if forced_realize or not group:
|
236
296
|
tr = r
|
237
297
|
if can_chase:
|
238
298
|
# can chase this down to contiguous children
|
239
|
-
st =
|
240
|
-
while len(children[tr]) == 1:
|
241
|
-
|
242
|
-
st_childs = dedup(
|
299
|
+
st = unwrap(r_uop.st)
|
300
|
+
while len(ctx.children[tr]) == 1:
|
301
|
+
tr_next_uop = uval(ctx.allbufs[(tr_next:=next(iter(ctx.children[tr])))])
|
302
|
+
st_childs = dedup([unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop is tr])
|
243
303
|
if len(st_childs) > 1: break
|
244
|
-
if st.size != st_childs[0].
|
245
|
-
st = st + st_childs[0]
|
246
|
-
if not st.contiguous or
|
304
|
+
if st.size != st_childs[0].size: break
|
305
|
+
st = st + st_childs[0]
|
306
|
+
if not st.contiguous or tr_next_uop.op is Ops.REDUCE_AXIS: break
|
247
307
|
tr = tr_next
|
248
308
|
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
249
|
-
if tr.op is
|
250
|
-
tr =
|
251
|
-
|
252
|
-
realizes[tr] =
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
for
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
309
|
+
if (tr_uop:=uval(ctx.allbufs[tr])).op is Ops.CAST and tr_uop.dtype.base.itemsize > tr_uop.src[0].dtype.base.itemsize:
|
310
|
+
tr = tr_uop.src[0].base.buf_uop
|
311
|
+
group = {tr: None}
|
312
|
+
realizes[tr] = tr
|
313
|
+
reduce_for_op.update((tr, r) for tr in group)
|
314
|
+
if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.WHERE: reduce_of_const.append(r)
|
315
|
+
# fuse double reduces with no other child
|
316
|
+
for reduceop in double_reduces:
|
317
|
+
top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
|
318
|
+
if len(ctx.children[top_reduce]) == 1: del realizes[top_reduce]
|
319
|
+
# maybe fuse arange with its children
|
320
|
+
for rbuf in reduce_of_const:
|
321
|
+
group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
|
322
|
+
if any(ctx.allbufs[tr].src[1].is_contiguous_base for tr in group): continue
|
323
|
+
kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
|
324
|
+
if len(kernel_children) == 0: continue
|
325
|
+
for tr in group: del realizes[tr]
|
326
|
+
# group BUFFER uops into kernels
|
327
|
+
output_groups: DefaultDict[UOp, List[UOp]] = defaultdict(list)
|
328
|
+
for ubuf in realizes: output_groups[reduce_for_op.get(ubuf, ubuf)].append(ubuf)
|
329
|
+
return list(output_groups.values())
|
330
|
+
|
331
|
+
# **** Schedule creation and BFS toposort
|
332
|
+
|
333
|
+
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> UOp:
|
334
|
+
ctx[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), to_store)
|
335
|
+
return UOp(Ops.LOAD, base.dtype, (b, st.to_uop()))
|
336
|
+
|
337
|
+
def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> Optional[UOp]:
|
338
|
+
base_shape = unwrap(base.st).shape
|
339
|
+
st = unwrap(view.st)
|
340
|
+
# fold simple pads
|
341
|
+
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(base_shape) and resolve(prod(base_shape) >= prod([y-x for x,y in m])):
|
342
|
+
return None if can_pad(base) else realize(ctx, b, to_store, base).view(st)
|
343
|
+
# early realize before expand
|
344
|
+
if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base).view(st)
|
345
|
+
# otherwise safety check pads
|
346
|
+
return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, b, to_store, base).view(st)
|
347
|
+
|
348
|
+
do_realize = PatternMatcher([
|
349
|
+
# always realize meta ops
|
350
|
+
(UPatSrc((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta)), realize),
|
351
|
+
# don't realize image to image casts
|
352
|
+
(UPatSrc(Ops.CAST, src=(UPat(Ops.LOAD, name="x"),), dtype=dtypes.float).view(name="v"), lambda ctx,x,v,**kwargs: r.src[2].view(v.st)
|
353
|
+
if (r:=ctx.get(b:=x.buf_uop)) is not None and r.op is Ops.STORE and isinstance(b.dtype, ImageDType) and r.src[2].op not in GroupOp.Meta else None),
|
354
|
+
# realize before expand or unsafe pad ops
|
355
|
+
(UPatSrc().view(name="view"), realize_view),
|
356
|
+
# realize before COPY or BUFFER_VIEW
|
357
|
+
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatSrc(), UPatSrc().view(name="view")),), name="root"),
|
358
|
+
lambda ctx,root,view=None,**kwargs: root.replace(src=(realize(ctx,**kwargs) if view is None else realize(ctx,**kwargs).view(view.st),)),),
|
359
|
+
])
|
360
|
+
break_sched = PatternMatcher([(UPatSrc(), lambda ctx,b,to_store,base: realize(ctx, b, to_store, base) if b in ctx else None),])
|
361
|
+
|
362
|
+
@track_rewrites(named=True)
|
363
|
+
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
364
|
+
if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {}
|
365
|
+
for out in outs: out.forced_realize = True
|
366
|
+
# create the big graph
|
367
|
+
ctx = ScheduleContext()
|
368
|
+
cache: Dict[LazyBuffer, UOp] = {}
|
369
|
+
buffers: Dict[UOp, Buffer] = {}
|
370
|
+
lazybufs: Dict[Buffer, LazyBuffer] = {}
|
371
|
+
big_graph = UOp.sink(*(to_uop(x, ctx, buffers, lazybufs, cache) for x in outs))
|
372
|
+
# get realizes
|
373
|
+
realizes: Dict[UOp, UOp] = {}
|
374
|
+
graph_rewrite(big_graph, do_realize, realizes)
|
375
|
+
store_groups = group_realizes(ctx, realizes)
|
376
|
+
# split realizes into small graphs
|
377
|
+
graph_rewrite(big_graph, break_sched, realizes)
|
378
|
+
sinks = [UOp.sink(*(realizes[u] for u in stores)) for stores in store_groups]
|
379
|
+
# preschedule all realizes
|
380
|
+
prescheduled: List[ScheduleItem] = []
|
381
|
+
for sink in sinks:
|
382
|
+
metadata = tuple({mx for x in sink.sparents if (x.op is Ops.STORE or is_scheduled(x)) and (mx:=ctx.ubuf_metadata.get(x.buf_uop))})
|
383
|
+
ast, ast_ctx = full_ast_rewrite(sink, ctx.var_vals, ctx.assigns)
|
384
|
+
prescheduled.append(ScheduleItem(ast, tuple(b for u in ast_ctx.bufs if (b:=buffers[u]).size != 0), metadata, tuple(ast_ctx.assign_preloads)))
|
385
|
+
# do BFS
|
386
|
+
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
387
|
+
graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)
|
388
|
+
in_degree: DefaultDict[ScheduleItem, int] = defaultdict(int)
|
389
|
+
for si in prescheduled:
|
284
390
|
# realize outputs before a parent is assigned to
|
285
|
-
parents_assigns =
|
391
|
+
parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(buffers[x])) and xsi is not si)
|
286
392
|
for assign in parents_assigns:
|
287
|
-
graph[
|
393
|
+
graph[si].append(assign)
|
288
394
|
in_degree[assign] += 1
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
296
|
-
if seen is None: seen = set()
|
297
|
-
graph, in_degree, prescheduled = _graph_schedule(outs, seen)
|
298
|
-
queue = deque(si for key, si in prescheduled.items() if in_degree[key] == 0)
|
395
|
+
# realize outputs after all parents are realized
|
396
|
+
scheduled_parents = dedup(xsi for x in si.inputs if (xsi:=schedule_targets.get(x)) is not None and xsi not in parents_assigns)
|
397
|
+
for x in scheduled_parents:
|
398
|
+
graph[x].append(si)
|
399
|
+
in_degree[si] += 1
|
400
|
+
queue = deque(si for si in prescheduled if in_degree[si] == 0)
|
299
401
|
schedule: List[ScheduleItem] = []
|
300
|
-
var_vals: Dict[Variable, int] = {}
|
301
|
-
kernel_number = GlobalCounters.kernel_count
|
302
402
|
while queue:
|
303
|
-
|
304
|
-
for
|
305
|
-
if
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
for out in ps.outputs: del out.srcs # can only schedule once
|
310
|
-
schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0)))
|
311
|
-
if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
|
312
|
-
for x in graph[ps.outputs[0]]:
|
403
|
+
schedule.append(si:=queue.popleft())
|
404
|
+
for b in si.outputs: del lazybufs[b].srcs # can only schedule once
|
405
|
+
if (m:=BUF_LIMIT.get(device:=si.outputs[0].device)) and len(si.bufs) >= m:
|
406
|
+
if DEBUG >= 3: print(si)
|
407
|
+
raise RuntimeError(f"Kernel for {si.metadata} exceeded the {m} buffer count limit for {device} with {len(si.bufs)} buffers.")
|
408
|
+
for x in graph[si]:
|
313
409
|
in_degree[x] -= 1
|
314
|
-
if in_degree[x] == 0: queue.append(
|
315
|
-
|
316
|
-
if SAVE_SCHEDULE:
|
317
|
-
def _save():
|
318
|
-
print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
|
319
|
-
with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
|
320
|
-
if len(SCHEDULES) == 0: atexit.register(_save)
|
321
|
-
SCHEDULES.extend((ps.ast for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)])
|
410
|
+
if in_degree[x] == 0: queue.append(x)
|
322
411
|
# confirm everything was scheduled correctly
|
323
|
-
if
|
324
|
-
raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
|
412
|
+
if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
|
325
413
|
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
326
|
-
return schedule, var_vals
|
414
|
+
return schedule, ctx.var_vals
|
327
415
|
|
328
|
-
def create_schedule(outs:List[LazyBuffer]
|
329
|
-
schedule, var_vals = create_schedule_with_vars(outs
|
416
|
+
def create_schedule(outs:List[LazyBuffer]) -> List[ScheduleItem]:
|
417
|
+
schedule, var_vals = create_schedule_with_vars(outs)
|
330
418
|
assert len(var_vals) == 0
|
331
419
|
return schedule
|
332
|
-
|
333
|
-
# *** memory planning ***
|
334
|
-
|
335
|
-
def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], debug_prefix="") -> Dict[Buffer, Buffer]:
|
336
|
-
if getenv("NO_MEMORY_PLANNER"): return {}
|
337
|
-
last_appearance = {}
|
338
|
-
for i,u in enumerate(buffers):
|
339
|
-
for buf in u: last_appearance[buf] = i
|
340
|
-
|
341
|
-
# LRU algorithm
|
342
|
-
assigned: Dict[Buffer, Buffer] = {}
|
343
|
-
local_cache: DefaultDict[Tuple[str, int, DType], List[Buffer]] = defaultdict(list)
|
344
|
-
|
345
|
-
def handle_buffer(buf):
|
346
|
-
key = (buf.device, buf.size, buf.dtype)
|
347
|
-
if buf not in assigned:
|
348
|
-
if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
|
349
|
-
else: assigned[buf] = Buffer(*key)
|
350
|
-
if i == last_appearance[buf]:
|
351
|
-
if assigned[buf] not in local_cache[key]: local_cache[key].append(assigned[buf])
|
352
|
-
|
353
|
-
for i,u in enumerate(buffers):
|
354
|
-
for buf in u:
|
355
|
-
# all unallocated unparented buffers are fair game to replace
|
356
|
-
if buf.is_allocated() or buf.lb_refcount > 0: continue
|
357
|
-
# handle view buffers
|
358
|
-
if buf._base is not None:
|
359
|
-
assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf._base, buf._base), offset=buf.offset)
|
360
|
-
else:
|
361
|
-
handle_buffer(buf)
|
362
|
-
|
363
|
-
if DEBUG >= 1 and len(ak:=dedup(assigned.keys())) != len(av:=dedup(assigned.values())):
|
364
|
-
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
|
365
|
-
f"{len(ak)} -> {len(av)} bufs")
|
366
|
-
return assigned
|
367
|
-
|
368
|
-
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
|
369
|
-
assigned = _internal_memory_planner([si.bufs for si in schedule])
|
370
|
-
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs)) for si in schedule]
|