tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/viz/index.html CHANGED
@@ -4,21 +4,19 @@
4
4
  <title>tinygrad viz</title>
5
5
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
6
  <link rel="icon" href="data:;base64,iVBORw0KGgo=">
7
- <script src="assets/d3js.org/d3.v5.min.js" charset="utf-8"></script>
8
- <script src="assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js"></script>
7
+ <script src="assets/d3js.org/d3.v7.min.js" charset="utf-8"></script>
8
+ <script src="assets/dagrejs.github.io/project/dagre/latest/dagre.min.js"></script>
9
9
  <link rel="stylesheet" href="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css">
10
10
  <script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"></script>
11
11
  <script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js"></script>
12
12
  <script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js"></script>
13
+ <script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js"></script>
13
14
  <link rel="stylesheet" href="assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css" />
14
15
  <style>
15
16
  * {
16
17
  box-sizing: border-box;
17
- margin-block-start: initial;
18
- margin-block-end: initial;
19
- }
20
- button {
21
- outline: none;
18
+ margin: 0;
19
+ padding: 0;
22
20
  }
23
21
  html, body {
24
22
  color: #f0f0f5;
@@ -33,6 +31,7 @@
33
31
  font-variation-settings: "wdth" 100;
34
32
  font-size: 14px;
35
33
  overflow: hidden;
34
+ background-color: #08090e;
36
35
  }
37
36
  a {
38
37
  color: #4a90e2;
@@ -46,11 +45,22 @@
46
45
  ul.active {
47
46
  opacity: 1;
48
47
  }
48
+ ul > ul {
49
+ display: none;
50
+ }
51
+ ul.expanded > ul {
52
+ display: block;
53
+ }
49
54
  ul.disabled {
50
55
  opacity: 0.4;
51
56
  pointer-events: none;
52
57
  }
53
- svg {
58
+ label {
59
+ display: inline-flex;
60
+ align-items: center;
61
+ gap: 4px;
62
+ }
63
+ .graph svg {
54
64
  width: 100%;
55
65
  height: 100%;
56
66
  }
@@ -58,19 +68,26 @@
58
68
  cursor: default;
59
69
  user-select: none;
60
70
  }
61
- .node rect {
62
- stroke: #4a4b57;
63
- stroke-width: 1.4px;
64
- rx: 8px;
65
- ry: 8px;
71
+ g.clickable * {
72
+ cursor: pointer;
73
+ user-select: auto;
74
+ }
75
+ g.tag circle {
76
+ fill: #FFD700;
77
+ stroke: #B8860B;
78
+ stroke-width: 0.8;
79
+ }
80
+ g.tag text {
81
+ text-anchor: middle;
82
+ font-size: 6px;
83
+ fill: #08090e;
66
84
  }
67
85
  .label :is(text, p) {
68
- color: #08090e;
69
86
  font-weight: 350;
70
87
  }
71
- .edgePath path {
88
+ .edgePath {
72
89
  stroke: #4a4b57;
73
- fill: #4a4b57;
90
+ fill: none;
74
91
  stroke-width: 1.4px;
75
92
  }
76
93
  .main-container {
@@ -80,42 +97,60 @@
80
97
  position: relative;
81
98
  }
82
99
  .container {
100
+ flex: 0 0 auto;
83
101
  background-color: #0f1018;
102
+ padding: 20px;
103
+ z-index: 2;
104
+ position: relative;
105
+ height: 100%;
84
106
  }
85
- .container > * + *, .rewrite-container > * + * {
107
+ .metadata > * + *, .rewrite-container > * + *, .ctx-list > * + * {
86
108
  margin-top: 12px;
87
109
  }
110
+ .stats-list > * + * {
111
+ margin-top: 8px;
112
+ }
113
+ .stats-list > p > * + * {
114
+ margin-top: 12px;
115
+ }
116
+ .stats-list {
117
+ width: 100%;
118
+ max-height: 240px;
119
+ overflow: auto;
120
+ }
121
+ .ctx-list > ul > * + * {
122
+ margin-top: 4px;
123
+ }
88
124
  .graph {
89
- background-color: #08090e;
90
125
  position: absolute;
91
126
  inset: 0;
92
127
  z-index: 1;
93
128
  }
94
- .kernel-list-parent {
95
- position: relative;
129
+ .profiler {
130
+ flex: 1 1 auto;
131
+ min-width: 0;
132
+ width: 100%;
133
+ height: calc(100% - 50px);
134
+ margin-top: 50px;
135
+ overflow-y: auto;
136
+ overflow-x: hidden;
137
+ scrollbar-gutter: stable;
138
+ }
139
+ .ctx-list-parent {
96
140
  width: 15%;
97
- padding: 50px 20px 20px 20px;
141
+ padding-top: 50px;
98
142
  border-right: 1px solid #4a4b56;
99
- z-index: 2;
100
143
  }
101
- .kernel-list {
144
+ .ctx-list, .metadata {
102
145
  width: 100%;
103
146
  height: 100%;
104
147
  overflow-y: auto;
148
+ scrollbar-gutter: stable;
105
149
  }
106
- .kernel-list > ul > * + * {
107
- margin-top: 4px;
108
- }
109
- .metadata {
110
- position: relative;
150
+ .metadata-parent {
111
151
  width: 20%;
112
- padding: 20px;
113
- background-color: #0f1018;
114
152
  border-left: 1px solid #4a4b56;
115
- z-index: 2;
116
153
  margin-left: auto;
117
- height: 100%;
118
- overflow-y: auto;
119
154
  }
120
155
  .resize-handle {
121
156
  position: absolute;
@@ -127,13 +162,6 @@
127
162
  z-index: 3;
128
163
  background-color: transparent;
129
164
  }
130
- #kernel-resize-handle {
131
- right: 0;
132
- }
133
- #metadata-resize-handle {
134
- margin-top: 0;
135
- left: 0;
136
- }
137
165
  .floating-container {
138
166
  position: fixed;
139
167
  top: 10px;
@@ -143,402 +171,174 @@
143
171
  flex-direction: row;
144
172
  gap: 8px;
145
173
  }
146
- .nav-btn {
174
+ .btn {
175
+ outline: none;
147
176
  background-color: #1a1b26;
148
177
  border: 1px solid #4a4b56;
149
178
  color: #f0f0f5;
150
- height: 32px;
151
- border-radius: 8px;
179
+ border-radius: 4px;
180
+ padding: 6px;
152
181
  cursor: pointer;
153
- text-decoration: none;
182
+ height: 32px;
154
183
  display: flex;
155
184
  align-items: center;
156
- padding: 0 6px;
157
- font-weight: bold;
158
- }
159
- .collapse-btn {
160
- width: 32px;
161
- padding: 6px;
162
- }
163
- .btn {
164
- height: 32px;
165
- background-color: #1a1b26;
166
- border: 1px solid #4a4b56;
167
- color: #f0f0f5;
168
- border-radius: 8px;
169
- cursor: pointer;
170
- transition-duration: .5s;
185
+ justify-content: center;
186
+ text-decoration: none;
171
187
  }
172
188
  .btn:hover {
173
189
  background-color: #2a2b36;
174
- border-color: #5a5b66;
175
- color: #ffffff;
176
- transform: translateY(-1px);
177
190
  }
178
- .collapsed .kernel-list, .collapsed .metadata {
179
- width: 0;
180
- padding: 0;
181
- overflow: hidden;
191
+ .collapsed .container {
192
+ display: none;
182
193
  }
183
194
  .rewrite-list {
184
195
  display: flex;
185
196
  flex-wrap: wrap;
186
197
  }
187
- .rewrite-list > * + * {
188
- margin-left: 4px;
198
+ .rewrite-list > ul {
199
+ padding: 2px;
189
200
  }
190
201
  .wrap {
191
202
  word-wrap: break-word;
192
203
  white-space: pre-wrap;
193
204
  }
194
- .code-block.hljs {
205
+ pre code.hljs {
195
206
  overflow-y: auto;
196
207
  max-height: 30vh;
197
- border-radius: 8px;
198
208
  padding: 8px;
199
209
  }
210
+ #progress-message {
211
+ position: absolute;
212
+ z-index: 2;
213
+ left: 50%;
214
+ top: 2%;
215
+ color: #ffd230;
216
+ display: none;
217
+ }
218
+ #tooltip {
219
+ position: absolute;
220
+ z-index: 4;
221
+ background-color: #1e2029;
222
+ padding: 4px 8px;
223
+ border-radius: 4px;
224
+ pointer-events: none;
225
+ display: none;
226
+ font-size: 10px;
227
+ white-space: pre;
228
+ }
229
+ #device-list > div {
230
+ min-height: 32px;
231
+ max-width: 132px;
232
+ overflow-x: auto;
233
+ overflow-y: hidden;
234
+ white-space: nowrap;
235
+ display: flex;
236
+ }
237
+ #device-list > div:hover {
238
+ background-color: rgba(20, 23, 35, 0.3);
239
+ }
240
+ .raw-text {
241
+ padding: 0 8px;
242
+ width: 100%;
243
+ height: 100%;
244
+ max-height: 100vh;
245
+ overflow-x: auto;
246
+ }
247
+ .raw-text code {
248
+ max-height: none !important;
249
+ }
250
+ table {
251
+ width: 100%;
252
+ border-collapse: separate;
253
+ border-spacing: 0;
254
+ background-color: #1a1b26;
255
+ color: #f0f0f5;
256
+ font-size: 0.95em;
257
+ }
258
+ table td {
259
+ border-bottom: 1px solid #4a4b56;
260
+ vertical-align: top;
261
+ }
262
+ table tr:last-child > td {
263
+ border-bottom: none;
264
+ }
265
+ tr.main-row:hover {
266
+ background-color: #2a2d3a;
267
+ }
268
+ tr.sub-row {
269
+ max-width: 150px;
270
+ }
271
+ tr.main-row > td, tr.sub-row > td {
272
+ padding: 8px 12px;
273
+ }
274
+ tr.code-row > td:first-child {
275
+ font-family: monospace;
276
+ }
277
+ td.pct-row > div {
278
+ height: 12px;
279
+ width: 100%;
280
+ display: flex;
281
+ }
282
+ td.pct-row > div > div {
283
+ height: 100%;
284
+ }
285
+ thead {
286
+ position: sticky;
287
+ top: 0;
288
+ z-index: 10;
289
+ background-color: #20222e;
290
+ }
291
+ thead th {
292
+ text-align: left;
293
+ padding: 10px 12px;
294
+ font-weight: 600;
295
+ border-bottom: 1px solid #4a4b56;
296
+ font-size: 0.95em;
297
+ letter-spacing: 0.03em;
298
+ }
299
+ .legend {
300
+ display: flex;
301
+ align-items: center;
302
+ }
303
+ .legend > div {
304
+ width: 0.95em;
305
+ height: 0.95em;
306
+ margin-right: 4px;
307
+ }
200
308
  </style>
201
309
  </head>
202
310
  <body>
203
311
  <div class="main-container">
204
312
  <div class="floating-container">
205
313
  <button class="btn collapse-btn">
206
- <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M15 19l-7-7 7-7"/></svg>
314
+ <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" width="20"><path d="M15 19l-7-7 7-7"/></svg>
315
+ </button>
316
+ <button class="btn" id="zoom-to-fit-btn" aria-label="Fit graph">
317
+ <svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" width="20">
318
+ <path stroke-linecap="round" stroke-linejoin="round" d="M7.5 3.75H6A2.25 2.25 0 0 0 3.75 6v1.5M16.5 3.75H18A2.25 2.25 0 0 1 20.25 6v1.5m0 9V18A2.25 2.25 0 0 1 18 20.25h-1.5m-9 0H6A2.25 2.25 0 0 1 3.75 18v-1.5M15 12a3 3 0 1 1-6 0 3 3 0 0 1 6 0Z" />
319
+ </svg>
207
320
  </button>
208
- <a class="btn nav-btn" href="/profiler">Profiler</a>
209
321
  </div>
210
- <div class="container kernel-list-parent"><div class="container kernel-list"></div></div>
211
- <div class="graph">
212
- <svg id="graph-svg">
213
- <g id="render"></g>
322
+ <div id="progress-message"></div>
323
+ <div class="container ctx-list-parent"><div class="ctx-list"></div></div>
324
+ <div class="view profiler"></div>
325
+ <div class="view graph">
326
+ <svg id="graph-svg" preserveAspectRatio="xMidYMid meet">
327
+ <g id="render">
328
+ <g id="edges"></g>
329
+ <g id="nodes"></g>
330
+ <g id="edge-labels"></g> <!-- NOTE: this ensures edge labels are always on top -->
331
+ </g>
332
+ <defs>
333
+ <marker id="arrowhead" viewBox="0 -5 10 10" refX="10" refY="0" markerWidth="6" markerHeight="6" orient="auto">
334
+ <path d="M0,-5L10,0L0,5" fill="#4a4b57"></path>
335
+ </marker>
336
+ </defs>
214
337
  </svg>
215
338
  </div>
216
- <div class="container metadata"></div>
339
+ <div class="container metadata-parent"><div class="metadata"></div></div>
217
340
  </div>
218
- <script>
219
- // **** hljs extra definitions for UOps and float4
220
- hljs.registerLanguage("python", (hljs) => ({
221
- ...hljs.getLanguage("python"),
222
- case_insensitive: false,
223
- contains: [
224
- { begin: 'dtypes\\.[a-zA-Z_][a-zA-Z0-9_-]*(\\.[a-zA-Z_][a-zA-Z0-9_-]*)*' + '(?=[.\\s\\n[:,(])', className: "type" },
225
- { begin: 'dtypes\\.[a-zA-Z_][a-zA-Z0-9_-].vec*' + '(?=[.\\s\\n[:,(])', className: "type" },
226
- { begin: '[a-zA-Z_][a-zA-Z0-9_-]*\\.[a-zA-Z_][a-zA-Z0-9_-]*' + '(?=[.\\s\\n[:,()])', className: "operator" },
227
- { begin: '[A-Z][a-zA-Z0-9_]*(?=\\()', className: "section", ignoreEnd: true },
228
- ...hljs.getLanguage("python").contains,
229
- ]
230
- }));
231
- hljs.registerLanguage("cpp", (hljs) => ({
232
- ...hljs.getLanguage('cpp'),
233
- contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains]
234
- }));
235
-
236
- // **** D3
237
- function recenterRects(svg, zoom) {
238
- const svgBounds = svg.node().getBoundingClientRect();
239
- for (const rect of svg.node().querySelectorAll("rect")) {
240
- const rectBounds = rect.getBoundingClientRect();
241
- const outOfBounds = rectBounds.top < svgBounds.top || rectBounds.left < svgBounds.left ||
242
- rectBounds.bottom > svgBounds.bottom || rectBounds.right > svgBounds.right;
243
- // if there's at least one rect in view we don't do anything
244
- if (!outOfBounds) return;
245
- }
246
- svg.call(zoom.transform, d3.zoomIdentity)
247
- }
248
- function renderGraph(graph, additions) {
249
- const g = new dagreD3.graphlib.Graph({ compound: true }).setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
250
- g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: rgba(26, 27, 38, 0.5);" : "display: none;"});
251
- for (const [k,u] of Object.entries(graph)) {
252
- let node = {label: u[0], labelType: "text", style: `fill: ${u[2]};`};
253
- // for PROGRAM UOp we render the node with a code block
254
- if (u[0].includes("PROGRAM")) {
255
- const [name, ...rest] = u[0].split("\n");
256
- const label = Object.assign(document.createElement("div"));
257
- label.appendChild(Object.assign(document.createElement("p"), {innerText: name, className: "label", style: "margin-bottom: 2px" }))
258
- label.appendChild(highlightedCodeBlock(rest.join("\n"), "cpp", true));
259
- node = {label, labelType: "html", style: `fill: ${u[2]}`};
260
- }
261
- g.setNode(k, node);
262
- for (const src of u[1]) {
263
- g.setEdge(src, k, {curve: d3.curveBasis})
264
- }
265
- if (additions.includes(parseInt(k))) {
266
- g.setParent(k, "addition");
267
- }
268
- }
269
- const svg = d3.select("#graph-svg");
270
- const inner = svg.select("g");
271
- var zoom = d3.zoom()
272
- .scaleExtent([0.05, 2])
273
- .on("zoom", () => {
274
- const transform = d3.event.transform;
275
- inner.attr("transform", transform);
276
- });
277
- recenterRects(svg, zoom);
278
- svg.call(zoom);
279
- const render = new dagreD3.render();
280
- render(inner, g);
281
- }
282
-
283
- // **** extra helpers
284
- const toPath = ([fp, lineno]) => `${fp.replaceAll("\\", "/").split("/").pop()}:${lineno}`;
285
- const vsCodeOpener = (parts) => Object.assign(document.createElement("a"), { textContent: parts[parts.length-1]+"\n\n",
286
- href: "vscode://file"+parts.join("/"), style: "font-family: monospace; margin: 4px 0;" });
287
- const highlightedCodeBlock = (code, lang, wrap) => {
288
- const pre = Object.assign(document.createElement("pre"), {className: wrap ? "wrap" : ""});
289
- // NOTE: since code is in textContent, we don't need DOMPurify
290
- const codeEl = Object.assign(document.createElement("code"), { className: `language-${lang} code-block`, textContent: code});
291
- pre.appendChild(codeEl);
292
- hljs.highlightElement(codeEl);
293
- return pre;
294
- };
295
- const coloredToHTML = (str) => {
296
- const colors = ['gray','red','green','yellow','blue','magenta','cyan','white'];
297
- return str.replace(/\u001b\[(\d+)m(.*?)\u001b\[0m/g, (_, code, st) => {
298
- return `<span style="${`color: color-mix(in srgb, ${colors[(parseInt(code)-30+60)%60]} 60%, white)`}">${st}</span>`;
299
- })
300
- }
301
-
302
- // **** main loop
303
- var ret = [];
304
- var cache = {};
305
- var kernels = null;
306
- var currentUOp = 0;
307
- var currentKernel = -1;
308
- var currentRewrite = 0;
309
- var expandKernel = true;
310
- const evtSources = [];
311
- async function main() {
312
- const mainContainer = document.querySelector('.main-container');
313
- // ***** LHS kernels list
314
- if (kernels == null) {
315
- kernels = await (await fetch("/kernels")).json();
316
- currentKernel = -1;
317
- }
318
- const kernelListParent = document.querySelector(".container.kernel-list-parent");
319
- const kernelList = document.querySelector(".container.kernel-list");
320
- kernelList.innerHTML = "";
321
- kernels.forEach(([key, items], i) => {
322
- const kernelUl = Object.assign(document.createElement("ul"), { key: `kernel-${i}`, className: i === currentKernel ? "active" : "",
323
- style: "overflow-x: auto; cursor: initial;" });
324
- if (i === currentKernel) {
325
- requestAnimationFrame(() => kernelUl.scrollIntoView({ behavior: "auto", block: "nearest" }));
326
- }
327
- const p = Object.assign(document.createElement("p"), { id: `kernel-${key}`, innerHTML: coloredToHTML(key), style: "cursor: pointer;"});
328
- kernelUl.appendChild(p)
329
- items.forEach((u, j) => {
330
- const rwUl = Object.assign(document.createElement("ul"), {
331
- innerText: u.name ? `${u.name} - ${u.match_count}` : `${toPath(u.loc)} - ${u.match_count}`, key: `uop-rewrite-${j}`,
332
- className: (j === currentUOp && i == currentKernel) ? "active" : "" })
333
- if (j === currentUOp) {
334
- requestAnimationFrame(() => rwUl.scrollIntoView({ behavior: "auto", block: "nearest" }));
335
- }
336
- rwUl.style.display = i === currentKernel && expandKernel ? "block" : "none";
337
- rwUl.onclick = (e) => {
338
- e.stopPropagation();
339
- currentUOp = j;
340
- currentKernel = i;
341
- currentRewrite = 0;
342
- main();
343
- }
344
- kernelUl.appendChild(rwUl)
345
- })
346
- p.onclick = () => {
347
- if (i === currentKernel) {
348
- expandKernel = !expandKernel;
349
- main();
350
- return;
351
- }
352
- currentKernel = i;
353
- currentUOp = 0;
354
- currentRewrite = 0;
355
- expandKernel = true;
356
- main();
357
- }
358
- kernelList.appendChild(kernelUl);
359
- });
360
- // ***** UOp graph
361
- if (currentKernel == -1) return;
362
- const kernel = kernels[currentKernel][1][currentUOp];
363
- const cacheKey = `kernel=${currentKernel}&idx=${currentUOp}`;
364
- // close any pending event sources
365
- let activeSrc = null;
366
- for (const e of evtSources) {
367
- if (e.url.split("?")[1] !== cacheKey) e.close();
368
- else if (e.readyState === EventSource.OPEN) activeSrc = e;
369
- }
370
- if (cacheKey in cache) {
371
- ret = cache[cacheKey];
372
- }
373
- // if we don't have a complete cache yet we start streaming this kernel
374
- if (!(cacheKey in cache) || (cache[cacheKey].length !== kernel.match_count+1 && activeSrc == null)) {
375
- ret = [];
376
- cache[cacheKey] = ret;
377
- const eventSource = new EventSource(`/kernels?kernel=${currentKernel}&idx=${currentUOp}`);
378
- evtSources.push(eventSource);
379
- eventSource.onmessage = (e) => {
380
- if (e.data === "END") return eventSource.close();
381
- const chunk = JSON.parse(e.data);
382
- ret.push(chunk);
383
- // if it's the first one render this new rgaph
384
- if (ret.length === 1) return main();
385
- // otherwise just enable the graph selector
386
- const gUl = document.getElementById(`rewrite-${ret.length-1}`);
387
- if (gUl != null) gUl.classList.remove("disabled");
388
- };
389
- }
390
- if (ret.length === 0) return;
391
- renderGraph(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || []);
392
- // ***** RHS metadata
393
- const metadata = document.querySelector(".container.metadata");
394
- metadata.innerHTML = "";
395
- metadata.appendChild(vsCodeOpener(kernel.loc.join(":").split("/")));
396
- metadata.appendChild(highlightedCodeBlock(kernel.code_line, "python", true));
397
- // ** code blocks
398
- let code = ret[currentRewrite].uop;
399
- let lang = "python"
400
- if (kernel.kernel_code != null) {
401
- code = kernel.kernel_code;
402
- lang = "cpp";
403
- }
404
- const codeBlock = highlightedCodeBlock(code, lang, false);
405
- metadata.appendChild(codeBlock);
406
- // ** rewrite list
407
- if (kernel.match_count >= 1) {
408
- const rewriteList = Object.assign(document.createElement("div"), { className: "rewrite-list" })
409
- metadata.appendChild(rewriteList);
410
- for (let i=0; i<=kernel.match_count; i++) {
411
- const gUl = Object.assign(document.createElement("ul"), { innerText: i, id: `rewrite-${i}` });
412
- rewriteList.appendChild(gUl);
413
- if (i > ret.length-1) gUl.classList.add("disabled");
414
- if (i === currentRewrite) {
415
- gUl.classList.add("active");
416
- if (i !== 0) {
417
- const diff = ret[i].diff;
418
- const [loc, pattern] = ret[i].upat;
419
- const parts = loc.join(":").split("/");
420
- const div = Object.assign(document.createElement("div"), { className: "rewrite-container" });
421
- const link = vsCodeOpener(parts);
422
- div.appendChild(link);
423
- const pre = highlightedCodeBlock(pattern, "python", true);
424
- div.appendChild(pre);
425
- metadata.appendChild(div);
426
- const diffHtml = diff.map((line) => {
427
- const color = line.startsWith("+") ? "#3aa56d" : line.startsWith("-") ? "#d14b4b" : "#f0f0f5";
428
- return `<span style="color: ${color};">${line}</span>`;
429
- }).join("<br>");
430
- metadata.appendChild(Object.assign(document.createElement("pre"), { innerHTML: `<code>${diffHtml}</code>`, className: "wrap" }));
431
- }
432
- }
433
- gUl.addEventListener("click", () => {
434
- currentRewrite = i;
435
- main();
436
- });
437
- }
438
- } else {
439
- metadata.appendChild(Object.assign(document.createElement("p"), { textContent: `No rewrites in ${toPath(kernel.loc)}.` }));
440
- }
441
- // ***** collapse/expand
442
- let isCollapsed = false;
443
- const collapseBtn = document.querySelector(".collapse-btn");
444
- collapseBtn.addEventListener("click", () => {
445
- isCollapsed = !isCollapsed;
446
- mainContainer.classList.toggle("collapsed", isCollapsed);
447
- kernelListParent.style.display = isCollapsed ? "none" : "block";
448
- metadata.style.display = isCollapsed ? "none" : "block";
449
- collapseBtn.style.transform = isCollapsed ? "rotate(180deg)" : "rotate(0deg)";
450
- });
451
- // ***** resizer
452
- function createResizer(element, width, type) {
453
- const { minWidth, maxWidth } = width;
454
- const handle = Object.assign(document.createElement("div"), { id: `${type}-resize-handle`, className: "resize-handle" });
455
- element.appendChild(handle);
456
-
457
- const resize = (e) => {
458
- const change = e.clientX - element.dataset.startX;
459
- const adjustedChange = type === "kernel" ? change : -change;
460
- const newWidth = ((Number(element.dataset.startWidth) + adjustedChange) / Number(element.dataset.containerWidth)) * 100;
461
- if (newWidth >= minWidth && newWidth <= maxWidth) {
462
- element.style.width = `${newWidth}%`;
463
- }
464
- };
465
-
466
- handle.addEventListener("mousedown", (e) => {
467
- e.preventDefault();
468
- element.dataset.startX = e.clientX;
469
- element.dataset.containerWidth = mainContainer.getBoundingClientRect().width;
470
- element.dataset.startWidth = element.getBoundingClientRect().width;
471
-
472
- document.documentElement.addEventListener("mousemove", resize, false);
473
- document.documentElement.addEventListener("mouseup", () => {
474
- document.documentElement.removeEventListener("mousemove", resize, false);
475
- element.style.userSelect = "initial";
476
- }, { once: true });
477
- });
478
- }
479
- createResizer(kernelListParent, { minWidth: 15, maxWidth: 50 }, "kernel"); // left resizer
480
- createResizer(metadata, { minWidth: 20, maxWidth: 50 }, "metadata"); // right resizer
481
- }
482
-
483
- // **** keyboard shortcuts
484
- document.addEventListener("keydown", async function(event) {
485
- // up and down change the UOp or kernel from the list
486
- if (!expandKernel) {
487
- if (event.key == "ArrowUp") {
488
- event.preventDefault()
489
- currentUOp = 0;
490
- currentRewrite = 0;
491
- currentKernel = Math.max(0, currentKernel-1)
492
- return main()
493
- }
494
- if (event.key == "ArrowDown") {
495
- event.preventDefault()
496
- currentUOp = 0;
497
- currentRewrite = 0;
498
- currentKernel = Math.min(kernels.length-1, currentKernel+1);
499
- return main()
500
- }
501
- }
502
- if (event.key == "Enter") {
503
- event.preventDefault()
504
- if (currentKernel === -1) {
505
- currentKernel = 0;
506
- expandKernel = true;
507
- }
508
- else {
509
- expandKernel = !expandKernel;
510
- }
511
- currentUOp = 0;
512
- currentRewrite = 0;
513
- main();
514
- }
515
- if (event.key == "ArrowUp") {
516
- event.preventDefault()
517
- currentRewrite = 0;
518
- currentUOp = Math.max(0, currentUOp-1)
519
- main()
520
- }
521
- if (event.key == "ArrowDown") {
522
- event.preventDefault()
523
- currentRewrite = 0;
524
- const totalUOps = kernels[currentKernel][1].length-1;
525
- currentUOp = Math.min(totalUOps, currentUOp+1)
526
- main()
527
- }
528
- // left and right go through rewrites in a single UOp
529
- if (event.key == "ArrowLeft") {
530
- event.preventDefault()
531
- currentRewrite = Math.max(0, currentRewrite-1)
532
- main()
533
- }
534
- if (event.key == "ArrowRight") {
535
- event.preventDefault()
536
- const totalRewrites = ret.length-1;
537
- currentRewrite = Math.min(totalRewrites, currentRewrite+1)
538
- main()
539
- }
540
- })
541
- main()
542
- </script>
543
- </body>
341
+ <div id="tooltip"></div>
342
+ <script src="/js/index.js"></script>
343
+ </body>
544
344
  </html>