tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
@@ -1,241 +1,170 @@
1
- import sys, atexit, functools, itertools
1
+ import sys, functools
2
2
  from collections import defaultdict, deque
3
3
  from dataclasses import dataclass, field
4
- from typing import Set, Tuple, List, Dict, Optional, DefaultDict, cast
5
- from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint
6
- from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
7
- from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG
8
- from tinygrad.dtype import ImageDType, dtypes
4
+ from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
5
+ from tinygrad.ops import can_pad, identity_element, resolve, symbolic_simple, view_left, merge_views
6
+ from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap, flatten
7
+ from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY
8
+ from tinygrad.dtype import ImageDType
9
9
  from tinygrad.shape.shapetracker import ShapeTracker
10
10
  from tinygrad.shape.view import View, strides_for_shape
11
- from tinygrad.engine.lazy import LazyBuffer
12
11
  from tinygrad.device import Buffer
12
+ from tinygrad.spec import type_verify, kernel_spec
13
13
 
14
14
  # creation can recurse a lot
15
15
  sys.setrecursionlimit(10000)
16
16
 
17
- BUF_LIMIT = {"METAL":32}
18
-
19
- # **** ScheduleItem return type
20
-
21
- @dataclass(frozen=True)
22
- class ScheduleItem:
23
- ast: UOp
24
- bufs: Tuple[Buffer, ...]
25
- metadata: Tuple[Metadata, ...]
26
- assign_preloads: Tuple[UOp, ...]
27
- @property
28
- def outputs(self) -> Tuple[Buffer, ...]:
29
- """Read/write or write only buffers in the schedule."""
30
- return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
31
- @property
32
- def inputs(self) -> Tuple[Buffer, ...]:
33
- """Read only buffers in the schedule."""
34
- return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
35
- @functools.cached_property
36
- def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
17
+ # **** schedule simplifier
18
+
19
+ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
20
+ if not all_int(x.shape): return None
21
+ # remove reduce on unmasked const
22
+ prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
23
+ ret = x.const_arg
24
+ match reduce.arg[0]:
25
+ case Ops.ADD: ret *= prshape
26
+ case Ops.MUL: ret **= prshape
27
+ case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
28
+ case _: return None
29
+ return reduce.const_like(ret)
30
+
31
+ def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
32
+ if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
33
+ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
34
+ new_src = list(alu.src)
35
+ for i,s in enumerate(alu.src):
36
+ if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
37
+ if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
38
+
39
+ sym = symbolic_simple+PatternMatcher([
40
+ # UOp with size 0 is zero
41
+ (UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
42
+ and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
43
+ # DETACH and CONTIGUOUS_BACKWARD are NOOPs here
44
+ (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
45
+ # reduce of size 0 is the identity element
46
+ (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
47
+ lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
48
+ # reduce of const is collapsed (TODO: make this a generic rule for stride0)
49
+ (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
50
+ # COPY(CONST) creates a new CONST on the destination device
51
+ (UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)),
52
+ # no COPY to same device, except clone (arg is True)
53
+ (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
54
+ lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
55
+ # remove cast to image when it's already a contiguous image
56
+ (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)),
57
+ lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
58
+ # remove contiguous if we can just view the buffer
59
+ (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
60
+ lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
61
+ # contiguous/buffer/copy is already contiguous
62
+ (UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY)),)), lambda root: root.src[0]),
63
+ # support for using a contiguous permuted view instead of the parent view if one exists
64
+ (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
65
+ (UPat(GroupOp.ALU, name="alu"), replace_contiguous),
66
+ # remove CONST/BIND/BUFFER from SINK
67
+ (UPat(Ops.SINK, name="root"),
68
+ lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
69
+ if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
70
+ ])
37
71
 
38
- # **** small wrapper for LazyBuffer -> UOp
72
+ remove_movement_ops = merge_views+PatternMatcher([
73
+ # NOTE: movement ops are always applied to base
74
+ (UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
75
+ # some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
76
+ (UPat(Ops.VIEW, name="view"),
77
+ lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
78
+ ])
39
79
 
40
- def UPatSrc(*args, **kwargs): return UPat(Ops.VIEW, src=(UPat.var("b"), UPat(*args, **{**kwargs, "name":"to_store"})), name="base")
41
- @functools.lru_cache(None)
42
- def is_scheduled(u:UOp): return u.op is Ops.VIEW and len(u.src) == 2
80
+ # **** UOp realization
43
81
 
44
82
  @dataclass(frozen=True)
45
83
  class ScheduleContext:
46
- ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata
47
- var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value
48
- assigns: Set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
49
- allbufs: Dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
50
- children: DefaultDict[UOp, Dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
51
-
52
- def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], lazybufs:Dict[Buffer, LazyBuffer], cache:Dict[LazyBuffer, UOp]) -> UOp:
84
+ ops_metadata: dict[UOp, Metadata] # this maps uops in the schedule to the tensor metadata
85
+ assigns: dict[UOp, None] = field(default_factory=dict) # this holds all the BUFFER uops we ASSIGN to in this schedule
86
+ realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
87
+ allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
88
+ children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
89
+ preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
90
+
91
+ # wrap tensor uops around a VIEW(BUFFER, <uop>)
92
+ # this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
93
+ def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp:
53
94
  if (r:=cache.get(buf)) is not None: return r
95
+ # SINK is passthrough
96
+ if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
97
+ # skip creating buffers for CONST/BIND/DEVICE/BUFFER
98
+ if buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf
99
+ if buf.base.op is Ops.BUFFER: return buf.view(unwrap(buf.st))
100
+ # VIEW is passthrough
54
101
  if buf is not buf.base:
55
- cache[buf] = ret = to_uop(buf.base, ctx, buffers, lazybufs, cache).view(buf.st)
102
+ cache[buf] = ret = add_buffers(buf.base, buffer_map, cache).view(unwrap(buf.st))
56
103
  return ret
57
104
  # make things that can't be images not images
58
- if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
59
- not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
60
- if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to {buf.dtype.base}")
61
- # hack the underlying buffer too
62
- buf.dtype = buf.buffer.dtype = buf.dtype.base
63
- assert not buf.is_realized, "can't fixup allocated buffer"
64
- buf.buffer.options = None
65
- dtype = buf.dtype if buf.op in GroupOp.Meta else buf.dtype.base
66
- # consts are always fused and generated
67
- if buf.op is Ops.CONST:
68
- if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()])
69
- return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(UOp.const(dtype, val), 0)
70
- # everything else is a VIEW of BUFFER (with an optional op)
71
- if buf.is_realized:
72
- buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
73
- op = None
74
- elif buf.op is Ops.ASSIGN:
75
- target, new_val = [to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs]
76
- ctx.assigns.add(ubuf:=target.buf_uop)
77
- op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg)
78
- else:
79
- buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
80
- op = UOp(cast(Ops, buf.op), dtype, tuple(to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs),
81
- None if buf.op in {Ops.CAST, Ops.BITCAST} else buf.arg)
82
- cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
83
- if op is not None:
84
- if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
85
- lazybufs[buf.buffer] = buf
86
- for x in op.src:
87
- if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[ubuf] = None
88
- ctx.allbufs[ubuf] = ret
105
+ dtype = buf.dtype
106
+ if isinstance(dtype, ImageDType) and (prod(buf.shape)!=prod(dtype.shape) or not any(buf.shape[x]%4==0 for x in unwrap(buf.st).unit_stride_axes())):
107
+ if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}")
108
+ dtype = buf.dtype.base
109
+ # ASSIGN already has a target buffer, otherwise we create a new one
110
+ assert isinstance(buf.device, str), f"buf device is str, not {buf.device}"
111
+ buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
112
+ op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
113
+ # track the buffer uop for the simplified uop
114
+ buffer_map[buf] = buf_uop
115
+ # (early) bufferize
116
+ cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
89
117
  return ret
90
118
 
91
- # **** AST graph rewrite
92
-
93
- # ** helpers for doing movementops on uops
94
-
95
- def apply_swizzle(u:UOp, arg:ShapeTracker) -> UOp:
96
- with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u.view(arg), view_left)
97
-
98
- def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
99
- permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
100
- tmp = input_st.permute(permute_axis)
101
- return tmp, tmp.shape[-len(axis):]
102
-
103
- # ** movementops rewrite rules
104
-
105
- def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
106
- tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(src.st).shape), r.axis_arg)
107
- prshape = prod(rshape)
108
- strides = strides_for_shape(rshape)
109
- nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
110
- v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
111
- # update input_st and axis
112
- new_input_st = tmp + ShapeTracker(tuple(nv))
113
- _, new_rshape = permute_reduce(new_input_st, r.axis_arg)
114
- new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
115
- return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
116
-
117
- def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp:
118
- swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st)
119
- assert swizzle_st.contiguous, "can't push a non contiguous VIEW down to STORE"
120
- assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
121
- output_shape = swizzle_st.reduce(root.axis_arg)
122
- new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u)
123
- return swizzle.src[0].r(root.arg[0], new_axis).view(ShapeTracker.from_shape(output_shape))
124
-
125
- def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
126
- swizzles = [x for x in root.src if x.base is not x]
127
- if len(swizzles) == 0: return None
128
- swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
129
- assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
130
- new_shape, new_input_shape = swizzle_shapes[0]
131
- ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src))
132
- return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
119
+ class UPatScheduled(UPat):
120
+ def __init__(self, *args, **kwargs):
121
+ super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
133
122
 
134
- def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
135
- assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
136
- assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
137
- return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
123
+ def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
138
124
 
139
- merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
140
-
141
- # push VIEW to loads
142
- view_left = merge_views+PatternMatcher([
143
- # VIEW before elementwise ops
144
- (UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
145
- # early merge VIEW buffer ops
146
- (UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.arg+v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
147
- ])
148
-
149
- # push VIEW to stores
150
- view_right = merge_views+PatternMatcher([
151
- # ASSIGN can override st
152
- (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))),
153
- lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None),
154
- # non contiguous VIEW on a reduce creates a new VIEW
155
- (UPat(Ops.REDUCE_AXIS, src=UPat.var("src"), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
156
- # push a VIEW down to STORE, through a reduce (ONLY reshapes)
157
- (UPat(Ops.REDUCE_AXIS, src=(UPat.var(name="src").view(name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
158
- # push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes)
159
- (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
160
- (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
161
- ])
162
-
163
- # ** ScheduleItem context builder
164
-
165
- @dataclass(frozen=True)
166
- class ScheduleItemContext:
167
- var_vals: Dict[Variable, int]
168
- assigned: Set[UOp]
169
- sts: Set[ShapeTracker] = field(default_factory=set)
170
- bufs: List[UOp] = field(default_factory=list)
171
- assign_preloads: List[UOp] = field(default_factory=list)
172
-
173
- def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
174
- if (st:=unwrap(x.st)) in ctx.sts: return None
175
- st, var_vals = st.simplify().unbind()
176
- ctx.var_vals.update(var_vals)
177
- ctx.sts.add(st)
178
- return st.to_uop() if st != x.st else None
179
-
180
- def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
181
- ctx.bufs.append(x)
182
- return UOp(Ops.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1)
183
- append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
184
-
185
- def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
186
- if b in ctx.assigned: ctx.assign_preloads.append(b)
187
- return x.replace(op=Ops.LOAD)
188
-
189
- to_si = PatternMatcher([
190
- (UPat(Ops.VIEW, name="x"), _append_st_vars),
191
- (UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),
192
- (UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))),
193
- ])
125
+ def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
126
+ st = unwrap(view.st)
127
+ # fold simple pads
128
+ if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
129
+ return None if can_pad(src, ctx.realizes, dict()) else realize(ctx, b, src)
130
+ # early realize before expand
131
+ if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
132
+ # otherwise safety check pads
133
+ return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, dict())) else realize(ctx, b, src)
194
134
 
195
- # ** fusion
135
+ def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
136
+ if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
137
+ buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
138
+ return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
196
139
 
197
- lazy = PatternMatcher([
198
- (UPatSrc(), lambda ctx,to_store,**kwargs: to_store),
199
- (UPat(Ops.BUFFER, name="b").view(name="view"), lambda ctx,b,view: UOp(Ops.PRELOAD, view.dtype, (b, view.st.to_uop()))),
200
- (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
140
+ do_realize = PatternMatcher([
141
+ # always realize SINK parents
142
+ (UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
143
+ # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
144
+ (UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
145
+ # realize before expand or unsafe pad ops
146
+ (UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view),
147
+ # realize before COPY or BUFFER_VIEW
148
+ (UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
149
+ (UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
150
+ # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
151
+ (UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
201
152
  ])
202
153
 
203
- multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])
204
-
205
- def full_ast_rewrite(pre:UOp, var_vals:Dict[Variable, int], assigned:Set[UOp]) -> Tuple[UOp, ScheduleItemContext]:
206
- # fuse and fold store -> loads
207
- sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, {x.buf_uop:x.src[2] for x in pre.src})
208
- # assert cyclic dependency
209
- for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.buf_uop in assigned), key=lambda x:x.buf_uop):
210
- if not all_same([x.op for x in ops]):
211
- raise RuntimeError(f"cycle detected in kernel.\nhelp: use .contiguous() to break the part loading pre-assign {b} into a different kernel.")
212
- # do movementops
213
- sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
214
- # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
215
- if len(assign_targets:=[x.buf_uop for x in sink.sparents if x.op is Ops.ASSIGN]) != 0:
216
- if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
217
- and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.buf_uop in assign_targets):
218
- raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
219
- +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
220
- # convert to AST
221
- sink = graph_rewrite(graph_rewrite(sink, to_si, ctx:=ScheduleItemContext(var_vals, assigned)), append_bufs, ctx)
222
- if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE.append(((pre, var_vals, assigned), sink))
223
- return sink, ctx
224
-
225
- PROCESS_REPLAY_CAPTURE: List[Tuple[Tuple, UOp]] = []
226
- if getenv("RUN_PROCESS_REPLAY"):
227
- @atexit.register
228
- def save_process_replay():
229
- for x,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(x[0].key), (x, {}, ret))
230
-
231
- # **** Schedule grouping
154
+ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
155
+ ctx.allbufs[buf_uop] = view
156
+ if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns[buf_uop] = None
157
+ for x in op.base.src:
158
+ if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
159
+ create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
232
160
 
161
+ def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER
233
162
  def uval(u:UOp) -> UOp:
234
163
  assert is_scheduled(u), f"must be a scheduled op {u}"
235
- return to_store.src[0] if (to_store:=u.src[1]).is_contiguous_base else to_store
164
+ return u.src[1]
236
165
 
237
- def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp], realizes:Dict[UOp, UOp],
238
- reduce_for_op:Dict[UOp, UOp], group:Dict[UOp, None], cache:Dict[Tuple[UOp, ShapeTracker], None]) -> None:
166
+ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp],
167
+ reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
239
168
  """recursively search the uop for groupable children, realize the UOp if a child can't group"""
240
169
  if (tr, st) in cache: return
241
170
  cache.setdefault((tr, st))
@@ -252,46 +181,32 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:DefaultDict[UOp, Di
252
181
  if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
253
182
  recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
254
183
 
255
- def get_isolated_children(r:UOp, reduce_for_op:Dict[UOp, UOp], children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp],
256
- realizes:Dict[UOp, UOp], group:Dict[UOp, None]) -> Dict[UOp, None]:
257
- rc_parents, cache = deque(group), set()
258
- while rc_parents:
259
- if (p:=uval(allbufs[rc_parents.pop()])) in cache: continue
260
- cache.add(p)
261
- # max one reduceop per kernel
262
- if p.op is Ops.REDUCE_AXIS: return {}
263
- rc_parents.extend(x.base.buf_uop for x in p.src if is_scheduled(x.base) and x.base.buf_uop is not r)
264
- # search descendants of the reduceop that can cleanly group
265
- descendants: Dict[UOp, None] = {}
266
- for tr in group: recursive_group(tr, unwrap(allbufs[tr].st), tr, children, allbufs, realizes, reduce_for_op, descendants, cache={})
267
- return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
268
-
269
- def group_realizes(ctx:ScheduleContext, realizes:Dict[UOp, UOp]) -> List[List[UOp]]:
270
- """search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
184
+ def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]:
185
+ # start by adding uops that always realize
186
+ sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
271
187
  # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
272
- reduce_for_op: Dict[UOp, UOp] = {}
273
- reduce_of_const: List[UOp] = []
274
- double_reduces: List[UOp] = []
188
+ reduce_for_op: dict[UOp, UOp] = {}
189
+ double_reduces: list[UOp] = []
275
190
  for r, r_uop in ctx.allbufs.items():
276
191
  if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue
277
- if FUSE_CONV_BW and r_uop.op is Ops.REDUCE_AXIS and uval((x:=r_uop.src[0]).base).op is r_uop.op and x.base is not x: double_reduces.append(r)
278
- if r in realizes: continue
279
- group: Dict[UOp, None] = {}
280
- recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, realizes, reduce_for_op, group, cache={})
192
+ if FUSE_CONV_BW and is_scheduled((x:=r_uop.src[0]).base) and uval(x.base).op is r_uop.op and x.base is not x: double_reduces.append(r)
193
+ if r in ctx.realizes: continue
194
+ group: dict[UOp, None] = {}
195
+ recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, ctx.realizes, reduce_for_op, group, cache={})
281
196
  # max one reduceop per kernel
282
197
  can_chase = all(tr not in reduce_for_op for tr in group)
283
198
  # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
284
199
  forced_realize = r in group
285
- if not forced_realize and len(group) > 1:
286
- group = get_isolated_children(r, reduce_for_op, ctx.children, ctx.allbufs, realizes, group)
200
+ # can only have one output
201
+ if not forced_realize and len(group) > 1: forced_realize = True
287
202
  # can only fuse assign if no other assign_target is used in the kernel
288
203
  if not forced_realize and any(x in ctx.assigns for x in group):
289
204
  parents = deque((r, *group))
290
205
  while parents and not forced_realize:
291
206
  if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue
292
207
  if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False
293
- if p in realizes: continue
294
- parents.extend([x.base.src[0] for x in p_uop.src if x.base.op is Ops.VIEW and len(x.base.src) != 0])
208
+ if p in ctx.realizes: continue
209
+ parents.extend([x.base.buf_uop for x in p_uop.src if x.base.is_realized or (x.base.op is Ops.VIEW and len(x.base.src) != 0)])
295
210
  if forced_realize or not group:
296
211
  tr = r
297
212
  if can_chase:
@@ -309,86 +224,241 @@ def group_realizes(ctx:ScheduleContext, realizes:Dict[UOp, UOp]) -> List[List[UO
309
224
  if (tr_uop:=uval(ctx.allbufs[tr])).op is Ops.CAST and tr_uop.dtype.base.itemsize > tr_uop.src[0].dtype.base.itemsize:
310
225
  tr = tr_uop.src[0].base.buf_uop
311
226
  group = {tr: None}
312
- realizes[tr] = tr
227
+ ctx.realizes[tr] = tr
313
228
  reduce_for_op.update((tr, r) for tr in group)
314
- if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.WHERE: reduce_of_const.append(r)
229
+ if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.CONST:
230
+ # maybe fuse arange with its children
231
+ if len(flatten(ctx.children[tr] for tr in group)) != 0:
232
+ for tr in group: del ctx.realizes[tr]
315
233
  # fuse double reduces with no other child
316
234
  for reduceop in double_reduces:
317
235
  top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
318
- if len(ctx.children[top_reduce]) == 1: del realizes[top_reduce]
319
- # maybe fuse arange with its children
320
- for rbuf in reduce_of_const:
321
- group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
322
- if any(ctx.allbufs[tr].src[1].is_contiguous_base for tr in group): continue
323
- kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
324
- if len(kernel_children) == 0: continue
325
- for tr in group: del realizes[tr]
326
- # group BUFFER uops into kernels
327
- output_groups: DefaultDict[UOp, List[UOp]] = defaultdict(list)
328
- for ubuf in realizes: output_groups[reduce_for_op.get(ubuf, ubuf)].append(ubuf)
329
- return list(output_groups.values())
330
-
331
- # **** Schedule creation and BFS toposort
332
-
333
- def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> UOp:
334
- ctx[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), to_store)
335
- return UOp(Ops.LOAD, base.dtype, (b, st.to_uop()))
336
-
337
- def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> Optional[UOp]:
338
- base_shape = unwrap(base.st).shape
339
- st = unwrap(view.st)
340
- # fold simple pads
341
- if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(base_shape) and resolve(prod(base_shape) >= prod([y-x for x,y in m])):
342
- return None if can_pad(base) else realize(ctx, b, to_store, base).view(st)
343
- # early realize before expand
344
- if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base).view(st)
345
- # otherwise safety check pads
346
- return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, b, to_store, base).view(st)
236
+ if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
237
+ graph_rewrite(sink, break_sched, ctx)
238
+ return ctx.realizes
239
+
240
+ # break the SINK into stores
241
+
242
+ def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
243
+ # NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
244
+ return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
245
+
246
+ def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
247
+ if (m:=ctx.ops_metadata.get(b)) is not None: ctx.ops_metadata[x] = m
248
+ if b not in ctx.realizes: return x # collapse BUFFER
249
+ ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
250
+ return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
251
+
252
+ break_sched = PatternMatcher([
253
+ # VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
254
+ (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
255
+ (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
256
+ ])
347
257
 
348
- do_realize = PatternMatcher([
349
- # always realize meta ops
350
- (UPatSrc((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta)), realize),
351
- # don't realize image to image casts
352
- (UPatSrc(Ops.CAST, src=(UPat(Ops.LOAD, name="x"),), dtype=dtypes.float).view(name="v"), lambda ctx,x,v,**kwargs: r.src[2].view(v.st)
353
- if (r:=ctx.get(b:=x.buf_uop)) is not None and r.op is Ops.STORE and isinstance(b.dtype, ImageDType) and r.src[2].op not in GroupOp.Meta else None),
354
- # realize before expand or unsafe pad ops
355
- (UPatSrc().view(name="view"), realize_view),
356
- # realize before COPY or BUFFER_VIEW
357
- (UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatSrc(), UPatSrc().view(name="view")),), name="root"),
358
- lambda ctx,root,view=None,**kwargs: root.replace(src=(realize(ctx,**kwargs) if view is None else realize(ctx,**kwargs).view(view.st),)),),
258
+ # **** convert Kernel to a ScheduleItem (for legacy reasons)
259
+
260
+ @dataclass(frozen=True)
261
+ class ScheduleItem:
262
+ ast: UOp
263
+ bufs: tuple[Buffer, ...]
264
+ metadata: tuple[Metadata, ...]
265
+ @property
266
+ def outputs(self) -> tuple[Buffer, ...]:
267
+ """Read/write or write only buffers in the schedule."""
268
+ return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
269
+ @property
270
+ def inputs(self) -> tuple[Buffer, ...]:
271
+ """Read only buffers in the schedule."""
272
+ return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
273
+ @functools.cached_property
274
+ def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
275
+
276
+ def kernel_to_si(k:UOp) -> ScheduleItem:
277
+ assert k.op is Ops.KERNEL, f"must be KERNEL {k}"
278
+ return ScheduleItem(k.arg.ast, tuple(u.buf_uop.buffer for u in k.src), k.arg.metadata)
279
+
280
+ # **** Kernel creation
281
+
282
+ @dataclass(frozen=True)
283
+ class Kernel:
284
+ ast: UOp
285
+ metadata: tuple[Metadata, ...]
286
+ def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
287
+
288
+ @dataclass(frozen=True)
289
+ class ScheduleItemContext:
290
+ var_vals: dict[Variable, int]
291
+ bufs: list[UOp] = field(default_factory=list)
292
+
293
+ def apply_swizzle(u:UOp) -> UOp:
294
+ with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
295
+
296
+ def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
297
+ input_st = ShapeTracker.from_shape(unwrap(src.st).shape)
298
+ tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
299
+ prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
300
+ strides = strides_for_shape(rshape)
301
+ nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
302
+ v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
303
+ # update input_st and axis
304
+ new_input_st = tmp + ShapeTracker(tuple(nv))
305
+ new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
306
+ return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
307
+
308
+ def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp:
309
+ if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}")
310
+ output_shape = swizzle_st.reduce(r.axis_arg)
311
+ return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape))
312
+
313
+ def elementwise_view_right(root:UOp) -> UOp|None:
314
+ if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None
315
+ assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}"
316
+ assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
317
+ # push the swizzle from src to root
318
+ output_swizzle = swizzles[0]
319
+ new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
320
+ ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
321
+ return ret.view(ShapeTracker.from_shape(output_swizzle.shape))
322
+
323
+ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
324
+ assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
325
+ assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
326
+ return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
327
+
328
+ # push VIEW to children
329
+ view_right = merge_views+PatternMatcher([
330
+ # STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
331
+ (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
332
+ lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
333
+ # STORE is the last child, so we just merge the ShapeTrackers and store the base
334
+ (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
335
+ # REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
336
+ (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
337
+ # REDUCE(src.view()) -> REDUCE(src).view()
338
+ (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right),
339
+ # ALU(src.view()) -> ALU(src).view()
340
+ (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right),
341
+ # double reduce op collapses to a single reduce op
342
+ (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
359
343
  ])
360
- break_sched = PatternMatcher([(UPatSrc(), lambda ctx,b,to_store,base: realize(ctx, b, to_store, base) if b in ctx else None),])
344
+
345
+ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
346
+ st = unwrap(x.st).simplify()
347
+ if any(x.op is Ops.BIND for x in st.vars()):
348
+ st, var_vals = st.unbind()
349
+ ctx.var_vals.update(var_vals)
350
+ return st.to_uop() if st != x.st else None
351
+
352
+ def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
353
+ ctx.bufs.append(x)
354
+ return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
355
+
356
+ to_si = PatternMatcher([
357
+ # BUFFER -> DEFINE_GLOBAL
358
+ (UPat(Ops.BUFFER, name="x"), _append_buf),
359
+ # simplify and unbind the final VIEWs
360
+ (UPat(Ops.VIEW, name="x"), _append_st_vars),
361
+ # don't need SINK on COPY or BUFFER_VIEW
362
+ (UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
363
+ # don't need contiguous or assign anymore
364
+ (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
365
+ (UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
366
+ # don't need DEVICE anymore
367
+ (UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
368
+ # PRELOAD becomes LOAD
369
+ (UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
370
+ # once images are loaded they become the base dtype
371
+ (UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
372
+ ])
373
+
374
+ def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
375
+ ctx[var.replace(src=())] = val.arg
376
+ return var
377
+ unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
378
+
379
+ def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> UOp:
380
+ # unbind_vars + push views to edges
381
+ sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=var_vals), view_right)
382
+ # remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
383
+ ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(var_vals))
384
+ # deal with ASSIGN
385
+ if len(ctx.assigns) != 0:
386
+ assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer]
387
+ for x in list(sink.toposort)[::-1]:
388
+ # we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
389
+ if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
390
+ # PRELOAD tells the toposort this kernel should run before ASSIGN
391
+ if x.op is Ops.PRELOAD:
392
+ assign_preloads[x.buf_uop] = None
393
+ # if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
394
+ if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous:
395
+ # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
396
+ if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass
397
+ # if it has a single view and it's equal when you shrink a contig, it's fine
398
+ elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass
399
+ # otherwise, it's not fine
400
+ else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
401
+ +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
402
+ # NOTE: we only add the metadata for fused tensors
403
+ metadata = tuple(dedup(m for x in pre.toposort if x.op is not Ops.BUFFER and (m:=ctx.ops_metadata.get(x)) is not None))
404
+ return UOp(Ops.KERNEL, src=tuple(si_ctx.bufs), arg=Kernel(ast, metadata))
405
+
406
+ # **** schedule creation and toposort
361
407
 
362
408
  @track_rewrites(named=True)
363
- def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
364
- if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {}
365
- for out in outs: out.forced_realize = True
366
- # create the big graph
367
- ctx = ScheduleContext()
368
- cache: Dict[LazyBuffer, UOp] = {}
369
- buffers: Dict[UOp, Buffer] = {}
370
- lazybufs: Dict[Buffer, LazyBuffer] = {}
371
- big_graph = UOp.sink(*(to_uop(x, ctx, buffers, lazybufs, cache) for x in outs))
409
+ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
410
+ tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
411
+ # tensors can become an existing buffer or simplify to a const, no ScheduleItem needed
412
+ becomes_map: dict[UOp, UOp] = {}
413
+ for k,v in tensor_map.items():
414
+ # NOOP
415
+ if k.base is v.base: continue
416
+ # NOTE: only the base tensors get a BUFFER UOp
417
+ if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st))
418
+ # otherwise if it simplified to a CONST the UOp just becomes that CONST
419
+ elif v.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
420
+
421
+ # we group the rest of UOps into ScheduleItems
422
+ buffer_map: dict[UOp, UOp] = {}
423
+ sink = add_buffers(tensor_map[big_sink], buffer_map, cache={})
372
424
  # get realizes
373
- realizes: Dict[UOp, UOp] = {}
374
- graph_rewrite(big_graph, do_realize, realizes)
375
- store_groups = group_realizes(ctx, realizes)
376
- # split realizes into small graphs
377
- graph_rewrite(big_graph, break_sched, realizes)
378
- sinks = [UOp.sink(*(realizes[u] for u in stores)) for stores in store_groups]
379
- # preschedule all realizes
380
- prescheduled: List[ScheduleItem] = []
381
- for sink in sinks:
382
- metadata = tuple({mx for x in sink.sparents if (x.op is Ops.STORE or is_scheduled(x)) and (mx:=ctx.ubuf_metadata.get(x.buf_uop))})
383
- ast, ast_ctx = full_ast_rewrite(sink, ctx.var_vals, ctx.assigns)
384
- prescheduled.append(ScheduleItem(ast, tuple(b for u in ast_ctx.bufs if (b:=buffers[u]).size != 0), metadata, tuple(ast_ctx.assign_preloads)))
385
- # do BFS
425
+ buf_tensors: dict[UOp, list[UOp]] = {}
426
+ ops_metadata: dict[UOp, Metadata] = {}
427
+ for k,v in tensor_map.items():
428
+ if (b:=buffer_map.get(v)) is not None:
429
+ buf_tensors.setdefault(b, []).append(k)
430
+ ops_metadata[b] = k.metadata
431
+ realize_map = group_realizes(sink, ctx:=ScheduleContext(ops_metadata))
432
+
433
+ # TODO: this should be the break between the "grouper" and the "linearizer"
434
+ # here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
435
+ # call into `def linearize_schedule(sched_sink:UOp) -> list[ScheduleItem]`
436
+
437
+ # create kernels + map buffers to realized tensors
438
+ sinks: list[UOp] = []
439
+ var_vals: dict[Variable, int] = {}
440
+ for buf_uop,store in realize_map.items():
441
+ assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}"
442
+ sinks.append(schedule_uop(store.sink(), ctx, var_vals))
443
+ # can only schedule once
444
+ for tensor_uop in buf_tensors[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
445
+ # increment refcount for this buffer
446
+ buf_uop.buffer.ref(1)
447
+ sched_sink = UOp(Ops.SINK, src=tuple(sinks))
448
+ # display, TODO: this isn't a complete sched_sink yet
449
+ if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]))
450
+ type_verify(list(sched_sink.toposort), kernel_spec)
451
+
452
+ # convert kernels to ScheduleItem
453
+ prescheduled = [kernel_to_si(k) for k in sched_sink.src]
454
+ # add ScheduleItem children
455
+ # TODO: this should construct the graph directly from the sched_sink
386
456
  schedule_targets = {out:si for si in prescheduled for out in si.outputs}
387
- graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)
388
- in_degree: DefaultDict[ScheduleItem, int] = defaultdict(int)
457
+ graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)
458
+ in_degree: defaultdict[ScheduleItem, int] = defaultdict(int)
389
459
  for si in prescheduled:
390
460
  # realize outputs before a parent is assigned to
391
- parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(buffers[x])) and xsi is not si)
461
+ parents_assigns = dedup(xsi for x in ctx.preloads[si.bufs[0]] if (xsi:=schedule_targets.get(x.buffer)) and xsi is not si)
392
462
  for assign in parents_assigns:
393
463
  graph[si].append(assign)
394
464
  in_degree[assign] += 1
@@ -397,23 +467,20 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
397
467
  for x in scheduled_parents:
398
468
  graph[x].append(si)
399
469
  in_degree[si] += 1
470
+
471
+ # do BFS
400
472
  queue = deque(si for si in prescheduled if in_degree[si] == 0)
401
- schedule: List[ScheduleItem] = []
473
+ schedule: list[ScheduleItem] = []
402
474
  while queue:
403
475
  schedule.append(si:=queue.popleft())
404
- for b in si.outputs: del lazybufs[b].srcs # can only schedule once
405
- if (m:=BUF_LIMIT.get(device:=si.outputs[0].device)) and len(si.bufs) >= m:
406
- if DEBUG >= 3: print(si)
407
- raise RuntimeError(f"Kernel for {si.metadata} exceeded the {m} buffer count limit for {device} with {len(si.bufs)} buffers.")
408
476
  for x in graph[si]:
409
477
  in_degree[x] -= 1
410
478
  if in_degree[x] == 0: queue.append(x)
411
479
  # confirm everything was scheduled correctly
412
480
  if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
413
481
  if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
414
- return schedule, ctx.var_vals
415
-
416
- def create_schedule(outs:List[LazyBuffer]) -> List[ScheduleItem]:
417
- schedule, var_vals = create_schedule_with_vars(outs)
418
- assert len(var_vals) == 0
419
- return schedule
482
+ # capture process replay
483
+ if CAPTURE_PROCESS_REPLAY:
484
+ with Context(PICKLE_BUFFERS=0):
485
+ diskcache_put("schedule_process_replay", str(big_sink.key), (big_sink, ContextVar._cache, [x.ast for x in schedule]))
486
+ return schedule, var_vals, becomes_map