tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
- tinygrad/codegen/kernel.py +572 -83
- tinygrad/codegen/linearizer.py +415 -395
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +183 -0
- tinygrad/dtype.py +113 -0
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +76 -55
- tinygrad/helpers.py +196 -89
- tinygrad/lazy.py +210 -371
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +202 -22
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +112 -32
- tinygrad/nn/state.py +136 -39
- tinygrad/ops.py +119 -202
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +353 -166
- tinygrad/renderer/llvmir.py +150 -138
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +81 -0
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +75 -0
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +24 -77
- tinygrad/runtime/ops_cuda.py +175 -89
- tinygrad/runtime/ops_disk.py +56 -33
- tinygrad/runtime/ops_gpu.py +92 -95
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +39 -60
- tinygrad/runtime/ops_metal.py +92 -74
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +86 -254
- tinygrad/shape/symbolic.py +166 -141
- tinygrad/shape/view.py +296 -0
- tinygrad/tensor.py +2619 -448
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- tinygrad-0.9.0.dist-info/METADATA +227 -0
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/assembly.py +0 -190
- tinygrad/codegen/optimizer.py +0 -379
- tinygrad/codegen/search.py +0 -72
- tinygrad/graph.py +0 -83
- tinygrad/jit.py +0 -57
- tinygrad/nn/image.py +0 -100
- tinygrad/renderer/assembly_arm64.py +0 -169
- tinygrad/renderer/assembly_ptx.py +0 -98
- tinygrad/renderer/wgsl.py +0 -53
- tinygrad/runtime/lib.py +0 -113
- tinygrad/runtime/ops_cpu.py +0 -51
- tinygrad/runtime/ops_hip.py +0 -82
- tinygrad/runtime/ops_shm.py +0 -29
- tinygrad/runtime/ops_torch.py +0 -30
- tinygrad/runtime/ops_webgpu.py +0 -45
- tinygrad-0.7.0.dist-info/METADATA +0 -212
- tinygrad-0.7.0.dist-info/RECORD +0 -40
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,362 @@
|
|
1
|
+
import sys, pickle, atexit
|
2
|
+
from collections import defaultdict, deque
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union
|
5
|
+
from tinygrad.ops import LoadOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
|
6
|
+
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
7
|
+
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, prod, dedup, all_int, merge_dicts, getenv
|
8
|
+
from tinygrad.shape.symbolic import Variable
|
9
|
+
from tinygrad.dtype import ImageDType, dtypes, DType
|
10
|
+
from tinygrad.lazy import LazyBuffer
|
11
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
12
|
+
from tinygrad.device import Buffer
|
13
|
+
|
14
|
+
# creation can recurse a lot
|
15
|
+
sys.setrecursionlimit(10000)
|
16
|
+
|
17
|
+
# optionally log the ops to disk
|
18
|
+
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
|
19
|
+
|
20
|
+
# *** ScheduleItem return type ***
|
21
|
+
|
22
|
+
@dataclass(frozen=True)
|
23
|
+
class ScheduleItem:
|
24
|
+
ast: Tuple[LazyOp, ...]
|
25
|
+
bufs: Tuple[Buffer, ...]
|
26
|
+
@property
|
27
|
+
def outputs(self) -> Tuple[Buffer, ...]:
|
28
|
+
"""Read/write or write only buffers in the schedule."""
|
29
|
+
return self.bufs[:len(self.ast)]
|
30
|
+
@property
|
31
|
+
def inputs(self) -> Tuple[Buffer, ...]:
|
32
|
+
"""Read only buffers in the schedule."""
|
33
|
+
return self.bufs[len(self.ast):]
|
34
|
+
|
35
|
+
# *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
|
36
|
+
|
37
|
+
# TODO: it's unfortunate this needs to exist, but because of ASSIGN, we have to retain the LazyBuffer structure until post toposort
|
38
|
+
@dataclass(frozen=True)
|
39
|
+
class _LBScheduleItem:
|
40
|
+
ast: Tuple[LazyOp, ...]
|
41
|
+
outputs: Tuple[LazyBuffer, ...]
|
42
|
+
inputs: Tuple[LazyBuffer, ...]
|
43
|
+
var_vals: Dict[Variable, int]
|
44
|
+
|
45
|
+
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
|
46
|
+
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], cache) -> LazyOp:
|
47
|
+
"""recursively create a lazyop"""
|
48
|
+
if (buf, st) in cache: return cache[(buf, st)]
|
49
|
+
if buf != buf.base:
|
50
|
+
st = buf.st + st
|
51
|
+
buf = buf.base
|
52
|
+
# all buffers here are base now
|
53
|
+
assert buf.op is not None
|
54
|
+
|
55
|
+
# consts are always fused and generated
|
56
|
+
if buf.op is LoadOps.CONST:
|
57
|
+
unbound_st, st_var_vals = st.simplify().unbind()
|
58
|
+
var_vals.update(st_var_vals)
|
59
|
+
if isinstance(buf.arg, Variable): var_vals.__setitem__(*buf.arg.unbind())
|
60
|
+
return LazyOp(BufferOps.CONST, (), ConstBuffer(buf.arg, buf.dtype, unbound_st))
|
61
|
+
|
62
|
+
# if we aren't fusing it, it's a load and we add it to the inputs
|
63
|
+
if buf.realized is not None or (buf in realizes and buf not in outputs):
|
64
|
+
unbound_st, st_var_vals = st.simplify().unbind()
|
65
|
+
var_vals.update(st_var_vals)
|
66
|
+
if buf in assign_targets:
|
67
|
+
# can only assign to contiguous read+write buffer
|
68
|
+
if not unbound_st.contiguous:
|
69
|
+
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
70
|
+
if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
|
71
|
+
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
|
72
|
+
raise RuntimeError(f"must be contiguous for assign {unbound_st}")
|
73
|
+
return LazyOp(BufferOps.LOAD, (), MemBuffer(outputs.index(assign_targets[buf]), buf.dtype, unbound_st))
|
74
|
+
if buf not in inputs: inputs.append(buf)
|
75
|
+
return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.index(buf), buf.dtype, unbound_st))
|
76
|
+
|
77
|
+
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
|
78
|
+
if buf.op is LoadOps.CONTIGUOUS:
|
79
|
+
assert buf in outputs
|
80
|
+
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
|
81
|
+
if buf.op is LoadOps.ASSIGN:
|
82
|
+
assert buf in outputs
|
83
|
+
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
|
84
|
+
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
|
85
|
+
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
|
86
|
+
|
87
|
+
# if it's a reduce, we have to change the shapetracker
|
88
|
+
if buf.op in ReduceOps:
|
89
|
+
assert st.contiguous, "ReduceOps late fusion must be contiguous"
|
90
|
+
st = ShapeTracker.from_shape(buf.srcs[0].shape)
|
91
|
+
|
92
|
+
# otherwise we fuse it like normal
|
93
|
+
cache[(buf, st)] = ret = \
|
94
|
+
LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outputs, var_vals, st, realizes, assign_targets, cache) for x in buf.srcs), buf.arg)
|
95
|
+
return ret
|
96
|
+
|
97
|
+
def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
|
98
|
+
"""create a schedule item from a list of outputs"""
|
99
|
+
inputs: List[LazyBuffer] = []
|
100
|
+
ast: List[LazyOp] = []
|
101
|
+
var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
|
102
|
+
# single output AST
|
103
|
+
if (op:=(out:=outs[0]).op) in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}:
|
104
|
+
assert len(outs) == 1, f"can't schedule a group of {op}"
|
105
|
+
inputs = [x.base for x in out.srcs]
|
106
|
+
if getenv("USE_COPY_KERNEL") and op is LoadOps.COPY and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
|
107
|
+
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
|
108
|
+
ast = [LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st))]
|
109
|
+
else: ast = [LazyOp(op, (), out.arg)]
|
110
|
+
# multi output AST
|
111
|
+
else:
|
112
|
+
assign_targets = {x.srcs[1]:x for x in outs if x.op is LoadOps.ASSIGN}
|
113
|
+
for i, out in enumerate(outs):
|
114
|
+
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
|
115
|
+
output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
|
116
|
+
lop = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, assign_targets, cache={})
|
117
|
+
output_view, vv = output_view.simplify().unbind()
|
118
|
+
if vv: var_vals.update(vv)
|
119
|
+
ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
|
120
|
+
return _LBScheduleItem(tuple(ast), outs, tuple(inputs), var_vals)
|
121
|
+
|
122
|
+
# *** DAG creation: decide which LazyBuffers should realize ***
|
123
|
+
|
124
|
+
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None],
|
125
|
+
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
|
126
|
+
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
|
127
|
+
if buf in allbufs or buf.base.realized is not None: return
|
128
|
+
if GRAPH: log_lazybuffer(buf, scheduled)
|
129
|
+
# view
|
130
|
+
if buf.base != buf:
|
131
|
+
# fuse some pads
|
132
|
+
if len(buf.st.views) == 1 and buf.st.views[-1].mask is not None and all_int(buf.base.st.shape) and \
|
133
|
+
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
|
134
|
+
simple_pads.add(buf.base)
|
135
|
+
# realize all expands
|
136
|
+
elif prod(buf.base.st.shape) < prod(buf.st.shape):
|
137
|
+
if buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg, ImageDType):
|
138
|
+
pass # don't realize image to image casts. this is part of a larger problem
|
139
|
+
else:
|
140
|
+
realizes[buf.base] = None
|
141
|
+
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
|
142
|
+
# base
|
143
|
+
allbufs[buf] = None
|
144
|
+
if buf.forced_realize: realizes[buf] = None
|
145
|
+
if buf.op in LoadOps: realizes[buf.base] = None
|
146
|
+
if buf.op is LoadOps.COPY:
|
147
|
+
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
|
148
|
+
realizes[buf.srcs[0].base] = None
|
149
|
+
if buf.op is LoadOps.VIEW: realizes[buf.srcs[0].base] = None
|
150
|
+
for x in buf.srcs:
|
151
|
+
children[x.base][buf] = None
|
152
|
+
_recurse_lb(x, realizes, allbufs, simple_pads, children)
|
153
|
+
|
154
|
+
def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
|
155
|
+
if buf in realizes or buf.realized is not None: return True
|
156
|
+
# NOTE: this broke to_image_idx and coder with JIT
|
157
|
+
if buf.op in UNSAFE_PAD_OPS: return False
|
158
|
+
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
|
159
|
+
|
160
|
+
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
|
161
|
+
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer]):
|
162
|
+
"""recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
|
163
|
+
if tr in realizes:
|
164
|
+
# can only fuse contiguous
|
165
|
+
# max one reduceop per kernel
|
166
|
+
if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.add(r)
|
167
|
+
return group.add(tr)
|
168
|
+
for tr_next in children[tr]:
|
169
|
+
if tr_next.realized is None:
|
170
|
+
# max one reduceop per kernel
|
171
|
+
if tr_next.op in ReduceOps: return group.add(r)
|
172
|
+
# can only fuse contiguous
|
173
|
+
if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r)
|
174
|
+
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group)
|
175
|
+
|
176
|
+
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], DefaultDict[LazyBuffer, int],
|
177
|
+
Dict[LazyBuffer, _LBScheduleItem]]:
|
178
|
+
"""create a graph for realizing the outputs"""
|
179
|
+
# start by just realizing the buffers passed in
|
180
|
+
realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None}
|
181
|
+
allbufs: Dict[LazyBuffer, None] = {}
|
182
|
+
simple_pads: Set[LazyBuffer] = set()
|
183
|
+
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
|
184
|
+
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
|
185
|
+
assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None}
|
186
|
+
|
187
|
+
# check if we have to realize pads
|
188
|
+
for p in simple_pads:
|
189
|
+
if not _is_padding_okay(p, realizes):
|
190
|
+
realizes[p] = None
|
191
|
+
|
192
|
+
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
193
|
+
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
|
194
|
+
for r in allbufs:
|
195
|
+
if r.op not in ReduceOps or r in realizes: continue
|
196
|
+
|
197
|
+
group: Set[LazyBuffer] = set()
|
198
|
+
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, group)
|
199
|
+
# max one reduceop per kernel
|
200
|
+
can_chase = all(tr not in reduce_for_op for tr in group)
|
201
|
+
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
202
|
+
forced_realize = r in group
|
203
|
+
if not forced_realize and len(group) > 1:
|
204
|
+
# create a multi output kernel if the LazyBufferss can cleanly group
|
205
|
+
rc_parents, rc_children = deque(group), deque(group)
|
206
|
+
while rc_parents and not forced_realize:
|
207
|
+
# max one reduceop per kernel
|
208
|
+
if (p:=rc_parents.pop()).op in ReduceOps: forced_realize = True
|
209
|
+
else: rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
|
210
|
+
# search descendants of the reduceop that can cleanly group
|
211
|
+
realized_descendants: Set[LazyBuffer] = set()
|
212
|
+
while rc_children and not forced_realize:
|
213
|
+
if (c:=rc_children.pop()).op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op:
|
214
|
+
realized_descendants.clear()
|
215
|
+
break
|
216
|
+
if c in realizes and c not in group: realized_descendants.add(c)
|
217
|
+
rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device)
|
218
|
+
group.update(realized_descendants)
|
219
|
+
# can only fuse assign if no other assign_target is used in the kernel
|
220
|
+
if not forced_realize and any(x.op is LoadOps.ASSIGN for x in group):
|
221
|
+
parents = deque((r, *group))
|
222
|
+
while parents and not forced_realize:
|
223
|
+
if (p:=parents.pop().base).realized or p in realizes:
|
224
|
+
if p in assign_targets and assign_targets[p] not in group: forced_realize, can_chase = True, False
|
225
|
+
continue
|
226
|
+
parents.extend(p.srcs)
|
227
|
+
if forced_realize:
|
228
|
+
tr = r
|
229
|
+
if can_chase:
|
230
|
+
# can chase this down to contiguous children
|
231
|
+
st = tr.st
|
232
|
+
while len(children[tr]) == 1:
|
233
|
+
tr_next = next(iter(children[tr]))
|
234
|
+
st_childs = dedup(s for s in tr_next.srcs if s.base is tr)
|
235
|
+
if len(st_childs) > 1: break
|
236
|
+
if st.size != st_childs[0].st.size: break
|
237
|
+
st = st + st_childs[0].st
|
238
|
+
if not st.contiguous or tr_next.op in ReduceOps: break
|
239
|
+
tr = tr_next
|
240
|
+
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
241
|
+
if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
|
242
|
+
tr = tr.srcs[0].base
|
243
|
+
reduce_for_op[tr] = r
|
244
|
+
realizes[tr] = None
|
245
|
+
else: reduce_for_op.update((tr, r) for tr in group)
|
246
|
+
|
247
|
+
output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
|
248
|
+
for buf in realizes:
|
249
|
+
if buf.realized is not None or buf.op is LoadOps.CONST or buf in seen: continue
|
250
|
+
output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf)
|
251
|
+
|
252
|
+
# make things that can't be images not images
|
253
|
+
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
|
254
|
+
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
|
255
|
+
if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
|
256
|
+
buf.dtype = dtypes.float32
|
257
|
+
# hack the underlying buffer too
|
258
|
+
if buf.base is buf:
|
259
|
+
assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer"
|
260
|
+
buf.buffer.dtype = dtypes.float32
|
261
|
+
buf.buffer.options = None
|
262
|
+
|
263
|
+
# preschedule all buffers in realizes
|
264
|
+
prescheduled = {group[0]:_schedule_group(tuple(group), realizes, reduce_for_op) for group in output_groups.values()}
|
265
|
+
schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs}
|
266
|
+
|
267
|
+
graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
|
268
|
+
in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int)
|
269
|
+
for key, lsi in prescheduled.items():
|
270
|
+
if key not in in_degree: in_degree[key] = 0
|
271
|
+
# realize outputs after all parents are realized
|
272
|
+
scheduled_parents = set(schedule_targets[x].outputs[0] for x in lsi.inputs if x in schedule_targets)
|
273
|
+
for x in scheduled_parents:
|
274
|
+
graph[x].append(key)
|
275
|
+
in_degree[key] += 1
|
276
|
+
# realize outputs before a parent is assigned to
|
277
|
+
parents_assigns = set(schedule_targets[assign_targets[x]].outputs[0] for x in lsi.inputs if x in assign_targets)
|
278
|
+
for assign in parents_assigns:
|
279
|
+
graph[key].append(assign)
|
280
|
+
in_degree[assign] += 1
|
281
|
+
|
282
|
+
return graph, in_degree, prescheduled
|
283
|
+
|
284
|
+
# *** DAG ordering: breadth first search ***
|
285
|
+
|
286
|
+
SCHEDULES: List = []
|
287
|
+
def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
288
|
+
if seen is None: seen = set()
|
289
|
+
graph, in_degree, prescheduled = _graph_schedule(outs, seen)
|
290
|
+
queue = deque(si for key, si in prescheduled.items() if in_degree[key] == 0)
|
291
|
+
schedule: List[ScheduleItem] = []
|
292
|
+
var_vals: Dict[Variable, int] = {}
|
293
|
+
kernel_number = GlobalCounters.kernel_count
|
294
|
+
while queue:
|
295
|
+
ps = queue.popleft()
|
296
|
+
for buf in ps.outputs: seen.add(buf)
|
297
|
+
if GRAPH:
|
298
|
+
kernel_number += 1
|
299
|
+
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
|
300
|
+
var_vals = merge_dicts([var_vals, ps.var_vals])
|
301
|
+
for out in ps.outputs: del out.srcs # can only schedule once
|
302
|
+
schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0)))
|
303
|
+
if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
|
304
|
+
for x in graph[ps.outputs[0]]:
|
305
|
+
in_degree[x] -= 1
|
306
|
+
if in_degree[x] == 0: queue.append(prescheduled[x])
|
307
|
+
|
308
|
+
if SAVE_SCHEDULE:
|
309
|
+
def _save():
|
310
|
+
print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
|
311
|
+
pickle.dump(SCHEDULES, open(fp, "wb"))
|
312
|
+
if len(SCHEDULES) == 0: atexit.register(_save)
|
313
|
+
SCHEDULES.extend((ps.ast for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)])
|
314
|
+
# confirm everything was scheduled correctly
|
315
|
+
if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
|
316
|
+
raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
|
317
|
+
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
318
|
+
return schedule, var_vals
|
319
|
+
|
320
|
+
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
|
321
|
+
schedule, var_vals = create_schedule_with_vars(outs, seen)
|
322
|
+
assert len(var_vals) == 0
|
323
|
+
return schedule
|
324
|
+
|
325
|
+
# *** memory planning ***
|
326
|
+
|
327
|
+
def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], debug_prefix="") -> Dict[Buffer, Buffer]:
|
328
|
+
if getenv("NO_MEMORY_PLANNER"): return {}
|
329
|
+
last_appearance = {}
|
330
|
+
for i,u in enumerate(buffers):
|
331
|
+
for buf in u: last_appearance[buf] = i
|
332
|
+
|
333
|
+
# LRU algorithm
|
334
|
+
assigned: Dict[Buffer, Buffer] = {}
|
335
|
+
local_cache: DefaultDict[Tuple[str, int, DType], List[Buffer]] = defaultdict(list)
|
336
|
+
|
337
|
+
def handle_buffer(buf):
|
338
|
+
key = (buf.device, buf.size, buf.dtype)
|
339
|
+
if buf not in assigned:
|
340
|
+
if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
|
341
|
+
else: assigned[buf] = Buffer(*key)
|
342
|
+
if i == last_appearance[buf]:
|
343
|
+
if assigned[buf] not in local_cache[key]: local_cache[key].append(assigned[buf])
|
344
|
+
|
345
|
+
for i,u in enumerate(buffers):
|
346
|
+
for buf in u:
|
347
|
+
# all unallocated unparented buffers are fair game to replace
|
348
|
+
if buf.is_allocated() or buf.lb_refcount > 0: continue
|
349
|
+
# handle view buffers
|
350
|
+
if buf._base is not None:
|
351
|
+
assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf._base, buf._base), offset=buf.offset)
|
352
|
+
else:
|
353
|
+
handle_buffer(buf)
|
354
|
+
|
355
|
+
if DEBUG >= 1 and len(ak:=dedup(assigned.keys())) != len(av:=dedup(assigned.values())):
|
356
|
+
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
|
357
|
+
f"{len(ak)} -> {len(av)} bufs")
|
358
|
+
return assigned
|
359
|
+
|
360
|
+
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
|
361
|
+
assigned = _internal_memory_planner([si.bufs for si in schedule])
|
362
|
+
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs)) for si in schedule]
|
@@ -0,0 +1,196 @@
|
|
1
|
+
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
2
|
+
import itertools, functools, random, math, time, multiprocessing, traceback, signal
|
3
|
+
from collections import defaultdict
|
4
|
+
from dataclasses import replace
|
5
|
+
from tinygrad.device import Device, Buffer, Compiler
|
6
|
+
from tinygrad.ops import MemBuffer
|
7
|
+
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
8
|
+
from tinygrad.dtype import ImageDType
|
9
|
+
from tinygrad.codegen.linearizer import Linearizer
|
10
|
+
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
|
11
|
+
from tinygrad.codegen.uops import UOpGraph
|
12
|
+
from tinygrad.tensor import Tensor
|
13
|
+
from tinygrad.shape.symbolic import sym_infer
|
14
|
+
from tinygrad.engine.realize import CompiledRunner
|
15
|
+
from tinygrad.renderer import Program
|
16
|
+
|
17
|
+
actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
|
18
|
+
actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(5)]
|
19
|
+
actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(5)]
|
20
|
+
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
|
21
|
+
actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]
|
22
|
+
if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
|
23
|
+
actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
|
24
|
+
actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
|
25
|
+
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
26
|
+
|
27
|
+
def _get_test_global_size(global_size, max_global_size, var_vals):
|
28
|
+
test_global_size, factor = [sym_infer(sz, var_vals) for sz in global_size], 1
|
29
|
+
while prod(test_global_size) > max_global_size:
|
30
|
+
for j in range(len(global_size)-1,-1,-1):
|
31
|
+
if test_global_size[j] > 16:
|
32
|
+
test_global_size[j] //= 2
|
33
|
+
factor *= 2
|
34
|
+
break
|
35
|
+
return test_global_size, factor
|
36
|
+
|
37
|
+
def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"):
|
38
|
+
factor = 1
|
39
|
+
if p.global_size is not None and max_global_size is not None:
|
40
|
+
global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals)
|
41
|
+
p = replace(p, global_size=global_size)
|
42
|
+
try: car = CompiledRunner(p, precompiled=lib)
|
43
|
+
except AssertionError: return [math.inf] * cnt
|
44
|
+
tms = []
|
45
|
+
input_bufs = [rawbufs[i] for i,_ in car.p.globals]
|
46
|
+
for _ in range(cnt):
|
47
|
+
if clear_l2:
|
48
|
+
with Context(DEBUG=0, BEAM=0, CACHECOLLECTING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
49
|
+
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
|
50
|
+
if early_stop is not None and early_stop < tms[-1]: break
|
51
|
+
return tms
|
52
|
+
|
53
|
+
class TimeoutException(Exception): pass
|
54
|
+
def timeout_handler(signum, frame): raise TimeoutException()
|
55
|
+
|
56
|
+
def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
|
57
|
+
signal.signal(signal.SIGALRM, timeout_handler)
|
58
|
+
# set timeout
|
59
|
+
signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
|
60
|
+
try:
|
61
|
+
x[1].linearize()
|
62
|
+
if len(x[1].uops.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
|
63
|
+
p = x[1].to_program()
|
64
|
+
st = time.perf_counter()
|
65
|
+
prog = compiler.compile(p.src)
|
66
|
+
et = time.perf_counter() - st
|
67
|
+
ret = (p, prog, et)
|
68
|
+
except RuntimeError:
|
69
|
+
if DEBUG >= 4: traceback.print_exc()
|
70
|
+
ret = None
|
71
|
+
except TimeoutException:
|
72
|
+
ret = None
|
73
|
+
finally:
|
74
|
+
signal.alarm(0)
|
75
|
+
return x[0], ret
|
76
|
+
|
77
|
+
# workers should ignore ctrl c
|
78
|
+
def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
|
79
|
+
|
80
|
+
def _ensure_buffer_alloc(bufs:List[Buffer]) -> List[Buffer]: return [buf.ensure_allocated() for buf in bufs]
|
81
|
+
|
82
|
+
# *** external API ***
|
83
|
+
|
84
|
+
# get (scrap) buffers for timing the linearizer
|
85
|
+
def bufs_from_lin(lin:Linearizer, allocate:bool=True) -> List[Buffer]:
|
86
|
+
bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
|
87
|
+
for x in lin.membufs: bufsts[x.idx].append(x)
|
88
|
+
rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
|
89
|
+
for k,lx in bufsts.items():
|
90
|
+
buf_size = prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.real_size() for y in lx)
|
91
|
+
if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case.
|
92
|
+
rawbufs[k] = Buffer(lin.opts.device, buf_size, lx[0].dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, lx[0].dtype)
|
93
|
+
assert all(r is not None for r in rawbufs)
|
94
|
+
return cast(List[Buffer], rawbufs)
|
95
|
+
|
96
|
+
# get dictionary of all possible actions
|
97
|
+
def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
|
98
|
+
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
|
99
|
+
for i,a in enumerate(actions):
|
100
|
+
if a.axis is not None and a.op is not OptOps.TC:
|
101
|
+
if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.amt and Opt(a.op, ax, 0) in actions): continue
|
102
|
+
lin2 = lin.copy()
|
103
|
+
try:
|
104
|
+
lin2.apply_opt(a)
|
105
|
+
up, lcl, tc_up = 1, 1, prod(tc.dims)//prod([x[1] for x in tc.threads]) if (tc:=lin2.tensor_core) else 1
|
106
|
+
for s,c in zip(lin2.full_shape, lin2.colors()):
|
107
|
+
if c in {"magenta", "yellow"}: up *= s
|
108
|
+
elif c in {"cyan", "green", "white"}: lcl *= s
|
109
|
+
if up//tc_up > max_up or lcl > max_lcl: continue
|
110
|
+
acted_lins[i+1] = lin2
|
111
|
+
except KernelOptError: pass
|
112
|
+
return acted_lins
|
113
|
+
|
114
|
+
beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
|
115
|
+
def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=True) -> Linearizer:
|
116
|
+
global beam_pool
|
117
|
+
key = {"ast": lin.ast[0].key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
118
|
+
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
|
119
|
+
ret = lin.copy()
|
120
|
+
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
121
|
+
return ret
|
122
|
+
|
123
|
+
beam: List[Tuple[Linearizer, float]] = [(lin, float("inf"))]
|
124
|
+
seen_libs = set()
|
125
|
+
|
126
|
+
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "HSA", "AMD", "NV"} else 0
|
127
|
+
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
|
128
|
+
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
|
129
|
+
|
130
|
+
min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6
|
131
|
+
if BEAM_DEBUG: print(f"BEAM_SEARCH:\n{lin.ast}")
|
132
|
+
if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}")
|
133
|
+
|
134
|
+
try:
|
135
|
+
rawbufs = _ensure_buffer_alloc(rawbufs)
|
136
|
+
var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
|
137
|
+
exiting, st = False, time.perf_counter()
|
138
|
+
dev = Device[lin.opts.device]
|
139
|
+
while not exiting:
|
140
|
+
acted_lins: List[Linearizer] = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
|
141
|
+
timed_lins: List[Tuple[Linearizer, float]] = []
|
142
|
+
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
|
143
|
+
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
|
144
|
+
if proc is None: continue
|
145
|
+
p, lib, compile_et = proc
|
146
|
+
if lib in seen_libs: continue
|
147
|
+
#print(acted_lins[i].colored_shape(), acted_lins[i].applied_opts) # for debugging BEAMs that segfault
|
148
|
+
seen_libs.add(lib)
|
149
|
+
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
|
150
|
+
except RuntimeError: continue # for runtime issues
|
151
|
+
timed_lins.append((acted_lins[i], min(tms)))
|
152
|
+
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(UOpGraph, p.uops).uops):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
153
|
+
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
154
|
+
|
155
|
+
# done
|
156
|
+
opts = sorted(timed_lins, key=lambda x: x[1])
|
157
|
+
exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress))
|
158
|
+
if not exiting: beam = opts[:amt]
|
159
|
+
elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1]
|
160
|
+
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
|
161
|
+
except KeyboardInterrupt as e:
|
162
|
+
if beam_pool is not None: beam_pool.terminate()
|
163
|
+
raise e
|
164
|
+
|
165
|
+
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
166
|
+
if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={beam[0][1]*1e6:0.2f} us, applied_opts={beam[0][0].applied_opts}")
|
167
|
+
return beam[0][0]
|
168
|
+
|
169
|
+
def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buffer]) -> List[int]:
|
170
|
+
test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
|
171
|
+
MAX_WORKGROUP = 1024
|
172
|
+
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
|
173
|
+
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
|
174
|
+
def try_exec(local_size):
|
175
|
+
try: return clprg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501
|
176
|
+
except Exception: return float('inf')
|
177
|
+
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
|
178
|
+
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
|
179
|
+
return ret[1]
|
180
|
+
|
181
|
+
def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
|
182
|
+
key = {"ast": lin.ast[0].key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
|
183
|
+
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
184
|
+
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
185
|
+
|
186
|
+
dev = Device[lin.opts.device]
|
187
|
+
assert dev.compiler is not None
|
188
|
+
|
189
|
+
rawbufs = _ensure_buffer_alloc(rawbufs)
|
190
|
+
var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
|
191
|
+
p = lin.to_program()
|
192
|
+
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
|
193
|
+
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
|
194
|
+
|
195
|
+
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
|
196
|
+
return min(tms)
|