tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +35 -37
  4. tinygrad/codegen/linearize.py +19 -10
  5. tinygrad/codegen/lowerer.py +31 -8
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +10 -0
  8. tinygrad/device.py +28 -11
  9. tinygrad/dtype.py +12 -3
  10. tinygrad/engine/jit.py +3 -2
  11. tinygrad/engine/multi.py +0 -1
  12. tinygrad/engine/realize.py +7 -4
  13. tinygrad/engine/schedule.py +227 -255
  14. tinygrad/engine/search.py +20 -27
  15. tinygrad/gradient.py +3 -0
  16. tinygrad/helpers.py +7 -4
  17. tinygrad/nn/state.py +2 -2
  18. tinygrad/ops.py +64 -329
  19. tinygrad/renderer/__init__.py +19 -3
  20. tinygrad/renderer/cstyle.py +39 -18
  21. tinygrad/renderer/llvmir.py +55 -18
  22. tinygrad/renderer/ptx.py +6 -2
  23. tinygrad/renderer/wgsl.py +20 -12
  24. tinygrad/runtime/autogen/libc.py +404 -71
  25. tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
  26. tinygrad/runtime/autogen/webgpu.py +6985 -0
  27. tinygrad/runtime/graph/metal.py +28 -29
  28. tinygrad/runtime/ops_amd.py +37 -34
  29. tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
  30. tinygrad/runtime/ops_disk.py +1 -1
  31. tinygrad/runtime/ops_dsp.py +59 -33
  32. tinygrad/runtime/ops_llvm.py +14 -12
  33. tinygrad/runtime/ops_metal.py +78 -62
  34. tinygrad/runtime/ops_nv.py +9 -6
  35. tinygrad/runtime/ops_python.py +5 -5
  36. tinygrad/runtime/ops_webgpu.py +200 -38
  37. tinygrad/runtime/support/am/amdev.py +23 -11
  38. tinygrad/runtime/support/am/ip.py +10 -10
  39. tinygrad/runtime/support/elf.py +2 -0
  40. tinygrad/runtime/support/hcq.py +7 -5
  41. tinygrad/runtime/support/llvm.py +8 -14
  42. tinygrad/shape/shapetracker.py +3 -2
  43. tinygrad/shape/view.py +2 -3
  44. tinygrad/spec.py +21 -20
  45. tinygrad/tensor.py +150 -90
  46. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  47. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  48. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  49. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  50. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  51. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  52. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  53. tinygrad/viz/index.html +544 -0
  54. tinygrad/viz/perfetto.html +178 -0
  55. tinygrad/viz/serve.py +205 -0
  56. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
  57. tinygrad-0.10.2.dist-info/RECORD +99 -0
  58. tinygrad/codegen/rewriter.py +0 -516
  59. tinygrad-0.10.1.dist-info/RECORD +0 -86
  60. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  61. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
  62. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,11 @@
1
- import sys, functools
1
+ import sys, functools, atexit, pickle
2
2
  from collections import defaultdict, deque
3
- from dataclasses import dataclass, field
3
+ from dataclasses import dataclass
4
4
  from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
5
- from tinygrad.ops import can_pad, identity_element, resolve, symbolic_simple, view_left, merge_views
6
- from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap, flatten
7
- from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY
5
+ from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views
6
+ from tinygrad.codegen.symbolic import symbolic_simple
7
+ from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv
8
+ from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND
8
9
  from tinygrad.dtype import ImageDType
9
10
  from tinygrad.shape.shapetracker import ShapeTracker
10
11
  from tinygrad.shape.view import View, strides_for_shape
@@ -16,17 +17,17 @@ sys.setrecursionlimit(10000)
16
17
 
17
18
  # **** schedule simplifier
18
19
 
19
- def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
20
- if not all_int(x.shape): return None
21
- # remove reduce on unmasked const
22
- prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
23
- ret = x.const_arg
20
+ def simplify_stride0_reduce(reduce:UOp, x:UOp):
21
+ # must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis)
22
+ if any(v.mask is not None for v in unwrap(x.st).views): return None
23
+ # must have all stride 0 in the relevant axis (NOTE: can do partial)
24
+ if not all(unwrap(x.st).views[-1].strides[axis] == 0 for axis in reduce.arg[1]) or not all_int(x.shape): return None
25
+ prshape = prod(x.shape[i] for i in reduce.arg[1])
26
+ ret = x.shrink(tuple((0,s) if i not in reduce.arg[1] else (0,1) for i,s in enumerate(x.shape)))
24
27
  match reduce.arg[0]:
25
- case Ops.ADD: ret *= prshape
26
- case Ops.MUL: ret **= prshape
27
- case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
28
- case _: return None
29
- return reduce.const_like(ret)
28
+ case Ops.ADD: return ret*prshape
29
+ case Ops.MUL: return ret.pow(prshape)
30
+ case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough
30
31
 
31
32
  def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
32
33
  if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
@@ -39,22 +40,25 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
39
40
  sym = symbolic_simple+PatternMatcher([
40
41
  # UOp with size 0 is zero
41
42
  (UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
42
- and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
43
+ and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
43
44
  # DETACH and CONTIGUOUS_BACKWARD are NOOPs here
44
45
  (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
45
46
  # reduce of size 0 is the identity element
46
47
  (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
47
48
  lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
48
- # reduce of const is collapsed (TODO: make this a generic rule for stride0)
49
- (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
49
+ # reduce on stride 0 is collapsed
50
+ (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce),
50
51
  # COPY(CONST) creates a new CONST on the destination device
51
- (UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)),
52
+ (UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.arg)),
52
53
  # no COPY to same device, except clone (arg is True)
53
54
  (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
54
55
  lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
55
56
  # remove cast to image when it's already a contiguous image
56
- (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)),
57
- lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
57
+ (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"))),)),
58
+ lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
59
+ # make things that can't be images not images
60
+ (UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
61
+ and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
58
62
  # remove contiguous if we can just view the buffer
59
63
  (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
60
64
  lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
@@ -63,10 +67,9 @@ sym = symbolic_simple+PatternMatcher([
63
67
  # support for using a contiguous permuted view instead of the parent view if one exists
64
68
  (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
65
69
  (UPat(GroupOp.ALU, name="alu"), replace_contiguous),
66
- # remove CONST/BIND/BUFFER from SINK
67
- (UPat(Ops.SINK, name="root"),
68
- lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
69
- if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
70
+ # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
71
+ (UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"),
72
+ lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None),
70
73
  ])
71
74
 
72
75
  remove_movement_ops = merge_views+PatternMatcher([
@@ -80,95 +83,40 @@ remove_movement_ops = merge_views+PatternMatcher([
80
83
  # **** UOp realization
81
84
 
82
85
  @dataclass(frozen=True)
83
- class ScheduleContext:
84
- ops_metadata: dict[UOp, Metadata] # this maps uops in the schedule to the tensor metadata
85
- assigns: dict[UOp, None] = field(default_factory=dict) # this holds all the BUFFER uops we ASSIGN to in this schedule
86
- realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
87
- allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
88
- children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
89
- preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
90
-
91
- # wrap tensor uops around a VIEW(BUFFER, <uop>)
92
- # this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
93
- def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp:
94
- if (r:=cache.get(buf)) is not None: return r
95
- # SINK is passthrough
96
- if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
97
- # skip creating buffers for CONST/BIND/DEVICE/BUFFER
98
- if buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf
99
- if buf.base.op is Ops.BUFFER: return buf.view(unwrap(buf.st))
100
- # VIEW is passthrough
101
- if buf is not buf.base:
102
- cache[buf] = ret = add_buffers(buf.base, buffer_map, cache).view(unwrap(buf.st))
103
- return ret
104
- # make things that can't be images not images
105
- dtype = buf.dtype
106
- if isinstance(dtype, ImageDType) and (prod(buf.shape)!=prod(dtype.shape) or not any(buf.shape[x]%4==0 for x in unwrap(buf.st).unit_stride_axes())):
107
- if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}")
108
- dtype = buf.dtype.base
109
- # ASSIGN already has a target buffer, otherwise we create a new one
110
- assert isinstance(buf.device, str), f"buf device is str, not {buf.device}"
111
- buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
112
- op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
113
- # track the buffer uop for the simplified uop
114
- buffer_map[buf] = buf_uop
115
- # (early) bufferize
116
- cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
117
- return ret
118
-
119
- class UPatScheduled(UPat):
120
- def __init__(self, *args, **kwargs):
121
- super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
122
-
123
- def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
124
-
125
- def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
86
+ class GrouperContext:
87
+ assigns: dict[UOp, UOp] # maps realized buffers to assigns
88
+ realizes: dict[UOp, None] # all the simplified tensor uops we realize
89
+ children: defaultdict[UOp, dict[UOp, None]] # children graph of tensor uops
90
+
91
+ def realize(ctx:GrouperContext, tr:UOp) -> None: ctx.realizes[tr] = None
92
+
93
+ def realize_before_view(ctx:GrouperContext, view:UOp, src:UOp) -> None:
126
94
  st = unwrap(view.st)
127
95
  # fold simple pads
128
96
  if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
129
- return None if can_pad(src, ctx.realizes, dict()) else realize(ctx, b, src)
97
+ return None if can_pad(src, ctx.realizes, cache=dict()) else realize(ctx, src)
130
98
  # early realize before expand
131
- if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
99
+ if resolve(prod(src.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, src)
132
100
  # otherwise safety check pads
133
- return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, dict())) else realize(ctx, b, src)
134
-
135
- def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
136
- if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
137
- buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
138
- return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
101
+ return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, cache=dict())) else realize(ctx, src)
139
102
 
140
103
  do_realize = PatternMatcher([
141
104
  # always realize SINK parents
142
- (UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
105
+ (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x.base, None) for x in s.src if x.base.op not in {Ops.CONST, Ops.BIND, Ops.BUFFER})),
143
106
  # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
144
- (UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
107
+ (UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
145
108
  # realize before expand or unsafe pad ops
146
- (UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view),
147
- # realize before COPY or BUFFER_VIEW
148
- (UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
149
- (UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
150
- # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
151
- (UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
109
+ (UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}, name="src"),)), realize_before_view),
110
+ # realize before COPY
111
+ (UPat(Ops.COPY, src=(UPat(), UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW}, name="tr"))), realize),
152
112
  ])
153
113
 
154
- def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
155
- ctx.allbufs[buf_uop] = view
156
- if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns[buf_uop] = None
157
- for x in op.base.src:
158
- if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
159
- create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
160
-
161
- def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER
162
- def uval(u:UOp) -> UOp:
163
- assert is_scheduled(u), f"must be a scheduled op {u}"
164
- return u.src[1]
165
-
166
- def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp],
167
- reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
114
+ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
115
+ reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
168
116
  """recursively search the uop for groupable children, realize the UOp if a child can't group"""
169
117
  if (tr, st) in cache: return
170
118
  cache.setdefault((tr, st))
171
- rsize = unwrap(allbufs[r].st).size
119
+ rsize = unwrap(r.st).size
172
120
  if tr in realizes and tr is not r:
173
121
  # can only fuse contiguous
174
122
  # max one reduceop per kernel
@@ -176,23 +124,28 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, di
176
124
  return group.setdefault(tr)
177
125
  for tr_next in children[tr]:
178
126
  # max one reduceop per kernel
179
- if (tr_next_uop:=uval(allbufs[tr_next]).base).op is Ops.REDUCE_AXIS: return group.setdefault(r)
127
+ if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
180
128
  # can only fuse contiguous
181
- if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
182
- recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
129
+ if len(st_childs:=dedup(unwrap(x.st) for x in tr_next.src if x.base == tr)) > 1: return group.setdefault(r)
130
+ recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
131
+
132
+ def append_uop(ctx:GrouperContext, u:UOp) -> None:
133
+ if u.op is Ops.ASSIGN: ctx.assigns[u.buf_uop] = u
134
+ for s in u.src: ctx.children[s.base][u] = None
135
+ create_ctx = PatternMatcher([(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}, name="u"), append_uop)])
183
136
 
184
- def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]:
137
+ def group_realizes(sink:UOp) -> dict[UOp, None]:
185
138
  # start by adding uops that always realize
186
- sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
139
+ sink = graph_rewrite(sink, do_realize+create_ctx, ctx:=GrouperContext({}, {}, defaultdict(dict)))
187
140
  # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
188
141
  reduce_for_op: dict[UOp, UOp] = {}
189
142
  double_reduces: list[UOp] = []
190
- for r, r_uop in ctx.allbufs.items():
191
- if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue
192
- if FUSE_CONV_BW and is_scheduled((x:=r_uop.src[0]).base) and uval(x.base).op is r_uop.op and x.base is not x: double_reduces.append(r)
143
+ for r in sink.toposort:
144
+ if r.op is not Ops.REDUCE_AXIS: continue
145
+ if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
193
146
  if r in ctx.realizes: continue
194
147
  group: dict[UOp, None] = {}
195
- recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, ctx.realizes, reduce_for_op, group, cache={})
148
+ recursive_group(r, unwrap(r.st), r, ctx.children, ctx.realizes, reduce_for_op, group, cache={})
196
149
  # max one reduceop per kernel
197
150
  can_chase = all(tr not in reduce_for_op for tr in group)
198
151
  # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
@@ -200,59 +153,77 @@ def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]:
200
153
  # can only have one output
201
154
  if not forced_realize and len(group) > 1: forced_realize = True
202
155
  # can only fuse assign if no other assign_target is used in the kernel
203
- if not forced_realize and any(x in ctx.assigns for x in group):
156
+ if not forced_realize and any(x.op is Ops.ASSIGN for x in group):
204
157
  parents = deque((r, *group))
205
158
  while parents and not forced_realize:
206
- if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue
207
- if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False
159
+ p = parents.pop().base
160
+ if (assign:=ctx.assigns.get(p)) is not None and assign not in group: forced_realize, can_chase = True, False
208
161
  if p in ctx.realizes: continue
209
- parents.extend([x.base.buf_uop for x in p_uop.src if x.base.is_realized or (x.base.op is Ops.VIEW and len(x.base.src) != 0)])
162
+ parents.extend(p.src)
210
163
  if forced_realize or not group:
211
164
  tr = r
212
165
  if can_chase:
213
166
  # can chase this down to contiguous children
214
- st = unwrap(r_uop.st)
167
+ st = unwrap(tr.st)
215
168
  while len(ctx.children[tr]) == 1:
216
- tr_next_uop = uval(ctx.allbufs[(tr_next:=next(iter(ctx.children[tr])))])
217
- st_childs = dedup([unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop is tr])
169
+ tr_next = next(iter(ctx.children[tr]))
170
+ st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr)
218
171
  if len(st_childs) > 1: break
219
172
  if st.size != st_childs[0].size: break
220
173
  st = st + st_childs[0]
221
- if not st.contiguous or tr_next_uop.op is Ops.REDUCE_AXIS: break
174
+ if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
222
175
  tr = tr_next
223
176
  # don't cast to higher size before store (tr cannot be realized if forced_realize)
224
- if (tr_uop:=uval(ctx.allbufs[tr])).op is Ops.CAST and tr_uop.dtype.base.itemsize > tr_uop.src[0].dtype.base.itemsize:
225
- tr = tr_uop.src[0].base.buf_uop
177
+ if tr.op is Ops.CAST and tr.dtype.itemsize > tr.src[0].dtype.itemsize:
178
+ tr = tr.src[0].base
226
179
  group = {tr: None}
227
- ctx.realizes[tr] = tr
180
+ ctx.realizes[tr] = None
228
181
  reduce_for_op.update((tr, r) for tr in group)
229
- if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.CONST:
182
+ if FUSE_ARANGE and r.arg[0] is Ops.ADD and r.src[0].base.op is Ops.CONST:
230
183
  # maybe fuse arange with its children
231
184
  if len(flatten(ctx.children[tr] for tr in group)) != 0:
232
185
  for tr in group: del ctx.realizes[tr]
233
186
  # fuse double reduces with no other child
234
187
  for reduceop in double_reduces:
235
- top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
188
+ top_reduce = reduceop.src[0].base
236
189
  if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
237
- graph_rewrite(sink, break_sched, ctx)
238
190
  return ctx.realizes
239
191
 
240
- # break the SINK into stores
192
+ # break the SINK into kernels
241
193
 
242
- def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
243
- # NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
244
- return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
245
-
246
- def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
247
- if (m:=ctx.ops_metadata.get(b)) is not None: ctx.ops_metadata[x] = m
248
- if b not in ctx.realizes: return x # collapse BUFFER
249
- ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
250
- return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
194
+ @dataclass(frozen=True)
195
+ class Kernel:
196
+ ast: UOp
197
+ metadata: tuple[Metadata, ...]
198
+ def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
251
199
 
252
- break_sched = PatternMatcher([
253
- # VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
254
- (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
255
- (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
200
+ @dataclass(frozen=True)
201
+ class KernelContext:
202
+ realizes: dict[UOp, None]
203
+ ops_metadata: dict[UOp, Metadata]
204
+
205
+ def create_kernel(ctx:KernelContext, x:UOp):
206
+ if x not in ctx.realizes: return None
207
+ assert isinstance(x.device, str), f"buf device in kernel must be string {x.device}"
208
+ b = x.buf_uop if x.op is Ops.ASSIGN else UOp.new_buffer(x.device, x.size, x.dtype)
209
+ output_st = ShapeTracker.from_shape(x.shape)
210
+ # KERNEL nodes become: ASSIGN(VIEW(BUFFER), KERNEL)
211
+ # TODO: this should be ASSIGN(BUFFER, KERNEL) followed by the output ShapeTracker
212
+ return b.view(output_st).assign(UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x, (m,) if (m:=ctx.ops_metadata.get(x)) else ())))
213
+
214
+ def append_to_kernel(ctx:KernelContext, x:UOp):
215
+ new_srcs: list[UOp] = []
216
+ new_metadata: dict[Metadata, None] = dict.fromkeys(x.arg.metadata)
217
+ for s in x.src:
218
+ if s.op is Ops.BUFFER or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL) or s in ctx.realizes: new_srcs.append(s)
219
+ else:
220
+ new_srcs.extend(s.src)
221
+ if (m:=ctx.ops_metadata.get(s)) is not None: new_metadata[m] = None
222
+ return x.replace(src=n, arg=Kernel(x.arg.ast, tuple(new_metadata))) if (n:=tuple(dedup(new_srcs))) != x.src else None
223
+
224
+ create_kernels = merge_views+PatternMatcher([
225
+ (UPat(GroupOp.All-{Ops.KERNEL, Ops.BUFFER}, name="x"), create_kernel),
226
+ (UPat(Ops.KERNEL, name="x"), append_to_kernel),
256
227
  ])
257
228
 
258
229
  # **** convert Kernel to a ScheduleItem (for legacy reasons)
@@ -273,23 +244,8 @@ class ScheduleItem:
273
244
  @functools.cached_property
274
245
  def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
275
246
 
276
- def kernel_to_si(k:UOp) -> ScheduleItem:
277
- assert k.op is Ops.KERNEL, f"must be KERNEL {k}"
278
- return ScheduleItem(k.arg.ast, tuple(u.buf_uop.buffer for u in k.src), k.arg.metadata)
279
-
280
247
  # **** Kernel creation
281
248
 
282
- @dataclass(frozen=True)
283
- class Kernel:
284
- ast: UOp
285
- metadata: tuple[Metadata, ...]
286
- def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
287
-
288
- @dataclass(frozen=True)
289
- class ScheduleItemContext:
290
- var_vals: dict[Variable, int]
291
- bufs: list[UOp] = field(default_factory=list)
292
-
293
249
  def apply_swizzle(u:UOp) -> UOp:
294
250
  with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
295
251
 
@@ -342,33 +298,47 @@ view_right = merge_views+PatternMatcher([
342
298
  (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
343
299
  ])
344
300
 
345
- def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
301
+ def _append_st_vars(ctx:dict[Variable, int], x:UOp) -> UOp|None:
346
302
  st = unwrap(x.st).simplify()
347
303
  if any(x.op is Ops.BIND for x in st.vars()):
348
304
  st, var_vals = st.unbind()
349
- ctx.var_vals.update(var_vals)
305
+ ctx.update(var_vals)
350
306
  return st.to_uop() if st != x.st else None
351
307
 
352
- def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
353
- ctx.bufs.append(x)
354
- return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
355
-
356
- to_si = PatternMatcher([
357
- # BUFFER -> DEFINE_GLOBAL
358
- (UPat(Ops.BUFFER, name="x"), _append_buf),
359
- # simplify and unbind the final VIEWs
308
+ def check_load_st(glbl:UOp, view:UOp):
309
+ if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
310
+ # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
311
+ if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
312
+ # if it has a single view and it's equal when you shrink a contig, it's fine
313
+ if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
314
+ # otherwise, it's not fine
315
+ raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
316
+ +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
317
+
318
+ fix_kernel_ops = PatternMatcher([
319
+ # BIND in shapetracker becomes DEFINE_VAR
360
320
  (UPat(Ops.VIEW, name="x"), _append_st_vars),
361
- # don't need SINK on COPY or BUFFER_VIEW
362
- (UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
363
- # don't need contiguous or assign anymore
321
+ # remove CONTIGUOUS/ASSIGN/DEVICE
364
322
  (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
365
323
  (UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
366
- # don't need DEVICE anymore
367
324
  (UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
368
- # PRELOAD becomes LOAD
369
- (UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
370
- # once images are loaded they become the base dtype
325
+ # no ImageDType after load
371
326
  (UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
327
+ # if this kernel also assigns to the loaded buffer, ensure we can index it correctly
328
+ (UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
329
+ ])
330
+
331
+ def load_buf(ctx:list[UOp], x:UOp):
332
+ if x not in ctx: ctx.append(x)
333
+ return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop()))
334
+
335
+ add_buffer_ops = PatternMatcher([
336
+ # LOAD
337
+ (UPat(Ops.BUFFER, name="x"), load_buf),
338
+ # STORE (except for COPY/BUFFER_VIEW)
339
+ (UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
340
+ (UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
341
+ lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
372
342
  ])
373
343
 
374
344
  def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
@@ -376,111 +346,113 @@ def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
376
346
  return var
377
347
  unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
378
348
 
379
- def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> UOp:
349
+ def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
350
+ assert sink.op is Ops.ASSIGN and sink.src[1].op is Ops.KERNEL, f"{sink} must be ASSIGN"
351
+ # substitute kernel sources for the target buffer
352
+ ast = sink.src[1].arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in sink.src[1].src if s.op is Ops.ASSIGN}).sink()
353
+ # add buffer ops
354
+ ast = graph_rewrite(ast, add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True)
380
355
  # unbind_vars + push views to edges
381
- sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=var_vals), view_right)
382
- # remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
383
- ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(var_vals))
384
- # deal with ASSIGN
385
- if len(ctx.assigns) != 0:
386
- assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer]
387
- for x in list(sink.toposort)[::-1]:
388
- # we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
389
- if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
390
- # PRELOAD tells the toposort this kernel should run before ASSIGN
391
- if x.op is Ops.PRELOAD:
392
- assign_preloads[x.buf_uop] = None
393
- # if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
394
- if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous:
395
- # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
396
- if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass
397
- # if it has a single view and it's equal when you shrink a contig, it's fine
398
- elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass
399
- # otherwise, it's not fine
400
- else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
401
- +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
402
- # NOTE: we only add the metadata for fused tensors
403
- metadata = tuple(dedup(m for x in pre.toposort if x.op is not Ops.BUFFER and (m:=ctx.ops_metadata.get(x)) is not None))
404
- return UOp(Ops.KERNEL, src=tuple(si_ctx.bufs), arg=Kernel(ast, metadata))
356
+ ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
357
+ # fix_kernel_ops
358
+ ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
359
+ # create subbuffer
360
+ if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = bufs[1].buffer.view(ast.size, ast.dtype, (x:=ast.src[0]).st_arg.views[0].offset*x.dtype.itemsize)
361
+ return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), sink.src[1].arg.metadata)
362
+
363
+ PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
364
+ if CAPTURE_PROCESS_REPLAY:
365
+ @atexit.register
366
+ def save_process_replay():
367
+ for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
405
368
 
406
369
  # **** schedule creation and toposort
407
370
 
408
371
  @track_rewrites(named=True)
409
372
  def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
373
+ # remove_movement_ops + sym
410
374
  tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
411
- # tensors can become an existing buffer or simplify to a const, no ScheduleItem needed
375
+
376
+ # display the cleaned up tensor graph
377
+ if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph")
378
+
379
+ # do_realize + group_realizes
380
+ sink = tensor_map[big_sink]
381
+ realize_map = group_realizes(sink)
382
+
383
+ # map tensors to new uops
412
384
  becomes_map: dict[UOp, UOp] = {}
413
- for k,v in tensor_map.items():
414
- # NOOP
415
- if k.base is v.base: continue
416
- # NOTE: only the base tensors get a BUFFER UOp
417
- if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st))
418
- # otherwise if it simplified to a CONST the UOp just becomes that CONST
419
- elif v.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
420
-
421
- # we group the rest of UOps into ScheduleItems
422
- buffer_map: dict[UOp, UOp] = {}
423
- sink = add_buffers(tensor_map[big_sink], buffer_map, cache={})
424
- # get realizes
425
- buf_tensors: dict[UOp, list[UOp]] = {}
385
+ rev_tensor_map: dict[UOp, list[UOp]] = {}
426
386
  ops_metadata: dict[UOp, Metadata] = {}
427
387
  for k,v in tensor_map.items():
428
- if (b:=buffer_map.get(v)) is not None:
429
- buf_tensors.setdefault(b, []).append(k)
430
- ops_metadata[b] = k.metadata
431
- realize_map = group_realizes(sink, ctx:=ScheduleContext(ops_metadata))
432
-
433
- # TODO: this should be the break between the "grouper" and the "linearizer"
434
- # here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
435
- # call into `def linearize_schedule(sched_sink:UOp) -> list[ScheduleItem]`
436
-
437
- # create kernels + map buffers to realized tensors
438
- sinks: list[UOp] = []
439
- var_vals: dict[Variable, int] = {}
440
- for buf_uop,store in realize_map.items():
441
- assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}"
442
- sinks.append(schedule_uop(store.sink(), ctx, var_vals))
443
- # can only schedule once
444
- for tensor_uop in buf_tensors[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
445
- # increment refcount for this buffer
446
- buf_uop.buffer.ref(1)
447
- sched_sink = UOp(Ops.SINK, src=tuple(sinks))
448
- # display, TODO: this isn't a complete sched_sink yet
449
- if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]))
388
+ rev_tensor_map.setdefault(v, []).append(k)
389
+ if k is v: continue
390
+ if v.base.op is Ops.BUFFER:
391
+ # VIEW isn't a valid tensor uop, we need to backtrack to the movement op that created it
392
+ if v.op is Ops.VIEW:
393
+ mop = [x for x in k.toposort if (xs:=tensor_map[x]).base is v.base and xs.st == v.st][0]
394
+ if k is not mop: becomes_map[k] = mop
395
+ else: becomes_map[k] = v
396
+ elif v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
397
+ # if we're not realizing this tensor, map its metadata to the simplified uop
398
+ elif isinstance(k.metadata, Metadata): ops_metadata[v] = k.metadata
399
+
400
+ # create kernels
401
+ if len(realize_map) == 0: return [], {}, becomes_map
402
+ kernel_map = graph_rewrite_map(sink, create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True)
403
+ sched_sink = kernel_map[sink]
450
404
  type_verify(list(sched_sink.toposort), kernel_spec)
451
405
 
452
- # convert kernels to ScheduleItem
453
- prescheduled = [kernel_to_si(k) for k in sched_sink.src]
454
- # add ScheduleItem children
455
- # TODO: this should construct the graph directly from the sched_sink
456
- schedule_targets = {out:si for si in prescheduled for out in si.outputs}
457
- graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)
458
- in_degree: defaultdict[ScheduleItem, int] = defaultdict(int)
459
- for si in prescheduled:
460
- # realize outputs before a parent is assigned to
461
- parents_assigns = dedup(xsi for x in ctx.preloads[si.bufs[0]] if (xsi:=schedule_targets.get(x.buffer)) and xsi is not si)
462
- for assign in parents_assigns:
463
- graph[si].append(assign)
464
- in_degree[assign] += 1
465
- # realize outputs after all parents are realized
466
- scheduled_parents = dedup(xsi for x in si.inputs if (xsi:=schedule_targets.get(x)) is not None and xsi not in parents_assigns)
467
- for x in scheduled_parents:
468
- graph[x].append(si)
469
- in_degree[si] += 1
470
-
471
- # do BFS
472
- queue = deque(si for si in prescheduled if in_degree[si] == 0)
406
+ # map realized tensors to buffers
407
+ for k,v in kernel_map.items():
408
+ if k is v or v.op is not Ops.ASSIGN: continue
409
+ for t in rev_tensor_map[k]: becomes_map[t] = t.src[0] if t.op is Ops.ASSIGN else v.buf_uop.reshape(t.shape)
410
+
411
+ # if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
412
+ kernel_assign: dict[UOp, UOp] = {}
413
+ assign_rep: dict[UOp, UOp] = {}
414
+ for u in sched_sink.toposort:
415
+ if u.op is not Ops.ASSIGN: continue
416
+ kernel_assign[u.buf_uop] = u
417
+ for s in u.src[1].src:
418
+ if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
419
+ if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort):
420
+ raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
421
+ assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
422
+ if assign_rep:
423
+ sched_sink = sched_sink.substitute(assign_rep)
424
+ type_verify(list(sched_sink.toposort), kernel_spec)
425
+
426
+ # display the final graph
427
+ if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
428
+
429
+ # final toposort (bfs)
430
+ children: dict[UOp, list[UOp]] = {}
431
+ in_degree: dict[UOp, int] = {}
432
+ for u in sched_sink.toposort:
433
+ if u.op is not Ops.ASSIGN: continue
434
+ in_degree[u] = 0
435
+ for s in u.src[1].src:
436
+ if s.op is not Ops.ASSIGN: continue
437
+ children.setdefault(s, []).append(u)
438
+ in_degree[u] += 1
439
+
440
+ queue = deque(k for k,v in in_degree.items() if v == 0)
473
441
  schedule: list[ScheduleItem] = []
442
+ var_vals: dict[Variable, int] = {}
474
443
  while queue:
475
- schedule.append(si:=queue.popleft())
476
- for x in graph[si]:
444
+ u = queue.popleft()
445
+ schedule.append(schedule_uop(u, var_vals))
446
+ # increment the refcount of the target buf (this is required by the JIT and memory planner)
447
+ u.buf_uop.buffer.ref(1)
448
+ for x in children.get(u, []):
477
449
  in_degree[x] -= 1
478
450
  if in_degree[x] == 0: queue.append(x)
451
+
479
452
  # confirm everything was scheduled correctly
480
- if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
453
+ if len(schedule) != (kc:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}")
481
454
  if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
482
455
  # capture process replay
483
456
  if CAPTURE_PROCESS_REPLAY:
484
- with Context(PICKLE_BUFFERS=0):
485
- diskcache_put("schedule_process_replay", str(big_sink.key), (big_sink, ContextVar._cache, [x.ast for x in schedule]))
457
+ with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, [x.ast for x in schedule]))
486
458
  return schedule, var_vals, becomes_map