tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/__init__.py +0 -0
  3. tinygrad/codegen/kernel.py +253 -225
  4. tinygrad/codegen/linearizer.py +398 -436
  5. tinygrad/codegen/uops.py +451 -0
  6. tinygrad/device.py +268 -274
  7. tinygrad/dtype.py +56 -40
  8. tinygrad/engine/__init__.py +0 -0
  9. tinygrad/engine/graph.py +100 -0
  10. tinygrad/engine/jit.py +198 -0
  11. tinygrad/engine/realize.py +192 -0
  12. tinygrad/engine/schedule.py +370 -0
  13. tinygrad/engine/search.py +199 -0
  14. tinygrad/{mlops.py → function.py} +40 -32
  15. tinygrad/helpers.py +144 -46
  16. tinygrad/lazy.py +143 -242
  17. tinygrad/multi.py +173 -0
  18. tinygrad/nn/__init__.py +180 -9
  19. tinygrad/nn/datasets.py +8 -0
  20. tinygrad/nn/optim.py +106 -28
  21. tinygrad/nn/state.py +87 -19
  22. tinygrad/ops.py +104 -45
  23. tinygrad/renderer/__init__.py +65 -0
  24. tinygrad/renderer/assembly.py +269 -0
  25. tinygrad/renderer/cstyle.py +308 -210
  26. tinygrad/renderer/llvmir.py +119 -124
  27. tinygrad/runtime/__init__.py +0 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +13403 -0
  29. tinygrad/runtime/autogen/comgr.py +891 -0
  30. tinygrad/runtime/autogen/cuda.py +5923 -0
  31. tinygrad/runtime/autogen/hip.py +5909 -0
  32. tinygrad/runtime/autogen/hsa.py +5893 -0
  33. tinygrad/runtime/autogen/io_uring.py +1486 -0
  34. tinygrad/runtime/autogen/kfd.py +812 -0
  35. tinygrad/runtime/autogen/nv_gpu.py +33597 -0
  36. tinygrad/runtime/autogen/opencl.py +1795 -0
  37. tinygrad/runtime/driver/__init__.py +0 -0
  38. tinygrad/runtime/driver/hip_comgr.py +56 -0
  39. tinygrad/runtime/graph/__init__.py +0 -0
  40. tinygrad/runtime/graph/clang.py +39 -0
  41. tinygrad/runtime/graph/cuda.py +59 -54
  42. tinygrad/runtime/graph/hcq.py +187 -0
  43. tinygrad/runtime/graph/metal.py +37 -41
  44. tinygrad/runtime/ops_amd.py +550 -0
  45. tinygrad/runtime/ops_clang.py +16 -14
  46. tinygrad/runtime/ops_cuda.py +129 -37
  47. tinygrad/runtime/ops_disk.py +111 -43
  48. tinygrad/runtime/ops_gpu.py +52 -50
  49. tinygrad/runtime/ops_llvm.py +36 -56
  50. tinygrad/runtime/ops_metal.py +41 -24
  51. tinygrad/runtime/ops_npy.py +9 -0
  52. tinygrad/runtime/ops_nv.py +625 -0
  53. tinygrad/runtime/ops_python.py +208 -0
  54. tinygrad/shape/__init__.py +0 -0
  55. tinygrad/shape/shapetracker.py +46 -107
  56. tinygrad/shape/symbolic.py +99 -98
  57. tinygrad/shape/view.py +162 -45
  58. tinygrad/tensor.py +2492 -483
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
  60. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
  61. tinygrad-0.9.1.dist-info/RECORD +63 -0
  62. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  63. tinygrad/features/image.py +0 -93
  64. tinygrad/features/multi.py +0 -103
  65. tinygrad/features/search.py +0 -160
  66. tinygrad/graph.py +0 -106
  67. tinygrad/jit.py +0 -152
  68. tinygrad/realize.py +0 -50
  69. tinygrad/runtime/graph/hip.py +0 -24
  70. tinygrad/runtime/ops_cpu.py +0 -45
  71. tinygrad/runtime/ops_hip.py +0 -97
  72. tinygrad/runtime/ops_torch.py +0 -49
  73. tinygrad-0.8.0.dist-info/RECORD +0 -41
  74. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,370 @@
1
+ import sys, pickle, atexit
2
+ from collections import defaultdict, deque
3
+ from dataclasses import dataclass
4
+ from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, get_args
5
+ from tinygrad.ops import LoadOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
6
+ from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
7
+ from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv
8
+ from tinygrad.shape.symbolic import Variable
9
+ from tinygrad.dtype import ConstType, ImageDType, dtypes, DType
10
+ from tinygrad.lazy import LazyBuffer
11
+ from tinygrad.shape.shapetracker import ShapeTracker
12
+ from tinygrad.device import Buffer
13
+
14
+ # creation can recurse a lot
15
+ sys.setrecursionlimit(10000)
16
+
17
+ # optionally log the ops to disk
18
+ logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
19
+
20
+ # *** ScheduleItem return type ***
21
+
22
+ @dataclass(frozen=True)
23
+ class ScheduleItem:
24
+ ast: Tuple[LazyOp, ...]
25
+ bufs: Tuple[Buffer, ...]
26
+ @property
27
+ def outputs(self) -> Tuple[Buffer, ...]:
28
+ """Read/write or write only buffers in the schedule."""
29
+ return self.bufs[:len(self.ast)]
30
+ @property
31
+ def inputs(self) -> Tuple[Buffer, ...]:
32
+ """Read only buffers in the schedule."""
33
+ return self.bufs[len(self.ast):]
34
+
35
+ # *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
36
+
37
+ # TODO: it's unfortunate this needs to exist, but because of ASSIGN, we have to retain the LazyBuffer structure until post toposort
38
+ @dataclass(frozen=True)
39
+ class _LBScheduleItem:
40
+ ast: Tuple[LazyOp, ...]
41
+ outputs: Tuple[LazyBuffer, ...]
42
+ inputs: Tuple[LazyBuffer, ...]
43
+ var_vals: Dict[Variable, int]
44
+
45
+ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
46
+ realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], cache) -> LazyOp:
47
+ """recursively create a lazyop"""
48
+ if (buf, st) in cache: return cache[(buf, st)]
49
+ if buf != buf.base:
50
+ st = buf.st + st
51
+ buf = buf.base
52
+ # all buffers here are base now
53
+ assert buf.op is not None
54
+
55
+ # consts are always fused and generated
56
+ if buf.op is LoadOps.CONST:
57
+ unbound_st, st_var_vals = st.simplify().unbind()
58
+ var_vals.update(st_var_vals)
59
+ if isinstance(buf.arg, Variable):
60
+ val, var_val = buf.arg.unbind()
61
+ var_vals.__setitem__(val, var_val)
62
+ else:
63
+ assert isinstance(buf.arg, get_args(ConstType)), f"cannot create ConstBuffer with value {buf.arg}"
64
+ val = buf.arg
65
+ return LazyOp(BufferOps.CONST, (), ConstBuffer(val, buf.dtype, unbound_st))
66
+
67
+ # if we aren't fusing it, it's a load and we add it to the inputs
68
+ if buf.realized is not None or (buf in realizes and buf not in outputs):
69
+ unbound_st, st_var_vals = st.simplify().unbind()
70
+ var_vals.update(st_var_vals)
71
+ if buf in assign_targets:
72
+ # can only assign to contiguous read+write buffer
73
+ if not unbound_st.contiguous:
74
+ # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
75
+ if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
76
+ ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
77
+ raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
78
+ +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
79
+ return LazyOp(BufferOps.LOAD, (), MemBuffer(outputs.index(assign_targets[buf]), buf.dtype, unbound_st))
80
+ if buf not in inputs: inputs.append(buf)
81
+ return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.index(buf), buf.dtype, unbound_st))
82
+
83
+ # if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
84
+ if buf.op is LoadOps.CONTIGUOUS:
85
+ assert buf in outputs
86
+ return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
87
+ if buf.op is LoadOps.ASSIGN:
88
+ assert buf in outputs
89
+ assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
90
+ assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
91
+ return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
92
+
93
+ # if it's a reduce, we have to change the shapetracker
94
+ if buf.op in ReduceOps:
95
+ assert st.contiguous, "ReduceOps late fusion must be contiguous"
96
+ st = ShapeTracker.from_shape(buf.srcs[0].shape)
97
+
98
+ # otherwise we fuse it like normal
99
+ cache[(buf, st)] = ret = \
100
+ LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outputs, var_vals, st, realizes, assign_targets, cache) for x in buf.srcs), buf.arg)
101
+ return ret
102
+
103
+ def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
104
+ """create a schedule item from a list of outputs"""
105
+ inputs: List[LazyBuffer] = []
106
+ ast: List[LazyOp] = []
107
+ var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
108
+ # single output AST
109
+ if (op:=(out:=outs[0]).op) in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}:
110
+ assert len(outs) == 1, f"can't schedule a group of {op}"
111
+ inputs = [x.base for x in out.srcs]
112
+ if getenv("USE_COPY_KERNEL") and op is LoadOps.COPY and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
113
+ rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
114
+ ast = [LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st))]
115
+ else: ast = [LazyOp(op, (), out.arg)]
116
+ # multi output AST
117
+ else:
118
+ assign_targets = {x.srcs[1]:x for x in outs if x.op is LoadOps.ASSIGN}
119
+ for i, out in enumerate(outs):
120
+ output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
121
+ output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
122
+ lop = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, assign_targets, cache={})
123
+ output_view, vv = output_view.simplify().unbind()
124
+ if vv: var_vals.update(vv)
125
+ ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
126
+ return _LBScheduleItem(tuple(ast), outs, tuple(inputs), var_vals)
127
+
128
+ # *** DAG creation: decide which LazyBuffers should realize ***
129
+
130
+ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None],
131
+ simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
132
+ """recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
133
+ if buf in allbufs or buf.base.realized is not None: return
134
+ if GRAPH: log_lazybuffer(buf, scheduled)
135
+ # view
136
+ if buf.base != buf:
137
+ # fuse some pads
138
+ if len(buf.st.views) == 1 and buf.st.views[-1].mask is not None and all_int(buf.base.st.shape) and \
139
+ prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
140
+ simple_pads.add(buf.base)
141
+ # realize all expands
142
+ elif prod(buf.base.st.shape) < prod(buf.st.shape):
143
+ if buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg, ImageDType):
144
+ pass # don't realize image to image casts. this is part of a larger problem
145
+ else:
146
+ realizes[buf.base] = None
147
+ # check all other pads for safe fusion
148
+ elif any(v.mask is not None for v in buf.st.views): simple_pads.add(buf.base)
149
+ return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
150
+ # base
151
+ allbufs[buf] = None
152
+ if buf.forced_realize: realizes[buf] = None
153
+ if buf.op in LoadOps: realizes[buf.base] = None
154
+ if buf.op is LoadOps.COPY:
155
+ assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
156
+ realizes[buf.srcs[0].base] = None
157
+ if buf.op is LoadOps.VIEW: realizes[buf.srcs[0].base] = None
158
+ for x in buf.srcs:
159
+ children[x.base][buf] = None
160
+ _recurse_lb(x, realizes, allbufs, simple_pads, children)
161
+
162
+ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
163
+ if buf in realizes or buf.realized is not None: return True
164
+ # NOTE: this broke to_image_idx and coder with JIT
165
+ if buf.op in UNSAFE_PAD_OPS: return False
166
+ return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
167
+
168
+ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
169
+ realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer]):
170
+ """recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
171
+ if tr in realizes:
172
+ # can only fuse contiguous
173
+ # max one reduceop per kernel
174
+ if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.add(r)
175
+ return group.add(tr)
176
+ for tr_next in children[tr]:
177
+ if tr_next.realized is None:
178
+ # max one reduceop per kernel
179
+ if tr_next.op in ReduceOps: return group.add(r)
180
+ # can only fuse contiguous
181
+ if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r)
182
+ _recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group)
183
+
184
+ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], DefaultDict[LazyBuffer, int],
185
+ Dict[LazyBuffer, _LBScheduleItem]]:
186
+ """create a graph for realizing the outputs"""
187
+ # start by just realizing the buffers passed in
188
+ realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None}
189
+ allbufs: Dict[LazyBuffer, None] = {}
190
+ simple_pads: Set[LazyBuffer] = set()
191
+ children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
192
+ for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
193
+ assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None}
194
+
195
+ # check if we have to realize pads
196
+ for p in simple_pads:
197
+ if not _is_padding_okay(p, realizes):
198
+ realizes[p] = None
199
+
200
+ # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
201
+ reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
202
+ for r in allbufs:
203
+ if r.op not in ReduceOps or r in realizes: continue
204
+
205
+ group: Set[LazyBuffer] = set()
206
+ _recursive_group(r, r.st, r, children, realizes, reduce_for_op, group)
207
+ # max one reduceop per kernel
208
+ can_chase = all(tr not in reduce_for_op for tr in group)
209
+ # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
210
+ forced_realize = r in group
211
+ if not forced_realize and len(group) > 1:
212
+ # create a multi output kernel if the LazyBufferss can cleanly group
213
+ rc_parents, rc_children = deque(group), deque(group)
214
+ while rc_parents and not forced_realize:
215
+ # max one reduceop per kernel
216
+ if (p:=rc_parents.pop()).op in ReduceOps: forced_realize = True
217
+ else: rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
218
+ # search descendants of the reduceop that can cleanly group
219
+ realized_descendants: Set[LazyBuffer] = set()
220
+ while rc_children and not forced_realize:
221
+ if (c:=rc_children.pop()).op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op:
222
+ realized_descendants.clear()
223
+ break
224
+ if c in realizes and c not in group: realized_descendants.add(c)
225
+ rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device)
226
+ group.update(realized_descendants)
227
+ # can only fuse assign if no other assign_target is used in the kernel
228
+ if not forced_realize and any(x.op is LoadOps.ASSIGN for x in group):
229
+ parents = deque((r, *group))
230
+ while parents and not forced_realize:
231
+ if (p:=parents.pop().base).realized or p in realizes:
232
+ if p in assign_targets and assign_targets[p] not in group: forced_realize, can_chase = True, False
233
+ continue
234
+ parents.extend(p.srcs)
235
+ if forced_realize:
236
+ tr = r
237
+ if can_chase:
238
+ # can chase this down to contiguous children
239
+ st = tr.st
240
+ while len(children[tr]) == 1:
241
+ tr_next = next(iter(children[tr]))
242
+ st_childs = dedup(s for s in tr_next.srcs if s.base is tr)
243
+ if len(st_childs) > 1: break
244
+ if st.size != st_childs[0].st.size: break
245
+ st = st + st_childs[0].st
246
+ if not st.contiguous or tr_next.op in ReduceOps: break
247
+ tr = tr_next
248
+ # don't cast to higher size before store (tr cannot be realized if forced_realize)
249
+ if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
250
+ tr = tr.srcs[0].base
251
+ reduce_for_op[tr] = r
252
+ realizes[tr] = None
253
+ else: reduce_for_op.update((tr, r) for tr in group)
254
+
255
+ output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
256
+ for buf in realizes:
257
+ if buf.realized is not None or buf.op is LoadOps.CONST or buf in seen: continue
258
+ output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf)
259
+
260
+ # make things that can't be images not images
261
+ if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
262
+ not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
263
+ if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
264
+ buf.dtype = dtypes.float32
265
+ # hack the underlying buffer too
266
+ if buf.base is buf:
267
+ assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer"
268
+ buf.buffer.dtype = dtypes.float32
269
+ buf.buffer.options = None
270
+
271
+ # preschedule all buffers in realizes
272
+ prescheduled = {group[0]:_schedule_group(tuple(group), realizes, reduce_for_op) for group in output_groups.values()}
273
+ schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs}
274
+
275
+ graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
276
+ in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int)
277
+ for key, lsi in prescheduled.items():
278
+ if key not in in_degree: in_degree[key] = 0
279
+ # realize outputs after all parents are realized
280
+ scheduled_parents = set(schedule_targets[x].outputs[0] for x in lsi.inputs if x in schedule_targets)
281
+ for x in scheduled_parents:
282
+ graph[x].append(key)
283
+ in_degree[key] += 1
284
+ # realize outputs before a parent is assigned to
285
+ parents_assigns = set(schedule_targets[assign_targets[x]].outputs[0] for x in lsi.inputs if x in assign_targets)
286
+ for assign in parents_assigns:
287
+ graph[key].append(assign)
288
+ in_degree[assign] += 1
289
+
290
+ return graph, in_degree, prescheduled
291
+
292
+ # *** DAG ordering: breadth first search ***
293
+
294
+ SCHEDULES: List = []
295
+ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
296
+ if seen is None: seen = set()
297
+ graph, in_degree, prescheduled = _graph_schedule(outs, seen)
298
+ queue = deque(si for key, si in prescheduled.items() if in_degree[key] == 0)
299
+ schedule: List[ScheduleItem] = []
300
+ var_vals: Dict[Variable, int] = {}
301
+ kernel_number = GlobalCounters.kernel_count
302
+ while queue:
303
+ ps = queue.popleft()
304
+ for buf in ps.outputs: seen.add(buf)
305
+ if GRAPH:
306
+ kernel_number += 1
307
+ for out in ps.outputs: realized_lazybuffer(out, kernel_number)
308
+ var_vals = merge_dicts([var_vals, ps.var_vals])
309
+ for out in ps.outputs: del out.srcs # can only schedule once
310
+ schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0)))
311
+ if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
312
+ for x in graph[ps.outputs[0]]:
313
+ in_degree[x] -= 1
314
+ if in_degree[x] == 0: queue.append(prescheduled[x])
315
+
316
+ if SAVE_SCHEDULE:
317
+ def _save():
318
+ print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
319
+ with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
320
+ if len(SCHEDULES) == 0: atexit.register(_save)
321
+ SCHEDULES.extend((ps.ast for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)])
322
+ # confirm everything was scheduled correctly
323
+ if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
324
+ raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
325
+ if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
326
+ return schedule, var_vals
327
+
328
+ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
329
+ schedule, var_vals = create_schedule_with_vars(outs, seen)
330
+ assert len(var_vals) == 0
331
+ return schedule
332
+
333
+ # *** memory planning ***
334
+
335
+ def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], debug_prefix="") -> Dict[Buffer, Buffer]:
336
+ if getenv("NO_MEMORY_PLANNER"): return {}
337
+ last_appearance = {}
338
+ for i,u in enumerate(buffers):
339
+ for buf in u: last_appearance[buf] = i
340
+
341
+ # LRU algorithm
342
+ assigned: Dict[Buffer, Buffer] = {}
343
+ local_cache: DefaultDict[Tuple[str, int, DType], List[Buffer]] = defaultdict(list)
344
+
345
+ def handle_buffer(buf):
346
+ key = (buf.device, buf.size, buf.dtype)
347
+ if buf not in assigned:
348
+ if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
349
+ else: assigned[buf] = Buffer(*key)
350
+ if i == last_appearance[buf]:
351
+ if assigned[buf] not in local_cache[key]: local_cache[key].append(assigned[buf])
352
+
353
+ for i,u in enumerate(buffers):
354
+ for buf in u:
355
+ # all unallocated unparented buffers are fair game to replace
356
+ if buf.is_allocated() or buf.lb_refcount > 0: continue
357
+ # handle view buffers
358
+ if buf._base is not None:
359
+ assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf._base, buf._base), offset=buf.offset)
360
+ else:
361
+ handle_buffer(buf)
362
+
363
+ if DEBUG >= 1 and len(ak:=dedup(assigned.keys())) != len(av:=dedup(assigned.values())):
364
+ print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
365
+ f"{len(ak)} -> {len(av)} bufs")
366
+ return assigned
367
+
368
+ def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
369
+ assigned = _internal_memory_planner([si.bufs for si in schedule])
370
+ return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs)) for si in schedule]
@@ -0,0 +1,199 @@
1
+ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
2
+ import itertools, functools, random, math, time, multiprocessing, traceback, signal
3
+ from collections import defaultdict
4
+ from dataclasses import replace
5
+ from tinygrad.device import Device, Buffer, Compiler
6
+ from tinygrad.ops import MemBuffer
7
+ from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
8
+ from tinygrad.dtype import ImageDType
9
+ from tinygrad.codegen.linearizer import Linearizer
10
+ from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
11
+ from tinygrad.codegen.uops import UOpGraph
12
+ from tinygrad.tensor import Tensor
13
+ from tinygrad.shape.symbolic import sym_infer
14
+ from tinygrad.engine.realize import CompiledRunner
15
+ from tinygrad.renderer import Program
16
+
17
+ actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
18
+ actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(5)]
19
+ actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(5)]
20
+ actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
21
+ actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]
22
+ if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
23
+ actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
24
+ actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
25
+ if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
26
+
27
+ def _get_test_global_size(global_size, max_global_size, var_vals):
28
+ test_global_size, factor = [sym_infer(sz, var_vals) for sz in global_size], 1
29
+ while prod(test_global_size) > max_global_size:
30
+ for j in range(len(global_size)-1,-1,-1):
31
+ if test_global_size[j] > 16:
32
+ test_global_size[j] //= 2
33
+ factor *= 2
34
+ break
35
+ return test_global_size, factor
36
+
37
+ def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"):
38
+ factor = 1
39
+ if p.global_size is not None and max_global_size is not None:
40
+ global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals)
41
+ p = replace(p, global_size=global_size)
42
+ try: car = CompiledRunner(p, precompiled=lib)
43
+ except AssertionError: return [math.inf] * cnt
44
+ tms = []
45
+ input_bufs = [rawbufs[i] for i,_ in car.p.globals]
46
+ for _ in range(cnt):
47
+ if clear_l2:
48
+ with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
49
+ tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
50
+ if early_stop is not None and early_stop < tms[-1]: break
51
+ return tms
52
+
53
+ class TimeoutException(Exception): pass
54
+ def timeout_handler(signum, frame): raise TimeoutException()
55
+
56
+ def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
57
+ signal.signal(signal.SIGALRM, timeout_handler)
58
+ # set timeout
59
+ signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
60
+ try:
61
+ x[1].linearize()
62
+ if len(x[1].uops.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
63
+ p = x[1].to_program()
64
+ st = time.perf_counter()
65
+ prog = compiler.compile(p.src)
66
+ et = time.perf_counter() - st
67
+ ret = (p, prog, et)
68
+ except RuntimeError:
69
+ if DEBUG >= 4: traceback.print_exc()
70
+ ret = None
71
+ except TimeoutException:
72
+ ret = None
73
+ except Exception as e:
74
+ if getenv("BEAM_STRICT_MODE"): raise e
75
+ ret = None
76
+ finally:
77
+ signal.alarm(0)
78
+ return x[0], ret
79
+
80
+ # workers should ignore ctrl c
81
+ def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
82
+
83
+ def _ensure_buffer_alloc(bufs:List[Buffer]) -> List[Buffer]: return [buf.ensure_allocated() for buf in bufs]
84
+
85
+ # *** external API ***
86
+
87
+ # get (scrap) buffers for timing the linearizer
88
+ def bufs_from_lin(lin:Linearizer, allocate:bool=True) -> List[Buffer]:
89
+ bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
90
+ for x in lin.membufs: bufsts[x.idx].append(x)
91
+ rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
92
+ for k,lx in bufsts.items():
93
+ buf_size = prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.real_size() for y in lx)
94
+ if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case.
95
+ rawbufs[k] = Buffer(lin.opts.device, buf_size, lx[0].dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, lx[0].dtype)
96
+ assert all(r is not None for r in rawbufs)
97
+ return cast(List[Buffer], rawbufs)
98
+
99
+ # get dictionary of all possible actions
100
+ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
101
+ acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
102
+ for i,a in enumerate(actions):
103
+ if a.axis is not None and a.op is not OptOps.TC:
104
+ if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.amt and Opt(a.op, ax, 0) in actions): continue
105
+ lin2 = lin.copy()
106
+ try:
107
+ lin2.apply_opt(a)
108
+ up, lcl, tc_up = 1, 1, prod(tc.dims)//prod([x[1] for x in tc.threads]) if (tc:=lin2.tensor_core) else 1
109
+ for s,c in zip(lin2.full_shape, lin2.colors()):
110
+ if c in {"magenta", "yellow"}: up *= s
111
+ elif c in {"cyan", "green", "white"}: lcl *= s
112
+ if up//tc_up > max_up or lcl > max_lcl: continue
113
+ acted_lins[i+1] = lin2
114
+ except KernelOptError: pass
115
+ return acted_lins
116
+
117
+ beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
118
+ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=True) -> Linearizer:
119
+ global beam_pool
120
+ key = {"ast": lin.ast[0].key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
121
+ if not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
122
+ ret = lin.copy()
123
+ for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
124
+ return ret
125
+
126
+ beam: List[Tuple[Linearizer, float]] = [(lin, float("inf"))]
127
+ seen_libs = set()
128
+
129
+ default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV"} else 0
130
+ if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
131
+ beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
132
+
133
+ min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6
134
+ if BEAM_DEBUG: print(f"BEAM_SEARCH:\n{lin.ast}")
135
+ if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}")
136
+
137
+ try:
138
+ rawbufs = _ensure_buffer_alloc(rawbufs)
139
+ var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
140
+ exiting, st = False, time.perf_counter()
141
+ dev = Device[lin.opts.device]
142
+ while not exiting:
143
+ acted_lins: List[Linearizer] = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
144
+ timed_lins: List[Tuple[Linearizer, float]] = []
145
+ _compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
146
+ for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
147
+ if proc is None: continue
148
+ p, lib, compile_et = proc
149
+ if lib in seen_libs: continue
150
+ #print(acted_lins[i].colored_shape(), acted_lins[i].applied_opts) # for debugging BEAMs that segfault
151
+ seen_libs.add(lib)
152
+ try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
153
+ except RuntimeError: continue # for runtime issues
154
+ timed_lins.append((acted_lins[i], min(tms)))
155
+ if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(UOpGraph, p.uops).uops):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
156
+ elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
157
+
158
+ # done
159
+ opts = sorted(timed_lins, key=lambda x: x[1])
160
+ exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress))
161
+ if not exiting: beam = opts[:amt]
162
+ elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1]
163
+ if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
164
+ except KeyboardInterrupt as e:
165
+ if beam_pool is not None: beam_pool.terminate()
166
+ raise e
167
+
168
+ if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
169
+ if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={beam[0][1]*1e6:0.2f} us, applied_opts={beam[0][0].applied_opts}")
170
+ return beam[0][0]
171
+
172
+ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buffer]) -> List[int]:
173
+ test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
174
+ MAX_WORKGROUP = 1024
175
+ local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
176
+ local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
177
+ def try_exec(local_size):
178
+ try: return clprg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501
179
+ except Exception: return float('inf')
180
+ ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
181
+ assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
182
+ return ret[1]
183
+
184
+ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
185
+ key = {"ast": lin.ast[0].key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
186
+ "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
187
+ if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
188
+
189
+ dev = Device[lin.opts.device]
190
+ assert dev.compiler is not None
191
+
192
+ rawbufs = _ensure_buffer_alloc(rawbufs)
193
+ var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
194
+ p = lin.to_program()
195
+ tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
196
+ max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
197
+
198
+ if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
199
+ return min(tms)