tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,8 @@
1
+ pre code.hljs{display:block;overflow-x:auto;padding:1em}code.hljs{padding:3px 5px}/*!
2
+ Theme: Tokyo-night-Dark
3
+ origin: https://github.com/enkia/tokyo-night-vscode-theme
4
+ Description: Original highlight.js style
5
+ Author: (c) Henri Vandersleyen <hvandersleyen@gmail.com>
6
+ License: see project LICENSE
7
+ Touched: 2022
8
+ */.hljs-comment,.hljs-meta{color:#565f89}.hljs-deletion,.hljs-doctag,.hljs-regexp,.hljs-selector-attr,.hljs-selector-class,.hljs-selector-id,.hljs-selector-pseudo,.hljs-tag,.hljs-template-tag,.hljs-variable.language_{color:#f7768e}.hljs-link,.hljs-literal,.hljs-number,.hljs-params,.hljs-template-variable,.hljs-type,.hljs-variable{color:#ff9e64}.hljs-attribute,.hljs-built_in{color:#e0af68}.hljs-keyword,.hljs-property,.hljs-subst,.hljs-title,.hljs-title.class_,.hljs-title.class_.inherited__,.hljs-title.function_{color:#7dcfff}.hljs-selector-tag{color:#73daca}.hljs-addition,.hljs-bullet,.hljs-quote,.hljs-string,.hljs-symbol{color:#9ece6a}.hljs-code,.hljs-formula,.hljs-section{color:#7aa2f7}.hljs-attr,.hljs-char.escape_,.hljs-keyword,.hljs-name,.hljs-operator{color:#bb9af7}.hljs-punctuation{color:#c0caf5}.hljs{background:#1a1b26;color:#9aa5ce}.hljs-emphasis{font-style:italic}.hljs-strong{font-weight:700}
@@ -0,0 +1,544 @@
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>tinygrad viz</title>
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <link rel="icon" href="data:;base64,iVBORw0KGgo=">
7
+ <script src="assets/d3js.org/d3.v5.min.js" charset="utf-8"></script>
8
+ <script src="assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js"></script>
9
+ <link rel="stylesheet" href="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css">
10
+ <script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"></script>
11
+ <script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js"></script>
12
+ <script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js"></script>
13
+ <link rel="stylesheet" href="assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css" />
14
+ <style>
15
+ * {
16
+ box-sizing: border-box;
17
+ margin-block-start: initial;
18
+ margin-block-end: initial;
19
+ }
20
+ button {
21
+ outline: none;
22
+ }
23
+ html, body {
24
+ color: #f0f0f5;
25
+ margin: 0;
26
+ padding: 0;
27
+ width: 100%;
28
+ height: 100%;
29
+ font-family: sans-serif;
30
+ font-optical-sizing: auto;
31
+ font-weight: 400;
32
+ font-style: normal;
33
+ font-variation-settings: "wdth" 100;
34
+ font-size: 14px;
35
+ overflow: hidden;
36
+ }
37
+ a {
38
+ color: #4a90e2;
39
+ }
40
+ ul {
41
+ padding: 0;
42
+ opacity: 0.6;
43
+ white-space: nowrap;
44
+ cursor: pointer;
45
+ }
46
+ ul.active {
47
+ opacity: 1;
48
+ }
49
+ ul.disabled {
50
+ opacity: 0.4;
51
+ pointer-events: none;
52
+ }
53
+ svg {
54
+ width: 100%;
55
+ height: 100%;
56
+ }
57
+ svg * {
58
+ cursor: default;
59
+ user-select: none;
60
+ }
61
+ .node rect {
62
+ stroke: #4a4b57;
63
+ stroke-width: 1.4px;
64
+ rx: 8px;
65
+ ry: 8px;
66
+ }
67
+ .label :is(text, p) {
68
+ color: #08090e;
69
+ font-weight: 350;
70
+ }
71
+ .edgePath path {
72
+ stroke: #4a4b57;
73
+ fill: #4a4b57;
74
+ stroke-width: 1.4px;
75
+ }
76
+ .main-container {
77
+ display: flex;
78
+ width: 100%;
79
+ height: 100%;
80
+ position: relative;
81
+ }
82
+ .container {
83
+ background-color: #0f1018;
84
+ }
85
+ .container > * + *, .rewrite-container > * + * {
86
+ margin-top: 12px;
87
+ }
88
+ .graph {
89
+ background-color: #08090e;
90
+ position: absolute;
91
+ inset: 0;
92
+ z-index: 1;
93
+ }
94
+ .kernel-list-parent {
95
+ position: relative;
96
+ width: 15%;
97
+ padding: 50px 20px 20px 20px;
98
+ border-right: 1px solid #4a4b56;
99
+ z-index: 2;
100
+ }
101
+ .kernel-list {
102
+ width: 100%;
103
+ height: 100%;
104
+ overflow-y: auto;
105
+ }
106
+ .kernel-list > ul > * + * {
107
+ margin-top: 4px;
108
+ }
109
+ .metadata {
110
+ position: relative;
111
+ width: 20%;
112
+ padding: 20px;
113
+ background-color: #0f1018;
114
+ border-left: 1px solid #4a4b56;
115
+ z-index: 2;
116
+ margin-left: auto;
117
+ height: 100%;
118
+ overflow-y: auto;
119
+ }
120
+ .resize-handle {
121
+ position: absolute;
122
+ top: 0;
123
+ bottom: 0;
124
+ width: 20px;
125
+ height: 100%;
126
+ cursor: col-resize;
127
+ z-index: 3;
128
+ background-color: transparent;
129
+ }
130
+ #kernel-resize-handle {
131
+ right: 0;
132
+ }
133
+ #metadata-resize-handle {
134
+ margin-top: 0;
135
+ left: 0;
136
+ }
137
+ .floating-container {
138
+ position: fixed;
139
+ top: 10px;
140
+ left: 20px;
141
+ z-index: 4;
142
+ display: flex;
143
+ flex-direction: row;
144
+ gap: 8px;
145
+ }
146
+ .nav-btn {
147
+ background-color: #1a1b26;
148
+ border: 1px solid #4a4b56;
149
+ color: #f0f0f5;
150
+ height: 32px;
151
+ border-radius: 8px;
152
+ cursor: pointer;
153
+ text-decoration: none;
154
+ display: flex;
155
+ align-items: center;
156
+ padding: 0 6px;
157
+ font-weight: bold;
158
+ }
159
+ .collapse-btn {
160
+ width: 32px;
161
+ padding: 6px;
162
+ }
163
+ .btn {
164
+ height: 32px;
165
+ background-color: #1a1b26;
166
+ border: 1px solid #4a4b56;
167
+ color: #f0f0f5;
168
+ border-radius: 8px;
169
+ cursor: pointer;
170
+ transition-duration: .5s;
171
+ }
172
+ .btn:hover {
173
+ background-color: #2a2b36;
174
+ border-color: #5a5b66;
175
+ color: #ffffff;
176
+ transform: translateY(-1px);
177
+ }
178
+ .collapsed .kernel-list, .collapsed .metadata {
179
+ width: 0;
180
+ padding: 0;
181
+ overflow: hidden;
182
+ }
183
+ .rewrite-list {
184
+ display: flex;
185
+ flex-wrap: wrap;
186
+ }
187
+ .rewrite-list > * + * {
188
+ margin-left: 4px;
189
+ }
190
+ .wrap {
191
+ word-wrap: break-word;
192
+ white-space: pre-wrap;
193
+ }
194
+ .code-block.hljs {
195
+ overflow-y: auto;
196
+ max-height: 30vh;
197
+ border-radius: 8px;
198
+ padding: 8px;
199
+ }
200
+ </style>
201
+ </head>
202
+ <body>
203
+ <div class="main-container">
204
+ <div class="floating-container">
205
+ <button class="btn collapse-btn">
206
+ <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M15 19l-7-7 7-7"/></svg>
207
+ </button>
208
+ <a class="btn nav-btn" href="/profiler">Profiler</a>
209
+ </div>
210
+ <div class="container kernel-list-parent"><div class="container kernel-list"></div></div>
211
+ <div class="graph">
212
+ <svg id="graph-svg">
213
+ <g id="render"></g>
214
+ </svg>
215
+ </div>
216
+ <div class="container metadata"></div>
217
+ </div>
218
+ <script>
219
+ // **** hljs extra definitions for UOps and float4
220
+ hljs.registerLanguage("python", (hljs) => ({
221
+ ...hljs.getLanguage("python"),
222
+ case_insensitive: false,
223
+ contains: [
224
+ { begin: 'dtypes\\.[a-zA-Z_][a-zA-Z0-9_-]*(\\.[a-zA-Z_][a-zA-Z0-9_-]*)*' + '(?=[.\\s\\n[:,(])', className: "type" },
225
+ { begin: 'dtypes\\.[a-zA-Z_][a-zA-Z0-9_-].vec*' + '(?=[.\\s\\n[:,(])', className: "type" },
226
+ { begin: '[a-zA-Z_][a-zA-Z0-9_-]*\\.[a-zA-Z_][a-zA-Z0-9_-]*' + '(?=[.\\s\\n[:,()])', className: "operator" },
227
+ { begin: '[A-Z][a-zA-Z0-9_]*(?=\\()', className: "section", ignoreEnd: true },
228
+ ...hljs.getLanguage("python").contains,
229
+ ]
230
+ }));
231
+ hljs.registerLanguage("cpp", (hljs) => ({
232
+ ...hljs.getLanguage('cpp'),
233
+ contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains]
234
+ }));
235
+
236
+ // **** D3
237
+ function recenterRects(svg, zoom) {
238
+ const svgBounds = svg.node().getBoundingClientRect();
239
+ for (const rect of svg.node().querySelectorAll("rect")) {
240
+ const rectBounds = rect.getBoundingClientRect();
241
+ const outOfBounds = rectBounds.top < svgBounds.top || rectBounds.left < svgBounds.left ||
242
+ rectBounds.bottom > svgBounds.bottom || rectBounds.right > svgBounds.right;
243
+ // if there's at least one rect in view we don't do anything
244
+ if (!outOfBounds) return;
245
+ }
246
+ svg.call(zoom.transform, d3.zoomIdentity)
247
+ }
248
+ function renderGraph(graph, additions) {
249
+ const g = new dagreD3.graphlib.Graph({ compound: true }).setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
250
+ g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: rgba(26, 27, 38, 0.5);" : "display: none;"});
251
+ for (const [k,u] of Object.entries(graph)) {
252
+ let node = {label: u[0], labelType: "text", style: `fill: ${u[2]};`};
253
+ // for PROGRAM UOp we render the node with a code block
254
+ if (u[0].includes("PROGRAM")) {
255
+ const [name, ...rest] = u[0].split("\n");
256
+ const label = Object.assign(document.createElement("div"));
257
+ label.appendChild(Object.assign(document.createElement("p"), {innerText: name, className: "label", style: "margin-bottom: 2px" }))
258
+ label.appendChild(highlightedCodeBlock(rest.join("\n"), "cpp", true));
259
+ node = {label, labelType: "html", style: `fill: ${u[2]}`};
260
+ }
261
+ g.setNode(k, node);
262
+ for (const src of u[1]) {
263
+ g.setEdge(src, k, {curve: d3.curveBasis})
264
+ }
265
+ if (additions.includes(parseInt(k))) {
266
+ g.setParent(k, "addition");
267
+ }
268
+ }
269
+ const svg = d3.select("#graph-svg");
270
+ const inner = svg.select("g");
271
+ var zoom = d3.zoom()
272
+ .scaleExtent([0.05, 2])
273
+ .on("zoom", () => {
274
+ const transform = d3.event.transform;
275
+ inner.attr("transform", transform);
276
+ });
277
+ recenterRects(svg, zoom);
278
+ svg.call(zoom);
279
+ const render = new dagreD3.render();
280
+ render(inner, g);
281
+ }
282
+
283
+ // **** extra helpers
284
+ const toPath = ([fp, lineno]) => `${fp.replaceAll("\\", "/").split("/").pop()}:${lineno}`;
285
+ const vsCodeOpener = (parts) => Object.assign(document.createElement("a"), { textContent: parts[parts.length-1]+"\n\n",
286
+ href: "vscode://file"+parts.join("/"), style: "font-family: monospace; margin: 4px 0;" });
287
+ const highlightedCodeBlock = (code, lang, wrap) => {
288
+ const pre = Object.assign(document.createElement("pre"), {className: wrap ? "wrap" : ""});
289
+ // NOTE: since code is in textContent, we don't need DOMPurify
290
+ const codeEl = Object.assign(document.createElement("code"), { className: `language-${lang} code-block`, textContent: code});
291
+ pre.appendChild(codeEl);
292
+ hljs.highlightElement(codeEl);
293
+ return pre;
294
+ };
295
+ const coloredToHTML = (str) => {
296
+ const colors = ['gray','red','green','yellow','blue','magenta','cyan','white'];
297
+ return str.replace(/\u001b\[(\d+)m(.*?)\u001b\[0m/g, (_, code, st) => {
298
+ return `<span style="${`color: color-mix(in srgb, ${colors[(parseInt(code)-30+60)%60]} 60%, white)`}">${st}</span>`;
299
+ })
300
+ }
301
+
302
+ // **** main loop
303
+ var ret = [];
304
+ var cache = {};
305
+ var kernels = null;
306
+ var currentUOp = 0;
307
+ var currentKernel = -1;
308
+ var currentRewrite = 0;
309
+ var expandKernel = true;
310
+ const evtSources = [];
311
+ async function main() {
312
+ const mainContainer = document.querySelector('.main-container');
313
+ // ***** LHS kernels list
314
+ if (kernels == null) {
315
+ kernels = await (await fetch("/kernels")).json();
316
+ currentKernel = -1;
317
+ }
318
+ const kernelListParent = document.querySelector(".container.kernel-list-parent");
319
+ const kernelList = document.querySelector(".container.kernel-list");
320
+ kernelList.innerHTML = "";
321
+ kernels.forEach(([key, items], i) => {
322
+ const kernelUl = Object.assign(document.createElement("ul"), { key: `kernel-${i}`, className: i === currentKernel ? "active" : "",
323
+ style: "overflow-x: auto; cursor: initial;" });
324
+ if (i === currentKernel) {
325
+ requestAnimationFrame(() => kernelUl.scrollIntoView({ behavior: "auto", block: "nearest" }));
326
+ }
327
+ const p = Object.assign(document.createElement("p"), { id: `kernel-${key}`, innerHTML: coloredToHTML(key), style: "cursor: pointer;"});
328
+ kernelUl.appendChild(p)
329
+ items.forEach((u, j) => {
330
+ const rwUl = Object.assign(document.createElement("ul"), {
331
+ innerText: u.name ? `${u.name} - ${u.match_count}` : `${toPath(u.loc)} - ${u.match_count}`, key: `uop-rewrite-${j}`,
332
+ className: (j === currentUOp && i == currentKernel) ? "active" : "" })
333
+ if (j === currentUOp) {
334
+ requestAnimationFrame(() => rwUl.scrollIntoView({ behavior: "auto", block: "nearest" }));
335
+ }
336
+ rwUl.style.display = i === currentKernel && expandKernel ? "block" : "none";
337
+ rwUl.onclick = (e) => {
338
+ e.stopPropagation();
339
+ currentUOp = j;
340
+ currentKernel = i;
341
+ currentRewrite = 0;
342
+ main();
343
+ }
344
+ kernelUl.appendChild(rwUl)
345
+ })
346
+ p.onclick = () => {
347
+ if (i === currentKernel) {
348
+ expandKernel = !expandKernel;
349
+ main();
350
+ return;
351
+ }
352
+ currentKernel = i;
353
+ currentUOp = 0;
354
+ currentRewrite = 0;
355
+ expandKernel = true;
356
+ main();
357
+ }
358
+ kernelList.appendChild(kernelUl);
359
+ });
360
+ // ***** UOp graph
361
+ if (currentKernel == -1) return;
362
+ const kernel = kernels[currentKernel][1][currentUOp];
363
+ const cacheKey = `kernel=${currentKernel}&idx=${currentUOp}`;
364
+ // close any pending event sources
365
+ let activeSrc = null;
366
+ for (const e of evtSources) {
367
+ if (e.url.split("?")[1] !== cacheKey) e.close();
368
+ else if (e.readyState === EventSource.OPEN) activeSrc = e;
369
+ }
370
+ if (cacheKey in cache) {
371
+ ret = cache[cacheKey];
372
+ }
373
+ // if we don't have a complete cache yet we start streaming this kernel
374
+ if (!(cacheKey in cache) || (cache[cacheKey].length !== kernel.match_count+1 && activeSrc == null)) {
375
+ ret = [];
376
+ cache[cacheKey] = ret;
377
+ const eventSource = new EventSource(`/kernels?kernel=${currentKernel}&idx=${currentUOp}`);
378
+ evtSources.push(eventSource);
379
+ eventSource.onmessage = (e) => {
380
+ if (e.data === "END") return eventSource.close();
381
+ const chunk = JSON.parse(e.data);
382
+ ret.push(chunk);
383
+ // if it's the first one render this new rgaph
384
+ if (ret.length === 1) return main();
385
+ // otherwise just enable the graph selector
386
+ const gUl = document.getElementById(`rewrite-${ret.length-1}`);
387
+ if (gUl != null) gUl.classList.remove("disabled");
388
+ };
389
+ }
390
+ if (ret.length === 0) return;
391
+ renderGraph(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || []);
392
+ // ***** RHS metadata
393
+ const metadata = document.querySelector(".container.metadata");
394
+ metadata.innerHTML = "";
395
+ metadata.appendChild(vsCodeOpener(kernel.loc.join(":").split("/")));
396
+ metadata.appendChild(highlightedCodeBlock(kernel.code_line, "python", true));
397
+ // ** code blocks
398
+ let code = ret[currentRewrite].uop;
399
+ let lang = "python"
400
+ if (kernel.kernel_code != null) {
401
+ code = kernel.kernel_code;
402
+ lang = "cpp";
403
+ }
404
+ const codeBlock = highlightedCodeBlock(code, lang, false);
405
+ metadata.appendChild(codeBlock);
406
+ // ** rewrite list
407
+ if (kernel.match_count >= 1) {
408
+ const rewriteList = Object.assign(document.createElement("div"), { className: "rewrite-list" })
409
+ metadata.appendChild(rewriteList);
410
+ for (let i=0; i<=kernel.match_count; i++) {
411
+ const gUl = Object.assign(document.createElement("ul"), { innerText: i, id: `rewrite-${i}` });
412
+ rewriteList.appendChild(gUl);
413
+ if (i > ret.length-1) gUl.classList.add("disabled");
414
+ if (i === currentRewrite) {
415
+ gUl.classList.add("active");
416
+ if (i !== 0) {
417
+ const diff = ret[i].diff;
418
+ const [loc, pattern] = ret[i].upat;
419
+ const parts = loc.join(":").split("/");
420
+ const div = Object.assign(document.createElement("div"), { className: "rewrite-container" });
421
+ const link = vsCodeOpener(parts);
422
+ div.appendChild(link);
423
+ const pre = highlightedCodeBlock(pattern, "python", true);
424
+ div.appendChild(pre);
425
+ metadata.appendChild(div);
426
+ const diffHtml = diff.map((line) => {
427
+ const color = line.startsWith("+") ? "#3aa56d" : line.startsWith("-") ? "#d14b4b" : "#f0f0f5";
428
+ return `<span style="color: ${color};">${line}</span>`;
429
+ }).join("<br>");
430
+ metadata.appendChild(Object.assign(document.createElement("pre"), { innerHTML: `<code>${diffHtml}</code>`, className: "wrap" }));
431
+ }
432
+ }
433
+ gUl.addEventListener("click", () => {
434
+ currentRewrite = i;
435
+ main();
436
+ });
437
+ }
438
+ } else {
439
+ metadata.appendChild(Object.assign(document.createElement("p"), { textContent: `No rewrites in ${toPath(kernel.loc)}.` }));
440
+ }
441
+ // ***** collapse/expand
442
+ let isCollapsed = false;
443
+ const collapseBtn = document.querySelector(".collapse-btn");
444
+ collapseBtn.addEventListener("click", () => {
445
+ isCollapsed = !isCollapsed;
446
+ mainContainer.classList.toggle("collapsed", isCollapsed);
447
+ kernelListParent.style.display = isCollapsed ? "none" : "block";
448
+ metadata.style.display = isCollapsed ? "none" : "block";
449
+ collapseBtn.style.transform = isCollapsed ? "rotate(180deg)" : "rotate(0deg)";
450
+ });
451
+ // ***** resizer
452
+ function createResizer(element, width, type) {
453
+ const { minWidth, maxWidth } = width;
454
+ const handle = Object.assign(document.createElement("div"), { id: `${type}-resize-handle`, className: "resize-handle" });
455
+ element.appendChild(handle);
456
+
457
+ const resize = (e) => {
458
+ const change = e.clientX - element.dataset.startX;
459
+ const adjustedChange = type === "kernel" ? change : -change;
460
+ const newWidth = ((Number(element.dataset.startWidth) + adjustedChange) / Number(element.dataset.containerWidth)) * 100;
461
+ if (newWidth >= minWidth && newWidth <= maxWidth) {
462
+ element.style.width = `${newWidth}%`;
463
+ }
464
+ };
465
+
466
+ handle.addEventListener("mousedown", (e) => {
467
+ e.preventDefault();
468
+ element.dataset.startX = e.clientX;
469
+ element.dataset.containerWidth = mainContainer.getBoundingClientRect().width;
470
+ element.dataset.startWidth = element.getBoundingClientRect().width;
471
+
472
+ document.documentElement.addEventListener("mousemove", resize, false);
473
+ document.documentElement.addEventListener("mouseup", () => {
474
+ document.documentElement.removeEventListener("mousemove", resize, false);
475
+ element.style.userSelect = "initial";
476
+ }, { once: true });
477
+ });
478
+ }
479
+ createResizer(kernelListParent, { minWidth: 15, maxWidth: 50 }, "kernel"); // left resizer
480
+ createResizer(metadata, { minWidth: 20, maxWidth: 50 }, "metadata"); // right resizer
481
+ }
482
+
483
+ // **** keyboard shortcuts
484
+ document.addEventListener("keydown", async function(event) {
485
+ // up and down change the UOp or kernel from the list
486
+ if (!expandKernel) {
487
+ if (event.key == "ArrowUp") {
488
+ event.preventDefault()
489
+ currentUOp = 0;
490
+ currentRewrite = 0;
491
+ currentKernel = Math.max(0, currentKernel-1)
492
+ return main()
493
+ }
494
+ if (event.key == "ArrowDown") {
495
+ event.preventDefault()
496
+ currentUOp = 0;
497
+ currentRewrite = 0;
498
+ currentKernel = Math.min(kernels.length-1, currentKernel+1);
499
+ return main()
500
+ }
501
+ }
502
+ if (event.key == "Enter") {
503
+ event.preventDefault()
504
+ if (currentKernel === -1) {
505
+ currentKernel = 0;
506
+ expandKernel = true;
507
+ }
508
+ else {
509
+ expandKernel = !expandKernel;
510
+ }
511
+ currentUOp = 0;
512
+ currentRewrite = 0;
513
+ main();
514
+ }
515
+ if (event.key == "ArrowUp") {
516
+ event.preventDefault()
517
+ currentRewrite = 0;
518
+ currentUOp = Math.max(0, currentUOp-1)
519
+ main()
520
+ }
521
+ if (event.key == "ArrowDown") {
522
+ event.preventDefault()
523
+ currentRewrite = 0;
524
+ const totalUOps = kernels[currentKernel][1].length-1;
525
+ currentUOp = Math.min(totalUOps, currentUOp+1)
526
+ main()
527
+ }
528
+ // left and right go through rewrites in a single UOp
529
+ if (event.key == "ArrowLeft") {
530
+ event.preventDefault()
531
+ currentRewrite = Math.max(0, currentRewrite-1)
532
+ main()
533
+ }
534
+ if (event.key == "ArrowRight") {
535
+ event.preventDefault()
536
+ const totalRewrites = ret.length-1;
537
+ currentRewrite = Math.min(totalRewrites, currentRewrite+1)
538
+ main()
539
+ }
540
+ })
541
+ main()
542
+ </script>
543
+ </body>
544
+ </html>