tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +248 -115
- tinygrad/codegen/lowerer.py +215 -0
- tinygrad/codegen/transcendental.py +310 -0
- tinygrad/codegen/uopgraph.py +622 -0
- tinygrad/codegen/uops.py +235 -393
- tinygrad/device.py +428 -69
- tinygrad/dtype.py +18 -4
- tinygrad/engine/graph.py +19 -32
- tinygrad/engine/jit.py +148 -70
- tinygrad/engine/realize.py +127 -51
- tinygrad/engine/schedule.py +259 -216
- tinygrad/engine/search.py +29 -22
- tinygrad/function.py +9 -0
- tinygrad/helpers.py +87 -49
- tinygrad/lazy.py +34 -35
- tinygrad/multi.py +41 -36
- tinygrad/nn/__init__.py +39 -22
- tinygrad/nn/state.py +3 -3
- tinygrad/ops.py +63 -62
- tinygrad/renderer/__init__.py +43 -21
- tinygrad/renderer/assembly.py +104 -106
- tinygrad/renderer/cstyle.py +87 -60
- tinygrad/renderer/llvmir.py +21 -30
- tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/kfd.py +32 -0
- tinygrad/runtime/autogen/libc.py +4260 -0
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/graph/clang.py +2 -2
- tinygrad/runtime/graph/cuda.py +8 -11
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +18 -15
- tinygrad/runtime/ops_amd.py +197 -305
- tinygrad/runtime/ops_clang.py +2 -2
- tinygrad/runtime/ops_cuda.py +36 -94
- tinygrad/runtime/ops_disk.py +3 -7
- tinygrad/runtime/ops_gpu.py +4 -2
- tinygrad/runtime/ops_hip.py +70 -0
- tinygrad/runtime/ops_metal.py +38 -27
- tinygrad/runtime/ops_nv.py +283 -363
- tinygrad/runtime/ops_python.py +26 -30
- tinygrad/runtime/support/compiler_cuda.py +78 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/shape/shapetracker.py +5 -14
- tinygrad/shape/symbolic.py +4 -8
- tinygrad/shape/view.py +34 -22
- tinygrad/tensor.py +399 -97
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
- tinygrad-0.9.2.dist-info/RECORD +70 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/runtime/{driver → support}/__init__.py +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/engine/schedule.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1
|
-
import sys, pickle, atexit
|
1
|
+
import sys, pickle, atexit, importlib, contextlib
|
2
2
|
from collections import defaultdict, deque
|
3
|
-
from dataclasses import dataclass
|
4
|
-
from typing import Tuple, List, Dict, Optional, Set, DefaultDict,
|
5
|
-
from tinygrad.ops import
|
3
|
+
from dataclasses import dataclass, field
|
4
|
+
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, get_args
|
5
|
+
from tinygrad.ops import MetaOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st
|
6
6
|
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
7
|
-
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE,
|
8
|
-
|
9
|
-
from tinygrad.
|
7
|
+
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
|
8
|
+
GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
|
9
|
+
from tinygrad.shape.symbolic import Variable, sint
|
10
|
+
from tinygrad.dtype import ConstType, ImageDType, dtypes
|
10
11
|
from tinygrad.lazy import LazyBuffer
|
11
12
|
from tinygrad.shape.shapetracker import ShapeTracker
|
12
13
|
from tinygrad.device import Buffer
|
14
|
+
from tinygrad.shape.view import View, strides_for_shape
|
13
15
|
|
14
16
|
# creation can recurse a lot
|
15
17
|
sys.setrecursionlimit(10000)
|
@@ -21,143 +23,198 @@ logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
|
|
21
23
|
|
22
24
|
@dataclass(frozen=True)
|
23
25
|
class ScheduleItem:
|
24
|
-
ast:
|
26
|
+
ast: LazyOp
|
25
27
|
bufs: Tuple[Buffer, ...]
|
28
|
+
metadata: Optional[List[Metadata]] = None
|
26
29
|
@property
|
27
30
|
def outputs(self) -> Tuple[Buffer, ...]:
|
28
31
|
"""Read/write or write only buffers in the schedule."""
|
29
|
-
return self.bufs[:len(self.ast)]
|
32
|
+
return self.bufs[:len(self.ast.src)] if self.ast.op is MetaOps.KERNEL else self.bufs[0:1]
|
30
33
|
@property
|
31
34
|
def inputs(self) -> Tuple[Buffer, ...]:
|
32
35
|
"""Read only buffers in the schedule."""
|
33
|
-
return self.bufs[len(self.ast):]
|
36
|
+
return self.bufs[len(self.ast.src):] if self.ast.op is MetaOps.KERNEL else self.bufs[1:]
|
37
|
+
|
38
|
+
@dataclass(frozen=True)
|
39
|
+
class LBScheduleItem:
|
40
|
+
ast: LazyOp
|
41
|
+
outputs: List[LazyBuffer]
|
42
|
+
inputs: List[LazyBuffer]
|
43
|
+
var_vals: Dict[Variable, int] = field(default_factory=dict)
|
44
|
+
metadata: List[Metadata] = field(default_factory=list)
|
45
|
+
def __hash__(self):
|
46
|
+
"""The unique identifier of a schedule item in the toposort."""
|
47
|
+
return hash(self.outputs[0])
|
34
48
|
|
35
49
|
# *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
|
36
50
|
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
outputs: Tuple[LazyBuffer, ...]
|
42
|
-
inputs: Tuple[LazyBuffer, ...]
|
43
|
-
var_vals: Dict[Variable, int]
|
44
|
-
|
45
|
-
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
|
46
|
-
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], cache) -> LazyOp:
|
51
|
+
def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:Dict[LazyBuffer, int],
|
52
|
+
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
|
53
|
+
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]],
|
54
|
+
cache:Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp]) -> LazyOp:
|
47
55
|
"""recursively create a lazyop"""
|
56
|
+
if buf is not buf.base: st, buf = buf.st+st, buf.base
|
48
57
|
if (buf, st) in cache: return cache[(buf, st)]
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
# all buffers here are base now
|
53
|
-
assert buf.op is not None
|
54
|
-
|
55
|
-
# consts are always fused and generated
|
56
|
-
if buf.op is LoadOps.CONST:
|
57
|
-
unbound_st, st_var_vals = st.simplify().unbind()
|
58
|
-
var_vals.update(st_var_vals)
|
59
|
-
if isinstance(buf.arg, Variable):
|
60
|
-
val, var_val = buf.arg.unbind()
|
61
|
-
var_vals.__setitem__(val, var_val)
|
62
|
-
else:
|
63
|
-
assert isinstance(buf.arg, get_args(ConstType)), f"cannot create ConstBuffer with value {buf.arg}"
|
64
|
-
val = buf.arg
|
65
|
-
return LazyOp(BufferOps.CONST, (), ConstBuffer(val, buf.dtype, unbound_st))
|
66
|
-
|
67
|
-
# if we aren't fusing it, it's a load and we add it to the inputs
|
58
|
+
assert buf.op is not None, "base must be a base itself"
|
59
|
+
|
60
|
+
# buffer ops define ShapeTracker
|
68
61
|
if buf.realized is not None or (buf in realizes and buf not in outputs):
|
69
62
|
unbound_st, st_var_vals = st.simplify().unbind()
|
70
63
|
var_vals.update(st_var_vals)
|
64
|
+
# if it's a const, we generate it
|
65
|
+
if buf.op is MetaOps.CONST:
|
66
|
+
if isinstance(val:=buf.arg, Variable):
|
67
|
+
val, var_val = val.unbind()
|
68
|
+
var_vals[val] = var_val
|
69
|
+
else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}"
|
70
|
+
return LazyOp(BufferOps.CONST, (), ConstBuffer(val, buf.dtype, unbound_st))
|
71
|
+
# otherwise, it's a load and we add it to the inputs
|
71
72
|
if buf in assign_targets:
|
72
|
-
#
|
73
|
-
if
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.index(buf), buf.dtype, unbound_st))
|
82
|
-
|
83
|
-
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
|
84
|
-
if buf.op is LoadOps.CONTIGUOUS:
|
85
|
-
assert buf in outputs
|
86
|
-
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
|
87
|
-
if buf.op is LoadOps.ASSIGN:
|
88
|
-
assert buf in outputs
|
89
|
-
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
|
90
|
-
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
|
91
|
-
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
|
92
|
-
|
93
|
-
# if it's a reduce, we have to change the shapetracker
|
73
|
+
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
74
|
+
if unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and\
|
75
|
+
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
|
76
|
+
return LazyOp(BufferOps.LOAD, (), MemBuffer(outputs.index(assign_targets[buf]), buf.dtype, unbound_st))
|
77
|
+
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
78
|
+
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
79
|
+
return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.setdefault(buf, len(inputs)), buf.dtype, unbound_st))
|
80
|
+
|
81
|
+
# reduce ops change ShapeTracker
|
94
82
|
if buf.op in ReduceOps:
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
83
|
+
rinfo = reduce_info.get((buf, st))
|
84
|
+
rsrc = _recursive_lazyop(buf.srcs[0], st:=(rinfo[0] if rinfo else st), outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
|
85
|
+
# if we are merging the reduce, skip it
|
86
|
+
if rinfo is None:
|
87
|
+
assert rsrc.op is buf.op, f"can't merge reduceop {buf.op} with {rsrc.op}\n{st}"
|
88
|
+
return rsrc
|
89
|
+
return cache.setdefault((buf, st), LazyOp(buf.op, (rsrc,), rinfo[1]))
|
90
|
+
|
91
|
+
# elementwise ops pass shapetracker
|
92
|
+
in_ops = tuple(_recursive_lazyop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs)
|
93
|
+
if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
|
94
|
+
assert buf in outputs, f"{buf.op} must be writable"
|
95
|
+
return in_ops[0]
|
96
|
+
return cache.setdefault((buf, st), LazyOp(buf.op, in_ops, buf.arg))
|
97
|
+
|
98
|
+
def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
|
99
|
+
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis) + axis
|
100
|
+
tmp = input_st.permute(permute_axis)
|
101
|
+
return tmp, tmp.shape[-len(axis):]
|
102
|
+
|
103
|
+
def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer],
|
104
|
+
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]],
|
105
|
+
cache:Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]]) -> \
|
106
|
+
Optional[Tuple[LazyBuffer, ShapeTracker]]:
|
107
|
+
if (buf, st) in cache: return cache[(buf, st)]
|
108
|
+
if buf.base.realized is not None or (buf.base in realizes and buf.base not in outs): return None
|
109
|
+
if buf is not buf.base: st, buf = buf.st+st, buf.base
|
110
|
+
input_st = ShapeTracker.from_shape(buf.srcs[0].shape) if buf.op in ReduceOps else st
|
111
|
+
reduce_srcs = [r for x in buf.srcs if (r:=_recurse_reduceops(x, input_st, realizes, outs, reduce_info, cache)) is not None]
|
112
|
+
top_reduce = reduce_srcs[-1] if len(reduce_srcs) != 0 else None
|
113
|
+
if buf.op in ReduceOps:
|
114
|
+
axis = buf.arg
|
115
|
+
if not st.contiguous:
|
116
|
+
# push the movementop to the input
|
117
|
+
tmp, rshape = _permute_reduce(input_st, axis)
|
118
|
+
prshape = prod(rshape)
|
119
|
+
strides = strides_for_shape(rshape)
|
120
|
+
nv: List[View] = []
|
121
|
+
for v in st.views:
|
122
|
+
nv.append(View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
123
|
+
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None))
|
124
|
+
input_st = tmp + ShapeTracker(tuple(nv))
|
125
|
+
# update the axis
|
126
|
+
_, new_rshape = _permute_reduce(input_st, axis)
|
127
|
+
axis = tuple(range(len(input_st.shape)-len(new_rshape), len(input_st.shape)))
|
128
|
+
elif top_reduce is not None:
|
129
|
+
top_reduce_input_st, top_reduce_axes = reduce_info[top_reduce]
|
130
|
+
if buf.srcs[0] is not buf.srcs[0].base and buf.srcs[0].base is top_reduce[0] and buf.op is top_reduce[0].op:
|
131
|
+
# merge this reduce with its parent
|
132
|
+
new_st = top_reduce[1]+st
|
133
|
+
top_reduce = (top_reduce[0], new_st.reshape(reduce_st(top_reduce_input_st, new_axis:=axis+top_reduce_axes)))
|
134
|
+
reduce_info[top_reduce] = (top_reduce_input_st, new_axis)
|
135
|
+
return None
|
136
|
+
# reshape this reduceop based on the top reduce
|
137
|
+
input_st = input_st.reshape(tuple(1 if i in top_reduce_axes else s for i,s in enumerate(top_reduce_input_st.shape)))
|
138
|
+
st = st.reshape(reduce_st(input_st, axis))
|
139
|
+
reduce_info[(buf, st)] = (input_st, axis)
|
140
|
+
return (buf, st)
|
141
|
+
return cache.setdefault((buf, st), top_reduce)
|
142
|
+
|
143
|
+
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem:
|
144
|
+
"""describe the computation for a LazyBuffer with LazyOp + inputs + var_vals"""
|
145
|
+
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
|
146
|
+
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
|
147
|
+
wr = LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st))
|
148
|
+
return LBScheduleItem(LazyOp(MetaOps.KERNEL, (wr,)), outs, [x.base for x in out.srcs])
|
149
|
+
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
|
150
|
+
return LBScheduleItem(LazyOp(MetaOps.EXT, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])
|
151
|
+
# push through all movementops between reduceops
|
152
|
+
reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {}
|
153
|
+
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]] = {}
|
154
|
+
for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
|
155
|
+
# pad all reduceops to the max of each dimension
|
156
|
+
shape_dims = [sorted(dedup(dims)) for dims in zip(*[input_st.shape for input_st,_ in reduce_info.values()])]
|
157
|
+
for i,dims in enumerate(shape_dims):
|
158
|
+
if len(dims) == 1 or (len(dims) == 2 and dims[0] == 1): continue
|
159
|
+
for (r,view),(input_st,axis) in reduce_info.items():
|
160
|
+
if (dim:=input_st.shape[i]) > 1 and dim != max(dims):
|
161
|
+
input_st = input_st.pad(((0, 0),)*i+((0, max(dims)-dim),))
|
162
|
+
reduce_info[(r, view)] = (input_st, axis)
|
163
|
+
# create the stores
|
164
|
+
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
|
165
|
+
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
|
166
|
+
cache: Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp] = {}
|
106
167
|
ast: List[LazyOp] = []
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
for i, out in enumerate(outs):
|
120
|
-
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
|
121
|
-
output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
|
122
|
-
lop = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, assign_targets, cache={})
|
123
|
-
output_view, vv = output_view.simplify().unbind()
|
124
|
-
if vv: var_vals.update(vv)
|
125
|
-
ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
|
126
|
-
return _LBScheduleItem(tuple(ast), outs, tuple(inputs), var_vals)
|
168
|
+
inputs: Dict[LazyBuffer, int] = {}
|
169
|
+
for i, out in enumerate(outs):
|
170
|
+
output_st = ShapeTracker.from_shape(reduce_st(*deque(reduce_info.values(), 1).pop()) if reduce_info else out.shape)
|
171
|
+
lop = _recursive_lazyop(out, output_st, tuple(outs), var_vals, inputs, realizes, assign_targets, reduce_info, cache=cache)
|
172
|
+
if out.op is MetaOps.ASSIGN and out.arg:
|
173
|
+
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
|
174
|
+
output_st = out.arg[0].reshape(output_st.shape)
|
175
|
+
output_st, vv = output_st.simplify().unbind()
|
176
|
+
if vv: var_vals.update(vv)
|
177
|
+
ast.append(LazyOp(BufferOps.STORE, (lop,), MemBuffer(i, out.dtype, output_st)))
|
178
|
+
return LBScheduleItem(LazyOp(MetaOps.KERNEL, tuple(ast)), outs, list(inputs), var_vals,
|
179
|
+
dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))
|
127
180
|
|
128
181
|
# *** DAG creation: decide which LazyBuffers should realize ***
|
129
182
|
|
130
|
-
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None],
|
131
|
-
|
183
|
+
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Dict[LazyBuffer, None],
|
184
|
+
children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer],
|
185
|
+
double_reduces:Dict[LazyBuffer, None], scheduled=False) -> None:
|
132
186
|
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
|
133
187
|
if buf in allbufs or buf.base.realized is not None: return
|
134
188
|
if GRAPH: log_lazybuffer(buf, scheduled)
|
135
|
-
#
|
136
|
-
if buf
|
189
|
+
# check if we need to realize views
|
190
|
+
if buf is not buf.base:
|
137
191
|
# fuse some pads
|
138
192
|
if len(buf.st.views) == 1 and buf.st.views[-1].mask is not None and all_int(buf.base.st.shape) and \
|
139
193
|
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
|
140
|
-
simple_pads
|
194
|
+
simple_pads[buf.base] = None
|
141
195
|
# realize all expands
|
142
196
|
elif prod(buf.base.st.shape) < prod(buf.st.shape):
|
197
|
+
# this was causing "test_lil_model" to fail
|
143
198
|
if buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg, ImageDType):
|
144
|
-
|
145
|
-
else:
|
146
|
-
realizes[buf.base] = None
|
199
|
+
simple_pads[buf.base] = None # don't realize image to image casts. this is part of a larger problem
|
200
|
+
else: realizes[buf.base] = None
|
147
201
|
# check all other pads for safe fusion
|
148
|
-
elif any(v.mask is not None for v in buf.st.views): simple_pads
|
149
|
-
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
|
150
|
-
|
202
|
+
elif any(v.mask is not None for v in buf.st.views): simple_pads[buf.base] = None
|
203
|
+
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children, assign_targets, double_reduces)
|
204
|
+
if buf.op in ReduceOps and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None
|
151
205
|
allbufs[buf] = None
|
152
|
-
if buf.forced_realize: realizes[buf] = None
|
153
|
-
if buf.op
|
154
|
-
|
206
|
+
if buf.forced_realize or buf.op in MetaOps: realizes[buf] = None
|
207
|
+
if buf.op is MetaOps.ASSIGN:
|
208
|
+
assert buf.srcs[1].base is buf.srcs[1], f"assign must be to base {buf.srcs[1]}"
|
209
|
+
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
|
210
|
+
assign_targets[buf.srcs[1]] = buf
|
211
|
+
if buf.op is MetaOps.COPY:
|
155
212
|
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
|
156
213
|
realizes[buf.srcs[0].base] = None
|
157
|
-
if buf.op is
|
214
|
+
if buf.op is MetaOps.VIEW: realizes[buf.srcs[0].base] = None
|
158
215
|
for x in buf.srcs:
|
159
|
-
children[x.base][buf] = None
|
160
|
-
_recurse_lb(x, realizes, allbufs, simple_pads, children)
|
216
|
+
if x.base.realized is None: children[x.base][buf] = None
|
217
|
+
_recurse_lb(x, realizes, allbufs, simple_pads, children, assign_targets, double_reduces)
|
161
218
|
|
162
219
|
def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
|
163
220
|
if buf in realizes or buf.realized is not None: return True
|
@@ -166,31 +223,50 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
|
|
166
223
|
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
|
167
224
|
|
168
225
|
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
|
169
|
-
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:
|
226
|
+
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None],
|
227
|
+
cache:Dict[Tuple[LazyBuffer, ShapeTracker], None]) -> None:
|
170
228
|
"""recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
|
171
|
-
if tr in
|
229
|
+
if (tr, st) in cache: return
|
230
|
+
cache.setdefault((tr, st))
|
231
|
+
if tr in realizes and tr is not r:
|
172
232
|
# can only fuse contiguous
|
173
233
|
# max one reduceop per kernel
|
174
|
-
if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.
|
175
|
-
return group.
|
234
|
+
if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.setdefault(r)
|
235
|
+
return group.setdefault(tr)
|
176
236
|
for tr_next in children[tr]:
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
237
|
+
# max one reduceop per kernel
|
238
|
+
if tr_next.op in ReduceOps: return group.setdefault(r)
|
239
|
+
# can only fuse contiguous
|
240
|
+
if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r)
|
241
|
+
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, cache)
|
242
|
+
|
243
|
+
def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],\
|
244
|
+
realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None]) -> Dict[LazyBuffer, None]:
|
245
|
+
rc_parents, cache = deque(group), set()
|
246
|
+
while rc_parents:
|
247
|
+
if (p:=rc_parents.pop()) in cache: continue
|
248
|
+
cache.add(p)
|
249
|
+
# max one reduceop per kernel
|
250
|
+
if p.op in ReduceOps: return {}
|
251
|
+
rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
|
252
|
+
# search descendants of the reduceop that can cleanly group
|
253
|
+
descendants: Dict[LazyBuffer, None] = {}
|
254
|
+
for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={})
|
255
|
+
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
|
256
|
+
|
257
|
+
SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = []
|
258
|
+
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
|
259
|
+
Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], # this is the graph
|
260
|
+
DefaultDict[LBScheduleItem, int]]: # this is the in-degree of the graph
|
186
261
|
"""create a graph for realizing the outputs"""
|
187
262
|
# start by just realizing the buffers passed in
|
188
263
|
realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None}
|
189
264
|
allbufs: Dict[LazyBuffer, None] = {}
|
190
|
-
simple_pads:
|
265
|
+
simple_pads: Dict[LazyBuffer, None] = {}
|
191
266
|
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
|
192
|
-
|
193
|
-
|
267
|
+
assign_targets: Dict[LazyBuffer, LazyBuffer] = {}
|
268
|
+
double_reduces: Dict[LazyBuffer, None] = {}
|
269
|
+
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, assign_targets, double_reduces, scheduled=True)
|
194
270
|
|
195
271
|
# check if we have to realize pads
|
196
272
|
for p in simple_pads:
|
@@ -199,40 +275,27 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
|
|
199
275
|
|
200
276
|
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
201
277
|
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
|
278
|
+
reduce_of_const: List[LazyBuffer] = []
|
202
279
|
for r in allbufs:
|
203
280
|
if r.op not in ReduceOps or r in realizes: continue
|
204
281
|
|
205
|
-
group:
|
206
|
-
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, group)
|
282
|
+
group: Dict[LazyBuffer, None] = {}
|
283
|
+
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, group, cache={})
|
207
284
|
# max one reduceop per kernel
|
208
285
|
can_chase = all(tr not in reduce_for_op for tr in group)
|
209
286
|
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
210
287
|
forced_realize = r in group
|
211
288
|
if not forced_realize and len(group) > 1:
|
212
|
-
|
213
|
-
rc_parents, rc_children = deque(group), deque(group)
|
214
|
-
while rc_parents and not forced_realize:
|
215
|
-
# max one reduceop per kernel
|
216
|
-
if (p:=rc_parents.pop()).op in ReduceOps: forced_realize = True
|
217
|
-
else: rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
|
218
|
-
# search descendants of the reduceop that can cleanly group
|
219
|
-
realized_descendants: Set[LazyBuffer] = set()
|
220
|
-
while rc_children and not forced_realize:
|
221
|
-
if (c:=rc_children.pop()).op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op:
|
222
|
-
realized_descendants.clear()
|
223
|
-
break
|
224
|
-
if c in realizes and c not in group: realized_descendants.add(c)
|
225
|
-
rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device)
|
226
|
-
group.update(realized_descendants)
|
289
|
+
group = _get_isolated_children(r, reduce_for_op, children, realizes, group)
|
227
290
|
# can only fuse assign if no other assign_target is used in the kernel
|
228
|
-
if not forced_realize and any(x.op is
|
291
|
+
if not forced_realize and any(x.op is MetaOps.ASSIGN for x in group):
|
229
292
|
parents = deque((r, *group))
|
230
293
|
while parents and not forced_realize:
|
231
294
|
if (p:=parents.pop().base).realized or p in realizes:
|
232
295
|
if p in assign_targets and assign_targets[p] not in group: forced_realize, can_chase = True, False
|
233
296
|
continue
|
234
297
|
parents.extend(p.srcs)
|
235
|
-
if forced_realize:
|
298
|
+
if forced_realize or not group:
|
236
299
|
tr = r
|
237
300
|
if can_chase:
|
238
301
|
# can chase this down to contiguous children
|
@@ -251,10 +314,26 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
|
|
251
314
|
reduce_for_op[tr] = r
|
252
315
|
realizes[tr] = None
|
253
316
|
else: reduce_for_op.update((tr, r) for tr in group)
|
317
|
+
if FUSE_ARANGE and r.op is ReduceOps.SUM and r.srcs[0].base.op is MetaOps.CONST: reduce_of_const.append(r)
|
318
|
+
|
319
|
+
# fuse double reduces with no other child
|
320
|
+
if FUSE_CONV_BW:
|
321
|
+
for reduceop in double_reduces:
|
322
|
+
top_reduce = reduceop.base.srcs[0].base
|
323
|
+
if len(children[top_reduce]) == 1: del realizes[top_reduce]
|
324
|
+
|
325
|
+
for r in reduce_of_const:
|
326
|
+
group = {tr:None for tr,rop in reduce_for_op.items() if rop is r}
|
327
|
+
if DEBUG_ARANGE:=(getenv("DEBUG_ARANGE")): print(f"checking {r} {group=}")
|
328
|
+
if any(tr.forced_realize for tr in group) or any(x.base in group for x in outs): continue
|
329
|
+
kernel_children = {c for tr in group for c in children[tr] if c.op not in {MetaOps.COPY, MetaOps.VIEW}}
|
330
|
+
if len(kernel_children) == 0: continue
|
331
|
+
if DEBUG_ARANGE: print(colored(f"folding {r}", "green"))
|
332
|
+
for tr in group: del realizes[tr]
|
254
333
|
|
255
334
|
output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
|
256
335
|
for buf in realizes:
|
257
|
-
if buf.realized is not None or buf.op is
|
336
|
+
if buf.realized is not None or buf.op is MetaOps.CONST or buf in seen: continue
|
258
337
|
output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf)
|
259
338
|
|
260
339
|
# make things that can't be images not images
|
@@ -269,59 +348,62 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
|
|
269
348
|
buf.buffer.options = None
|
270
349
|
|
271
350
|
# preschedule all buffers in realizes
|
272
|
-
prescheduled =
|
273
|
-
schedule_targets = {out:
|
351
|
+
prescheduled = [_lower_lazybuffer(group, realizes) for group in output_groups.values()]
|
352
|
+
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
|
274
353
|
|
275
|
-
graph: DefaultDict[
|
276
|
-
in_degree: DefaultDict[
|
277
|
-
for
|
278
|
-
if
|
354
|
+
graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list)
|
355
|
+
in_degree: DefaultDict[LBScheduleItem, int] = defaultdict(int)
|
356
|
+
for lsi in prescheduled:
|
357
|
+
if lsi not in in_degree: in_degree[lsi] = 0
|
279
358
|
# realize outputs after all parents are realized
|
280
|
-
scheduled_parents =
|
359
|
+
scheduled_parents = dedup(schedule_targets[x] for x in lsi.inputs if x in schedule_targets)
|
281
360
|
for x in scheduled_parents:
|
282
|
-
graph[x].append(
|
283
|
-
in_degree[
|
361
|
+
graph[x].append(lsi)
|
362
|
+
in_degree[lsi] += 1
|
284
363
|
# realize outputs before a parent is assigned to
|
285
|
-
parents_assigns =
|
364
|
+
parents_assigns = dedup(schedule_targets[assign_targets[x]] for x in lsi.inputs if x in assign_targets)
|
286
365
|
for assign in parents_assigns:
|
287
|
-
graph[
|
366
|
+
graph[lsi].append(assign)
|
288
367
|
in_degree[assign] += 1
|
289
368
|
|
290
|
-
|
369
|
+
if SAVE_SCHEDULE:
|
370
|
+
def _save():
|
371
|
+
print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
|
372
|
+
with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
|
373
|
+
if len(SCHEDULES) == 0: atexit.register(_save)
|
374
|
+
SCHEDULES.append((graph, in_degree))
|
375
|
+
return graph, in_degree
|
291
376
|
|
292
377
|
# *** DAG ordering: breadth first search ***
|
293
378
|
|
294
|
-
SCHEDULES: List = []
|
295
379
|
def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
296
380
|
if seen is None: seen = set()
|
297
|
-
graph, in_degree
|
298
|
-
|
381
|
+
graph, in_degree = _graph_schedule(outs, seen)
|
382
|
+
if getenv("RUN_PROCESS_REPLAY") and getenv("COMPARE_SCHEDULE", 1):
|
383
|
+
# NOTE: process relpay needs PYTHONPATH=., remove this once it just pickles LazyBuffers
|
384
|
+
with contextlib.suppress(Exception): importlib.import_module("test.external.process_replay.diff_schedule").process_replay(outs, graph, in_degree)
|
385
|
+
|
386
|
+
queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0)
|
299
387
|
schedule: List[ScheduleItem] = []
|
300
388
|
var_vals: Dict[Variable, int] = {}
|
301
389
|
kernel_number = GlobalCounters.kernel_count
|
302
390
|
while queue:
|
303
|
-
|
304
|
-
for buf in
|
391
|
+
lsi = queue.popleft()
|
392
|
+
for buf in lsi.outputs: seen.add(buf)
|
305
393
|
if GRAPH:
|
306
394
|
kernel_number += 1
|
307
|
-
for out in
|
308
|
-
var_vals = merge_dicts([var_vals,
|
309
|
-
for out in
|
310
|
-
schedule.append(si:=ScheduleItem(
|
311
|
-
if logops and si.ast
|
312
|
-
for x in graph[
|
395
|
+
for out in lsi.outputs: realized_lazybuffer(out, kernel_number)
|
396
|
+
var_vals = merge_dicts([var_vals, lsi.var_vals])
|
397
|
+
for out in lsi.outputs: del out.srcs # can only schedule once
|
398
|
+
schedule.append(si:=ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata))
|
399
|
+
if logops and si.ast.op is MetaOps.KERNEL and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
|
400
|
+
for x in graph[lsi]:
|
313
401
|
in_degree[x] -= 1
|
314
|
-
if in_degree[x] == 0: queue.append(
|
402
|
+
if in_degree[x] == 0: queue.append(x)
|
315
403
|
|
316
|
-
if SAVE_SCHEDULE:
|
317
|
-
def _save():
|
318
|
-
print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
|
319
|
-
with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
|
320
|
-
if len(SCHEDULES) == 0: atexit.register(_save)
|
321
|
-
SCHEDULES.extend((ps.ast for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)])
|
322
404
|
# confirm everything was scheduled correctly
|
323
|
-
if
|
324
|
-
raise RuntimeError(f"cycle detected in graph, prescheduled {len(
|
405
|
+
if any(degree != 0 for degree in in_degree.values()) or len(in_degree) != len(schedule):
|
406
|
+
raise RuntimeError(f"cycle detected in graph, prescheduled {len(in_degree)} but only scheduled {len(schedule)}")
|
325
407
|
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
326
408
|
return schedule, var_vals
|
327
409
|
|
@@ -329,42 +411,3 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
|
329
411
|
schedule, var_vals = create_schedule_with_vars(outs, seen)
|
330
412
|
assert len(var_vals) == 0
|
331
413
|
return schedule
|
332
|
-
|
333
|
-
# *** memory planning ***
|
334
|
-
|
335
|
-
def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], debug_prefix="") -> Dict[Buffer, Buffer]:
|
336
|
-
if getenv("NO_MEMORY_PLANNER"): return {}
|
337
|
-
last_appearance = {}
|
338
|
-
for i,u in enumerate(buffers):
|
339
|
-
for buf in u: last_appearance[buf] = i
|
340
|
-
|
341
|
-
# LRU algorithm
|
342
|
-
assigned: Dict[Buffer, Buffer] = {}
|
343
|
-
local_cache: DefaultDict[Tuple[str, int, DType], List[Buffer]] = defaultdict(list)
|
344
|
-
|
345
|
-
def handle_buffer(buf):
|
346
|
-
key = (buf.device, buf.size, buf.dtype)
|
347
|
-
if buf not in assigned:
|
348
|
-
if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
|
349
|
-
else: assigned[buf] = Buffer(*key)
|
350
|
-
if i == last_appearance[buf]:
|
351
|
-
if assigned[buf] not in local_cache[key]: local_cache[key].append(assigned[buf])
|
352
|
-
|
353
|
-
for i,u in enumerate(buffers):
|
354
|
-
for buf in u:
|
355
|
-
# all unallocated unparented buffers are fair game to replace
|
356
|
-
if buf.is_allocated() or buf.lb_refcount > 0: continue
|
357
|
-
# handle view buffers
|
358
|
-
if buf._base is not None:
|
359
|
-
assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf._base, buf._base), offset=buf.offset)
|
360
|
-
else:
|
361
|
-
handle_buffer(buf)
|
362
|
-
|
363
|
-
if DEBUG >= 1 and len(ak:=dedup(assigned.keys())) != len(av:=dedup(assigned.values())):
|
364
|
-
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
|
365
|
-
f"{len(ak)} -> {len(av)} bufs")
|
366
|
-
return assigned
|
367
|
-
|
368
|
-
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
|
369
|
-
assigned = _internal_memory_planner([si.bufs for si in schedule])
|
370
|
-
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs)) for si in schedule]
|