tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,370 +1,419 @@
1
- import sys, pickle, atexit
1
+ import sys, atexit, functools, itertools
2
2
  from collections import defaultdict, deque
3
- from dataclasses import dataclass
4
- from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, get_args
5
- from tinygrad.ops import LoadOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
6
- from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
7
- from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv
8
- from tinygrad.shape.symbolic import Variable
9
- from tinygrad.dtype import ConstType, ImageDType, dtypes, DType
10
- from tinygrad.lazy import LazyBuffer
3
+ from dataclasses import dataclass, field
4
+ from typing import Set, Tuple, List, Dict, Optional, DefaultDict, cast
5
+ from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint
6
+ from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
7
+ from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG
8
+ from tinygrad.dtype import ImageDType, dtypes
11
9
  from tinygrad.shape.shapetracker import ShapeTracker
10
+ from tinygrad.shape.view import View, strides_for_shape
11
+ from tinygrad.engine.lazy import LazyBuffer
12
12
  from tinygrad.device import Buffer
13
13
 
14
14
  # creation can recurse a lot
15
15
  sys.setrecursionlimit(10000)
16
16
 
17
- # optionally log the ops to disk
18
- logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
17
+ BUF_LIMIT = {"METAL":32}
19
18
 
20
- # *** ScheduleItem return type ***
19
+ # **** ScheduleItem return type
21
20
 
22
21
  @dataclass(frozen=True)
23
22
  class ScheduleItem:
24
- ast: Tuple[LazyOp, ...]
23
+ ast: UOp
25
24
  bufs: Tuple[Buffer, ...]
25
+ metadata: Tuple[Metadata, ...]
26
+ assign_preloads: Tuple[UOp, ...]
26
27
  @property
27
28
  def outputs(self) -> Tuple[Buffer, ...]:
28
29
  """Read/write or write only buffers in the schedule."""
29
- return self.bufs[:len(self.ast)]
30
+ return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
30
31
  @property
31
32
  def inputs(self) -> Tuple[Buffer, ...]:
32
33
  """Read only buffers in the schedule."""
33
- return self.bufs[len(self.ast):]
34
+ return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
35
+ @functools.cached_property
36
+ def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
34
37
 
35
- # *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
38
+ # **** small wrapper for LazyBuffer -> UOp
36
39
 
37
- # TODO: it's unfortunate this needs to exist, but because of ASSIGN, we have to retain the LazyBuffer structure until post toposort
38
- @dataclass(frozen=True)
39
- class _LBScheduleItem:
40
- ast: Tuple[LazyOp, ...]
41
- outputs: Tuple[LazyBuffer, ...]
42
- inputs: Tuple[LazyBuffer, ...]
43
- var_vals: Dict[Variable, int]
44
-
45
- def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
46
- realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], cache) -> LazyOp:
47
- """recursively create a lazyop"""
48
- if (buf, st) in cache: return cache[(buf, st)]
49
- if buf != buf.base:
50
- st = buf.st + st
51
- buf = buf.base
52
- # all buffers here are base now
53
- assert buf.op is not None
40
+ def UPatSrc(*args, **kwargs): return UPat(Ops.VIEW, src=(UPat.var("b"), UPat(*args, **{**kwargs, "name":"to_store"})), name="base")
41
+ @functools.lru_cache(None)
42
+ def is_scheduled(u:UOp): return u.op is Ops.VIEW and len(u.src) == 2
54
43
 
44
+ @dataclass(frozen=True)
45
+ class ScheduleContext:
46
+ ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata
47
+ var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value
48
+ assigns: Set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
49
+ allbufs: Dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
50
+ children: DefaultDict[UOp, Dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
51
+
52
+ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], lazybufs:Dict[Buffer, LazyBuffer], cache:Dict[LazyBuffer, UOp]) -> UOp:
53
+ if (r:=cache.get(buf)) is not None: return r
54
+ if buf is not buf.base:
55
+ cache[buf] = ret = to_uop(buf.base, ctx, buffers, lazybufs, cache).view(buf.st)
56
+ return ret
57
+ # make things that can't be images not images
58
+ if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
59
+ not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
60
+ if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to {buf.dtype.base}")
61
+ # hack the underlying buffer too
62
+ buf.dtype = buf.buffer.dtype = buf.dtype.base
63
+ assert not buf.is_realized, "can't fixup allocated buffer"
64
+ buf.buffer.options = None
65
+ dtype = buf.dtype if buf.op in GroupOp.Meta else buf.dtype.base
55
66
  # consts are always fused and generated
56
- if buf.op is LoadOps.CONST:
57
- unbound_st, st_var_vals = st.simplify().unbind()
58
- var_vals.update(st_var_vals)
59
- if isinstance(buf.arg, Variable):
60
- val, var_val = buf.arg.unbind()
61
- var_vals.__setitem__(val, var_val)
62
- else:
63
- assert isinstance(buf.arg, get_args(ConstType)), f"cannot create ConstBuffer with value {buf.arg}"
64
- val = buf.arg
65
- return LazyOp(BufferOps.CONST, (), ConstBuffer(val, buf.dtype, unbound_st))
66
-
67
- # if we aren't fusing it, it's a load and we add it to the inputs
68
- if buf.realized is not None or (buf in realizes and buf not in outputs):
69
- unbound_st, st_var_vals = st.simplify().unbind()
70
- var_vals.update(st_var_vals)
71
- if buf in assign_targets:
72
- # can only assign to contiguous read+write buffer
73
- if not unbound_st.contiguous:
74
- # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
75
- if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
76
- ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
77
- raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
78
- +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
79
- return LazyOp(BufferOps.LOAD, (), MemBuffer(outputs.index(assign_targets[buf]), buf.dtype, unbound_st))
80
- if buf not in inputs: inputs.append(buf)
81
- return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.index(buf), buf.dtype, unbound_st))
82
-
83
- # if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
84
- if buf.op is LoadOps.CONTIGUOUS:
85
- assert buf in outputs
86
- return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
87
- if buf.op is LoadOps.ASSIGN:
88
- assert buf in outputs
89
- assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
90
- assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
91
- return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
92
-
93
- # if it's a reduce, we have to change the shapetracker
94
- if buf.op in ReduceOps:
95
- assert st.contiguous, "ReduceOps late fusion must be contiguous"
96
- st = ShapeTracker.from_shape(buf.srcs[0].shape)
97
-
98
- # otherwise we fuse it like normal
99
- cache[(buf, st)] = ret = \
100
- LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outputs, var_vals, st, realizes, assign_targets, cache) for x in buf.srcs), buf.arg)
67
+ if buf.op is Ops.CONST:
68
+ if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()])
69
+ return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(UOp.const(dtype, val), 0)
70
+ # everything else is a VIEW of BUFFER (with an optional op)
71
+ if buf.is_realized:
72
+ buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
73
+ op = None
74
+ elif buf.op is Ops.ASSIGN:
75
+ target, new_val = [to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs]
76
+ ctx.assigns.add(ubuf:=target.buf_uop)
77
+ op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg)
78
+ else:
79
+ buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
80
+ op = UOp(cast(Ops, buf.op), dtype, tuple(to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs),
81
+ None if buf.op in {Ops.CAST, Ops.BITCAST} else buf.arg)
82
+ cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
83
+ if op is not None:
84
+ if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
85
+ lazybufs[buf.buffer] = buf
86
+ for x in op.src:
87
+ if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[ubuf] = None
88
+ ctx.allbufs[ubuf] = ret
101
89
  return ret
102
90
 
103
- def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
104
- """create a schedule item from a list of outputs"""
105
- inputs: List[LazyBuffer] = []
106
- ast: List[LazyOp] = []
107
- var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
108
- # single output AST
109
- if (op:=(out:=outs[0]).op) in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}:
110
- assert len(outs) == 1, f"can't schedule a group of {op}"
111
- inputs = [x.base for x in out.srcs]
112
- if getenv("USE_COPY_KERNEL") and op is LoadOps.COPY and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
113
- rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
114
- ast = [LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st))]
115
- else: ast = [LazyOp(op, (), out.arg)]
116
- # multi output AST
117
- else:
118
- assign_targets = {x.srcs[1]:x for x in outs if x.op is LoadOps.ASSIGN}
119
- for i, out in enumerate(outs):
120
- output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
121
- output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
122
- lop = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, assign_targets, cache={})
123
- output_view, vv = output_view.simplify().unbind()
124
- if vv: var_vals.update(vv)
125
- ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
126
- return _LBScheduleItem(tuple(ast), outs, tuple(inputs), var_vals)
127
-
128
- # *** DAG creation: decide which LazyBuffers should realize ***
129
-
130
- def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None],
131
- simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
132
- """recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
133
- if buf in allbufs or buf.base.realized is not None: return
134
- if GRAPH: log_lazybuffer(buf, scheduled)
135
- # view
136
- if buf.base != buf:
137
- # fuse some pads
138
- if len(buf.st.views) == 1 and buf.st.views[-1].mask is not None and all_int(buf.base.st.shape) and \
139
- prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
140
- simple_pads.add(buf.base)
141
- # realize all expands
142
- elif prod(buf.base.st.shape) < prod(buf.st.shape):
143
- if buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg, ImageDType):
144
- pass # don't realize image to image casts. this is part of a larger problem
145
- else:
146
- realizes[buf.base] = None
147
- # check all other pads for safe fusion
148
- elif any(v.mask is not None for v in buf.st.views): simple_pads.add(buf.base)
149
- return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
150
- # base
151
- allbufs[buf] = None
152
- if buf.forced_realize: realizes[buf] = None
153
- if buf.op in LoadOps: realizes[buf.base] = None
154
- if buf.op is LoadOps.COPY:
155
- assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
156
- realizes[buf.srcs[0].base] = None
157
- if buf.op is LoadOps.VIEW: realizes[buf.srcs[0].base] = None
158
- for x in buf.srcs:
159
- children[x.base][buf] = None
160
- _recurse_lb(x, realizes, allbufs, simple_pads, children)
161
-
162
- def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
163
- if buf in realizes or buf.realized is not None: return True
164
- # NOTE: this broke to_image_idx and coder with JIT
165
- if buf.op in UNSAFE_PAD_OPS: return False
166
- return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
167
-
168
- def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
169
- realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer]):
170
- """recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
171
- if tr in realizes:
91
+ # **** AST graph rewrite
92
+
93
+ # ** helpers for doing movementops on uops
94
+
95
+ def apply_swizzle(u:UOp, arg:ShapeTracker) -> UOp:
96
+ with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u.view(arg), view_left)
97
+
98
+ def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
99
+ permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
100
+ tmp = input_st.permute(permute_axis)
101
+ return tmp, tmp.shape[-len(axis):]
102
+
103
+ # ** movementops rewrite rules
104
+
105
+ def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
106
+ tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(src.st).shape), r.axis_arg)
107
+ prshape = prod(rshape)
108
+ strides = strides_for_shape(rshape)
109
+ nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
110
+ v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
111
+ # update input_st and axis
112
+ new_input_st = tmp + ShapeTracker(tuple(nv))
113
+ _, new_rshape = permute_reduce(new_input_st, r.axis_arg)
114
+ new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
115
+ return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
116
+
117
+ def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp:
118
+ swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st)
119
+ assert swizzle_st.contiguous, "can't push a non contiguous VIEW down to STORE"
120
+ assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
121
+ output_shape = swizzle_st.reduce(root.axis_arg)
122
+ new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u)
123
+ return swizzle.src[0].r(root.arg[0], new_axis).view(ShapeTracker.from_shape(output_shape))
124
+
125
+ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
126
+ swizzles = [x for x in root.src if x.base is not x]
127
+ if len(swizzles) == 0: return None
128
+ swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
129
+ assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
130
+ new_shape, new_input_shape = swizzle_shapes[0]
131
+ ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src))
132
+ return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
133
+
134
+ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
135
+ assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
136
+ assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
137
+ return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
138
+
139
+ merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
140
+
141
+ # push VIEW to loads
142
+ view_left = merge_views+PatternMatcher([
143
+ # VIEW before elementwise ops
144
+ (UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
145
+ # early merge VIEW buffer ops
146
+ (UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.arg+v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
147
+ ])
148
+
149
+ # push VIEW to stores
150
+ view_right = merge_views+PatternMatcher([
151
+ # ASSIGN can override st
152
+ (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))),
153
+ lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None),
154
+ # non contiguous VIEW on a reduce creates a new VIEW
155
+ (UPat(Ops.REDUCE_AXIS, src=UPat.var("src"), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
156
+ # push a VIEW down to STORE, through a reduce (ONLY reshapes)
157
+ (UPat(Ops.REDUCE_AXIS, src=(UPat.var(name="src").view(name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
158
+ # push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes)
159
+ (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
160
+ (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
161
+ ])
162
+
163
+ # ** ScheduleItem context builder
164
+
165
+ @dataclass(frozen=True)
166
+ class ScheduleItemContext:
167
+ var_vals: Dict[Variable, int]
168
+ assigned: Set[UOp]
169
+ sts: Set[ShapeTracker] = field(default_factory=set)
170
+ bufs: List[UOp] = field(default_factory=list)
171
+ assign_preloads: List[UOp] = field(default_factory=list)
172
+
173
+ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
174
+ if (st:=unwrap(x.st)) in ctx.sts: return None
175
+ st, var_vals = st.simplify().unbind()
176
+ ctx.var_vals.update(var_vals)
177
+ ctx.sts.add(st)
178
+ return st.to_uop() if st != x.st else None
179
+
180
+ def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
181
+ ctx.bufs.append(x)
182
+ return UOp(Ops.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1)
183
+ append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
184
+
185
+ def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
186
+ if b in ctx.assigned: ctx.assign_preloads.append(b)
187
+ return x.replace(op=Ops.LOAD)
188
+
189
+ to_si = PatternMatcher([
190
+ (UPat(Ops.VIEW, name="x"), _append_st_vars),
191
+ (UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),
192
+ (UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))),
193
+ ])
194
+
195
+ # ** fusion
196
+
197
+ lazy = PatternMatcher([
198
+ (UPatSrc(), lambda ctx,to_store,**kwargs: to_store),
199
+ (UPat(Ops.BUFFER, name="b").view(name="view"), lambda ctx,b,view: UOp(Ops.PRELOAD, view.dtype, (b, view.st.to_uop()))),
200
+ (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
201
+ ])
202
+
203
+ multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])
204
+
205
+ def full_ast_rewrite(pre:UOp, var_vals:Dict[Variable, int], assigned:Set[UOp]) -> Tuple[UOp, ScheduleItemContext]:
206
+ # fuse and fold store -> loads
207
+ sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, {x.buf_uop:x.src[2] for x in pre.src})
208
+ # assert cyclic dependency
209
+ for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.buf_uop in assigned), key=lambda x:x.buf_uop):
210
+ if not all_same([x.op for x in ops]):
211
+ raise RuntimeError(f"cycle detected in kernel.\nhelp: use .contiguous() to break the part loading pre-assign {b} into a different kernel.")
212
+ # do movementops
213
+ sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
214
+ # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
215
+ if len(assign_targets:=[x.buf_uop for x in sink.sparents if x.op is Ops.ASSIGN]) != 0:
216
+ if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
217
+ and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.buf_uop in assign_targets):
218
+ raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
219
+ +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
220
+ # convert to AST
221
+ sink = graph_rewrite(graph_rewrite(sink, to_si, ctx:=ScheduleItemContext(var_vals, assigned)), append_bufs, ctx)
222
+ if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE.append(((pre, var_vals, assigned), sink))
223
+ return sink, ctx
224
+
225
+ PROCESS_REPLAY_CAPTURE: List[Tuple[Tuple, UOp]] = []
226
+ if getenv("RUN_PROCESS_REPLAY"):
227
+ @atexit.register
228
+ def save_process_replay():
229
+ for x,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(x[0].key), (x, {}, ret))
230
+
231
+ # **** Schedule grouping
232
+
233
+ def uval(u:UOp) -> UOp:
234
+ assert is_scheduled(u), f"must be a scheduled op {u}"
235
+ return to_store.src[0] if (to_store:=u.src[1]).is_contiguous_base else to_store
236
+
237
+ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp], realizes:Dict[UOp, UOp],
238
+ reduce_for_op:Dict[UOp, UOp], group:Dict[UOp, None], cache:Dict[Tuple[UOp, ShapeTracker], None]) -> None:
239
+ """recursively search the uop for groupable children, realize the UOp if a child can't group"""
240
+ if (tr, st) in cache: return
241
+ cache.setdefault((tr, st))
242
+ rsize = unwrap(allbufs[r].st).size
243
+ if tr in realizes and tr is not r:
172
244
  # can only fuse contiguous
173
245
  # max one reduceop per kernel
174
- if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.add(r)
175
- return group.add(tr)
246
+ if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
247
+ return group.setdefault(tr)
176
248
  for tr_next in children[tr]:
177
- if tr_next.realized is None:
178
- # max one reduceop per kernel
179
- if tr_next.op in ReduceOps: return group.add(r)
180
- # can only fuse contiguous
181
- if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r)
182
- _recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group)
183
-
184
- def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], DefaultDict[LazyBuffer, int],
185
- Dict[LazyBuffer, _LBScheduleItem]]:
186
- """create a graph for realizing the outputs"""
187
- # start by just realizing the buffers passed in
188
- realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None}
189
- allbufs: Dict[LazyBuffer, None] = {}
190
- simple_pads: Set[LazyBuffer] = set()
191
- children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
192
- for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
193
- assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None}
194
-
195
- # check if we have to realize pads
196
- for p in simple_pads:
197
- if not _is_padding_okay(p, realizes):
198
- realizes[p] = None
199
-
249
+ # max one reduceop per kernel
250
+ if (tr_next_uop:=uval(allbufs[tr_next]).base).op is Ops.REDUCE_AXIS: return group.setdefault(r)
251
+ # can only fuse contiguous
252
+ if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
253
+ recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
254
+
255
+ def get_isolated_children(r:UOp, reduce_for_op:Dict[UOp, UOp], children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp],
256
+ realizes:Dict[UOp, UOp], group:Dict[UOp, None]) -> Dict[UOp, None]:
257
+ rc_parents, cache = deque(group), set()
258
+ while rc_parents:
259
+ if (p:=uval(allbufs[rc_parents.pop()])) in cache: continue
260
+ cache.add(p)
261
+ # max one reduceop per kernel
262
+ if p.op is Ops.REDUCE_AXIS: return {}
263
+ rc_parents.extend(x.base.buf_uop for x in p.src if is_scheduled(x.base) and x.base.buf_uop is not r)
264
+ # search descendants of the reduceop that can cleanly group
265
+ descendants: Dict[UOp, None] = {}
266
+ for tr in group: recursive_group(tr, unwrap(allbufs[tr].st), tr, children, allbufs, realizes, reduce_for_op, descendants, cache={})
267
+ return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
268
+
269
+ def group_realizes(ctx:ScheduleContext, realizes:Dict[UOp, UOp]) -> List[List[UOp]]:
270
+ """search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
200
271
  # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
201
- reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
202
- for r in allbufs:
203
- if r.op not in ReduceOps or r in realizes: continue
204
-
205
- group: Set[LazyBuffer] = set()
206
- _recursive_group(r, r.st, r, children, realizes, reduce_for_op, group)
272
+ reduce_for_op: Dict[UOp, UOp] = {}
273
+ reduce_of_const: List[UOp] = []
274
+ double_reduces: List[UOp] = []
275
+ for r, r_uop in ctx.allbufs.items():
276
+ if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue
277
+ if FUSE_CONV_BW and r_uop.op is Ops.REDUCE_AXIS and uval((x:=r_uop.src[0]).base).op is r_uop.op and x.base is not x: double_reduces.append(r)
278
+ if r in realizes: continue
279
+ group: Dict[UOp, None] = {}
280
+ recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, realizes, reduce_for_op, group, cache={})
207
281
  # max one reduceop per kernel
208
282
  can_chase = all(tr not in reduce_for_op for tr in group)
209
283
  # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
210
284
  forced_realize = r in group
211
285
  if not forced_realize and len(group) > 1:
212
- # create a multi output kernel if the LazyBufferss can cleanly group
213
- rc_parents, rc_children = deque(group), deque(group)
214
- while rc_parents and not forced_realize:
215
- # max one reduceop per kernel
216
- if (p:=rc_parents.pop()).op in ReduceOps: forced_realize = True
217
- else: rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
218
- # search descendants of the reduceop that can cleanly group
219
- realized_descendants: Set[LazyBuffer] = set()
220
- while rc_children and not forced_realize:
221
- if (c:=rc_children.pop()).op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op:
222
- realized_descendants.clear()
223
- break
224
- if c in realizes and c not in group: realized_descendants.add(c)
225
- rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device)
226
- group.update(realized_descendants)
286
+ group = get_isolated_children(r, reduce_for_op, ctx.children, ctx.allbufs, realizes, group)
227
287
  # can only fuse assign if no other assign_target is used in the kernel
228
- if not forced_realize and any(x.op is LoadOps.ASSIGN for x in group):
288
+ if not forced_realize and any(x in ctx.assigns for x in group):
229
289
  parents = deque((r, *group))
230
290
  while parents and not forced_realize:
231
- if (p:=parents.pop().base).realized or p in realizes:
232
- if p in assign_targets and assign_targets[p] not in group: forced_realize, can_chase = True, False
233
- continue
234
- parents.extend(p.srcs)
235
- if forced_realize:
291
+ if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue
292
+ if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False
293
+ if p in realizes: continue
294
+ parents.extend([x.base.src[0] for x in p_uop.src if x.base.op is Ops.VIEW and len(x.base.src) != 0])
295
+ if forced_realize or not group:
236
296
  tr = r
237
297
  if can_chase:
238
298
  # can chase this down to contiguous children
239
- st = tr.st
240
- while len(children[tr]) == 1:
241
- tr_next = next(iter(children[tr]))
242
- st_childs = dedup(s for s in tr_next.srcs if s.base is tr)
299
+ st = unwrap(r_uop.st)
300
+ while len(ctx.children[tr]) == 1:
301
+ tr_next_uop = uval(ctx.allbufs[(tr_next:=next(iter(ctx.children[tr])))])
302
+ st_childs = dedup([unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop is tr])
243
303
  if len(st_childs) > 1: break
244
- if st.size != st_childs[0].st.size: break
245
- st = st + st_childs[0].st
246
- if not st.contiguous or tr_next.op in ReduceOps: break
304
+ if st.size != st_childs[0].size: break
305
+ st = st + st_childs[0]
306
+ if not st.contiguous or tr_next_uop.op is Ops.REDUCE_AXIS: break
247
307
  tr = tr_next
248
308
  # don't cast to higher size before store (tr cannot be realized if forced_realize)
249
- if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
250
- tr = tr.srcs[0].base
251
- reduce_for_op[tr] = r
252
- realizes[tr] = None
253
- else: reduce_for_op.update((tr, r) for tr in group)
254
-
255
- output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
256
- for buf in realizes:
257
- if buf.realized is not None or buf.op is LoadOps.CONST or buf in seen: continue
258
- output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf)
259
-
260
- # make things that can't be images not images
261
- if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
262
- not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
263
- if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
264
- buf.dtype = dtypes.float32
265
- # hack the underlying buffer too
266
- if buf.base is buf:
267
- assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer"
268
- buf.buffer.dtype = dtypes.float32
269
- buf.buffer.options = None
270
-
271
- # preschedule all buffers in realizes
272
- prescheduled = {group[0]:_schedule_group(tuple(group), realizes, reduce_for_op) for group in output_groups.values()}
273
- schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs}
274
-
275
- graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
276
- in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int)
277
- for key, lsi in prescheduled.items():
278
- if key not in in_degree: in_degree[key] = 0
279
- # realize outputs after all parents are realized
280
- scheduled_parents = set(schedule_targets[x].outputs[0] for x in lsi.inputs if x in schedule_targets)
281
- for x in scheduled_parents:
282
- graph[x].append(key)
283
- in_degree[key] += 1
309
+ if (tr_uop:=uval(ctx.allbufs[tr])).op is Ops.CAST and tr_uop.dtype.base.itemsize > tr_uop.src[0].dtype.base.itemsize:
310
+ tr = tr_uop.src[0].base.buf_uop
311
+ group = {tr: None}
312
+ realizes[tr] = tr
313
+ reduce_for_op.update((tr, r) for tr in group)
314
+ if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.WHERE: reduce_of_const.append(r)
315
+ # fuse double reduces with no other child
316
+ for reduceop in double_reduces:
317
+ top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
318
+ if len(ctx.children[top_reduce]) == 1: del realizes[top_reduce]
319
+ # maybe fuse arange with its children
320
+ for rbuf in reduce_of_const:
321
+ group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
322
+ if any(ctx.allbufs[tr].src[1].is_contiguous_base for tr in group): continue
323
+ kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
324
+ if len(kernel_children) == 0: continue
325
+ for tr in group: del realizes[tr]
326
+ # group BUFFER uops into kernels
327
+ output_groups: DefaultDict[UOp, List[UOp]] = defaultdict(list)
328
+ for ubuf in realizes: output_groups[reduce_for_op.get(ubuf, ubuf)].append(ubuf)
329
+ return list(output_groups.values())
330
+
331
+ # **** Schedule creation and BFS toposort
332
+
333
+ def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> UOp:
334
+ ctx[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), to_store)
335
+ return UOp(Ops.LOAD, base.dtype, (b, st.to_uop()))
336
+
337
+ def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> Optional[UOp]:
338
+ base_shape = unwrap(base.st).shape
339
+ st = unwrap(view.st)
340
+ # fold simple pads
341
+ if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(base_shape) and resolve(prod(base_shape) >= prod([y-x for x,y in m])):
342
+ return None if can_pad(base) else realize(ctx, b, to_store, base).view(st)
343
+ # early realize before expand
344
+ if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base).view(st)
345
+ # otherwise safety check pads
346
+ return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, b, to_store, base).view(st)
347
+
348
+ do_realize = PatternMatcher([
349
+ # always realize meta ops
350
+ (UPatSrc((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta)), realize),
351
+ # don't realize image to image casts
352
+ (UPatSrc(Ops.CAST, src=(UPat(Ops.LOAD, name="x"),), dtype=dtypes.float).view(name="v"), lambda ctx,x,v,**kwargs: r.src[2].view(v.st)
353
+ if (r:=ctx.get(b:=x.buf_uop)) is not None and r.op is Ops.STORE and isinstance(b.dtype, ImageDType) and r.src[2].op not in GroupOp.Meta else None),
354
+ # realize before expand or unsafe pad ops
355
+ (UPatSrc().view(name="view"), realize_view),
356
+ # realize before COPY or BUFFER_VIEW
357
+ (UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatSrc(), UPatSrc().view(name="view")),), name="root"),
358
+ lambda ctx,root,view=None,**kwargs: root.replace(src=(realize(ctx,**kwargs) if view is None else realize(ctx,**kwargs).view(view.st),)),),
359
+ ])
360
+ break_sched = PatternMatcher([(UPatSrc(), lambda ctx,b,to_store,base: realize(ctx, b, to_store, base) if b in ctx else None),])
361
+
362
+ @track_rewrites(named=True)
363
+ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
364
+ if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {}
365
+ for out in outs: out.forced_realize = True
366
+ # create the big graph
367
+ ctx = ScheduleContext()
368
+ cache: Dict[LazyBuffer, UOp] = {}
369
+ buffers: Dict[UOp, Buffer] = {}
370
+ lazybufs: Dict[Buffer, LazyBuffer] = {}
371
+ big_graph = UOp.sink(*(to_uop(x, ctx, buffers, lazybufs, cache) for x in outs))
372
+ # get realizes
373
+ realizes: Dict[UOp, UOp] = {}
374
+ graph_rewrite(big_graph, do_realize, realizes)
375
+ store_groups = group_realizes(ctx, realizes)
376
+ # split realizes into small graphs
377
+ graph_rewrite(big_graph, break_sched, realizes)
378
+ sinks = [UOp.sink(*(realizes[u] for u in stores)) for stores in store_groups]
379
+ # preschedule all realizes
380
+ prescheduled: List[ScheduleItem] = []
381
+ for sink in sinks:
382
+ metadata = tuple({mx for x in sink.sparents if (x.op is Ops.STORE or is_scheduled(x)) and (mx:=ctx.ubuf_metadata.get(x.buf_uop))})
383
+ ast, ast_ctx = full_ast_rewrite(sink, ctx.var_vals, ctx.assigns)
384
+ prescheduled.append(ScheduleItem(ast, tuple(b for u in ast_ctx.bufs if (b:=buffers[u]).size != 0), metadata, tuple(ast_ctx.assign_preloads)))
385
+ # do BFS
386
+ schedule_targets = {out:si for si in prescheduled for out in si.outputs}
387
+ graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)
388
+ in_degree: DefaultDict[ScheduleItem, int] = defaultdict(int)
389
+ for si in prescheduled:
284
390
  # realize outputs before a parent is assigned to
285
- parents_assigns = set(schedule_targets[assign_targets[x]].outputs[0] for x in lsi.inputs if x in assign_targets)
391
+ parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(buffers[x])) and xsi is not si)
286
392
  for assign in parents_assigns:
287
- graph[key].append(assign)
393
+ graph[si].append(assign)
288
394
  in_degree[assign] += 1
289
-
290
- return graph, in_degree, prescheduled
291
-
292
- # *** DAG ordering: breadth first search ***
293
-
294
- SCHEDULES: List = []
295
- def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
296
- if seen is None: seen = set()
297
- graph, in_degree, prescheduled = _graph_schedule(outs, seen)
298
- queue = deque(si for key, si in prescheduled.items() if in_degree[key] == 0)
395
+ # realize outputs after all parents are realized
396
+ scheduled_parents = dedup(xsi for x in si.inputs if (xsi:=schedule_targets.get(x)) is not None and xsi not in parents_assigns)
397
+ for x in scheduled_parents:
398
+ graph[x].append(si)
399
+ in_degree[si] += 1
400
+ queue = deque(si for si in prescheduled if in_degree[si] == 0)
299
401
  schedule: List[ScheduleItem] = []
300
- var_vals: Dict[Variable, int] = {}
301
- kernel_number = GlobalCounters.kernel_count
302
402
  while queue:
303
- ps = queue.popleft()
304
- for buf in ps.outputs: seen.add(buf)
305
- if GRAPH:
306
- kernel_number += 1
307
- for out in ps.outputs: realized_lazybuffer(out, kernel_number)
308
- var_vals = merge_dicts([var_vals, ps.var_vals])
309
- for out in ps.outputs: del out.srcs # can only schedule once
310
- schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0)))
311
- if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
312
- for x in graph[ps.outputs[0]]:
403
+ schedule.append(si:=queue.popleft())
404
+ for b in si.outputs: del lazybufs[b].srcs # can only schedule once
405
+ if (m:=BUF_LIMIT.get(device:=si.outputs[0].device)) and len(si.bufs) >= m:
406
+ if DEBUG >= 3: print(si)
407
+ raise RuntimeError(f"Kernel for {si.metadata} exceeded the {m} buffer count limit for {device} with {len(si.bufs)} buffers.")
408
+ for x in graph[si]:
313
409
  in_degree[x] -= 1
314
- if in_degree[x] == 0: queue.append(prescheduled[x])
315
-
316
- if SAVE_SCHEDULE:
317
- def _save():
318
- print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
319
- with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
320
- if len(SCHEDULES) == 0: atexit.register(_save)
321
- SCHEDULES.extend((ps.ast for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)])
410
+ if in_degree[x] == 0: queue.append(x)
322
411
  # confirm everything was scheduled correctly
323
- if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
324
- raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
412
+ if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
325
413
  if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
326
- return schedule, var_vals
414
+ return schedule, ctx.var_vals
327
415
 
328
- def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
329
- schedule, var_vals = create_schedule_with_vars(outs, seen)
416
+ def create_schedule(outs:List[LazyBuffer]) -> List[ScheduleItem]:
417
+ schedule, var_vals = create_schedule_with_vars(outs)
330
418
  assert len(var_vals) == 0
331
419
  return schedule
332
-
333
- # *** memory planning ***
334
-
335
- def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], debug_prefix="") -> Dict[Buffer, Buffer]:
336
- if getenv("NO_MEMORY_PLANNER"): return {}
337
- last_appearance = {}
338
- for i,u in enumerate(buffers):
339
- for buf in u: last_appearance[buf] = i
340
-
341
- # LRU algorithm
342
- assigned: Dict[Buffer, Buffer] = {}
343
- local_cache: DefaultDict[Tuple[str, int, DType], List[Buffer]] = defaultdict(list)
344
-
345
- def handle_buffer(buf):
346
- key = (buf.device, buf.size, buf.dtype)
347
- if buf not in assigned:
348
- if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
349
- else: assigned[buf] = Buffer(*key)
350
- if i == last_appearance[buf]:
351
- if assigned[buf] not in local_cache[key]: local_cache[key].append(assigned[buf])
352
-
353
- for i,u in enumerate(buffers):
354
- for buf in u:
355
- # all unallocated unparented buffers are fair game to replace
356
- if buf.is_allocated() or buf.lb_refcount > 0: continue
357
- # handle view buffers
358
- if buf._base is not None:
359
- assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf._base, buf._base), offset=buf.offset)
360
- else:
361
- handle_buffer(buf)
362
-
363
- if DEBUG >= 1 and len(ak:=dedup(assigned.keys())) != len(av:=dedup(assigned.values())):
364
- print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
365
- f"{len(ak)} -> {len(av)} bufs")
366
- return assigned
367
-
368
- def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
369
- assigned = _internal_memory_planner([si.bufs for si in schedule])
370
- return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs)) for si in schedule]