@jax-js/jax 0.0.2 → 0.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +9 -8
- package/dist/{backend-1eVbAoaV.js → backend-BqDtPGaR.js} +1869 -86
- package/dist/{backend-BK21PBVP.cjs → backend-D2C4MJRP.cjs} +1892 -85
- package/dist/index.cjs +737 -118
- package/dist/index.d.cts +247 -44
- package/dist/index.d.ts +247 -44
- package/dist/index.js +726 -114
- package/dist/{webgpu-JVpVad6g.js → webgpu-CNg9JGva.js} +54 -33
- package/dist/{webgpu-c5Fe8nx8.cjs → webgpu-fqhx41TC.cjs} +54 -33
- package/package.json +7 -6
|
@@ -49,6 +49,10 @@ function unzip2(pairs) {
|
|
|
49
49
|
function zip(xs, ys) {
|
|
50
50
|
return xs.map((x, i) => [x, ys[i]]);
|
|
51
51
|
}
|
|
52
|
+
function zipn(...arrays) {
|
|
53
|
+
const minLength = Math.min(...arrays.map((x) => x.length));
|
|
54
|
+
return Array.from({ length: minLength }, (_, i) => arrays.map((arr) => arr[i]));
|
|
55
|
+
}
|
|
52
56
|
function rep(length, value) {
|
|
53
57
|
if (value instanceof Function) return new Array(length).fill(0).map((_, i) => value(i));
|
|
54
58
|
return new Array(length).fill(value);
|
|
@@ -56,6 +60,11 @@ function rep(length, value) {
|
|
|
56
60
|
function prod(arr) {
|
|
57
61
|
return arr.reduce((acc, x) => acc * x, 1);
|
|
58
62
|
}
|
|
63
|
+
function gcd(...values) {
|
|
64
|
+
let a = 0;
|
|
65
|
+
for (let b of values) while (b !== 0) [a, b] = [b, a % b];
|
|
66
|
+
return Math.abs(a);
|
|
67
|
+
}
|
|
59
68
|
/** Shorthand for integer division, like in Python. */
|
|
60
69
|
function intdiv(a, b) {
|
|
61
70
|
return Math.floor(a / b);
|
|
@@ -73,6 +82,11 @@ function deepEqual(a, b) {
|
|
|
73
82
|
for (const key of Object.keys(a)) if (!deepEqual(a[key], b[key])) return false;
|
|
74
83
|
return true;
|
|
75
84
|
}
|
|
85
|
+
function union(...sets) {
|
|
86
|
+
const result = /* @__PURE__ */ new Set();
|
|
87
|
+
for (const s of sets) if (s) for (const x of s) result.add(x);
|
|
88
|
+
return result;
|
|
89
|
+
}
|
|
76
90
|
/** Splits the list based on a condition, `false` first then `true`. */
|
|
77
91
|
function partitionList(which, array) {
|
|
78
92
|
const falseList = [];
|
|
@@ -217,12 +231,13 @@ function runWithCache(cache, key, thunk) {
|
|
|
217
231
|
|
|
218
232
|
//#endregion
|
|
219
233
|
//#region src/alu.ts
|
|
234
|
+
/** A numerical data type for array contents. */
|
|
220
235
|
let DType = /* @__PURE__ */ function(DType$1) {
|
|
221
236
|
DType$1["Float32"] = "float32";
|
|
222
237
|
DType$1["Int32"] = "int32";
|
|
223
238
|
DType$1["Uint32"] = "uint32";
|
|
224
239
|
DType$1["Bool"] = "bool";
|
|
225
|
-
DType$1["
|
|
240
|
+
DType$1["Float16"] = "float16";
|
|
226
241
|
return DType$1;
|
|
227
242
|
}({});
|
|
228
243
|
const byteWidth = (dtype) => {
|
|
@@ -231,17 +246,30 @@ const byteWidth = (dtype) => {
|
|
|
231
246
|
case DType.Int32:
|
|
232
247
|
case DType.Uint32:
|
|
233
248
|
case DType.Bool: return 4;
|
|
234
|
-
case DType.
|
|
249
|
+
case DType.Float16: return 2;
|
|
235
250
|
default: throw new TypeError(`Unknown dtype: ${dtype}`);
|
|
236
251
|
}
|
|
237
252
|
};
|
|
238
|
-
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.
|
|
253
|
+
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
|
|
239
254
|
function dtypedArray(dtype, data) {
|
|
255
|
+
const { buffer, byteLength, byteOffset } = data;
|
|
256
|
+
const length = byteLength / byteWidth(dtype);
|
|
257
|
+
switch (dtype) {
|
|
258
|
+
case DType.Float32: return new Float32Array(buffer, byteOffset, length);
|
|
259
|
+
case DType.Int32:
|
|
260
|
+
case DType.Bool: return new Int32Array(buffer, byteOffset, length);
|
|
261
|
+
case DType.Uint32: return new Uint32Array(buffer, byteOffset, length);
|
|
262
|
+
case DType.Float16: return new Float16Array(buffer, byteOffset, length);
|
|
263
|
+
default: throw new Error(`Unimplemented dtype: ${dtype}`);
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
function dtypedJsArray(dtype, data) {
|
|
240
267
|
switch (dtype) {
|
|
241
268
|
case DType.Float32: return new Float32Array(data);
|
|
242
|
-
case DType.Int32:
|
|
243
|
-
case DType.Uint32: return new Uint32Array(data);
|
|
269
|
+
case DType.Int32:
|
|
244
270
|
case DType.Bool: return new Int32Array(data);
|
|
271
|
+
case DType.Uint32: return new Uint32Array(data);
|
|
272
|
+
case DType.Float16: return new Float16Array(data);
|
|
245
273
|
default: throw new Error(`Unimplemented dtype: ${dtype}`);
|
|
246
274
|
}
|
|
247
275
|
}
|
|
@@ -298,6 +326,9 @@ var AluExp = class AluExp {
|
|
|
298
326
|
static log(a) {
|
|
299
327
|
return new AluExp(AluOp.Log, a.dtype, [a]);
|
|
300
328
|
}
|
|
329
|
+
static sqrt(a) {
|
|
330
|
+
return new AluExp(AluOp.Sqrt, a.dtype, [a]);
|
|
331
|
+
}
|
|
301
332
|
static reciprocal(a) {
|
|
302
333
|
return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
|
|
303
334
|
}
|
|
@@ -343,24 +374,27 @@ var AluExp = class AluExp {
|
|
|
343
374
|
static variable(dtype, name) {
|
|
344
375
|
return new AluExp(AluOp.Variable, dtype, [], name);
|
|
345
376
|
}
|
|
346
|
-
static globalIndex(dtype, gid, bufidx) {
|
|
347
|
-
return new AluExp(AluOp.GlobalIndex, dtype, [bufidx], gid);
|
|
377
|
+
static globalIndex(dtype, gid, len, bufidx) {
|
|
378
|
+
return new AluExp(AluOp.GlobalIndex, dtype, [bufidx], [gid, len]);
|
|
348
379
|
}
|
|
349
380
|
static globalView(dtype, gid, st, indices) {
|
|
350
381
|
return new AluExp(AluOp.GlobalView, dtype, indices, [gid, st]);
|
|
351
382
|
}
|
|
383
|
+
static f32(value) {
|
|
384
|
+
return AluExp.const(DType.Float32, value);
|
|
385
|
+
}
|
|
352
386
|
static i32(value) {
|
|
353
387
|
return AluExp.const(DType.Int32, value);
|
|
354
388
|
}
|
|
355
389
|
static u32(value) {
|
|
356
390
|
return AluExp.const(DType.Uint32, value);
|
|
357
391
|
}
|
|
358
|
-
static f32(value) {
|
|
359
|
-
return AluExp.const(DType.Float32, value);
|
|
360
|
-
}
|
|
361
392
|
static bool(value) {
|
|
362
393
|
return AluExp.const(DType.Bool, Number(value));
|
|
363
394
|
}
|
|
395
|
+
static f16(value) {
|
|
396
|
+
return AluExp.const(DType.Float16, value);
|
|
397
|
+
}
|
|
364
398
|
not() {
|
|
365
399
|
if (this.dtype !== DType.Bool) throw new Error("not() can only be called on boolean expressions");
|
|
366
400
|
return AluExp.cmpne(this, AluExp.const(DType.Bool, true));
|
|
@@ -390,9 +424,9 @@ var AluExp = class AluExp {
|
|
|
390
424
|
reindexGids(gidMap) {
|
|
391
425
|
return this.rewrite((exp) => {
|
|
392
426
|
if (exp.op === AluOp.GlobalIndex) {
|
|
393
|
-
const gid = exp.arg;
|
|
427
|
+
const [gid, len] = exp.arg;
|
|
394
428
|
const newGid = gidMap.get(gid);
|
|
395
|
-
if (newGid !== void 0 && newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, exp.src[0]);
|
|
429
|
+
if (newGid !== void 0 && newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
|
|
396
430
|
} else if (exp.op === AluOp.GlobalView) {
|
|
397
431
|
const gid = exp.arg[0];
|
|
398
432
|
const newGid = gidMap.get(gid);
|
|
@@ -421,17 +455,16 @@ var AluExp = class AluExp {
|
|
|
421
455
|
case AluOp.Sub:
|
|
422
456
|
ret = [src[0].min - src[1].max, src[0].max - src[1].min];
|
|
423
457
|
break;
|
|
424
|
-
case AluOp.Mul:
|
|
458
|
+
case AluOp.Mul:
|
|
425
459
|
ret = minMax4((a, b) => a * b);
|
|
426
460
|
break;
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
ret = minMax4((a, b) => Math.floor(a / b));
|
|
461
|
+
case AluOp.Idiv:
|
|
462
|
+
ret = minMax4((a, b) => Math.trunc(a / b));
|
|
430
463
|
break;
|
|
431
|
-
}
|
|
432
464
|
case AluOp.Mod: {
|
|
433
465
|
let divisorRange = src[1].#computeRange();
|
|
434
466
|
if (divisorRange[0] <= 0 && divisorRange[1] >= 0) divisorRange = [0, Math.max(-divisorRange[0], divisorRange[1])];
|
|
467
|
+
if (divisorRange[1] < 0) divisorRange = [-divisorRange[1], -divisorRange[0]];
|
|
435
468
|
const maxDivisor = isFloatDtype(this.dtype) ? divisorRange[1] : divisorRange[1] - 1;
|
|
436
469
|
ret = [clamp(src[0].min, -maxDivisor, 0), clamp(src[0].max, 0, maxDivisor)];
|
|
437
470
|
break;
|
|
@@ -454,23 +487,31 @@ var AluExp = class AluExp {
|
|
|
454
487
|
case AluOp.Log:
|
|
455
488
|
ret = [Math.log(src[0].min), Math.log(src[0].max)];
|
|
456
489
|
break;
|
|
490
|
+
case AluOp.Sqrt:
|
|
491
|
+
ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
|
|
492
|
+
break;
|
|
457
493
|
case AluOp.Reciprocal:
|
|
458
494
|
if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
|
|
459
495
|
ret = [1 / src[0].max, 1 / src[0].min];
|
|
460
496
|
break;
|
|
461
|
-
case AluOp.Cast:
|
|
497
|
+
case AluOp.Cast: {
|
|
498
|
+
const wasFloat = isFloatDtype(src[0].dtype);
|
|
499
|
+
const bounded = Number.isFinite(src[0].min) && Number.isFinite(src[0].max);
|
|
462
500
|
if (this.dtype === DType.Bool) {
|
|
463
501
|
const canBeZero = src[0].min <= 0 && src[0].max >= 0;
|
|
464
502
|
const mustBeZero = src[0].min === 0 && src[0].max === 0;
|
|
465
503
|
ret = mustBeZero ? [0, 0] : canBeZero ? [0, 1] : [1, 1];
|
|
466
|
-
} else if (this.dtype === DType.Int32)
|
|
467
|
-
|
|
468
|
-
const
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
504
|
+
} else if (this.dtype === DType.Int32) {
|
|
505
|
+
const a = wasFloat ? clamp(src[0].min, -2147483648, 2147483647) | 0 : src[0].min | 0;
|
|
506
|
+
const b = wasFloat ? clamp(src[0].max, -2147483648, 2147483647) | 0 : src[0].max | 0;
|
|
507
|
+
ret = bounded && a <= b ? [a, b] : [-Infinity, Infinity];
|
|
508
|
+
} else if (this.dtype === DType.Uint32) {
|
|
509
|
+
const a = wasFloat ? clamp(src[0].min, 0, 4294967295) >>> 0 : src[0].min >>> 0;
|
|
510
|
+
const b = wasFloat ? clamp(src[0].max, 0, 4294967295) >>> 0 : src[0].max >>> 0;
|
|
511
|
+
ret = bounded && a <= b ? [a, b] : [0, Infinity];
|
|
472
512
|
} else ret = [src[0].min, src[0].max];
|
|
473
513
|
break;
|
|
514
|
+
}
|
|
474
515
|
case AluOp.Cmplt:
|
|
475
516
|
ret = [0, 1];
|
|
476
517
|
break;
|
|
@@ -493,6 +534,7 @@ var AluExp = class AluExp {
|
|
|
493
534
|
ret[0] = clamp(ret[0], 0, 1);
|
|
494
535
|
ret[1] = clamp(ret[1], 0, 1);
|
|
495
536
|
}
|
|
537
|
+
if (this.dtype === DType.Uint32) ret[0] = Math.max(0, ret[0]);
|
|
496
538
|
this.#range = ret;
|
|
497
539
|
return ret;
|
|
498
540
|
}
|
|
@@ -502,10 +544,51 @@ var AluExp = class AluExp {
|
|
|
502
544
|
get max() {
|
|
503
545
|
return this.#computeRange()[1];
|
|
504
546
|
}
|
|
547
|
+
/** Largest known integer that divides self. */
|
|
548
|
+
constFactor() {
|
|
549
|
+
if (this.op === AluOp.Const) return Math.abs(this.arg);
|
|
550
|
+
if (this.op === AluOp.Add) return gcd(this.src[0].constFactor(), this.src[1].constFactor());
|
|
551
|
+
if (this.op === AluOp.Mul) {
|
|
552
|
+
if (this.src[0].op === AluOp.Const) return Math.abs(this.src[0].arg);
|
|
553
|
+
if (this.src[1].op === AluOp.Const) return Math.abs(this.src[1].arg);
|
|
554
|
+
}
|
|
555
|
+
return 1;
|
|
556
|
+
}
|
|
557
|
+
/**
|
|
558
|
+
* Checks if divisible by an integer v and returns the quotient if it is, or
|
|
559
|
+
* `null` if it's not divisible.
|
|
560
|
+
*/
|
|
561
|
+
divides(v) {
|
|
562
|
+
if (v === 1) return this;
|
|
563
|
+
if (this.op === AluOp.Const && this.arg % v === 0) return AluExp.const(this.dtype, this.arg / v);
|
|
564
|
+
if (this.op === AluOp.Add) {
|
|
565
|
+
const a = this.src[0].divides(v);
|
|
566
|
+
if (a !== null) {
|
|
567
|
+
const b = this.src[1].divides(v);
|
|
568
|
+
if (b !== null) return AluExp.add(a, b);
|
|
569
|
+
}
|
|
570
|
+
}
|
|
571
|
+
if (this.op === AluOp.Mul) {
|
|
572
|
+
const a = this.src[0].divides(v);
|
|
573
|
+
if (a !== null) return AluExp.mul(a, this.src[1]);
|
|
574
|
+
const b = this.src[1].divides(v);
|
|
575
|
+
if (b !== null) return AluExp.mul(this.src[0], b);
|
|
576
|
+
}
|
|
577
|
+
return null;
|
|
578
|
+
}
|
|
505
579
|
#isConstInt() {
|
|
506
580
|
return this.op === AluOp.Const && (this.dtype === DType.Int32 || this.dtype === DType.Uint32);
|
|
507
581
|
}
|
|
508
582
|
/**
|
|
583
|
+
* Get all expressions by deeply matching an operation.
|
|
584
|
+
*
|
|
585
|
+
* For example: `((2+(3*5))+4).splitOp(+) -> [2,(3*5),4]`.
|
|
586
|
+
*/
|
|
587
|
+
*splitOp(sep) {
|
|
588
|
+
if (this.op === sep) for (const src of this.src) yield* src.splitOp(sep);
|
|
589
|
+
else yield this;
|
|
590
|
+
}
|
|
591
|
+
/**
|
|
509
592
|
* Simplify the expression by replacing any known patterns and deduping
|
|
510
593
|
* identical subexpressions.
|
|
511
594
|
*/
|
|
@@ -550,7 +633,24 @@ var AluExp = class AluExp {
|
|
|
550
633
|
if (a.op === AluOp.Const && a.arg === -1) return new AluExp(opNeg, this.dtype, [src[0], b]);
|
|
551
634
|
else if (b.op === AluOp.Const && b.arg === -1) return new AluExp(opNeg, this.dtype, [src[0], a]);
|
|
552
635
|
}
|
|
636
|
+
if (op === AluOp.Where && src.slice(1).every((s, i) => s.op === AluOp.Const && s.arg === 1 - i)) return AluExp.cast(this.dtype, src[0]);
|
|
637
|
+
if (op === AluOp.Cmplt) {
|
|
638
|
+
if (src[0].min >= src[1].max) return AluExp.const(DType.Bool, false);
|
|
639
|
+
if (src[0].max < src[1].min) return AluExp.const(DType.Bool, true);
|
|
640
|
+
}
|
|
641
|
+
if (op === AluOp.Cmpne) {
|
|
642
|
+
if (src[0].max < src[1].min || src[0].min > src[1].max) return AluExp.const(DType.Bool, true);
|
|
643
|
+
}
|
|
644
|
+
if (op === AluOp.Where) {
|
|
645
|
+
if (src[0].max === 0) return src[2];
|
|
646
|
+
if (src[0].min === 1) return src[1];
|
|
647
|
+
}
|
|
553
648
|
if (op === AluOp.Mod && src[1].op === AluOp.Const && src[0].min >= 0 && src[0].max < src[1].arg) return src[0];
|
|
649
|
+
if (op === AluOp.Mod && src[0].op === AluOp.Mod && src[1].#isConstInt() && src[0].src[1].#isConstInt()) {
|
|
650
|
+
const A = src[0].src[1].arg;
|
|
651
|
+
const B = src[1].arg;
|
|
652
|
+
if (A > 0 && B > 0 && (A % B === 0 || B % A === 0)) return AluExp.mod(src[0].src[0], AluExp.const(this.dtype, Math.min(A, B))).simplify();
|
|
653
|
+
}
|
|
554
654
|
if (op === AluOp.Add && src[0].op === AluOp.Mul && src[0].src[1].#isConstInt() && src[1].op === AluOp.Mod && src[1].src[1].#isConstInt() && src[0].src[1].arg === src[1].src[1].arg) {
|
|
555
655
|
const [mul, mod] = src;
|
|
556
656
|
const check = (exp) => {
|
|
@@ -570,7 +670,7 @@ var AluExp = class AluExp {
|
|
|
570
670
|
const A = numer.src[i].arg;
|
|
571
671
|
if (A % B === 0) {
|
|
572
672
|
let ret = numer.src[1 - i];
|
|
573
|
-
if (A / B !== 1) ret = AluExp.mul(ret, AluExp.
|
|
673
|
+
if (A / B !== 1) ret = AluExp.mul(ret, AluExp.const(ret.dtype, A / B));
|
|
574
674
|
return ret.simplify(cache);
|
|
575
675
|
}
|
|
576
676
|
}
|
|
@@ -578,8 +678,8 @@ var AluExp = class AluExp {
|
|
|
578
678
|
const A = numer.src[j].src[i].arg;
|
|
579
679
|
if (A % B === 0) {
|
|
580
680
|
let ret = numer.src[j].src[1 - i];
|
|
581
|
-
if (A / B !== 1) ret = AluExp.mul(ret, AluExp.
|
|
582
|
-
ret = AluExp.add(ret, AluExp.idiv(numer.src[1 - j], B));
|
|
681
|
+
if (A / B !== 1) ret = AluExp.mul(ret, AluExp.const(ret.dtype, A / B));
|
|
682
|
+
ret = AluExp.add(ret, AluExp.idiv(numer.src[1 - j], AluExp.const(ret.dtype, B)));
|
|
583
683
|
return ret.simplify(cache);
|
|
584
684
|
}
|
|
585
685
|
}
|
|
@@ -588,23 +688,81 @@ var AluExp = class AluExp {
|
|
|
588
688
|
if (op === AluOp.Mod && src[1].#isConstInt() && src[1].arg > 0 && src[0].min >= 0) {
|
|
589
689
|
const [numer, denom] = src;
|
|
590
690
|
const B = denom.arg;
|
|
591
|
-
for (let i = 0; i < 2; i++) if (numer.op === AluOp.Add
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
691
|
+
for (let i = 0; i < 2; i++) if (numer.op === AluOp.Add) {
|
|
692
|
+
if (numer.src[i].#isConstInt()) {
|
|
693
|
+
const A = numer.src[i].arg;
|
|
694
|
+
const x = numer.src[1 - i];
|
|
695
|
+
if (A % B === 0 && x.min >= 0) return AluExp.mod(x, denom).simplify(cache);
|
|
696
|
+
}
|
|
697
|
+
for (let j = 0; j < 2; j++) if (numer.src[i].op === AluOp.Mul && numer.src[i].src[j].#isConstInt()) {
|
|
698
|
+
const A = numer.src[i].src[j].arg;
|
|
699
|
+
const x = numer.src[1 - i];
|
|
700
|
+
if (A % B === 0 && x.min >= 0) return AluExp.mod(x, denom).simplify(cache);
|
|
701
|
+
}
|
|
702
|
+
} else if (numer.op === AluOp.Mul) {
|
|
703
|
+
if (numer.src[i].#isConstInt()) {
|
|
704
|
+
const A = numer.src[i].arg;
|
|
705
|
+
if (A % B === 0) return AluExp.const(this.dtype, 0);
|
|
706
|
+
if (A % B === 1) return AluExp.mod(numer.src[1 - i], denom).simplify(cache);
|
|
707
|
+
}
|
|
596
708
|
}
|
|
597
709
|
}
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
710
|
+
const commOps = [
|
|
711
|
+
AluOp.Add,
|
|
712
|
+
AluOp.Mul,
|
|
713
|
+
AluOp.Max,
|
|
714
|
+
AluOp.Min
|
|
715
|
+
];
|
|
716
|
+
if (commOps.includes(op)) {
|
|
717
|
+
const p = (a, b) => new AluExp(op, this.dtype, [a, b]);
|
|
718
|
+
if (src[0].op === AluOp.Const) return p(src[1], src[0]).simplify(cache);
|
|
719
|
+
if (src[0].op === op && src[0].src[1].op === AluOp.Const) if (src[1].op === AluOp.Const) return p(src[0].src[0], p(src[0].src[1], src[1])).simplify(cache);
|
|
720
|
+
else return p(p(src[0].src[0], src[1]), src[0].src[1]).simplify(cache);
|
|
721
|
+
if (src[1].op === op && src[1].src[1].op === AluOp.Const) return p(p(src[0], src[1].src[0]), src[1].src[1]).simplify(cache);
|
|
604
722
|
}
|
|
605
|
-
if (op === AluOp.
|
|
606
|
-
|
|
607
|
-
|
|
723
|
+
if (op === AluOp.Mod || op === AluOp.Idiv && src[1].#isConstInt()) {
|
|
724
|
+
const [x, y] = src;
|
|
725
|
+
{
|
|
726
|
+
const factors = [];
|
|
727
|
+
const terms = [];
|
|
728
|
+
for (const u of x.splitOp(AluOp.Add)) {
|
|
729
|
+
const factor = u.constFactor();
|
|
730
|
+
factors.push(factor);
|
|
731
|
+
terms.push(u.divides(factor));
|
|
732
|
+
}
|
|
733
|
+
const g = gcd(y.arg, ...factors);
|
|
734
|
+
if (g !== 1) {
|
|
735
|
+
let ret = new AluExp(op, this.dtype, [factors.map((f, i) => AluExp.mul(AluExp.const(terms[i].dtype, f / g), terms[i])).reduceRight((a, x$1) => AluExp.add(x$1, a)), AluExp.const(y.dtype, y.arg / g)]);
|
|
736
|
+
if (op === AluOp.Mod) ret = AluExp.mul(ret, AluExp.const(this.dtype, g));
|
|
737
|
+
return ret.simplify(cache);
|
|
738
|
+
}
|
|
739
|
+
}
|
|
740
|
+
if (y.arg > 0) {
|
|
741
|
+
let [xNoConst, constVal] = [x, 0];
|
|
742
|
+
if (x.op === AluOp.Add && x.src[1].op === AluOp.Const) [xNoConst, constVal] = [x.src[0], x.src[1].arg];
|
|
743
|
+
const terms = [];
|
|
744
|
+
const factors = [];
|
|
745
|
+
for (const u of xNoConst.splitOp(AluOp.Add)) {
|
|
746
|
+
const f = u.constFactor();
|
|
747
|
+
const divided = u.divides(f);
|
|
748
|
+
terms.push(divided ?? u);
|
|
749
|
+
factors.push(divided ? f : 1);
|
|
750
|
+
}
|
|
751
|
+
const quotients = factors.map((f) => Math.floor(f / y.arg));
|
|
752
|
+
const remainders = factors.map((f) => f % y.arg);
|
|
753
|
+
const gcdVal = remainders.reduce((g, r) => gcd(g, r), y.arg);
|
|
754
|
+
if (constVal % y.arg !== constVal || gcdVal !== 1 || remainders.some((r, i) => r === 0 || r !== factors[i] && op === AluOp.Mod)) {
|
|
755
|
+
let quo = AluExp.const(x.dtype, Math.floor(constVal / y.arg));
|
|
756
|
+
let rem = AluExp.const(x.dtype, Math.floor(constVal % y.arg / gcdVal));
|
|
757
|
+
for (let i = 0; i < terms.length; i++) if (op === AluOp.Idiv && remainders[i] !== 0) rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(factors[i] / gcdVal)), terms[i]));
|
|
758
|
+
else {
|
|
759
|
+
rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(remainders[i] / gcdVal)), terms[i]));
|
|
760
|
+
quo = AluExp.add(quo, AluExp.mul(AluExp.const(x.dtype, quotients[i]), terms[i]));
|
|
761
|
+
}
|
|
762
|
+
if (!((x.min < 0 || rem.min < 0) && remainders.some((r) => r !== 0))) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
|
|
763
|
+
else return AluExp.add(AluExp.idiv(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal))), quo).simplify(cache);
|
|
764
|
+
}
|
|
765
|
+
}
|
|
608
766
|
}
|
|
609
767
|
const newExp = src.every((s, i) => s === this.src[i]) ? this : new AluExp(op, this.dtype, src, this.arg);
|
|
610
768
|
return newExp;
|
|
@@ -647,12 +805,16 @@ var AluExp = class AluExp {
|
|
|
647
805
|
case AluOp.Cos: return Math.cos(x);
|
|
648
806
|
case AluOp.Exp: return Math.exp(x);
|
|
649
807
|
case AluOp.Log: return Math.log(x);
|
|
808
|
+
case AluOp.Sqrt: return Math.sqrt(x);
|
|
650
809
|
case AluOp.Reciprocal: return 1 / x;
|
|
651
|
-
case AluOp.Cast:
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
810
|
+
case AluOp.Cast: {
|
|
811
|
+
const wasFloat = isFloatDtype(this.src[0].dtype);
|
|
812
|
+
if (this.dtype === DType.Int32) return (wasFloat ? clamp(x, -2147483648, 2147483647) : x) | 0;
|
|
813
|
+
else if (this.dtype === DType.Uint32) return (wasFloat ? clamp(x, 0, 4294967295) : x) >>> 0;
|
|
814
|
+
else if (isFloatDtype(this.dtype)) return x;
|
|
815
|
+
else if (this.dtype === DType.Bool) return Number(Boolean(x));
|
|
816
|
+
else throw new Error(`Unsupported cast to ${this.dtype}`);
|
|
817
|
+
}
|
|
656
818
|
case AluOp.Bitcast: {
|
|
657
819
|
const buf = new ArrayBuffer(byteWidth(this.dtype));
|
|
658
820
|
const view = new DataView(buf);
|
|
@@ -660,10 +822,12 @@ var AluExp = class AluExp {
|
|
|
660
822
|
if (fromType === DType.Float32) view.setFloat32(0, x, true);
|
|
661
823
|
else if (fromType === DType.Int32) view.setInt32(0, x, true);
|
|
662
824
|
else if (fromType === DType.Uint32) view.setUint32(0, x, true);
|
|
825
|
+
else if (fromType === DType.Float16) view.setFloat16(0, x, true);
|
|
663
826
|
else throw new Error(`Unsupported bitcast from ${fromType}`);
|
|
664
827
|
if (this.dtype === DType.Float32) return view.getFloat32(0, true);
|
|
665
828
|
else if (this.dtype === DType.Int32) return view.getInt32(0, true);
|
|
666
829
|
else if (this.dtype === DType.Uint32) return view.getUint32(0, true);
|
|
830
|
+
else if (this.dtype === DType.Float16) return view.getFloat16(0, true);
|
|
667
831
|
else throw new Error(`Unsupported bitcast to ${this.dtype}`);
|
|
668
832
|
}
|
|
669
833
|
default: throw new Error(`Missing implemementation for ${this.op}`);
|
|
@@ -692,7 +856,7 @@ var AluExp = class AluExp {
|
|
|
692
856
|
}
|
|
693
857
|
case AluOp.GlobalIndex: {
|
|
694
858
|
if (!globals) throw new Error("Missing globals function");
|
|
695
|
-
const gid = this.arg;
|
|
859
|
+
const gid = this.arg[0];
|
|
696
860
|
const bufidx = this.src[0].evaluate(context, globals);
|
|
697
861
|
return globals(gid, bufidx);
|
|
698
862
|
}
|
|
@@ -722,13 +886,7 @@ var AluExp = class AluExp {
|
|
|
722
886
|
[AluOp.Cmplt]: "<",
|
|
723
887
|
[AluOp.Cmpne]: "!="
|
|
724
888
|
};
|
|
725
|
-
const UNARY_SYM = {
|
|
726
|
-
[AluOp.Sin]: "sin",
|
|
727
|
-
[AluOp.Cos]: "cos",
|
|
728
|
-
[AluOp.Exp]: "exp",
|
|
729
|
-
[AluOp.Log]: "log",
|
|
730
|
-
[AluOp.Reciprocal]: "1/"
|
|
731
|
-
};
|
|
889
|
+
const UNARY_SYM = { [AluOp.Reciprocal]: "1/" };
|
|
732
890
|
return this.fold((node, parts) => {
|
|
733
891
|
switch (node.op) {
|
|
734
892
|
case AluOp.Const: return "" + (node.dtype === DType.Bool ? Boolean(node.arg) : node.arg);
|
|
@@ -737,7 +895,7 @@ var AluExp = class AluExp {
|
|
|
737
895
|
const [name, n] = node.arg;
|
|
738
896
|
return `#${name}{${n}}`;
|
|
739
897
|
}
|
|
740
|
-
case AluOp.GlobalIndex: return `G_${node.arg}<${node.dtype}>[${strip1(parts[0])}]`;
|
|
898
|
+
case AluOp.GlobalIndex: return `G_${node.arg[0]}<${node.dtype}>[${strip1(parts[0])}]`;
|
|
741
899
|
case AluOp.GlobalView: {
|
|
742
900
|
const [gid, st] = node.arg;
|
|
743
901
|
const shape = st.shape.join(",");
|
|
@@ -766,6 +924,17 @@ var AluExp = class AluExp {
|
|
|
766
924
|
};
|
|
767
925
|
return recurse(this);
|
|
768
926
|
}
|
|
927
|
+
/** Check if any expression in the tree satisfies a predicate. */
|
|
928
|
+
some(predicate) {
|
|
929
|
+
const visited = /* @__PURE__ */ new Set();
|
|
930
|
+
const recurse = (exp) => {
|
|
931
|
+
if (visited.has(exp)) return false;
|
|
932
|
+
if (predicate(exp)) return true;
|
|
933
|
+
visited.add(exp);
|
|
934
|
+
return exp.src.some(recurse);
|
|
935
|
+
};
|
|
936
|
+
return recurse(this);
|
|
937
|
+
}
|
|
769
938
|
/** Rewrite the expression recursively using a visitor. */
|
|
770
939
|
rewrite(visitor) {
|
|
771
940
|
return this.fold((exp, newSrc) => {
|
|
@@ -784,6 +953,23 @@ var AluExp = class AluExp {
|
|
|
784
953
|
});
|
|
785
954
|
return result;
|
|
786
955
|
}
|
|
956
|
+
/** Produce a list of all distinct AluOp in this expression. */
|
|
957
|
+
distinctOps() {
|
|
958
|
+
const ops = /* @__PURE__ */ new Set();
|
|
959
|
+
this.fold((exp) => {
|
|
960
|
+
ops.add(exp.op);
|
|
961
|
+
});
|
|
962
|
+
return ops;
|
|
963
|
+
}
|
|
964
|
+
/** Rewrite GlobalView operations to GlobalIndex operations. */
|
|
965
|
+
rewriteGlobalViews() {
|
|
966
|
+
return this.rewrite((exp) => {
|
|
967
|
+
if (exp.op === AluOp.GlobalView) {
|
|
968
|
+
const [gid, st] = exp.arg;
|
|
969
|
+
return accessorGlobal(exp.dtype, gid, st, exp.src);
|
|
970
|
+
}
|
|
971
|
+
});
|
|
972
|
+
}
|
|
787
973
|
};
|
|
788
974
|
/** Symbolic form for each mathematical operation. */
|
|
789
975
|
let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
@@ -798,6 +984,7 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
798
984
|
AluOp$1["Cos"] = "Cos";
|
|
799
985
|
AluOp$1["Exp"] = "Exp";
|
|
800
986
|
AluOp$1["Log"] = "Log";
|
|
987
|
+
AluOp$1["Sqrt"] = "Sqrt";
|
|
801
988
|
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
802
989
|
AluOp$1["Cast"] = "Cast";
|
|
803
990
|
AluOp$1["Bitcast"] = "Bitcast";
|
|
@@ -827,6 +1014,7 @@ const AluGroup = {
|
|
|
827
1014
|
AluOp.Cos,
|
|
828
1015
|
AluOp.Exp,
|
|
829
1016
|
AluOp.Log,
|
|
1017
|
+
AluOp.Sqrt,
|
|
830
1018
|
AluOp.Reciprocal,
|
|
831
1019
|
AluOp.Cast,
|
|
832
1020
|
AluOp.Bitcast
|
|
@@ -849,6 +1037,7 @@ const AluGroup = {
|
|
|
849
1037
|
AluOp.Cos,
|
|
850
1038
|
AluOp.Exp,
|
|
851
1039
|
AluOp.Log,
|
|
1040
|
+
AluOp.Sqrt,
|
|
852
1041
|
AluOp.Reciprocal
|
|
853
1042
|
])
|
|
854
1043
|
};
|
|
@@ -890,7 +1079,7 @@ var Kernel = class {
|
|
|
890
1079
|
}
|
|
891
1080
|
/** The dtype of the values output by this kernel. */
|
|
892
1081
|
get dtype() {
|
|
893
|
-
if (this.reduction) return this.reduction.
|
|
1082
|
+
if (this.reduction) return this.reduction.epilogue.dtype;
|
|
894
1083
|
else return this.exp.dtype;
|
|
895
1084
|
}
|
|
896
1085
|
/** The number of bytes in the output array when evaluating this kernel. */
|
|
@@ -914,22 +1103,23 @@ var Kernel = class {
|
|
|
914
1103
|
* at this level since they depend on GPU, versus CPU or Wasm.
|
|
915
1104
|
*/
|
|
916
1105
|
var Reduction = class {
|
|
917
|
-
constructor(dtype, op, size,
|
|
1106
|
+
constructor(dtype, op, size, epilogue = AluVar.acc(dtype)) {
|
|
918
1107
|
this.dtype = dtype;
|
|
919
1108
|
this.op = op;
|
|
920
1109
|
this.size = size;
|
|
921
|
-
this.
|
|
1110
|
+
this.epilogue = epilogue;
|
|
922
1111
|
if (!AluGroup.Reduce.has(op)) throw new TypeError(`Unsupported reduction: ${op}`);
|
|
1112
|
+
this.epilogue = epilogue.simplify();
|
|
923
1113
|
}
|
|
924
1114
|
hash(state) {
|
|
925
|
-
state.update(this.dtype, this.op, this.size, this.
|
|
1115
|
+
state.update(this.dtype, this.op, this.size, this.epilogue);
|
|
926
1116
|
}
|
|
927
1117
|
toString() {
|
|
928
|
-
return `${this.op}{${this.size}} -> ${this.
|
|
1118
|
+
return `${this.op}{${this.size}} -> ${this.epilogue}`;
|
|
929
1119
|
}
|
|
930
1120
|
/** Get the identity for this reduction operation. */
|
|
931
1121
|
get identity() {
|
|
932
|
-
if (this.dtype === DType.Bool) return this.op === AluOp.Add || this.op === AluOp.Max ?
|
|
1122
|
+
if (this.dtype === DType.Bool) return this.op === AluOp.Add || this.op === AluOp.Max ? 0 : 1;
|
|
933
1123
|
else if (this.dtype === DType.Int32) {
|
|
934
1124
|
if (this.op === AluOp.Add) return 0;
|
|
935
1125
|
else if (this.op === AluOp.Mul) return 1;
|
|
@@ -940,7 +1130,7 @@ var Reduction = class {
|
|
|
940
1130
|
else if (this.op === AluOp.Mul) return 1;
|
|
941
1131
|
else if (this.op === AluOp.Min) return -1 >>> 0;
|
|
942
1132
|
else if (this.op === AluOp.Max) return 0;
|
|
943
|
-
} else if (this.dtype
|
|
1133
|
+
} else if (isFloatDtype(this.dtype)) {
|
|
944
1134
|
if (this.op === AluOp.Add) return 0;
|
|
945
1135
|
else if (this.op === AluOp.Mul) return 1;
|
|
946
1136
|
else if (this.op === AluOp.Min) return Infinity;
|
|
@@ -963,7 +1153,7 @@ var Reduction = class {
|
|
|
963
1153
|
else if (this.op === AluOp.Mul) return values.reduce((a, b) => a * b >>> 0, 1);
|
|
964
1154
|
else if (this.op === AluOp.Min) return values.reduce((a, b) => Math.min(a, b), -1 >>> 0);
|
|
965
1155
|
else if (this.op === AluOp.Max) return values.reduce((a, b) => Math.max(a, b), 0);
|
|
966
|
-
} else if (this.dtype
|
|
1156
|
+
} else if (isFloatDtype(this.dtype)) {
|
|
967
1157
|
if (this.op === AluOp.Add) return values.reduce((a, b) => a + b, 0);
|
|
968
1158
|
else if (this.op === AluOp.Mul) return values.reduce((a, b) => a * b, 1);
|
|
969
1159
|
else if (this.op === AluOp.Min) return values.reduce((a, b) => Math.min(a, b), Infinity);
|
|
@@ -975,12 +1165,13 @@ var Reduction = class {
|
|
|
975
1165
|
/** Expression for accessing `indices` in input array with the given shape. */
|
|
976
1166
|
function accessorGlobal(dtype, gid, st, indices) {
|
|
977
1167
|
const [index, valid] = st.toAluExp(indices);
|
|
978
|
-
|
|
1168
|
+
const [, len] = st.views[0].dataRange();
|
|
1169
|
+
return AluExp.where(valid, AluExp.globalIndex(dtype, gid, len, index), AluExp.const(dtype, 0));
|
|
979
1170
|
}
|
|
980
1171
|
/** Expression for accessing `indices` in an array recipe with variable "idx". */
|
|
981
|
-
function accessorAluExp(
|
|
1172
|
+
function accessorAluExp(exp, st, indices) {
|
|
982
1173
|
const [index, valid] = st.toAluExp(indices);
|
|
983
|
-
return AluExp.where(valid, exp.substitute({ idx: index }), AluExp.const(dtype, 0));
|
|
1174
|
+
return AluExp.where(valid, exp.substitute({ idx: index }), AluExp.const(exp.dtype, 0));
|
|
984
1175
|
}
|
|
985
1176
|
function threefry2x32(k0, k1, c0, c1) {
|
|
986
1177
|
const rotl32 = (x, r) => (x << r | x >>> 32 - r) >>> 0;
|
|
@@ -1164,6 +1355,25 @@ var View = class View {
|
|
|
1164
1355
|
if (this.#contiguous === void 0) this.#contiguous = this.size === 0 || this.offset === 0 && this.mask === null && deepEqual(this.strides, defaultStrides(this.shape));
|
|
1165
1356
|
return this.#contiguous;
|
|
1166
1357
|
}
|
|
1358
|
+
/** Return the range of data being indexed in this view, or [0, 0] if none. */
|
|
1359
|
+
dataRange() {
|
|
1360
|
+
if (this.size === 0 || this.mask && this.mask[0][0] === this.mask[0][1]) return [0, 0];
|
|
1361
|
+
let min = this.offset;
|
|
1362
|
+
let max = this.offset;
|
|
1363
|
+
for (let i = 0; i < this.ndim; i++) {
|
|
1364
|
+
let [lo, hi] = this.mask ? this.mask[i] : [0, this.shape[i]];
|
|
1365
|
+
--hi;
|
|
1366
|
+
const s = this.strides[i];
|
|
1367
|
+
if (s > 0) {
|
|
1368
|
+
min += s * lo;
|
|
1369
|
+
max += s * hi;
|
|
1370
|
+
} else if (s < 0) {
|
|
1371
|
+
min += s * hi;
|
|
1372
|
+
max += s * lo;
|
|
1373
|
+
}
|
|
1374
|
+
}
|
|
1375
|
+
return [min, max + 1];
|
|
1376
|
+
}
|
|
1167
1377
|
/** Produce an AluExp for evaluating this view at an index. */
|
|
1168
1378
|
toAluExp(idxs) {
|
|
1169
1379
|
let iexpr = AluExp.i32(this.offset);
|
|
@@ -1478,6 +1688,39 @@ var ShapeTracker = class ShapeTracker {
|
|
|
1478
1688
|
}
|
|
1479
1689
|
return st.expand(newShape);
|
|
1480
1690
|
}
|
|
1691
|
+
/**
|
|
1692
|
+
* Repeat data in each axis by a positive number of repetitions.
|
|
1693
|
+
*
|
|
1694
|
+
* - If `tile` is true (default): [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
|
|
1695
|
+
* - If `tile` is false: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
|
|
1696
|
+
*/
|
|
1697
|
+
repeat(reps, tile = true) {
|
|
1698
|
+
if (reps.length > this.shape.length) throw new Error(`Too many repeats ${jstr(reps)} for shape ${jstr(this.shape)}`);
|
|
1699
|
+
if (reps.some((c) => c <= 0)) throw new Error(`Invalid repeats ${jstr(reps)}`);
|
|
1700
|
+
if (reps.length === 0) return this;
|
|
1701
|
+
const noop = this.shape.slice(0, -reps.length);
|
|
1702
|
+
const shape = this.shape.slice(-reps.length);
|
|
1703
|
+
return this.broadcast([...noop, ...shape.flatMap((s, i) => tile ? [reps[i], s] : [s, reps[i]])], shape.map((_, i) => noop.length + 2 * i + (tile ? 0 : 1))).reshape([...noop, ...shape.map((s, i) => s * reps[i])]);
|
|
1704
|
+
}
|
|
1705
|
+
/** Move axis i to axis j. */
|
|
1706
|
+
moveaxis(i, j) {
|
|
1707
|
+
const perm = range(this.shape.length);
|
|
1708
|
+
perm.splice(i, 1);
|
|
1709
|
+
perm.splice(j, 0, i);
|
|
1710
|
+
return this.permute(perm);
|
|
1711
|
+
}
|
|
1712
|
+
/** Like pad(), but allows for negative values. */
|
|
1713
|
+
padOrShrink(arg) {
|
|
1714
|
+
const padArg = [];
|
|
1715
|
+
const shrinkArg = [];
|
|
1716
|
+
for (let i = 0; i < arg.length; i++) {
|
|
1717
|
+
const [b, e] = arg[i];
|
|
1718
|
+
if (b < -this.shape[i] || e < -this.shape[i] || b + e < -this.shape[i]) throw new Error(`Invalid padOrShrink ${jstr(arg)} for ${jstr(this.shape)}`);
|
|
1719
|
+
padArg.push([Math.max(0, b), Math.max(0, e)]);
|
|
1720
|
+
shrinkArg.push([Math.max(0, -b), this.shape[i] - Math.max(0, -e)]);
|
|
1721
|
+
}
|
|
1722
|
+
return this.shrink(shrinkArg).pad(padArg);
|
|
1723
|
+
}
|
|
1481
1724
|
};
|
|
1482
1725
|
function applyLast(ar, f) {
|
|
1483
1726
|
return ar.toSpliced(ar.length - 1, 1, f(ar[ar.length - 1]));
|
|
@@ -1598,13 +1841,7 @@ function tuneNullopt(kernel) {
|
|
|
1598
1841
|
vars.gidx = AluExp.special(DType.Int32, "gidx", kernel.size);
|
|
1599
1842
|
if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
|
|
1600
1843
|
return {
|
|
1601
|
-
exp: kernel.exp.
|
|
1602
|
-
if (exp.op === AluOp.GlobalView) {
|
|
1603
|
-
const gid = exp.arg[0];
|
|
1604
|
-
const st = exp.arg[1];
|
|
1605
|
-
return accessorGlobal(exp.dtype, gid, st, exp.src);
|
|
1606
|
-
}
|
|
1607
|
-
}).substitute(vars).simplify(),
|
|
1844
|
+
exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
|
|
1608
1845
|
outputIdxExp: AluExp.special(DType.Int32, "gidx", kernel.size),
|
|
1609
1846
|
threadCount: kernel.size,
|
|
1610
1847
|
size: { reduce: kernel.reduction ? kernel.reduction.size : 0 }
|
|
@@ -1699,13 +1936,19 @@ function tuneWebgpu(kernel) {
|
|
|
1699
1936
|
const s = dim.st.shape.slice(dim.upcast);
|
|
1700
1937
|
addIndices(s, AluVar.upcast);
|
|
1701
1938
|
}
|
|
1702
|
-
|
|
1939
|
+
let newExp = exp.rewrite((exp$1) => {
|
|
1703
1940
|
if (exp$1.op === AluOp.GlobalView) {
|
|
1704
1941
|
const gid = exp$1.arg[0];
|
|
1705
1942
|
const st = exp$1.arg[1];
|
|
1706
1943
|
return accessorGlobal(exp$1.dtype, gid, st.compose(dim.st), indices);
|
|
1707
1944
|
}
|
|
1708
1945
|
});
|
|
1946
|
+
const [iexpr, vexpr] = dim.st.toAluExp(indices);
|
|
1947
|
+
if (vexpr.min !== 1) throw new Error("Invariant violation: vexpr !== true");
|
|
1948
|
+
newExp = newExp.substitute({
|
|
1949
|
+
gidx: AluExp.idiv(iexpr, AluExp.i32(reduction.size)).simplify(),
|
|
1950
|
+
ridx: AluExp.mod(iexpr, AluExp.i32(reduction.size)).simplify()
|
|
1951
|
+
});
|
|
1709
1952
|
const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
|
|
1710
1953
|
const outputUpcast = dim.outputSt.shape.slice(dim.groups);
|
|
1711
1954
|
const [outputIdxExp, _] = dim.outputSt.toAluExp([...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)]);
|
|
@@ -1727,7 +1970,7 @@ function tuneWebgpu(kernel) {
|
|
|
1727
1970
|
//#endregion
|
|
1728
1971
|
//#region src/backend/cpu.ts
|
|
1729
1972
|
/** Most basic implementation of `Backend` for testing. */
|
|
1730
|
-
var
|
|
1973
|
+
var CpuBackend = class {
|
|
1731
1974
|
type = "cpu";
|
|
1732
1975
|
maxArgs = Infinity;
|
|
1733
1976
|
#buffers;
|
|
@@ -1737,10 +1980,10 @@ var CPUBackend = class {
|
|
|
1737
1980
|
this.#nextSlot = 1;
|
|
1738
1981
|
}
|
|
1739
1982
|
malloc(size, initialData) {
|
|
1740
|
-
const buffer = new
|
|
1983
|
+
const buffer = new Uint8Array(size);
|
|
1741
1984
|
if (initialData) {
|
|
1742
1985
|
if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
|
|
1743
|
-
|
|
1986
|
+
buffer.set(initialData);
|
|
1744
1987
|
}
|
|
1745
1988
|
const slot = this.#nextSlot++;
|
|
1746
1989
|
this.#buffers.set(slot, {
|
|
@@ -1779,7 +2022,7 @@ var CPUBackend = class {
|
|
|
1779
2022
|
const { exp } = tuneNullopt(kernel);
|
|
1780
2023
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
|
|
1781
2024
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
|
|
1782
|
-
const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg, exp$1.dtype]));
|
|
2025
|
+
const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
|
|
1783
2026
|
const inputArrays = inputBuffers.map((buf, i) => {
|
|
1784
2027
|
const dtype = usedArgs.get(i);
|
|
1785
2028
|
if (!dtype) return null;
|
|
@@ -1801,7 +2044,7 @@ var CPUBackend = class {
|
|
|
1801
2044
|
}, globals);
|
|
1802
2045
|
acc = kernel.reduction.evaluate(acc, item);
|
|
1803
2046
|
}
|
|
1804
|
-
outputArray[i] = kernel.reduction.
|
|
2047
|
+
outputArray[i] = kernel.reduction.epilogue.evaluate({ acc });
|
|
1805
2048
|
}
|
|
1806
2049
|
}
|
|
1807
2050
|
#getBuffer(slot) {
|
|
@@ -1811,12 +2054,1540 @@ var CPUBackend = class {
|
|
|
1811
2054
|
}
|
|
1812
2055
|
};
|
|
1813
2056
|
|
|
2057
|
+
//#endregion
|
|
2058
|
+
//#region src/backend/wasm/allocator.ts
|
|
2059
|
+
/** Simple tensor memory allocator for WebAssembly linear memory. */
|
|
2060
|
+
var WasmAllocator = class {
|
|
2061
|
+
#memory;
|
|
2062
|
+
#headPtr;
|
|
2063
|
+
#freeLists;
|
|
2064
|
+
#allocatedBuffers;
|
|
2065
|
+
constructor(memory) {
|
|
2066
|
+
this.#memory = memory;
|
|
2067
|
+
this.#headPtr = 64;
|
|
2068
|
+
this.#freeLists = /* @__PURE__ */ new Map();
|
|
2069
|
+
this.#allocatedBuffers = /* @__PURE__ */ new Map();
|
|
2070
|
+
}
|
|
2071
|
+
malloc(size) {
|
|
2072
|
+
if (size === 0) return 0;
|
|
2073
|
+
const sizeClass = this.#findSizeClass(size);
|
|
2074
|
+
const freeList = this.#freeLists.get(sizeClass);
|
|
2075
|
+
let ptr;
|
|
2076
|
+
if (freeList && freeList.length > 0) ptr = freeList.pop();
|
|
2077
|
+
else ptr = this.#bumpAlloc(sizeClass);
|
|
2078
|
+
this.#allocatedBuffers.set(ptr, sizeClass);
|
|
2079
|
+
return ptr;
|
|
2080
|
+
}
|
|
2081
|
+
free(ptr) {
|
|
2082
|
+
if (ptr === 0) return;
|
|
2083
|
+
const sizeClass = this.#allocatedBuffers.get(ptr);
|
|
2084
|
+
if (sizeClass === void 0) throw new Error(`Attempting to free unallocated pointer: ${ptr}`);
|
|
2085
|
+
const freeList = this.#freeLists.get(sizeClass);
|
|
2086
|
+
if (freeList) freeList.push(ptr);
|
|
2087
|
+
else this.#freeLists.set(sizeClass, [ptr]);
|
|
2088
|
+
this.#allocatedBuffers.delete(ptr);
|
|
2089
|
+
}
|
|
2090
|
+
#bumpAlloc(size) {
|
|
2091
|
+
const ptr = this.#headPtr;
|
|
2092
|
+
size = size + 63 & -64;
|
|
2093
|
+
this.#headPtr += size;
|
|
2094
|
+
if (ptr + size > this.#memory.buffer.byteLength) this.#memory.grow((ptr + size + 65535 >> 16) - (this.#memory.buffer.byteLength >> 16));
|
|
2095
|
+
return ptr;
|
|
2096
|
+
}
|
|
2097
|
+
#findSizeClass(size) {
|
|
2098
|
+
if (size <= 512) return size + 63 & -64;
|
|
2099
|
+
if (size <= 2048) return size + 511 & -512;
|
|
2100
|
+
if (size <= 65536) {
|
|
2101
|
+
let sizeClass = 4096;
|
|
2102
|
+
while (sizeClass < size) sizeClass *= 2;
|
|
2103
|
+
return sizeClass;
|
|
2104
|
+
}
|
|
2105
|
+
return size + 65535 & -65536;
|
|
2106
|
+
}
|
|
2107
|
+
getStats() {
|
|
2108
|
+
const freeListSizes = /* @__PURE__ */ new Map();
|
|
2109
|
+
for (const [sizeClass, freeList] of this.#freeLists) if (freeList.length > 0) freeListSizes.set(sizeClass, freeList.length);
|
|
2110
|
+
return {
|
|
2111
|
+
totalAllocated: this.#headPtr,
|
|
2112
|
+
freeListSizes
|
|
2113
|
+
};
|
|
2114
|
+
}
|
|
2115
|
+
};
|
|
2116
|
+
|
|
2117
|
+
//#endregion
|
|
2118
|
+
//#region src/backend/wasm/builtins.ts
|
|
2119
|
+
/**
|
|
2120
|
+
* Approximate e^x.
|
|
2121
|
+
*
|
|
2122
|
+
* Method: range-reduce x = k*ln2 + r with k = round(x/ln2), |r|<=~0.3466
|
|
2123
|
+
* then e^x = 2^k * P(r), where P is 5th-order poly (Taylor).
|
|
2124
|
+
*/
|
|
2125
|
+
function wasm_exp(cg) {
|
|
2126
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2127
|
+
const k_f = cg.local.declare(cg.f32);
|
|
2128
|
+
const k = cg.local.declare(cg.i32);
|
|
2129
|
+
const r = cg.local.declare(cg.f32);
|
|
2130
|
+
const p = cg.local.declare(cg.f32);
|
|
2131
|
+
const scale = cg.local.declare(cg.f32);
|
|
2132
|
+
cg.local.get(0);
|
|
2133
|
+
cg.f32.const(1 / Math.LN2);
|
|
2134
|
+
cg.f32.mul();
|
|
2135
|
+
cg.f32.nearest();
|
|
2136
|
+
cg.local.tee(k_f);
|
|
2137
|
+
cg.i32.trunc_sat_f32_s();
|
|
2138
|
+
cg.local.set(k);
|
|
2139
|
+
cg.local.get(k);
|
|
2140
|
+
cg.i32.const(127);
|
|
2141
|
+
cg.i32.gt_s();
|
|
2142
|
+
cg.if(cg.void);
|
|
2143
|
+
cg.f32.const(Infinity);
|
|
2144
|
+
cg.return();
|
|
2145
|
+
cg.end();
|
|
2146
|
+
cg.local.get(k);
|
|
2147
|
+
cg.i32.const(-126);
|
|
2148
|
+
cg.i32.lt_s();
|
|
2149
|
+
cg.if(cg.void);
|
|
2150
|
+
cg.f32.const(0);
|
|
2151
|
+
cg.return();
|
|
2152
|
+
cg.end();
|
|
2153
|
+
cg.local.get(0);
|
|
2154
|
+
cg.local.get(k_f);
|
|
2155
|
+
cg.f32.const(Math.LN2);
|
|
2156
|
+
cg.f32.mul();
|
|
2157
|
+
cg.f32.sub();
|
|
2158
|
+
cg.local.set(r);
|
|
2159
|
+
cg.f32.const(1 / 120);
|
|
2160
|
+
cg.local.get(r);
|
|
2161
|
+
cg.f32.mul();
|
|
2162
|
+
cg.f32.const(1 / 24);
|
|
2163
|
+
cg.f32.add();
|
|
2164
|
+
cg.local.get(r);
|
|
2165
|
+
cg.f32.mul();
|
|
2166
|
+
cg.f32.const(1 / 6);
|
|
2167
|
+
cg.f32.add();
|
|
2168
|
+
cg.local.get(r);
|
|
2169
|
+
cg.f32.mul();
|
|
2170
|
+
cg.f32.const(1 / 2);
|
|
2171
|
+
cg.f32.add();
|
|
2172
|
+
cg.local.get(r);
|
|
2173
|
+
cg.f32.mul();
|
|
2174
|
+
cg.f32.const(1);
|
|
2175
|
+
cg.f32.add();
|
|
2176
|
+
cg.local.get(r);
|
|
2177
|
+
cg.f32.mul();
|
|
2178
|
+
cg.f32.const(1);
|
|
2179
|
+
cg.f32.add();
|
|
2180
|
+
cg.local.set(p);
|
|
2181
|
+
cg.local.get(k);
|
|
2182
|
+
cg.i32.const(127);
|
|
2183
|
+
cg.i32.add();
|
|
2184
|
+
cg.i32.const(23);
|
|
2185
|
+
cg.i32.shl();
|
|
2186
|
+
cg.f32.reinterpret_i32();
|
|
2187
|
+
cg.local.set(scale);
|
|
2188
|
+
cg.local.get(p);
|
|
2189
|
+
cg.local.get(scale);
|
|
2190
|
+
cg.f32.mul();
|
|
2191
|
+
});
|
|
2192
|
+
}
|
|
2193
|
+
/**
|
|
2194
|
+
* Approximate ln(x), x > 0.
|
|
2195
|
+
*
|
|
2196
|
+
* Method: decompose x = m * 2^e with m in [1,2), e integer (via bit ops)
|
|
2197
|
+
* ln(x) = e*ln2 + ln(m); use atanh-style series with t=(m-1)/(m+1)
|
|
2198
|
+
* ln(m) ≈ 2*(t + t^3/3 + t^5/5 + t^7/7)
|
|
2199
|
+
*/
|
|
2200
|
+
function wasm_log(cg) {
|
|
2201
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2202
|
+
const bits = cg.local.declare(cg.i32);
|
|
2203
|
+
const e = cg.local.declare(cg.i32);
|
|
2204
|
+
const m = cg.local.declare(cg.f32);
|
|
2205
|
+
const t = cg.local.declare(cg.f32);
|
|
2206
|
+
const t2 = cg.local.declare(cg.f32);
|
|
2207
|
+
const t3 = cg.local.declare(cg.f32);
|
|
2208
|
+
const t5 = cg.local.declare(cg.f32);
|
|
2209
|
+
const t7 = cg.local.declare(cg.f32);
|
|
2210
|
+
const lnm = cg.local.declare(cg.f32);
|
|
2211
|
+
const el2 = cg.local.declare(cg.f32);
|
|
2212
|
+
cg.local.get(0);
|
|
2213
|
+
cg.f32.const(0);
|
|
2214
|
+
cg.f32.le();
|
|
2215
|
+
cg.if(cg.void);
|
|
2216
|
+
cg.f32.const(NaN);
|
|
2217
|
+
cg.return();
|
|
2218
|
+
cg.end();
|
|
2219
|
+
cg.local.get(0);
|
|
2220
|
+
cg.i32.reinterpret_f32();
|
|
2221
|
+
cg.local.tee(bits);
|
|
2222
|
+
cg.i32.const(23);
|
|
2223
|
+
cg.i32.shr_u();
|
|
2224
|
+
cg.i32.const(255);
|
|
2225
|
+
cg.i32.and();
|
|
2226
|
+
cg.i32.const(127);
|
|
2227
|
+
cg.i32.sub();
|
|
2228
|
+
cg.local.set(e);
|
|
2229
|
+
cg.local.get(bits);
|
|
2230
|
+
cg.i32.const(8388607);
|
|
2231
|
+
cg.i32.and();
|
|
2232
|
+
cg.i32.const(1065353216);
|
|
2233
|
+
cg.i32.or();
|
|
2234
|
+
cg.f32.reinterpret_i32();
|
|
2235
|
+
cg.local.set(m);
|
|
2236
|
+
cg.local.get(m);
|
|
2237
|
+
cg.f32.const(1);
|
|
2238
|
+
cg.f32.sub();
|
|
2239
|
+
cg.local.get(m);
|
|
2240
|
+
cg.f32.const(1);
|
|
2241
|
+
cg.f32.add();
|
|
2242
|
+
cg.f32.div();
|
|
2243
|
+
cg.local.set(t);
|
|
2244
|
+
cg.local.get(t);
|
|
2245
|
+
cg.local.get(t);
|
|
2246
|
+
cg.f32.mul();
|
|
2247
|
+
cg.local.set(t2);
|
|
2248
|
+
cg.local.get(t);
|
|
2249
|
+
cg.local.get(t2);
|
|
2250
|
+
cg.f32.mul();
|
|
2251
|
+
cg.local.set(t3);
|
|
2252
|
+
cg.local.get(t3);
|
|
2253
|
+
cg.local.get(t2);
|
|
2254
|
+
cg.f32.mul();
|
|
2255
|
+
cg.local.set(t5);
|
|
2256
|
+
cg.local.get(t5);
|
|
2257
|
+
cg.local.get(t2);
|
|
2258
|
+
cg.f32.mul();
|
|
2259
|
+
cg.local.set(t7);
|
|
2260
|
+
cg.local.get(t7);
|
|
2261
|
+
cg.f32.const(1 / 7);
|
|
2262
|
+
cg.f32.mul();
|
|
2263
|
+
cg.local.get(t5);
|
|
2264
|
+
cg.f32.const(1 / 5);
|
|
2265
|
+
cg.f32.mul();
|
|
2266
|
+
cg.f32.add();
|
|
2267
|
+
cg.local.get(t3);
|
|
2268
|
+
cg.f32.const(1 / 3);
|
|
2269
|
+
cg.f32.mul();
|
|
2270
|
+
cg.f32.add();
|
|
2271
|
+
cg.local.get(t);
|
|
2272
|
+
cg.f32.add();
|
|
2273
|
+
cg.f32.const(2);
|
|
2274
|
+
cg.f32.mul();
|
|
2275
|
+
cg.local.set(lnm);
|
|
2276
|
+
cg.local.get(e);
|
|
2277
|
+
cg.f32.convert_i32_s();
|
|
2278
|
+
cg.f32.const(Math.LN2);
|
|
2279
|
+
cg.f32.mul();
|
|
2280
|
+
cg.local.set(el2);
|
|
2281
|
+
cg.local.get(el2);
|
|
2282
|
+
cg.local.get(lnm);
|
|
2283
|
+
cg.f32.add();
|
|
2284
|
+
});
|
|
2285
|
+
}
|
|
2286
|
+
/**
|
|
2287
|
+
* Approximate sin(x).
|
|
2288
|
+
*
|
|
2289
|
+
* Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
|
|
2290
|
+
* z = y - q*(π/2); use odd polynomial on z:
|
|
2291
|
+
* sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
|
|
2292
|
+
*/
|
|
2293
|
+
function wasm_sin(cg) {
|
|
2294
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2295
|
+
const y = cg.local.declare(cg.f32);
|
|
2296
|
+
const qf = cg.local.declare(cg.f32);
|
|
2297
|
+
const q = cg.local.declare(cg.i32);
|
|
2298
|
+
const z = cg.local.declare(cg.f32);
|
|
2299
|
+
const z2 = cg.local.declare(cg.f32);
|
|
2300
|
+
const sz = cg.local.declare(cg.f32);
|
|
2301
|
+
const cz = cg.local.declare(cg.f32);
|
|
2302
|
+
const mag = cg.local.declare(cg.f32);
|
|
2303
|
+
cg.local.get(0);
|
|
2304
|
+
cg.local.get(0);
|
|
2305
|
+
cg.f32.const(1 / (2 * Math.PI));
|
|
2306
|
+
cg.f32.mul();
|
|
2307
|
+
cg.f32.nearest();
|
|
2308
|
+
cg.local.tee(qf);
|
|
2309
|
+
cg.f32.const(2 * Math.PI);
|
|
2310
|
+
cg.f32.mul();
|
|
2311
|
+
cg.f32.sub();
|
|
2312
|
+
cg.local.set(y);
|
|
2313
|
+
cg.local.get(y);
|
|
2314
|
+
cg.f32.const(2 / Math.PI);
|
|
2315
|
+
cg.f32.mul();
|
|
2316
|
+
cg.f32.nearest();
|
|
2317
|
+
cg.local.tee(qf);
|
|
2318
|
+
cg.i32.trunc_f32_s();
|
|
2319
|
+
cg.local.set(q);
|
|
2320
|
+
cg.local.get(y);
|
|
2321
|
+
cg.local.get(qf);
|
|
2322
|
+
cg.f32.const(Math.PI / 2);
|
|
2323
|
+
cg.f32.mul();
|
|
2324
|
+
cg.f32.sub();
|
|
2325
|
+
cg.local.tee(z);
|
|
2326
|
+
cg.local.get(z);
|
|
2327
|
+
cg.f32.mul();
|
|
2328
|
+
cg.local.set(z2);
|
|
2329
|
+
cg.f32.const(-1 / 5040);
|
|
2330
|
+
cg.local.get(z2);
|
|
2331
|
+
cg.f32.mul();
|
|
2332
|
+
cg.f32.const(1 / 120);
|
|
2333
|
+
cg.f32.add();
|
|
2334
|
+
cg.local.get(z2);
|
|
2335
|
+
cg.f32.mul();
|
|
2336
|
+
cg.f32.const(-1 / 6);
|
|
2337
|
+
cg.f32.add();
|
|
2338
|
+
cg.local.get(z2);
|
|
2339
|
+
cg.f32.mul();
|
|
2340
|
+
cg.f32.const(1);
|
|
2341
|
+
cg.f32.add();
|
|
2342
|
+
cg.local.get(z);
|
|
2343
|
+
cg.f32.mul();
|
|
2344
|
+
cg.local.set(sz);
|
|
2345
|
+
cg.f32.const(-1 / 720);
|
|
2346
|
+
cg.local.get(z2);
|
|
2347
|
+
cg.f32.mul();
|
|
2348
|
+
cg.f32.const(1 / 24);
|
|
2349
|
+
cg.f32.add();
|
|
2350
|
+
cg.local.get(z2);
|
|
2351
|
+
cg.f32.mul();
|
|
2352
|
+
cg.f32.const(-1 / 2);
|
|
2353
|
+
cg.f32.add();
|
|
2354
|
+
cg.local.get(z2);
|
|
2355
|
+
cg.f32.mul();
|
|
2356
|
+
cg.f32.const(1);
|
|
2357
|
+
cg.f32.add();
|
|
2358
|
+
cg.local.set(cz);
|
|
2359
|
+
cg.local.get(cz);
|
|
2360
|
+
cg.local.get(sz);
|
|
2361
|
+
cg.local.get(q);
|
|
2362
|
+
cg.i32.const(1);
|
|
2363
|
+
cg.i32.and();
|
|
2364
|
+
cg.select();
|
|
2365
|
+
cg.local.tee(mag);
|
|
2366
|
+
cg.f32.neg();
|
|
2367
|
+
cg.local.get(mag);
|
|
2368
|
+
cg.local.get(q);
|
|
2369
|
+
cg.i32.const(2);
|
|
2370
|
+
cg.i32.and();
|
|
2371
|
+
cg.select();
|
|
2372
|
+
});
|
|
2373
|
+
}
|
|
2374
|
+
/**
|
|
2375
|
+
* Approximate cos(x).
|
|
2376
|
+
*
|
|
2377
|
+
* Same reduction as sinf, then quadrant mapping:
|
|
2378
|
+
* k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
|
|
2379
|
+
*/
|
|
2380
|
+
function wasm_cos(cg) {
|
|
2381
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2382
|
+
const y = cg.local.declare(cg.f32);
|
|
2383
|
+
const qf = cg.local.declare(cg.f32);
|
|
2384
|
+
const q = cg.local.declare(cg.i32);
|
|
2385
|
+
const z = cg.local.declare(cg.f32);
|
|
2386
|
+
const z2 = cg.local.declare(cg.f32);
|
|
2387
|
+
const sz = cg.local.declare(cg.f32);
|
|
2388
|
+
const cz = cg.local.declare(cg.f32);
|
|
2389
|
+
const mag = cg.local.declare(cg.f32);
|
|
2390
|
+
cg.local.get(0);
|
|
2391
|
+
cg.local.get(0);
|
|
2392
|
+
cg.f32.const(1 / (2 * Math.PI));
|
|
2393
|
+
cg.f32.mul();
|
|
2394
|
+
cg.f32.nearest();
|
|
2395
|
+
cg.local.tee(qf);
|
|
2396
|
+
cg.f32.const(2 * Math.PI);
|
|
2397
|
+
cg.f32.mul();
|
|
2398
|
+
cg.f32.sub();
|
|
2399
|
+
cg.local.set(y);
|
|
2400
|
+
cg.local.get(y);
|
|
2401
|
+
cg.f32.const(2 / Math.PI);
|
|
2402
|
+
cg.f32.mul();
|
|
2403
|
+
cg.f32.nearest();
|
|
2404
|
+
cg.local.tee(qf);
|
|
2405
|
+
cg.i32.trunc_f32_s();
|
|
2406
|
+
cg.local.set(q);
|
|
2407
|
+
cg.local.get(y);
|
|
2408
|
+
cg.local.get(qf);
|
|
2409
|
+
cg.f32.const(Math.PI / 2);
|
|
2410
|
+
cg.f32.mul();
|
|
2411
|
+
cg.f32.sub();
|
|
2412
|
+
cg.local.tee(z);
|
|
2413
|
+
cg.local.get(z);
|
|
2414
|
+
cg.f32.mul();
|
|
2415
|
+
cg.local.set(z2);
|
|
2416
|
+
cg.f32.const(-1 / 5040);
|
|
2417
|
+
cg.local.get(z2);
|
|
2418
|
+
cg.f32.mul();
|
|
2419
|
+
cg.f32.const(1 / 120);
|
|
2420
|
+
cg.f32.add();
|
|
2421
|
+
cg.local.get(z2);
|
|
2422
|
+
cg.f32.mul();
|
|
2423
|
+
cg.f32.const(-1 / 6);
|
|
2424
|
+
cg.f32.add();
|
|
2425
|
+
cg.local.get(z2);
|
|
2426
|
+
cg.f32.mul();
|
|
2427
|
+
cg.f32.const(1);
|
|
2428
|
+
cg.f32.add();
|
|
2429
|
+
cg.local.get(z);
|
|
2430
|
+
cg.f32.mul();
|
|
2431
|
+
cg.local.set(sz);
|
|
2432
|
+
cg.f32.const(-1 / 720);
|
|
2433
|
+
cg.local.get(z2);
|
|
2434
|
+
cg.f32.mul();
|
|
2435
|
+
cg.f32.const(1 / 24);
|
|
2436
|
+
cg.f32.add();
|
|
2437
|
+
cg.local.get(z2);
|
|
2438
|
+
cg.f32.mul();
|
|
2439
|
+
cg.f32.const(-1 / 2);
|
|
2440
|
+
cg.f32.add();
|
|
2441
|
+
cg.local.get(z2);
|
|
2442
|
+
cg.f32.mul();
|
|
2443
|
+
cg.f32.const(1);
|
|
2444
|
+
cg.f32.add();
|
|
2445
|
+
cg.local.set(cz);
|
|
2446
|
+
cg.local.get(sz);
|
|
2447
|
+
cg.local.get(cz);
|
|
2448
|
+
cg.local.get(q);
|
|
2449
|
+
cg.i32.const(1);
|
|
2450
|
+
cg.i32.and();
|
|
2451
|
+
cg.select();
|
|
2452
|
+
cg.local.tee(mag);
|
|
2453
|
+
cg.f32.neg();
|
|
2454
|
+
cg.local.get(mag);
|
|
2455
|
+
cg.local.get(q);
|
|
2456
|
+
cg.i32.const(1);
|
|
2457
|
+
cg.i32.add();
|
|
2458
|
+
cg.i32.const(2);
|
|
2459
|
+
cg.i32.and();
|
|
2460
|
+
cg.select();
|
|
2461
|
+
});
|
|
2462
|
+
}
|
|
2463
|
+
/**
|
|
2464
|
+
* Threefry2x32 pseudorandom number generator.
|
|
2465
|
+
*
|
|
2466
|
+
* Takes two 32-bit keys and two 32-bit counters as input,
|
|
2467
|
+
* returns two 32-bit pseudorandom values.
|
|
2468
|
+
*/
|
|
2469
|
+
function wasm_threefry2x32(cg) {
|
|
2470
|
+
return cg.function([
|
|
2471
|
+
cg.i32,
|
|
2472
|
+
cg.i32,
|
|
2473
|
+
cg.i32,
|
|
2474
|
+
cg.i32
|
|
2475
|
+
], [cg.i32, cg.i32], () => {
|
|
2476
|
+
const ks0 = cg.local.declare(cg.i32);
|
|
2477
|
+
const ks1 = cg.local.declare(cg.i32);
|
|
2478
|
+
const ks2 = cg.local.declare(cg.i32);
|
|
2479
|
+
const x0 = cg.local.declare(cg.i32);
|
|
2480
|
+
const x1 = cg.local.declare(cg.i32);
|
|
2481
|
+
const mix = (rot) => {
|
|
2482
|
+
cg.local.get(x0);
|
|
2483
|
+
cg.local.get(x1);
|
|
2484
|
+
cg.i32.add();
|
|
2485
|
+
cg.local.set(x0);
|
|
2486
|
+
cg.local.get(x1);
|
|
2487
|
+
cg.i32.const(rot);
|
|
2488
|
+
cg.i32.rotl();
|
|
2489
|
+
cg.local.get(x0);
|
|
2490
|
+
cg.i32.xor();
|
|
2491
|
+
cg.local.set(x1);
|
|
2492
|
+
};
|
|
2493
|
+
const keySchedule = (k0, k1, round) => {
|
|
2494
|
+
cg.local.get(x0);
|
|
2495
|
+
cg.local.get(k0);
|
|
2496
|
+
cg.i32.add();
|
|
2497
|
+
cg.local.set(x0);
|
|
2498
|
+
cg.local.get(x1);
|
|
2499
|
+
cg.local.get(k1);
|
|
2500
|
+
cg.i32.add();
|
|
2501
|
+
cg.i32.const(round);
|
|
2502
|
+
cg.i32.add();
|
|
2503
|
+
cg.local.set(x1);
|
|
2504
|
+
};
|
|
2505
|
+
cg.local.get(0);
|
|
2506
|
+
cg.local.set(ks0);
|
|
2507
|
+
cg.local.get(1);
|
|
2508
|
+
cg.local.set(ks1);
|
|
2509
|
+
cg.local.get(0);
|
|
2510
|
+
cg.local.get(1);
|
|
2511
|
+
cg.i32.xor();
|
|
2512
|
+
cg.i32.const(466688986);
|
|
2513
|
+
cg.i32.xor();
|
|
2514
|
+
cg.local.set(ks2);
|
|
2515
|
+
cg.local.get(2);
|
|
2516
|
+
cg.local.get(ks0);
|
|
2517
|
+
cg.i32.add();
|
|
2518
|
+
cg.local.set(x0);
|
|
2519
|
+
cg.local.get(3);
|
|
2520
|
+
cg.local.get(ks1);
|
|
2521
|
+
cg.i32.add();
|
|
2522
|
+
cg.local.set(x1);
|
|
2523
|
+
mix(13), mix(15), mix(26), mix(6);
|
|
2524
|
+
keySchedule(ks1, ks2, 1);
|
|
2525
|
+
mix(17), mix(29), mix(16), mix(24);
|
|
2526
|
+
keySchedule(ks2, ks0, 2);
|
|
2527
|
+
mix(13), mix(15), mix(26), mix(6);
|
|
2528
|
+
keySchedule(ks0, ks1, 3);
|
|
2529
|
+
mix(17), mix(29), mix(16), mix(24);
|
|
2530
|
+
keySchedule(ks1, ks2, 4);
|
|
2531
|
+
mix(13), mix(15), mix(26), mix(6);
|
|
2532
|
+
keySchedule(ks2, ks0, 5);
|
|
2533
|
+
cg.local.get(x0);
|
|
2534
|
+
cg.local.get(x1);
|
|
2535
|
+
});
|
|
2536
|
+
}
|
|
2537
|
+
|
|
2538
|
+
//#endregion
|
|
2539
|
+
//#region src/backend/wasm/wasmblr.ts
|
|
2540
|
+
/**
|
|
2541
|
+
* @file Minimalist WebAssembly assembler. This allows you to emit WebAssembly
|
|
2542
|
+
* bytecode directly from the browser.
|
|
2543
|
+
*
|
|
2544
|
+
* Self-contained port of https://github.com/bwasti/wasmblr to TypeScript.
|
|
2545
|
+
* Some operation names in this module are written in `snake_case` to match
|
|
2546
|
+
* their names in the Wasm specification.
|
|
2547
|
+
*
|
|
2548
|
+
* Reference: https://pengowray.github.io/wasm-ops/.
|
|
2549
|
+
*/
|
|
2550
|
+
const magicModuleHeader = [
|
|
2551
|
+
0,
|
|
2552
|
+
97,
|
|
2553
|
+
115,
|
|
2554
|
+
109
|
|
2555
|
+
];
|
|
2556
|
+
const moduleVersion = [
|
|
2557
|
+
1,
|
|
2558
|
+
0,
|
|
2559
|
+
0,
|
|
2560
|
+
0
|
|
2561
|
+
];
|
|
2562
|
+
function assert(condition, message) {
|
|
2563
|
+
if (!condition) throw new Error(message || "Assertion failed");
|
|
2564
|
+
}
|
|
2565
|
+
function encodeSigned(n) {
|
|
2566
|
+
const out = [];
|
|
2567
|
+
let more = true;
|
|
2568
|
+
while (more) {
|
|
2569
|
+
let byte = n & 127;
|
|
2570
|
+
n >>= 7;
|
|
2571
|
+
if (n === 0 && (byte & 64) === 0 || n === -1 && (byte & 64) !== 0) more = false;
|
|
2572
|
+
else byte |= 128;
|
|
2573
|
+
out.push(byte);
|
|
2574
|
+
}
|
|
2575
|
+
return out;
|
|
2576
|
+
}
|
|
2577
|
+
function encodeUnsigned(n) {
|
|
2578
|
+
const out = [];
|
|
2579
|
+
do {
|
|
2580
|
+
let byte = n & 127;
|
|
2581
|
+
n = n >>> 7;
|
|
2582
|
+
if (n !== 0) byte |= 128;
|
|
2583
|
+
out.push(byte);
|
|
2584
|
+
} while (n !== 0);
|
|
2585
|
+
return out;
|
|
2586
|
+
}
|
|
2587
|
+
function encodeString(s) {
|
|
2588
|
+
const bytes = new TextEncoder().encode(s);
|
|
2589
|
+
return [bytes.length, ...bytes];
|
|
2590
|
+
}
|
|
2591
|
+
function encodeBlocktype(type) {
|
|
2592
|
+
assert(type.length > 0, "blocktype must have at least one type");
|
|
2593
|
+
if (type.length === 1) return [type[0].typeId];
|
|
2594
|
+
return [
|
|
2595
|
+
96,
|
|
2596
|
+
...encodeUnsigned(0),
|
|
2597
|
+
...encodeUnsigned(type.length),
|
|
2598
|
+
...type.map((t) => t.typeId)
|
|
2599
|
+
];
|
|
2600
|
+
}
|
|
2601
|
+
function encodeOpcode(opcode) {
|
|
2602
|
+
if (typeof opcode === "number") return [opcode];
|
|
2603
|
+
return [opcode[0], ...encodeUnsigned(opcode[1])];
|
|
2604
|
+
}
|
|
2605
|
+
function concat(out, inp) {
|
|
2606
|
+
out.push(...inp);
|
|
2607
|
+
}
|
|
2608
|
+
var Function_ = class {
|
|
2609
|
+
inputTypes;
|
|
2610
|
+
outputTypes;
|
|
2611
|
+
body;
|
|
2612
|
+
locals = [];
|
|
2613
|
+
constructor(inputTypes, outputTypes, body) {
|
|
2614
|
+
this.inputTypes = inputTypes;
|
|
2615
|
+
this.outputTypes = outputTypes;
|
|
2616
|
+
this.body = body || (() => {});
|
|
2617
|
+
}
|
|
2618
|
+
emit() {
|
|
2619
|
+
this.locals = [];
|
|
2620
|
+
this.body();
|
|
2621
|
+
}
|
|
2622
|
+
};
|
|
2623
|
+
var Memory = class {
|
|
2624
|
+
min = 0;
|
|
2625
|
+
max = 0;
|
|
2626
|
+
isShared = false;
|
|
2627
|
+
aString = "";
|
|
2628
|
+
bString = "";
|
|
2629
|
+
constructor(cg) {
|
|
2630
|
+
this.cg = cg;
|
|
2631
|
+
}
|
|
2632
|
+
/** Declare the size of the memory. Each page is 64 KiB. */
|
|
2633
|
+
pages(min, max = 0) {
|
|
2634
|
+
assert(this.min === 0 && this.max === 0);
|
|
2635
|
+
this.min = min;
|
|
2636
|
+
this.max = max;
|
|
2637
|
+
return this;
|
|
2638
|
+
}
|
|
2639
|
+
export(a) {
|
|
2640
|
+
assert(!this.isImport && !this.isExport, "already set");
|
|
2641
|
+
this.aString = a;
|
|
2642
|
+
return this;
|
|
2643
|
+
}
|
|
2644
|
+
shared(isShared) {
|
|
2645
|
+
this.isShared = isShared;
|
|
2646
|
+
return this;
|
|
2647
|
+
}
|
|
2648
|
+
import(a, b) {
|
|
2649
|
+
assert(!this.isImport && !this.isExport, "already set");
|
|
2650
|
+
this.aString = a;
|
|
2651
|
+
this.bString = b;
|
|
2652
|
+
return this;
|
|
2653
|
+
}
|
|
2654
|
+
size() {
|
|
2655
|
+
this.cg._emit(63);
|
|
2656
|
+
this.cg._emit(0);
|
|
2657
|
+
}
|
|
2658
|
+
grow() {
|
|
2659
|
+
this.cg._emit(64);
|
|
2660
|
+
this.cg._emit(0);
|
|
2661
|
+
}
|
|
2662
|
+
get isImport() {
|
|
2663
|
+
return this.aString.length > 0 && this.bString.length > 0;
|
|
2664
|
+
}
|
|
2665
|
+
get isExport() {
|
|
2666
|
+
return this.aString.length > 0 && this.bString.length === 0;
|
|
2667
|
+
}
|
|
2668
|
+
};
|
|
2669
|
+
/** Public API of WebAssembly assembler. */
|
|
2670
|
+
var CodeGenerator = class {
|
|
2671
|
+
local;
|
|
2672
|
+
i32;
|
|
2673
|
+
f32;
|
|
2674
|
+
v128;
|
|
2675
|
+
i32x4;
|
|
2676
|
+
f32x4;
|
|
2677
|
+
memory;
|
|
2678
|
+
void = {
|
|
2679
|
+
typeId: 64,
|
|
2680
|
+
name: "void"
|
|
2681
|
+
};
|
|
2682
|
+
#functions = [];
|
|
2683
|
+
#importedFunctions = [];
|
|
2684
|
+
#exportedFunctions = /* @__PURE__ */ new Map();
|
|
2685
|
+
#curFunction = null;
|
|
2686
|
+
#curBytes = [];
|
|
2687
|
+
#typeStack = [];
|
|
2688
|
+
#blockFrames = [];
|
|
2689
|
+
constructor() {
|
|
2690
|
+
this.local = new Local(this);
|
|
2691
|
+
this.i32 = new I32(this);
|
|
2692
|
+
this.f32 = new F32(this);
|
|
2693
|
+
this.v128 = new V128(this);
|
|
2694
|
+
this.i32x4 = new I32x4(this);
|
|
2695
|
+
this.f32x4 = new F32x4(this);
|
|
2696
|
+
this.memory = new Memory(this);
|
|
2697
|
+
}
|
|
2698
|
+
unreachable() {
|
|
2699
|
+
this._emit(0);
|
|
2700
|
+
}
|
|
2701
|
+
nop() {
|
|
2702
|
+
this._emit(1);
|
|
2703
|
+
}
|
|
2704
|
+
block(...type) {
|
|
2705
|
+
this.#blockFrames.push({
|
|
2706
|
+
idx: this.#typeStack.length,
|
|
2707
|
+
ty: type
|
|
2708
|
+
});
|
|
2709
|
+
this._emit(2);
|
|
2710
|
+
this._emit(encodeBlocktype(type));
|
|
2711
|
+
}
|
|
2712
|
+
loop(...type) {
|
|
2713
|
+
this.#blockFrames.push({
|
|
2714
|
+
idx: this.#typeStack.length,
|
|
2715
|
+
ty: type
|
|
2716
|
+
});
|
|
2717
|
+
this._emit(3);
|
|
2718
|
+
this._emit(encodeBlocktype(type));
|
|
2719
|
+
}
|
|
2720
|
+
if(...type) {
|
|
2721
|
+
assert(this._pop().typeId === this.i32.typeId, "if_: expected i32");
|
|
2722
|
+
this.#blockFrames.push({
|
|
2723
|
+
idx: this.#typeStack.length,
|
|
2724
|
+
ty: type
|
|
2725
|
+
});
|
|
2726
|
+
this._emit(4);
|
|
2727
|
+
this._emit(encodeBlocktype(type));
|
|
2728
|
+
}
|
|
2729
|
+
else() {
|
|
2730
|
+
assert(this.#blockFrames.length > 0, "else: no block to else");
|
|
2731
|
+
const frame = this.#blockFrames[this.#blockFrames.length - 1];
|
|
2732
|
+
this.#typeStack = this.#typeStack.slice(0, frame.idx);
|
|
2733
|
+
this._emit(5);
|
|
2734
|
+
}
|
|
2735
|
+
/** End a block (`block`, `if`/`else`, `loop`, or function). */
|
|
2736
|
+
end() {
|
|
2737
|
+
const frame = this.#blockFrames.pop();
|
|
2738
|
+
assert(frame !== void 0, "end: no block to end");
|
|
2739
|
+
this.#typeStack = this.#typeStack.slice(0, frame.idx);
|
|
2740
|
+
for (const ty of frame.ty) if (ty.typeId !== this.void.typeId) this._push(ty);
|
|
2741
|
+
this._emit(11);
|
|
2742
|
+
}
|
|
2743
|
+
/** Branch to a block a certain depth outward on the stack. */
|
|
2744
|
+
br(depth) {
|
|
2745
|
+
this._emit(12);
|
|
2746
|
+
this._emit(encodeUnsigned(depth));
|
|
2747
|
+
}
|
|
2748
|
+
/** Conditional branch to a block a certain depth outward on the stack. */
|
|
2749
|
+
br_if(depth) {
|
|
2750
|
+
assert(this._pop().typeId === this.i32.typeId, "br_if: expected i32");
|
|
2751
|
+
this._emit(13);
|
|
2752
|
+
this._emit(encodeUnsigned(depth));
|
|
2753
|
+
}
|
|
2754
|
+
/** Jump table that indexes into a label vector (like switch). */
|
|
2755
|
+
br_table(...depths) {
|
|
2756
|
+
assert(this._pop().typeId === this.i32.typeId, "br_table: expected i32");
|
|
2757
|
+
assert(depths.length > 0, "br_table: expected at least one default depth");
|
|
2758
|
+
this._emit(14);
|
|
2759
|
+
this._emit(encodeUnsigned(depths.length - 1));
|
|
2760
|
+
for (const d of depths) this._emit(encodeUnsigned(d));
|
|
2761
|
+
}
|
|
2762
|
+
/** Return from a function, branching out of the outermost block. */
|
|
2763
|
+
return() {
|
|
2764
|
+
this._emit(15);
|
|
2765
|
+
}
|
|
2766
|
+
/** Call a function with the given ID. */
|
|
2767
|
+
call(fn) {
|
|
2768
|
+
const totalFunctions = this.#importedFunctions.length + this.#functions.length;
|
|
2769
|
+
assert(fn < totalFunctions, "function index does not exist");
|
|
2770
|
+
const func = fn < this.#importedFunctions.length ? this.#importedFunctions[fn] : this.#functions[fn - this.#importedFunctions.length];
|
|
2771
|
+
for (let i = func.inputTypes.length - 1; i >= 0; i--) {
|
|
2772
|
+
const argType = this._pop();
|
|
2773
|
+
assert(argType.typeId === func.inputTypes[i].typeId, `call: argument ${i} type mismatch, expected ${func.inputTypes[i].name} got ${argType.name}`);
|
|
2774
|
+
}
|
|
2775
|
+
for (const outputType of func.outputTypes) this._push(outputType);
|
|
2776
|
+
this._emit(16);
|
|
2777
|
+
this._emit(encodeUnsigned(fn));
|
|
2778
|
+
}
|
|
2779
|
+
/** Throw away an operand on the stack. */
|
|
2780
|
+
drop() {
|
|
2781
|
+
this._pop();
|
|
2782
|
+
this._emit(26);
|
|
2783
|
+
}
|
|
2784
|
+
/** Select one of the first two operands (T, F) based on the third operand (i32)'s value. */
|
|
2785
|
+
select() {
|
|
2786
|
+
assert(this._pop().typeId === this.i32.typeId, "select: expected i32 condition");
|
|
2787
|
+
const [b, a] = [this._pop(), this._pop()];
|
|
2788
|
+
assert(a.typeId === b.typeId, "select: expected same type for both operands");
|
|
2789
|
+
this._push(a);
|
|
2790
|
+
this._emit(27);
|
|
2791
|
+
}
|
|
2792
|
+
/** Import a JavaScript function; returns its index. */
|
|
2793
|
+
importFunction(module$1, name, inputTypes, outputTypes) {
|
|
2794
|
+
if (this.#functions.length > 0) throw new Error("function imports must precede defining functions");
|
|
2795
|
+
const idx = this.#importedFunctions.length;
|
|
2796
|
+
this.#importedFunctions.push({
|
|
2797
|
+
module: module$1,
|
|
2798
|
+
name,
|
|
2799
|
+
inputTypes,
|
|
2800
|
+
outputTypes
|
|
2801
|
+
});
|
|
2802
|
+
return idx;
|
|
2803
|
+
}
|
|
2804
|
+
/** Export a function. */
|
|
2805
|
+
export(fn, name) {
|
|
2806
|
+
this.#exportedFunctions.set(fn, name);
|
|
2807
|
+
}
|
|
2808
|
+
/** Declare a new function; returns its index. */
|
|
2809
|
+
function(inputTypes, outputTypes, body) {
|
|
2810
|
+
const idx = this.#importedFunctions.length + this.#functions.length;
|
|
2811
|
+
this.#functions.push(new Function_(inputTypes, outputTypes, body));
|
|
2812
|
+
return idx;
|
|
2813
|
+
}
|
|
2814
|
+
_declareLocal(type) {
|
|
2815
|
+
assert(this.#curFunction !== null, "No current function");
|
|
2816
|
+
const idx = this.#curFunction.locals.length + this.#curFunction.inputTypes.length;
|
|
2817
|
+
this.#curFunction.locals.push(type);
|
|
2818
|
+
return idx;
|
|
2819
|
+
}
|
|
2820
|
+
_inputTypes() {
|
|
2821
|
+
assert(this.#curFunction !== null, "No current function");
|
|
2822
|
+
return this.#curFunction.inputTypes;
|
|
2823
|
+
}
|
|
2824
|
+
_locals() {
|
|
2825
|
+
assert(this.#curFunction !== null, "No current function");
|
|
2826
|
+
return this.#curFunction.locals;
|
|
2827
|
+
}
|
|
2828
|
+
_push(type) {
|
|
2829
|
+
if (!type) throw new Error(`pushing type ${type}`);
|
|
2830
|
+
this.#typeStack.push(type);
|
|
2831
|
+
}
|
|
2832
|
+
_pop() {
|
|
2833
|
+
assert(this.#typeStack.length > 0, "popping empty stack");
|
|
2834
|
+
return this.#typeStack.pop();
|
|
2835
|
+
}
|
|
2836
|
+
_emit(bytes) {
|
|
2837
|
+
if (typeof bytes === "number") this.#curBytes.push(bytes);
|
|
2838
|
+
else this.#curBytes.push(...bytes);
|
|
2839
|
+
}
|
|
2840
|
+
finish() {
|
|
2841
|
+
this.#curBytes = [];
|
|
2842
|
+
const emittedBytes = [];
|
|
2843
|
+
concat(emittedBytes, magicModuleHeader);
|
|
2844
|
+
concat(emittedBytes, moduleVersion);
|
|
2845
|
+
const typeSectionBytes = [];
|
|
2846
|
+
const totalFunctionTypes = this.#importedFunctions.length + this.#functions.length;
|
|
2847
|
+
concat(typeSectionBytes, encodeUnsigned(totalFunctionTypes));
|
|
2848
|
+
for (const f of [...this.#importedFunctions, ...this.#functions]) {
|
|
2849
|
+
typeSectionBytes.push(96);
|
|
2850
|
+
concat(typeSectionBytes, encodeUnsigned(f.inputTypes.length));
|
|
2851
|
+
for (const t of f.inputTypes) typeSectionBytes.push(t.typeId);
|
|
2852
|
+
concat(typeSectionBytes, encodeUnsigned(f.outputTypes.length));
|
|
2853
|
+
for (const t of f.outputTypes) typeSectionBytes.push(t.typeId);
|
|
2854
|
+
}
|
|
2855
|
+
emittedBytes.push(1);
|
|
2856
|
+
concat(emittedBytes, encodeUnsigned(typeSectionBytes.length));
|
|
2857
|
+
concat(emittedBytes, typeSectionBytes);
|
|
2858
|
+
const importSectionBytes = [];
|
|
2859
|
+
const numImports = this.#importedFunctions.length + (this.memory.isImport ? 1 : 0);
|
|
2860
|
+
if (numImports > 0) {
|
|
2861
|
+
concat(importSectionBytes, encodeUnsigned(numImports));
|
|
2862
|
+
for (let i = 0; i < this.#importedFunctions.length; i++) {
|
|
2863
|
+
const f = this.#importedFunctions[i];
|
|
2864
|
+
concat(importSectionBytes, encodeString(f.module));
|
|
2865
|
+
concat(importSectionBytes, encodeString(f.name));
|
|
2866
|
+
importSectionBytes.push(0);
|
|
2867
|
+
concat(importSectionBytes, encodeUnsigned(i));
|
|
2868
|
+
}
|
|
2869
|
+
if (this.memory.isImport) {
|
|
2870
|
+
concat(importSectionBytes, encodeString(this.memory.aString));
|
|
2871
|
+
concat(importSectionBytes, encodeString(this.memory.bString));
|
|
2872
|
+
importSectionBytes.push(2);
|
|
2873
|
+
if (this.memory.min && this.memory.max) {
|
|
2874
|
+
if (this.memory.isShared) importSectionBytes.push(3);
|
|
2875
|
+
else importSectionBytes.push(1);
|
|
2876
|
+
concat(importSectionBytes, encodeUnsigned(this.memory.min));
|
|
2877
|
+
concat(importSectionBytes, encodeUnsigned(this.memory.max));
|
|
2878
|
+
} else {
|
|
2879
|
+
assert(!this.memory.isShared, "shared memory must have a max size");
|
|
2880
|
+
importSectionBytes.push(0);
|
|
2881
|
+
concat(importSectionBytes, encodeUnsigned(this.memory.min));
|
|
2882
|
+
}
|
|
2883
|
+
}
|
|
2884
|
+
emittedBytes.push(2);
|
|
2885
|
+
concat(emittedBytes, encodeUnsigned(importSectionBytes.length));
|
|
2886
|
+
concat(emittedBytes, importSectionBytes);
|
|
2887
|
+
}
|
|
2888
|
+
const functionSectionBytes = [];
|
|
2889
|
+
concat(functionSectionBytes, encodeUnsigned(this.#functions.length));
|
|
2890
|
+
for (let i = 0; i < this.#functions.length; i++) {
|
|
2891
|
+
const typeIndex = this.#importedFunctions.length + i;
|
|
2892
|
+
concat(functionSectionBytes, encodeUnsigned(typeIndex));
|
|
2893
|
+
}
|
|
2894
|
+
emittedBytes.push(3);
|
|
2895
|
+
concat(emittedBytes, encodeUnsigned(functionSectionBytes.length));
|
|
2896
|
+
concat(emittedBytes, functionSectionBytes);
|
|
2897
|
+
const memorySectionBytes = [];
|
|
2898
|
+
if (!this.memory.isImport && (this.memory.min || this.memory.max)) {
|
|
2899
|
+
memorySectionBytes.push(1);
|
|
2900
|
+
if (this.memory.min && this.memory.max) {
|
|
2901
|
+
if (this.memory.isShared) memorySectionBytes.push(3);
|
|
2902
|
+
else memorySectionBytes.push(1);
|
|
2903
|
+
concat(memorySectionBytes, encodeUnsigned(this.memory.min));
|
|
2904
|
+
concat(memorySectionBytes, encodeUnsigned(this.memory.max));
|
|
2905
|
+
} else {
|
|
2906
|
+
assert(!this.memory.isShared, "shared memory must have a max size");
|
|
2907
|
+
memorySectionBytes.push(0);
|
|
2908
|
+
concat(memorySectionBytes, encodeUnsigned(this.memory.min));
|
|
2909
|
+
}
|
|
2910
|
+
emittedBytes.push(5);
|
|
2911
|
+
concat(emittedBytes, encodeUnsigned(memorySectionBytes.length));
|
|
2912
|
+
concat(emittedBytes, memorySectionBytes);
|
|
2913
|
+
}
|
|
2914
|
+
const exportSectionBytes = [];
|
|
2915
|
+
const numExports = this.#exportedFunctions.size + (this.memory.isExport ? 1 : 0);
|
|
2916
|
+
concat(exportSectionBytes, encodeUnsigned(numExports));
|
|
2917
|
+
if (this.memory.isExport) {
|
|
2918
|
+
concat(exportSectionBytes, encodeString(this.memory.aString));
|
|
2919
|
+
exportSectionBytes.push(2);
|
|
2920
|
+
exportSectionBytes.push(0);
|
|
2921
|
+
}
|
|
2922
|
+
for (const [key, name] of this.#exportedFunctions.entries()) {
|
|
2923
|
+
concat(exportSectionBytes, encodeString(name));
|
|
2924
|
+
exportSectionBytes.push(0);
|
|
2925
|
+
concat(exportSectionBytes, encodeUnsigned(key));
|
|
2926
|
+
}
|
|
2927
|
+
emittedBytes.push(7);
|
|
2928
|
+
concat(emittedBytes, encodeUnsigned(exportSectionBytes.length));
|
|
2929
|
+
concat(emittedBytes, exportSectionBytes);
|
|
2930
|
+
const codeSectionBytes = [];
|
|
2931
|
+
concat(codeSectionBytes, encodeUnsigned(this.#functions.length));
|
|
2932
|
+
for (const f of this.#functions) {
|
|
2933
|
+
this.#typeStack = [];
|
|
2934
|
+
this.#blockFrames = [{
|
|
2935
|
+
idx: 0,
|
|
2936
|
+
ty: f.outputTypes
|
|
2937
|
+
}];
|
|
2938
|
+
this.#curFunction = f;
|
|
2939
|
+
this.#curBytes = [];
|
|
2940
|
+
f.emit();
|
|
2941
|
+
this.end();
|
|
2942
|
+
const bodyBytes = [...this.#curBytes];
|
|
2943
|
+
this.#curBytes = [];
|
|
2944
|
+
concat(this.#curBytes, encodeUnsigned(f.locals.length));
|
|
2945
|
+
for (const l of f.locals) {
|
|
2946
|
+
this._emit(1);
|
|
2947
|
+
this._emit(l.typeId);
|
|
2948
|
+
}
|
|
2949
|
+
const headerBytes = [...this.#curBytes];
|
|
2950
|
+
const fnSize = headerBytes.length + bodyBytes.length;
|
|
2951
|
+
concat(codeSectionBytes, encodeUnsigned(fnSize));
|
|
2952
|
+
concat(codeSectionBytes, headerBytes);
|
|
2953
|
+
concat(codeSectionBytes, bodyBytes);
|
|
2954
|
+
}
|
|
2955
|
+
this.#curFunction = null;
|
|
2956
|
+
emittedBytes.push(10);
|
|
2957
|
+
concat(emittedBytes, encodeUnsigned(codeSectionBytes.length));
|
|
2958
|
+
concat(emittedBytes, codeSectionBytes);
|
|
2959
|
+
return new Uint8Array(emittedBytes);
|
|
2960
|
+
}
|
|
2961
|
+
};
|
|
2962
|
+
var Local = class {
|
|
2963
|
+
constructor(cg) {
|
|
2964
|
+
this.cg = cg;
|
|
2965
|
+
}
|
|
2966
|
+
declare(type) {
|
|
2967
|
+
return this.cg._declareLocal(type);
|
|
2968
|
+
}
|
|
2969
|
+
get(idx) {
|
|
2970
|
+
assert(Number.isInteger(idx), "getting non-integer local");
|
|
2971
|
+
const inputTypes = this.cg._inputTypes();
|
|
2972
|
+
if (idx < inputTypes.length) this.cg._push(inputTypes[idx]);
|
|
2973
|
+
else this.cg._push(this.cg._locals()[idx - inputTypes.length]);
|
|
2974
|
+
this.cg._emit(32);
|
|
2975
|
+
this.cg._emit(encodeUnsigned(idx));
|
|
2976
|
+
}
|
|
2977
|
+
set(idx) {
|
|
2978
|
+
const t = this.cg._pop();
|
|
2979
|
+
const inputTypes = this.cg._inputTypes();
|
|
2980
|
+
const expectedType = idx < inputTypes.length ? inputTypes[idx] : this.cg._locals()[idx - inputTypes.length];
|
|
2981
|
+
assert(expectedType.typeId === t.typeId, "can't set local to this value (wrong type)");
|
|
2982
|
+
this.cg._emit(33);
|
|
2983
|
+
this.cg._emit(encodeUnsigned(idx));
|
|
2984
|
+
}
|
|
2985
|
+
tee(idx) {
|
|
2986
|
+
const t = this.cg._pop();
|
|
2987
|
+
const inputTypes = this.cg._inputTypes();
|
|
2988
|
+
const expectedType = idx < inputTypes.length ? inputTypes[idx] : this.cg._locals()[idx - inputTypes.length];
|
|
2989
|
+
assert(expectedType.typeId === t.typeId, "can't tee local to this value (wrong type)");
|
|
2990
|
+
this.cg._emit(34);
|
|
2991
|
+
this.cg._emit(encodeUnsigned(idx));
|
|
2992
|
+
this.cg._push(expectedType);
|
|
2993
|
+
}
|
|
2994
|
+
};
|
|
2995
|
+
function UNARY_OP(op, opcode, inType, outType) {
|
|
2996
|
+
return function() {
|
|
2997
|
+
const t = this.cg._pop();
|
|
2998
|
+
assert(t.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inType} -> ${outType})`);
|
|
2999
|
+
this.cg._emit(encodeOpcode(opcode));
|
|
3000
|
+
this.cg._push(this.cg[outType]);
|
|
3001
|
+
};
|
|
3002
|
+
}
|
|
3003
|
+
function BINARY_OP(op, opcode, typeA, typeB, outType) {
|
|
3004
|
+
return function() {
|
|
3005
|
+
const b = this.cg._pop();
|
|
3006
|
+
const a = this.cg._pop();
|
|
3007
|
+
assert(a.typeId === this.cg[typeA].typeId && b.typeId === this.cg[typeB].typeId, `invalid type for ${op} (${typeA}, ${typeB} -> ${outType})`);
|
|
3008
|
+
this.cg._emit(encodeOpcode(opcode));
|
|
3009
|
+
this.cg._push(this.cg[outType]);
|
|
3010
|
+
};
|
|
3011
|
+
}
|
|
3012
|
+
function LOAD_OP(op, opcode, outType) {
|
|
3013
|
+
return function(align = 0, offset = 0) {
|
|
3014
|
+
const idxType = this.cg._pop();
|
|
3015
|
+
assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
|
|
3016
|
+
this.cg._emit(encodeOpcode(opcode));
|
|
3017
|
+
this.cg._emit(encodeUnsigned(align));
|
|
3018
|
+
this.cg._emit(encodeUnsigned(offset));
|
|
3019
|
+
this.cg._push(this.cg[outType]);
|
|
3020
|
+
};
|
|
3021
|
+
}
|
|
3022
|
+
function STORE_OP(op, opcode, inType) {
|
|
3023
|
+
return function(align = 0, offset = 0) {
|
|
3024
|
+
const valType = this.cg._pop();
|
|
3025
|
+
const idxType = this.cg._pop();
|
|
3026
|
+
assert(valType.typeId === this.cg[inType].typeId, `invalid value type for ${op} (${inType})`);
|
|
3027
|
+
assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
|
|
3028
|
+
this.cg._emit(encodeOpcode(opcode));
|
|
3029
|
+
this.cg._emit(encodeUnsigned(align));
|
|
3030
|
+
this.cg._emit(encodeUnsigned(offset));
|
|
3031
|
+
};
|
|
3032
|
+
}
|
|
3033
|
+
var I32 = class {
|
|
3034
|
+
constructor(cg) {
|
|
3035
|
+
this.cg = cg;
|
|
3036
|
+
}
|
|
3037
|
+
get typeId() {
|
|
3038
|
+
return 127;
|
|
3039
|
+
}
|
|
3040
|
+
get name() {
|
|
3041
|
+
return "i32";
|
|
3042
|
+
}
|
|
3043
|
+
const(i) {
|
|
3044
|
+
this.cg._emit(65);
|
|
3045
|
+
this.cg._emit(encodeSigned(i));
|
|
3046
|
+
this.cg._push(this);
|
|
3047
|
+
}
|
|
3048
|
+
clz = UNARY_OP("clz", 103, "i32", "i32");
|
|
3049
|
+
ctz = UNARY_OP("ctz", 104, "i32", "i32");
|
|
3050
|
+
popcnt = UNARY_OP("popcnt", 105, "i32", "i32");
|
|
3051
|
+
lt_s = BINARY_OP("lt_s", 72, "i32", "i32", "i32");
|
|
3052
|
+
lt_u = BINARY_OP("lt_u", 73, "i32", "i32", "i32");
|
|
3053
|
+
gt_s = BINARY_OP("gt_s", 74, "i32", "i32", "i32");
|
|
3054
|
+
gt_u = BINARY_OP("gt_u", 75, "i32", "i32", "i32");
|
|
3055
|
+
le_s = BINARY_OP("le_s", 76, "i32", "i32", "i32");
|
|
3056
|
+
le_u = BINARY_OP("le_u", 77, "i32", "i32", "i32");
|
|
3057
|
+
ge_s = BINARY_OP("ge_s", 78, "i32", "i32", "i32");
|
|
3058
|
+
ge_u = BINARY_OP("ge_u", 79, "i32", "i32", "i32");
|
|
3059
|
+
add = BINARY_OP("add", 106, "i32", "i32", "i32");
|
|
3060
|
+
sub = BINARY_OP("sub", 107, "i32", "i32", "i32");
|
|
3061
|
+
mul = BINARY_OP("mul", 108, "i32", "i32", "i32");
|
|
3062
|
+
div_s = BINARY_OP("div_s", 109, "i32", "i32", "i32");
|
|
3063
|
+
div_u = BINARY_OP("div_u", 110, "i32", "i32", "i32");
|
|
3064
|
+
rem_s = BINARY_OP("rem_s", 111, "i32", "i32", "i32");
|
|
3065
|
+
rem_u = BINARY_OP("rem_u", 112, "i32", "i32", "i32");
|
|
3066
|
+
and = BINARY_OP("and", 113, "i32", "i32", "i32");
|
|
3067
|
+
or = BINARY_OP("or", 114, "i32", "i32", "i32");
|
|
3068
|
+
xor = BINARY_OP("xor", 115, "i32", "i32", "i32");
|
|
3069
|
+
shl = BINARY_OP("shl", 116, "i32", "i32", "i32");
|
|
3070
|
+
shr_s = BINARY_OP("shr_s", 117, "i32", "i32", "i32");
|
|
3071
|
+
shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
|
|
3072
|
+
rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
|
|
3073
|
+
rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
|
|
3074
|
+
eqz = BINARY_OP("eqz", 69, "i32", "i32", "i32");
|
|
3075
|
+
eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
|
|
3076
|
+
ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
|
|
3077
|
+
trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
|
|
3078
|
+
trunc_f32_u = UNARY_OP("trunc_f32_u", 169, "f32", "i32");
|
|
3079
|
+
load = LOAD_OP("load", 40, "i32");
|
|
3080
|
+
load8_s = LOAD_OP("load8_s", 44, "i32");
|
|
3081
|
+
load8_u = LOAD_OP("load8_u", 45, "i32");
|
|
3082
|
+
load16_s = LOAD_OP("load16_s", 46, "i32");
|
|
3083
|
+
load16_u = LOAD_OP("load16_u", 47, "i32");
|
|
3084
|
+
store = STORE_OP("store", 54, "i32");
|
|
3085
|
+
store8 = STORE_OP("store8", 58, "i32");
|
|
3086
|
+
store16 = STORE_OP("store16", 59, "i32");
|
|
3087
|
+
reinterpret_f32 = UNARY_OP("reinterpret_f32", 188, "f32", "i32");
|
|
3088
|
+
trunc_sat_f32_s = UNARY_OP("trunc_sat_f32_s", [252, 0], "f32", "i32");
|
|
3089
|
+
trunc_sat_f32_u = UNARY_OP("trunc_sat_f32_u", [252, 1], "f32", "i32");
|
|
3090
|
+
};
|
|
3091
|
+
var F32 = class {
|
|
3092
|
+
constructor(cg) {
|
|
3093
|
+
this.cg = cg;
|
|
3094
|
+
}
|
|
3095
|
+
get typeId() {
|
|
3096
|
+
return 125;
|
|
3097
|
+
}
|
|
3098
|
+
get name() {
|
|
3099
|
+
return "f32";
|
|
3100
|
+
}
|
|
3101
|
+
const(f) {
|
|
3102
|
+
this.cg._emit(67);
|
|
3103
|
+
const buffer = /* @__PURE__ */ new ArrayBuffer(4);
|
|
3104
|
+
new DataView(buffer).setFloat32(0, f, true);
|
|
3105
|
+
const bytes = new Uint8Array(buffer);
|
|
3106
|
+
for (let i = 0; i < 4; i++) this.cg._emit(bytes[i]);
|
|
3107
|
+
this.cg._push(this);
|
|
3108
|
+
}
|
|
3109
|
+
eq = BINARY_OP("eq", 91, "f32", "f32", "i32");
|
|
3110
|
+
ne = BINARY_OP("ne", 92, "f32", "f32", "i32");
|
|
3111
|
+
lt = BINARY_OP("lt", 93, "f32", "f32", "i32");
|
|
3112
|
+
gt = BINARY_OP("gt", 94, "f32", "f32", "i32");
|
|
3113
|
+
le = BINARY_OP("le", 95, "f32", "f32", "i32");
|
|
3114
|
+
ge = BINARY_OP("ge", 96, "f32", "f32", "i32");
|
|
3115
|
+
abs = UNARY_OP("abs", 139, "f32", "f32");
|
|
3116
|
+
neg = UNARY_OP("neg", 140, "f32", "f32");
|
|
3117
|
+
ceil = UNARY_OP("ceil", 141, "f32", "f32");
|
|
3118
|
+
floor = UNARY_OP("floor", 142, "f32", "f32");
|
|
3119
|
+
trunc = UNARY_OP("trunc", 143, "f32", "f32");
|
|
3120
|
+
nearest = UNARY_OP("nearest", 144, "f32", "f32");
|
|
3121
|
+
sqrt = UNARY_OP("sqrt", 145, "f32", "f32");
|
|
3122
|
+
add = BINARY_OP("add", 146, "f32", "f32", "f32");
|
|
3123
|
+
sub = BINARY_OP("sub", 147, "f32", "f32", "f32");
|
|
3124
|
+
mul = BINARY_OP("mul", 148, "f32", "f32", "f32");
|
|
3125
|
+
div = BINARY_OP("div", 149, "f32", "f32", "f32");
|
|
3126
|
+
min = BINARY_OP("min", 150, "f32", "f32", "f32");
|
|
3127
|
+
max = BINARY_OP("max", 151, "f32", "f32", "f32");
|
|
3128
|
+
copysign = BINARY_OP("copysign", 152, "f32", "f32", "f32");
|
|
3129
|
+
convert_i32_s = UNARY_OP("convert_i32_s", 178, "i32", "f32");
|
|
3130
|
+
convert_i32_u = UNARY_OP("convert_i32_u", 179, "i32", "f32");
|
|
3131
|
+
load = LOAD_OP("load", 42, "f32");
|
|
3132
|
+
store = STORE_OP("store", 56, "f32");
|
|
3133
|
+
reinterpret_i32 = UNARY_OP("reinterpret_i32", 190, "i32", "f32");
|
|
3134
|
+
};
|
|
3135
|
+
function VECTOR_OP(op, vopcode, inTypes, outType) {
|
|
3136
|
+
return function() {
|
|
3137
|
+
for (const inType of inTypes.toReversed()) {
|
|
3138
|
+
const actualType = this.cg._pop();
|
|
3139
|
+
assert(actualType.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inTypes.join(", ")} -> ${outType})`);
|
|
3140
|
+
}
|
|
3141
|
+
this.cg._emit(encodeOpcode([253, vopcode]));
|
|
3142
|
+
this.cg._push(this.cg[outType]);
|
|
3143
|
+
};
|
|
3144
|
+
}
|
|
3145
|
+
function VECTOR_OPL(op, vopcode, inTypes, outType) {
|
|
3146
|
+
return function(lane) {
|
|
3147
|
+
for (const inType of inTypes.toReversed()) {
|
|
3148
|
+
const actualType = this.cg._pop();
|
|
3149
|
+
assert(actualType.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inTypes} -> ${outType})`);
|
|
3150
|
+
}
|
|
3151
|
+
this.cg._emit(encodeOpcode([253, vopcode]));
|
|
3152
|
+
this.cg._emit(lane);
|
|
3153
|
+
this.cg._push(this.cg[outType]);
|
|
3154
|
+
};
|
|
3155
|
+
}
|
|
3156
|
+
function VECTOR_LOAD_OP(op, vopcode) {
|
|
3157
|
+
return function(align = 0, offset = 0) {
|
|
3158
|
+
const idxType = this.cg._pop();
|
|
3159
|
+
assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
|
|
3160
|
+
this.cg._emit(encodeOpcode([253, vopcode]));
|
|
3161
|
+
this.cg._emit(encodeUnsigned(align));
|
|
3162
|
+
this.cg._emit(encodeUnsigned(offset));
|
|
3163
|
+
this.cg._push(this.cg.v128);
|
|
3164
|
+
};
|
|
3165
|
+
}
|
|
3166
|
+
var V128 = class {
|
|
3167
|
+
constructor(cg) {
|
|
3168
|
+
this.cg = cg;
|
|
3169
|
+
}
|
|
3170
|
+
get typeId() {
|
|
3171
|
+
return 123;
|
|
3172
|
+
}
|
|
3173
|
+
get name() {
|
|
3174
|
+
return "v128";
|
|
3175
|
+
}
|
|
3176
|
+
load = VECTOR_LOAD_OP("load", 0);
|
|
3177
|
+
load32x2_s = VECTOR_LOAD_OP("load32x2_s", 5);
|
|
3178
|
+
load32x2_u = VECTOR_LOAD_OP("load32x2_u", 6);
|
|
3179
|
+
load32_splat = VECTOR_LOAD_OP("load32_splat", 9);
|
|
3180
|
+
load32_zero = VECTOR_LOAD_OP("load32_zero", 92);
|
|
3181
|
+
store(align = 0, offset = 0) {
|
|
3182
|
+
const valType = this.cg._pop();
|
|
3183
|
+
assert(valType.typeId === this.cg.v128.typeId, `invalid type for store`);
|
|
3184
|
+
const idxType = this.cg._pop();
|
|
3185
|
+
assert(idxType.typeId === this.cg.i32.typeId, `invalid type for store`);
|
|
3186
|
+
this.cg._emit(253);
|
|
3187
|
+
this.cg._emit(encodeUnsigned(11));
|
|
3188
|
+
this.cg._emit(encodeUnsigned(align));
|
|
3189
|
+
this.cg._emit(encodeUnsigned(offset));
|
|
3190
|
+
}
|
|
3191
|
+
not = VECTOR_OP("not", 77, ["v128"], "v128");
|
|
3192
|
+
and = VECTOR_OP("and", 78, ["v128", "v128"], "v128");
|
|
3193
|
+
andnot = VECTOR_OP("andnot", 79, ["v128", "v128"], "v128");
|
|
3194
|
+
or = VECTOR_OP("or", 80, ["v128", "v128"], "v128");
|
|
3195
|
+
xor = VECTOR_OP("xor", 81, ["v128", "v128"], "v128");
|
|
3196
|
+
bitselect = VECTOR_OP("bitselect", 82, [
|
|
3197
|
+
"v128",
|
|
3198
|
+
"v128",
|
|
3199
|
+
"v128"
|
|
3200
|
+
], "v128");
|
|
3201
|
+
any_true = VECTOR_OP("any_true", 83, ["v128"], "i32");
|
|
3202
|
+
};
|
|
3203
|
+
var I32x4 = class extends V128 {
|
|
3204
|
+
splat = VECTOR_OP("splat", 17, ["i32"], "v128");
|
|
3205
|
+
extract_lane = VECTOR_OPL("extract_lane", 27, ["v128"], "i32");
|
|
3206
|
+
replace_lane = VECTOR_OPL("replace_lane", 28, ["v128", "i32"], "v128");
|
|
3207
|
+
eq = VECTOR_OP("eq", 55, ["v128", "v128"], "v128");
|
|
3208
|
+
ne = VECTOR_OP("ne", 56, ["v128", "v128"], "v128");
|
|
3209
|
+
lt_s = VECTOR_OP("lt_s", 57, ["v128", "v128"], "v128");
|
|
3210
|
+
lt_u = VECTOR_OP("lt_u", 58, ["v128", "v128"], "v128");
|
|
3211
|
+
gt_s = VECTOR_OP("gt_s", 59, ["v128", "v128"], "v128");
|
|
3212
|
+
gt_u = VECTOR_OP("gt_u", 60, ["v128", "v128"], "v128");
|
|
3213
|
+
le_s = VECTOR_OP("le_s", 61, ["v128", "v128"], "v128");
|
|
3214
|
+
le_u = VECTOR_OP("le_u", 62, ["v128", "v128"], "v128");
|
|
3215
|
+
ge_s = VECTOR_OP("ge_s", 63, ["v128", "v128"], "v128");
|
|
3216
|
+
ge_u = VECTOR_OP("ge_u", 64, ["v128", "v128"], "v128");
|
|
3217
|
+
abs = VECTOR_OP("abs", 160, ["v128"], "v128");
|
|
3218
|
+
neg = VECTOR_OP("neg", 161, ["v128"], "v128");
|
|
3219
|
+
all_true = VECTOR_OP("all_true", 163, ["v128"], "i32");
|
|
3220
|
+
bitmask = VECTOR_OP("bitmask", 164, ["v128"], "i32");
|
|
3221
|
+
shl = VECTOR_OP("shl", 171, ["v128", "i32"], "v128");
|
|
3222
|
+
shr_s = VECTOR_OP("shr_s", 172, ["v128", "i32"], "v128");
|
|
3223
|
+
shr_u = VECTOR_OP("shr_u", 173, ["v128", "i32"], "v128");
|
|
3224
|
+
add = VECTOR_OP("add", 174, ["v128", "v128"], "v128");
|
|
3225
|
+
sub = VECTOR_OP("sub", 177, ["v128", "v128"], "v128");
|
|
3226
|
+
mul = VECTOR_OP("mul", 181, ["v128", "v128"], "v128");
|
|
3227
|
+
min_s = VECTOR_OP("min_s", 182, ["v128", "v128"], "v128");
|
|
3228
|
+
min_u = VECTOR_OP("min_u", 183, ["v128", "v128"], "v128");
|
|
3229
|
+
max_s = VECTOR_OP("max_s", 184, ["v128", "v128"], "v128");
|
|
3230
|
+
max_u = VECTOR_OP("max_u", 185, ["v128", "v128"], "v128");
|
|
3231
|
+
};
|
|
3232
|
+
var F32x4 = class extends V128 {
|
|
3233
|
+
splat = VECTOR_OP("splat", 19, ["f32"], "v128");
|
|
3234
|
+
extract_lane = VECTOR_OPL("extract_lane", 31, ["v128"], "f32");
|
|
3235
|
+
replace_lane = VECTOR_OPL("replace_lane", 32, ["v128", "f32"], "v128");
|
|
3236
|
+
eq = VECTOR_OP("eq", 65, ["v128", "v128"], "v128");
|
|
3237
|
+
ne = VECTOR_OP("ne", 66, ["v128", "v128"], "v128");
|
|
3238
|
+
lt = VECTOR_OP("lt", 67, ["v128", "v128"], "v128");
|
|
3239
|
+
gt = VECTOR_OP("gt", 68, ["v128", "v128"], "v128");
|
|
3240
|
+
le = VECTOR_OP("le", 69, ["v128", "v128"], "v128");
|
|
3241
|
+
ge = VECTOR_OP("ge", 70, ["v128", "v128"], "v128");
|
|
3242
|
+
ceil = VECTOR_OP("ceil", 103, ["v128"], "v128");
|
|
3243
|
+
floor = VECTOR_OP("floor", 104, ["v128"], "v128");
|
|
3244
|
+
trunc = VECTOR_OP("trunc", 105, ["v128"], "v128");
|
|
3245
|
+
nearest = VECTOR_OP("nearest", 106, ["v128"], "v128");
|
|
3246
|
+
abs = VECTOR_OP("abs", 224, ["v128"], "v128");
|
|
3247
|
+
neg = VECTOR_OP("neg", 225, ["v128"], "v128");
|
|
3248
|
+
sqrt = VECTOR_OP("sqrt", 227, ["v128"], "v128");
|
|
3249
|
+
add = VECTOR_OP("add", 228, ["v128", "v128"], "v128");
|
|
3250
|
+
sub = VECTOR_OP("sub", 229, ["v128", "v128"], "v128");
|
|
3251
|
+
mul = VECTOR_OP("mul", 230, ["v128", "v128"], "v128");
|
|
3252
|
+
div = VECTOR_OP("div", 231, ["v128", "v128"], "v128");
|
|
3253
|
+
min = VECTOR_OP("min", 232, ["v128", "v128"], "v128");
|
|
3254
|
+
max = VECTOR_OP("max", 233, ["v128", "v128"], "v128");
|
|
3255
|
+
pmin = VECTOR_OP("pmin", 234, ["v128", "v128"], "v128");
|
|
3256
|
+
pmax = VECTOR_OP("pmax", 235, ["v128", "v128"], "v128");
|
|
3257
|
+
};
|
|
3258
|
+
|
|
3259
|
+
//#endregion
|
|
3260
|
+
//#region src/backend/wasm.ts
|
|
3261
|
+
/** Backend that compiles into WebAssembly bytecode for immediate execution. */
|
|
3262
|
+
var WasmBackend = class {
|
|
3263
|
+
type = "wasm";
|
|
3264
|
+
maxArgs = 64;
|
|
3265
|
+
#memory;
|
|
3266
|
+
#nextSlot;
|
|
3267
|
+
#allocator;
|
|
3268
|
+
#buffers;
|
|
3269
|
+
constructor() {
|
|
3270
|
+
this.#memory = new WebAssembly.Memory({ initial: 0 });
|
|
3271
|
+
this.#allocator = new WasmAllocator(this.#memory);
|
|
3272
|
+
this.#nextSlot = 1;
|
|
3273
|
+
this.#buffers = /* @__PURE__ */ new Map();
|
|
3274
|
+
}
|
|
3275
|
+
malloc(size, initialData) {
|
|
3276
|
+
const ptr = this.#allocator.malloc(size);
|
|
3277
|
+
if (initialData) {
|
|
3278
|
+
if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
|
|
3279
|
+
new Uint8Array(this.#memory.buffer, ptr, size).set(initialData);
|
|
3280
|
+
}
|
|
3281
|
+
const slot = this.#nextSlot++;
|
|
3282
|
+
this.#buffers.set(slot, {
|
|
3283
|
+
ptr,
|
|
3284
|
+
size,
|
|
3285
|
+
ref: 1
|
|
3286
|
+
});
|
|
3287
|
+
return slot;
|
|
3288
|
+
}
|
|
3289
|
+
incRef(slot) {
|
|
3290
|
+
const buffer = this.#buffers.get(slot);
|
|
3291
|
+
if (!buffer) throw new SlotError(slot);
|
|
3292
|
+
buffer.ref++;
|
|
3293
|
+
}
|
|
3294
|
+
decRef(slot) {
|
|
3295
|
+
const buffer = this.#buffers.get(slot);
|
|
3296
|
+
if (!buffer) throw new SlotError(slot);
|
|
3297
|
+
buffer.ref--;
|
|
3298
|
+
if (buffer.ref === 0) {
|
|
3299
|
+
this.#allocator.free(buffer.ptr);
|
|
3300
|
+
this.#buffers.delete(slot);
|
|
3301
|
+
}
|
|
3302
|
+
}
|
|
3303
|
+
async read(slot, start, count) {
|
|
3304
|
+
return this.readSync(slot, start, count);
|
|
3305
|
+
}
|
|
3306
|
+
readSync(slot, start, count) {
|
|
3307
|
+
const buffer = this.#getBuffer(slot);
|
|
3308
|
+
if (start === void 0) start = 0;
|
|
3309
|
+
if (count === void 0) count = buffer.byteLength - start;
|
|
3310
|
+
return buffer.slice(start, start + count);
|
|
3311
|
+
}
|
|
3312
|
+
async prepare(kernel) {
|
|
3313
|
+
return this.prepareSync(kernel);
|
|
3314
|
+
}
|
|
3315
|
+
prepareSync(kernel) {
|
|
3316
|
+
const bytes = codegenWasm(kernel);
|
|
3317
|
+
const module$1 = new WebAssembly.Module(bytes);
|
|
3318
|
+
return new Executable(kernel, { module: module$1 });
|
|
3319
|
+
}
|
|
3320
|
+
dispatch(exe, inputs, outputs) {
|
|
3321
|
+
const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
|
|
3322
|
+
const func = instance.exports.kernel;
|
|
3323
|
+
const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
|
|
3324
|
+
func(...ptrs);
|
|
3325
|
+
}
|
|
3326
|
+
#getBuffer(slot) {
|
|
3327
|
+
const buffer = this.#buffers.get(slot);
|
|
3328
|
+
if (!buffer) throw new SlotError(slot);
|
|
3329
|
+
return new Uint8Array(this.#memory.buffer, buffer.ptr, buffer.size);
|
|
3330
|
+
}
|
|
3331
|
+
};
|
|
3332
|
+
function codegenWasm(kernel) {
|
|
3333
|
+
const tune = tuneNullopt(kernel);
|
|
3334
|
+
const re = kernel.reduction;
|
|
3335
|
+
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
3336
|
+
const cg = new CodeGenerator();
|
|
3337
|
+
cg.memory.import("env", "memory");
|
|
3338
|
+
const distinctOps = union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
|
|
3339
|
+
const funcs = {};
|
|
3340
|
+
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3341
|
+
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
3342
|
+
if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
|
|
3343
|
+
if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
|
|
3344
|
+
if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
|
|
3345
|
+
const kernelFunc = cg.function(rep(kernel.nargs + 1, cg.i32), [], () => {
|
|
3346
|
+
const gidx = cg.local.declare(cg.i32);
|
|
3347
|
+
cg.loop(cg.void);
|
|
3348
|
+
cg.block(cg.void);
|
|
3349
|
+
cg.local.get(gidx);
|
|
3350
|
+
cg.i32.const(kernel.size);
|
|
3351
|
+
cg.i32.ge_u();
|
|
3352
|
+
cg.br_if(0);
|
|
3353
|
+
cg.local.get(kernel.nargs);
|
|
3354
|
+
cg.local.get(gidx);
|
|
3355
|
+
cg.i32.const(byteWidth(kernel.dtype));
|
|
3356
|
+
cg.i32.mul();
|
|
3357
|
+
cg.i32.add();
|
|
3358
|
+
if (re) {
|
|
3359
|
+
const acc = cg.local.declare(dty(cg, null, kernel.exp.dtype));
|
|
3360
|
+
dty(cg, null, kernel.exp.dtype).const(re.identity);
|
|
3361
|
+
cg.local.set(acc);
|
|
3362
|
+
const ridx = cg.local.declare(cg.i32);
|
|
3363
|
+
cg.i32.const(0);
|
|
3364
|
+
cg.local.set(ridx);
|
|
3365
|
+
cg.loop(cg.void);
|
|
3366
|
+
cg.block(cg.void);
|
|
3367
|
+
cg.local.get(ridx);
|
|
3368
|
+
cg.i32.const(re.size);
|
|
3369
|
+
cg.i32.ge_u();
|
|
3370
|
+
cg.br_if(0);
|
|
3371
|
+
translateExp(cg, funcs, tune.exp, {
|
|
3372
|
+
gidx,
|
|
3373
|
+
ridx
|
|
3374
|
+
});
|
|
3375
|
+
if (re.op === AluOp.Add) {
|
|
3376
|
+
cg.local.get(acc);
|
|
3377
|
+
if (re.dtype === DType.Bool) cg.i32.or();
|
|
3378
|
+
else dty(cg, re.op, re.dtype).add();
|
|
3379
|
+
} else if (re.op === AluOp.Mul) {
|
|
3380
|
+
cg.local.get(acc);
|
|
3381
|
+
if (re.dtype === DType.Bool) cg.i32.and();
|
|
3382
|
+
else dty(cg, re.op, re.dtype).mul();
|
|
3383
|
+
} else if (re.op === AluOp.Min || re.op === AluOp.Max) if (re.dtype === DType.Float32) {
|
|
3384
|
+
cg.local.get(acc);
|
|
3385
|
+
if (re.op === AluOp.Min) cg.f32.min();
|
|
3386
|
+
else cg.f32.max();
|
|
3387
|
+
} else if ([
|
|
3388
|
+
DType.Int32,
|
|
3389
|
+
DType.Uint32,
|
|
3390
|
+
DType.Bool
|
|
3391
|
+
].includes(re.dtype)) {
|
|
3392
|
+
const local = cg.local.declare(cg.i32);
|
|
3393
|
+
cg.local.tee(local);
|
|
3394
|
+
cg.local.get(acc);
|
|
3395
|
+
cg.local.get(local);
|
|
3396
|
+
cg.local.get(acc);
|
|
3397
|
+
if (re.op === AluOp.Min) if (re.dtype === DType.Int32) cg.i32.lt_s();
|
|
3398
|
+
else cg.i32.lt_u();
|
|
3399
|
+
else if (re.dtype === DType.Int32) cg.i32.gt_s();
|
|
3400
|
+
else cg.i32.gt_u();
|
|
3401
|
+
cg.select();
|
|
3402
|
+
} else throw new Error(`invalid reduction min/max over ${re.dtype}`);
|
|
3403
|
+
else throw new Error(`invalid wasm reduction op: ${re.op}`);
|
|
3404
|
+
cg.local.set(acc);
|
|
3405
|
+
cg.local.get(ridx);
|
|
3406
|
+
cg.i32.const(1);
|
|
3407
|
+
cg.i32.add();
|
|
3408
|
+
cg.local.set(ridx);
|
|
3409
|
+
cg.br(1);
|
|
3410
|
+
cg.end();
|
|
3411
|
+
cg.end();
|
|
3412
|
+
translateExp(cg, funcs, kernel.reduction.epilogue, { acc });
|
|
3413
|
+
} else translateExp(cg, funcs, tune.exp, { gidx });
|
|
3414
|
+
dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
|
|
3415
|
+
cg.local.get(gidx);
|
|
3416
|
+
cg.i32.const(1);
|
|
3417
|
+
cg.i32.add();
|
|
3418
|
+
cg.local.set(gidx);
|
|
3419
|
+
cg.br(1);
|
|
3420
|
+
cg.end();
|
|
3421
|
+
cg.end();
|
|
3422
|
+
});
|
|
3423
|
+
cg.export(kernelFunc, "kernel");
|
|
3424
|
+
return cg.finish();
|
|
3425
|
+
}
|
|
3426
|
+
function translateExp(cg, funcs, exp, ctx) {
|
|
3427
|
+
const references = /* @__PURE__ */ new Map();
|
|
3428
|
+
const seen = /* @__PURE__ */ new Set();
|
|
3429
|
+
const countReferences = (exp$1) => {
|
|
3430
|
+
references.set(exp$1, (references.get(exp$1) ?? 0) + 1);
|
|
3431
|
+
if (!seen.has(exp$1)) {
|
|
3432
|
+
seen.add(exp$1);
|
|
3433
|
+
for (const src of exp$1.src) countReferences(src);
|
|
3434
|
+
}
|
|
3435
|
+
};
|
|
3436
|
+
const expContext = /* @__PURE__ */ new Map();
|
|
3437
|
+
const gen = (exp$1) => {
|
|
3438
|
+
if (expContext.has(exp$1)) return cg.local.get(expContext.get(exp$1));
|
|
3439
|
+
const { op, src, dtype, arg } = exp$1;
|
|
3440
|
+
if (AluGroup.Binary.has(op) || AluGroup.Compare.has(op)) {
|
|
3441
|
+
gen(src[0]);
|
|
3442
|
+
gen(src[1]);
|
|
3443
|
+
if (op === AluOp.Add) if (dtype === DType.Bool) cg.i32.or();
|
|
3444
|
+
else dty(cg, op, dtype).add();
|
|
3445
|
+
else if (op === AluOp.Sub) dty(cg, op, dtype).sub();
|
|
3446
|
+
else if (op === AluOp.Mul) if (dtype === DType.Bool) cg.i32.and();
|
|
3447
|
+
else dty(cg, op, dtype).mul();
|
|
3448
|
+
else if (op === AluOp.Idiv) if (dtype === DType.Float32) cg.f32.div(), cg.f32.trunc();
|
|
3449
|
+
else if (dtype === DType.Uint32) cg.i32.div_u();
|
|
3450
|
+
else if (dtype === DType.Int32) cg.i32.div_s();
|
|
3451
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3452
|
+
else if (op === AluOp.Mod) if (dtype === DType.Float32) {
|
|
3453
|
+
const a = cg.local.declare(cg.f32);
|
|
3454
|
+
const b = cg.local.declare(cg.f32);
|
|
3455
|
+
cg.local.set(b);
|
|
3456
|
+
cg.local.tee(a);
|
|
3457
|
+
cg.local.get(a);
|
|
3458
|
+
cg.local.get(b);
|
|
3459
|
+
cg.f32.div();
|
|
3460
|
+
cg.f32.trunc();
|
|
3461
|
+
cg.local.get(b);
|
|
3462
|
+
cg.f32.mul();
|
|
3463
|
+
cg.f32.sub();
|
|
3464
|
+
} else if (dtype === DType.Uint32) cg.i32.rem_u();
|
|
3465
|
+
else if (dtype === DType.Int32) cg.i32.rem_s();
|
|
3466
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3467
|
+
else if (op === AluOp.Min || op === AluOp.Max) if (dtype === DType.Float32) if (op === AluOp.Min) cg.f32.min();
|
|
3468
|
+
else cg.f32.max();
|
|
3469
|
+
else if (dtype === DType.Int32 || dtype === DType.Uint32) {
|
|
3470
|
+
const a = cg.local.declare(cg.i32);
|
|
3471
|
+
const b = cg.local.declare(cg.i32);
|
|
3472
|
+
cg.local.set(b);
|
|
3473
|
+
cg.local.tee(a);
|
|
3474
|
+
cg.local.get(b);
|
|
3475
|
+
cg.local.get(a);
|
|
3476
|
+
cg.local.get(b);
|
|
3477
|
+
if (dtype === DType.Int32) if (op === AluOp.Min) cg.i32.lt_s();
|
|
3478
|
+
else cg.i32.gt_s();
|
|
3479
|
+
else if (op === AluOp.Min) cg.i32.lt_u();
|
|
3480
|
+
else cg.i32.gt_u();
|
|
3481
|
+
cg.select();
|
|
3482
|
+
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3483
|
+
else if (op === AluOp.Cmplt) {
|
|
3484
|
+
const srcDtype = src[0].dtype;
|
|
3485
|
+
if (srcDtype === DType.Float32) cg.f32.lt();
|
|
3486
|
+
else if (srcDtype === DType.Int32) cg.i32.lt_s();
|
|
3487
|
+
else if (srcDtype === DType.Uint32) cg.i32.lt_u();
|
|
3488
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3489
|
+
} else if (op === AluOp.Cmpne) dty(cg, op, src[0].dtype).ne();
|
|
3490
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3491
|
+
} else if (AluGroup.Unary.has(op)) if (op === AluOp.Sin) gen(src[0]), cg.call(funcs.sin);
|
|
3492
|
+
else if (op === AluOp.Cos) gen(src[0]), cg.call(funcs.cos);
|
|
3493
|
+
else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
|
|
3494
|
+
else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
|
|
3495
|
+
else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
|
|
3496
|
+
else if (op === AluOp.Reciprocal) cg.f32.const(1), gen(src[0]), cg.f32.div();
|
|
3497
|
+
else if (op === AluOp.Cast) {
|
|
3498
|
+
gen(src[0]);
|
|
3499
|
+
const dtype0 = src[0].dtype;
|
|
3500
|
+
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
|
|
3501
|
+
if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
|
|
3502
|
+
else if (i32repr);
|
|
3503
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3504
|
+
else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
|
|
3505
|
+
else if (i32repr);
|
|
3506
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3507
|
+
else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
|
|
3508
|
+
else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
|
|
3509
|
+
else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
|
|
3510
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3511
|
+
else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
|
|
3512
|
+
else if (i32repr) cg.i32.const(0), cg.i32.ne();
|
|
3513
|
+
else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
|
|
3514
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3515
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3516
|
+
} else if (op === AluOp.Bitcast) {
|
|
3517
|
+
gen(src[0]);
|
|
3518
|
+
const dtype0 = src[0].dtype;
|
|
3519
|
+
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
|
|
3520
|
+
if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
|
|
3521
|
+
else if (i32repr);
|
|
3522
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3523
|
+
else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
|
|
3524
|
+
else if (dtype0 === DType.Float32);
|
|
3525
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3526
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3527
|
+
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3528
|
+
else if (op === AluOp.Where) {
|
|
3529
|
+
gen(src[1]);
|
|
3530
|
+
gen(src[2]);
|
|
3531
|
+
gen(src[0]);
|
|
3532
|
+
cg.select();
|
|
3533
|
+
} else if (op === AluOp.Threefry2x32) {
|
|
3534
|
+
for (let i = 0; i < 4; i++) gen(src[i]);
|
|
3535
|
+
cg.call(funcs.threefry2x32);
|
|
3536
|
+
if (arg === "xor") cg.i32.xor();
|
|
3537
|
+
else if (arg === 0) cg.drop();
|
|
3538
|
+
else if (arg === 1) {
|
|
3539
|
+
const local = cg.local.declare(cg.i32);
|
|
3540
|
+
cg.local.set(local);
|
|
3541
|
+
cg.drop();
|
|
3542
|
+
cg.local.get(local);
|
|
3543
|
+
} else throw new UnsupportedOpError(op, dtype, "wasm", arg);
|
|
3544
|
+
} else if (op === AluOp.Const) return dty(cg, op, dtype).const(arg);
|
|
3545
|
+
else if (op === AluOp.Special) return cg.local.get(ctx[arg[0]]);
|
|
3546
|
+
else if (op === AluOp.Variable) return cg.local.get(ctx[arg]);
|
|
3547
|
+
else if (op === AluOp.GlobalIndex) {
|
|
3548
|
+
const [gid, len] = arg;
|
|
3549
|
+
gen(src[0]);
|
|
3550
|
+
const local = cg.local.declare(cg.i32);
|
|
3551
|
+
cg.local.tee(local);
|
|
3552
|
+
cg.i32.const(0);
|
|
3553
|
+
cg.local.get(local), cg.i32.const(len), cg.i32.lt_u();
|
|
3554
|
+
cg.select();
|
|
3555
|
+
cg.i32.const(byteWidth(dtype));
|
|
3556
|
+
cg.i32.mul();
|
|
3557
|
+
cg.local.get(gid);
|
|
3558
|
+
cg.i32.add();
|
|
3559
|
+
dty(cg, op, dtype).load(Math.log2(byteWidth(dtype)));
|
|
3560
|
+
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3561
|
+
if ((references.get(exp$1) ?? 0) > 1) {
|
|
3562
|
+
const local = cg.local.declare(dty(cg, op, dtype));
|
|
3563
|
+
cg.local.tee(local);
|
|
3564
|
+
expContext.set(exp$1, local);
|
|
3565
|
+
}
|
|
3566
|
+
};
|
|
3567
|
+
countReferences(exp);
|
|
3568
|
+
gen(exp);
|
|
3569
|
+
}
|
|
3570
|
+
function dty(cg, op, dtype) {
|
|
3571
|
+
switch (dtype) {
|
|
3572
|
+
case DType.Float32: return cg.f32;
|
|
3573
|
+
case DType.Int32:
|
|
3574
|
+
case DType.Uint32:
|
|
3575
|
+
case DType.Bool: return cg.i32;
|
|
3576
|
+
default: throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3577
|
+
}
|
|
3578
|
+
}
|
|
3579
|
+
|
|
1814
3580
|
//#endregion
|
|
1815
3581
|
//#region src/backend.ts
|
|
1816
|
-
const devices = [
|
|
1817
|
-
|
|
3582
|
+
const devices = [
|
|
3583
|
+
"cpu",
|
|
3584
|
+
"wasm",
|
|
3585
|
+
"webgpu"
|
|
3586
|
+
];
|
|
3587
|
+
let defaultBackend = "wasm";
|
|
1818
3588
|
const initializedBackends = /* @__PURE__ */ new Map();
|
|
1819
|
-
initializedBackends.set("cpu", new
|
|
3589
|
+
initializedBackends.set("cpu", new CpuBackend());
|
|
3590
|
+
initializedBackends.set("wasm", new WasmBackend());
|
|
1820
3591
|
/** Set the default device backend (must be initialized). */
|
|
1821
3592
|
function setDevice(device) {
|
|
1822
3593
|
if (initializedBackends.has(device)) defaultBackend = device;
|
|
@@ -1841,12 +3612,13 @@ async function init(...devicesToInit) {
|
|
|
1841
3612
|
}
|
|
1842
3613
|
/** Create a backend, if available. Internal function called by `init()`. */
|
|
1843
3614
|
async function createBackend(device) {
|
|
1844
|
-
if (device === "cpu") return new
|
|
3615
|
+
if (device === "cpu") return new CpuBackend();
|
|
3616
|
+
else if (device === "wasm") return new WasmBackend();
|
|
1845
3617
|
else if (device === "webgpu") {
|
|
1846
3618
|
if (!navigator.gpu) return null;
|
|
1847
3619
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
1848
3620
|
if (!adapter) return null;
|
|
1849
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
3621
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-fqhx41TC.cjs"));
|
|
1850
3622
|
const importantLimits = [
|
|
1851
3623
|
"maxBufferSize",
|
|
1852
3624
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -1859,8 +3631,12 @@ async function createBackend(device) {
|
|
|
1859
3631
|
"maxStorageBuffersPerShaderStage",
|
|
1860
3632
|
"maxStorageTexturesPerShaderStage"
|
|
1861
3633
|
];
|
|
3634
|
+
const requestedFeatures = ["shader-f16", "timestamp-query"];
|
|
1862
3635
|
try {
|
|
1863
|
-
const device$1 = await adapter.requestDevice({
|
|
3636
|
+
const device$1 = await adapter.requestDevice({
|
|
3637
|
+
requiredLimits: Object.fromEntries(importantLimits.map((limit) => [limit, adapter.limits[limit]])),
|
|
3638
|
+
requiredFeatures: requestedFeatures.filter((feature) => adapter.features.has(feature))
|
|
3639
|
+
});
|
|
1864
3640
|
return new WebGPUBackend(device$1);
|
|
1865
3641
|
} catch (error) {
|
|
1866
3642
|
console.error("Unexpected error requesting WebGPU device:", error);
|
|
@@ -1886,6 +3662,13 @@ var SlotError = class extends Error {
|
|
|
1886
3662
|
super(`Used a buffer that is invalid or already freed: ${slot}`);
|
|
1887
3663
|
}
|
|
1888
3664
|
};
|
|
3665
|
+
var UnsupportedOpError = class extends Error {
|
|
3666
|
+
constructor(op, dtype, device, arg) {
|
|
3667
|
+
let msg = `${op || ""}<${dtype}> not supported in ${device} backend`;
|
|
3668
|
+
if (arg !== void 0) msg += ` with arg ${JSON.stringify(arg)}`;
|
|
3669
|
+
super(msg);
|
|
3670
|
+
}
|
|
3671
|
+
};
|
|
1889
3672
|
|
|
1890
3673
|
//#endregion
|
|
1891
3674
|
Object.defineProperty(exports, 'AluExp', {
|
|
@@ -1966,6 +3749,12 @@ Object.defineProperty(exports, 'SlotError', {
|
|
|
1966
3749
|
return SlotError;
|
|
1967
3750
|
}
|
|
1968
3751
|
});
|
|
3752
|
+
Object.defineProperty(exports, 'UnsupportedOpError', {
|
|
3753
|
+
enumerable: true,
|
|
3754
|
+
get: function () {
|
|
3755
|
+
return UnsupportedOpError;
|
|
3756
|
+
}
|
|
3757
|
+
});
|
|
1969
3758
|
Object.defineProperty(exports, 'accessorAluExp', {
|
|
1970
3759
|
enumerable: true,
|
|
1971
3760
|
get: function () {
|
|
@@ -2008,6 +3797,12 @@ Object.defineProperty(exports, 'dtypedArray', {
|
|
|
2008
3797
|
return dtypedArray;
|
|
2009
3798
|
}
|
|
2010
3799
|
});
|
|
3800
|
+
Object.defineProperty(exports, 'dtypedJsArray', {
|
|
3801
|
+
enumerable: true,
|
|
3802
|
+
get: function () {
|
|
3803
|
+
return dtypedJsArray;
|
|
3804
|
+
}
|
|
3805
|
+
});
|
|
2011
3806
|
Object.defineProperty(exports, 'findPow2', {
|
|
2012
3807
|
enumerable: true,
|
|
2013
3808
|
get: function () {
|
|
@@ -2110,6 +3905,12 @@ Object.defineProperty(exports, 'tuneWebgpu', {
|
|
|
2110
3905
|
return tuneWebgpu;
|
|
2111
3906
|
}
|
|
2112
3907
|
});
|
|
3908
|
+
Object.defineProperty(exports, 'union', {
|
|
3909
|
+
enumerable: true,
|
|
3910
|
+
get: function () {
|
|
3911
|
+
return union;
|
|
3912
|
+
}
|
|
3913
|
+
});
|
|
2113
3914
|
Object.defineProperty(exports, 'unravelAlu', {
|
|
2114
3915
|
enumerable: true,
|
|
2115
3916
|
get: function () {
|
|
@@ -2127,4 +3928,10 @@ Object.defineProperty(exports, 'zip', {
|
|
|
2127
3928
|
get: function () {
|
|
2128
3929
|
return zip;
|
|
2129
3930
|
}
|
|
3931
|
+
});
|
|
3932
|
+
Object.defineProperty(exports, 'zipn', {
|
|
3933
|
+
enumerable: true,
|
|
3934
|
+
get: function () {
|
|
3935
|
+
return zipn;
|
|
3936
|
+
}
|
|
2130
3937
|
});
|