tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,119 @@
1
+ from tinygrad.uop.ops import Ops, UOp, resolve, can_pad, GroupOp, UPat, PatternMatcher, graph_rewrite
2
+ from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, FUSE_CONV_BW
3
+ from tinygrad.shape.shapetracker import ShapeTracker
4
+
5
+ ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
6
+ Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
7
+ Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD}
8
+
9
+ # **** Grouper decides which of the UOps realize
10
+
11
+ def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
12
+
13
+ def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None:
14
+ for s in rb.src:
15
+ if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
16
+
17
+ def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
18
+ st = unwrap(view.st)
19
+ # always realize unsafe pad ops before masked view
20
+ if any(v.mask is not None for v in st.views) and not can_pad(tr, ctx): return realize(ctx, tr)
21
+ # fold simple pads
22
+ if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(tr.shape) and resolve(prod(tr.shape) >= prod([y-x for x,y in m])): return
23
+ # realize before expand
24
+ if resolve(prod(tr.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, tr)
25
+
26
+ do_realize = PatternMatcher([
27
+ # always realize SINK parents
28
+ (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
29
+ # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
30
+ (UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
31
+ # realize before expand or unsafe pad ops
32
+ (UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
33
+ # realize parents of COPY, MSELECT, MSTACK
34
+ (UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents),
35
+ ])
36
+
37
+ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:dict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
38
+ reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
39
+ if (tr, st) in cache: return
40
+ cache.setdefault((tr, st))
41
+ rsize = unwrap(r.st).size
42
+ if tr in realizes and tr is not r:
43
+ # can only fuse contiguous
44
+ # max one reduceop per kernel
45
+ if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
46
+ return group.setdefault(tr)
47
+ for tr_next in children.get(tr, {}):
48
+ # max one reduceop per kernel
49
+ if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
50
+ # can only fuse contiguous
51
+ if len(st_childs:=dedup(unwrap(x.st) for x in tr_next.src if x.base == tr)) > 1: return group.setdefault(r)
52
+ recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
53
+
54
+ def group_realizes(sink:UOp) -> dict[UOp, None]:
55
+ # start by adding uops that always realize
56
+ realizes: dict[UOp, None] = {}
57
+ sink = graph_rewrite(sink, do_realize, ctx=realizes, name="do_realize")
58
+ if DONT_GROUP_REDUCES: return realizes
59
+
60
+ # construct children graph (only for bases)
61
+ children: dict[UOp, dict[UOp, None]] = {}
62
+ assigns: dict[UOp, None] = {}
63
+ for u in (toposort:=sink.toposort()):
64
+ if u.op in {Ops.VIEW, Ops.SINK}: continue
65
+ if u.op is Ops.ASSIGN: assigns[u.buf_uop] = None
66
+ for s in u.src: children.setdefault(s.base, {})[u] = None
67
+
68
+ # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
69
+ reduce_for_op: dict[UOp, UOp] = {}
70
+ double_reduces: list[UOp] = []
71
+ for r in toposort:
72
+ if r.op is not Ops.REDUCE_AXIS: continue
73
+ if len(r.arg) == 3 and r.arg[2] is True: continue
74
+ if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
75
+ if r in realizes: continue
76
+ group: dict[UOp, None] = {}
77
+ recursive_group(r, unwrap(r.st), r, children, realizes, reduce_for_op, group, cache={})
78
+ # max one reduceop per kernel
79
+ can_chase = all(tr not in reduce_for_op for tr in group)
80
+ for u in r.toposort(gate=lambda u: u not in realizes):
81
+ if u.op is Ops.REDUCE_AXIS and u.src[0].base.op is Ops.CONST:
82
+ can_chase = False
83
+ break
84
+ # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
85
+ forced_realize = r in group
86
+ # can only have one output
87
+ if not forced_realize and len(group) > 1: forced_realize = True
88
+ # can only fuse assign if no other assign_target is used in the kernel
89
+ if not forced_realize and (assign_targets:={x.buf_uop for x in group if x.op is Ops.ASSIGN}):
90
+ parents = [r, *group]
91
+ while parents and not forced_realize:
92
+ p = parents.pop().base
93
+ if p.op is Ops.BUFFER and p in assigns and p not in assign_targets: forced_realize, can_chase = True, False
94
+ if p in realizes: continue
95
+ parents.extend(p.src)
96
+ if forced_realize or not group:
97
+ tr = r
98
+ if can_chase:
99
+ # can chase this down to contiguous children
100
+ st = unwrap(tr.st)
101
+ while len(lst:=children.get(tr, {})) == 1:
102
+ tr_next = next(iter(lst))
103
+ st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr)
104
+ if len(st_childs) > 1: break
105
+ if st.size != st_childs[0].size: break
106
+ st = st + st_childs[0]
107
+ if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
108
+ tr = tr_next
109
+ # don't cast to higher size before store (tr cannot be realized if forced_realize)
110
+ if tr.op is Ops.CAST and tr.dtype.itemsize > tr.src[0].dtype.itemsize:
111
+ tr = tr.src[0].base
112
+ group = {tr: None}
113
+ realizes[tr] = None
114
+ reduce_for_op.update((tr, r) for tr in group)
115
+ # fuse double reduces with no other child
116
+ for reduceop in double_reduces:
117
+ top_reduce = reduceop.src[0].base
118
+ if len(children.get(top_reduce, {})) == 1: del realizes[top_reduce]
119
+ return realizes
@@ -0,0 +1,368 @@
1
+ from dataclasses import dataclass
2
+ from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
3
+ from tinygrad.uop.ops import track_rewrites, _substitute
4
+ from tinygrad.uop.spec import type_verify, tensor_uop_spec
5
+ from tinygrad.uop.symbolic import symbolic_simple
6
+ from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
7
+ from tinygrad.dtype import ImageDType
8
+ from tinygrad.schedule.multi import multi_pm
9
+ from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
10
+ from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
11
+
12
+ # creation can recurse a lot
13
+ import sys
14
+ sys.setrecursionlimit(10000)
15
+
16
+ # **** schedule simplifier
17
+
18
+ def simplify_stride0_reduce(reduce:UOp, x:UOp):
19
+ # must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis)
20
+ if any(v.mask is not None for v in unwrap(x.st).views): return None
21
+ # must have all stride 0 in the relevant axis (NOTE: can do partial)
22
+ if not all(unwrap(x.st).views[-1].strides[axis] == 0 for axis in reduce.arg[1]) or not all_int(x.shape): return None
23
+ prshape = prod(x.shape[i] for i in reduce.arg[1])
24
+ ret = x.shrink(tuple((0,s) if i not in reduce.arg[1] else (0,1) for i,s in enumerate(x.shape)))
25
+ match reduce.arg[0]:
26
+ case Ops.ADD: return ret*prshape
27
+ case Ops.MUL: return ret.pow(prshape)
28
+ case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough
29
+
30
+ def split_reduceop(reduce:UOp, x:UOp):
31
+ if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape))<getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return None
32
+ # if there are few globals, make some reduces into globals by splitting into two kernels
33
+ # cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
34
+ # ~2**10 should be enough if GROUP is used
35
+ # 256 split maximum should be "negligible reduce" for low prod(reduce.shape), 8 split minimum.
36
+ # split is moved to the end to provide maximum locality for the second phase reduce.
37
+ real_strides = unwrap(x.st).real_strides(ignore_valid=True)
38
+ if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1)
39
+ if x.shape[i]%d==0 and real_strides[i]!=0]): return None
40
+ dim_to_split, divisor = split_candidates[0]
41
+ splitted_shape = x.shape[:dim_to_split]+(divisor,)+(x.shape[dim_to_split]//divisor,)+x.shape[dim_to_split+1:]
42
+ splitted = x.reshape(splitted_shape).permute(tuple([d for d in range(len(splitted_shape)) if d!=dim_to_split]+[dim_to_split]))
43
+ if DEBUG >= 3: print(f"split {divisor}: {x.shape} -> {splitted.shape} -> {reduce.shape}")
44
+ # reduce original axes, then split
45
+ return splitted.r(*reduce.arg).r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape)
46
+
47
+ def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
48
+ if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
49
+ return base.copy_to_device(copy.device).view(view.arg)
50
+
51
+ kernelize_sym = symbolic_simple+PatternMatcher([
52
+ # UOp with size 0 is zero
53
+ (UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
54
+ # DETACH and CONTIGUOUS_BACKWARD are NOOPs here
55
+ (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
56
+ # reduce of size 0 is the identity element
57
+ (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
58
+ lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
59
+ # reduce on stride 0 is collapsed
60
+ (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce),
61
+ # split_reduceop
62
+ (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
63
+ # COPY(CONST) creates a new CONST on the destination device
64
+ (UPat(Ops.COPY, name="root", src=(UPat.cvar("x"), UPat(Ops.DEVICE))), lambda root,x: root.const_like(x.arg)),
65
+ # non device changing COPY is a NOOP
66
+ (UPat(Ops.COPY, name="c", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda c,x: x if c.device == x.device else None),
67
+ # store a shrink before COPY, otherwise view after the COPY
68
+ (UPat(Ops.COPY, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"), UPat(Ops.DEVICE)), name="copy"), copy_reorder_view),
69
+ # remove cast to image when it's already a contiguous image
70
+ (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),
71
+ lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
72
+ # CAST before masking constants
73
+ (UPat.cvar("x").view().cast(name="c"), lambda x,c: x.cast(c.dtype).view(c.src[0].arg)),
74
+ # make things that can't be images not images
75
+ (UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
76
+ and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
77
+ # remove contiguous if we can just view the buffer
78
+ (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
79
+ lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
80
+ # contiguous/buffer/copy/assign is already contiguous
81
+ (UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]),
82
+ # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
83
+ (UPat((Ops.BITCAST, Ops.CONTIGUOUS), src=(UPat.var("x"),), name="t"), lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,),
84
+ (t.size, x.st.views[0].offset)).reshape(t.shape) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
85
+ # double ASSIGN to same target is one ASSIGN
86
+ (UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat.var("x"))))), lambda x,t: t.assign(x.contiguous())),
87
+ # ASSIGN to unrealized replaces the UOp
88
+ (UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat.var("x"))), lambda x,t: x.contiguous() if t.base.op not in {Ops.BUFFER, Ops.BUFFER_VIEW} and
89
+ not (t.base.op is Ops.MSTACK and all(x.op is Ops.BUFFER for x in t.base.src)) else None),
90
+ # put CAST to smaller dtype before EXPAND
91
+ (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st)
92
+ if cast.dtype.itemsize <= vm.dtype.itemsize and resolve(prod(vm.shape) > vm.st.real_size()) else None),
93
+ # put UnaryOps before EXPANDs, if it can fuse with the input
94
+ (UPat(GroupOp.Unary, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="inp"),), name="v"),), name="alu"),
95
+ lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
96
+ ])
97
+
98
+ # support for using a contiguous permuted view instead of the parent view if one exists
99
+
100
+ def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
101
+ if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
102
+
103
+ replace_contiguous = PatternMatcher([
104
+ (UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, name="src"),), name="contig"), found_contiguous),
105
+ (UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
106
+ ])
107
+
108
+ # **** create kernels
109
+
110
+ @dataclass(frozen=True)
111
+ class Kernel:
112
+ ast: UOp
113
+ metadata: tuple[Metadata, ...] = ()
114
+ def __repr__(self):
115
+ ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op)
116
+ return f"<Kernel {len(list(self.ast.toposort()))} {ast_rep} {self.metadata}>"
117
+
118
+ def create_kernel(x:UOp, b:UOp|None=None):
119
+ if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
120
+ kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ()))
121
+ buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
122
+ return buffer.assign(kernel).reshape(x.shape)
123
+
124
+ DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.MULTI, Ops.BIND}
125
+ def append_to_kernel(x:UOp):
126
+ new_srcs: list[UOp] = []
127
+ metadata = x.arg.metadata
128
+ for s in x.src:
129
+ if s.op in DONT_PLACE_IN_KERNEL: new_srcs.append(s)
130
+ else:
131
+ new_srcs.extend(s.src)
132
+ # NOTE: because const and device are shared UOps they don't change metadata
133
+ # NOTE: if it's a reshape after ASSIGN we're not fusing that parent kernel
134
+ if s.base.op not in {Ops.CONST, Ops.DEVICE} and (not (s.op is Ops.RESHAPE and s.base.op is Ops.ASSIGN)) and (m:=s.metadata): metadata += m
135
+ if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(dedup(metadata))))
136
+
137
+ create_kernels = PatternMatcher([
138
+ # always give assign/contiguous a kernel
139
+ (UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
140
+ (UPat(Ops.CONTIGUOUS, name="x"), create_kernel),
141
+ # walk back the local graph until we reach a realized source
142
+ (UPat(Ops.KERNEL, name="x"), append_to_kernel),
143
+ # push RESHAPE through MSELECT
144
+ (UPat(Ops.MSELECT, src=(UPat(Ops.RESHAPE, name="r"),), name="ms"), lambda ms,r: r.src[0].mselect(ms.arg).reshape(r.arg)),
145
+ # push RESHAPE through MSTACK
146
+ (UPat(Ops.MSTACK, src=UPat(Ops.RESHAPE), name="ms"),
147
+ lambda ms: UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).reshape(ms.src[0].arg)),
148
+ ])
149
+
150
+ # **** fix kernel AST
151
+
152
+ def unbind_view(x:UOp):
153
+ if any(x.op is Ops.BIND for x in x.arg.vars()): return x.replace(arg=x.arg.unbind()[0])
154
+ return None
155
+
156
+ replace_buffers = PatternMatcher([
157
+ # replace ASSIGN with the target BUFFER
158
+ (UPat(Ops.ASSIGN, src=(UPat((Ops.BUFFER, Ops.LOAD)), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),
159
+ # HACK: select the 0 branch of MSTACK (the device is wrong after this, is that okay?)
160
+ (UPat(Ops.MSTACK, name="x"), lambda x: x.src[0]),
161
+ # LOAD
162
+ (UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).load()),
163
+ # no SINK for meta ops
164
+ (UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
165
+ # STORE (except for meta ops)
166
+ (UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda ctx,sink:
167
+ UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)],
168
+ arg=sink.arg)),
169
+ # remove CONTIGUOUS/DEVICE from kernel AST
170
+ (UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
171
+ (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
172
+ # passthrough ASSIGN (but let MSTACK process first)
173
+ (UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.MSTACK}), UPat()), name="x"), lambda x: x.src[1]),
174
+ # remove any BINDs from VIEWS
175
+ (UPat(Ops.VIEW, src=(UPat(), UPat((Ops.BIND, Ops.DEFINE_VAR))), allow_any_len=True, name="x"), lambda x: x.replace(src=x.src[0:1])),
176
+ # remove any BINDs from DEFINE_VARs
177
+ (UPat(Ops.BIND, name="x"), lambda x: x.src[0]),
178
+ # remove BINDs from ShapeTrackers
179
+ (UPat(Ops.VIEW, name="x"), unbind_view),
180
+ ])
181
+
182
+ def fix_kernel_ast(k:UOp) -> UOp|None:
183
+ if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
184
+ # replace buffer with define_global + add load/store last
185
+ bufs = []
186
+ for s in k.src:
187
+ if s.op is Ops.BIND: continue
188
+ s = s.buf_uop
189
+ # traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
190
+ while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
191
+ bufs.append(s)
192
+ # replace global memory ops with the BUFFER they write to
193
+ # NOTE: merge_views is needed to unbind the reshapes
194
+ ast = graph_rewrite(k.arg.ast, merge_views+replace_buffers, bufs, bottom_up=True, name="replace buffers")
195
+ if ast.op is Ops.SINK and not all_same([x.device for x in k.src if x.op is not Ops.BIND]):
196
+ raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
197
+ return k.replace(arg=Kernel(ast, k.arg.metadata))
198
+
199
+ create_ast = PatternMatcher([
200
+ (UPat(Ops.KERNEL, name="k"), fix_kernel_ast),
201
+ (UPat(Ops.DEFINE_VAR, src=(UPat(),), allow_any_len=True, name="x"), lambda x: x.replace(src=())),
202
+ ])
203
+
204
+ # ** add metadata of KERNEL outputs
205
+
206
+ def append_metadata(root:UOp, k:UOp):
207
+ if not root.metadata or (new_metadata:=tuple(dedup(k.arg.metadata+root.metadata))) == k.arg.metadata: return None
208
+ return root.replace(src=(root.src[0], k.replace(arg=Kernel(k.arg.ast, new_metadata)))+root.src[2:])
209
+
210
+ replace_metadata = PatternMatcher([(UPat(Ops.ASSIGN, src=(UPat(), UPat(Ops.KERNEL, name="k")), name="root", allow_any_len=True), append_metadata),])
211
+
212
+ pm_fuse = PatternMatcher([
213
+ # FUSE on CONTIGUOUS removes FUSE
214
+ (UPat(Ops.CONTIGUOUS, name="c").fuse(), lambda c: c),
215
+
216
+ # FUSE triggers swizzle on reduceop
217
+ (UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").or_casted(),), name="view").fuse(),
218
+ lambda r,src,view: ret.cast(view.dtype) if (ret:=swizzle_reduceop(r, src, view, fuse=True)) is not None else None),
219
+
220
+ # FUSE on reduce (without view) adds fuse marker to grouper
221
+ (UPat(Ops.REDUCE_AXIS, name="r").fuse(),
222
+ lambda r: r.replace(src=(r.src[0].fuse(),), arg=r.arg+(True,)) if len(r.arg) == 2 else None),
223
+
224
+ # remove FUSE and insert CONTIGUOUS if it's an unsafe pad
225
+ (UPat(Ops.VIEW, src=(UPat(GroupOp.UnsafePad, name="alu"),), name="view").fuse(),
226
+ lambda alu, view: alu.contiguous().view(view.st) if any(v.mask is not None for v in view.st.views) else None),
227
+
228
+ # FUSE elementwise.
229
+ (UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST}, name="alu"),), name="view").fuse(),
230
+ lambda alu, view: alu.replace(src=tuple(apply_swizzle(x.view(view.arg)).fuse() for x in alu.src))),
231
+
232
+ # push FUSE through to srcs
233
+ (UPat(Ops.FUSE, name="x"), lambda x: x.src[0].replace(src=tuple(y.fuse() for y in x.src[0].src))),
234
+ ])
235
+
236
+ def do_fusion(x:UOp):
237
+ found_contiguous = {}
238
+ def gate_contiguous(x):
239
+ if is_contiguous:=(x.op is Ops.CONTIGUOUS): found_contiguous[x] = x.replace(src=(UOp(Ops.VIEW, arg=x.st), UOp.unique()))
240
+ return not is_contiguous
241
+ x.toposort(gate=gate_contiguous)
242
+ del gate_contiguous
243
+ return graph_rewrite(x.substitute(found_contiguous), pm_fuse, name="local fusion").substitute({v:k for k,v in found_contiguous.items()})
244
+
245
+ def fuse_arange(root:UOp):
246
+ # skip if root is arange
247
+ if not FUSE_ARANGE or root.src[0].base.op is Ops.CONST: return None
248
+ # gather all local aranges (including any fused ones)
249
+ local_arange: list[UOp] = []
250
+ def gate_reduce(u):
251
+ if u.op is Ops.REDUCE_AXIS and u.src[0].base.op is Ops.CONST: local_arange.append(u)
252
+ return u.op not in {*ALWAYS_CONTIGUOUS, Ops.REDUCE_AXIS} or u is root
253
+ toposort = root.toposort(gate=gate_reduce)
254
+ if not local_arange: return None
255
+ # fuse the nearest expand child of arange
256
+ local_children: dict[UOp, list[UOp]] = {}
257
+ for u in toposort:
258
+ for s in u.src: local_children.setdefault(s, []).append(u)
259
+ fuse_rep: dict[UOp, UOp] = {}
260
+ for r in local_arange:
261
+ # skip if already fused
262
+ if len(r.arg) > 2: continue
263
+ q = list(local_children[r])
264
+ while q:
265
+ u = q.pop()
266
+ if not (curr_children:=local_children.get(u, [])): continue
267
+ for child in curr_children:
268
+ other_paths = {s for s in child.toposort() if s.op in {Ops.REDUCE_AXIS, Ops.BUFFER} and s not in {root, r}}
269
+ fuse_rep[child] = child.replace(src=tuple(s.fuse() if s is u else s for s in child.src))
270
+ if other_paths: break
271
+ else: q.extend(curr_children)
272
+ return root.substitute(fuse_rep, name="fuse_arange") if fuse_rep else None
273
+
274
+ do_fuse = PatternMatcher([
275
+ (UPat(Ops.FUSE, name="x"), do_fusion),
276
+ (UPat(Ops.REDUCE_AXIS, name="root"), fuse_arange),
277
+ ])
278
+
279
+ add_contiguous = PatternMatcher([(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"),
280
+ lambda ctx,x: x.replace(tag=1).contiguous() if x in ctx and x.tag is None else None)])
281
+
282
+ # TODO: get this from the device through GrouperOpts
283
+ DEVICE_MAX_BUFS = {"METAL":32, "WEBGPU":8}
284
+
285
+ def limit_bufs(root:UOp):
286
+ # check if backend has a buffer limit
287
+ device = root.device if isinstance(root.device, str) else root.device[0].split(":")[0]
288
+ if not (MAX_BUFS:=getenv("MAX_KERNEL_BUFFERS", DEVICE_MAX_BUFS.get(device, 0))): return None
289
+ # count number of unique buffers flowing into this op
290
+ bufs: set[UOp] = set()
291
+ def gate_input(u:UOp):
292
+ if (is_load:=(u.op in {Ops.BUFFER, Ops.CONTIGUOUS, Ops.ASSIGN, Ops.MSTACK})): bufs.add(u)
293
+ return not is_load
294
+ root.toposort(gate=gate_input)
295
+ # NOTE: this -1 is for the output buffer
296
+ if len(bufs)>=MAX_BUFS-1:
297
+ return root.replace(src=tuple(s if s.base in bufs else s.replace(tag=1).contiguous() for s in root.src))
298
+
299
+ def view_add_srcs(x:UOp):
300
+ if len(avars:=x.arg.vars()) and len(x.src) == 1:
301
+ return x.replace(src=x.src+tuple(avars))
302
+ return None
303
+
304
+ finalize_contiguous = PatternMatcher([
305
+ # if an op takes more than one input, check combined LOADs don't exceed device limits
306
+ (UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs),
307
+ # merge contiguous
308
+ (UPat(Ops.CONTIGUOUS, src=(UPat(Ops.CONTIGUOUS),), name="x"), lambda x: x.src[0]),
309
+ # simplify views
310
+ (UPat(Ops.VIEW, src=(UPat.var('x')), name="v"), lambda x,v: x.view(new_st) if (new_st:=v.arg.simplify()) != v.arg else None),
311
+ # vars to views srcs
312
+ (UPat(Ops.VIEW, name="x"), view_add_srcs),
313
+ ])
314
+
315
+ remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
316
+
317
+ @track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True)
318
+ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
319
+ """
320
+ Function to transform the Tensor UOp graph into a version with Ops.KERNEL
321
+
322
+ Args:
323
+ sink: The Ops.SINK rooting the Tensor graph.
324
+
325
+ Returns:
326
+ Map transforming each UOp in the sink to the Ops.KERNEL graph.
327
+ """
328
+
329
+ # multi + merge_views + simplify
330
+ tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+kernelize_sym+replace_contiguous, ctx={}, name="merge_views")
331
+
332
+ # display the cleaned up tensor graph
333
+ if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
334
+
335
+ # insert contiguous in places determined by the realize map
336
+ realize_map = group_realizes(tensor_map[sink])
337
+ tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add_contiguous")
338
+ tensor_map = graph_rewrite_map(tensor_map[sink], finalize_contiguous+remove_tags, input_map=tensor_map, name="finalize_contiguous")
339
+
340
+ # group into kernels (this is context-free)
341
+ tensor_map = graph_rewrite_map(tensor_map[sink], create_kernels, input_map=tensor_map, name="create_kernels")
342
+
343
+ # if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
344
+ kernel_assign: dict[UOp, UOp] = {}
345
+ assign_rep: dict[UOp, UOp] = {}
346
+ for u in tensor_map[sink].toposort():
347
+ if u.op is not Ops.ASSIGN: continue
348
+ kernel_assign[u.buf_uop] = u
349
+ for s in u.src[1].src:
350
+ # TODO: this is probably broken for MSELECT/MSTACK
351
+ if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
352
+ if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()):
353
+ raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
354
+ assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
355
+ if assign_rep:
356
+ tensor_map = graph_rewrite_map(tensor_map[sink], _substitute, ctx=assign_rep, bottom_up=True, input_map=tensor_map, name="fix_assign")
357
+
358
+ # finally, create the AST for kernels
359
+ tensor_map = graph_rewrite_map(tensor_map[sink], create_ast+replace_metadata, bottom_up=True, input_map=tensor_map, name="create_ast")
360
+
361
+ # display the final graph
362
+ sched_sink = tensor_map[sink]
363
+ if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
364
+
365
+ # verify Kernels match the spec
366
+ if __debug__: type_verify(list(sched_sink.toposort()), tensor_uop_spec)
367
+
368
+ return tensor_map