@jax-js/jax 0.0.1 → 0.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,1978 +0,0 @@
1
- var __defProp = Object.defineProperty;
2
- var __knownSymbol = (name, symbol) => (symbol = Symbol[name]) ? symbol : Symbol.for("Symbol." + name);
3
- var __typeError = (msg) => {
4
- throw TypeError(msg);
5
- };
6
- var __export = (target, all) => {
7
- for (var name in all)
8
- __defProp(target, name, { get: all[name], enumerable: true });
9
- };
10
- var __using = (stack, value, async) => {
11
- if (value != null) {
12
- if (typeof value !== "object" && typeof value !== "function") __typeError("Object expected");
13
- var dispose, inner;
14
- if (async) dispose = value[__knownSymbol("asyncDispose")];
15
- if (dispose === void 0) {
16
- dispose = value[__knownSymbol("dispose")];
17
- if (async) inner = dispose;
18
- }
19
- if (typeof dispose !== "function") __typeError("Object not disposable");
20
- if (inner) dispose = function() {
21
- try {
22
- inner.call(this);
23
- } catch (e) {
24
- return Promise.reject(e);
25
- }
26
- };
27
- stack.push([async, dispose, value]);
28
- } else if (async) {
29
- stack.push([async]);
30
- }
31
- return value;
32
- };
33
- var __callDispose = (stack, error, hasError) => {
34
- var E = typeof SuppressedError === "function" ? SuppressedError : function(e, s, m, _) {
35
- return _ = Error(m), _.name = "SuppressedError", _.error = e, _.suppressed = s, _;
36
- };
37
- var fail = (e) => error = hasError ? new E(e, error, "An error was suppressed during disposal") : (hasError = true, e);
38
- var next = (it) => {
39
- while (it = stack.pop()) {
40
- try {
41
- var result = it[1] && it[1].call(it[2]);
42
- if (it[0]) return Promise.resolve(result).then(next, (e) => (fail(e), next()));
43
- } catch (e) {
44
- fail(e);
45
- }
46
- }
47
- if (hasError) throw error;
48
- };
49
- return next();
50
- };
51
-
52
- // src/utils.ts
53
- var DEBUG = 3;
54
- function unzip2(pairs) {
55
- const lst1 = [];
56
- const lst2 = [];
57
- for (const [x, y] of pairs) {
58
- lst1.push(x);
59
- lst2.push(y);
60
- }
61
- return [lst1, lst2];
62
- }
63
- function zip(xs, ys) {
64
- return xs.map((x, i) => [x, ys[i]]);
65
- }
66
- function rep(length, value) {
67
- if (value instanceof Function) {
68
- return new Array(length).fill(0).map((_, i) => value(i));
69
- }
70
- return new Array(length).fill(value);
71
- }
72
- function prod(arr) {
73
- return arr.reduce((acc, x) => acc * x, 1);
74
- }
75
- function intdiv(a, b) {
76
- return Math.floor(a / b);
77
- }
78
- function clamp(x, min, max) {
79
- return Math.max(min, Math.min(max, x));
80
- }
81
- function deepEqual(a, b) {
82
- if (a === b) {
83
- return true;
84
- }
85
- if (typeof a !== "object" || typeof b !== "object") {
86
- return false;
87
- }
88
- if (a === null || b === null) {
89
- return false;
90
- }
91
- if (Object.keys(a).length !== Object.keys(b).length) {
92
- return false;
93
- }
94
- for (const key of Object.keys(a)) {
95
- if (!deepEqual(a[key], b[key])) {
96
- return false;
97
- }
98
- }
99
- return true;
100
- }
101
- function partitionList(which, array) {
102
- const falseList = [];
103
- const trueList = [];
104
- for (let i = 0; i < which.length; i++) {
105
- if (which[i]) {
106
- trueList.push(array[i]);
107
- } else {
108
- falseList.push(array[i]);
109
- }
110
- }
111
- return [falseList, trueList];
112
- }
113
- function lexCompare(a, b) {
114
- const minLength = Math.min(a.length, b.length);
115
- for (let i = 0; i < minLength; i++) {
116
- if (a[i] < b[i]) return -1;
117
- if (a[i] > b[i]) return 1;
118
- }
119
- return a.length - b.length;
120
- }
121
- function range(start, stop, step = 1) {
122
- if (stop === void 0) {
123
- stop = start;
124
- start = 0;
125
- }
126
- const result = [];
127
- for (let i = start; i < stop; i += step) {
128
- result.push(i);
129
- }
130
- return result;
131
- }
132
- function isPermutation(axis, n) {
133
- if (axis.length !== n) return false;
134
- const seen = /* @__PURE__ */ new Set();
135
- for (const x of axis) {
136
- if (x < 0 || x >= n) return false;
137
- seen.add(x);
138
- }
139
- return seen.size === n;
140
- }
141
- function invertPermutation(axis) {
142
- const n = axis.length;
143
- if (!isPermutation(axis, n))
144
- throw new Error("invertPermutation: axis is not a permutation");
145
- const result = new Array(n);
146
- for (let i = 0; i < n; i++) {
147
- result[axis[i]] = i;
148
- }
149
- return result;
150
- }
151
- function toposort(terminals, parents) {
152
- const childCounts = /* @__PURE__ */ new Map();
153
- const stack = [...new Set(terminals)];
154
- while (true) {
155
- const node = stack.pop();
156
- if (!node) break;
157
- for (const parent of parents(node)) {
158
- if (childCounts.has(parent)) {
159
- childCounts.set(parent, childCounts.get(parent) + 1);
160
- } else {
161
- childCounts.set(parent, 1);
162
- stack.push(parent);
163
- }
164
- }
165
- }
166
- for (const node of terminals) {
167
- childCounts.set(node, childCounts.get(node) - 1);
168
- }
169
- const order = [];
170
- const frontier = terminals.filter((n) => !childCounts.get(n));
171
- while (true) {
172
- const node = frontier.pop();
173
- if (!node) break;
174
- order.push(node);
175
- for (const parent of parents(node)) {
176
- const c = childCounts.get(parent) - 1;
177
- childCounts.set(parent, c);
178
- if (c == 0) {
179
- frontier.push(parent);
180
- }
181
- }
182
- }
183
- return order.reverse();
184
- }
185
- function findPow2(hint, max) {
186
- if (max < 1) {
187
- throw new Error("max must be a positive integer");
188
- }
189
- let ret = 1;
190
- while (ret < hint && 2 * ret <= max) {
191
- ret *= 2;
192
- }
193
- return ret;
194
- }
195
- function recursiveFlatten(ar) {
196
- if (!Array.isArray(ar)) return [ar];
197
- return ar.flat(Infinity);
198
- }
199
- function strip1(str) {
200
- if (str[0] === "(" && str[str.length - 1] === ")") {
201
- return str.slice(1, -1);
202
- }
203
- return str;
204
- }
205
- var FpHash = class _FpHash {
206
- value = 8773157n;
207
- #update(x) {
208
- const base = 873192869n;
209
- const modulus = 3189051996290219n;
210
- this.value = (this.value * base + x) % modulus;
211
- }
212
- update(...values) {
213
- for (const x of values) {
214
- if (typeof x === "string") {
215
- for (const c of x) this.#update(BigInt(199 + c.charCodeAt(0)));
216
- } else if (typeof x === "number") {
217
- if (Number.isInteger(x)) {
218
- this.#update(68265653n ^ BigInt(x));
219
- } else {
220
- const ar = new Float64Array([x]);
221
- this.#update(new DataView(ar.buffer).getBigUint64(0, true));
222
- }
223
- } else if (typeof x === "boolean") {
224
- this.#update(x ? 69069841n : 63640693n);
225
- } else if (typeof x === "bigint") {
226
- this.#update(x ^ 71657401n);
227
- } else if (x === null) {
228
- this.#update(37832657n);
229
- } else if (x === void 0) {
230
- this.#update(18145117n);
231
- } else if (typeof x === "object" && "hash" in x) {
232
- x.hash(this);
233
- }
234
- }
235
- return this;
236
- }
237
- static hash(...values) {
238
- return new _FpHash().update(...values).value;
239
- }
240
- };
241
- function runWithCache(cache, key, thunk) {
242
- if (cache.has(key)) {
243
- return cache.get(key);
244
- } else {
245
- const value = thunk();
246
- cache.set(key, value);
247
- return value;
248
- }
249
- }
250
-
251
- // src/alu.ts
252
- var DType = /* @__PURE__ */ ((DType2) => {
253
- DType2["Float32"] = "float32";
254
- DType2["Int32"] = "int32";
255
- DType2["Bool"] = "bool";
256
- DType2["Complex64"] = "complex64";
257
- return DType2;
258
- })(DType || {});
259
- var isFloatDtype = (dtype) => dtype === "float32" /* Float32 */ || dtype === "complex64" /* Complex64 */;
260
- var AluExp = class _AluExp {
261
- constructor(op, dtype, src, arg = void 0) {
262
- this.op = op;
263
- this.dtype = dtype;
264
- this.src = src;
265
- this.arg = arg;
266
- if (AluGroup.RequiredFloat.has(op) && !isFloatDtype(dtype)) {
267
- throw new TypeError(`Unsupported dtype for ${op}: ${dtype}`);
268
- }
269
- }
270
- #hash;
271
- #simplified;
272
- #range;
273
- static add(a, b) {
274
- return new _AluExp("Add" /* Add */, a.dtype, [a, b]);
275
- }
276
- static sub(a, b) {
277
- return new _AluExp("Sub" /* Sub */, a.dtype, [a, b]);
278
- }
279
- static mul(a, b) {
280
- return new _AluExp("Mul" /* Mul */, a.dtype, [a, b]);
281
- }
282
- static idiv(a, b) {
283
- return new _AluExp("Idiv" /* Idiv */, a.dtype, [a, b]);
284
- }
285
- static mod(a, b) {
286
- return new _AluExp("Mod" /* Mod */, a.dtype, [a, b]);
287
- }
288
- static min(a, b) {
289
- return new _AluExp("Min" /* Min */, a.dtype, [a, b]);
290
- }
291
- static max(a, b) {
292
- return new _AluExp("Max" /* Max */, a.dtype, [a, b]);
293
- }
294
- static sin(a) {
295
- return new _AluExp("Sin" /* Sin */, a.dtype, [a]);
296
- }
297
- static cos(a) {
298
- return new _AluExp("Cos" /* Cos */, a.dtype, [a]);
299
- }
300
- static exp(a) {
301
- return new _AluExp("Exp" /* Exp */, a.dtype, [a]);
302
- }
303
- static log(a) {
304
- return new _AluExp("Log" /* Log */, a.dtype, [a]);
305
- }
306
- static reciprocal(a) {
307
- return new _AluExp("Reciprocal" /* Reciprocal */, a.dtype, [a]);
308
- }
309
- static cast(dtype, a) {
310
- if (a.dtype === dtype) return a;
311
- return new _AluExp("Cast" /* Cast */, dtype, [a]);
312
- }
313
- static cmplt(a, b) {
314
- return new _AluExp("Cmplt" /* Cmplt */, "bool" /* Bool */, [a, b]);
315
- }
316
- static cmpne(a, b) {
317
- return new _AluExp("Cmpne" /* Cmpne */, "bool" /* Bool */, [a, b]);
318
- }
319
- static where(cond, a, b) {
320
- return new _AluExp("Where" /* Where */, a.dtype, [cond, a, b]);
321
- }
322
- static const(dtype, value) {
323
- if (dtype === "bool" /* Bool */) {
324
- value = Number(Boolean(value));
325
- } else if (dtype === "int32" /* Int32 */) {
326
- value = Math.trunc(value);
327
- }
328
- if (typeof value !== "number") {
329
- throw new TypeError(
330
- `Expected a number for constant, got ${typeof value}: ${value}`
331
- );
332
- }
333
- return new _AluExp("Const" /* Const */, dtype, [], value);
334
- }
335
- static special(dtype, name, n) {
336
- return new _AluExp("Special" /* Special */, dtype, [], [name, n]);
337
- }
338
- static variable(dtype, name) {
339
- return new _AluExp("Variable" /* Variable */, dtype, [], name);
340
- }
341
- static globalIndex(dtype, gid, bufidx) {
342
- return new _AluExp("GlobalIndex" /* GlobalIndex */, dtype, [bufidx], gid);
343
- }
344
- static globalView(dtype, gid, st, indices) {
345
- return new _AluExp("GlobalView" /* GlobalView */, dtype, indices, [gid, st]);
346
- }
347
- static i32(value) {
348
- return _AluExp.const("int32" /* Int32 */, value);
349
- }
350
- static f32(value) {
351
- return _AluExp.const("float32" /* Float32 */, value);
352
- }
353
- static bool(value) {
354
- return _AluExp.const("bool" /* Bool */, Number(value));
355
- }
356
- not() {
357
- if (this.dtype !== "bool" /* Bool */) {
358
- throw new Error("not() can only be called on boolean expressions");
359
- }
360
- return _AluExp.cmpne(this, _AluExp.const("bool" /* Bool */, true));
361
- }
362
- /** Compute a reasonable expression hash with low collision rate. */
363
- getHash() {
364
- if (this.#hash !== void 0) return this.#hash;
365
- const hasher = new FpHash();
366
- hasher.update(this.op, this.dtype, JSON.stringify(this.arg));
367
- hasher.update(this.src.length, ...this.src);
368
- this.#hash = hasher.value;
369
- return this.#hash;
370
- }
371
- hash(state) {
372
- state.update(this.getHash());
373
- }
374
- /** Substitute variables in this AluExp to values. */
375
- substitute(variables) {
376
- return this.rewrite((exp) => {
377
- if (exp.op === "Variable" /* Variable */ && Object.hasOwn(variables, exp.arg)) {
378
- if (exp.dtype !== variables[exp.arg].dtype) {
379
- throw new Error(
380
- `Type mismatch: ${exp.dtype} vs ${variables[exp.arg].dtype}`
381
- );
382
- }
383
- return variables[exp.arg];
384
- }
385
- });
386
- }
387
- /** Reindex gid values in this expression as needed. */
388
- reindexGids(gidMap) {
389
- return this.rewrite((exp) => {
390
- if (exp.op === "GlobalIndex" /* GlobalIndex */) {
391
- const gid = exp.arg;
392
- const newGid = gidMap.get(gid);
393
- if (newGid !== void 0 && newGid !== gid) {
394
- return _AluExp.globalIndex(exp.dtype, newGid, exp.src[0]);
395
- }
396
- } else if (exp.op === "GlobalView" /* GlobalView */) {
397
- const gid = exp.arg[0];
398
- const newGid = gidMap.get(gid);
399
- if (newGid !== void 0 && newGid !== gid) {
400
- return _AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
401
- }
402
- }
403
- });
404
- }
405
- #computeRange() {
406
- if (this.#range !== void 0) return this.#range;
407
- const src = this.src;
408
- const minMax4 = (f) => {
409
- const [r1, r2] = [src[0].#computeRange(), src[1].#computeRange()];
410
- const values = [
411
- f(r1[0], r2[0]),
412
- f(r1[0], r2[1]),
413
- f(r1[1], r2[0]),
414
- f(r1[1], r2[1])
415
- ];
416
- return [Math.min(...values), Math.max(...values)];
417
- };
418
- let ret;
419
- switch (this.op) {
420
- case "Add" /* Add */:
421
- ret = [src[0].min + src[1].min, src[0].max + src[1].max];
422
- break;
423
- case "Sub" /* Sub */:
424
- ret = [src[0].min - src[1].max, src[0].max - src[1].min];
425
- break;
426
- case "Mul" /* Mul */: {
427
- ret = minMax4((a, b) => a * b);
428
- break;
429
- }
430
- case "Idiv" /* Idiv */: {
431
- ret = minMax4((a, b) => Math.floor(a / b));
432
- break;
433
- }
434
- case "Mod" /* Mod */: {
435
- let divisorRange = src[1].#computeRange();
436
- if (divisorRange[0] <= 0 && divisorRange[1] >= 0) {
437
- divisorRange = [0, Math.max(-divisorRange[0], divisorRange[1])];
438
- }
439
- const maxDivisor = isFloatDtype(this.dtype) ? divisorRange[1] : divisorRange[1] - 1;
440
- ret = [
441
- clamp(src[0].min, -maxDivisor, 0),
442
- clamp(src[0].max, 0, maxDivisor)
443
- ];
444
- break;
445
- }
446
- case "Min" /* Min */:
447
- ret = [
448
- Math.min(src[0].min, src[1].min),
449
- Math.min(src[0].max, src[1].max)
450
- ];
451
- break;
452
- case "Max" /* Max */:
453
- ret = [
454
- Math.max(src[0].min, src[1].min),
455
- Math.max(src[0].max, src[1].max)
456
- ];
457
- break;
458
- case "Sin" /* Sin */:
459
- ret = [Math.sin(src[0].min), Math.sin(src[0].max)];
460
- break;
461
- case "Cos" /* Cos */:
462
- ret = [Math.cos(src[0].min), Math.cos(src[0].max)];
463
- break;
464
- case "Exp" /* Exp */:
465
- ret = [Math.exp(src[0].min), Math.exp(src[0].max)];
466
- break;
467
- case "Log" /* Log */:
468
- ret = [Math.log(src[0].min), Math.log(src[0].max)];
469
- break;
470
- case "Reciprocal" /* Reciprocal */:
471
- if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
472
- ret = [1 / src[0].max, 1 / src[0].min];
473
- break;
474
- case "Cast" /* Cast */:
475
- if (this.dtype === "bool" /* Bool */) {
476
- const canBeZero = src[0].min <= 0 && src[0].max >= 0;
477
- const mustBeZero = src[0].min === 0 && src[0].max === 0;
478
- ret = mustBeZero ? [0, 0] : canBeZero ? [0, 1] : [1, 1];
479
- } else if (this.dtype === "int32" /* Int32 */) {
480
- ret = [Math.trunc(src[0].min), Math.trunc(src[0].max)];
481
- } else {
482
- ret = [src[0].min, src[0].max];
483
- }
484
- break;
485
- case "Cmplt" /* Cmplt */:
486
- ret = [0, 1];
487
- break;
488
- case "Cmpne" /* Cmpne */:
489
- ret = [0, 1];
490
- break;
491
- case "Where" /* Where */:
492
- ret = [
493
- Math.min(src[1].min, src[2].min),
494
- Math.max(src[1].max, src[2].max)
495
- ];
496
- break;
497
- case "Const" /* Const */:
498
- ret = [this.arg, this.arg];
499
- break;
500
- case "Special" /* Special */:
501
- ret = [0, this.arg[1] - 1];
502
- break;
503
- default:
504
- ret = [-Infinity, Infinity];
505
- }
506
- if (isNaN(ret[0]) || isNaN(ret[1])) {
507
- ret = [-Infinity, Infinity];
508
- }
509
- if (this.dtype === "bool" /* Bool */) {
510
- ret[0] = clamp(ret[0], 0, 1);
511
- ret[1] = clamp(ret[1], 0, 1);
512
- }
513
- this.#range = ret;
514
- return ret;
515
- }
516
- get min() {
517
- return this.#computeRange()[0];
518
- }
519
- get max() {
520
- return this.#computeRange()[1];
521
- }
522
- #isConstInt() {
523
- return this.op === "Const" /* Const */ && this.dtype === "int32" /* Int32 */;
524
- }
525
- /**
526
- * Simplify the expression by replacing any known patterns and deduping
527
- * identical subexpressions.
528
- */
529
- simplify(cache = /* @__PURE__ */ new Map()) {
530
- if (this.#simplified !== void 0) return this.#simplified;
531
- const hash = this.getHash();
532
- if (cache.has(hash)) {
533
- return this.#simplified = cache.get(hash);
534
- }
535
- const simplified = this.#simplifyInner(cache);
536
- const simplifiedHash = simplified.getHash();
537
- if (cache.has(simplifiedHash)) {
538
- const prevSimplified = cache.get(simplifiedHash);
539
- cache.set(hash, prevSimplified);
540
- this.#simplified = prevSimplified;
541
- return prevSimplified;
542
- } else {
543
- cache.set(hash, simplified);
544
- cache.set(simplifiedHash, simplified);
545
- this.#simplified = simplified;
546
- return simplified;
547
- }
548
- }
549
- #simplifyInner(cache) {
550
- const src = this.src.map((x) => x.simplify(cache));
551
- const { op } = this;
552
- if (src.every((x) => x.op === "Const" /* Const */) && !AluGroup.Variable.has(op)) {
553
- const newExp2 = new _AluExp(op, this.dtype, src, this.arg);
554
- return _AluExp.const(this.dtype, newExp2.evaluate({}));
555
- }
556
- if (op !== "Const" /* Const */ && this.min === this.max) {
557
- return _AluExp.const(this.dtype, this.min);
558
- }
559
- if (AluGroup.Binary.has(op)) {
560
- for (let i = 0; i < 2; i++) {
561
- if (src[i].op !== "Const" /* Const */) continue;
562
- const x = src[i].arg;
563
- if (op === "Add" /* Add */ && x === 0) return src[1 - i];
564
- if (op === "Sub" /* Sub */ && i === 1 && x === 0) return src[1 - i];
565
- if (op === "Mul" /* Mul */ && x === 1) return src[1 - i];
566
- if (op === "Mul" /* Mul */ && x === 0) return _AluExp.const(this.dtype, 0);
567
- if (op === "Idiv" /* Idiv */ && i === 1 && x === 1) return src[1 - i];
568
- }
569
- }
570
- if ((op === "Add" /* Add */ || op === "Sub" /* Sub */) && src[1].op === "Mul" /* Mul */) {
571
- const [a, b] = src[1].src;
572
- const opNeg = op === "Add" /* Add */ ? "Sub" /* Sub */ : "Add" /* Add */;
573
- if (a.op === "Const" /* Const */ && a.arg === -1) {
574
- return new _AluExp(opNeg, this.dtype, [src[0], b]);
575
- } else if (b.op === "Const" /* Const */ && b.arg === -1) {
576
- return new _AluExp(opNeg, this.dtype, [src[0], a]);
577
- }
578
- }
579
- if (op === "Mod" /* Mod */ && src[1].op === "Const" /* Const */ && src[0].min >= 0 && src[0].max < src[1].arg) {
580
- return src[0];
581
- }
582
- if (op === "Add" /* Add */ && src[0].op === "Mul" /* Mul */ && src[0].src[1].#isConstInt() && src[1].op === "Mod" /* Mod */ && src[1].src[1].#isConstInt() && src[0].src[1].arg === src[1].src[1].arg) {
583
- const [mul, mod] = src;
584
- const check = (exp) => {
585
- return exp.op === "Idiv" /* Idiv */ && exp.src[1].#isConstInt() && exp.src[1].arg === mod.src[1].arg && exp.src[0] === mod.src[0];
586
- };
587
- if (check(mul.src[0])) return mod.src[0];
588
- if (mul.src[0].op === "Mod" /* Mod */) {
589
- const [x, y] = mul.src[0].src;
590
- if (check(x)) {
591
- return _AluExp.mod(mod.src[0], _AluExp.mul(mod.src[1], y)).simplify(
592
- cache
593
- );
594
- }
595
- }
596
- }
597
- if (op === "Idiv" /* Idiv */ && src[1].#isConstInt()) {
598
- const [numer, denom] = src;
599
- const B = denom.arg;
600
- for (let i = 0; i < 2; i++) {
601
- if (numer.op === "Mul" /* Mul */ && numer.src[i].#isConstInt()) {
602
- const A = numer.src[i].arg;
603
- if (A % B === 0) {
604
- let ret = numer.src[1 - i];
605
- if (A / B !== 1) ret = _AluExp.mul(ret, _AluExp.i32(A / B));
606
- return ret.simplify(cache);
607
- }
608
- }
609
- for (let j = 0; j < 2; j++) {
610
- if (numer.op === "Add" /* Add */ && numer.src[j].op === "Mul" /* Mul */ && numer.src[j].src[i].#isConstInt()) {
611
- const A = numer.src[j].src[i].arg;
612
- if (A % B === 0) {
613
- let ret = numer.src[j].src[1 - i];
614
- if (A / B !== 1) ret = _AluExp.mul(ret, _AluExp.i32(A / B));
615
- ret = _AluExp.add(ret, _AluExp.idiv(numer.src[1 - j], B));
616
- return ret.simplify(cache);
617
- }
618
- }
619
- }
620
- }
621
- }
622
- if (op === "Mod" /* Mod */ && src[1].#isConstInt() && src[1].arg > 0 && src[0].min >= 0) {
623
- const [numer, denom] = src;
624
- const B = denom.arg;
625
- for (let i = 0; i < 2; i++) {
626
- if (numer.op === "Add" /* Add */ && numer.src[i].#isConstInt()) {
627
- const A = numer.src[i].arg;
628
- let ret = numer.src[1 - i];
629
- if (A % B !== 0) ret = _AluExp.add(ret, _AluExp.i32(A % B));
630
- return ret.simplify(cache);
631
- }
632
- }
633
- }
634
- if (op === "Cmplt" /* Cmplt */) {
635
- if (src[0].min >= src[1].max) return _AluExp.const("bool" /* Bool */, false);
636
- if (src[0].max < src[1].min) return _AluExp.const("bool" /* Bool */, true);
637
- }
638
- if (op === "Cmpne" /* Cmpne */) {
639
- if (src[0].max < src[1].min || src[0].min > src[1].max)
640
- return _AluExp.const("bool" /* Bool */, true);
641
- }
642
- if (op === "Where" /* Where */) {
643
- if (src[0].max === 0) return src[2];
644
- if (src[0].min === 1) return src[1];
645
- }
646
- const newExp = src.every((s, i) => s === this.src[i]) ? this : new _AluExp(op, this.dtype, src, this.arg);
647
- return newExp;
648
- }
649
- /** Resolve this to a value, or `undefined` if not possible. */
650
- resolve() {
651
- const x = this.simplify();
652
- if (x.op === "Const" /* Const */) return x.arg;
653
- return void 0;
654
- }
655
- /**
656
- * Evaluate the expression on CPU, returning the result.
657
- *
658
- * Typically you would compile the AluExp as a representation to a lower-level
659
- * language. This is just to define the semantics and help debug.
660
- *
661
- * Note that the representation of Bool is as a number (0 or 1) here.
662
- */
663
- evaluate(context, globals) {
664
- if (AluGroup.Binary.has(this.op) || AluGroup.Compare.has(this.op)) {
665
- const x = this.src[0].evaluate(context, globals);
666
- const y = this.src[1].evaluate(context, globals);
667
- switch (this.op) {
668
- case "Add" /* Add */:
669
- return this.dtype === "bool" /* Bool */ ? Number(x || y) : x + y;
670
- case "Sub" /* Sub */:
671
- return x - y;
672
- case "Mul" /* Mul */:
673
- return this.dtype === "bool" /* Bool */ ? Number(x && y) : x * y;
674
- case "Idiv" /* Idiv */:
675
- return Math.trunc(x / y);
676
- // Consistent with signed Mod.
677
- case "Mod" /* Mod */:
678
- return x % y;
679
- case "Min" /* Min */:
680
- return Math.min(x, y);
681
- case "Max" /* Max */:
682
- return Math.max(x, y);
683
- case "Cmplt" /* Cmplt */:
684
- return x < y;
685
- case "Cmpne" /* Cmpne */:
686
- return x != y;
687
- default:
688
- throw new Error(`Missing implemementation for ${this.op}`);
689
- }
690
- }
691
- if (AluGroup.Unary.has(this.op)) {
692
- const x = this.src[0].evaluate(context, globals);
693
- switch (this.op) {
694
- case "Sin" /* Sin */:
695
- return Math.sin(x);
696
- case "Cos" /* Cos */:
697
- return Math.cos(x);
698
- case "Exp" /* Exp */:
699
- return Math.exp(x);
700
- case "Log" /* Log */:
701
- return Math.log(x);
702
- case "Reciprocal" /* Reciprocal */:
703
- return 1 / x;
704
- case "Cast" /* Cast */:
705
- if (this.dtype === "int32" /* Int32 */) return Math.floor(x);
706
- else if (this.dtype === "float32" /* Float32 */) return x;
707
- else if (this.dtype === "bool" /* Bool */) return Number(Boolean(x));
708
- else throw new Error(`Unsupported cast to ${this.dtype}`);
709
- default:
710
- throw new Error(`Missing implemementation for ${this.op}`);
711
- }
712
- }
713
- switch (this.op) {
714
- case "Where" /* Where */:
715
- return this.src[0].evaluate(context, globals) ? this.src[1].evaluate(context, globals) : this.src[2].evaluate(context, globals);
716
- case "Const" /* Const */:
717
- return this.arg;
718
- case "Special" /* Special */:
719
- return context[this.arg[0]];
720
- case "Variable" /* Variable */:
721
- return context[this.arg];
722
- case "GlobalIndex" /* GlobalIndex */: {
723
- if (!globals) throw new Error("Missing globals function");
724
- const gid = this.arg;
725
- const bufidx = this.src[0].evaluate(context, globals);
726
- return globals(gid, bufidx);
727
- }
728
- case "GlobalView" /* GlobalView */: {
729
- if (!globals) throw new Error("Missing globals function");
730
- const gid = this.arg[0];
731
- const st = this.arg[1];
732
- const [iexpr, vexpr] = st.toAluExp(this.src);
733
- if (vexpr.evaluate(context, globals)) {
734
- const bufidx = iexpr.evaluate(context, globals);
735
- return globals(gid, bufidx);
736
- } else {
737
- return 0;
738
- }
739
- }
740
- default:
741
- throw new Error(`Missing implemementation for ${this.op}`);
742
- }
743
- }
744
- /** Get this expression in debug format as a string. */
745
- toString() {
746
- const BIN_SYM = {
747
- ["Add" /* Add */]: "+",
748
- ["Sub" /* Sub */]: "-",
749
- ["Mul" /* Mul */]: "*",
750
- ["Idiv" /* Idiv */]: "/",
751
- ["Mod" /* Mod */]: "%"
752
- };
753
- const CMP_SYM = {
754
- ["Cmplt" /* Cmplt */]: "<",
755
- ["Cmpne" /* Cmpne */]: "!="
756
- };
757
- const UNARY_SYM = {
758
- ["Sin" /* Sin */]: "sin",
759
- ["Cos" /* Cos */]: "cos",
760
- ["Exp" /* Exp */]: "exp",
761
- ["Log" /* Log */]: "log",
762
- ["Reciprocal" /* Reciprocal */]: "1/"
763
- };
764
- return this.fold((node, parts) => {
765
- switch (node.op) {
766
- case "Const" /* Const */:
767
- return "" + (node.dtype === "bool" /* Bool */ ? Boolean(node.arg) : node.arg);
768
- case "Variable" /* Variable */:
769
- return `$${node.arg}:${node.dtype}`;
770
- case "Special" /* Special */: {
771
- const [name, n] = node.arg;
772
- return `#${name}{${n}}`;
773
- }
774
- case "GlobalIndex" /* GlobalIndex */:
775
- return `G_${node.arg}<${node.dtype}>[${strip1(parts[0])}]`;
776
- case "GlobalView" /* GlobalView */: {
777
- const [gid, st] = node.arg;
778
- const shape = st.shape.join(",");
779
- const cont = st.contiguous ? "c" : "nc";
780
- return `GV_${gid}<${node.dtype}>{${shape}${cont ? "" : "*"}}[${parts.map(strip1).join(", ")}]`;
781
- }
782
- }
783
- if (BIN_SYM[node.op]) {
784
- return `(${parts[0]} ${BIN_SYM[node.op]} ${parts[1]})`;
785
- }
786
- if (CMP_SYM[node.op]) {
787
- return `(${parts[0]} ${CMP_SYM[node.op]} ${parts[1]})`;
788
- }
789
- if (UNARY_SYM[node.op]) {
790
- return `${UNARY_SYM[node.op]}${parts[0]}`;
791
- }
792
- if (node.op === "Cast" /* Cast */) {
793
- return `Cast<${node.dtype}>(${strip1(parts[0])})`;
794
- }
795
- return `${node.op}(${parts.map(strip1).join(", ")})`;
796
- });
797
- }
798
- /** Generic fold() operation with a reducer over the expression tree. */
799
- fold(reducer) {
800
- const visited = /* @__PURE__ */ new Map();
801
- const recurse = (exp) => {
802
- if (visited.has(exp)) return visited.get(exp);
803
- const mappedSrc = exp.src.map((s) => recurse(s));
804
- const result = reducer(exp, mappedSrc);
805
- visited.set(exp, result);
806
- return result;
807
- };
808
- return recurse(this);
809
- }
810
- /** Rewrite the expression recursively using a visitor. */
811
- rewrite(visitor) {
812
- return this.fold((exp, newSrc) => {
813
- if (newSrc.length === exp.src.length && newSrc.every((s, i) => s === exp.src[i])) {
814
- return visitor(exp) ?? exp;
815
- } else {
816
- const newExp = new _AluExp(exp.op, exp.dtype, newSrc, exp.arg);
817
- return visitor(newExp) ?? newExp;
818
- }
819
- });
820
- }
821
- /** Collect all nodes that satisfy a predicate. */
822
- collect(predicate) {
823
- const result = [];
824
- this.fold((exp) => {
825
- if (predicate(exp)) result.push(exp);
826
- });
827
- return result;
828
- }
829
- };
830
- var AluGroup = {
831
- Binary: /* @__PURE__ */ new Set([
832
- "Add" /* Add */,
833
- "Sub" /* Sub */,
834
- "Mul" /* Mul */,
835
- "Idiv" /* Idiv */,
836
- "Mod" /* Mod */,
837
- "Min" /* Min */,
838
- "Max" /* Max */
839
- ]),
840
- Unary: /* @__PURE__ */ new Set([
841
- "Sin" /* Sin */,
842
- "Cos" /* Cos */,
843
- "Exp" /* Exp */,
844
- "Log" /* Log */,
845
- "Reciprocal" /* Reciprocal */,
846
- "Cast" /* Cast */
847
- ]),
848
- Compare: /* @__PURE__ */ new Set(["Cmplt" /* Cmplt */, "Cmpne" /* Cmpne */]),
849
- Variable: /* @__PURE__ */ new Set([
850
- "Special" /* Special */,
851
- "Variable" /* Variable */,
852
- "GlobalIndex" /* GlobalIndex */,
853
- "GlobalView" /* GlobalView */
854
- ]),
855
- Reduce: /* @__PURE__ */ new Set(["Add" /* Add */, "Mul" /* Mul */, "Min" /* Min */, "Max" /* Max */]),
856
- RequiredFloat: /* @__PURE__ */ new Set([
857
- "Sin" /* Sin */,
858
- "Cos" /* Cos */,
859
- "Exp" /* Exp */,
860
- "Log" /* Log */,
861
- "Reciprocal" /* Reciprocal */
862
- ])
863
- };
864
- var AluVar = {
865
- gidx: AluExp.variable("int32" /* Int32 */, "gidx"),
866
- // global index
867
- ridx: AluExp.variable("int32" /* Int32 */, "ridx"),
868
- // reduction index
869
- acc: (dtype) => AluExp.variable(dtype, "acc"),
870
- // accumulator
871
- idx: AluExp.variable("int32" /* Int32 */, "idx"),
872
- // virtual "array index"
873
- unroll: AluExp.variable("int32" /* Int32 */, "unroll"),
874
- // unroll index, inside loop
875
- upcast: AluExp.variable("int32" /* Int32 */, "upcast")
876
- // upcast index, inside loop
877
- };
878
- var Kernel = class {
879
- constructor(nargs, size, exp, reduction) {
880
- this.nargs = nargs;
881
- this.size = size;
882
- this.exp = exp;
883
- this.reduction = reduction;
884
- this.exp = exp.simplify();
885
- }
886
- hash(state) {
887
- state.update(this.nargs, this.size, this.exp, this.reduction);
888
- }
889
- };
890
- var Reduction = class {
891
- constructor(dtype, op, size, fusion = AluVar.acc(dtype)) {
892
- this.dtype = dtype;
893
- this.op = op;
894
- this.size = size;
895
- this.fusion = fusion;
896
- if (!AluGroup.Reduce.has(op)) {
897
- throw new TypeError(`Unsupported reduction: ${op}`);
898
- }
899
- }
900
- hash(state) {
901
- state.update(this.dtype, this.op, this.size, this.fusion);
902
- }
903
- /** Get the identity for this reduction operation. */
904
- get identity() {
905
- if (this.dtype === "bool" /* Bool */) {
906
- return this.op === "Add" /* Add */ || this.op === "Max" /* Max */ ? false : true;
907
- } else if (this.dtype === "int32" /* Int32 */) {
908
- if (this.op === "Add" /* Add */) return 0;
909
- else if (this.op === "Mul" /* Mul */) return 1;
910
- else if (this.op === "Min" /* Min */) return -1 >>> 1;
911
- else if (this.op === "Max" /* Max */) return 1 << 31;
912
- } else if (this.dtype === "float32" /* Float32 */) {
913
- if (this.op === "Add" /* Add */) return 0;
914
- else if (this.op === "Mul" /* Mul */) return 1;
915
- else if (this.op === "Min" /* Min */) return Infinity;
916
- else if (this.op === "Max" /* Max */) return -Infinity;
917
- }
918
- throw new TypeError(`Unsupported reduction: ${this.op} ${this.dtype}`);
919
- }
920
- /** Evaluate this operation on CPU. */
921
- evaluate(...values) {
922
- if (this.dtype === "bool" /* Bool */) {
923
- if (this.op === "Add" /* Add */ || this.op === "Max" /* Max */) {
924
- return values.reduce((a, b) => a || b, true);
925
- } else if (this.op === "Mul" /* Mul */ || this.op === "Min" /* Min */) {
926
- return values.reduce((a, b) => a && b, true);
927
- }
928
- } else if (this.dtype === "int32" /* Int32 */) {
929
- if (this.op === "Add" /* Add */) {
930
- return values.reduce((a, b) => a + b | 0, 0);
931
- } else if (this.op === "Mul" /* Mul */) {
932
- return values.reduce((a, b) => a * b | 0, 1);
933
- } else if (this.op === "Min" /* Min */) {
934
- return values.reduce(
935
- (a, b) => Math.min(a, b),
936
- -1 >>> 1
937
- );
938
- } else if (this.op === "Max" /* Max */) {
939
- return values.reduce((a, b) => Math.max(a, b), 1 << 31);
940
- }
941
- } else if (this.dtype === "float32" /* Float32 */) {
942
- if (this.op === "Add" /* Add */) {
943
- return values.reduce((a, b) => a + b, 0);
944
- } else if (this.op === "Mul" /* Mul */) {
945
- return values.reduce((a, b) => a * b, 1);
946
- } else if (this.op === "Min" /* Min */) {
947
- return values.reduce(
948
- (a, b) => Math.min(a, b),
949
- Infinity
950
- );
951
- } else if (this.op === "Max" /* Max */) {
952
- return values.reduce(
953
- (a, b) => Math.max(a, b),
954
- -Infinity
955
- );
956
- }
957
- }
958
- throw new TypeError(`Unsupported reduction: ${this.op} ${this.dtype}`);
959
- }
960
- };
961
- function accessorGlobal(dtype, gid, st, indices) {
962
- const [index, valid] = st.toAluExp(indices);
963
- return AluExp.where(
964
- valid,
965
- AluExp.globalIndex(dtype, gid, index),
966
- AluExp.const(dtype, 0)
967
- );
968
- }
969
- function accessorAluExp(exp, st, indices) {
970
- const [index, valid] = st.toAluExp(indices);
971
- return AluExp.where(valid, exp.substitute({ idx: index }), AluExp.f32(0));
972
- }
973
-
974
- // src/shape.ts
975
- var jstr = JSON.stringify;
976
- function canonicalizeStrides(shape, strides) {
977
- const newStrides = [];
978
- for (let i = 0; i < shape.length; i++) {
979
- if (shape[i] === 1) newStrides.push(0);
980
- else newStrides.push(strides[i]);
981
- }
982
- return newStrides;
983
- }
984
- function defaultStrides(shape) {
985
- if (shape.length === 0) return [];
986
- const strides = rep(shape.length, 1);
987
- for (let i = shape.length - 1; i > 0; i--) {
988
- strides[i - 1] = shape[i] * strides[i];
989
- }
990
- return canonicalizeStrides(shape, strides);
991
- }
992
- function mergeDims(shape, strides, mask) {
993
- if (shape.length === 0) return [];
994
- if (shape.length !== strides.length || mask && shape.length !== mask.length) {
995
- throw new Error("internal: invalid args to mergeDims");
996
- }
997
- const ret = [
998
- [shape[0], strides[0], strides[0] !== 0 ? shape[0] : 0]
999
- ];
1000
- let merging = mask ? mask[0][1] - mask[0][0] === 1 : shape[0] === 1;
1001
- for (let i = 1; i < shape.length; i++) {
1002
- const [s, st] = [shape[i], strides[i]];
1003
- if (s === 1) continue;
1004
- const [lastS, lastSt, lastPreExpandS] = ret[ret.length - 1];
1005
- if (merging || lastSt === 0 || lastSt === s * st) {
1006
- ret[ret.length - 1] = [lastS * s, st, merging ? s : lastPreExpandS * s];
1007
- } else {
1008
- ret.push([s, st, s]);
1009
- }
1010
- merging = mask ? mask[i][1] - mask[i][0] === 1 : false;
1011
- }
1012
- return ret;
1013
- }
1014
- function reshapeMask(maskInput, oldShape, newShape) {
1015
- const newMask = [];
1016
- let rMasksI = maskInput.length;
1017
- let rShapeI = oldShape.length;
1018
- let rNewShapeI = newShape.length;
1019
- const rMasks = () => rMasksI ? maskInput[--rMasksI] : [0, 1];
1020
- const rShape = () => rShapeI ? oldShape[--rShapeI] : 1;
1021
- const rNewShape = () => rNewShapeI ? newShape[--rNewShapeI] : 1;
1022
- let currStride = 1;
1023
- let [oldDim, newDim, mask] = [rShape(), rNewShape(), rMasks()];
1024
- while (newMask.length < newShape.length) {
1025
- const [l, r] = mask;
1026
- const nextStride = newDim * currStride;
1027
- if (oldDim === nextStride) {
1028
- newMask.push([intdiv(l, currStride), intdiv(r - 1, currStride) + 1]);
1029
- currStride = 1;
1030
- [oldDim, newDim, mask] = [rShape(), rNewShape(), rMasks()];
1031
- } else if (oldDim > nextStride) {
1032
- if (oldDim % nextStride !== 0) return null;
1033
- if ((l % nextStride !== 0 || r % nextStride !== 0) && intdiv(l, nextStride) !== intdiv(r - 1, nextStride))
1034
- return null;
1035
- newMask.push([
1036
- intdiv(l % nextStride, currStride),
1037
- intdiv((r - 1) % nextStride, currStride) + 1
1038
- ]);
1039
- [currStride, newDim] = [nextStride, rNewShape()];
1040
- } else {
1041
- const nextMask = rMasks();
1042
- if (!deepEqual(mask, [0, oldDim]) && l !== r && nextMask[1] - nextMask[0] !== 1)
1043
- return null;
1044
- mask = [nextMask[0] * oldDim + l, (nextMask[1] - 1) * oldDim + r];
1045
- oldDim *= rShape();
1046
- }
1047
- }
1048
- return newMask.reverse();
1049
- }
1050
- var View = class _View {
1051
- constructor(shape, strides, offset, mask) {
1052
- this.shape = shape;
1053
- this.strides = strides;
1054
- this.offset = offset;
1055
- this.mask = mask;
1056
- }
1057
- // Cached, computed property values.
1058
- #size;
1059
- #contiguous;
1060
- static create(shape, strides, offset = 0, mask = null) {
1061
- if (shape.some((s) => s < 0))
1062
- throw new Error("View shape must be non-negative");
1063
- strides = strides ? canonicalizeStrides(shape, strides) : defaultStrides(shape);
1064
- if (shape.includes(0)) {
1065
- return new _View(shape, rep(shape.length, 0), 0, null);
1066
- }
1067
- if (mask !== null && mask.every(([b, e], i) => b === 0 && e === shape[i])) {
1068
- mask = null;
1069
- }
1070
- if (mask !== null) {
1071
- const elimDims = [];
1072
- let hasNoData = false;
1073
- for (let i = 0; i < shape.length; i++) {
1074
- const [b, e] = mask[i];
1075
- if (b + 1 >= e) elimDims.push(i);
1076
- if (b >= e) hasNoData = true;
1077
- }
1078
- if (elimDims.length) {
1079
- if (hasNoData) {
1080
- strides = rep(shape.length, 0);
1081
- offset = 0;
1082
- mask = rep(shape.length, () => [0, 0]);
1083
- }
1084
- for (const i of elimDims) {
1085
- offset += strides[i] * mask[i][0];
1086
- strides[i] = 0;
1087
- }
1088
- }
1089
- }
1090
- return new _View(shape, strides, offset, mask);
1091
- }
1092
- get ndim() {
1093
- return this.shape.length;
1094
- }
1095
- get size() {
1096
- if (this.#size === void 0) this.#size = prod(this.shape);
1097
- return this.#size;
1098
- }
1099
- /** Whether this is a default, contiguous, unaltered view of the data (identity). */
1100
- get contiguous() {
1101
- if (this.#contiguous === void 0) {
1102
- this.#contiguous = this.size === 0 || this.offset === 0 && this.mask === null && deepEqual(this.strides, defaultStrides(this.shape));
1103
- }
1104
- return this.#contiguous;
1105
- }
1106
- /** Produce an AluExp for evaluating this view at an index. */
1107
- toAluExp(idxs) {
1108
- let iexpr = AluExp.i32(this.offset);
1109
- let vexpr = AluExp.bool(true);
1110
- for (let i = this.ndim - 1; i >= 0; i--) {
1111
- const idx = idxs[i];
1112
- if (this.shape[i] !== 1 && this.strides[i] !== 0) {
1113
- iexpr = AluExp.add(AluExp.mul(idx, AluExp.i32(this.strides[i])), iexpr);
1114
- }
1115
- if (this.mask) {
1116
- if (this.mask[i][0] !== 0)
1117
- vexpr = AluExp.mul(
1118
- AluExp.cmplt(idx, AluExp.i32(this.mask[i][0])).not(),
1119
- vexpr
1120
- );
1121
- if (this.mask[i][1] !== this.shape[i])
1122
- vexpr = AluExp.mul(
1123
- AluExp.cmplt(idx, AluExp.i32(this.mask[i][1])),
1124
- vexpr
1125
- );
1126
- }
1127
- }
1128
- return [iexpr, vexpr];
1129
- }
1130
- /**
1131
- * Try to compose this view with another one. `this` view is applied first,
1132
- * followed by the argument. If this is not possible for the specific views,
1133
- * return `null` instead.
1134
- *
1135
- * If composable, return a combined view with the same shape as `v1`.
1136
- *
1137
- * This is very tricky. The shapes of v1 and v2 may be different, and in that
1138
- * case, we do some math to figure out whether they're compatible.
1139
- */
1140
- compose(v1) {
1141
- const v2 = this;
1142
- if (v2.contiguous) return v1;
1143
- if (v1.contiguous) {
1144
- if (deepEqual(v1.shape, v2.shape)) return v2;
1145
- if (v1.size === v2.size) {
1146
- const ret = v2.reshape(v1.shape);
1147
- if (ret !== null) return ret;
1148
- }
1149
- }
1150
- if (v1.mask !== null) {
1151
- const newV1 = v1.shrink(v1.mask);
1152
- const merged = v2.compose(newV1);
1153
- return merged ? merged.pad(zip(v1.mask, v1.shape).map(([m, s]) => [m[0], s - m[1]])) : null;
1154
- }
1155
- const origin = unravel(v2.shape, v1.offset);
1156
- const terms = rep(v2.ndim, () => []);
1157
- const strides = rep(v1.ndim, 0);
1158
- for (let d1 = 0; d1 < v1.strides.length; d1++) {
1159
- const st = v1.strides[d1];
1160
- if (st === 0) {
1161
- continue;
1162
- }
1163
- const unravelOffset = unravel(v2.shape, v1.offset + st);
1164
- for (let d2 = 0; d2 < v2.ndim; d2++) {
1165
- const o = origin[d2];
1166
- const diff = unravelOffset[d2] - o;
1167
- if (diff === 0) {
1168
- continue;
1169
- }
1170
- terms[d2].push([d1, diff]);
1171
- strides[d1] += diff * v2.strides[d2];
1172
- }
1173
- }
1174
- let [mergedSize, mergedTermMin, mergedTermMax] = [1, 0, 0];
1175
- const extents = [];
1176
- for (let i = v2.ndim - 1; i >= 0; i--) {
1177
- const term = terms[i];
1178
- const s = v2.shape[i];
1179
- let [tmin, tmax] = [origin[i], origin[i]];
1180
- for (const [d1, s1] of term) {
1181
- if (s1 > 0) tmax += (v1.shape[d1] - 1) * s1;
1182
- else if (s1 < 0) tmin += (v1.shape[d1] - 1) * s1;
1183
- }
1184
- mergedTermMin += tmin * mergedSize;
1185
- mergedTermMax += tmax * mergedSize;
1186
- mergedSize *= s;
1187
- if (mergedTermMin >= 0 && mergedTermMax < mergedSize) {
1188
- extents.push([mergedSize, mergedTermMin, mergedTermMax]);
1189
- [mergedSize, mergedTermMin, mergedTermMax] = [1, 0, 0];
1190
- }
1191
- }
1192
- if (mergedTermMin !== 0 || mergedTermMax !== 0) return null;
1193
- extents.reverse();
1194
- const v2Shape = extents.map(([s]) => s);
1195
- if (!deepEqual(v2Shape, v2.shape)) {
1196
- const reshapedV2 = v2.reshape(v2Shape);
1197
- if (reshapedV2 === null) return null;
1198
- if (!deepEqual(reshapedV2.shape, v2.shape)) return reshapedV2.compose(v1);
1199
- }
1200
- if (v2.mask !== null) {
1201
- const newB = rep(v1.ndim, 0);
1202
- const newE = v1.shape.slice();
1203
- let bad = false;
1204
- for (let d2 = 0; d2 < v2.ndim; d2++) {
1205
- const [b, e] = v2.mask[d2];
1206
- const o = origin[d2];
1207
- const term = terms[d2];
1208
- const [_, tmin, tmax] = extents[d2];
1209
- if (b <= tmin && tmax < e) continue;
1210
- if (term.length !== 1) {
1211
- if (term.length === 0 && newE.length) newE[0] = 0;
1212
- else bad = true;
1213
- } else {
1214
- const [d1, s1] = term[0];
1215
- newB[d1] = Math.max(
1216
- newB[d1],
1217
- Math.ceil((s1 > 0 ? b - o : e - o - 1) / s1)
1218
- );
1219
- newE[d1] = Math.min(
1220
- newE[d1],
1221
- Math.floor((s1 < 0 ? b - o : e - o - 1) / s1) + 1
1222
- );
1223
- }
1224
- }
1225
- for (let d1 = 0; d1 < v1.ndim; d1++) {
1226
- if (newB[d1] !== 0 || newE[d1] !== v1.shape[d1]) {
1227
- return v2.compose(
1228
- _View.create(v1.shape, v1.strides, v1.offset, zip(newB, newE))
1229
- );
1230
- }
1231
- }
1232
- if (bad) return null;
1233
- }
1234
- let finalOffset = v2.offset;
1235
- for (let d2 = 0; d2 < v2.ndim; d2++) {
1236
- finalOffset += origin[d2] * v2.strides[d2];
1237
- }
1238
- return _View.create(v1.shape, strides, finalOffset, null);
1239
- }
1240
- /** Attempt to simplify this view into a smaller reshaped form. */
1241
- minify() {
1242
- const minShape = mergeDims(this.shape, this.strides, this.mask).map(
1243
- (x) => x[0]
1244
- );
1245
- const nv = this.reshape(minShape);
1246
- return nv ? nv : this;
1247
- }
1248
- /** Pad the view with zeros on each dimension. */
1249
- pad(arg) {
1250
- if (arg.length !== this.ndim || !arg.every(([b, e]) => b >= 0 && e >= 0)) {
1251
- throw new Error(`invalid pad ${jstr(arg)} for ${jstr(this.shape)}`);
1252
- }
1253
- if (arg.every(([b, e]) => b === 0 && e === 0)) return this;
1254
- const zvarg = arg.map(([b, e], i) => [-b, this.shape[i] + e]);
1255
- const mask = arg.map(([b, _e], i) => [b, this.shape[i] + b]);
1256
- return this.#unsafeResize(zvarg, mask);
1257
- }
1258
- /** Shrink the view by taking a subarray. */
1259
- shrink(arg) {
1260
- if (arg.length !== this.ndim || !arg.every(([b, e], i) => 0 <= b && b <= e && e <= this.shape[i])) {
1261
- throw new Error(`invalid shrink ${jstr(arg)} for ${jstr(this.shape)}`);
1262
- }
1263
- return this.#unsafeResize(arg);
1264
- }
1265
- #unsafeResize(arg, mask) {
1266
- const offset = this.strides.map((s, i) => s * arg[i][0]).reduce((a, b) => a + b, 0);
1267
- if (this.mask) {
1268
- const nmask = this.mask.map(([mx, my], i) => [
1269
- Math.max(0, Math.min(mx - arg[i][0], arg[i][1] - arg[i][0])),
1270
- Math.max(0, Math.min(my - arg[i][0], arg[i][1] - arg[i][0]))
1271
- ]);
1272
- mask = mask ? mask.map(([mx, my], i) => [
1273
- Math.max(mx, nmask[i][0]),
1274
- Math.min(my, nmask[i][1])
1275
- ]) : nmask;
1276
- }
1277
- return _View.create(
1278
- arg.map(([b, e]) => e - b),
1279
- this.strides,
1280
- this.offset + offset,
1281
- mask
1282
- );
1283
- }
1284
- /** Expand one or more axes with length "1" by repeating the data. */
1285
- expand(newShape) {
1286
- if (newShape.length !== this.ndim) {
1287
- throw new Error(
1288
- `Can't expand ${jstr(this.shape)} into ${jstr(newShape)}`
1289
- );
1290
- }
1291
- for (let i = 0; i < this.ndim; i++) {
1292
- if (newShape[i] !== this.shape[i] && this.shape[i] !== 1) {
1293
- throw new Error(
1294
- `Can't expand ${jstr(this.shape)} into ${jstr(newShape)}`
1295
- );
1296
- }
1297
- }
1298
- if (this.size === 0) return _View.create(newShape);
1299
- const mask = this.mask ? this.mask.map(
1300
- (m, i) => this.shape[i] === newShape[i] ? m : m[0] === 0 && m[1] === 1 ? [0, newShape[i]] : [0, 0]
1301
- ) : null;
1302
- return _View.create(newShape, this.strides, this.offset, mask);
1303
- }
1304
- /** Permute the axes of an array. */
1305
- permute(axis) {
1306
- if (!isPermutation(axis, this.ndim))
1307
- throw new Error(`Invalid permutation ${jstr(axis)} of len ${this.ndim}`);
1308
- const newShape = axis.map((a) => this.shape[a]);
1309
- const newStrides = axis.map((a) => this.strides[a]);
1310
- const newMask = this.mask ? axis.map((a) => this.mask[a]) : null;
1311
- return _View.create(newShape, newStrides, this.offset, newMask);
1312
- }
1313
- /** Flip (reverse) one or more axes of the view. */
1314
- flip(arg) {
1315
- if (arg.length !== this.ndim)
1316
- throw new Error(`Invalid flip ${jstr(arg)} for ${jstr(this.shape)}`);
1317
- const strides = this.strides.slice();
1318
- let offset = this.offset;
1319
- const mask = this.mask ? this.mask.slice() : null;
1320
- for (let i = 0; i < this.ndim; i++) {
1321
- const s = this.shape[i];
1322
- if (arg[i]) {
1323
- strides[i] = -strides[i];
1324
- offset += (s - 1) * this.strides[i];
1325
- if (mask) mask[i] = [s - mask[i][1], s - mask[i][0]];
1326
- }
1327
- }
1328
- return _View.create(this.shape, strides, offset, mask);
1329
- }
1330
- /** Reshape the view into a new shape. */
1331
- reshape(newShape) {
1332
- if (deepEqual(this.shape, newShape)) return this;
1333
- if (newShape.some((s) => s < 0))
1334
- throw new Error(`Reshape cannot have negative numbers ${jstr(newShape)}`);
1335
- if (this.size !== prod(newShape))
1336
- throw new Error(`Reshape size ${jstr(this.shape)} -> ${jstr(newShape)}`);
1337
- if (this.size === 0) return _View.create(newShape);
1338
- if (newShape.length === 0 && this.mask?.some(([b, e]) => b === e))
1339
- return null;
1340
- if (this.contiguous) return _View.create(newShape);
1341
- const rStrides = [];
1342
- const merge = mergeDims(this.shape, this.strides, this.mask);
1343
- let rShapeIdx = newShape.length;
1344
- for (let i = merge.length - 1; i >= 0; i--) {
1345
- let [mergedSize, newStride, realSize] = merge[i];
1346
- let acc = 1;
1347
- while (acc < mergedSize && rShapeIdx > 0) {
1348
- const newDim = newShape[--rShapeIdx];
1349
- rStrides.push(newStride * acc);
1350
- acc *= newDim;
1351
- if (acc >= realSize) newStride = 0;
1352
- }
1353
- if (acc !== mergedSize) return null;
1354
- }
1355
- const newStrides = rep(newShape.length - rStrides.length, 0).concat(
1356
- rStrides.reverse()
1357
- );
1358
- if (!this.mask) return _View.create(newShape, newStrides, this.offset);
1359
- const newMask = reshapeMask(this.mask, this.shape, newShape);
1360
- if (!newMask) return null;
1361
- let newOffset = this.offset;
1362
- for (let i = 0; i < this.ndim; i++)
1363
- newOffset += this.strides[i] * this.mask[i][0];
1364
- for (let i = 0; i < newShape.length; i++)
1365
- newOffset -= newStrides[i] * newMask[i][0];
1366
- return _View.create(newShape, newStrides, newOffset, newMask);
1367
- }
1368
- };
1369
- function unravel(shape, offset) {
1370
- let acc = 1;
1371
- const idxs = [];
1372
- for (let i = shape.length - 1; i >= 0; i--) {
1373
- const d = shape[i];
1374
- idxs.push(Math.floor(offset / acc) % d);
1375
- acc *= d;
1376
- }
1377
- return idxs.reverse();
1378
- }
1379
- function unravelAlu(shape, offset) {
1380
- let acc = 1;
1381
- const idxs = [];
1382
- for (let i = shape.length - 1; i >= 0; i--) {
1383
- const d = shape[i];
1384
- idxs.push(AluExp.mod(AluExp.idiv(offset, AluExp.i32(acc)), AluExp.i32(d)));
1385
- acc *= d;
1386
- }
1387
- return idxs.reverse();
1388
- }
1389
- var ShapeTracker = class _ShapeTracker {
1390
- constructor(views) {
1391
- this.views = views;
1392
- }
1393
- // Views apply left-to-right
1394
- /** Compose this shape tracker with another, applying after. */
1395
- compose(other) {
1396
- if (this.contiguous) return other;
1397
- let ret = this;
1398
- for (const v of other.views) {
1399
- ret = new _ShapeTracker(ret.views.concat(v)).simplify();
1400
- }
1401
- return ret;
1402
- }
1403
- static fromShape(shape) {
1404
- return new _ShapeTracker([View.create(shape)]);
1405
- }
1406
- get contiguous() {
1407
- return this.views.length === 1 && this.views[0].contiguous;
1408
- }
1409
- get consecutive() {
1410
- return this.views.length === 1 && this.views[0].mask === null && deepEqual(this.views[0].strides, defaultStrides(this.views[0].shape));
1411
- }
1412
- get lastStrides() {
1413
- return this.views[this.views.length - 1].strides;
1414
- }
1415
- get shape() {
1416
- return this.views[this.views.length - 1].shape;
1417
- }
1418
- get size() {
1419
- return this.views[this.views.length - 1].size;
1420
- }
1421
- toAluExp(idxs) {
1422
- let [iexpr, vexpr] = this.views[this.views.length - 1].toAluExp(idxs);
1423
- for (let i = this.views.length - 2; i >= 0; i--) {
1424
- const view = this.views[i].minify();
1425
- const exprs = view.toAluExp(unravelAlu(view.shape, iexpr));
1426
- iexpr = exprs[0];
1427
- vexpr = AluExp.mul(vexpr, exprs[1]);
1428
- }
1429
- return [iexpr.simplify(), vexpr.simplify()];
1430
- }
1431
- simplify() {
1432
- const views = this.views.slice();
1433
- while (views.length >= 2) {
1434
- const newView = views[views.length - 2].compose(views[views.length - 1]);
1435
- if (newView === null) break;
1436
- views.splice(views.length - 2, 2, newView);
1437
- }
1438
- return new _ShapeTracker(views);
1439
- }
1440
- pad(arg) {
1441
- return new _ShapeTracker(applyLast(this.views, (x) => x.pad(arg)));
1442
- }
1443
- shrink(arg) {
1444
- return new _ShapeTracker(applyLast(this.views, (x) => x.shrink(arg)));
1445
- }
1446
- expand(newShape) {
1447
- return new _ShapeTracker(applyLast(this.views, (x) => x.expand(newShape)));
1448
- }
1449
- permute(axis) {
1450
- return new _ShapeTracker(applyLast(this.views, (x) => x.permute(axis)));
1451
- }
1452
- flip(arg) {
1453
- return new _ShapeTracker(applyLast(this.views, (x) => x.flip(arg)));
1454
- }
1455
- reshape(newShape) {
1456
- const newView = this.views[this.views.length - 1].reshape(newShape);
1457
- return new _ShapeTracker(
1458
- newView === null ? this.views.concat(View.create(newShape)) : this.views.toSpliced(this.views.length - 1, 1, newView)
1459
- );
1460
- }
1461
- // Below this line are "composite" operations.
1462
- /** Broadcast along the given new axes, then expand the shape. */
1463
- broadcast(newShape, axis) {
1464
- let st = this;
1465
- if (axis.length > 0) {
1466
- const unsqueezed = [...st.shape];
1467
- for (const i of axis.toSorted()) {
1468
- unsqueezed.splice(i, 0, 1);
1469
- }
1470
- st = st.reshape(unsqueezed);
1471
- }
1472
- return st.expand(newShape);
1473
- }
1474
- };
1475
- function applyLast(ar, f) {
1476
- return ar.toSpliced(ar.length - 1, 1, f(ar[ar.length - 1]));
1477
- }
1478
-
1479
- // src/tuner.ts
1480
- var TuneDims = class {
1481
- st;
1482
- // Shape tracker including reduction axes.
1483
- outputSt;
1484
- // Shape tracker including only output axes.
1485
- // local: number; // TODO: Split gidx -> global and local axes during tuning.
1486
- groups;
1487
- // Reductions start here, with groups.
1488
- reduce;
1489
- // Single reduction thread.
1490
- unroll;
1491
- // Upcast along the reduce dimension.
1492
- upcast;
1493
- // Upcast along output dimension.
1494
- get end() {
1495
- return this.st.shape.length;
1496
- }
1497
- constructor(shape) {
1498
- this.st = ShapeTracker.fromShape(shape);
1499
- this.outputSt = ShapeTracker.fromShape(shape.slice(0, -1));
1500
- this.groups = this.st.shape.length - 1;
1501
- this.reduce = this.st.shape.length - 1;
1502
- this.unroll = this.st.shape.length;
1503
- this.upcast = this.st.shape.length;
1504
- }
1505
- // Place the axis at the end of the shape, so it is part of each workgroup.
1506
- applyLocal(axis, amount) {
1507
- if (axis >= this.groups) throw new Error("Cannot localize reduction axis");
1508
- const length = this.st.shape[axis];
1509
- if (length % amount !== 0)
1510
- throw new Error(`Localize by ${amount} on axis length ${length}`);
1511
- if (length !== amount) {
1512
- this.groups++, this.reduce++, this.unroll++, this.upcast++;
1513
- this.st = this.st.reshape([
1514
- ...this.st.shape.slice(0, axis),
1515
- length / amount,
1516
- amount,
1517
- ...this.st.shape.slice(axis + 1)
1518
- ]);
1519
- this.outputSt = this.outputSt.reshape([
1520
- ...this.outputSt.shape.slice(0, axis),
1521
- length / amount,
1522
- amount,
1523
- ...this.outputSt.shape.slice(axis + 1)
1524
- ]);
1525
- axis++;
1526
- }
1527
- this.st = this.st.permute([
1528
- ...range(axis),
1529
- ...range(axis + 1, this.groups),
1530
- axis,
1531
- ...range(this.groups, this.st.shape.length)
1532
- ]);
1533
- this.outputSt = this.outputSt.permute([
1534
- ...range(axis),
1535
- ...range(axis + 1, this.groups),
1536
- axis,
1537
- ...range(this.groups, this.outputSt.shape.length)
1538
- ]);
1539
- }
1540
- applyUpcast(axis, amount) {
1541
- if (axis >= this.groups)
1542
- throw new Error("Cannot upcast along reduction axis");
1543
- const length = this.st.shape[axis];
1544
- if (length % amount !== 0)
1545
- throw new Error(`Upcast by ${amount} on axis length ${length}`);
1546
- this.st = this.st.reshape([
1547
- ...this.st.shape.slice(0, axis),
1548
- length / amount,
1549
- amount,
1550
- ...this.st.shape.slice(axis + 1)
1551
- ]).permute([
1552
- ...range(axis + 1),
1553
- ...range(axis + 2, this.st.shape.length + 1),
1554
- axis + 1
1555
- ]);
1556
- this.outputSt = this.outputSt.reshape([
1557
- ...this.outputSt.shape.slice(0, axis),
1558
- length / amount,
1559
- amount,
1560
- ...this.outputSt.shape.slice(axis + 1)
1561
- ]).permute([
1562
- ...range(axis + 1),
1563
- ...range(axis + 2, this.outputSt.shape.length + 1),
1564
- axis + 1
1565
- ]);
1566
- }
1567
- applyUnroll(axis, amount) {
1568
- if (axis < this.groups) throw new Error("Cannot unroll non-reduce axis");
1569
- if (axis >= this.unroll) throw new Error("Axis already unrolled");
1570
- const length = this.st.shape[axis];
1571
- if (length % amount !== 0)
1572
- throw new Error(`Unroll by ${amount} on axis length ${length}`);
1573
- if (length === amount) {
1574
- this.st = this.st.permute([
1575
- ...range(axis),
1576
- ...range(axis + 1, this.upcast),
1577
- axis,
1578
- ...range(this.upcast, this.st.shape.length)
1579
- ]);
1580
- if (axis < this.reduce) this.reduce--;
1581
- this.unroll--;
1582
- } else {
1583
- this.st = this.st.reshape([
1584
- ...this.st.shape.slice(0, axis),
1585
- length / amount,
1586
- amount,
1587
- ...this.st.shape.slice(axis + 1)
1588
- ]).permute([
1589
- ...range(axis + 1),
1590
- ...range(axis + 2, this.upcast + 1),
1591
- // Move to just before upcast
1592
- axis + 1,
1593
- ...range(this.upcast + 1, this.st.shape.length + 1)
1594
- ]);
1595
- this.upcast++;
1596
- }
1597
- }
1598
- };
1599
- function tuneNullopt(kernel) {
1600
- const vars = {};
1601
- vars.gidx = AluExp.special("int32" /* Int32 */, "gidx", kernel.size);
1602
- if (kernel.reduction)
1603
- vars.ridx = AluExp.special("int32" /* Int32 */, "ridx", kernel.reduction.size);
1604
- return {
1605
- exp: kernel.exp.rewrite((exp) => {
1606
- if (exp.op === "GlobalView" /* GlobalView */) {
1607
- const gid = exp.arg[0];
1608
- const st = exp.arg[1];
1609
- return accessorGlobal(exp.dtype, gid, st, exp.src);
1610
- }
1611
- }).substitute(vars).simplify(),
1612
- outputIdxExp: AluExp.special("int32" /* Int32 */, "gidx", kernel.size),
1613
- threadCount: kernel.size,
1614
- size: {
1615
- reduce: kernel.reduction ? kernel.reduction.size : 0
1616
- }
1617
- };
1618
- }
1619
- function tuneWebgpu(kernel) {
1620
- const { exp, reduction } = kernel;
1621
- if (!reduction) return tuneNullopt(kernel);
1622
- const globalIndexes = exp.collect((exp2) => exp2.op === "GlobalIndex" /* GlobalIndex */);
1623
- if (globalIndexes.length > 0) {
1624
- if (DEBUG >= 4) console.log("Tuning: Found GlobalIndex ops, skipping opt.");
1625
- return tuneNullopt(kernel);
1626
- }
1627
- const globalViews = exp.collect((exp2) => exp2.op === "GlobalView" /* GlobalView */);
1628
- if (globalViews.length === 0) {
1629
- if (DEBUG >= 4) console.info("Tuning: No GlobalView ops found in kernel.");
1630
- return tuneNullopt(kernel);
1631
- }
1632
- const shape = globalViews[0].arg[1].shape;
1633
- const expectedSrc = [
1634
- ...unravelAlu(shape.slice(0, -1), AluVar.gidx),
1635
- AluVar.ridx
1636
- ].map((e) => e.simplify());
1637
- for (const gv of globalViews) {
1638
- if (!gv.src.length || !deepEqual(gv.src, expectedSrc)) {
1639
- if (DEBUG >= 4)
1640
- console.info("Tuning: GlobalView src[] not consistent with reduction.");
1641
- return tuneNullopt(kernel);
1642
- }
1643
- }
1644
- if (shape[shape.length - 1] !== reduction.size)
1645
- throw new Error("Invariant violation: shape doesn't match reduction size.");
1646
- const sts = globalViews.map((gv) => gv.arg[1]);
1647
- for (const st of sts) {
1648
- if (!deepEqual(st.shape, shape))
1649
- throw new Error("Invariant violation: GlobalView shape mismatch");
1650
- }
1651
- const dim = new TuneDims(shape);
1652
- const upcastedAxis = /* @__PURE__ */ new Set();
1653
- while (prod(dim.st.shape.slice(0, dim.groups)) >= 1024) {
1654
- const choices = [];
1655
- const composedSts = sts.map((st) => st.compose(dim.st));
1656
- for (let axis = 0; axis < dim.groups; axis++) {
1657
- for (const amount of [3, 4]) {
1658
- if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some(
1659
- (st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0)
1660
- )) {
1661
- let nonzeroStrides = 0;
1662
- let totalStrides = 0;
1663
- for (const st of composedSts) {
1664
- nonzeroStrides += st.lastStrides[axis] > 0 ? 1 : 0;
1665
- totalStrides += st.lastStrides[axis];
1666
- }
1667
- choices.push([nonzeroStrides, totalStrides, axis, amount]);
1668
- }
1669
- }
1670
- }
1671
- if (choices.length > 0) {
1672
- choices.sort(lexCompare);
1673
- dim.applyUpcast(choices[0][2], choices[0][3]);
1674
- upcastedAxis.add(choices[0][2]);
1675
- } else {
1676
- break;
1677
- }
1678
- }
1679
- if (/chrome/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
1680
- const s = dim.st.shape[dim.unroll - 1];
1681
- if (s <= 32) {
1682
- dim.applyUnroll(dim.reduce, s);
1683
- } else {
1684
- for (const splits of [4]) {
1685
- if (s % splits === 0) {
1686
- dim.applyUnroll(dim.unroll - 1, splits);
1687
- break;
1688
- }
1689
- }
1690
- }
1691
- }
1692
- for (const ax of Array.from(upcastedAxis).sort()) {
1693
- const s = dim.st.shape[ax];
1694
- for (const amount of [8, 4]) {
1695
- if (s % amount === 0) {
1696
- dim.applyLocal(ax, amount);
1697
- break;
1698
- }
1699
- }
1700
- }
1701
- const indices = [];
1702
- const addIndices = (s, exp2) => {
1703
- if (s.length === 0) return;
1704
- else if (s.length === 1) indices.push(exp2);
1705
- else indices.push(...unravelAlu(s, exp2));
1706
- };
1707
- if (0 < dim.groups) {
1708
- const s = dim.st.shape.slice(0, dim.groups);
1709
- addIndices(s, AluExp.special("int32" /* Int32 */, "gidx", prod(s)));
1710
- }
1711
- if (dim.groups < dim.reduce) {
1712
- const s = dim.st.shape.slice(dim.groups, dim.reduce);
1713
- addIndices(s, AluExp.special("int32" /* Int32 */, "group", prod(s)));
1714
- }
1715
- if (dim.reduce <= dim.unroll) {
1716
- const s = dim.st.shape.slice(dim.reduce, dim.unroll);
1717
- addIndices(s, AluExp.special("int32" /* Int32 */, "ridx", prod(s)));
1718
- }
1719
- if (dim.unroll < dim.upcast) {
1720
- const s = dim.st.shape.slice(dim.unroll, dim.upcast);
1721
- addIndices(s, AluVar.unroll);
1722
- }
1723
- if (dim.upcast < dim.end) {
1724
- const s = dim.st.shape.slice(dim.upcast);
1725
- addIndices(s, AluVar.upcast);
1726
- }
1727
- const newExp = exp.rewrite((exp2) => {
1728
- if (exp2.op === "GlobalView" /* GlobalView */) {
1729
- const gid = exp2.arg[0];
1730
- const st = exp2.arg[1];
1731
- return accessorGlobal(exp2.dtype, gid, st.compose(dim.st), indices);
1732
- }
1733
- });
1734
- const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
1735
- const outputUpcast = dim.outputSt.shape.slice(dim.groups);
1736
- const [outputIdxExp, _] = dim.outputSt.toAluExp([
1737
- ...unravelAlu(
1738
- outputGidx,
1739
- AluExp.special("int32" /* Int32 */, "gidx", prod(outputGidx))
1740
- ),
1741
- ...unravelAlu(outputUpcast, AluVar.upcast)
1742
- // Needs later substitution.
1743
- ]);
1744
- if (prod(dim.st.shape.slice(dim.groups, dim.upcast)) !== reduction.size) {
1745
- throw new Error(
1746
- `Invariant violation: reduction size ${reduction.size} does not match tuned dims ${JSON.stringify(dim.st.shape.slice(dim.groups, dim.upcast))}`
1747
- );
1748
- }
1749
- const size = {
1750
- groups: prod(dim.st.shape.slice(dim.groups, dim.reduce)),
1751
- reduce: prod(dim.st.shape.slice(dim.reduce, dim.unroll)),
1752
- unroll: prod(dim.st.shape.slice(dim.unroll, dim.upcast)),
1753
- upcast: prod(dim.st.shape.slice(dim.upcast))
1754
- };
1755
- return {
1756
- exp: newExp.simplify(),
1757
- outputIdxExp: outputIdxExp.simplify(),
1758
- threadCount: kernel.size / size.upcast * size.groups,
1759
- size
1760
- };
1761
- }
1762
-
1763
- // src/backend/cpu.ts
1764
- var CPUBackend = class {
1765
- type = "cpu";
1766
- maxArgs = Infinity;
1767
- #buffers;
1768
- #nextSlot;
1769
- constructor() {
1770
- this.#buffers = /* @__PURE__ */ new Map();
1771
- this.#nextSlot = 1;
1772
- }
1773
- malloc(size, initialData) {
1774
- const buffer = new ArrayBuffer(size);
1775
- if (initialData) {
1776
- if (initialData.byteLength !== size) {
1777
- throw new Error("initialData size does not match buffer size");
1778
- }
1779
- new Uint8Array(buffer).set(new Uint8Array(initialData));
1780
- }
1781
- const slot = this.#nextSlot++;
1782
- this.#buffers.set(slot, { buffer, ref: 1 });
1783
- return slot;
1784
- }
1785
- incRef(slot) {
1786
- const buffer = this.#buffers.get(slot);
1787
- if (!buffer) throw new SlotError(slot);
1788
- buffer.ref++;
1789
- }
1790
- decRef(slot) {
1791
- const buffer = this.#buffers.get(slot);
1792
- if (!buffer) throw new SlotError(slot);
1793
- buffer.ref--;
1794
- if (buffer.ref === 0) {
1795
- this.#buffers.delete(slot);
1796
- }
1797
- }
1798
- async read(slot, start, count) {
1799
- return this.readSync(slot, start, count);
1800
- }
1801
- readSync(slot, start, count) {
1802
- const buffer = this.#getBuffer(slot);
1803
- if (start === void 0) start = 0;
1804
- if (count === void 0) count = buffer.byteLength - start;
1805
- return buffer.slice(start, start + count);
1806
- }
1807
- async prepare(kernel) {
1808
- return this.prepareSync(kernel);
1809
- }
1810
- prepareSync(kernel) {
1811
- return new Executable(kernel, void 0);
1812
- }
1813
- dispatch({ kernel }, inputs, outputs) {
1814
- const { exp } = tuneNullopt(kernel);
1815
- const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
1816
- const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
1817
- const usedArgs = new Map(
1818
- exp.collect((exp2) => exp2.op === "GlobalIndex" /* GlobalIndex */).map((exp2) => [exp2.arg, exp2.dtype])
1819
- );
1820
- const inputArrays = inputBuffers.map((buf, i) => {
1821
- const dtype = usedArgs.get(i);
1822
- if (!dtype) return null;
1823
- return dtype === "float32" /* Float32 */ ? new Float32Array(buf) : new Int32Array(buf);
1824
- });
1825
- const outputArray = exp.dtype === "float32" /* Float32 */ ? new Float32Array(outputBuffers[0]) : new Int32Array(outputBuffers[0]);
1826
- const globals = (gidx, bufidx) => inputArrays[gidx][bufidx];
1827
- if (!kernel.reduction) {
1828
- for (let i = 0; i < kernel.size; i++) {
1829
- outputArray[i] = exp.evaluate({ gidx: i }, globals);
1830
- }
1831
- } else {
1832
- for (let i = 0; i < kernel.size; i++) {
1833
- let acc = kernel.reduction.identity;
1834
- for (let j = 0; j < kernel.reduction.size; j++) {
1835
- const item = exp.evaluate({ gidx: i, ridx: j }, globals);
1836
- acc = kernel.reduction.evaluate(acc, item);
1837
- }
1838
- outputArray[i] = kernel.reduction.fusion.evaluate({ acc });
1839
- }
1840
- }
1841
- }
1842
- #getBuffer(slot) {
1843
- const buffer = this.#buffers.get(slot);
1844
- if (!buffer) throw new SlotError(slot);
1845
- return buffer.buffer;
1846
- }
1847
- };
1848
-
1849
- // src/backend.ts
1850
- var devices = ["cpu", "webgpu"];
1851
- var defaultBackend = "cpu";
1852
- var initializedBackends = /* @__PURE__ */ new Map();
1853
- initializedBackends.set("cpu", new CPUBackend());
1854
- function setDevice(device) {
1855
- if (initializedBackends.has(device)) {
1856
- defaultBackend = device;
1857
- } else {
1858
- throw new Error(`Backend not initialized: ${device}`);
1859
- }
1860
- }
1861
- async function init(...devicesToInit) {
1862
- if (devicesToInit.length === 0) {
1863
- devicesToInit = devices;
1864
- }
1865
- const promises = [];
1866
- for (const device of new Set(devicesToInit)) {
1867
- if (!initializedBackends.has(device)) {
1868
- promises.push(
1869
- (async () => {
1870
- const backend = await createBackend(device);
1871
- if (backend) {
1872
- initializedBackends.set(device, backend);
1873
- }
1874
- })()
1875
- );
1876
- }
1877
- }
1878
- await Promise.all(promises);
1879
- return Array.from(initializedBackends.keys());
1880
- }
1881
- async function createBackend(device) {
1882
- if (device === "cpu") {
1883
- return new CPUBackend();
1884
- } else if (device === "webgpu") {
1885
- if (!navigator.gpu) return null;
1886
- const adapter = await navigator.gpu.requestAdapter({
1887
- powerPreference: "high-performance"
1888
- });
1889
- if (!adapter) return null;
1890
- const { WebGPUBackend } = await import("./webgpu-QNXDOQZP.js");
1891
- const importantLimits = [
1892
- "maxBufferSize",
1893
- "maxComputeInvocationsPerWorkgroup",
1894
- "maxComputeWorkgroupSizeX",
1895
- // All of our workgroups use X or Y.
1896
- "maxComputeWorkgroupSizeY",
1897
- "maxComputeWorkgroupSizeZ",
1898
- "maxComputeWorkgroupStorageSize",
1899
- "maxComputeWorkgroupsPerDimension",
1900
- // Grid size limited to 65535 due to AMD storage in u16.
1901
- "maxStorageBufferBindingSize",
1902
- "maxStorageBuffersPerShaderStage",
1903
- "maxStorageTexturesPerShaderStage"
1904
- ];
1905
- try {
1906
- const device2 = await adapter.requestDevice({
1907
- requiredLimits: Object.fromEntries(
1908
- importantLimits.map((feature) => [feature, adapter.limits[feature]])
1909
- )
1910
- });
1911
- return new WebGPUBackend(device2);
1912
- } catch (error) {
1913
- console.error("Unexpected error requesting WebGPU device:", error);
1914
- return null;
1915
- }
1916
- } else {
1917
- throw new Error(`Backend not found: ${device}`);
1918
- }
1919
- }
1920
- function getBackend(device) {
1921
- device = device ?? defaultBackend;
1922
- const backend = initializedBackends.get(device);
1923
- if (!backend) {
1924
- throw new Error(`${device} backend not ready, call init() first`);
1925
- }
1926
- return backend;
1927
- }
1928
- var Executable = class {
1929
- constructor(kernel, data) {
1930
- this.kernel = kernel;
1931
- this.data = data;
1932
- }
1933
- };
1934
- var SlotError = class extends Error {
1935
- constructor(slot) {
1936
- super(`Used a buffer that is invalid or already freed: ${slot}`);
1937
- }
1938
- };
1939
-
1940
- export {
1941
- __export,
1942
- __using,
1943
- __callDispose,
1944
- DEBUG,
1945
- unzip2,
1946
- zip,
1947
- rep,
1948
- prod,
1949
- deepEqual,
1950
- partitionList,
1951
- range,
1952
- isPermutation,
1953
- invertPermutation,
1954
- toposort,
1955
- findPow2,
1956
- recursiveFlatten,
1957
- strip1,
1958
- FpHash,
1959
- runWithCache,
1960
- DType,
1961
- isFloatDtype,
1962
- AluExp,
1963
- AluGroup,
1964
- AluVar,
1965
- Kernel,
1966
- Reduction,
1967
- accessorGlobal,
1968
- accessorAluExp,
1969
- unravelAlu,
1970
- ShapeTracker,
1971
- tuneWebgpu,
1972
- devices,
1973
- setDevice,
1974
- init,
1975
- getBackend,
1976
- Executable,
1977
- SlotError
1978
- };