tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. tinygrad/__init__.py +6 -0
  2. tinygrad/codegen/kernel.py +572 -83
  3. tinygrad/codegen/linearizer.py +415 -395
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +183 -0
  6. tinygrad/dtype.py +113 -0
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +76 -55
  14. tinygrad/helpers.py +196 -89
  15. tinygrad/lazy.py +210 -371
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +202 -22
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +112 -32
  20. tinygrad/nn/state.py +136 -39
  21. tinygrad/ops.py +119 -202
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +353 -166
  25. tinygrad/renderer/llvmir.py +150 -138
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +81 -0
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +75 -0
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +24 -77
  43. tinygrad/runtime/ops_cuda.py +175 -89
  44. tinygrad/runtime/ops_disk.py +56 -33
  45. tinygrad/runtime/ops_gpu.py +92 -95
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +39 -60
  48. tinygrad/runtime/ops_metal.py +92 -74
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +86 -254
  53. tinygrad/shape/symbolic.py +166 -141
  54. tinygrad/shape/view.py +296 -0
  55. tinygrad/tensor.py +2619 -448
  56. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. tinygrad-0.9.0.dist-info/METADATA +227 -0
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/codegen/assembly.py +0 -190
  61. tinygrad/codegen/optimizer.py +0 -379
  62. tinygrad/codegen/search.py +0 -72
  63. tinygrad/graph.py +0 -83
  64. tinygrad/jit.py +0 -57
  65. tinygrad/nn/image.py +0 -100
  66. tinygrad/renderer/assembly_arm64.py +0 -169
  67. tinygrad/renderer/assembly_ptx.py +0 -98
  68. tinygrad/renderer/wgsl.py +0 -53
  69. tinygrad/runtime/lib.py +0 -113
  70. tinygrad/runtime/ops_cpu.py +0 -51
  71. tinygrad/runtime/ops_hip.py +0 -82
  72. tinygrad/runtime/ops_shm.py +0 -29
  73. tinygrad/runtime/ops_torch.py +0 -30
  74. tinygrad/runtime/ops_webgpu.py +0 -45
  75. tinygrad-0.7.0.dist-info/METADATA +0 -212
  76. tinygrad-0.7.0.dist-info/RECORD +0 -40
  77. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,362 @@
1
+ import sys, pickle, atexit
2
+ from collections import defaultdict, deque
3
+ from dataclasses import dataclass
4
+ from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union
5
+ from tinygrad.ops import LoadOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
6
+ from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
7
+ from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, prod, dedup, all_int, merge_dicts, getenv
8
+ from tinygrad.shape.symbolic import Variable
9
+ from tinygrad.dtype import ImageDType, dtypes, DType
10
+ from tinygrad.lazy import LazyBuffer
11
+ from tinygrad.shape.shapetracker import ShapeTracker
12
+ from tinygrad.device import Buffer
13
+
14
+ # creation can recurse a lot
15
+ sys.setrecursionlimit(10000)
16
+
17
+ # optionally log the ops to disk
18
+ logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
19
+
20
+ # *** ScheduleItem return type ***
21
+
22
+ @dataclass(frozen=True)
23
+ class ScheduleItem:
24
+ ast: Tuple[LazyOp, ...]
25
+ bufs: Tuple[Buffer, ...]
26
+ @property
27
+ def outputs(self) -> Tuple[Buffer, ...]:
28
+ """Read/write or write only buffers in the schedule."""
29
+ return self.bufs[:len(self.ast)]
30
+ @property
31
+ def inputs(self) -> Tuple[Buffer, ...]:
32
+ """Read only buffers in the schedule."""
33
+ return self.bufs[len(self.ast):]
34
+
35
+ # *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
36
+
37
+ # TODO: it's unfortunate this needs to exist, but because of ASSIGN, we have to retain the LazyBuffer structure until post toposort
38
+ @dataclass(frozen=True)
39
+ class _LBScheduleItem:
40
+ ast: Tuple[LazyOp, ...]
41
+ outputs: Tuple[LazyBuffer, ...]
42
+ inputs: Tuple[LazyBuffer, ...]
43
+ var_vals: Dict[Variable, int]
44
+
45
+ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
46
+ realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], cache) -> LazyOp:
47
+ """recursively create a lazyop"""
48
+ if (buf, st) in cache: return cache[(buf, st)]
49
+ if buf != buf.base:
50
+ st = buf.st + st
51
+ buf = buf.base
52
+ # all buffers here are base now
53
+ assert buf.op is not None
54
+
55
+ # consts are always fused and generated
56
+ if buf.op is LoadOps.CONST:
57
+ unbound_st, st_var_vals = st.simplify().unbind()
58
+ var_vals.update(st_var_vals)
59
+ if isinstance(buf.arg, Variable): var_vals.__setitem__(*buf.arg.unbind())
60
+ return LazyOp(BufferOps.CONST, (), ConstBuffer(buf.arg, buf.dtype, unbound_st))
61
+
62
+ # if we aren't fusing it, it's a load and we add it to the inputs
63
+ if buf.realized is not None or (buf in realizes and buf not in outputs):
64
+ unbound_st, st_var_vals = st.simplify().unbind()
65
+ var_vals.update(st_var_vals)
66
+ if buf in assign_targets:
67
+ # can only assign to contiguous read+write buffer
68
+ if not unbound_st.contiguous:
69
+ # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
70
+ if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
71
+ ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
72
+ raise RuntimeError(f"must be contiguous for assign {unbound_st}")
73
+ return LazyOp(BufferOps.LOAD, (), MemBuffer(outputs.index(assign_targets[buf]), buf.dtype, unbound_st))
74
+ if buf not in inputs: inputs.append(buf)
75
+ return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.index(buf), buf.dtype, unbound_st))
76
+
77
+ # if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
78
+ if buf.op is LoadOps.CONTIGUOUS:
79
+ assert buf in outputs
80
+ return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
81
+ if buf.op is LoadOps.ASSIGN:
82
+ assert buf in outputs
83
+ assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
84
+ assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
85
+ return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
86
+
87
+ # if it's a reduce, we have to change the shapetracker
88
+ if buf.op in ReduceOps:
89
+ assert st.contiguous, "ReduceOps late fusion must be contiguous"
90
+ st = ShapeTracker.from_shape(buf.srcs[0].shape)
91
+
92
+ # otherwise we fuse it like normal
93
+ cache[(buf, st)] = ret = \
94
+ LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outputs, var_vals, st, realizes, assign_targets, cache) for x in buf.srcs), buf.arg)
95
+ return ret
96
+
97
+ def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
98
+ """create a schedule item from a list of outputs"""
99
+ inputs: List[LazyBuffer] = []
100
+ ast: List[LazyOp] = []
101
+ var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
102
+ # single output AST
103
+ if (op:=(out:=outs[0]).op) in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}:
104
+ assert len(outs) == 1, f"can't schedule a group of {op}"
105
+ inputs = [x.base for x in out.srcs]
106
+ if getenv("USE_COPY_KERNEL") and op is LoadOps.COPY and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
107
+ rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
108
+ ast = [LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st))]
109
+ else: ast = [LazyOp(op, (), out.arg)]
110
+ # multi output AST
111
+ else:
112
+ assign_targets = {x.srcs[1]:x for x in outs if x.op is LoadOps.ASSIGN}
113
+ for i, out in enumerate(outs):
114
+ output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
115
+ output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
116
+ lop = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, assign_targets, cache={})
117
+ output_view, vv = output_view.simplify().unbind()
118
+ if vv: var_vals.update(vv)
119
+ ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
120
+ return _LBScheduleItem(tuple(ast), outs, tuple(inputs), var_vals)
121
+
122
+ # *** DAG creation: decide which LazyBuffers should realize ***
123
+
124
+ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None],
125
+ simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
126
+ """recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
127
+ if buf in allbufs or buf.base.realized is not None: return
128
+ if GRAPH: log_lazybuffer(buf, scheduled)
129
+ # view
130
+ if buf.base != buf:
131
+ # fuse some pads
132
+ if len(buf.st.views) == 1 and buf.st.views[-1].mask is not None and all_int(buf.base.st.shape) and \
133
+ prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
134
+ simple_pads.add(buf.base)
135
+ # realize all expands
136
+ elif prod(buf.base.st.shape) < prod(buf.st.shape):
137
+ if buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg, ImageDType):
138
+ pass # don't realize image to image casts. this is part of a larger problem
139
+ else:
140
+ realizes[buf.base] = None
141
+ return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
142
+ # base
143
+ allbufs[buf] = None
144
+ if buf.forced_realize: realizes[buf] = None
145
+ if buf.op in LoadOps: realizes[buf.base] = None
146
+ if buf.op is LoadOps.COPY:
147
+ assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
148
+ realizes[buf.srcs[0].base] = None
149
+ if buf.op is LoadOps.VIEW: realizes[buf.srcs[0].base] = None
150
+ for x in buf.srcs:
151
+ children[x.base][buf] = None
152
+ _recurse_lb(x, realizes, allbufs, simple_pads, children)
153
+
154
+ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
155
+ if buf in realizes or buf.realized is not None: return True
156
+ # NOTE: this broke to_image_idx and coder with JIT
157
+ if buf.op in UNSAFE_PAD_OPS: return False
158
+ return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
159
+
160
+ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
161
+ realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer]):
162
+ """recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
163
+ if tr in realizes:
164
+ # can only fuse contiguous
165
+ # max one reduceop per kernel
166
+ if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.add(r)
167
+ return group.add(tr)
168
+ for tr_next in children[tr]:
169
+ if tr_next.realized is None:
170
+ # max one reduceop per kernel
171
+ if tr_next.op in ReduceOps: return group.add(r)
172
+ # can only fuse contiguous
173
+ if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r)
174
+ _recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group)
175
+
176
+ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], DefaultDict[LazyBuffer, int],
177
+ Dict[LazyBuffer, _LBScheduleItem]]:
178
+ """create a graph for realizing the outputs"""
179
+ # start by just realizing the buffers passed in
180
+ realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None}
181
+ allbufs: Dict[LazyBuffer, None] = {}
182
+ simple_pads: Set[LazyBuffer] = set()
183
+ children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
184
+ for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
185
+ assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None}
186
+
187
+ # check if we have to realize pads
188
+ for p in simple_pads:
189
+ if not _is_padding_okay(p, realizes):
190
+ realizes[p] = None
191
+
192
+ # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
193
+ reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
194
+ for r in allbufs:
195
+ if r.op not in ReduceOps or r in realizes: continue
196
+
197
+ group: Set[LazyBuffer] = set()
198
+ _recursive_group(r, r.st, r, children, realizes, reduce_for_op, group)
199
+ # max one reduceop per kernel
200
+ can_chase = all(tr not in reduce_for_op for tr in group)
201
+ # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
202
+ forced_realize = r in group
203
+ if not forced_realize and len(group) > 1:
204
+ # create a multi output kernel if the LazyBufferss can cleanly group
205
+ rc_parents, rc_children = deque(group), deque(group)
206
+ while rc_parents and not forced_realize:
207
+ # max one reduceop per kernel
208
+ if (p:=rc_parents.pop()).op in ReduceOps: forced_realize = True
209
+ else: rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
210
+ # search descendants of the reduceop that can cleanly group
211
+ realized_descendants: Set[LazyBuffer] = set()
212
+ while rc_children and not forced_realize:
213
+ if (c:=rc_children.pop()).op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op:
214
+ realized_descendants.clear()
215
+ break
216
+ if c in realizes and c not in group: realized_descendants.add(c)
217
+ rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device)
218
+ group.update(realized_descendants)
219
+ # can only fuse assign if no other assign_target is used in the kernel
220
+ if not forced_realize and any(x.op is LoadOps.ASSIGN for x in group):
221
+ parents = deque((r, *group))
222
+ while parents and not forced_realize:
223
+ if (p:=parents.pop().base).realized or p in realizes:
224
+ if p in assign_targets and assign_targets[p] not in group: forced_realize, can_chase = True, False
225
+ continue
226
+ parents.extend(p.srcs)
227
+ if forced_realize:
228
+ tr = r
229
+ if can_chase:
230
+ # can chase this down to contiguous children
231
+ st = tr.st
232
+ while len(children[tr]) == 1:
233
+ tr_next = next(iter(children[tr]))
234
+ st_childs = dedup(s for s in tr_next.srcs if s.base is tr)
235
+ if len(st_childs) > 1: break
236
+ if st.size != st_childs[0].st.size: break
237
+ st = st + st_childs[0].st
238
+ if not st.contiguous or tr_next.op in ReduceOps: break
239
+ tr = tr_next
240
+ # don't cast to higher size before store (tr cannot be realized if forced_realize)
241
+ if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
242
+ tr = tr.srcs[0].base
243
+ reduce_for_op[tr] = r
244
+ realizes[tr] = None
245
+ else: reduce_for_op.update((tr, r) for tr in group)
246
+
247
+ output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
248
+ for buf in realizes:
249
+ if buf.realized is not None or buf.op is LoadOps.CONST or buf in seen: continue
250
+ output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf)
251
+
252
+ # make things that can't be images not images
253
+ if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
254
+ not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
255
+ if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
256
+ buf.dtype = dtypes.float32
257
+ # hack the underlying buffer too
258
+ if buf.base is buf:
259
+ assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer"
260
+ buf.buffer.dtype = dtypes.float32
261
+ buf.buffer.options = None
262
+
263
+ # preschedule all buffers in realizes
264
+ prescheduled = {group[0]:_schedule_group(tuple(group), realizes, reduce_for_op) for group in output_groups.values()}
265
+ schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs}
266
+
267
+ graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
268
+ in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int)
269
+ for key, lsi in prescheduled.items():
270
+ if key not in in_degree: in_degree[key] = 0
271
+ # realize outputs after all parents are realized
272
+ scheduled_parents = set(schedule_targets[x].outputs[0] for x in lsi.inputs if x in schedule_targets)
273
+ for x in scheduled_parents:
274
+ graph[x].append(key)
275
+ in_degree[key] += 1
276
+ # realize outputs before a parent is assigned to
277
+ parents_assigns = set(schedule_targets[assign_targets[x]].outputs[0] for x in lsi.inputs if x in assign_targets)
278
+ for assign in parents_assigns:
279
+ graph[key].append(assign)
280
+ in_degree[assign] += 1
281
+
282
+ return graph, in_degree, prescheduled
283
+
284
+ # *** DAG ordering: breadth first search ***
285
+
286
+ SCHEDULES: List = []
287
+ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
288
+ if seen is None: seen = set()
289
+ graph, in_degree, prescheduled = _graph_schedule(outs, seen)
290
+ queue = deque(si for key, si in prescheduled.items() if in_degree[key] == 0)
291
+ schedule: List[ScheduleItem] = []
292
+ var_vals: Dict[Variable, int] = {}
293
+ kernel_number = GlobalCounters.kernel_count
294
+ while queue:
295
+ ps = queue.popleft()
296
+ for buf in ps.outputs: seen.add(buf)
297
+ if GRAPH:
298
+ kernel_number += 1
299
+ for out in ps.outputs: realized_lazybuffer(out, kernel_number)
300
+ var_vals = merge_dicts([var_vals, ps.var_vals])
301
+ for out in ps.outputs: del out.srcs # can only schedule once
302
+ schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0)))
303
+ if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
304
+ for x in graph[ps.outputs[0]]:
305
+ in_degree[x] -= 1
306
+ if in_degree[x] == 0: queue.append(prescheduled[x])
307
+
308
+ if SAVE_SCHEDULE:
309
+ def _save():
310
+ print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
311
+ pickle.dump(SCHEDULES, open(fp, "wb"))
312
+ if len(SCHEDULES) == 0: atexit.register(_save)
313
+ SCHEDULES.extend((ps.ast for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)])
314
+ # confirm everything was scheduled correctly
315
+ if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
316
+ raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
317
+ if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
318
+ return schedule, var_vals
319
+
320
+ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
321
+ schedule, var_vals = create_schedule_with_vars(outs, seen)
322
+ assert len(var_vals) == 0
323
+ return schedule
324
+
325
+ # *** memory planning ***
326
+
327
+ def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], debug_prefix="") -> Dict[Buffer, Buffer]:
328
+ if getenv("NO_MEMORY_PLANNER"): return {}
329
+ last_appearance = {}
330
+ for i,u in enumerate(buffers):
331
+ for buf in u: last_appearance[buf] = i
332
+
333
+ # LRU algorithm
334
+ assigned: Dict[Buffer, Buffer] = {}
335
+ local_cache: DefaultDict[Tuple[str, int, DType], List[Buffer]] = defaultdict(list)
336
+
337
+ def handle_buffer(buf):
338
+ key = (buf.device, buf.size, buf.dtype)
339
+ if buf not in assigned:
340
+ if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
341
+ else: assigned[buf] = Buffer(*key)
342
+ if i == last_appearance[buf]:
343
+ if assigned[buf] not in local_cache[key]: local_cache[key].append(assigned[buf])
344
+
345
+ for i,u in enumerate(buffers):
346
+ for buf in u:
347
+ # all unallocated unparented buffers are fair game to replace
348
+ if buf.is_allocated() or buf.lb_refcount > 0: continue
349
+ # handle view buffers
350
+ if buf._base is not None:
351
+ assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf._base, buf._base), offset=buf.offset)
352
+ else:
353
+ handle_buffer(buf)
354
+
355
+ if DEBUG >= 1 and len(ak:=dedup(assigned.keys())) != len(av:=dedup(assigned.values())):
356
+ print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
357
+ f"{len(ak)} -> {len(av)} bufs")
358
+ return assigned
359
+
360
+ def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
361
+ assigned = _internal_memory_planner([si.bufs for si in schedule])
362
+ return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs)) for si in schedule]
@@ -0,0 +1,196 @@
1
+ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
2
+ import itertools, functools, random, math, time, multiprocessing, traceback, signal
3
+ from collections import defaultdict
4
+ from dataclasses import replace
5
+ from tinygrad.device import Device, Buffer, Compiler
6
+ from tinygrad.ops import MemBuffer
7
+ from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
8
+ from tinygrad.dtype import ImageDType
9
+ from tinygrad.codegen.linearizer import Linearizer
10
+ from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
11
+ from tinygrad.codegen.uops import UOpGraph
12
+ from tinygrad.tensor import Tensor
13
+ from tinygrad.shape.symbolic import sym_infer
14
+ from tinygrad.engine.realize import CompiledRunner
15
+ from tinygrad.renderer import Program
16
+
17
+ actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
18
+ actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(5)]
19
+ actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(5)]
20
+ actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
21
+ actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]
22
+ if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
23
+ actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
24
+ actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
25
+ if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
26
+
27
+ def _get_test_global_size(global_size, max_global_size, var_vals):
28
+ test_global_size, factor = [sym_infer(sz, var_vals) for sz in global_size], 1
29
+ while prod(test_global_size) > max_global_size:
30
+ for j in range(len(global_size)-1,-1,-1):
31
+ if test_global_size[j] > 16:
32
+ test_global_size[j] //= 2
33
+ factor *= 2
34
+ break
35
+ return test_global_size, factor
36
+
37
+ def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"):
38
+ factor = 1
39
+ if p.global_size is not None and max_global_size is not None:
40
+ global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals)
41
+ p = replace(p, global_size=global_size)
42
+ try: car = CompiledRunner(p, precompiled=lib)
43
+ except AssertionError: return [math.inf] * cnt
44
+ tms = []
45
+ input_bufs = [rawbufs[i] for i,_ in car.p.globals]
46
+ for _ in range(cnt):
47
+ if clear_l2:
48
+ with Context(DEBUG=0, BEAM=0, CACHECOLLECTING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
49
+ tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
50
+ if early_stop is not None and early_stop < tms[-1]: break
51
+ return tms
52
+
53
+ class TimeoutException(Exception): pass
54
+ def timeout_handler(signum, frame): raise TimeoutException()
55
+
56
+ def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
57
+ signal.signal(signal.SIGALRM, timeout_handler)
58
+ # set timeout
59
+ signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
60
+ try:
61
+ x[1].linearize()
62
+ if len(x[1].uops.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
63
+ p = x[1].to_program()
64
+ st = time.perf_counter()
65
+ prog = compiler.compile(p.src)
66
+ et = time.perf_counter() - st
67
+ ret = (p, prog, et)
68
+ except RuntimeError:
69
+ if DEBUG >= 4: traceback.print_exc()
70
+ ret = None
71
+ except TimeoutException:
72
+ ret = None
73
+ finally:
74
+ signal.alarm(0)
75
+ return x[0], ret
76
+
77
+ # workers should ignore ctrl c
78
+ def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
79
+
80
+ def _ensure_buffer_alloc(bufs:List[Buffer]) -> List[Buffer]: return [buf.ensure_allocated() for buf in bufs]
81
+
82
+ # *** external API ***
83
+
84
+ # get (scrap) buffers for timing the linearizer
85
+ def bufs_from_lin(lin:Linearizer, allocate:bool=True) -> List[Buffer]:
86
+ bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
87
+ for x in lin.membufs: bufsts[x.idx].append(x)
88
+ rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
89
+ for k,lx in bufsts.items():
90
+ buf_size = prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.real_size() for y in lx)
91
+ if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case.
92
+ rawbufs[k] = Buffer(lin.opts.device, buf_size, lx[0].dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, lx[0].dtype)
93
+ assert all(r is not None for r in rawbufs)
94
+ return cast(List[Buffer], rawbufs)
95
+
96
+ # get dictionary of all possible actions
97
+ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
98
+ acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
99
+ for i,a in enumerate(actions):
100
+ if a.axis is not None and a.op is not OptOps.TC:
101
+ if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.amt and Opt(a.op, ax, 0) in actions): continue
102
+ lin2 = lin.copy()
103
+ try:
104
+ lin2.apply_opt(a)
105
+ up, lcl, tc_up = 1, 1, prod(tc.dims)//prod([x[1] for x in tc.threads]) if (tc:=lin2.tensor_core) else 1
106
+ for s,c in zip(lin2.full_shape, lin2.colors()):
107
+ if c in {"magenta", "yellow"}: up *= s
108
+ elif c in {"cyan", "green", "white"}: lcl *= s
109
+ if up//tc_up > max_up or lcl > max_lcl: continue
110
+ acted_lins[i+1] = lin2
111
+ except KernelOptError: pass
112
+ return acted_lins
113
+
114
+ beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
115
+ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=True) -> Linearizer:
116
+ global beam_pool
117
+ key = {"ast": lin.ast[0].key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
118
+ if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
119
+ ret = lin.copy()
120
+ for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
121
+ return ret
122
+
123
+ beam: List[Tuple[Linearizer, float]] = [(lin, float("inf"))]
124
+ seen_libs = set()
125
+
126
+ default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "HSA", "AMD", "NV"} else 0
127
+ if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
128
+ beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
129
+
130
+ min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6
131
+ if BEAM_DEBUG: print(f"BEAM_SEARCH:\n{lin.ast}")
132
+ if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}")
133
+
134
+ try:
135
+ rawbufs = _ensure_buffer_alloc(rawbufs)
136
+ var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
137
+ exiting, st = False, time.perf_counter()
138
+ dev = Device[lin.opts.device]
139
+ while not exiting:
140
+ acted_lins: List[Linearizer] = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
141
+ timed_lins: List[Tuple[Linearizer, float]] = []
142
+ _compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
143
+ for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
144
+ if proc is None: continue
145
+ p, lib, compile_et = proc
146
+ if lib in seen_libs: continue
147
+ #print(acted_lins[i].colored_shape(), acted_lins[i].applied_opts) # for debugging BEAMs that segfault
148
+ seen_libs.add(lib)
149
+ try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
150
+ except RuntimeError: continue # for runtime issues
151
+ timed_lins.append((acted_lins[i], min(tms)))
152
+ if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(UOpGraph, p.uops).uops):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
153
+ elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
154
+
155
+ # done
156
+ opts = sorted(timed_lins, key=lambda x: x[1])
157
+ exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress))
158
+ if not exiting: beam = opts[:amt]
159
+ elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1]
160
+ if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
161
+ except KeyboardInterrupt as e:
162
+ if beam_pool is not None: beam_pool.terminate()
163
+ raise e
164
+
165
+ if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
166
+ if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={beam[0][1]*1e6:0.2f} us, applied_opts={beam[0][0].applied_opts}")
167
+ return beam[0][0]
168
+
169
+ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buffer]) -> List[int]:
170
+ test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
171
+ MAX_WORKGROUP = 1024
172
+ local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
173
+ local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
174
+ def try_exec(local_size):
175
+ try: return clprg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501
176
+ except Exception: return float('inf')
177
+ ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
178
+ assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
179
+ return ret[1]
180
+
181
+ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
182
+ key = {"ast": lin.ast[0].key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
183
+ "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
184
+ if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
185
+
186
+ dev = Device[lin.opts.device]
187
+ assert dev.compiler is not None
188
+
189
+ rawbufs = _ensure_buffer_alloc(rawbufs)
190
+ var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
191
+ p = lin.to_program()
192
+ tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
193
+ max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
194
+
195
+ if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
196
+ return min(tms)