tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -1,245 +1,122 @@
1
- import sys, atexit, functools, itertools
1
+ import sys, functools, atexit, pickle
2
2
  from collections import defaultdict, deque
3
- from dataclasses import dataclass, field
4
- from typing import Set, Tuple, List, Dict, Optional, DefaultDict, cast
5
- from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint
6
- from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
7
- from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG
8
- from tinygrad.dtype import ImageDType, dtypes
3
+ from dataclasses import dataclass
4
+ from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
5
+ from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views
6
+ from tinygrad.codegen.symbolic import symbolic_simple
7
+ from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv
8
+ from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND
9
+ from tinygrad.dtype import ImageDType
9
10
  from tinygrad.shape.shapetracker import ShapeTracker
10
11
  from tinygrad.shape.view import View, strides_for_shape
11
- from tinygrad.engine.lazy import LazyBuffer
12
12
  from tinygrad.device import Buffer
13
+ from tinygrad.spec import type_verify, kernel_spec
13
14
 
14
15
  # creation can recurse a lot
15
16
  sys.setrecursionlimit(10000)
16
17
 
17
- BUF_LIMIT = {"METAL":32}
18
-
19
- # **** ScheduleItem return type
20
-
21
- @dataclass(frozen=True)
22
- class ScheduleItem:
23
- ast: UOp
24
- bufs: Tuple[Buffer, ...]
25
- metadata: Tuple[Metadata, ...]
26
- assign_preloads: Tuple[UOp, ...]
27
- @property
28
- def outputs(self) -> Tuple[Buffer, ...]:
29
- """Read/write or write only buffers in the schedule."""
30
- return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
31
- @property
32
- def inputs(self) -> Tuple[Buffer, ...]:
33
- """Read only buffers in the schedule."""
34
- return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
35
- @functools.cached_property
36
- def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
37
-
38
- # **** small wrapper for LazyBuffer -> UOp
39
-
40
- def UPatSrc(*args, **kwargs): return UPat(Ops.VIEW, src=(UPat.var("b"), UPat(*args, **{**kwargs, "name":"to_store"})), name="base")
41
- @functools.lru_cache(None)
42
- def is_scheduled(u:UOp): return u.op is Ops.VIEW and len(u.src) == 2
43
-
44
- @dataclass(frozen=True)
45
- class ScheduleContext:
46
- ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata
47
- var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value
48
- assigns: Set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
49
- allbufs: Dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
50
- children: DefaultDict[UOp, Dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
51
-
52
- def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], lazybufs:Dict[Buffer, LazyBuffer], cache:Dict[LazyBuffer, UOp]) -> UOp:
53
- if (r:=cache.get(buf)) is not None: return r
54
- if buf is not buf.base:
55
- cache[buf] = ret = to_uop(buf.base, ctx, buffers, lazybufs, cache).view(buf.st)
56
- return ret
18
+ # **** schedule simplifier
19
+
20
+ def simplify_stride0_reduce(reduce:UOp, x:UOp):
21
+ # must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis)
22
+ if any(v.mask is not None for v in unwrap(x.st).views): return None
23
+ # must have all stride 0 in the relevant axis (NOTE: can do partial)
24
+ if not all(unwrap(x.st).views[-1].strides[axis] == 0 for axis in reduce.arg[1]) or not all_int(x.shape): return None
25
+ prshape = prod(x.shape[i] for i in reduce.arg[1])
26
+ ret = x.shrink(tuple((0,s) if i not in reduce.arg[1] else (0,1) for i,s in enumerate(x.shape)))
27
+ match reduce.arg[0]:
28
+ case Ops.ADD: return ret*prshape
29
+ case Ops.MUL: return ret.pow(prshape)
30
+ case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough
31
+
32
+ def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
33
+ if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
34
+ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
35
+ new_src = list(alu.src)
36
+ for i,s in enumerate(alu.src):
37
+ if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
38
+ if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
39
+
40
+ sym = symbolic_simple+PatternMatcher([
41
+ # UOp with size 0 is zero
42
+ (UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
43
+ and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
44
+ # DETACH and CONTIGUOUS_BACKWARD are NOOPs here
45
+ (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
46
+ # reduce of size 0 is the identity element
47
+ (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
48
+ lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
49
+ # reduce on stride 0 is collapsed
50
+ (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce),
51
+ # COPY(CONST) creates a new CONST on the destination device
52
+ (UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.arg)),
53
+ # no COPY to same device, except clone (arg is True)
54
+ (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
55
+ lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
56
+ # remove cast to image when it's already a contiguous image
57
+ (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"))),)),
58
+ lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
57
59
  # make things that can't be images not images
58
- if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
59
- not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
60
- if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to {buf.dtype.base}")
61
- # hack the underlying buffer too
62
- buf.dtype = buf.buffer.dtype = buf.dtype.base
63
- assert not buf.is_realized, "can't fixup allocated buffer"
64
- buf.buffer.options = None
65
- dtype = buf.dtype if buf.op in GroupOp.Meta else buf.dtype.base
66
- # consts are always fused and generated
67
- if buf.op is Ops.CONST:
68
- if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()])
69
- return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(UOp.const(dtype, val), 0)
70
- # everything else is a VIEW of BUFFER (with an optional op)
71
- if buf.is_realized:
72
- buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
73
- op = None
74
- elif buf.op is Ops.ASSIGN:
75
- target, new_val = [to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs]
76
- ctx.assigns.add(ubuf:=target.buf_uop)
77
- op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg)
78
- else:
79
- buffers[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(buffers))] = buf.buffer
80
- op = UOp(cast(Ops, buf.op), dtype, tuple(to_uop(x, ctx, buffers, lazybufs, cache) for x in buf.srcs),
81
- None if buf.op in {Ops.CAST, Ops.BITCAST} else buf.arg)
82
- cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
83
- if op is not None:
84
- if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
85
- lazybufs[buf.buffer] = buf
86
- for x in op.src:
87
- if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[ubuf] = None
88
- ctx.allbufs[ubuf] = ret
89
- return ret
90
-
91
- # **** AST graph rewrite
92
-
93
- # ** helpers for doing movementops on uops
94
-
95
- def apply_swizzle(u:UOp, arg:ShapeTracker) -> UOp:
96
- with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u.view(arg), view_left)
97
-
98
- def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
99
- permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
100
- tmp = input_st.permute(permute_axis)
101
- return tmp, tmp.shape[-len(axis):]
102
-
103
- # ** movementops rewrite rules
104
-
105
- def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
106
- tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(src.st).shape), r.axis_arg)
107
- prshape = prod(rshape)
108
- strides = strides_for_shape(rshape)
109
- nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
110
- v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
111
- # update input_st and axis
112
- new_input_st = tmp + ShapeTracker(tuple(nv))
113
- _, new_rshape = permute_reduce(new_input_st, r.axis_arg)
114
- new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
115
- return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
116
-
117
- def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp:
118
- swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st)
119
- assert swizzle_st.contiguous, "can't push a non contiguous VIEW down to STORE"
120
- assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
121
- output_shape = swizzle_st.reduce(root.axis_arg)
122
- new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u)
123
- return swizzle.src[0].r(root.arg[0], new_axis).view(ShapeTracker.from_shape(output_shape))
124
-
125
- def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
126
- swizzles = [x for x in root.src if x.base is not x]
127
- if len(swizzles) == 0: return None
128
- swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
129
- assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
130
- new_shape, new_input_shape = swizzle_shapes[0]
131
- ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src))
132
- return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
133
-
134
- def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
135
- assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
136
- assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
137
- return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
138
-
139
- merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
140
-
141
- # push VIEW to loads
142
- view_left = merge_views+PatternMatcher([
143
- # VIEW before elementwise ops
144
- (UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
145
- # early merge VIEW buffer ops
146
- (UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.arg+v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
60
+ (UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
61
+ and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
62
+ # remove contiguous if we can just view the buffer
63
+ (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
64
+ lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
65
+ # contiguous/buffer/copy is already contiguous
66
+ (UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY)),)), lambda root: root.src[0]),
67
+ # support for using a contiguous permuted view instead of the parent view if one exists
68
+ (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
69
+ (UPat(GroupOp.ALU, name="alu"), replace_contiguous),
70
+ # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
71
+ (UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"),
72
+ lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None),
147
73
  ])
148
74
 
149
- # push VIEW to stores
150
- view_right = merge_views+PatternMatcher([
151
- # ASSIGN can override st
152
- (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))),
153
- lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None),
154
- # non contiguous VIEW on a reduce creates a new VIEW
155
- (UPat(Ops.REDUCE_AXIS, src=UPat.var("src"), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
156
- # push a VIEW down to STORE, through a reduce (ONLY reshapes)
157
- (UPat(Ops.REDUCE_AXIS, src=(UPat.var(name="src").view(name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
158
- # push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes)
159
- (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
160
- (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
75
+ remove_movement_ops = merge_views+PatternMatcher([
76
+ # NOTE: movement ops are always applied to base
77
+ (UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
78
+ # some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
79
+ (UPat(Ops.VIEW, name="view"),
80
+ lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
161
81
  ])
162
82
 
163
- # ** ScheduleItem context builder
83
+ # **** UOp realization
164
84
 
165
85
  @dataclass(frozen=True)
166
- class ScheduleItemContext:
167
- var_vals: Dict[Variable, int]
168
- assigned: Set[UOp]
169
- sts: Set[ShapeTracker] = field(default_factory=set)
170
- bufs: List[UOp] = field(default_factory=list)
171
- assign_preloads: List[UOp] = field(default_factory=list)
172
-
173
- def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
174
- if (st:=unwrap(x.st)) in ctx.sts: return None
175
- st, var_vals = st.simplify().unbind()
176
- ctx.var_vals.update(var_vals)
177
- ctx.sts.add(st)
178
- return st.to_uop() if st != x.st else None
179
-
180
- def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
181
- ctx.bufs.append(x)
182
- return UOp(Ops.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1)
183
- append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
86
+ class GrouperContext:
87
+ assigns: dict[UOp, UOp] # maps realized buffers to assigns
88
+ realizes: dict[UOp, None] # all the simplified tensor uops we realize
89
+ children: defaultdict[UOp, dict[UOp, None]] # children graph of tensor uops
184
90
 
185
- def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
186
- if b in ctx.assigned: ctx.assign_preloads.append(b)
187
- return x.replace(op=Ops.LOAD)
91
+ def realize(ctx:GrouperContext, tr:UOp) -> None: ctx.realizes[tr] = None
188
92
 
189
- to_si = PatternMatcher([
190
- (UPat(Ops.VIEW, name="x"), _append_st_vars),
191
- (UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),
192
- (UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))),
193
- ])
194
-
195
- # ** fusion
93
+ def realize_before_view(ctx:GrouperContext, view:UOp, src:UOp) -> None:
94
+ st = unwrap(view.st)
95
+ # fold simple pads
96
+ if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
97
+ return None if can_pad(src, ctx.realizes, cache=dict()) else realize(ctx, src)
98
+ # early realize before expand
99
+ if resolve(prod(src.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, src)
100
+ # otherwise safety check pads
101
+ return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, cache=dict())) else realize(ctx, src)
196
102
 
197
- lazy = PatternMatcher([
198
- (UPatSrc(), lambda ctx,to_store,**kwargs: to_store),
199
- (UPat(Ops.BUFFER, name="b").view(name="view"), lambda ctx,b,view: UOp(Ops.PRELOAD, view.dtype, (b, view.st.to_uop()))),
200
- (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
103
+ do_realize = PatternMatcher([
104
+ # always realize SINK parents
105
+ (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x.base, None) for x in s.src if x.base.op not in {Ops.CONST, Ops.BIND, Ops.BUFFER})),
106
+ # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
107
+ (UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
108
+ # realize before expand or unsafe pad ops
109
+ (UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}, name="src"),)), realize_before_view),
110
+ # realize before COPY
111
+ (UPat(Ops.COPY, src=(UPat(), UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW}, name="tr"))), realize),
201
112
  ])
202
113
 
203
- multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])
204
-
205
- def full_ast_rewrite(pre:UOp, var_vals:Dict[Variable, int], assigned:Set[UOp]) -> Tuple[UOp, ScheduleItemContext]:
206
- # fuse and fold store -> loads
207
- sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, {x.buf_uop:x.src[2] for x in pre.src})
208
- # assert cyclic dependency
209
- for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.buf_uop in assigned), key=lambda x:x.buf_uop):
210
- if not all_same([x.op for x in ops]):
211
- raise RuntimeError(f"cycle detected in kernel.\nhelp: use .contiguous() to break the part loading pre-assign {b} into a different kernel.")
212
- # do movementops
213
- sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
214
- # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
215
- if len(assign_targets:=[x.buf_uop for x in sink.sparents if x.op is Ops.ASSIGN]) != 0:
216
- if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
217
- and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.buf_uop in assign_targets):
218
- raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
219
- +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
220
- # convert to AST
221
- sink = graph_rewrite(graph_rewrite(sink, to_si, ctx:=ScheduleItemContext(var_vals, assigned)), append_bufs, ctx)
222
- if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE.append(((pre, var_vals, assigned), sink))
223
- return sink, ctx
224
-
225
- PROCESS_REPLAY_CAPTURE: List[Tuple[Tuple, UOp]] = []
226
- if getenv("RUN_PROCESS_REPLAY"):
227
- @atexit.register
228
- def save_process_replay():
229
- for x,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(x[0].key), (x, {}, ret))
230
-
231
- # **** Schedule grouping
232
-
233
- def uval(u:UOp) -> UOp:
234
- assert is_scheduled(u), f"must be a scheduled op {u}"
235
- return to_store.src[0] if (to_store:=u.src[1]).is_contiguous_base else to_store
236
-
237
- def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp], realizes:Dict[UOp, UOp],
238
- reduce_for_op:Dict[UOp, UOp], group:Dict[UOp, None], cache:Dict[Tuple[UOp, ShapeTracker], None]) -> None:
114
+ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
115
+ reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
239
116
  """recursively search the uop for groupable children, realize the UOp if a child can't group"""
240
117
  if (tr, st) in cache: return
241
118
  cache.setdefault((tr, st))
242
- rsize = unwrap(allbufs[r].st).size
119
+ rsize = unwrap(r.st).size
243
120
  if tr in realizes and tr is not r:
244
121
  # can only fuse contiguous
245
122
  # max one reduceop per kernel
@@ -247,173 +124,335 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:DefaultDict[UOp, Di
247
124
  return group.setdefault(tr)
248
125
  for tr_next in children[tr]:
249
126
  # max one reduceop per kernel
250
- if (tr_next_uop:=uval(allbufs[tr_next]).base).op is Ops.REDUCE_AXIS: return group.setdefault(r)
127
+ if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
251
128
  # can only fuse contiguous
252
- if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
253
- recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
254
-
255
- def get_isolated_children(r:UOp, reduce_for_op:Dict[UOp, UOp], children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp],
256
- realizes:Dict[UOp, UOp], group:Dict[UOp, None]) -> Dict[UOp, None]:
257
- rc_parents, cache = deque(group), set()
258
- while rc_parents:
259
- if (p:=uval(allbufs[rc_parents.pop()])) in cache: continue
260
- cache.add(p)
261
- # max one reduceop per kernel
262
- if p.op is Ops.REDUCE_AXIS: return {}
263
- rc_parents.extend(x.base.buf_uop for x in p.src if is_scheduled(x.base) and x.base.buf_uop is not r)
264
- # search descendants of the reduceop that can cleanly group
265
- descendants: Dict[UOp, None] = {}
266
- for tr in group: recursive_group(tr, unwrap(allbufs[tr].st), tr, children, allbufs, realizes, reduce_for_op, descendants, cache={})
267
- return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
268
-
269
- def group_realizes(ctx:ScheduleContext, realizes:Dict[UOp, UOp]) -> List[List[UOp]]:
270
- """search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
129
+ if len(st_childs:=dedup(unwrap(x.st) for x in tr_next.src if x.base == tr)) > 1: return group.setdefault(r)
130
+ recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
131
+
132
+ def append_uop(ctx:GrouperContext, u:UOp) -> None:
133
+ if u.op is Ops.ASSIGN: ctx.assigns[u.buf_uop] = u
134
+ for s in u.src: ctx.children[s.base][u] = None
135
+ create_ctx = PatternMatcher([(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}, name="u"), append_uop)])
136
+
137
+ def group_realizes(sink:UOp) -> dict[UOp, None]:
138
+ # start by adding uops that always realize
139
+ sink = graph_rewrite(sink, do_realize+create_ctx, ctx:=GrouperContext({}, {}, defaultdict(dict)))
271
140
  # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
272
- reduce_for_op: Dict[UOp, UOp] = {}
273
- reduce_of_const: List[UOp] = []
274
- double_reduces: List[UOp] = []
275
- for r, r_uop in ctx.allbufs.items():
276
- if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue
277
- if FUSE_CONV_BW and r_uop.op is Ops.REDUCE_AXIS and uval((x:=r_uop.src[0]).base).op is r_uop.op and x.base is not x: double_reduces.append(r)
278
- if r in realizes: continue
279
- group: Dict[UOp, None] = {}
280
- recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, realizes, reduce_for_op, group, cache={})
141
+ reduce_for_op: dict[UOp, UOp] = {}
142
+ double_reduces: list[UOp] = []
143
+ for r in sink.toposort:
144
+ if r.op is not Ops.REDUCE_AXIS: continue
145
+ if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
146
+ if r in ctx.realizes: continue
147
+ group: dict[UOp, None] = {}
148
+ recursive_group(r, unwrap(r.st), r, ctx.children, ctx.realizes, reduce_for_op, group, cache={})
281
149
  # max one reduceop per kernel
282
150
  can_chase = all(tr not in reduce_for_op for tr in group)
283
151
  # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
284
152
  forced_realize = r in group
285
- if not forced_realize and len(group) > 1:
286
- group = get_isolated_children(r, reduce_for_op, ctx.children, ctx.allbufs, realizes, group)
153
+ # can only have one output
154
+ if not forced_realize and len(group) > 1: forced_realize = True
287
155
  # can only fuse assign if no other assign_target is used in the kernel
288
- if not forced_realize and any(x in ctx.assigns for x in group):
156
+ if not forced_realize and any(x.op is Ops.ASSIGN for x in group):
289
157
  parents = deque((r, *group))
290
158
  while parents and not forced_realize:
291
- if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue
292
- if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False
293
- if p in realizes: continue
294
- parents.extend([x.base.src[0] for x in p_uop.src if x.base.op is Ops.VIEW and len(x.base.src) != 0])
159
+ p = parents.pop().base
160
+ if (assign:=ctx.assigns.get(p)) is not None and assign not in group: forced_realize, can_chase = True, False
161
+ if p in ctx.realizes: continue
162
+ parents.extend(p.src)
295
163
  if forced_realize or not group:
296
164
  tr = r
297
165
  if can_chase:
298
166
  # can chase this down to contiguous children
299
- st = unwrap(r_uop.st)
167
+ st = unwrap(tr.st)
300
168
  while len(ctx.children[tr]) == 1:
301
- tr_next_uop = uval(ctx.allbufs[(tr_next:=next(iter(ctx.children[tr])))])
302
- st_childs = dedup([unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop is tr])
169
+ tr_next = next(iter(ctx.children[tr]))
170
+ st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr)
303
171
  if len(st_childs) > 1: break
304
172
  if st.size != st_childs[0].size: break
305
173
  st = st + st_childs[0]
306
- if not st.contiguous or tr_next_uop.op is Ops.REDUCE_AXIS: break
174
+ if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
307
175
  tr = tr_next
308
176
  # don't cast to higher size before store (tr cannot be realized if forced_realize)
309
- if (tr_uop:=uval(ctx.allbufs[tr])).op is Ops.CAST and tr_uop.dtype.base.itemsize > tr_uop.src[0].dtype.base.itemsize:
310
- tr = tr_uop.src[0].base.buf_uop
177
+ if tr.op is Ops.CAST and tr.dtype.itemsize > tr.src[0].dtype.itemsize:
178
+ tr = tr.src[0].base
311
179
  group = {tr: None}
312
- realizes[tr] = tr
180
+ ctx.realizes[tr] = None
313
181
  reduce_for_op.update((tr, r) for tr in group)
314
- if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.WHERE: reduce_of_const.append(r)
182
+ if FUSE_ARANGE and r.arg[0] is Ops.ADD and r.src[0].base.op is Ops.CONST:
183
+ # maybe fuse arange with its children
184
+ if len(flatten(ctx.children[tr] for tr in group)) != 0:
185
+ for tr in group: del ctx.realizes[tr]
315
186
  # fuse double reduces with no other child
316
187
  for reduceop in double_reduces:
317
- top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
318
- if len(ctx.children[top_reduce]) == 1: del realizes[top_reduce]
319
- # maybe fuse arange with its children
320
- for rbuf in reduce_of_const:
321
- group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
322
- if any(ctx.allbufs[tr].src[1].is_contiguous_base for tr in group): continue
323
- kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
324
- if len(kernel_children) == 0: continue
325
- for tr in group: del realizes[tr]
326
- # group BUFFER uops into kernels
327
- output_groups: DefaultDict[UOp, List[UOp]] = defaultdict(list)
328
- for ubuf in realizes: output_groups[reduce_for_op.get(ubuf, ubuf)].append(ubuf)
329
- return list(output_groups.values())
330
-
331
- # **** Schedule creation and BFS toposort
332
-
333
- def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> UOp:
334
- ctx[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), to_store)
335
- return UOp(Ops.LOAD, base.dtype, (b, st.to_uop()))
336
-
337
- def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> Optional[UOp]:
338
- base_shape = unwrap(base.st).shape
339
- st = unwrap(view.st)
340
- # fold simple pads
341
- if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(base_shape) and resolve(prod(base_shape) >= prod([y-x for x,y in m])):
342
- return None if can_pad(base) else realize(ctx, b, to_store, base).view(st)
343
- # early realize before expand
344
- if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base).view(st)
345
- # otherwise safety check pads
346
- return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, b, to_store, base).view(st)
188
+ top_reduce = reduceop.src[0].base
189
+ if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
190
+ return ctx.realizes
347
191
 
348
- do_realize = PatternMatcher([
349
- # always realize meta ops
350
- (UPatSrc((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta)), realize),
351
- # don't realize image to image casts
352
- (UPatSrc(Ops.CAST, src=(UPat(Ops.LOAD, name="x"),), dtype=dtypes.float).view(name="v"), lambda ctx,x,v,**kwargs: r.src[2].view(v.st)
353
- if (r:=ctx.get(b:=x.buf_uop)) is not None and r.op is Ops.STORE and isinstance(b.dtype, ImageDType) and r.src[2].op not in GroupOp.Meta else None),
354
- # realize before expand or unsafe pad ops
355
- (UPatSrc().view(name="view"), realize_view),
356
- # realize before COPY or BUFFER_VIEW
357
- (UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatSrc(), UPatSrc().view(name="view")),), name="root"),
358
- lambda ctx,root,view=None,**kwargs: root.replace(src=(realize(ctx,**kwargs) if view is None else realize(ctx,**kwargs).view(view.st),)),),
192
+ # break the SINK into kernels
193
+
194
+ @dataclass(frozen=True)
195
+ class Kernel:
196
+ ast: UOp
197
+ metadata: tuple[Metadata, ...]
198
+ def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
199
+
200
+ @dataclass(frozen=True)
201
+ class KernelContext:
202
+ realizes: dict[UOp, None]
203
+ ops_metadata: dict[UOp, Metadata]
204
+
205
+ def create_kernel(ctx:KernelContext, x:UOp):
206
+ if x not in ctx.realizes: return None
207
+ assert isinstance(x.device, str), f"buf device in kernel must be string {x.device}"
208
+ b = x.buf_uop if x.op is Ops.ASSIGN else UOp.new_buffer(x.device, x.size, x.dtype)
209
+ output_st = ShapeTracker.from_shape(x.shape)
210
+ # KERNEL nodes become: ASSIGN(VIEW(BUFFER), KERNEL)
211
+ # TODO: this should be ASSIGN(BUFFER, KERNEL) followed by the output ShapeTracker
212
+ return b.view(output_st).assign(UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x, (m,) if (m:=ctx.ops_metadata.get(x)) else ())))
213
+
214
+ def append_to_kernel(ctx:KernelContext, x:UOp):
215
+ new_srcs: list[UOp] = []
216
+ new_metadata: dict[Metadata, None] = dict.fromkeys(x.arg.metadata)
217
+ for s in x.src:
218
+ if s.op is Ops.BUFFER or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL) or s in ctx.realizes: new_srcs.append(s)
219
+ else:
220
+ new_srcs.extend(s.src)
221
+ if (m:=ctx.ops_metadata.get(s)) is not None: new_metadata[m] = None
222
+ return x.replace(src=n, arg=Kernel(x.arg.ast, tuple(new_metadata))) if (n:=tuple(dedup(new_srcs))) != x.src else None
223
+
224
+ create_kernels = merge_views+PatternMatcher([
225
+ (UPat(GroupOp.All-{Ops.KERNEL, Ops.BUFFER}, name="x"), create_kernel),
226
+ (UPat(Ops.KERNEL, name="x"), append_to_kernel),
227
+ ])
228
+
229
+ # **** convert Kernel to a ScheduleItem (for legacy reasons)
230
+
231
+ @dataclass(frozen=True)
232
+ class ScheduleItem:
233
+ ast: UOp
234
+ bufs: tuple[Buffer, ...]
235
+ metadata: tuple[Metadata, ...]
236
+ @property
237
+ def outputs(self) -> tuple[Buffer, ...]:
238
+ """Read/write or write only buffers in the schedule."""
239
+ return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
240
+ @property
241
+ def inputs(self) -> tuple[Buffer, ...]:
242
+ """Read only buffers in the schedule."""
243
+ return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
244
+ @functools.cached_property
245
+ def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
246
+
247
+ # **** Kernel creation
248
+
249
+ def apply_swizzle(u:UOp) -> UOp:
250
+ with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
251
+
252
+ def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
253
+ input_st = ShapeTracker.from_shape(unwrap(src.st).shape)
254
+ tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
255
+ prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
256
+ strides = strides_for_shape(rshape)
257
+ nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
258
+ v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
259
+ # update input_st and axis
260
+ new_input_st = tmp + ShapeTracker(tuple(nv))
261
+ new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
262
+ return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
263
+
264
+ def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp:
265
+ if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}")
266
+ output_shape = swizzle_st.reduce(r.axis_arg)
267
+ return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape))
268
+
269
+ def elementwise_view_right(root:UOp) -> UOp|None:
270
+ if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None
271
+ assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}"
272
+ assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
273
+ # push the swizzle from src to root
274
+ output_swizzle = swizzles[0]
275
+ new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
276
+ ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
277
+ return ret.view(ShapeTracker.from_shape(output_swizzle.shape))
278
+
279
+ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
280
+ assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
281
+ assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
282
+ return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
283
+
284
+ # push VIEW to children
285
+ view_right = merge_views+PatternMatcher([
286
+ # STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
287
+ (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
288
+ lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
289
+ # STORE is the last child, so we just merge the ShapeTrackers and store the base
290
+ (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
291
+ # REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
292
+ (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
293
+ # REDUCE(src.view()) -> REDUCE(src).view()
294
+ (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right),
295
+ # ALU(src.view()) -> ALU(src).view()
296
+ (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right),
297
+ # double reduce op collapses to a single reduce op
298
+ (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
359
299
  ])
360
- break_sched = PatternMatcher([(UPatSrc(), lambda ctx,b,to_store,base: realize(ctx, b, to_store, base) if b in ctx else None),])
300
+
301
+ def _append_st_vars(ctx:dict[Variable, int], x:UOp) -> UOp|None:
302
+ st = unwrap(x.st).simplify()
303
+ if any(x.op is Ops.BIND for x in st.vars()):
304
+ st, var_vals = st.unbind()
305
+ ctx.update(var_vals)
306
+ return st.to_uop() if st != x.st else None
307
+
308
+ def check_load_st(glbl:UOp, view:UOp):
309
+ if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
310
+ # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
311
+ if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
312
+ # if it has a single view and it's equal when you shrink a contig, it's fine
313
+ if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
314
+ # otherwise, it's not fine
315
+ raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
316
+ +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
317
+
318
+ fix_kernel_ops = PatternMatcher([
319
+ # BIND in shapetracker becomes DEFINE_VAR
320
+ (UPat(Ops.VIEW, name="x"), _append_st_vars),
321
+ # remove CONTIGUOUS/ASSIGN/DEVICE
322
+ (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
323
+ (UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
324
+ (UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
325
+ # no ImageDType after load
326
+ (UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
327
+ # if this kernel also assigns to the loaded buffer, ensure we can index it correctly
328
+ (UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
329
+ ])
330
+
331
+ def load_buf(ctx:list[UOp], x:UOp):
332
+ if x not in ctx: ctx.append(x)
333
+ return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop()))
334
+
335
+ add_buffer_ops = PatternMatcher([
336
+ # LOAD
337
+ (UPat(Ops.BUFFER, name="x"), load_buf),
338
+ # STORE (except for COPY/BUFFER_VIEW)
339
+ (UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
340
+ (UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
341
+ lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
342
+ ])
343
+
344
+ def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
345
+ ctx[var.replace(src=())] = val.arg
346
+ return var
347
+ unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
348
+
349
+ def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
350
+ assert sink.op is Ops.ASSIGN and sink.src[1].op is Ops.KERNEL, f"{sink} must be ASSIGN"
351
+ # substitute kernel sources for the target buffer
352
+ ast = sink.src[1].arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in sink.src[1].src if s.op is Ops.ASSIGN}).sink()
353
+ # add buffer ops
354
+ ast = graph_rewrite(ast, add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True)
355
+ # unbind_vars + push views to edges
356
+ ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
357
+ # fix_kernel_ops
358
+ ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
359
+ # create subbuffer
360
+ if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = bufs[1].buffer.view(ast.size, ast.dtype, (x:=ast.src[0]).st_arg.views[0].offset*x.dtype.itemsize)
361
+ return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), sink.src[1].arg.metadata)
362
+
363
+ PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
364
+ if CAPTURE_PROCESS_REPLAY:
365
+ @atexit.register
366
+ def save_process_replay():
367
+ for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
368
+
369
+ # **** schedule creation and toposort
361
370
 
362
371
  @track_rewrites(named=True)
363
- def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
364
- if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {}
365
- for out in outs: out.forced_realize = True
366
- # create the big graph
367
- ctx = ScheduleContext()
368
- cache: Dict[LazyBuffer, UOp] = {}
369
- buffers: Dict[UOp, Buffer] = {}
370
- lazybufs: Dict[Buffer, LazyBuffer] = {}
371
- big_graph = UOp.sink(*(to_uop(x, ctx, buffers, lazybufs, cache) for x in outs))
372
- # get realizes
373
- realizes: Dict[UOp, UOp] = {}
374
- graph_rewrite(big_graph, do_realize, realizes)
375
- store_groups = group_realizes(ctx, realizes)
376
- # split realizes into small graphs
377
- graph_rewrite(big_graph, break_sched, realizes)
378
- sinks = [UOp.sink(*(realizes[u] for u in stores)) for stores in store_groups]
379
- # preschedule all realizes
380
- prescheduled: List[ScheduleItem] = []
381
- for sink in sinks:
382
- metadata = tuple({mx for x in sink.sparents if (x.op is Ops.STORE or is_scheduled(x)) and (mx:=ctx.ubuf_metadata.get(x.buf_uop))})
383
- ast, ast_ctx = full_ast_rewrite(sink, ctx.var_vals, ctx.assigns)
384
- prescheduled.append(ScheduleItem(ast, tuple(b for u in ast_ctx.bufs if (b:=buffers[u]).size != 0), metadata, tuple(ast_ctx.assign_preloads)))
385
- # do BFS
386
- schedule_targets = {out:si for si in prescheduled for out in si.outputs}
387
- graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)
388
- in_degree: DefaultDict[ScheduleItem, int] = defaultdict(int)
389
- for si in prescheduled:
390
- # realize outputs before a parent is assigned to
391
- parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(buffers[x])) and xsi is not si)
392
- for assign in parents_assigns:
393
- graph[si].append(assign)
394
- in_degree[assign] += 1
395
- # realize outputs after all parents are realized
396
- scheduled_parents = dedup(xsi for x in si.inputs if (xsi:=schedule_targets.get(x)) is not None and xsi not in parents_assigns)
397
- for x in scheduled_parents:
398
- graph[x].append(si)
399
- in_degree[si] += 1
400
- queue = deque(si for si in prescheduled if in_degree[si] == 0)
401
- schedule: List[ScheduleItem] = []
372
+ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
373
+ # remove_movement_ops + sym
374
+ tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
375
+
376
+ # display the cleaned up tensor graph
377
+ if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph")
378
+
379
+ # do_realize + group_realizes
380
+ sink = tensor_map[big_sink]
381
+ realize_map = group_realizes(sink)
382
+
383
+ # map tensors to new uops
384
+ becomes_map: dict[UOp, UOp] = {}
385
+ rev_tensor_map: dict[UOp, list[UOp]] = {}
386
+ ops_metadata: dict[UOp, Metadata] = {}
387
+ for k,v in tensor_map.items():
388
+ rev_tensor_map.setdefault(v, []).append(k)
389
+ if k is v: continue
390
+ if v.base.op is Ops.BUFFER:
391
+ # VIEW isn't a valid tensor uop, we need to backtrack to the movement op that created it
392
+ if v.op is Ops.VIEW:
393
+ mop = [x for x in k.toposort if (xs:=tensor_map[x]).base is v.base and xs.st == v.st][0]
394
+ if k is not mop: becomes_map[k] = mop
395
+ else: becomes_map[k] = v
396
+ elif v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
397
+ # if we're not realizing this tensor, map its metadata to the simplified uop
398
+ elif isinstance(k.metadata, Metadata): ops_metadata[v] = k.metadata
399
+
400
+ # create kernels
401
+ if len(realize_map) == 0: return [], {}, becomes_map
402
+ kernel_map = graph_rewrite_map(sink, create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True)
403
+ sched_sink = kernel_map[sink]
404
+ type_verify(list(sched_sink.toposort), kernel_spec)
405
+
406
+ # map realized tensors to buffers
407
+ for k,v in kernel_map.items():
408
+ if k is v or v.op is not Ops.ASSIGN: continue
409
+ for t in rev_tensor_map[k]: becomes_map[t] = t.src[0] if t.op is Ops.ASSIGN else v.buf_uop.reshape(t.shape)
410
+
411
+ # if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
412
+ kernel_assign: dict[UOp, UOp] = {}
413
+ assign_rep: dict[UOp, UOp] = {}
414
+ for u in sched_sink.toposort:
415
+ if u.op is not Ops.ASSIGN: continue
416
+ kernel_assign[u.buf_uop] = u
417
+ for s in u.src[1].src:
418
+ if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
419
+ if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort):
420
+ raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
421
+ assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
422
+ if assign_rep:
423
+ sched_sink = sched_sink.substitute(assign_rep)
424
+ type_verify(list(sched_sink.toposort), kernel_spec)
425
+
426
+ # display the final graph
427
+ if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
428
+
429
+ # final toposort (bfs)
430
+ children: dict[UOp, list[UOp]] = {}
431
+ in_degree: dict[UOp, int] = {}
432
+ for u in sched_sink.toposort:
433
+ if u.op is not Ops.ASSIGN: continue
434
+ in_degree[u] = 0
435
+ for s in u.src[1].src:
436
+ if s.op is not Ops.ASSIGN: continue
437
+ children.setdefault(s, []).append(u)
438
+ in_degree[u] += 1
439
+
440
+ queue = deque(k for k,v in in_degree.items() if v == 0)
441
+ schedule: list[ScheduleItem] = []
442
+ var_vals: dict[Variable, int] = {}
402
443
  while queue:
403
- schedule.append(si:=queue.popleft())
404
- for b in si.outputs: del lazybufs[b].srcs # can only schedule once
405
- if (m:=BUF_LIMIT.get(device:=si.outputs[0].device)) and len(si.bufs) >= m:
406
- if DEBUG >= 3: print(si)
407
- raise RuntimeError(f"Kernel for {si.metadata} exceeded the {m} buffer count limit for {device} with {len(si.bufs)} buffers.")
408
- for x in graph[si]:
444
+ u = queue.popleft()
445
+ schedule.append(schedule_uop(u, var_vals))
446
+ # increment the refcount of the target buf (this is required by the JIT and memory planner)
447
+ u.buf_uop.buffer.ref(1)
448
+ for x in children.get(u, []):
409
449
  in_degree[x] -= 1
410
450
  if in_degree[x] == 0: queue.append(x)
451
+
411
452
  # confirm everything was scheduled correctly
412
- if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
453
+ if len(schedule) != (kc:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}")
413
454
  if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
414
- return schedule, ctx.var_vals
415
-
416
- def create_schedule(outs:List[LazyBuffer]) -> List[ScheduleItem]:
417
- schedule, var_vals = create_schedule_with_vars(outs)
418
- assert len(var_vals) == 0
419
- return schedule
455
+ # capture process replay
456
+ if CAPTURE_PROCESS_REPLAY:
457
+ with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, [x.ast for x in schedule]))
458
+ return schedule, var_vals, becomes_map