tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,718 @@
1
+ // ** graph helpers
2
+
3
+ const displayGraph = (cls) => {
4
+ for (const e of document.getElementsByClassName("view")) e.style.display = e.classList.contains(cls) ? "flex" : "none";
5
+ }
6
+
7
+ const ANSI_COLORS = ["#b3b3b3", "#ff6666", "#66b366", "#ffff66", "#6666ff", "#ff66ff", "#66ffff", "#ffffff"];
8
+ const parseColors = (name, defaultColor="#ffffff") => Array.from(name.matchAll(/(?:\u001b\[(\d+)m([\s\S]*?)\u001b\[0m)|([^\u001b]+)/g),
9
+ ([_, code, colored_st, st]) => ({ st: colored_st ?? st, color: code != null ? ANSI_COLORS[(parseInt(code)-30+60)%60] : defaultColor }));
10
+
11
+ const rect = (s) => (typeof s === "string" ? document.querySelector(s) : s).getBoundingClientRect();
12
+
13
+ let timeout = null;
14
+ const updateProgress = ({ show=true }) => {
15
+ clearTimeout(timeout);
16
+ const msg = document.getElementById("progress-message");
17
+ if (show) {
18
+ msg.innerText = "Rendering new graph...";
19
+ timeout = setTimeout(() => { msg.style.display = "block"; }, 2000);
20
+ } else msg.style.display = "none";
21
+ }
22
+
23
+ // ** UOp graph
24
+
25
+ function intersectRect(r1, r2) {
26
+ const dx = r2.x-r1.x;
27
+ const dy = r2.y-r1.y;
28
+ if (dx === 0 && dy === 0) throw new Error("Invalid node coordinates, rects must not overlap");
29
+ const scaleX = dx !== 0 ? (r1.width/2)/Math.abs(dx) : Infinity;
30
+ const scaleY = dy !== 0 ? (r1.height/2)/Math.abs(dy) : Infinity;
31
+ const scale = Math.min(scaleX, scaleY);
32
+ return {x:r1.x+dx*scale, y:r1.y+dy*scale};
33
+ }
34
+
35
+ function addTags(root) {
36
+ root.selectAll("circle").data(d => [d]).join("circle").attr("r", 5);
37
+ root.selectAll("text").data(d => [d]).join("text").text(d => d).attr("dy", "0.35em");
38
+ }
39
+
40
+ let [workerUrl, worker] = [null, null];
41
+ async function renderDag(graph, additions, recenter=false) {
42
+ // start calculating the new layout (non-blocking)
43
+ updateProgress({ show:true });
44
+ if (worker == null) {
45
+ const resp = await Promise.all(["/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js","/js/worker.js"].map(u => fetch(u)));
46
+ workerUrl = URL.createObjectURL(new Blob([(await Promise.all(resp.map((r) => r.text()))).join("\n")], { type: "application/javascript" }));
47
+ worker = new Worker(workerUrl);
48
+ } else {
49
+ worker.terminate();
50
+ worker = new Worker(workerUrl);
51
+ }
52
+ worker.postMessage({graph, additions, ctxs});
53
+ worker.onmessage = (e) => {
54
+ displayGraph("graph");
55
+ updateProgress({ show:false });
56
+ const g = dagre.graphlib.json.read(e.data);
57
+ // draw nodes
58
+ const STROKE_WIDTH = 1.4;
59
+ const nodes = d3.select("#nodes").selectAll("g").data(g.nodes().map(id => g.node(id)), d => d).join("g")
60
+ .attr("transform", d => `translate(${d.x},${d.y})`).classed("clickable", d => d.ref != null)
61
+ .on("click", (_,d) => setCtxWithHistory(d.ref));
62
+ nodes.selectAll("rect").data(d => [d]).join("rect").attr("width", d => d.width).attr("height", d => d.height).attr("fill", d => d.color)
63
+ .attr("x", d => -d.width/2).attr("y", d => -d.height/2).attr("style", d => d.style ?? `stroke:#4a4b57; stroke-width:${STROKE_WIDTH}px;`);
64
+ nodes.selectAll("g.label").data(d => [d]).join("g").attr("class", "label").attr("transform", d => {
65
+ const x = (d.width-d.padding*2)/2;
66
+ const y = (d.height-d.padding*2)/2+STROKE_WIDTH;
67
+ return `translate(-${x}, -${y})`;
68
+ }).selectAll("text").data(d => {
69
+ const ret = [[]];
70
+ for (const { st, color } of parseColors(d.label, defaultColor="initial")) {
71
+ for (const [i, l] of st.split("\n").entries()) {
72
+ if (i > 0) ret.push([]);
73
+ ret.at(-1).push({ st:l, color });
74
+ }
75
+ }
76
+ return [ret];
77
+ }).join("text").selectAll("tspan").data(d => d).join("tspan").attr("x", "0").attr("dy", 14).selectAll("tspan").data(d => d).join("tspan")
78
+ .attr("fill", d => d.color).text(d => d.st).attr("xml:space", "preserve");
79
+ addTags(nodes.selectAll("g.tag").data(d => d.tag != null ? [d] : []).join("g").attr("class", "tag")
80
+ .attr("transform", d => `translate(${-d.width/2+8}, ${-d.height/2+8})`).datum(e => e.tag));
81
+ // draw edges
82
+ const line = d3.line().x(d => d.x).y(d => d.y).curve(d3.curveBasis);
83
+ d3.select("#edges").selectAll("path.edgePath").data(g.edges()).join("path").attr("class", "edgePath").attr("d", (e) => {
84
+ const edge = g.edge(e);
85
+ const points = edge.points.slice(1, edge.points.length-1);
86
+ points.unshift(intersectRect(g.node(e.v), points[0]));
87
+ points.push(intersectRect(g.node(e.w), points[points.length-1]));
88
+ return line(points);
89
+ }).attr("marker-end", "url(#arrowhead)");
90
+ addTags(d3.select("#edge-labels").selectAll("g").data(g.edges().filter(e => g.edge(e).label != null)).join("g").attr("transform", (e) => {
91
+ // get a point near the end
92
+ const [p1, p2] = g.edge(e).points.slice(-2);
93
+ const dx = p2.x-p1.x;
94
+ const dy = p2.y-p1.y;
95
+ // normalize to the unit vector
96
+ const len = Math.sqrt(dx*dx + dy*dy);
97
+ const ux = dx / len;
98
+ const uy = dy / len;
99
+ // avoid overlap with the arrowhead
100
+ const offset = 17;
101
+ const x = p2.x - ux * offset;
102
+ const y = p2.y - uy * offset;
103
+ return `translate(${x}, ${y})`
104
+ }).attr("class", "tag").datum(e => g.edge(e).label));
105
+ if (recenter) document.getElementById("zoom-to-fit-btn").click();
106
+ };
107
+
108
+ }
109
+
110
+ // ** profiler graph
111
+
112
+ function formatTime(ts, dur=ts) {
113
+ if (dur<=1e3) return `${ts.toFixed(2)}us`;
114
+ if (dur<=1e6) return `${(ts*1e-3).toFixed(2)}ms`;
115
+ return `${(ts*1e-6).toFixed(2)}s`;
116
+ }
117
+ const formatUnit = (d, unit="") => d3.format(".3~s")(d)+unit;
118
+
119
+ const colorScheme = {TINY:["#1b5745", "#354f52", "#354f52", "#1d2e62", "#63b0cd"],
120
+ DEFAULT:["#2b2e39", "#2c2f3a", "#31343f", "#323544", "#2d303a", "#2e313c", "#343746", "#353847", "#3c4050", "#404459", "#444862", "#4a4e65"],
121
+ BUFFER:["#3A57B7","#5066C1","#6277CD","#7488D8","#8A9BE3","#A3B4F2"],
122
+ CATEGORICAL:["#ff8080", "#F4A261", "#C8F9D4", "#8D99AE", "#F4A261", "#ffffa2", "#ffffc0", "#87CEEB"],}
123
+ const cycleColors = (lst, i) => lst[i%lst.length];
124
+
125
+ const rescaleTrack = (source, tid, k) => {
126
+ for (const e of source.shapes) {
127
+ for (let i=0; i<e.y0.length; i++) {
128
+ e.y0[i] = e.y0[i]*k;
129
+ e.y1[i] = e.y1[i]*k;
130
+ }
131
+ }
132
+ const change = (source.height*k)-source.height;
133
+ const div = document.getElementById(tid);
134
+ div.style.height = rect(div).height+change+"px";
135
+ source.height = source.height*k;
136
+ return change;
137
+ }
138
+
139
+ const drawLine = (ctx, x, y) => {
140
+ ctx.beginPath();
141
+ ctx.moveTo(x[0], y[0]);
142
+ ctx.lineTo(x[1], y[1]);
143
+ ctx.fillStyle = ctx.strokeStyle = "#f0f0f5";
144
+ ctx.stroke();
145
+ }
146
+
147
+ var data, focusedDevice, canvasZoom, zoomLevel = d3.zoomIdentity;
148
+ async function renderProfiler() {
149
+ displayGraph("profiler");
150
+ d3.select(".metadata").html("");
151
+ // layout once!
152
+ if (data != null) return;
153
+ const profiler = d3.select(".profiler").html("");
154
+ const { layout, st, et } = await (await fetch("/get_profile")).json();
155
+ // place devices on the y axis and set vertical positions
156
+ const [tickSize, padding] = [10, 8];
157
+ const deviceList = profiler.append("div").attr("id", "device-list").style("padding-top", tickSize+padding+"px");
158
+ const canvas = profiler.append("canvas").attr("id", "timeline").node();
159
+ // NOTE: scrolling via mouse can only zoom the graph
160
+ canvas.addEventListener("wheel", e => (e.stopPropagation(), e.preventDefault()), { passive:false });
161
+ const ctx = canvas.getContext("2d");
162
+ const canvasTop = rect(canvas).top;
163
+ // color by key (name/category/device)
164
+ const colorMap = new Map();
165
+ data = {tracks:new Map(), axes:{}, st, et};
166
+ const heightScale = d3.scaleLinear().domain([0, Object.entries(layout).reduce((peak, [_,d]) => Math.max(peak, d.peak||0), 0)]).range([4,maxheight=100]);
167
+ for (const [k, v] of Object.entries(layout)) {
168
+ if (v.shapes.length === 0) continue;
169
+ const div = deviceList.append("div").attr("id", k).text(k).style("padding", padding+"px");
170
+ const { y:baseY, height:baseHeight } = rect(div.node());
171
+ const offsetY = baseY-canvasTop+padding/2;
172
+ if (v.shapes[0].dur != null) {
173
+ const levelHeight = baseHeight-padding;
174
+ const shapes = [];
175
+ data.tracks.set(k, { shapes, offsetY });
176
+ let colorKey, ref;
177
+ for (const e of v.shapes) {
178
+ if (e.depth === 0) colorKey = e.cat ?? e.name;
179
+ if (!colorMap.has(colorKey)) colorMap.set(colorKey, cycleColors(colorScheme[k] ?? colorScheme.DEFAULT, colorMap.size));
180
+ const fillColor = d3.color(colorMap.get(colorKey)).brighter(e.depth).toString();
181
+ const label = parseColors(e.name).map(({ color, st }) => ({ color, st, width:ctx.measureText(st).width }));
182
+ if (e.ref != null) ref = {ctx:e.ref, step:0};
183
+ else if (ref != null) {
184
+ const start = ref.step>0 ? ref.step+1 : 0;
185
+ const stepIdx = ctxs[ref.ctx+1].steps.findIndex((s, i) => i >= start && s.name == e.name);
186
+ ref = stepIdx === -1 ? null : {ctx:ref.ctx, step:stepIdx};
187
+ }
188
+ const arg = { tooltipText:formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), ...ref };
189
+ // offset y by depth
190
+ shapes.push({x:e.st-st, y:levelHeight*e.depth, width:e.dur, height:levelHeight, arg, label, fillColor });
191
+ }
192
+ div.style("height", levelHeight*v.maxDepth+padding+"px").style("pointerEvents", "none");
193
+ } else {
194
+ const height = heightScale(v.peak);
195
+ const yscale = d3.scaleLinear().domain([0, v.peak]).range([height, 0]);
196
+ const shapes = [];
197
+ for (const [i,e] of v.shapes.entries()) {
198
+ const x = e.x.map(tsIdx => v.timestamps[tsIdx]-st);
199
+ const arg = {tooltipText:`${e.arg.dtype} len:${formatUnit(e.arg.sz)}\n${formatUnit(e.arg.nbytes, "B")}`};
200
+ shapes.push({ x, y0:e.y.map(yscale), y1:e.y.map(y => yscale(y+e.arg.nbytes)), arg, fillColor:cycleColors(colorScheme.BUFFER, i) });
201
+ }
202
+ data.tracks.set(k, { shapes, offsetY, height, peak:v.peak, scaleFactor:maxheight*4/height });
203
+ div.style("height", height+padding+"px").style("cursor", "pointer").on("click", (e) => {
204
+ const newFocus = e.currentTarget.id === focusedDevice ? null : e.currentTarget.id;
205
+ let offset = 0;
206
+ for (const [tid, track] of data.tracks) {
207
+ track.offsetY += offset;
208
+ if (tid === newFocus) offset += rescaleTrack(track, tid, track.scaleFactor);
209
+ else if (tid === focusedDevice) offset += rescaleTrack(track, tid, 1/track.scaleFactor);
210
+ }
211
+ data.axes.y = newFocus != null ? { domain:[0, (t=data.tracks.get(newFocus)).peak], range:[t.offsetY+t.height, t.offsetY], fmt:"B" } : null;
212
+ focusedDevice = newFocus;
213
+ return resize();
214
+ });
215
+ }
216
+ }
217
+ updateProgress({ "show":false });
218
+ // draw events on a timeline
219
+ const dpr = window.devicePixelRatio || 1;
220
+ const ellipsisWidth = ctx.measureText("...").width;
221
+ const rectLst = [];
222
+ function render(transform) {
223
+ zoomLevel = transform;
224
+ rectLst.length = 0;
225
+ ctx.save();
226
+ ctx.clearRect(0, 0, canvas.clientWidth, canvas.clientHeight);
227
+ // rescale to match current zoom
228
+ const xscale = d3.scaleLinear().domain([0, et-st]).range([0, canvas.clientWidth]);
229
+ xscale.domain(xscale.range().map(zoomLevel.invertX, zoomLevel).map(xscale.invert, xscale));
230
+ const zoomDomain = transform != null ? xscale.domain() : null;
231
+ let yscale = null;
232
+ if (data.axes.y != null) {
233
+ yscale = d3.scaleLinear().domain(data.axes.y.domain).range(data.axes.y.range);
234
+ }
235
+ // draw shapes
236
+ for (const [_, { offsetY, shapes }] of data.tracks) {
237
+ for (const e of shapes) {
238
+ const [start, end] = e.width != null ? [e.x, e.x+e.width] : [e.x[0], e.x[e.x.length-1]];
239
+ if (zoomDomain != null && (start>zoomDomain[1]|| end<zoomDomain[0])) continue;
240
+ ctx.fillStyle = e.fillColor;
241
+ // generic polygon
242
+ if (e.width == null) {
243
+ const x = e.x.map(xscale);
244
+ ctx.beginPath();
245
+ ctx.moveTo(x[0], offsetY+e.y0[0]);
246
+ for (let i=1; i<x.length; i++) ctx.lineTo(x[i], offsetY+e.y0[i]);
247
+ for (let i=x.length-1; i>=0; i--) ctx.lineTo(x[i], offsetY+e.y1[i]);
248
+ ctx.closePath();
249
+ ctx.fill();
250
+ // NOTE: y coordinates are in reverse order
251
+ for (let i = 0; i < x.length - 1; i++) {
252
+ let tooltipText = e.arg.tooltipText;
253
+ if (yscale != null && ((yaxisVal=yscale.invert(offsetY+e.y1[i]))>0)) {
254
+ tooltipText += `\nTotal: ${formatUnit(yaxisVal, data.axes.y.fmt)}`;
255
+ }
256
+ rectLst.push({ x0:x[i], x1:x[i+1], y0:offsetY+e.y1[i], y1:offsetY+e.y0[i], arg:{...e.arg, tooltipText} });
257
+ }
258
+ continue;
259
+ }
260
+ // contiguous rect
261
+ const x = xscale(start);
262
+ const width = xscale(end)-x;
263
+ ctx.fillRect(x, offsetY+e.y, width, e.height);
264
+ rectLst.push({ y0:offsetY+e.y, y1:offsetY+e.y+e.height, x0:x, x1:x+width, arg:e.arg });
265
+ // add label
266
+ if (e.label == null) continue;
267
+ ctx.textAlign = "left";
268
+ ctx.textBaseline = "middle";
269
+ let [labelX, labelWidth] = [x+2, 0];
270
+ const labelY = offsetY+e.y+e.height/2;
271
+ for (const [i,l] of e.label.entries()) {
272
+ if (labelWidth+l.width+(i===e.label.length-1 ? 0 : ellipsisWidth)+2 > width) {
273
+ if (labelWidth !== 0) ctx.fillText("...", labelX, labelY);
274
+ break;
275
+ }
276
+ ctx.fillStyle = l.color;
277
+ ctx.fillText(l.st, labelX, labelY);
278
+ labelWidth += l.width;
279
+ labelX += l.width;
280
+ }
281
+ }
282
+ }
283
+ // draw axes
284
+ drawLine(ctx, xscale.range(), [0, 0]);
285
+ for (const tick of xscale.ticks()) {
286
+ // tick line
287
+ const x = xscale(tick);
288
+ drawLine(ctx, [x, x], [0, tickSize])
289
+ // tick label
290
+ ctx.textBaseline = "top";
291
+ ctx.textAlign = "left";
292
+ ctx.fillText(formatTime(tick, et-st), x+ctx.lineWidth+2, tickSize);
293
+ }
294
+ if (yscale != null) {
295
+ drawLine(ctx, [0, 0], yscale.range());
296
+ for (const tick of yscale.ticks()) {
297
+ const y = yscale(tick);
298
+ drawLine(ctx, [0, tickSize], [y, y]);
299
+ ctx.textAlign = "left";
300
+ ctx.textBaseline = "middle";
301
+ ctx.fillText(formatUnit(tick, data.axes.y.fmt), tickSize+2, y);
302
+ }
303
+ }
304
+ ctx.restore();
305
+ }
306
+
307
+ function resize() {
308
+ const profiler = document.querySelector(".profiler");
309
+ // NOTE: use clientWidth to account for the scrollbar
310
+ let [width, height] = [profiler.clientWidth, profiler.scrollHeight];
311
+ width -= rect("#device-list").width+padding;
312
+ canvas.width = width*dpr;
313
+ canvas.height = height*dpr;
314
+ canvas.style.height = `${height}px`;
315
+ canvas.style.width = `${width}px`;
316
+ ctx.scale(dpr, dpr);
317
+ d3.select(canvas).call(canvasZoom.transform, zoomLevel);
318
+ }
319
+
320
+ canvasZoom = d3.zoom().filter(e => (!e.ctrlKey || e.type === 'wheel' || e.type === 'mousedown') && !e.button)
321
+ .scaleExtent([1, Infinity]).translateExtent([[0,0], [Infinity,0]]).on("zoom", e => render(e.transform));
322
+ d3.select(canvas).call(canvasZoom);
323
+ document.addEventListener("contextmenu", e => e.ctrlKey && e.preventDefault());
324
+
325
+ resize();
326
+ window.addEventListener("resize", resize);
327
+
328
+ function findRectAtPosition(x, y) {
329
+ const { top, left, width, height } = rect(canvas);
330
+ const X = ((x-left) * (canvas.width/width))/dpr;
331
+ const Y = ((y-top) * (canvas.height/height))/dpr;
332
+ for (const r of rectLst) {
333
+ if (Y>=r.y0 && Y<=r.y1 && X>=r.x0 && X<=r.x1) return r.arg;
334
+ }
335
+ }
336
+
337
+ canvas.addEventListener("click", e => {
338
+ e.preventDefault();
339
+ const foundRect = findRectAtPosition(e.clientX, e.clientY);
340
+ if (foundRect?.step != null) return setCtxWithHistory(foundRect.ctx, foundRect.step);
341
+ });
342
+
343
+ canvas.addEventListener("mousemove", e => {
344
+ const foundRect = findRectAtPosition(e.clientX, e.clientY);
345
+ if (foundRect?.tooltipText != null) {
346
+ const tooltip = document.getElementById("tooltip");
347
+ tooltip.style.display = "block";
348
+ tooltip.style.left = (e.pageX+10)+"px";
349
+ tooltip.style.top = (e.pageY)+"px";
350
+ tooltip.innerText = foundRect.tooltipText;
351
+ } else tooltip.style.display = "none";
352
+ });
353
+ canvas.addEventListener("mouseleave", () => document.getElementById("tooltip").style.display = "none");
354
+ }
355
+
356
+ // ** zoom and recentering
357
+
358
+ const svgZoom = d3.zoom().on("zoom", (e) => d3.select("#render").attr("transform", e.transform));
359
+ d3.select("#graph-svg").call(svgZoom);
360
+
361
+ // zoom to fit into view
362
+ document.getElementById("zoom-to-fit-btn").addEventListener("click", () => {
363
+ const canvas = d3.select("#timeline");
364
+ if (!canvas.empty() && rect(canvas.node()).width !== 0) {
365
+ return canvas.call(canvasZoom.transform, d3.zoomIdentity);
366
+ }
367
+ const svg = d3.select("#graph-svg");
368
+ svg.call(svgZoom.transform, d3.zoomIdentity);
369
+ const mainRect = rect(".main-container");
370
+ const x0 = rect(".ctx-list-parent").right;
371
+ const x1 = rect(".metadata-parent").left;
372
+ const pad = 16;
373
+ const R = { x: x0+pad, y: mainRect.top+pad, width: (x1>0 ? x1-x0 : mainRect.width)-2*pad, height: mainRect.height-2*pad };
374
+ const r = rect("#render");
375
+ if (r.width === 0) return;
376
+ const scale = Math.min(R.width/r.width, R.height/r.height);
377
+ const [tx, ty] = [R.x+(R.width-r.width*scale)/2-r.left*scale, R.y+(R.height-r.height*scale)/2];
378
+ svg.call(svgZoom.transform, d3.zoomIdentity.translate(tx, ty).scale(scale));
379
+ });
380
+
381
+ // **** main VIZ interfacae
382
+
383
+ function codeBlock(st, language, { loc, wrap }={}) {
384
+ const code = document.createElement("code");
385
+ code.innerHTML = hljs.highlight(st, { language }).value;
386
+ code.className = "hljs";
387
+ const ret = document.createElement("pre");
388
+ if (wrap) ret.className = "wrap";
389
+ if (loc != null) {
390
+ const link = ret.appendChild(document.createElement("a"));
391
+ link.href = "vscode://file/"+loc.join(":");
392
+ link.textContent = `${loc[0].split("/").at(-1)}:${loc[1]}`+"\n\n";
393
+ }
394
+ ret.appendChild(code);
395
+ return ret;
396
+ }
397
+
398
+ function appendTd(tr, value, unit=null) {
399
+ const fmt = (typeof value === "number" && !Number.isInteger(value)) ? value.toFixed(2) : value;
400
+ tr.appendChild(document.createElement("td")).innerText = unit == "us" ? formatTime(value) : fmt+(unit ?? "");
401
+ }
402
+
403
+ function appendRow(table, name, value, unit=null, cls="main-row") {
404
+ const tr = table.appendChild(document.createElement("tr"));
405
+ tr.className = cls;
406
+ tr.appendChild(document.createElement("td")).innerText = name;
407
+ appendTd(tr, value, unit);
408
+ return tr;
409
+ }
410
+
411
+ function setActive(e) {
412
+ if (e == null) return;
413
+ e.classList.add("active");
414
+ requestAnimationFrame(() => e.scrollIntoView({ behavior: "auto", block: "nearest" }));
415
+ }
416
+
417
+ // ** hljs extra definitions for UOps and float4
418
+ hljs.registerLanguage("python", (hljs) => ({
419
+ ...hljs.getLanguage("python"),
420
+ case_insensitive: false,
421
+ contains: [
422
+ { begin: 'dtypes\\.[a-zA-Z_][a-zA-Z0-9_-]*(\\.[a-zA-Z_][a-zA-Z0-9_-]*)*' + '(?=[.\\s\\n[:,(])', className: "type" },
423
+ { begin: 'dtypes\\.[a-zA-Z_][a-zA-Z0-9_-].vec*' + '(?=[.\\s\\n[:,(])', className: "type" },
424
+ { begin: '[a-zA-Z_][a-zA-Z0-9_-]*\\.[a-zA-Z_][a-zA-Z0-9_-]*' + '(?=[.\\s\\n[:,()])', className: "operator" },
425
+ { begin: '[A-Z][a-zA-Z0-9_]*(?=\\()', className: "section", ignoreEnd: true },
426
+ ...hljs.getLanguage("python").contains,
427
+ ]
428
+ }));
429
+ hljs.registerLanguage("cpp", (hljs) => ({
430
+ ...hljs.getLanguage('cpp'),
431
+ contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains]
432
+ }));
433
+
434
+ var ret = [];
435
+ var cache = {};
436
+ var ctxs = null;
437
+ const evtSources = [];
438
+ // VIZ displays graph rewrites in 3 levels, from bottom-up:
439
+ // rewrite: a single UOp transformation
440
+ // step: collection of rewrites
441
+ // context: collection of steps
442
+ const state = {currentCtx:-1, currentStep:0, currentRewrite:0, expandSteps:false};
443
+ function setState(ns) {
444
+ const { currentCtx:prevCtx, currentStep:prevStep } = state;
445
+ Object.assign(state, ns);
446
+ // update element styles if needed
447
+ document.getElementById(`ctx-${state.currentCtx}`)?.classList.toggle("expanded", state.expandSteps);
448
+ if (state.currentCtx !== prevCtx) {
449
+ document.getElementById(`ctx-${prevCtx}`)?.classList.remove("active", "expanded");
450
+ setActive(document.getElementById(`ctx-${state.currentCtx}`));
451
+ }
452
+ if (state.currentCtx !== prevCtx || state.currentStep !== prevStep) {
453
+ document.getElementById(`step-${prevCtx}-${prevStep}`)?.classList.remove("active");
454
+ setActive(document.getElementById(`step-${state.currentCtx}-${state.currentStep}`));
455
+ }
456
+ // re-render
457
+ main();
458
+ }
459
+
460
+ // set a new context and keep the old one in browser history
461
+ function setCtxWithHistory(newCtx, step=0) {
462
+ if (newCtx == null) return;
463
+ // NOTE: browser does a structured clone, passing a mutable object is safe.
464
+ history.replaceState(state, "");
465
+ history.pushState(state, "");
466
+ setState({ expandSteps:true, currentCtx:newCtx+1, currentStep:step, currentRewrite:0 });
467
+ }
468
+
469
+ window.addEventListener("popstate", (e) => {
470
+ if (e.state != null) setState(e.state);
471
+ });
472
+
473
+ async function main() {
474
+ // ** left sidebar context list
475
+ if (ctxs == null) {
476
+ ctxs = [{ name:"Profiler", steps:[] }];
477
+ for (const r of (await (await fetch("/ctxs")).json())) ctxs.push(r);
478
+ const ctxList = document.querySelector(".ctx-list");
479
+ for (const [i,{name, steps}] of ctxs.entries()) {
480
+ const ul = ctxList.appendChild(document.createElement("ul"));
481
+ ul.id = `ctx-${i}`;
482
+ const p = ul.appendChild(document.createElement("p"));
483
+ p.innerHTML = parseColors(name).map(c => `<span style="color: ${c.color}">${c.st}</span>`).join("");
484
+ p.onclick = () => {
485
+ setState(i === state.currentCtx ? { expandSteps:!state.expandSteps } : { expandSteps:true, currentCtx:i, currentStep:0, currentRewrite:0 });
486
+ }
487
+ for (const [j,u] of steps.entries()) {
488
+ const inner = ul.appendChild(document.createElement("ul"));
489
+ inner.id = `step-${i}-${j}`;
490
+ inner.innerText = `${u.name ?? u.loc[0].replaceAll("\\", "/").split("/").pop()+':'+u.loc[1]}`+(u.match_count ? ` - ${u.match_count}` : '');
491
+ inner.style.marginLeft = `${8*u.depth}px`;
492
+ inner.onclick = (e) => {
493
+ e.stopPropagation();
494
+ setState({ currentStep:j, currentCtx:i, currentRewrite:0 });
495
+ }
496
+ }
497
+ }
498
+ return setState({ currentCtx:-1 });
499
+ }
500
+ // ** center graph
501
+ const { currentCtx, currentStep, currentRewrite, expandSteps } = state;
502
+ if (currentCtx == -1) return;
503
+ const ctx = ctxs[currentCtx];
504
+ const step = ctx.steps[currentStep];
505
+ const ckey = step?.query;
506
+ // close any pending event sources
507
+ let activeSrc = null;
508
+ for (const e of evtSources) {
509
+ const url = new URL(e.url);
510
+ if (url.pathname+url.search !== ckey) e.close();
511
+ else if (e.readyState === EventSource.OPEN) activeSrc = e;
512
+ }
513
+ if (ctx.name === "Profiler") return renderProfiler();
514
+ if (ckey in cache) {
515
+ ret = cache[ckey];
516
+ }
517
+ // ** Disassembly view
518
+ if (ckey.startsWith("/disasm")) {
519
+ if (!(ckey in cache)) cache[ckey] = ret = await (await fetch(ckey)).json();
520
+ displayGraph("profiler");
521
+ const root = document.createElement("div");
522
+ root.className = "raw-text";
523
+ const metadata = document.querySelector(".metadata");
524
+ metadata.innerHTML = "";
525
+ // detailed assembly view
526
+ if (ret.cols != null) {
527
+ const asm = root.appendChild(document.createElement("table"));
528
+ const thead = asm.appendChild(document.createElement("thead"));
529
+ for (const c of ret.cols) thead.appendChild(document.createElement("th")).innerText = c.title ?? c;
530
+ for (const r of ret.rows) {
531
+ const tr = asm.appendChild(document.createElement("tr"));
532
+ tr.className = "main-row code-row";
533
+ for (const [i,value] of r.entries()) {
534
+ // string format scalar values
535
+ if (!Array.isArray(value)) appendTd(tr, value);
536
+ // display arrays in a bar graph
537
+ else {
538
+ const segmentsTd = tr.appendChild(document.createElement("td"));
539
+ segmentsTd.className = "pct-row";
540
+ const usageBar = segmentsTd.appendChild(document.createElement("div"));
541
+ for (const [k, v, width] of value) {
542
+ const seg = usageBar.appendChild(document.createElement("div"));
543
+ seg.style.width = width+"%";
544
+ seg.title = `${ret.cols[i].labels[k]} ${v}`;
545
+ seg.style.background = cycleColors(colorScheme.CATEGORICAL, parseInt(k));
546
+ }
547
+ }
548
+ }
549
+ }
550
+ const summary = metadata.appendChild(document.createElement("table"));
551
+ for (const s of ret.summary) {
552
+ const tr = summary.appendChild(document.createElement("tr"));
553
+ tr.className = "main-row";
554
+ const td = tr.appendChild(document.createElement("td"));
555
+ const div = td.appendChild(document.createElement("div"));
556
+ div.className = "legend";
557
+ div.appendChild(document.createElement("div")).style.background = cycleColors(colorScheme.CATEGORICAL, s.idx);
558
+ div.appendChild(document.createElement("p")).textContent = s.label;
559
+ appendTd(tr, s.value);
560
+ }
561
+ } else root.appendChild(codeBlock(ret.src, "x86asm"));
562
+ return document.querySelector(".profiler").replaceChildren(root);
563
+ }
564
+ // ** UOp view (default)
565
+ // if we don't have a complete cache yet we start streaming rewrites in this step
566
+ if (!(ckey in cache) || (cache[ckey].length !== step.match_count+1 && activeSrc == null)) {
567
+ ret = [];
568
+ cache[ckey] = ret;
569
+ const eventSource = new EventSource(ckey);
570
+ evtSources.push(eventSource);
571
+ eventSource.onmessage = (e) => {
572
+ if (e.data === "END") return eventSource.close();
573
+ const chunk = JSON.parse(e.data);
574
+ ret.push(chunk);
575
+ // if it's the first one render this new rgaph
576
+ if (ret.length === 1) return main();
577
+ // otherwise just enable the graph selector
578
+ const ul = document.getElementById(`rewrite-${ret.length-1}`);
579
+ if (ul != null) ul.classList.remove("disabled");
580
+ };
581
+ }
582
+ if (ret.length === 0) return;
583
+ renderDag(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || [], recenter=currentRewrite === 0);
584
+ // ** right sidebar code blocks
585
+ const metadata = document.querySelector(".metadata");
586
+ const [code, lang] = ctx.fmt != null ? [ctx.fmt, "cpp"] : [ret[currentRewrite].uop, "python"];
587
+ metadata.replaceChildren(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }), codeBlock(code, lang, { wrap:false }));
588
+ if (ctx.runtime_stats != null) {
589
+ const div = metadata.appendChild(document.createElement("div"));
590
+ div.className = "stats-list";
591
+ for (const [i, s] of ctx.runtime_stats.entries()) {
592
+ const p = div.appendChild(document.createElement("p"));
593
+ if (ctx.runtime_stats.length > 1) p.innerText = `Run ${i+1}/${ctx.runtime_stats.length}`;
594
+ const table = div.appendChild(document.createElement("table"));
595
+ const tbody = table.appendChild(document.createElement("tbody"));
596
+ for (const { name, value, unit, subunits } of s.data) {
597
+ const mainRow = appendRow(tbody, name, value, unit, "main-row");
598
+ if (!subunits?.length) continue;
599
+ const subunitRow = tbody.appendChild(document.createElement("tr"));
600
+ subunitRow.style.display = "none";
601
+ mainRow.onclick = () => subunitRow.style.display = subunitRow.style.display === "none" ? "table-row" : "none";
602
+ mainRow.style.cursor = "pointer";
603
+ const td = subunitRow.appendChild(document.createElement("td"));
604
+ td.colSpan = 2;
605
+ const table = td.appendChild(document.createElement("table"));
606
+ for (const u of subunits) appendRow(table, u.name, u.value, unit, "sub-row");
607
+ }
608
+ }
609
+ }
610
+ // ** rewrite steps
611
+ if (step.match_count >= 1) {
612
+ const rewriteList = metadata.appendChild(document.createElement("div"));
613
+ rewriteList.className = "rewrite-list";
614
+ for (let s=0; s<=step.match_count; s++) {
615
+ const ul = rewriteList.appendChild(document.createElement("ul"));
616
+ ul.innerText = s;
617
+ ul.id = `rewrite-${s}`;
618
+ ul.onclick = () => setState({ currentRewrite:s });
619
+ ul.className = s > ret.length-1 ? "disabled" : s === currentRewrite ? "active" : "";
620
+ if (s > 0 && s === currentRewrite) {
621
+ const { upat, diff } = ret[s];
622
+ metadata.appendChild(codeBlock(upat[1], "python", { loc:upat[0], wrap:true }));
623
+ const diffCode = metadata.appendChild(document.createElement("pre")).appendChild(document.createElement("code"));
624
+ for (const line of diff) {
625
+ const span = diffCode.appendChild(document.createElement("span"));
626
+ span.style.color = line.startsWith("+") ? "#3aa56d" : line.startsWith("-") ? "#d14b4b" : "#f0f0f5";
627
+ span.innerText = line;
628
+ diffCode.appendChild(document.createElement("br"));
629
+ }
630
+ diffCode.className = "wrap";
631
+ }
632
+ }
633
+ }
634
+ }
635
+
636
+ // **** collapse/expand
637
+
638
+ let isCollapsed = false;
639
+ document.querySelector(".collapse-btn").addEventListener("click", (e) => {
640
+ isCollapsed = !isCollapsed;
641
+ document.querySelector(".main-container").classList.toggle("collapsed", isCollapsed);
642
+ e.currentTarget.blur();
643
+ e.currentTarget.style.transform = isCollapsed ? "rotate(180deg)" : "rotate(0deg)";
644
+ window.dispatchEvent(new Event("resize"));
645
+ });
646
+
647
+ // **** resizer
648
+
649
+ function appendResizer(element, { minWidth, maxWidth }, left=false) {
650
+ const handle = Object.assign(document.createElement("div"), { className: "resize-handle", style: left ? "right: 0" : "left: 0; margin-top: 0" });
651
+ element.appendChild(handle);
652
+ const resize = (e) => {
653
+ const change = e.clientX - element.dataset.startX;
654
+ let newWidth = ((Number(element.dataset.startWidth)+(left ? change : -change))/Number(element.dataset.containerWidth))*100;
655
+ element.style.width = `${Math.max(minWidth, Math.min(maxWidth, newWidth))}%`;
656
+ };
657
+ handle.addEventListener("mousedown", (e) => {
658
+ e.preventDefault();
659
+ element.dataset.startX = e.clientX;
660
+ element.dataset.containerWidth = rect(".main-container").width;
661
+ element.dataset.startWidth = element.getBoundingClientRect().width;
662
+ document.documentElement.addEventListener("mousemove", resize, false);
663
+ document.documentElement.addEventListener("mouseup", () => {
664
+ document.documentElement.removeEventListener("mousemove", resize, false);
665
+ element.style.userSelect = "initial";
666
+ }, { once: true });
667
+ });
668
+ }
669
+ appendResizer(document.querySelector(".ctx-list-parent"), { minWidth: 15, maxWidth: 50 }, left=true);
670
+ appendResizer(document.querySelector(".metadata-parent"), { minWidth: 20, maxWidth: 50 });
671
+
672
+ // **** keyboard shortcuts
673
+
674
+ document.addEventListener("keydown", async function(event) {
675
+ const { currentCtx, currentStep, currentRewrite, expandSteps } = state;
676
+ // up and down change the step or context from the list
677
+ const changeStep = expandSteps && ctxs[currentCtx].steps?.length;
678
+ if (event.key == "ArrowUp") {
679
+ event.preventDefault();
680
+ if (changeStep) {
681
+ return setState({ currentRewrite:0, currentStep:Math.max(0, currentStep-1) });
682
+ }
683
+ return setState({ currentStep:0, currentRewrite:0, currentCtx:Math.max(0, currentCtx-1), expandSteps:false });
684
+ }
685
+ if (event.key == "ArrowDown") {
686
+ event.preventDefault();
687
+ if (changeStep) {
688
+ const totalUOps = ctxs[currentCtx].steps.length-1;
689
+ return setState({ currentRewrite:0, currentStep:Math.min(totalUOps, currentStep+1) });
690
+ }
691
+ return setState({ currentStep:0, currentRewrite:0, currentCtx:Math.min(ctxs.length-1, currentCtx+1), expandSteps:false });
692
+ }
693
+ // enter toggles focus on a single rewrite stage
694
+ if (event.key == "Enter") {
695
+ event.preventDefault()
696
+ if (currentCtx === -1) {
697
+ return setState({ currentCtx:0, expandSteps:true });
698
+ }
699
+ return setState({ expandSteps:!expandSteps });
700
+ }
701
+ // left and right go through rewrites in a single UOp
702
+ if (event.key == "ArrowLeft") {
703
+ event.preventDefault()
704
+ return setState({ currentRewrite:Math.max(0, currentRewrite-1) });
705
+ }
706
+ if (event.key == "ArrowRight") {
707
+ event.preventDefault()
708
+ const totalRewrites = ret.length-1;
709
+ return setState({ currentRewrite:Math.min(totalRewrites, currentRewrite+1) });
710
+ }
711
+ // space recenters the graph
712
+ if (event.key == " ") {
713
+ event.preventDefault()
714
+ document.getElementById("zoom-to-fit-btn").click();
715
+ }
716
+ });
717
+
718
+ main()