tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tinygrad/codegen/kernel.py +248 -115
  2. tinygrad/codegen/lowerer.py +215 -0
  3. tinygrad/codegen/transcendental.py +310 -0
  4. tinygrad/codegen/uopgraph.py +622 -0
  5. tinygrad/codegen/uops.py +235 -393
  6. tinygrad/device.py +428 -69
  7. tinygrad/dtype.py +18 -4
  8. tinygrad/engine/graph.py +19 -32
  9. tinygrad/engine/jit.py +148 -70
  10. tinygrad/engine/realize.py +127 -51
  11. tinygrad/engine/schedule.py +259 -216
  12. tinygrad/engine/search.py +29 -22
  13. tinygrad/function.py +9 -0
  14. tinygrad/helpers.py +87 -49
  15. tinygrad/lazy.py +34 -35
  16. tinygrad/multi.py +41 -36
  17. tinygrad/nn/__init__.py +39 -22
  18. tinygrad/nn/state.py +3 -3
  19. tinygrad/ops.py +63 -62
  20. tinygrad/renderer/__init__.py +43 -21
  21. tinygrad/renderer/assembly.py +104 -106
  22. tinygrad/renderer/cstyle.py +87 -60
  23. tinygrad/renderer/llvmir.py +21 -30
  24. tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
  25. tinygrad/runtime/autogen/cuda.py +6 -162
  26. tinygrad/runtime/autogen/kfd.py +32 -0
  27. tinygrad/runtime/autogen/libc.py +4260 -0
  28. tinygrad/runtime/autogen/nvrtc.py +579 -0
  29. tinygrad/runtime/graph/clang.py +2 -2
  30. tinygrad/runtime/graph/cuda.py +8 -11
  31. tinygrad/runtime/graph/hcq.py +120 -107
  32. tinygrad/runtime/graph/metal.py +18 -15
  33. tinygrad/runtime/ops_amd.py +197 -305
  34. tinygrad/runtime/ops_clang.py +2 -2
  35. tinygrad/runtime/ops_cuda.py +36 -94
  36. tinygrad/runtime/ops_disk.py +3 -7
  37. tinygrad/runtime/ops_gpu.py +4 -2
  38. tinygrad/runtime/ops_hip.py +70 -0
  39. tinygrad/runtime/ops_metal.py +38 -27
  40. tinygrad/runtime/ops_nv.py +283 -363
  41. tinygrad/runtime/ops_python.py +26 -30
  42. tinygrad/runtime/support/compiler_cuda.py +78 -0
  43. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
  44. tinygrad/runtime/support/elf.py +38 -0
  45. tinygrad/shape/shapetracker.py +5 -14
  46. tinygrad/shape/symbolic.py +4 -8
  47. tinygrad/shape/view.py +34 -22
  48. tinygrad/tensor.py +399 -97
  49. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
  50. tinygrad-0.9.2.dist-info/RECORD +70 -0
  51. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
  52. tinygrad/codegen/linearizer.py +0 -528
  53. tinygrad-0.9.1.dist-info/RECORD +0 -63
  54. /tinygrad/runtime/{driver → support}/__init__.py +0 -0
  55. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
  56. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,17 @@
1
- import sys, pickle, atexit
1
+ import sys, pickle, atexit, importlib, contextlib
2
2
  from collections import defaultdict, deque
3
- from dataclasses import dataclass
4
- from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, get_args
5
- from tinygrad.ops import LoadOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
3
+ from dataclasses import dataclass, field
4
+ from typing import Tuple, List, Dict, Optional, Set, DefaultDict, get_args
5
+ from tinygrad.ops import MetaOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st
6
6
  from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
7
- from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv
8
- from tinygrad.shape.symbolic import Variable
9
- from tinygrad.dtype import ConstType, ImageDType, dtypes, DType
7
+ from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
8
+ GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
9
+ from tinygrad.shape.symbolic import Variable, sint
10
+ from tinygrad.dtype import ConstType, ImageDType, dtypes
10
11
  from tinygrad.lazy import LazyBuffer
11
12
  from tinygrad.shape.shapetracker import ShapeTracker
12
13
  from tinygrad.device import Buffer
14
+ from tinygrad.shape.view import View, strides_for_shape
13
15
 
14
16
  # creation can recurse a lot
15
17
  sys.setrecursionlimit(10000)
@@ -21,143 +23,198 @@ logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
21
23
 
22
24
  @dataclass(frozen=True)
23
25
  class ScheduleItem:
24
- ast: Tuple[LazyOp, ...]
26
+ ast: LazyOp
25
27
  bufs: Tuple[Buffer, ...]
28
+ metadata: Optional[List[Metadata]] = None
26
29
  @property
27
30
  def outputs(self) -> Tuple[Buffer, ...]:
28
31
  """Read/write or write only buffers in the schedule."""
29
- return self.bufs[:len(self.ast)]
32
+ return self.bufs[:len(self.ast.src)] if self.ast.op is MetaOps.KERNEL else self.bufs[0:1]
30
33
  @property
31
34
  def inputs(self) -> Tuple[Buffer, ...]:
32
35
  """Read only buffers in the schedule."""
33
- return self.bufs[len(self.ast):]
36
+ return self.bufs[len(self.ast.src):] if self.ast.op is MetaOps.KERNEL else self.bufs[1:]
37
+
38
+ @dataclass(frozen=True)
39
+ class LBScheduleItem:
40
+ ast: LazyOp
41
+ outputs: List[LazyBuffer]
42
+ inputs: List[LazyBuffer]
43
+ var_vals: Dict[Variable, int] = field(default_factory=dict)
44
+ metadata: List[Metadata] = field(default_factory=list)
45
+ def __hash__(self):
46
+ """The unique identifier of a schedule item in the toposort."""
47
+ return hash(self.outputs[0])
34
48
 
35
49
  # *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
36
50
 
37
- # TODO: it's unfortunate this needs to exist, but because of ASSIGN, we have to retain the LazyBuffer structure until post toposort
38
- @dataclass(frozen=True)
39
- class _LBScheduleItem:
40
- ast: Tuple[LazyOp, ...]
41
- outputs: Tuple[LazyBuffer, ...]
42
- inputs: Tuple[LazyBuffer, ...]
43
- var_vals: Dict[Variable, int]
44
-
45
- def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
46
- realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], cache) -> LazyOp:
51
+ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:Dict[LazyBuffer, int],
52
+ realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
53
+ reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]],
54
+ cache:Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp]) -> LazyOp:
47
55
  """recursively create a lazyop"""
56
+ if buf is not buf.base: st, buf = buf.st+st, buf.base
48
57
  if (buf, st) in cache: return cache[(buf, st)]
49
- if buf != buf.base:
50
- st = buf.st + st
51
- buf = buf.base
52
- # all buffers here are base now
53
- assert buf.op is not None
54
-
55
- # consts are always fused and generated
56
- if buf.op is LoadOps.CONST:
57
- unbound_st, st_var_vals = st.simplify().unbind()
58
- var_vals.update(st_var_vals)
59
- if isinstance(buf.arg, Variable):
60
- val, var_val = buf.arg.unbind()
61
- var_vals.__setitem__(val, var_val)
62
- else:
63
- assert isinstance(buf.arg, get_args(ConstType)), f"cannot create ConstBuffer with value {buf.arg}"
64
- val = buf.arg
65
- return LazyOp(BufferOps.CONST, (), ConstBuffer(val, buf.dtype, unbound_st))
66
-
67
- # if we aren't fusing it, it's a load and we add it to the inputs
58
+ assert buf.op is not None, "base must be a base itself"
59
+
60
+ # buffer ops define ShapeTracker
68
61
  if buf.realized is not None or (buf in realizes and buf not in outputs):
69
62
  unbound_st, st_var_vals = st.simplify().unbind()
70
63
  var_vals.update(st_var_vals)
64
+ # if it's a const, we generate it
65
+ if buf.op is MetaOps.CONST:
66
+ if isinstance(val:=buf.arg, Variable):
67
+ val, var_val = val.unbind()
68
+ var_vals[val] = var_val
69
+ else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}"
70
+ return LazyOp(BufferOps.CONST, (), ConstBuffer(val, buf.dtype, unbound_st))
71
+ # otherwise, it's a load and we add it to the inputs
71
72
  if buf in assign_targets:
72
- # can only assign to contiguous read+write buffer
73
- if not unbound_st.contiguous:
74
- # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
75
- if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
76
- ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
77
- raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
78
- +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
79
- return LazyOp(BufferOps.LOAD, (), MemBuffer(outputs.index(assign_targets[buf]), buf.dtype, unbound_st))
80
- if buf not in inputs: inputs.append(buf)
81
- return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.index(buf), buf.dtype, unbound_st))
82
-
83
- # if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
84
- if buf.op is LoadOps.CONTIGUOUS:
85
- assert buf in outputs
86
- return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
87
- if buf.op is LoadOps.ASSIGN:
88
- assert buf in outputs
89
- assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
90
- assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
91
- return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
92
-
93
- # if it's a reduce, we have to change the shapetracker
73
+ # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
74
+ if unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and\
75
+ ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
76
+ return LazyOp(BufferOps.LOAD, (), MemBuffer(outputs.index(assign_targets[buf]), buf.dtype, unbound_st))
77
+ raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
78
+ +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
79
+ return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.setdefault(buf, len(inputs)), buf.dtype, unbound_st))
80
+
81
+ # reduce ops change ShapeTracker
94
82
  if buf.op in ReduceOps:
95
- assert st.contiguous, "ReduceOps late fusion must be contiguous"
96
- st = ShapeTracker.from_shape(buf.srcs[0].shape)
97
-
98
- # otherwise we fuse it like normal
99
- cache[(buf, st)] = ret = \
100
- LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outputs, var_vals, st, realizes, assign_targets, cache) for x in buf.srcs), buf.arg)
101
- return ret
102
-
103
- def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
104
- """create a schedule item from a list of outputs"""
105
- inputs: List[LazyBuffer] = []
83
+ rinfo = reduce_info.get((buf, st))
84
+ rsrc = _recursive_lazyop(buf.srcs[0], st:=(rinfo[0] if rinfo else st), outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
85
+ # if we are merging the reduce, skip it
86
+ if rinfo is None:
87
+ assert rsrc.op is buf.op, f"can't merge reduceop {buf.op} with {rsrc.op}\n{st}"
88
+ return rsrc
89
+ return cache.setdefault((buf, st), LazyOp(buf.op, (rsrc,), rinfo[1]))
90
+
91
+ # elementwise ops pass shapetracker
92
+ in_ops = tuple(_recursive_lazyop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs)
93
+ if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
94
+ assert buf in outputs, f"{buf.op} must be writable"
95
+ return in_ops[0]
96
+ return cache.setdefault((buf, st), LazyOp(buf.op, in_ops, buf.arg))
97
+
98
+ def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
99
+ permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis) + axis
100
+ tmp = input_st.permute(permute_axis)
101
+ return tmp, tmp.shape[-len(axis):]
102
+
103
+ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer],
104
+ reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]],
105
+ cache:Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]]) -> \
106
+ Optional[Tuple[LazyBuffer, ShapeTracker]]:
107
+ if (buf, st) in cache: return cache[(buf, st)]
108
+ if buf.base.realized is not None or (buf.base in realizes and buf.base not in outs): return None
109
+ if buf is not buf.base: st, buf = buf.st+st, buf.base
110
+ input_st = ShapeTracker.from_shape(buf.srcs[0].shape) if buf.op in ReduceOps else st
111
+ reduce_srcs = [r for x in buf.srcs if (r:=_recurse_reduceops(x, input_st, realizes, outs, reduce_info, cache)) is not None]
112
+ top_reduce = reduce_srcs[-1] if len(reduce_srcs) != 0 else None
113
+ if buf.op in ReduceOps:
114
+ axis = buf.arg
115
+ if not st.contiguous:
116
+ # push the movementop to the input
117
+ tmp, rshape = _permute_reduce(input_st, axis)
118
+ prshape = prod(rshape)
119
+ strides = strides_for_shape(rshape)
120
+ nv: List[View] = []
121
+ for v in st.views:
122
+ nv.append(View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
123
+ v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None))
124
+ input_st = tmp + ShapeTracker(tuple(nv))
125
+ # update the axis
126
+ _, new_rshape = _permute_reduce(input_st, axis)
127
+ axis = tuple(range(len(input_st.shape)-len(new_rshape), len(input_st.shape)))
128
+ elif top_reduce is not None:
129
+ top_reduce_input_st, top_reduce_axes = reduce_info[top_reduce]
130
+ if buf.srcs[0] is not buf.srcs[0].base and buf.srcs[0].base is top_reduce[0] and buf.op is top_reduce[0].op:
131
+ # merge this reduce with its parent
132
+ new_st = top_reduce[1]+st
133
+ top_reduce = (top_reduce[0], new_st.reshape(reduce_st(top_reduce_input_st, new_axis:=axis+top_reduce_axes)))
134
+ reduce_info[top_reduce] = (top_reduce_input_st, new_axis)
135
+ return None
136
+ # reshape this reduceop based on the top reduce
137
+ input_st = input_st.reshape(tuple(1 if i in top_reduce_axes else s for i,s in enumerate(top_reduce_input_st.shape)))
138
+ st = st.reshape(reduce_st(input_st, axis))
139
+ reduce_info[(buf, st)] = (input_st, axis)
140
+ return (buf, st)
141
+ return cache.setdefault((buf, st), top_reduce)
142
+
143
+ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem:
144
+ """describe the computation for a LazyBuffer with LazyOp + inputs + var_vals"""
145
+ if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
146
+ rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
147
+ wr = LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st))
148
+ return LBScheduleItem(LazyOp(MetaOps.KERNEL, (wr,)), outs, [x.base for x in out.srcs])
149
+ if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
150
+ return LBScheduleItem(LazyOp(MetaOps.EXT, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])
151
+ # push through all movementops between reduceops
152
+ reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {}
153
+ seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]] = {}
154
+ for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
155
+ # pad all reduceops to the max of each dimension
156
+ shape_dims = [sorted(dedup(dims)) for dims in zip(*[input_st.shape for input_st,_ in reduce_info.values()])]
157
+ for i,dims in enumerate(shape_dims):
158
+ if len(dims) == 1 or (len(dims) == 2 and dims[0] == 1): continue
159
+ for (r,view),(input_st,axis) in reduce_info.items():
160
+ if (dim:=input_st.shape[i]) > 1 and dim != max(dims):
161
+ input_st = input_st.pad(((0, 0),)*i+((0, max(dims)-dim),))
162
+ reduce_info[(r, view)] = (input_st, axis)
163
+ # create the stores
164
+ var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
165
+ assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
166
+ cache: Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp] = {}
106
167
  ast: List[LazyOp] = []
107
- var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
108
- # single output AST
109
- if (op:=(out:=outs[0]).op) in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}:
110
- assert len(outs) == 1, f"can't schedule a group of {op}"
111
- inputs = [x.base for x in out.srcs]
112
- if getenv("USE_COPY_KERNEL") and op is LoadOps.COPY and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
113
- rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
114
- ast = [LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st))]
115
- else: ast = [LazyOp(op, (), out.arg)]
116
- # multi output AST
117
- else:
118
- assign_targets = {x.srcs[1]:x for x in outs if x.op is LoadOps.ASSIGN}
119
- for i, out in enumerate(outs):
120
- output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
121
- output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
122
- lop = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, assign_targets, cache={})
123
- output_view, vv = output_view.simplify().unbind()
124
- if vv: var_vals.update(vv)
125
- ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
126
- return _LBScheduleItem(tuple(ast), outs, tuple(inputs), var_vals)
168
+ inputs: Dict[LazyBuffer, int] = {}
169
+ for i, out in enumerate(outs):
170
+ output_st = ShapeTracker.from_shape(reduce_st(*deque(reduce_info.values(), 1).pop()) if reduce_info else out.shape)
171
+ lop = _recursive_lazyop(out, output_st, tuple(outs), var_vals, inputs, realizes, assign_targets, reduce_info, cache=cache)
172
+ if out.op is MetaOps.ASSIGN and out.arg:
173
+ assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
174
+ output_st = out.arg[0].reshape(output_st.shape)
175
+ output_st, vv = output_st.simplify().unbind()
176
+ if vv: var_vals.update(vv)
177
+ ast.append(LazyOp(BufferOps.STORE, (lop,), MemBuffer(i, out.dtype, output_st)))
178
+ return LBScheduleItem(LazyOp(MetaOps.KERNEL, tuple(ast)), outs, list(inputs), var_vals,
179
+ dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))
127
180
 
128
181
  # *** DAG creation: decide which LazyBuffers should realize ***
129
182
 
130
- def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None],
131
- simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
183
+ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Dict[LazyBuffer, None],
184
+ children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer],
185
+ double_reduces:Dict[LazyBuffer, None], scheduled=False) -> None:
132
186
  """recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
133
187
  if buf in allbufs or buf.base.realized is not None: return
134
188
  if GRAPH: log_lazybuffer(buf, scheduled)
135
- # view
136
- if buf.base != buf:
189
+ # check if we need to realize views
190
+ if buf is not buf.base:
137
191
  # fuse some pads
138
192
  if len(buf.st.views) == 1 and buf.st.views[-1].mask is not None and all_int(buf.base.st.shape) and \
139
193
  prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
140
- simple_pads.add(buf.base)
194
+ simple_pads[buf.base] = None
141
195
  # realize all expands
142
196
  elif prod(buf.base.st.shape) < prod(buf.st.shape):
197
+ # this was causing "test_lil_model" to fail
143
198
  if buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg, ImageDType):
144
- pass # don't realize image to image casts. this is part of a larger problem
145
- else:
146
- realizes[buf.base] = None
199
+ simple_pads[buf.base] = None # don't realize image to image casts. this is part of a larger problem
200
+ else: realizes[buf.base] = None
147
201
  # check all other pads for safe fusion
148
- elif any(v.mask is not None for v in buf.st.views): simple_pads.add(buf.base)
149
- return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
150
- # base
202
+ elif any(v.mask is not None for v in buf.st.views): simple_pads[buf.base] = None
203
+ return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children, assign_targets, double_reduces)
204
+ if buf.op in ReduceOps and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None
151
205
  allbufs[buf] = None
152
- if buf.forced_realize: realizes[buf] = None
153
- if buf.op in LoadOps: realizes[buf.base] = None
154
- if buf.op is LoadOps.COPY:
206
+ if buf.forced_realize or buf.op in MetaOps: realizes[buf] = None
207
+ if buf.op is MetaOps.ASSIGN:
208
+ assert buf.srcs[1].base is buf.srcs[1], f"assign must be to base {buf.srcs[1]}"
209
+ assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
210
+ assign_targets[buf.srcs[1]] = buf
211
+ if buf.op is MetaOps.COPY:
155
212
  assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
156
213
  realizes[buf.srcs[0].base] = None
157
- if buf.op is LoadOps.VIEW: realizes[buf.srcs[0].base] = None
214
+ if buf.op is MetaOps.VIEW: realizes[buf.srcs[0].base] = None
158
215
  for x in buf.srcs:
159
- children[x.base][buf] = None
160
- _recurse_lb(x, realizes, allbufs, simple_pads, children)
216
+ if x.base.realized is None: children[x.base][buf] = None
217
+ _recurse_lb(x, realizes, allbufs, simple_pads, children, assign_targets, double_reduces)
161
218
 
162
219
  def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
163
220
  if buf in realizes or buf.realized is not None: return True
@@ -166,31 +223,50 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
166
223
  return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
167
224
 
168
225
  def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
169
- realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer]):
226
+ realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None],
227
+ cache:Dict[Tuple[LazyBuffer, ShapeTracker], None]) -> None:
170
228
  """recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
171
- if tr in realizes:
229
+ if (tr, st) in cache: return
230
+ cache.setdefault((tr, st))
231
+ if tr in realizes and tr is not r:
172
232
  # can only fuse contiguous
173
233
  # max one reduceop per kernel
174
- if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.add(r)
175
- return group.add(tr)
234
+ if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.setdefault(r)
235
+ return group.setdefault(tr)
176
236
  for tr_next in children[tr]:
177
- if tr_next.realized is None:
178
- # max one reduceop per kernel
179
- if tr_next.op in ReduceOps: return group.add(r)
180
- # can only fuse contiguous
181
- if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r)
182
- _recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group)
183
-
184
- def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], DefaultDict[LazyBuffer, int],
185
- Dict[LazyBuffer, _LBScheduleItem]]:
237
+ # max one reduceop per kernel
238
+ if tr_next.op in ReduceOps: return group.setdefault(r)
239
+ # can only fuse contiguous
240
+ if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r)
241
+ _recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, cache)
242
+
243
+ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],\
244
+ realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None]) -> Dict[LazyBuffer, None]:
245
+ rc_parents, cache = deque(group), set()
246
+ while rc_parents:
247
+ if (p:=rc_parents.pop()) in cache: continue
248
+ cache.add(p)
249
+ # max one reduceop per kernel
250
+ if p.op in ReduceOps: return {}
251
+ rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
252
+ # search descendants of the reduceop that can cleanly group
253
+ descendants: Dict[LazyBuffer, None] = {}
254
+ for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={})
255
+ return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
256
+
257
+ SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = []
258
+ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
259
+ Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], # this is the graph
260
+ DefaultDict[LBScheduleItem, int]]: # this is the in-degree of the graph
186
261
  """create a graph for realizing the outputs"""
187
262
  # start by just realizing the buffers passed in
188
263
  realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None}
189
264
  allbufs: Dict[LazyBuffer, None] = {}
190
- simple_pads: Set[LazyBuffer] = set()
265
+ simple_pads: Dict[LazyBuffer, None] = {}
191
266
  children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
192
- for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
193
- assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None}
267
+ assign_targets: Dict[LazyBuffer, LazyBuffer] = {}
268
+ double_reduces: Dict[LazyBuffer, None] = {}
269
+ for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, assign_targets, double_reduces, scheduled=True)
194
270
 
195
271
  # check if we have to realize pads
196
272
  for p in simple_pads:
@@ -199,40 +275,27 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
199
275
 
200
276
  # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
201
277
  reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
278
+ reduce_of_const: List[LazyBuffer] = []
202
279
  for r in allbufs:
203
280
  if r.op not in ReduceOps or r in realizes: continue
204
281
 
205
- group: Set[LazyBuffer] = set()
206
- _recursive_group(r, r.st, r, children, realizes, reduce_for_op, group)
282
+ group: Dict[LazyBuffer, None] = {}
283
+ _recursive_group(r, r.st, r, children, realizes, reduce_for_op, group, cache={})
207
284
  # max one reduceop per kernel
208
285
  can_chase = all(tr not in reduce_for_op for tr in group)
209
286
  # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
210
287
  forced_realize = r in group
211
288
  if not forced_realize and len(group) > 1:
212
- # create a multi output kernel if the LazyBufferss can cleanly group
213
- rc_parents, rc_children = deque(group), deque(group)
214
- while rc_parents and not forced_realize:
215
- # max one reduceop per kernel
216
- if (p:=rc_parents.pop()).op in ReduceOps: forced_realize = True
217
- else: rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
218
- # search descendants of the reduceop that can cleanly group
219
- realized_descendants: Set[LazyBuffer] = set()
220
- while rc_children and not forced_realize:
221
- if (c:=rc_children.pop()).op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op:
222
- realized_descendants.clear()
223
- break
224
- if c in realizes and c not in group: realized_descendants.add(c)
225
- rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device)
226
- group.update(realized_descendants)
289
+ group = _get_isolated_children(r, reduce_for_op, children, realizes, group)
227
290
  # can only fuse assign if no other assign_target is used in the kernel
228
- if not forced_realize and any(x.op is LoadOps.ASSIGN for x in group):
291
+ if not forced_realize and any(x.op is MetaOps.ASSIGN for x in group):
229
292
  parents = deque((r, *group))
230
293
  while parents and not forced_realize:
231
294
  if (p:=parents.pop().base).realized or p in realizes:
232
295
  if p in assign_targets and assign_targets[p] not in group: forced_realize, can_chase = True, False
233
296
  continue
234
297
  parents.extend(p.srcs)
235
- if forced_realize:
298
+ if forced_realize or not group:
236
299
  tr = r
237
300
  if can_chase:
238
301
  # can chase this down to contiguous children
@@ -251,10 +314,26 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
251
314
  reduce_for_op[tr] = r
252
315
  realizes[tr] = None
253
316
  else: reduce_for_op.update((tr, r) for tr in group)
317
+ if FUSE_ARANGE and r.op is ReduceOps.SUM and r.srcs[0].base.op is MetaOps.CONST: reduce_of_const.append(r)
318
+
319
+ # fuse double reduces with no other child
320
+ if FUSE_CONV_BW:
321
+ for reduceop in double_reduces:
322
+ top_reduce = reduceop.base.srcs[0].base
323
+ if len(children[top_reduce]) == 1: del realizes[top_reduce]
324
+
325
+ for r in reduce_of_const:
326
+ group = {tr:None for tr,rop in reduce_for_op.items() if rop is r}
327
+ if DEBUG_ARANGE:=(getenv("DEBUG_ARANGE")): print(f"checking {r} {group=}")
328
+ if any(tr.forced_realize for tr in group) or any(x.base in group for x in outs): continue
329
+ kernel_children = {c for tr in group for c in children[tr] if c.op not in {MetaOps.COPY, MetaOps.VIEW}}
330
+ if len(kernel_children) == 0: continue
331
+ if DEBUG_ARANGE: print(colored(f"folding {r}", "green"))
332
+ for tr in group: del realizes[tr]
254
333
 
255
334
  output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
256
335
  for buf in realizes:
257
- if buf.realized is not None or buf.op is LoadOps.CONST or buf in seen: continue
336
+ if buf.realized is not None or buf.op is MetaOps.CONST or buf in seen: continue
258
337
  output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf)
259
338
 
260
339
  # make things that can't be images not images
@@ -269,59 +348,62 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
269
348
  buf.buffer.options = None
270
349
 
271
350
  # preschedule all buffers in realizes
272
- prescheduled = {group[0]:_schedule_group(tuple(group), realizes, reduce_for_op) for group in output_groups.values()}
273
- schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs}
351
+ prescheduled = [_lower_lazybuffer(group, realizes) for group in output_groups.values()]
352
+ schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
274
353
 
275
- graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
276
- in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int)
277
- for key, lsi in prescheduled.items():
278
- if key not in in_degree: in_degree[key] = 0
354
+ graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list)
355
+ in_degree: DefaultDict[LBScheduleItem, int] = defaultdict(int)
356
+ for lsi in prescheduled:
357
+ if lsi not in in_degree: in_degree[lsi] = 0
279
358
  # realize outputs after all parents are realized
280
- scheduled_parents = set(schedule_targets[x].outputs[0] for x in lsi.inputs if x in schedule_targets)
359
+ scheduled_parents = dedup(schedule_targets[x] for x in lsi.inputs if x in schedule_targets)
281
360
  for x in scheduled_parents:
282
- graph[x].append(key)
283
- in_degree[key] += 1
361
+ graph[x].append(lsi)
362
+ in_degree[lsi] += 1
284
363
  # realize outputs before a parent is assigned to
285
- parents_assigns = set(schedule_targets[assign_targets[x]].outputs[0] for x in lsi.inputs if x in assign_targets)
364
+ parents_assigns = dedup(schedule_targets[assign_targets[x]] for x in lsi.inputs if x in assign_targets)
286
365
  for assign in parents_assigns:
287
- graph[key].append(assign)
366
+ graph[lsi].append(assign)
288
367
  in_degree[assign] += 1
289
368
 
290
- return graph, in_degree, prescheduled
369
+ if SAVE_SCHEDULE:
370
+ def _save():
371
+ print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
372
+ with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
373
+ if len(SCHEDULES) == 0: atexit.register(_save)
374
+ SCHEDULES.append((graph, in_degree))
375
+ return graph, in_degree
291
376
 
292
377
  # *** DAG ordering: breadth first search ***
293
378
 
294
- SCHEDULES: List = []
295
379
  def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
296
380
  if seen is None: seen = set()
297
- graph, in_degree, prescheduled = _graph_schedule(outs, seen)
298
- queue = deque(si for key, si in prescheduled.items() if in_degree[key] == 0)
381
+ graph, in_degree = _graph_schedule(outs, seen)
382
+ if getenv("RUN_PROCESS_REPLAY") and getenv("COMPARE_SCHEDULE", 1):
383
+ # NOTE: process relpay needs PYTHONPATH=., remove this once it just pickles LazyBuffers
384
+ with contextlib.suppress(Exception): importlib.import_module("test.external.process_replay.diff_schedule").process_replay(outs, graph, in_degree)
385
+
386
+ queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0)
299
387
  schedule: List[ScheduleItem] = []
300
388
  var_vals: Dict[Variable, int] = {}
301
389
  kernel_number = GlobalCounters.kernel_count
302
390
  while queue:
303
- ps = queue.popleft()
304
- for buf in ps.outputs: seen.add(buf)
391
+ lsi = queue.popleft()
392
+ for buf in lsi.outputs: seen.add(buf)
305
393
  if GRAPH:
306
394
  kernel_number += 1
307
- for out in ps.outputs: realized_lazybuffer(out, kernel_number)
308
- var_vals = merge_dicts([var_vals, ps.var_vals])
309
- for out in ps.outputs: del out.srcs # can only schedule once
310
- schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0)))
311
- if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
312
- for x in graph[ps.outputs[0]]:
395
+ for out in lsi.outputs: realized_lazybuffer(out, kernel_number)
396
+ var_vals = merge_dicts([var_vals, lsi.var_vals])
397
+ for out in lsi.outputs: del out.srcs # can only schedule once
398
+ schedule.append(si:=ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata))
399
+ if logops and si.ast.op is MetaOps.KERNEL and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
400
+ for x in graph[lsi]:
313
401
  in_degree[x] -= 1
314
- if in_degree[x] == 0: queue.append(prescheduled[x])
402
+ if in_degree[x] == 0: queue.append(x)
315
403
 
316
- if SAVE_SCHEDULE:
317
- def _save():
318
- print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
319
- with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
320
- if len(SCHEDULES) == 0: atexit.register(_save)
321
- SCHEDULES.extend((ps.ast for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)])
322
404
  # confirm everything was scheduled correctly
323
- if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
324
- raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
405
+ if any(degree != 0 for degree in in_degree.values()) or len(in_degree) != len(schedule):
406
+ raise RuntimeError(f"cycle detected in graph, prescheduled {len(in_degree)} but only scheduled {len(schedule)}")
325
407
  if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
326
408
  return schedule, var_vals
327
409
 
@@ -329,42 +411,3 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
329
411
  schedule, var_vals = create_schedule_with_vars(outs, seen)
330
412
  assert len(var_vals) == 0
331
413
  return schedule
332
-
333
- # *** memory planning ***
334
-
335
- def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], debug_prefix="") -> Dict[Buffer, Buffer]:
336
- if getenv("NO_MEMORY_PLANNER"): return {}
337
- last_appearance = {}
338
- for i,u in enumerate(buffers):
339
- for buf in u: last_appearance[buf] = i
340
-
341
- # LRU algorithm
342
- assigned: Dict[Buffer, Buffer] = {}
343
- local_cache: DefaultDict[Tuple[str, int, DType], List[Buffer]] = defaultdict(list)
344
-
345
- def handle_buffer(buf):
346
- key = (buf.device, buf.size, buf.dtype)
347
- if buf not in assigned:
348
- if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
349
- else: assigned[buf] = Buffer(*key)
350
- if i == last_appearance[buf]:
351
- if assigned[buf] not in local_cache[key]: local_cache[key].append(assigned[buf])
352
-
353
- for i,u in enumerate(buffers):
354
- for buf in u:
355
- # all unallocated unparented buffers are fair game to replace
356
- if buf.is_allocated() or buf.lb_refcount > 0: continue
357
- # handle view buffers
358
- if buf._base is not None:
359
- assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf._base, buf._base), offset=buf.offset)
360
- else:
361
- handle_buffer(buf)
362
-
363
- if DEBUG >= 1 and len(ak:=dedup(assigned.keys())) != len(av:=dedup(assigned.values())):
364
- print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
365
- f"{len(ak)} -> {len(av)} bufs")
366
- return assigned
367
-
368
- def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
369
- assigned = _internal_memory_planner([si.bufs for si in schedule])
370
- return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs)) for si in schedule]