tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -1,458 +1,83 @@
1
- import sys, functools, atexit, pickle
2
- from collections import defaultdict, deque
3
- from dataclasses import dataclass
4
- from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
5
- from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views
6
- from tinygrad.codegen.symbolic import symbolic_simple
7
- from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv
8
- from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND
9
- from tinygrad.dtype import ImageDType
10
- from tinygrad.shape.shapetracker import ShapeTracker
11
- from tinygrad.shape.view import View, strides_for_shape
12
- from tinygrad.device import Buffer
13
- from tinygrad.spec import type_verify, kernel_spec
1
+ from typing import cast
2
+ from dataclasses import dataclass, field
3
+ from collections import deque, defaultdict
4
+ from tinygrad.uop.ops import UOp, Variable, Ops, buffers
5
+ from tinygrad.device import Device, Buffer, MultiBuffer
6
+ from tinygrad.helpers import Metadata, all_same
14
7
 
15
- # creation can recurse a lot
16
- sys.setrecursionlimit(10000)
17
-
18
- # **** schedule simplifier
19
-
20
- def simplify_stride0_reduce(reduce:UOp, x:UOp):
21
- # must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis)
22
- if any(v.mask is not None for v in unwrap(x.st).views): return None
23
- # must have all stride 0 in the relevant axis (NOTE: can do partial)
24
- if not all(unwrap(x.st).views[-1].strides[axis] == 0 for axis in reduce.arg[1]) or not all_int(x.shape): return None
25
- prshape = prod(x.shape[i] for i in reduce.arg[1])
26
- ret = x.shrink(tuple((0,s) if i not in reduce.arg[1] else (0,1) for i,s in enumerate(x.shape)))
27
- match reduce.arg[0]:
28
- case Ops.ADD: return ret*prshape
29
- case Ops.MUL: return ret.pow(prshape)
30
- case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough
31
-
32
- def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
33
- if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
34
- def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
35
- new_src = list(alu.src)
36
- for i,s in enumerate(alu.src):
37
- if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
38
- if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
39
-
40
- sym = symbolic_simple+PatternMatcher([
41
- # UOp with size 0 is zero
42
- (UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
43
- and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
44
- # DETACH and CONTIGUOUS_BACKWARD are NOOPs here
45
- (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
46
- # reduce of size 0 is the identity element
47
- (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
48
- lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
49
- # reduce on stride 0 is collapsed
50
- (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce),
51
- # COPY(CONST) creates a new CONST on the destination device
52
- (UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.arg)),
53
- # no COPY to same device, except clone (arg is True)
54
- (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
55
- lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
56
- # remove cast to image when it's already a contiguous image
57
- (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"))),)),
58
- lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
59
- # make things that can't be images not images
60
- (UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
61
- and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
62
- # remove contiguous if we can just view the buffer
63
- (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
64
- lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
65
- # contiguous/buffer/copy is already contiguous
66
- (UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY)),)), lambda root: root.src[0]),
67
- # support for using a contiguous permuted view instead of the parent view if one exists
68
- (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
69
- (UPat(GroupOp.ALU, name="alu"), replace_contiguous),
70
- # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
71
- (UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"),
72
- lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None),
73
- ])
74
-
75
- remove_movement_ops = merge_views+PatternMatcher([
76
- # NOTE: movement ops are always applied to base
77
- (UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
78
- # some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
79
- (UPat(Ops.VIEW, name="view"),
80
- lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
81
- ])
82
-
83
- # **** UOp realization
84
-
85
- @dataclass(frozen=True)
86
- class GrouperContext:
87
- assigns: dict[UOp, UOp] # maps realized buffers to assigns
88
- realizes: dict[UOp, None] # all the simplified tensor uops we realize
89
- children: defaultdict[UOp, dict[UOp, None]] # children graph of tensor uops
90
-
91
- def realize(ctx:GrouperContext, tr:UOp) -> None: ctx.realizes[tr] = None
92
-
93
- def realize_before_view(ctx:GrouperContext, view:UOp, src:UOp) -> None:
94
- st = unwrap(view.st)
95
- # fold simple pads
96
- if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
97
- return None if can_pad(src, ctx.realizes, cache=dict()) else realize(ctx, src)
98
- # early realize before expand
99
- if resolve(prod(src.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, src)
100
- # otherwise safety check pads
101
- return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, cache=dict())) else realize(ctx, src)
102
-
103
- do_realize = PatternMatcher([
104
- # always realize SINK parents
105
- (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x.base, None) for x in s.src if x.base.op not in {Ops.CONST, Ops.BIND, Ops.BUFFER})),
106
- # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
107
- (UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
108
- # realize before expand or unsafe pad ops
109
- (UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}, name="src"),)), realize_before_view),
110
- # realize before COPY
111
- (UPat(Ops.COPY, src=(UPat(), UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW}, name="tr"))), realize),
112
- ])
113
-
114
- def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
115
- reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
116
- """recursively search the uop for groupable children, realize the UOp if a child can't group"""
117
- if (tr, st) in cache: return
118
- cache.setdefault((tr, st))
119
- rsize = unwrap(r.st).size
120
- if tr in realizes and tr is not r:
121
- # can only fuse contiguous
122
- # max one reduceop per kernel
123
- if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
124
- return group.setdefault(tr)
125
- for tr_next in children[tr]:
126
- # max one reduceop per kernel
127
- if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
128
- # can only fuse contiguous
129
- if len(st_childs:=dedup(unwrap(x.st) for x in tr_next.src if x.base == tr)) > 1: return group.setdefault(r)
130
- recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
131
-
132
- def append_uop(ctx:GrouperContext, u:UOp) -> None:
133
- if u.op is Ops.ASSIGN: ctx.assigns[u.buf_uop] = u
134
- for s in u.src: ctx.children[s.base][u] = None
135
- create_ctx = PatternMatcher([(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}, name="u"), append_uop)])
136
-
137
- def group_realizes(sink:UOp) -> dict[UOp, None]:
138
- # start by adding uops that always realize
139
- sink = graph_rewrite(sink, do_realize+create_ctx, ctx:=GrouperContext({}, {}, defaultdict(dict)))
140
- # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
141
- reduce_for_op: dict[UOp, UOp] = {}
142
- double_reduces: list[UOp] = []
143
- for r in sink.toposort:
144
- if r.op is not Ops.REDUCE_AXIS: continue
145
- if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
146
- if r in ctx.realizes: continue
147
- group: dict[UOp, None] = {}
148
- recursive_group(r, unwrap(r.st), r, ctx.children, ctx.realizes, reduce_for_op, group, cache={})
149
- # max one reduceop per kernel
150
- can_chase = all(tr not in reduce_for_op for tr in group)
151
- # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
152
- forced_realize = r in group
153
- # can only have one output
154
- if not forced_realize and len(group) > 1: forced_realize = True
155
- # can only fuse assign if no other assign_target is used in the kernel
156
- if not forced_realize and any(x.op is Ops.ASSIGN for x in group):
157
- parents = deque((r, *group))
158
- while parents and not forced_realize:
159
- p = parents.pop().base
160
- if (assign:=ctx.assigns.get(p)) is not None and assign not in group: forced_realize, can_chase = True, False
161
- if p in ctx.realizes: continue
162
- parents.extend(p.src)
163
- if forced_realize or not group:
164
- tr = r
165
- if can_chase:
166
- # can chase this down to contiguous children
167
- st = unwrap(tr.st)
168
- while len(ctx.children[tr]) == 1:
169
- tr_next = next(iter(ctx.children[tr]))
170
- st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr)
171
- if len(st_childs) > 1: break
172
- if st.size != st_childs[0].size: break
173
- st = st + st_childs[0]
174
- if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
175
- tr = tr_next
176
- # don't cast to higher size before store (tr cannot be realized if forced_realize)
177
- if tr.op is Ops.CAST and tr.dtype.itemsize > tr.src[0].dtype.itemsize:
178
- tr = tr.src[0].base
179
- group = {tr: None}
180
- ctx.realizes[tr] = None
181
- reduce_for_op.update((tr, r) for tr in group)
182
- if FUSE_ARANGE and r.arg[0] is Ops.ADD and r.src[0].base.op is Ops.CONST:
183
- # maybe fuse arange with its children
184
- if len(flatten(ctx.children[tr] for tr in group)) != 0:
185
- for tr in group: del ctx.realizes[tr]
186
- # fuse double reduces with no other child
187
- for reduceop in double_reduces:
188
- top_reduce = reduceop.src[0].base
189
- if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
190
- return ctx.realizes
191
-
192
- # break the SINK into kernels
193
-
194
- @dataclass(frozen=True)
195
- class Kernel:
196
- ast: UOp
197
- metadata: tuple[Metadata, ...]
198
- def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
199
-
200
- @dataclass(frozen=True)
201
- class KernelContext:
202
- realizes: dict[UOp, None]
203
- ops_metadata: dict[UOp, Metadata]
204
-
205
- def create_kernel(ctx:KernelContext, x:UOp):
206
- if x not in ctx.realizes: return None
207
- assert isinstance(x.device, str), f"buf device in kernel must be string {x.device}"
208
- b = x.buf_uop if x.op is Ops.ASSIGN else UOp.new_buffer(x.device, x.size, x.dtype)
209
- output_st = ShapeTracker.from_shape(x.shape)
210
- # KERNEL nodes become: ASSIGN(VIEW(BUFFER), KERNEL)
211
- # TODO: this should be ASSIGN(BUFFER, KERNEL) followed by the output ShapeTracker
212
- return b.view(output_st).assign(UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x, (m,) if (m:=ctx.ops_metadata.get(x)) else ())))
213
-
214
- def append_to_kernel(ctx:KernelContext, x:UOp):
215
- new_srcs: list[UOp] = []
216
- new_metadata: dict[Metadata, None] = dict.fromkeys(x.arg.metadata)
217
- for s in x.src:
218
- if s.op is Ops.BUFFER or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL) or s in ctx.realizes: new_srcs.append(s)
219
- else:
220
- new_srcs.extend(s.src)
221
- if (m:=ctx.ops_metadata.get(s)) is not None: new_metadata[m] = None
222
- return x.replace(src=n, arg=Kernel(x.arg.ast, tuple(new_metadata))) if (n:=tuple(dedup(new_srcs))) != x.src else None
223
-
224
- create_kernels = merge_views+PatternMatcher([
225
- (UPat(GroupOp.All-{Ops.KERNEL, Ops.BUFFER}, name="x"), create_kernel),
226
- (UPat(Ops.KERNEL, name="x"), append_to_kernel),
227
- ])
228
-
229
- # **** convert Kernel to a ScheduleItem (for legacy reasons)
8
+ # **** ScheduleItem return type
230
9
 
231
10
  @dataclass(frozen=True)
232
11
  class ScheduleItem:
233
12
  ast: UOp
234
13
  bufs: tuple[Buffer, ...]
235
- metadata: tuple[Metadata, ...]
236
- @property
237
- def outputs(self) -> tuple[Buffer, ...]:
238
- """Read/write or write only buffers in the schedule."""
239
- return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
240
- @property
241
- def inputs(self) -> tuple[Buffer, ...]:
242
- """Read only buffers in the schedule."""
243
- return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
244
- @functools.cached_property
245
- def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
246
-
247
- # **** Kernel creation
14
+ metadata: tuple[Metadata, ...] = ()
15
+ fixedvars: dict[Variable, int] = field(default_factory=dict)
248
16
 
249
- def apply_swizzle(u:UOp) -> UOp:
250
- with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
17
+ # **** schedule linearizer
251
18
 
252
- def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
253
- input_st = ShapeTracker.from_shape(unwrap(src.st).shape)
254
- tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
255
- prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
256
- strides = strides_for_shape(rshape)
257
- nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
258
- v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
259
- # update input_st and axis
260
- new_input_st = tmp + ShapeTracker(tuple(nv))
261
- new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
262
- return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
263
-
264
- def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp:
265
- if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}")
266
- output_shape = swizzle_st.reduce(r.axis_arg)
267
- return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape))
268
-
269
- def elementwise_view_right(root:UOp) -> UOp|None:
270
- if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None
271
- assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}"
272
- assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
273
- # push the swizzle from src to root
274
- output_swizzle = swizzles[0]
275
- new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
276
- ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
277
- return ret.view(ShapeTracker.from_shape(output_swizzle.shape))
278
-
279
- def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
280
- assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
281
- assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
282
- return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
283
-
284
- # push VIEW to children
285
- view_right = merge_views+PatternMatcher([
286
- # STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
287
- (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
288
- lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
289
- # STORE is the last child, so we just merge the ShapeTrackers and store the base
290
- (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
291
- # REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
292
- (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
293
- # REDUCE(src.view()) -> REDUCE(src).view()
294
- (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right),
295
- # ALU(src.view()) -> ALU(src).view()
296
- (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right),
297
- # double reduce op collapses to a single reduce op
298
- (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
299
- ])
300
-
301
- def _append_st_vars(ctx:dict[Variable, int], x:UOp) -> UOp|None:
302
- st = unwrap(x.st).simplify()
303
- if any(x.op is Ops.BIND for x in st.vars()):
304
- st, var_vals = st.unbind()
305
- ctx.update(var_vals)
306
- return st.to_uop() if st != x.st else None
307
-
308
- def check_load_st(glbl:UOp, view:UOp):
309
- if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
310
- # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
311
- if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
312
- # if it has a single view and it's equal when you shrink a contig, it's fine
313
- if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
314
- # otherwise, it's not fine
315
- raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
316
- +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
317
-
318
- fix_kernel_ops = PatternMatcher([
319
- # BIND in shapetracker becomes DEFINE_VAR
320
- (UPat(Ops.VIEW, name="x"), _append_st_vars),
321
- # remove CONTIGUOUS/ASSIGN/DEVICE
322
- (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
323
- (UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
324
- (UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
325
- # no ImageDType after load
326
- (UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
327
- # if this kernel also assigns to the loaded buffer, ensure we can index it correctly
328
- (UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
329
- ])
330
-
331
- def load_buf(ctx:list[UOp], x:UOp):
332
- if x not in ctx: ctx.append(x)
333
- return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop()))
334
-
335
- add_buffer_ops = PatternMatcher([
336
- # LOAD
337
- (UPat(Ops.BUFFER, name="x"), load_buf),
338
- # STORE (except for COPY/BUFFER_VIEW)
339
- (UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
340
- (UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
341
- lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
342
- ])
343
-
344
- def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
345
- ctx[var.replace(src=())] = val.arg
346
- return var
347
- unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
348
-
349
- def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
350
- assert sink.op is Ops.ASSIGN and sink.src[1].op is Ops.KERNEL, f"{sink} must be ASSIGN"
351
- # substitute kernel sources for the target buffer
352
- ast = sink.src[1].arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in sink.src[1].src if s.op is Ops.ASSIGN}).sink()
353
- # add buffer ops
354
- ast = graph_rewrite(ast, add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True)
355
- # unbind_vars + push views to edges
356
- ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
357
- # fix_kernel_ops
358
- ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
359
- # create subbuffer
360
- if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = bufs[1].buffer.view(ast.size, ast.dtype, (x:=ast.src[0]).st_arg.views[0].offset*x.dtype.itemsize)
361
- return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), sink.src[1].arg.metadata)
362
-
363
- PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
364
- if CAPTURE_PROCESS_REPLAY:
365
- @atexit.register
366
- def save_process_replay():
367
- for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
368
-
369
- # **** schedule creation and toposort
370
-
371
- @track_rewrites(named=True)
372
- def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
373
- # remove_movement_ops + sym
374
- tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
375
-
376
- # display the cleaned up tensor graph
377
- if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph")
378
-
379
- # do_realize + group_realizes
380
- sink = tensor_map[big_sink]
381
- realize_map = group_realizes(sink)
382
-
383
- # map tensors to new uops
384
- becomes_map: dict[UOp, UOp] = {}
385
- rev_tensor_map: dict[UOp, list[UOp]] = {}
386
- ops_metadata: dict[UOp, Metadata] = {}
387
- for k,v in tensor_map.items():
388
- rev_tensor_map.setdefault(v, []).append(k)
389
- if k is v: continue
390
- if v.base.op is Ops.BUFFER:
391
- # VIEW isn't a valid tensor uop, we need to backtrack to the movement op that created it
392
- if v.op is Ops.VIEW:
393
- mop = [x for x in k.toposort if (xs:=tensor_map[x]).base is v.base and xs.st == v.st][0]
394
- if k is not mop: becomes_map[k] = mop
395
- else: becomes_map[k] = v
396
- elif v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
397
- # if we're not realizing this tensor, map its metadata to the simplified uop
398
- elif isinstance(k.metadata, Metadata): ops_metadata[v] = k.metadata
399
-
400
- # create kernels
401
- if len(realize_map) == 0: return [], {}, becomes_map
402
- kernel_map = graph_rewrite_map(sink, create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True)
403
- sched_sink = kernel_map[sink]
404
- type_verify(list(sched_sink.toposort), kernel_spec)
405
-
406
- # map realized tensors to buffers
407
- for k,v in kernel_map.items():
408
- if k is v or v.op is not Ops.ASSIGN: continue
409
- for t in rev_tensor_map[k]: becomes_map[t] = t.src[0] if t.op is Ops.ASSIGN else v.buf_uop.reshape(t.shape)
410
-
411
- # if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
412
- kernel_assign: dict[UOp, UOp] = {}
413
- assign_rep: dict[UOp, UOp] = {}
414
- for u in sched_sink.toposort:
415
- if u.op is not Ops.ASSIGN: continue
416
- kernel_assign[u.buf_uop] = u
417
- for s in u.src[1].src:
418
- if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
419
- if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort):
420
- raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
421
- assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
422
- if assign_rep:
423
- sched_sink = sched_sink.substitute(assign_rep)
424
- type_verify(list(sched_sink.toposort), kernel_spec)
425
-
426
- # display the final graph
427
- if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
428
-
429
- # final toposort (bfs)
430
- children: dict[UOp, list[UOp]] = {}
19
+ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int]]:
20
+ # construct the KERNEL children graph based on assigns
21
+ children: defaultdict[UOp, list[UOp]] = defaultdict(list)
431
22
  in_degree: dict[UOp, int] = {}
432
- for u in sched_sink.toposort:
433
- if u.op is not Ops.ASSIGN: continue
434
- in_degree[u] = 0
435
- for s in u.src[1].src:
436
- if s.op is not Ops.ASSIGN: continue
437
- children.setdefault(s, []).append(u)
438
- in_degree[u] += 1
23
+ var_vals: dict[Variable, int] = {}
24
+ for u in sched_sink.toposort():
25
+ if u.op is not Ops.ASSIGN: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
26
+ k = u.src[1]
27
+ in_degree.setdefault(k, 0)
28
+ for s in k.src:
29
+ if s.op is Ops.ASSIGN:
30
+ children[s.src[1]].append(k)
31
+ in_degree[k] += 1
32
+ elif s.op in {Ops.MSELECT, Ops.MSTACK}:
33
+ for ss in s.src:
34
+ if ss.op is Ops.MSELECT: ss = ss.src[0]
35
+ if ss.op is not Ops.BUFFER:
36
+ assert ss.op is Ops.ASSIGN
37
+ children[ss.src[1]].append(k)
38
+ in_degree[k] += 1
39
+ elif s.op is Ops.BUFFER:
40
+ pass # a BUFFER is already realized, nothing to do here
41
+ elif s.op is Ops.BIND:
42
+ var, val = s.unbind()
43
+ assert var not in var_vals or var_vals[var] == val, f"bind mismatch on {var}, {var_vals[var]} != {val}"
44
+ var_vals[var] = val
45
+ else:
46
+ raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}")
47
+
48
+ # linearize KERNEL UOps into ScheduleItems in BFS order
49
+
50
+ def _heuristic(k: UOp):
51
+ if k.arg.ast.op is Ops.COPY and not all_same([Device[cast(Buffer, s.buf_uop.buffer).device].group_id for s in k.src]): return 1000
52
+ return 0
53
+
54
+ last_heuristic: int = 0
55
+ queues: defaultdict[int, deque[UOp]] = defaultdict(deque)
56
+ last_queue: deque[UOp] = deque()
57
+ for k,v in in_degree.items():
58
+ if v == 0: queues[_heuristic(k)].append(k)
439
59
 
440
- queue = deque(k for k,v in in_degree.items() if v == 0)
441
60
  schedule: list[ScheduleItem] = []
442
- var_vals: dict[Variable, int] = {}
443
- while queue:
444
- u = queue.popleft()
445
- schedule.append(schedule_uop(u, var_vals))
446
- # increment the refcount of the target buf (this is required by the JIT and memory planner)
447
- u.buf_uop.buffer.ref(1)
448
- for x in children.get(u, []):
61
+ while last_queue or any(queues.values()):
62
+ if not last_queue: last_heuristic, last_queue = min((it for it in queues.items() if it[1]), key=lambda x: abs(x[0]-last_heuristic))
63
+ k = last_queue.popleft()
64
+ ast = k.arg.ast
65
+ # create subbuffers if needed
66
+ if ast.op is Ops.BUFFER_VIEW:
67
+ base = k.src[1].buf_uop.buffer
68
+ assert isinstance(base, Buffer), "base can't be MultiBuffer"
69
+ buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
70
+ ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND)
71
+ if any(isinstance(x, MultiBuffer) for x in ubufs):
72
+ assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
73
+ dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
74
+ for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
75
+ schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0]:i} if len(dnums) else {}))
76
+ else:
77
+ # ONE -> ONE
78
+ schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))
79
+ for x in children[k]:
449
80
  in_degree[x] -= 1
450
- if in_degree[x] == 0: queue.append(x)
81
+ if in_degree[x] == 0: queues[_heuristic(x)].append(x)
451
82
 
452
- # confirm everything was scheduled correctly
453
- if len(schedule) != (kc:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}")
454
- if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
455
- # capture process replay
456
- if CAPTURE_PROCESS_REPLAY:
457
- with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, [x.ast for x in schedule]))
458
- return schedule, var_vals, becomes_map
83
+ return schedule, var_vals
File without changes