@jax-js/jax 0.0.2 → 0.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +9 -8
- package/dist/{backend-1eVbAoaV.js → backend-BqDtPGaR.js} +1869 -86
- package/dist/{backend-BK21PBVP.cjs → backend-D2C4MJRP.cjs} +1892 -85
- package/dist/index.cjs +737 -118
- package/dist/index.d.cts +247 -44
- package/dist/index.d.ts +247 -44
- package/dist/index.js +726 -114
- package/dist/{webgpu-JVpVad6g.js → webgpu-CNg9JGva.js} +54 -33
- package/dist/{webgpu-c5Fe8nx8.cjs → webgpu-fqhx41TC.cjs} +54 -33
- package/package.json +7 -6
|
@@ -48,6 +48,10 @@ function unzip2(pairs) {
|
|
|
48
48
|
function zip(xs, ys) {
|
|
49
49
|
return xs.map((x, i) => [x, ys[i]]);
|
|
50
50
|
}
|
|
51
|
+
function zipn(...arrays) {
|
|
52
|
+
const minLength = Math.min(...arrays.map((x) => x.length));
|
|
53
|
+
return Array.from({ length: minLength }, (_, i) => arrays.map((arr) => arr[i]));
|
|
54
|
+
}
|
|
51
55
|
function rep(length, value) {
|
|
52
56
|
if (value instanceof Function) return new Array(length).fill(0).map((_, i) => value(i));
|
|
53
57
|
return new Array(length).fill(value);
|
|
@@ -55,6 +59,11 @@ function rep(length, value) {
|
|
|
55
59
|
function prod(arr) {
|
|
56
60
|
return arr.reduce((acc, x) => acc * x, 1);
|
|
57
61
|
}
|
|
62
|
+
function gcd(...values) {
|
|
63
|
+
let a = 0;
|
|
64
|
+
for (let b of values) while (b !== 0) [a, b] = [b, a % b];
|
|
65
|
+
return Math.abs(a);
|
|
66
|
+
}
|
|
58
67
|
/** Shorthand for integer division, like in Python. */
|
|
59
68
|
function intdiv(a, b) {
|
|
60
69
|
return Math.floor(a / b);
|
|
@@ -72,6 +81,11 @@ function deepEqual(a, b) {
|
|
|
72
81
|
for (const key of Object.keys(a)) if (!deepEqual(a[key], b[key])) return false;
|
|
73
82
|
return true;
|
|
74
83
|
}
|
|
84
|
+
function union(...sets) {
|
|
85
|
+
const result = /* @__PURE__ */ new Set();
|
|
86
|
+
for (const s of sets) if (s) for (const x of s) result.add(x);
|
|
87
|
+
return result;
|
|
88
|
+
}
|
|
75
89
|
/** Splits the list based on a condition, `false` first then `true`. */
|
|
76
90
|
function partitionList(which, array) {
|
|
77
91
|
const falseList = [];
|
|
@@ -216,12 +230,13 @@ function runWithCache(cache, key, thunk) {
|
|
|
216
230
|
|
|
217
231
|
//#endregion
|
|
218
232
|
//#region src/alu.ts
|
|
233
|
+
/** A numerical data type for array contents. */
|
|
219
234
|
let DType = /* @__PURE__ */ function(DType$1) {
|
|
220
235
|
DType$1["Float32"] = "float32";
|
|
221
236
|
DType$1["Int32"] = "int32";
|
|
222
237
|
DType$1["Uint32"] = "uint32";
|
|
223
238
|
DType$1["Bool"] = "bool";
|
|
224
|
-
DType$1["
|
|
239
|
+
DType$1["Float16"] = "float16";
|
|
225
240
|
return DType$1;
|
|
226
241
|
}({});
|
|
227
242
|
const byteWidth = (dtype) => {
|
|
@@ -230,17 +245,30 @@ const byteWidth = (dtype) => {
|
|
|
230
245
|
case DType.Int32:
|
|
231
246
|
case DType.Uint32:
|
|
232
247
|
case DType.Bool: return 4;
|
|
233
|
-
case DType.
|
|
248
|
+
case DType.Float16: return 2;
|
|
234
249
|
default: throw new TypeError(`Unknown dtype: ${dtype}`);
|
|
235
250
|
}
|
|
236
251
|
};
|
|
237
|
-
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.
|
|
252
|
+
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
|
|
238
253
|
function dtypedArray(dtype, data) {
|
|
254
|
+
const { buffer, byteLength, byteOffset } = data;
|
|
255
|
+
const length = byteLength / byteWidth(dtype);
|
|
256
|
+
switch (dtype) {
|
|
257
|
+
case DType.Float32: return new Float32Array(buffer, byteOffset, length);
|
|
258
|
+
case DType.Int32:
|
|
259
|
+
case DType.Bool: return new Int32Array(buffer, byteOffset, length);
|
|
260
|
+
case DType.Uint32: return new Uint32Array(buffer, byteOffset, length);
|
|
261
|
+
case DType.Float16: return new Float16Array(buffer, byteOffset, length);
|
|
262
|
+
default: throw new Error(`Unimplemented dtype: ${dtype}`);
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
function dtypedJsArray(dtype, data) {
|
|
239
266
|
switch (dtype) {
|
|
240
267
|
case DType.Float32: return new Float32Array(data);
|
|
241
|
-
case DType.Int32:
|
|
242
|
-
case DType.Uint32: return new Uint32Array(data);
|
|
268
|
+
case DType.Int32:
|
|
243
269
|
case DType.Bool: return new Int32Array(data);
|
|
270
|
+
case DType.Uint32: return new Uint32Array(data);
|
|
271
|
+
case DType.Float16: return new Float16Array(data);
|
|
244
272
|
default: throw new Error(`Unimplemented dtype: ${dtype}`);
|
|
245
273
|
}
|
|
246
274
|
}
|
|
@@ -297,6 +325,9 @@ var AluExp = class AluExp {
|
|
|
297
325
|
static log(a) {
|
|
298
326
|
return new AluExp(AluOp.Log, a.dtype, [a]);
|
|
299
327
|
}
|
|
328
|
+
static sqrt(a) {
|
|
329
|
+
return new AluExp(AluOp.Sqrt, a.dtype, [a]);
|
|
330
|
+
}
|
|
300
331
|
static reciprocal(a) {
|
|
301
332
|
return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
|
|
302
333
|
}
|
|
@@ -342,24 +373,27 @@ var AluExp = class AluExp {
|
|
|
342
373
|
static variable(dtype, name) {
|
|
343
374
|
return new AluExp(AluOp.Variable, dtype, [], name);
|
|
344
375
|
}
|
|
345
|
-
static globalIndex(dtype, gid, bufidx) {
|
|
346
|
-
return new AluExp(AluOp.GlobalIndex, dtype, [bufidx], gid);
|
|
376
|
+
static globalIndex(dtype, gid, len, bufidx) {
|
|
377
|
+
return new AluExp(AluOp.GlobalIndex, dtype, [bufidx], [gid, len]);
|
|
347
378
|
}
|
|
348
379
|
static globalView(dtype, gid, st, indices) {
|
|
349
380
|
return new AluExp(AluOp.GlobalView, dtype, indices, [gid, st]);
|
|
350
381
|
}
|
|
382
|
+
static f32(value) {
|
|
383
|
+
return AluExp.const(DType.Float32, value);
|
|
384
|
+
}
|
|
351
385
|
static i32(value) {
|
|
352
386
|
return AluExp.const(DType.Int32, value);
|
|
353
387
|
}
|
|
354
388
|
static u32(value) {
|
|
355
389
|
return AluExp.const(DType.Uint32, value);
|
|
356
390
|
}
|
|
357
|
-
static f32(value) {
|
|
358
|
-
return AluExp.const(DType.Float32, value);
|
|
359
|
-
}
|
|
360
391
|
static bool(value) {
|
|
361
392
|
return AluExp.const(DType.Bool, Number(value));
|
|
362
393
|
}
|
|
394
|
+
static f16(value) {
|
|
395
|
+
return AluExp.const(DType.Float16, value);
|
|
396
|
+
}
|
|
363
397
|
not() {
|
|
364
398
|
if (this.dtype !== DType.Bool) throw new Error("not() can only be called on boolean expressions");
|
|
365
399
|
return AluExp.cmpne(this, AluExp.const(DType.Bool, true));
|
|
@@ -389,9 +423,9 @@ var AluExp = class AluExp {
|
|
|
389
423
|
reindexGids(gidMap) {
|
|
390
424
|
return this.rewrite((exp) => {
|
|
391
425
|
if (exp.op === AluOp.GlobalIndex) {
|
|
392
|
-
const gid = exp.arg;
|
|
426
|
+
const [gid, len] = exp.arg;
|
|
393
427
|
const newGid = gidMap.get(gid);
|
|
394
|
-
if (newGid !== void 0 && newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, exp.src[0]);
|
|
428
|
+
if (newGid !== void 0 && newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
|
|
395
429
|
} else if (exp.op === AluOp.GlobalView) {
|
|
396
430
|
const gid = exp.arg[0];
|
|
397
431
|
const newGid = gidMap.get(gid);
|
|
@@ -420,17 +454,16 @@ var AluExp = class AluExp {
|
|
|
420
454
|
case AluOp.Sub:
|
|
421
455
|
ret = [src[0].min - src[1].max, src[0].max - src[1].min];
|
|
422
456
|
break;
|
|
423
|
-
case AluOp.Mul:
|
|
457
|
+
case AluOp.Mul:
|
|
424
458
|
ret = minMax4((a, b) => a * b);
|
|
425
459
|
break;
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
ret = minMax4((a, b) => Math.floor(a / b));
|
|
460
|
+
case AluOp.Idiv:
|
|
461
|
+
ret = minMax4((a, b) => Math.trunc(a / b));
|
|
429
462
|
break;
|
|
430
|
-
}
|
|
431
463
|
case AluOp.Mod: {
|
|
432
464
|
let divisorRange = src[1].#computeRange();
|
|
433
465
|
if (divisorRange[0] <= 0 && divisorRange[1] >= 0) divisorRange = [0, Math.max(-divisorRange[0], divisorRange[1])];
|
|
466
|
+
if (divisorRange[1] < 0) divisorRange = [-divisorRange[1], -divisorRange[0]];
|
|
434
467
|
const maxDivisor = isFloatDtype(this.dtype) ? divisorRange[1] : divisorRange[1] - 1;
|
|
435
468
|
ret = [clamp(src[0].min, -maxDivisor, 0), clamp(src[0].max, 0, maxDivisor)];
|
|
436
469
|
break;
|
|
@@ -453,23 +486,31 @@ var AluExp = class AluExp {
|
|
|
453
486
|
case AluOp.Log:
|
|
454
487
|
ret = [Math.log(src[0].min), Math.log(src[0].max)];
|
|
455
488
|
break;
|
|
489
|
+
case AluOp.Sqrt:
|
|
490
|
+
ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
|
|
491
|
+
break;
|
|
456
492
|
case AluOp.Reciprocal:
|
|
457
493
|
if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
|
|
458
494
|
ret = [1 / src[0].max, 1 / src[0].min];
|
|
459
495
|
break;
|
|
460
|
-
case AluOp.Cast:
|
|
496
|
+
case AluOp.Cast: {
|
|
497
|
+
const wasFloat = isFloatDtype(src[0].dtype);
|
|
498
|
+
const bounded = Number.isFinite(src[0].min) && Number.isFinite(src[0].max);
|
|
461
499
|
if (this.dtype === DType.Bool) {
|
|
462
500
|
const canBeZero = src[0].min <= 0 && src[0].max >= 0;
|
|
463
501
|
const mustBeZero = src[0].min === 0 && src[0].max === 0;
|
|
464
502
|
ret = mustBeZero ? [0, 0] : canBeZero ? [0, 1] : [1, 1];
|
|
465
|
-
} else if (this.dtype === DType.Int32)
|
|
466
|
-
|
|
467
|
-
const
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
503
|
+
} else if (this.dtype === DType.Int32) {
|
|
504
|
+
const a = wasFloat ? clamp(src[0].min, -2147483648, 2147483647) | 0 : src[0].min | 0;
|
|
505
|
+
const b = wasFloat ? clamp(src[0].max, -2147483648, 2147483647) | 0 : src[0].max | 0;
|
|
506
|
+
ret = bounded && a <= b ? [a, b] : [-Infinity, Infinity];
|
|
507
|
+
} else if (this.dtype === DType.Uint32) {
|
|
508
|
+
const a = wasFloat ? clamp(src[0].min, 0, 4294967295) >>> 0 : src[0].min >>> 0;
|
|
509
|
+
const b = wasFloat ? clamp(src[0].max, 0, 4294967295) >>> 0 : src[0].max >>> 0;
|
|
510
|
+
ret = bounded && a <= b ? [a, b] : [0, Infinity];
|
|
471
511
|
} else ret = [src[0].min, src[0].max];
|
|
472
512
|
break;
|
|
513
|
+
}
|
|
473
514
|
case AluOp.Cmplt:
|
|
474
515
|
ret = [0, 1];
|
|
475
516
|
break;
|
|
@@ -492,6 +533,7 @@ var AluExp = class AluExp {
|
|
|
492
533
|
ret[0] = clamp(ret[0], 0, 1);
|
|
493
534
|
ret[1] = clamp(ret[1], 0, 1);
|
|
494
535
|
}
|
|
536
|
+
if (this.dtype === DType.Uint32) ret[0] = Math.max(0, ret[0]);
|
|
495
537
|
this.#range = ret;
|
|
496
538
|
return ret;
|
|
497
539
|
}
|
|
@@ -501,10 +543,51 @@ var AluExp = class AluExp {
|
|
|
501
543
|
get max() {
|
|
502
544
|
return this.#computeRange()[1];
|
|
503
545
|
}
|
|
546
|
+
/** Largest known integer that divides self. */
|
|
547
|
+
constFactor() {
|
|
548
|
+
if (this.op === AluOp.Const) return Math.abs(this.arg);
|
|
549
|
+
if (this.op === AluOp.Add) return gcd(this.src[0].constFactor(), this.src[1].constFactor());
|
|
550
|
+
if (this.op === AluOp.Mul) {
|
|
551
|
+
if (this.src[0].op === AluOp.Const) return Math.abs(this.src[0].arg);
|
|
552
|
+
if (this.src[1].op === AluOp.Const) return Math.abs(this.src[1].arg);
|
|
553
|
+
}
|
|
554
|
+
return 1;
|
|
555
|
+
}
|
|
556
|
+
/**
|
|
557
|
+
* Checks if divisible by an integer v and returns the quotient if it is, or
|
|
558
|
+
* `null` if it's not divisible.
|
|
559
|
+
*/
|
|
560
|
+
divides(v) {
|
|
561
|
+
if (v === 1) return this;
|
|
562
|
+
if (this.op === AluOp.Const && this.arg % v === 0) return AluExp.const(this.dtype, this.arg / v);
|
|
563
|
+
if (this.op === AluOp.Add) {
|
|
564
|
+
const a = this.src[0].divides(v);
|
|
565
|
+
if (a !== null) {
|
|
566
|
+
const b = this.src[1].divides(v);
|
|
567
|
+
if (b !== null) return AluExp.add(a, b);
|
|
568
|
+
}
|
|
569
|
+
}
|
|
570
|
+
if (this.op === AluOp.Mul) {
|
|
571
|
+
const a = this.src[0].divides(v);
|
|
572
|
+
if (a !== null) return AluExp.mul(a, this.src[1]);
|
|
573
|
+
const b = this.src[1].divides(v);
|
|
574
|
+
if (b !== null) return AluExp.mul(this.src[0], b);
|
|
575
|
+
}
|
|
576
|
+
return null;
|
|
577
|
+
}
|
|
504
578
|
#isConstInt() {
|
|
505
579
|
return this.op === AluOp.Const && (this.dtype === DType.Int32 || this.dtype === DType.Uint32);
|
|
506
580
|
}
|
|
507
581
|
/**
|
|
582
|
+
* Get all expressions by deeply matching an operation.
|
|
583
|
+
*
|
|
584
|
+
* For example: `((2+(3*5))+4).splitOp(+) -> [2,(3*5),4]`.
|
|
585
|
+
*/
|
|
586
|
+
*splitOp(sep) {
|
|
587
|
+
if (this.op === sep) for (const src of this.src) yield* src.splitOp(sep);
|
|
588
|
+
else yield this;
|
|
589
|
+
}
|
|
590
|
+
/**
|
|
508
591
|
* Simplify the expression by replacing any known patterns and deduping
|
|
509
592
|
* identical subexpressions.
|
|
510
593
|
*/
|
|
@@ -549,7 +632,24 @@ var AluExp = class AluExp {
|
|
|
549
632
|
if (a.op === AluOp.Const && a.arg === -1) return new AluExp(opNeg, this.dtype, [src[0], b]);
|
|
550
633
|
else if (b.op === AluOp.Const && b.arg === -1) return new AluExp(opNeg, this.dtype, [src[0], a]);
|
|
551
634
|
}
|
|
635
|
+
if (op === AluOp.Where && src.slice(1).every((s, i) => s.op === AluOp.Const && s.arg === 1 - i)) return AluExp.cast(this.dtype, src[0]);
|
|
636
|
+
if (op === AluOp.Cmplt) {
|
|
637
|
+
if (src[0].min >= src[1].max) return AluExp.const(DType.Bool, false);
|
|
638
|
+
if (src[0].max < src[1].min) return AluExp.const(DType.Bool, true);
|
|
639
|
+
}
|
|
640
|
+
if (op === AluOp.Cmpne) {
|
|
641
|
+
if (src[0].max < src[1].min || src[0].min > src[1].max) return AluExp.const(DType.Bool, true);
|
|
642
|
+
}
|
|
643
|
+
if (op === AluOp.Where) {
|
|
644
|
+
if (src[0].max === 0) return src[2];
|
|
645
|
+
if (src[0].min === 1) return src[1];
|
|
646
|
+
}
|
|
552
647
|
if (op === AluOp.Mod && src[1].op === AluOp.Const && src[0].min >= 0 && src[0].max < src[1].arg) return src[0];
|
|
648
|
+
if (op === AluOp.Mod && src[0].op === AluOp.Mod && src[1].#isConstInt() && src[0].src[1].#isConstInt()) {
|
|
649
|
+
const A = src[0].src[1].arg;
|
|
650
|
+
const B = src[1].arg;
|
|
651
|
+
if (A > 0 && B > 0 && (A % B === 0 || B % A === 0)) return AluExp.mod(src[0].src[0], AluExp.const(this.dtype, Math.min(A, B))).simplify();
|
|
652
|
+
}
|
|
553
653
|
if (op === AluOp.Add && src[0].op === AluOp.Mul && src[0].src[1].#isConstInt() && src[1].op === AluOp.Mod && src[1].src[1].#isConstInt() && src[0].src[1].arg === src[1].src[1].arg) {
|
|
554
654
|
const [mul, mod] = src;
|
|
555
655
|
const check = (exp) => {
|
|
@@ -569,7 +669,7 @@ var AluExp = class AluExp {
|
|
|
569
669
|
const A = numer.src[i].arg;
|
|
570
670
|
if (A % B === 0) {
|
|
571
671
|
let ret = numer.src[1 - i];
|
|
572
|
-
if (A / B !== 1) ret = AluExp.mul(ret, AluExp.
|
|
672
|
+
if (A / B !== 1) ret = AluExp.mul(ret, AluExp.const(ret.dtype, A / B));
|
|
573
673
|
return ret.simplify(cache);
|
|
574
674
|
}
|
|
575
675
|
}
|
|
@@ -577,8 +677,8 @@ var AluExp = class AluExp {
|
|
|
577
677
|
const A = numer.src[j].src[i].arg;
|
|
578
678
|
if (A % B === 0) {
|
|
579
679
|
let ret = numer.src[j].src[1 - i];
|
|
580
|
-
if (A / B !== 1) ret = AluExp.mul(ret, AluExp.
|
|
581
|
-
ret = AluExp.add(ret, AluExp.idiv(numer.src[1 - j], B));
|
|
680
|
+
if (A / B !== 1) ret = AluExp.mul(ret, AluExp.const(ret.dtype, A / B));
|
|
681
|
+
ret = AluExp.add(ret, AluExp.idiv(numer.src[1 - j], AluExp.const(ret.dtype, B)));
|
|
582
682
|
return ret.simplify(cache);
|
|
583
683
|
}
|
|
584
684
|
}
|
|
@@ -587,23 +687,81 @@ var AluExp = class AluExp {
|
|
|
587
687
|
if (op === AluOp.Mod && src[1].#isConstInt() && src[1].arg > 0 && src[0].min >= 0) {
|
|
588
688
|
const [numer, denom] = src;
|
|
589
689
|
const B = denom.arg;
|
|
590
|
-
for (let i = 0; i < 2; i++) if (numer.op === AluOp.Add
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
690
|
+
for (let i = 0; i < 2; i++) if (numer.op === AluOp.Add) {
|
|
691
|
+
if (numer.src[i].#isConstInt()) {
|
|
692
|
+
const A = numer.src[i].arg;
|
|
693
|
+
const x = numer.src[1 - i];
|
|
694
|
+
if (A % B === 0 && x.min >= 0) return AluExp.mod(x, denom).simplify(cache);
|
|
695
|
+
}
|
|
696
|
+
for (let j = 0; j < 2; j++) if (numer.src[i].op === AluOp.Mul && numer.src[i].src[j].#isConstInt()) {
|
|
697
|
+
const A = numer.src[i].src[j].arg;
|
|
698
|
+
const x = numer.src[1 - i];
|
|
699
|
+
if (A % B === 0 && x.min >= 0) return AluExp.mod(x, denom).simplify(cache);
|
|
700
|
+
}
|
|
701
|
+
} else if (numer.op === AluOp.Mul) {
|
|
702
|
+
if (numer.src[i].#isConstInt()) {
|
|
703
|
+
const A = numer.src[i].arg;
|
|
704
|
+
if (A % B === 0) return AluExp.const(this.dtype, 0);
|
|
705
|
+
if (A % B === 1) return AluExp.mod(numer.src[1 - i], denom).simplify(cache);
|
|
706
|
+
}
|
|
595
707
|
}
|
|
596
708
|
}
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
709
|
+
const commOps = [
|
|
710
|
+
AluOp.Add,
|
|
711
|
+
AluOp.Mul,
|
|
712
|
+
AluOp.Max,
|
|
713
|
+
AluOp.Min
|
|
714
|
+
];
|
|
715
|
+
if (commOps.includes(op)) {
|
|
716
|
+
const p = (a, b) => new AluExp(op, this.dtype, [a, b]);
|
|
717
|
+
if (src[0].op === AluOp.Const) return p(src[1], src[0]).simplify(cache);
|
|
718
|
+
if (src[0].op === op && src[0].src[1].op === AluOp.Const) if (src[1].op === AluOp.Const) return p(src[0].src[0], p(src[0].src[1], src[1])).simplify(cache);
|
|
719
|
+
else return p(p(src[0].src[0], src[1]), src[0].src[1]).simplify(cache);
|
|
720
|
+
if (src[1].op === op && src[1].src[1].op === AluOp.Const) return p(p(src[0], src[1].src[0]), src[1].src[1]).simplify(cache);
|
|
603
721
|
}
|
|
604
|
-
if (op === AluOp.
|
|
605
|
-
|
|
606
|
-
|
|
722
|
+
if (op === AluOp.Mod || op === AluOp.Idiv && src[1].#isConstInt()) {
|
|
723
|
+
const [x, y] = src;
|
|
724
|
+
{
|
|
725
|
+
const factors = [];
|
|
726
|
+
const terms = [];
|
|
727
|
+
for (const u of x.splitOp(AluOp.Add)) {
|
|
728
|
+
const factor = u.constFactor();
|
|
729
|
+
factors.push(factor);
|
|
730
|
+
terms.push(u.divides(factor));
|
|
731
|
+
}
|
|
732
|
+
const g = gcd(y.arg, ...factors);
|
|
733
|
+
if (g !== 1) {
|
|
734
|
+
let ret = new AluExp(op, this.dtype, [factors.map((f, i) => AluExp.mul(AluExp.const(terms[i].dtype, f / g), terms[i])).reduceRight((a, x$1) => AluExp.add(x$1, a)), AluExp.const(y.dtype, y.arg / g)]);
|
|
735
|
+
if (op === AluOp.Mod) ret = AluExp.mul(ret, AluExp.const(this.dtype, g));
|
|
736
|
+
return ret.simplify(cache);
|
|
737
|
+
}
|
|
738
|
+
}
|
|
739
|
+
if (y.arg > 0) {
|
|
740
|
+
let [xNoConst, constVal] = [x, 0];
|
|
741
|
+
if (x.op === AluOp.Add && x.src[1].op === AluOp.Const) [xNoConst, constVal] = [x.src[0], x.src[1].arg];
|
|
742
|
+
const terms = [];
|
|
743
|
+
const factors = [];
|
|
744
|
+
for (const u of xNoConst.splitOp(AluOp.Add)) {
|
|
745
|
+
const f = u.constFactor();
|
|
746
|
+
const divided = u.divides(f);
|
|
747
|
+
terms.push(divided ?? u);
|
|
748
|
+
factors.push(divided ? f : 1);
|
|
749
|
+
}
|
|
750
|
+
const quotients = factors.map((f) => Math.floor(f / y.arg));
|
|
751
|
+
const remainders = factors.map((f) => f % y.arg);
|
|
752
|
+
const gcdVal = remainders.reduce((g, r) => gcd(g, r), y.arg);
|
|
753
|
+
if (constVal % y.arg !== constVal || gcdVal !== 1 || remainders.some((r, i) => r === 0 || r !== factors[i] && op === AluOp.Mod)) {
|
|
754
|
+
let quo = AluExp.const(x.dtype, Math.floor(constVal / y.arg));
|
|
755
|
+
let rem = AluExp.const(x.dtype, Math.floor(constVal % y.arg / gcdVal));
|
|
756
|
+
for (let i = 0; i < terms.length; i++) if (op === AluOp.Idiv && remainders[i] !== 0) rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(factors[i] / gcdVal)), terms[i]));
|
|
757
|
+
else {
|
|
758
|
+
rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(remainders[i] / gcdVal)), terms[i]));
|
|
759
|
+
quo = AluExp.add(quo, AluExp.mul(AluExp.const(x.dtype, quotients[i]), terms[i]));
|
|
760
|
+
}
|
|
761
|
+
if (!((x.min < 0 || rem.min < 0) && remainders.some((r) => r !== 0))) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
|
|
762
|
+
else return AluExp.add(AluExp.idiv(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal))), quo).simplify(cache);
|
|
763
|
+
}
|
|
764
|
+
}
|
|
607
765
|
}
|
|
608
766
|
const newExp = src.every((s, i) => s === this.src[i]) ? this : new AluExp(op, this.dtype, src, this.arg);
|
|
609
767
|
return newExp;
|
|
@@ -646,12 +804,16 @@ var AluExp = class AluExp {
|
|
|
646
804
|
case AluOp.Cos: return Math.cos(x);
|
|
647
805
|
case AluOp.Exp: return Math.exp(x);
|
|
648
806
|
case AluOp.Log: return Math.log(x);
|
|
807
|
+
case AluOp.Sqrt: return Math.sqrt(x);
|
|
649
808
|
case AluOp.Reciprocal: return 1 / x;
|
|
650
|
-
case AluOp.Cast:
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
809
|
+
case AluOp.Cast: {
|
|
810
|
+
const wasFloat = isFloatDtype(this.src[0].dtype);
|
|
811
|
+
if (this.dtype === DType.Int32) return (wasFloat ? clamp(x, -2147483648, 2147483647) : x) | 0;
|
|
812
|
+
else if (this.dtype === DType.Uint32) return (wasFloat ? clamp(x, 0, 4294967295) : x) >>> 0;
|
|
813
|
+
else if (isFloatDtype(this.dtype)) return x;
|
|
814
|
+
else if (this.dtype === DType.Bool) return Number(Boolean(x));
|
|
815
|
+
else throw new Error(`Unsupported cast to ${this.dtype}`);
|
|
816
|
+
}
|
|
655
817
|
case AluOp.Bitcast: {
|
|
656
818
|
const buf = new ArrayBuffer(byteWidth(this.dtype));
|
|
657
819
|
const view = new DataView(buf);
|
|
@@ -659,10 +821,12 @@ var AluExp = class AluExp {
|
|
|
659
821
|
if (fromType === DType.Float32) view.setFloat32(0, x, true);
|
|
660
822
|
else if (fromType === DType.Int32) view.setInt32(0, x, true);
|
|
661
823
|
else if (fromType === DType.Uint32) view.setUint32(0, x, true);
|
|
824
|
+
else if (fromType === DType.Float16) view.setFloat16(0, x, true);
|
|
662
825
|
else throw new Error(`Unsupported bitcast from ${fromType}`);
|
|
663
826
|
if (this.dtype === DType.Float32) return view.getFloat32(0, true);
|
|
664
827
|
else if (this.dtype === DType.Int32) return view.getInt32(0, true);
|
|
665
828
|
else if (this.dtype === DType.Uint32) return view.getUint32(0, true);
|
|
829
|
+
else if (this.dtype === DType.Float16) return view.getFloat16(0, true);
|
|
666
830
|
else throw new Error(`Unsupported bitcast to ${this.dtype}`);
|
|
667
831
|
}
|
|
668
832
|
default: throw new Error(`Missing implemementation for ${this.op}`);
|
|
@@ -691,7 +855,7 @@ var AluExp = class AluExp {
|
|
|
691
855
|
}
|
|
692
856
|
case AluOp.GlobalIndex: {
|
|
693
857
|
if (!globals) throw new Error("Missing globals function");
|
|
694
|
-
const gid = this.arg;
|
|
858
|
+
const gid = this.arg[0];
|
|
695
859
|
const bufidx = this.src[0].evaluate(context, globals);
|
|
696
860
|
return globals(gid, bufidx);
|
|
697
861
|
}
|
|
@@ -721,13 +885,7 @@ var AluExp = class AluExp {
|
|
|
721
885
|
[AluOp.Cmplt]: "<",
|
|
722
886
|
[AluOp.Cmpne]: "!="
|
|
723
887
|
};
|
|
724
|
-
const UNARY_SYM = {
|
|
725
|
-
[AluOp.Sin]: "sin",
|
|
726
|
-
[AluOp.Cos]: "cos",
|
|
727
|
-
[AluOp.Exp]: "exp",
|
|
728
|
-
[AluOp.Log]: "log",
|
|
729
|
-
[AluOp.Reciprocal]: "1/"
|
|
730
|
-
};
|
|
888
|
+
const UNARY_SYM = { [AluOp.Reciprocal]: "1/" };
|
|
731
889
|
return this.fold((node, parts) => {
|
|
732
890
|
switch (node.op) {
|
|
733
891
|
case AluOp.Const: return "" + (node.dtype === DType.Bool ? Boolean(node.arg) : node.arg);
|
|
@@ -736,7 +894,7 @@ var AluExp = class AluExp {
|
|
|
736
894
|
const [name, n] = node.arg;
|
|
737
895
|
return `#${name}{${n}}`;
|
|
738
896
|
}
|
|
739
|
-
case AluOp.GlobalIndex: return `G_${node.arg}<${node.dtype}>[${strip1(parts[0])}]`;
|
|
897
|
+
case AluOp.GlobalIndex: return `G_${node.arg[0]}<${node.dtype}>[${strip1(parts[0])}]`;
|
|
740
898
|
case AluOp.GlobalView: {
|
|
741
899
|
const [gid, st] = node.arg;
|
|
742
900
|
const shape = st.shape.join(",");
|
|
@@ -765,6 +923,17 @@ var AluExp = class AluExp {
|
|
|
765
923
|
};
|
|
766
924
|
return recurse(this);
|
|
767
925
|
}
|
|
926
|
+
/** Check if any expression in the tree satisfies a predicate. */
|
|
927
|
+
some(predicate) {
|
|
928
|
+
const visited = /* @__PURE__ */ new Set();
|
|
929
|
+
const recurse = (exp) => {
|
|
930
|
+
if (visited.has(exp)) return false;
|
|
931
|
+
if (predicate(exp)) return true;
|
|
932
|
+
visited.add(exp);
|
|
933
|
+
return exp.src.some(recurse);
|
|
934
|
+
};
|
|
935
|
+
return recurse(this);
|
|
936
|
+
}
|
|
768
937
|
/** Rewrite the expression recursively using a visitor. */
|
|
769
938
|
rewrite(visitor) {
|
|
770
939
|
return this.fold((exp, newSrc) => {
|
|
@@ -783,6 +952,23 @@ var AluExp = class AluExp {
|
|
|
783
952
|
});
|
|
784
953
|
return result;
|
|
785
954
|
}
|
|
955
|
+
/** Produce a list of all distinct AluOp in this expression. */
|
|
956
|
+
distinctOps() {
|
|
957
|
+
const ops = /* @__PURE__ */ new Set();
|
|
958
|
+
this.fold((exp) => {
|
|
959
|
+
ops.add(exp.op);
|
|
960
|
+
});
|
|
961
|
+
return ops;
|
|
962
|
+
}
|
|
963
|
+
/** Rewrite GlobalView operations to GlobalIndex operations. */
|
|
964
|
+
rewriteGlobalViews() {
|
|
965
|
+
return this.rewrite((exp) => {
|
|
966
|
+
if (exp.op === AluOp.GlobalView) {
|
|
967
|
+
const [gid, st] = exp.arg;
|
|
968
|
+
return accessorGlobal(exp.dtype, gid, st, exp.src);
|
|
969
|
+
}
|
|
970
|
+
});
|
|
971
|
+
}
|
|
786
972
|
};
|
|
787
973
|
/** Symbolic form for each mathematical operation. */
|
|
788
974
|
let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
@@ -797,6 +983,7 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
797
983
|
AluOp$1["Cos"] = "Cos";
|
|
798
984
|
AluOp$1["Exp"] = "Exp";
|
|
799
985
|
AluOp$1["Log"] = "Log";
|
|
986
|
+
AluOp$1["Sqrt"] = "Sqrt";
|
|
800
987
|
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
801
988
|
AluOp$1["Cast"] = "Cast";
|
|
802
989
|
AluOp$1["Bitcast"] = "Bitcast";
|
|
@@ -826,6 +1013,7 @@ const AluGroup = {
|
|
|
826
1013
|
AluOp.Cos,
|
|
827
1014
|
AluOp.Exp,
|
|
828
1015
|
AluOp.Log,
|
|
1016
|
+
AluOp.Sqrt,
|
|
829
1017
|
AluOp.Reciprocal,
|
|
830
1018
|
AluOp.Cast,
|
|
831
1019
|
AluOp.Bitcast
|
|
@@ -848,6 +1036,7 @@ const AluGroup = {
|
|
|
848
1036
|
AluOp.Cos,
|
|
849
1037
|
AluOp.Exp,
|
|
850
1038
|
AluOp.Log,
|
|
1039
|
+
AluOp.Sqrt,
|
|
851
1040
|
AluOp.Reciprocal
|
|
852
1041
|
])
|
|
853
1042
|
};
|
|
@@ -889,7 +1078,7 @@ var Kernel = class {
|
|
|
889
1078
|
}
|
|
890
1079
|
/** The dtype of the values output by this kernel. */
|
|
891
1080
|
get dtype() {
|
|
892
|
-
if (this.reduction) return this.reduction.
|
|
1081
|
+
if (this.reduction) return this.reduction.epilogue.dtype;
|
|
893
1082
|
else return this.exp.dtype;
|
|
894
1083
|
}
|
|
895
1084
|
/** The number of bytes in the output array when evaluating this kernel. */
|
|
@@ -913,22 +1102,23 @@ var Kernel = class {
|
|
|
913
1102
|
* at this level since they depend on GPU, versus CPU or Wasm.
|
|
914
1103
|
*/
|
|
915
1104
|
var Reduction = class {
|
|
916
|
-
constructor(dtype, op, size,
|
|
1105
|
+
constructor(dtype, op, size, epilogue = AluVar.acc(dtype)) {
|
|
917
1106
|
this.dtype = dtype;
|
|
918
1107
|
this.op = op;
|
|
919
1108
|
this.size = size;
|
|
920
|
-
this.
|
|
1109
|
+
this.epilogue = epilogue;
|
|
921
1110
|
if (!AluGroup.Reduce.has(op)) throw new TypeError(`Unsupported reduction: ${op}`);
|
|
1111
|
+
this.epilogue = epilogue.simplify();
|
|
922
1112
|
}
|
|
923
1113
|
hash(state) {
|
|
924
|
-
state.update(this.dtype, this.op, this.size, this.
|
|
1114
|
+
state.update(this.dtype, this.op, this.size, this.epilogue);
|
|
925
1115
|
}
|
|
926
1116
|
toString() {
|
|
927
|
-
return `${this.op}{${this.size}} -> ${this.
|
|
1117
|
+
return `${this.op}{${this.size}} -> ${this.epilogue}`;
|
|
928
1118
|
}
|
|
929
1119
|
/** Get the identity for this reduction operation. */
|
|
930
1120
|
get identity() {
|
|
931
|
-
if (this.dtype === DType.Bool) return this.op === AluOp.Add || this.op === AluOp.Max ?
|
|
1121
|
+
if (this.dtype === DType.Bool) return this.op === AluOp.Add || this.op === AluOp.Max ? 0 : 1;
|
|
932
1122
|
else if (this.dtype === DType.Int32) {
|
|
933
1123
|
if (this.op === AluOp.Add) return 0;
|
|
934
1124
|
else if (this.op === AluOp.Mul) return 1;
|
|
@@ -939,7 +1129,7 @@ var Reduction = class {
|
|
|
939
1129
|
else if (this.op === AluOp.Mul) return 1;
|
|
940
1130
|
else if (this.op === AluOp.Min) return -1 >>> 0;
|
|
941
1131
|
else if (this.op === AluOp.Max) return 0;
|
|
942
|
-
} else if (this.dtype
|
|
1132
|
+
} else if (isFloatDtype(this.dtype)) {
|
|
943
1133
|
if (this.op === AluOp.Add) return 0;
|
|
944
1134
|
else if (this.op === AluOp.Mul) return 1;
|
|
945
1135
|
else if (this.op === AluOp.Min) return Infinity;
|
|
@@ -962,7 +1152,7 @@ var Reduction = class {
|
|
|
962
1152
|
else if (this.op === AluOp.Mul) return values.reduce((a, b) => a * b >>> 0, 1);
|
|
963
1153
|
else if (this.op === AluOp.Min) return values.reduce((a, b) => Math.min(a, b), -1 >>> 0);
|
|
964
1154
|
else if (this.op === AluOp.Max) return values.reduce((a, b) => Math.max(a, b), 0);
|
|
965
|
-
} else if (this.dtype
|
|
1155
|
+
} else if (isFloatDtype(this.dtype)) {
|
|
966
1156
|
if (this.op === AluOp.Add) return values.reduce((a, b) => a + b, 0);
|
|
967
1157
|
else if (this.op === AluOp.Mul) return values.reduce((a, b) => a * b, 1);
|
|
968
1158
|
else if (this.op === AluOp.Min) return values.reduce((a, b) => Math.min(a, b), Infinity);
|
|
@@ -974,12 +1164,13 @@ var Reduction = class {
|
|
|
974
1164
|
/** Expression for accessing `indices` in input array with the given shape. */
|
|
975
1165
|
function accessorGlobal(dtype, gid, st, indices) {
|
|
976
1166
|
const [index, valid] = st.toAluExp(indices);
|
|
977
|
-
|
|
1167
|
+
const [, len] = st.views[0].dataRange();
|
|
1168
|
+
return AluExp.where(valid, AluExp.globalIndex(dtype, gid, len, index), AluExp.const(dtype, 0));
|
|
978
1169
|
}
|
|
979
1170
|
/** Expression for accessing `indices` in an array recipe with variable "idx". */
|
|
980
|
-
function accessorAluExp(
|
|
1171
|
+
function accessorAluExp(exp, st, indices) {
|
|
981
1172
|
const [index, valid] = st.toAluExp(indices);
|
|
982
|
-
return AluExp.where(valid, exp.substitute({ idx: index }), AluExp.const(dtype, 0));
|
|
1173
|
+
return AluExp.where(valid, exp.substitute({ idx: index }), AluExp.const(exp.dtype, 0));
|
|
983
1174
|
}
|
|
984
1175
|
function threefry2x32(k0, k1, c0, c1) {
|
|
985
1176
|
const rotl32 = (x, r) => (x << r | x >>> 32 - r) >>> 0;
|
|
@@ -1163,6 +1354,25 @@ var View = class View {
|
|
|
1163
1354
|
if (this.#contiguous === void 0) this.#contiguous = this.size === 0 || this.offset === 0 && this.mask === null && deepEqual(this.strides, defaultStrides(this.shape));
|
|
1164
1355
|
return this.#contiguous;
|
|
1165
1356
|
}
|
|
1357
|
+
/** Return the range of data being indexed in this view, or [0, 0] if none. */
|
|
1358
|
+
dataRange() {
|
|
1359
|
+
if (this.size === 0 || this.mask && this.mask[0][0] === this.mask[0][1]) return [0, 0];
|
|
1360
|
+
let min = this.offset;
|
|
1361
|
+
let max = this.offset;
|
|
1362
|
+
for (let i = 0; i < this.ndim; i++) {
|
|
1363
|
+
let [lo, hi] = this.mask ? this.mask[i] : [0, this.shape[i]];
|
|
1364
|
+
--hi;
|
|
1365
|
+
const s = this.strides[i];
|
|
1366
|
+
if (s > 0) {
|
|
1367
|
+
min += s * lo;
|
|
1368
|
+
max += s * hi;
|
|
1369
|
+
} else if (s < 0) {
|
|
1370
|
+
min += s * hi;
|
|
1371
|
+
max += s * lo;
|
|
1372
|
+
}
|
|
1373
|
+
}
|
|
1374
|
+
return [min, max + 1];
|
|
1375
|
+
}
|
|
1166
1376
|
/** Produce an AluExp for evaluating this view at an index. */
|
|
1167
1377
|
toAluExp(idxs) {
|
|
1168
1378
|
let iexpr = AluExp.i32(this.offset);
|
|
@@ -1477,6 +1687,39 @@ var ShapeTracker = class ShapeTracker {
|
|
|
1477
1687
|
}
|
|
1478
1688
|
return st.expand(newShape);
|
|
1479
1689
|
}
|
|
1690
|
+
/**
|
|
1691
|
+
* Repeat data in each axis by a positive number of repetitions.
|
|
1692
|
+
*
|
|
1693
|
+
* - If `tile` is true (default): [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
|
|
1694
|
+
* - If `tile` is false: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
|
|
1695
|
+
*/
|
|
1696
|
+
repeat(reps, tile = true) {
|
|
1697
|
+
if (reps.length > this.shape.length) throw new Error(`Too many repeats ${jstr(reps)} for shape ${jstr(this.shape)}`);
|
|
1698
|
+
if (reps.some((c) => c <= 0)) throw new Error(`Invalid repeats ${jstr(reps)}`);
|
|
1699
|
+
if (reps.length === 0) return this;
|
|
1700
|
+
const noop = this.shape.slice(0, -reps.length);
|
|
1701
|
+
const shape = this.shape.slice(-reps.length);
|
|
1702
|
+
return this.broadcast([...noop, ...shape.flatMap((s, i) => tile ? [reps[i], s] : [s, reps[i]])], shape.map((_, i) => noop.length + 2 * i + (tile ? 0 : 1))).reshape([...noop, ...shape.map((s, i) => s * reps[i])]);
|
|
1703
|
+
}
|
|
1704
|
+
/** Move axis i to axis j. */
|
|
1705
|
+
moveaxis(i, j) {
|
|
1706
|
+
const perm = range(this.shape.length);
|
|
1707
|
+
perm.splice(i, 1);
|
|
1708
|
+
perm.splice(j, 0, i);
|
|
1709
|
+
return this.permute(perm);
|
|
1710
|
+
}
|
|
1711
|
+
/** Like pad(), but allows for negative values. */
|
|
1712
|
+
padOrShrink(arg) {
|
|
1713
|
+
const padArg = [];
|
|
1714
|
+
const shrinkArg = [];
|
|
1715
|
+
for (let i = 0; i < arg.length; i++) {
|
|
1716
|
+
const [b, e] = arg[i];
|
|
1717
|
+
if (b < -this.shape[i] || e < -this.shape[i] || b + e < -this.shape[i]) throw new Error(`Invalid padOrShrink ${jstr(arg)} for ${jstr(this.shape)}`);
|
|
1718
|
+
padArg.push([Math.max(0, b), Math.max(0, e)]);
|
|
1719
|
+
shrinkArg.push([Math.max(0, -b), this.shape[i] - Math.max(0, -e)]);
|
|
1720
|
+
}
|
|
1721
|
+
return this.shrink(shrinkArg).pad(padArg);
|
|
1722
|
+
}
|
|
1480
1723
|
};
|
|
1481
1724
|
function applyLast(ar, f) {
|
|
1482
1725
|
return ar.toSpliced(ar.length - 1, 1, f(ar[ar.length - 1]));
|
|
@@ -1597,13 +1840,7 @@ function tuneNullopt(kernel) {
|
|
|
1597
1840
|
vars.gidx = AluExp.special(DType.Int32, "gidx", kernel.size);
|
|
1598
1841
|
if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
|
|
1599
1842
|
return {
|
|
1600
|
-
exp: kernel.exp.
|
|
1601
|
-
if (exp.op === AluOp.GlobalView) {
|
|
1602
|
-
const gid = exp.arg[0];
|
|
1603
|
-
const st = exp.arg[1];
|
|
1604
|
-
return accessorGlobal(exp.dtype, gid, st, exp.src);
|
|
1605
|
-
}
|
|
1606
|
-
}).substitute(vars).simplify(),
|
|
1843
|
+
exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
|
|
1607
1844
|
outputIdxExp: AluExp.special(DType.Int32, "gidx", kernel.size),
|
|
1608
1845
|
threadCount: kernel.size,
|
|
1609
1846
|
size: { reduce: kernel.reduction ? kernel.reduction.size : 0 }
|
|
@@ -1698,13 +1935,19 @@ function tuneWebgpu(kernel) {
|
|
|
1698
1935
|
const s = dim.st.shape.slice(dim.upcast);
|
|
1699
1936
|
addIndices(s, AluVar.upcast);
|
|
1700
1937
|
}
|
|
1701
|
-
|
|
1938
|
+
let newExp = exp.rewrite((exp$1) => {
|
|
1702
1939
|
if (exp$1.op === AluOp.GlobalView) {
|
|
1703
1940
|
const gid = exp$1.arg[0];
|
|
1704
1941
|
const st = exp$1.arg[1];
|
|
1705
1942
|
return accessorGlobal(exp$1.dtype, gid, st.compose(dim.st), indices);
|
|
1706
1943
|
}
|
|
1707
1944
|
});
|
|
1945
|
+
const [iexpr, vexpr] = dim.st.toAluExp(indices);
|
|
1946
|
+
if (vexpr.min !== 1) throw new Error("Invariant violation: vexpr !== true");
|
|
1947
|
+
newExp = newExp.substitute({
|
|
1948
|
+
gidx: AluExp.idiv(iexpr, AluExp.i32(reduction.size)).simplify(),
|
|
1949
|
+
ridx: AluExp.mod(iexpr, AluExp.i32(reduction.size)).simplify()
|
|
1950
|
+
});
|
|
1708
1951
|
const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
|
|
1709
1952
|
const outputUpcast = dim.outputSt.shape.slice(dim.groups);
|
|
1710
1953
|
const [outputIdxExp, _] = dim.outputSt.toAluExp([...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)]);
|
|
@@ -1726,7 +1969,7 @@ function tuneWebgpu(kernel) {
|
|
|
1726
1969
|
//#endregion
|
|
1727
1970
|
//#region src/backend/cpu.ts
|
|
1728
1971
|
/** Most basic implementation of `Backend` for testing. */
|
|
1729
|
-
var
|
|
1972
|
+
var CpuBackend = class {
|
|
1730
1973
|
type = "cpu";
|
|
1731
1974
|
maxArgs = Infinity;
|
|
1732
1975
|
#buffers;
|
|
@@ -1736,10 +1979,10 @@ var CPUBackend = class {
|
|
|
1736
1979
|
this.#nextSlot = 1;
|
|
1737
1980
|
}
|
|
1738
1981
|
malloc(size, initialData) {
|
|
1739
|
-
const buffer = new
|
|
1982
|
+
const buffer = new Uint8Array(size);
|
|
1740
1983
|
if (initialData) {
|
|
1741
1984
|
if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
|
|
1742
|
-
|
|
1985
|
+
buffer.set(initialData);
|
|
1743
1986
|
}
|
|
1744
1987
|
const slot = this.#nextSlot++;
|
|
1745
1988
|
this.#buffers.set(slot, {
|
|
@@ -1778,7 +2021,7 @@ var CPUBackend = class {
|
|
|
1778
2021
|
const { exp } = tuneNullopt(kernel);
|
|
1779
2022
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
|
|
1780
2023
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
|
|
1781
|
-
const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg, exp$1.dtype]));
|
|
2024
|
+
const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
|
|
1782
2025
|
const inputArrays = inputBuffers.map((buf, i) => {
|
|
1783
2026
|
const dtype = usedArgs.get(i);
|
|
1784
2027
|
if (!dtype) return null;
|
|
@@ -1800,7 +2043,7 @@ var CPUBackend = class {
|
|
|
1800
2043
|
}, globals);
|
|
1801
2044
|
acc = kernel.reduction.evaluate(acc, item);
|
|
1802
2045
|
}
|
|
1803
|
-
outputArray[i] = kernel.reduction.
|
|
2046
|
+
outputArray[i] = kernel.reduction.epilogue.evaluate({ acc });
|
|
1804
2047
|
}
|
|
1805
2048
|
}
|
|
1806
2049
|
#getBuffer(slot) {
|
|
@@ -1810,12 +2053,1540 @@ var CPUBackend = class {
|
|
|
1810
2053
|
}
|
|
1811
2054
|
};
|
|
1812
2055
|
|
|
2056
|
+
//#endregion
|
|
2057
|
+
//#region src/backend/wasm/allocator.ts
|
|
2058
|
+
/** Simple tensor memory allocator for WebAssembly linear memory. */
|
|
2059
|
+
var WasmAllocator = class {
|
|
2060
|
+
#memory;
|
|
2061
|
+
#headPtr;
|
|
2062
|
+
#freeLists;
|
|
2063
|
+
#allocatedBuffers;
|
|
2064
|
+
constructor(memory) {
|
|
2065
|
+
this.#memory = memory;
|
|
2066
|
+
this.#headPtr = 64;
|
|
2067
|
+
this.#freeLists = /* @__PURE__ */ new Map();
|
|
2068
|
+
this.#allocatedBuffers = /* @__PURE__ */ new Map();
|
|
2069
|
+
}
|
|
2070
|
+
malloc(size) {
|
|
2071
|
+
if (size === 0) return 0;
|
|
2072
|
+
const sizeClass = this.#findSizeClass(size);
|
|
2073
|
+
const freeList = this.#freeLists.get(sizeClass);
|
|
2074
|
+
let ptr;
|
|
2075
|
+
if (freeList && freeList.length > 0) ptr = freeList.pop();
|
|
2076
|
+
else ptr = this.#bumpAlloc(sizeClass);
|
|
2077
|
+
this.#allocatedBuffers.set(ptr, sizeClass);
|
|
2078
|
+
return ptr;
|
|
2079
|
+
}
|
|
2080
|
+
free(ptr) {
|
|
2081
|
+
if (ptr === 0) return;
|
|
2082
|
+
const sizeClass = this.#allocatedBuffers.get(ptr);
|
|
2083
|
+
if (sizeClass === void 0) throw new Error(`Attempting to free unallocated pointer: ${ptr}`);
|
|
2084
|
+
const freeList = this.#freeLists.get(sizeClass);
|
|
2085
|
+
if (freeList) freeList.push(ptr);
|
|
2086
|
+
else this.#freeLists.set(sizeClass, [ptr]);
|
|
2087
|
+
this.#allocatedBuffers.delete(ptr);
|
|
2088
|
+
}
|
|
2089
|
+
#bumpAlloc(size) {
|
|
2090
|
+
const ptr = this.#headPtr;
|
|
2091
|
+
size = size + 63 & -64;
|
|
2092
|
+
this.#headPtr += size;
|
|
2093
|
+
if (ptr + size > this.#memory.buffer.byteLength) this.#memory.grow((ptr + size + 65535 >> 16) - (this.#memory.buffer.byteLength >> 16));
|
|
2094
|
+
return ptr;
|
|
2095
|
+
}
|
|
2096
|
+
#findSizeClass(size) {
|
|
2097
|
+
if (size <= 512) return size + 63 & -64;
|
|
2098
|
+
if (size <= 2048) return size + 511 & -512;
|
|
2099
|
+
if (size <= 65536) {
|
|
2100
|
+
let sizeClass = 4096;
|
|
2101
|
+
while (sizeClass < size) sizeClass *= 2;
|
|
2102
|
+
return sizeClass;
|
|
2103
|
+
}
|
|
2104
|
+
return size + 65535 & -65536;
|
|
2105
|
+
}
|
|
2106
|
+
getStats() {
|
|
2107
|
+
const freeListSizes = /* @__PURE__ */ new Map();
|
|
2108
|
+
for (const [sizeClass, freeList] of this.#freeLists) if (freeList.length > 0) freeListSizes.set(sizeClass, freeList.length);
|
|
2109
|
+
return {
|
|
2110
|
+
totalAllocated: this.#headPtr,
|
|
2111
|
+
freeListSizes
|
|
2112
|
+
};
|
|
2113
|
+
}
|
|
2114
|
+
};
|
|
2115
|
+
|
|
2116
|
+
//#endregion
|
|
2117
|
+
//#region src/backend/wasm/builtins.ts
|
|
2118
|
+
/**
|
|
2119
|
+
* Approximate e^x.
|
|
2120
|
+
*
|
|
2121
|
+
* Method: range-reduce x = k*ln2 + r with k = round(x/ln2), |r|<=~0.3466
|
|
2122
|
+
* then e^x = 2^k * P(r), where P is 5th-order poly (Taylor).
|
|
2123
|
+
*/
|
|
2124
|
+
function wasm_exp(cg) {
|
|
2125
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2126
|
+
const k_f = cg.local.declare(cg.f32);
|
|
2127
|
+
const k = cg.local.declare(cg.i32);
|
|
2128
|
+
const r = cg.local.declare(cg.f32);
|
|
2129
|
+
const p = cg.local.declare(cg.f32);
|
|
2130
|
+
const scale = cg.local.declare(cg.f32);
|
|
2131
|
+
cg.local.get(0);
|
|
2132
|
+
cg.f32.const(1 / Math.LN2);
|
|
2133
|
+
cg.f32.mul();
|
|
2134
|
+
cg.f32.nearest();
|
|
2135
|
+
cg.local.tee(k_f);
|
|
2136
|
+
cg.i32.trunc_sat_f32_s();
|
|
2137
|
+
cg.local.set(k);
|
|
2138
|
+
cg.local.get(k);
|
|
2139
|
+
cg.i32.const(127);
|
|
2140
|
+
cg.i32.gt_s();
|
|
2141
|
+
cg.if(cg.void);
|
|
2142
|
+
cg.f32.const(Infinity);
|
|
2143
|
+
cg.return();
|
|
2144
|
+
cg.end();
|
|
2145
|
+
cg.local.get(k);
|
|
2146
|
+
cg.i32.const(-126);
|
|
2147
|
+
cg.i32.lt_s();
|
|
2148
|
+
cg.if(cg.void);
|
|
2149
|
+
cg.f32.const(0);
|
|
2150
|
+
cg.return();
|
|
2151
|
+
cg.end();
|
|
2152
|
+
cg.local.get(0);
|
|
2153
|
+
cg.local.get(k_f);
|
|
2154
|
+
cg.f32.const(Math.LN2);
|
|
2155
|
+
cg.f32.mul();
|
|
2156
|
+
cg.f32.sub();
|
|
2157
|
+
cg.local.set(r);
|
|
2158
|
+
cg.f32.const(1 / 120);
|
|
2159
|
+
cg.local.get(r);
|
|
2160
|
+
cg.f32.mul();
|
|
2161
|
+
cg.f32.const(1 / 24);
|
|
2162
|
+
cg.f32.add();
|
|
2163
|
+
cg.local.get(r);
|
|
2164
|
+
cg.f32.mul();
|
|
2165
|
+
cg.f32.const(1 / 6);
|
|
2166
|
+
cg.f32.add();
|
|
2167
|
+
cg.local.get(r);
|
|
2168
|
+
cg.f32.mul();
|
|
2169
|
+
cg.f32.const(1 / 2);
|
|
2170
|
+
cg.f32.add();
|
|
2171
|
+
cg.local.get(r);
|
|
2172
|
+
cg.f32.mul();
|
|
2173
|
+
cg.f32.const(1);
|
|
2174
|
+
cg.f32.add();
|
|
2175
|
+
cg.local.get(r);
|
|
2176
|
+
cg.f32.mul();
|
|
2177
|
+
cg.f32.const(1);
|
|
2178
|
+
cg.f32.add();
|
|
2179
|
+
cg.local.set(p);
|
|
2180
|
+
cg.local.get(k);
|
|
2181
|
+
cg.i32.const(127);
|
|
2182
|
+
cg.i32.add();
|
|
2183
|
+
cg.i32.const(23);
|
|
2184
|
+
cg.i32.shl();
|
|
2185
|
+
cg.f32.reinterpret_i32();
|
|
2186
|
+
cg.local.set(scale);
|
|
2187
|
+
cg.local.get(p);
|
|
2188
|
+
cg.local.get(scale);
|
|
2189
|
+
cg.f32.mul();
|
|
2190
|
+
});
|
|
2191
|
+
}
|
|
2192
|
+
/**
|
|
2193
|
+
* Approximate ln(x), x > 0.
|
|
2194
|
+
*
|
|
2195
|
+
* Method: decompose x = m * 2^e with m in [1,2), e integer (via bit ops)
|
|
2196
|
+
* ln(x) = e*ln2 + ln(m); use atanh-style series with t=(m-1)/(m+1)
|
|
2197
|
+
* ln(m) ≈ 2*(t + t^3/3 + t^5/5 + t^7/7)
|
|
2198
|
+
*/
|
|
2199
|
+
function wasm_log(cg) {
|
|
2200
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2201
|
+
const bits = cg.local.declare(cg.i32);
|
|
2202
|
+
const e = cg.local.declare(cg.i32);
|
|
2203
|
+
const m = cg.local.declare(cg.f32);
|
|
2204
|
+
const t = cg.local.declare(cg.f32);
|
|
2205
|
+
const t2 = cg.local.declare(cg.f32);
|
|
2206
|
+
const t3 = cg.local.declare(cg.f32);
|
|
2207
|
+
const t5 = cg.local.declare(cg.f32);
|
|
2208
|
+
const t7 = cg.local.declare(cg.f32);
|
|
2209
|
+
const lnm = cg.local.declare(cg.f32);
|
|
2210
|
+
const el2 = cg.local.declare(cg.f32);
|
|
2211
|
+
cg.local.get(0);
|
|
2212
|
+
cg.f32.const(0);
|
|
2213
|
+
cg.f32.le();
|
|
2214
|
+
cg.if(cg.void);
|
|
2215
|
+
cg.f32.const(NaN);
|
|
2216
|
+
cg.return();
|
|
2217
|
+
cg.end();
|
|
2218
|
+
cg.local.get(0);
|
|
2219
|
+
cg.i32.reinterpret_f32();
|
|
2220
|
+
cg.local.tee(bits);
|
|
2221
|
+
cg.i32.const(23);
|
|
2222
|
+
cg.i32.shr_u();
|
|
2223
|
+
cg.i32.const(255);
|
|
2224
|
+
cg.i32.and();
|
|
2225
|
+
cg.i32.const(127);
|
|
2226
|
+
cg.i32.sub();
|
|
2227
|
+
cg.local.set(e);
|
|
2228
|
+
cg.local.get(bits);
|
|
2229
|
+
cg.i32.const(8388607);
|
|
2230
|
+
cg.i32.and();
|
|
2231
|
+
cg.i32.const(1065353216);
|
|
2232
|
+
cg.i32.or();
|
|
2233
|
+
cg.f32.reinterpret_i32();
|
|
2234
|
+
cg.local.set(m);
|
|
2235
|
+
cg.local.get(m);
|
|
2236
|
+
cg.f32.const(1);
|
|
2237
|
+
cg.f32.sub();
|
|
2238
|
+
cg.local.get(m);
|
|
2239
|
+
cg.f32.const(1);
|
|
2240
|
+
cg.f32.add();
|
|
2241
|
+
cg.f32.div();
|
|
2242
|
+
cg.local.set(t);
|
|
2243
|
+
cg.local.get(t);
|
|
2244
|
+
cg.local.get(t);
|
|
2245
|
+
cg.f32.mul();
|
|
2246
|
+
cg.local.set(t2);
|
|
2247
|
+
cg.local.get(t);
|
|
2248
|
+
cg.local.get(t2);
|
|
2249
|
+
cg.f32.mul();
|
|
2250
|
+
cg.local.set(t3);
|
|
2251
|
+
cg.local.get(t3);
|
|
2252
|
+
cg.local.get(t2);
|
|
2253
|
+
cg.f32.mul();
|
|
2254
|
+
cg.local.set(t5);
|
|
2255
|
+
cg.local.get(t5);
|
|
2256
|
+
cg.local.get(t2);
|
|
2257
|
+
cg.f32.mul();
|
|
2258
|
+
cg.local.set(t7);
|
|
2259
|
+
cg.local.get(t7);
|
|
2260
|
+
cg.f32.const(1 / 7);
|
|
2261
|
+
cg.f32.mul();
|
|
2262
|
+
cg.local.get(t5);
|
|
2263
|
+
cg.f32.const(1 / 5);
|
|
2264
|
+
cg.f32.mul();
|
|
2265
|
+
cg.f32.add();
|
|
2266
|
+
cg.local.get(t3);
|
|
2267
|
+
cg.f32.const(1 / 3);
|
|
2268
|
+
cg.f32.mul();
|
|
2269
|
+
cg.f32.add();
|
|
2270
|
+
cg.local.get(t);
|
|
2271
|
+
cg.f32.add();
|
|
2272
|
+
cg.f32.const(2);
|
|
2273
|
+
cg.f32.mul();
|
|
2274
|
+
cg.local.set(lnm);
|
|
2275
|
+
cg.local.get(e);
|
|
2276
|
+
cg.f32.convert_i32_s();
|
|
2277
|
+
cg.f32.const(Math.LN2);
|
|
2278
|
+
cg.f32.mul();
|
|
2279
|
+
cg.local.set(el2);
|
|
2280
|
+
cg.local.get(el2);
|
|
2281
|
+
cg.local.get(lnm);
|
|
2282
|
+
cg.f32.add();
|
|
2283
|
+
});
|
|
2284
|
+
}
|
|
2285
|
+
/**
|
|
2286
|
+
* Approximate sin(x).
|
|
2287
|
+
*
|
|
2288
|
+
* Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
|
|
2289
|
+
* z = y - q*(π/2); use odd polynomial on z:
|
|
2290
|
+
* sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
|
|
2291
|
+
*/
|
|
2292
|
+
function wasm_sin(cg) {
|
|
2293
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2294
|
+
const y = cg.local.declare(cg.f32);
|
|
2295
|
+
const qf = cg.local.declare(cg.f32);
|
|
2296
|
+
const q = cg.local.declare(cg.i32);
|
|
2297
|
+
const z = cg.local.declare(cg.f32);
|
|
2298
|
+
const z2 = cg.local.declare(cg.f32);
|
|
2299
|
+
const sz = cg.local.declare(cg.f32);
|
|
2300
|
+
const cz = cg.local.declare(cg.f32);
|
|
2301
|
+
const mag = cg.local.declare(cg.f32);
|
|
2302
|
+
cg.local.get(0);
|
|
2303
|
+
cg.local.get(0);
|
|
2304
|
+
cg.f32.const(1 / (2 * Math.PI));
|
|
2305
|
+
cg.f32.mul();
|
|
2306
|
+
cg.f32.nearest();
|
|
2307
|
+
cg.local.tee(qf);
|
|
2308
|
+
cg.f32.const(2 * Math.PI);
|
|
2309
|
+
cg.f32.mul();
|
|
2310
|
+
cg.f32.sub();
|
|
2311
|
+
cg.local.set(y);
|
|
2312
|
+
cg.local.get(y);
|
|
2313
|
+
cg.f32.const(2 / Math.PI);
|
|
2314
|
+
cg.f32.mul();
|
|
2315
|
+
cg.f32.nearest();
|
|
2316
|
+
cg.local.tee(qf);
|
|
2317
|
+
cg.i32.trunc_f32_s();
|
|
2318
|
+
cg.local.set(q);
|
|
2319
|
+
cg.local.get(y);
|
|
2320
|
+
cg.local.get(qf);
|
|
2321
|
+
cg.f32.const(Math.PI / 2);
|
|
2322
|
+
cg.f32.mul();
|
|
2323
|
+
cg.f32.sub();
|
|
2324
|
+
cg.local.tee(z);
|
|
2325
|
+
cg.local.get(z);
|
|
2326
|
+
cg.f32.mul();
|
|
2327
|
+
cg.local.set(z2);
|
|
2328
|
+
cg.f32.const(-1 / 5040);
|
|
2329
|
+
cg.local.get(z2);
|
|
2330
|
+
cg.f32.mul();
|
|
2331
|
+
cg.f32.const(1 / 120);
|
|
2332
|
+
cg.f32.add();
|
|
2333
|
+
cg.local.get(z2);
|
|
2334
|
+
cg.f32.mul();
|
|
2335
|
+
cg.f32.const(-1 / 6);
|
|
2336
|
+
cg.f32.add();
|
|
2337
|
+
cg.local.get(z2);
|
|
2338
|
+
cg.f32.mul();
|
|
2339
|
+
cg.f32.const(1);
|
|
2340
|
+
cg.f32.add();
|
|
2341
|
+
cg.local.get(z);
|
|
2342
|
+
cg.f32.mul();
|
|
2343
|
+
cg.local.set(sz);
|
|
2344
|
+
cg.f32.const(-1 / 720);
|
|
2345
|
+
cg.local.get(z2);
|
|
2346
|
+
cg.f32.mul();
|
|
2347
|
+
cg.f32.const(1 / 24);
|
|
2348
|
+
cg.f32.add();
|
|
2349
|
+
cg.local.get(z2);
|
|
2350
|
+
cg.f32.mul();
|
|
2351
|
+
cg.f32.const(-1 / 2);
|
|
2352
|
+
cg.f32.add();
|
|
2353
|
+
cg.local.get(z2);
|
|
2354
|
+
cg.f32.mul();
|
|
2355
|
+
cg.f32.const(1);
|
|
2356
|
+
cg.f32.add();
|
|
2357
|
+
cg.local.set(cz);
|
|
2358
|
+
cg.local.get(cz);
|
|
2359
|
+
cg.local.get(sz);
|
|
2360
|
+
cg.local.get(q);
|
|
2361
|
+
cg.i32.const(1);
|
|
2362
|
+
cg.i32.and();
|
|
2363
|
+
cg.select();
|
|
2364
|
+
cg.local.tee(mag);
|
|
2365
|
+
cg.f32.neg();
|
|
2366
|
+
cg.local.get(mag);
|
|
2367
|
+
cg.local.get(q);
|
|
2368
|
+
cg.i32.const(2);
|
|
2369
|
+
cg.i32.and();
|
|
2370
|
+
cg.select();
|
|
2371
|
+
});
|
|
2372
|
+
}
|
|
2373
|
+
/**
|
|
2374
|
+
* Approximate cos(x).
|
|
2375
|
+
*
|
|
2376
|
+
* Same reduction as sinf, then quadrant mapping:
|
|
2377
|
+
* k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
|
|
2378
|
+
*/
|
|
2379
|
+
function wasm_cos(cg) {
|
|
2380
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2381
|
+
const y = cg.local.declare(cg.f32);
|
|
2382
|
+
const qf = cg.local.declare(cg.f32);
|
|
2383
|
+
const q = cg.local.declare(cg.i32);
|
|
2384
|
+
const z = cg.local.declare(cg.f32);
|
|
2385
|
+
const z2 = cg.local.declare(cg.f32);
|
|
2386
|
+
const sz = cg.local.declare(cg.f32);
|
|
2387
|
+
const cz = cg.local.declare(cg.f32);
|
|
2388
|
+
const mag = cg.local.declare(cg.f32);
|
|
2389
|
+
cg.local.get(0);
|
|
2390
|
+
cg.local.get(0);
|
|
2391
|
+
cg.f32.const(1 / (2 * Math.PI));
|
|
2392
|
+
cg.f32.mul();
|
|
2393
|
+
cg.f32.nearest();
|
|
2394
|
+
cg.local.tee(qf);
|
|
2395
|
+
cg.f32.const(2 * Math.PI);
|
|
2396
|
+
cg.f32.mul();
|
|
2397
|
+
cg.f32.sub();
|
|
2398
|
+
cg.local.set(y);
|
|
2399
|
+
cg.local.get(y);
|
|
2400
|
+
cg.f32.const(2 / Math.PI);
|
|
2401
|
+
cg.f32.mul();
|
|
2402
|
+
cg.f32.nearest();
|
|
2403
|
+
cg.local.tee(qf);
|
|
2404
|
+
cg.i32.trunc_f32_s();
|
|
2405
|
+
cg.local.set(q);
|
|
2406
|
+
cg.local.get(y);
|
|
2407
|
+
cg.local.get(qf);
|
|
2408
|
+
cg.f32.const(Math.PI / 2);
|
|
2409
|
+
cg.f32.mul();
|
|
2410
|
+
cg.f32.sub();
|
|
2411
|
+
cg.local.tee(z);
|
|
2412
|
+
cg.local.get(z);
|
|
2413
|
+
cg.f32.mul();
|
|
2414
|
+
cg.local.set(z2);
|
|
2415
|
+
cg.f32.const(-1 / 5040);
|
|
2416
|
+
cg.local.get(z2);
|
|
2417
|
+
cg.f32.mul();
|
|
2418
|
+
cg.f32.const(1 / 120);
|
|
2419
|
+
cg.f32.add();
|
|
2420
|
+
cg.local.get(z2);
|
|
2421
|
+
cg.f32.mul();
|
|
2422
|
+
cg.f32.const(-1 / 6);
|
|
2423
|
+
cg.f32.add();
|
|
2424
|
+
cg.local.get(z2);
|
|
2425
|
+
cg.f32.mul();
|
|
2426
|
+
cg.f32.const(1);
|
|
2427
|
+
cg.f32.add();
|
|
2428
|
+
cg.local.get(z);
|
|
2429
|
+
cg.f32.mul();
|
|
2430
|
+
cg.local.set(sz);
|
|
2431
|
+
cg.f32.const(-1 / 720);
|
|
2432
|
+
cg.local.get(z2);
|
|
2433
|
+
cg.f32.mul();
|
|
2434
|
+
cg.f32.const(1 / 24);
|
|
2435
|
+
cg.f32.add();
|
|
2436
|
+
cg.local.get(z2);
|
|
2437
|
+
cg.f32.mul();
|
|
2438
|
+
cg.f32.const(-1 / 2);
|
|
2439
|
+
cg.f32.add();
|
|
2440
|
+
cg.local.get(z2);
|
|
2441
|
+
cg.f32.mul();
|
|
2442
|
+
cg.f32.const(1);
|
|
2443
|
+
cg.f32.add();
|
|
2444
|
+
cg.local.set(cz);
|
|
2445
|
+
cg.local.get(sz);
|
|
2446
|
+
cg.local.get(cz);
|
|
2447
|
+
cg.local.get(q);
|
|
2448
|
+
cg.i32.const(1);
|
|
2449
|
+
cg.i32.and();
|
|
2450
|
+
cg.select();
|
|
2451
|
+
cg.local.tee(mag);
|
|
2452
|
+
cg.f32.neg();
|
|
2453
|
+
cg.local.get(mag);
|
|
2454
|
+
cg.local.get(q);
|
|
2455
|
+
cg.i32.const(1);
|
|
2456
|
+
cg.i32.add();
|
|
2457
|
+
cg.i32.const(2);
|
|
2458
|
+
cg.i32.and();
|
|
2459
|
+
cg.select();
|
|
2460
|
+
});
|
|
2461
|
+
}
|
|
2462
|
+
/**
|
|
2463
|
+
* Threefry2x32 pseudorandom number generator.
|
|
2464
|
+
*
|
|
2465
|
+
* Takes two 32-bit keys and two 32-bit counters as input,
|
|
2466
|
+
* returns two 32-bit pseudorandom values.
|
|
2467
|
+
*/
|
|
2468
|
+
function wasm_threefry2x32(cg) {
|
|
2469
|
+
return cg.function([
|
|
2470
|
+
cg.i32,
|
|
2471
|
+
cg.i32,
|
|
2472
|
+
cg.i32,
|
|
2473
|
+
cg.i32
|
|
2474
|
+
], [cg.i32, cg.i32], () => {
|
|
2475
|
+
const ks0 = cg.local.declare(cg.i32);
|
|
2476
|
+
const ks1 = cg.local.declare(cg.i32);
|
|
2477
|
+
const ks2 = cg.local.declare(cg.i32);
|
|
2478
|
+
const x0 = cg.local.declare(cg.i32);
|
|
2479
|
+
const x1 = cg.local.declare(cg.i32);
|
|
2480
|
+
const mix = (rot) => {
|
|
2481
|
+
cg.local.get(x0);
|
|
2482
|
+
cg.local.get(x1);
|
|
2483
|
+
cg.i32.add();
|
|
2484
|
+
cg.local.set(x0);
|
|
2485
|
+
cg.local.get(x1);
|
|
2486
|
+
cg.i32.const(rot);
|
|
2487
|
+
cg.i32.rotl();
|
|
2488
|
+
cg.local.get(x0);
|
|
2489
|
+
cg.i32.xor();
|
|
2490
|
+
cg.local.set(x1);
|
|
2491
|
+
};
|
|
2492
|
+
const keySchedule = (k0, k1, round) => {
|
|
2493
|
+
cg.local.get(x0);
|
|
2494
|
+
cg.local.get(k0);
|
|
2495
|
+
cg.i32.add();
|
|
2496
|
+
cg.local.set(x0);
|
|
2497
|
+
cg.local.get(x1);
|
|
2498
|
+
cg.local.get(k1);
|
|
2499
|
+
cg.i32.add();
|
|
2500
|
+
cg.i32.const(round);
|
|
2501
|
+
cg.i32.add();
|
|
2502
|
+
cg.local.set(x1);
|
|
2503
|
+
};
|
|
2504
|
+
cg.local.get(0);
|
|
2505
|
+
cg.local.set(ks0);
|
|
2506
|
+
cg.local.get(1);
|
|
2507
|
+
cg.local.set(ks1);
|
|
2508
|
+
cg.local.get(0);
|
|
2509
|
+
cg.local.get(1);
|
|
2510
|
+
cg.i32.xor();
|
|
2511
|
+
cg.i32.const(466688986);
|
|
2512
|
+
cg.i32.xor();
|
|
2513
|
+
cg.local.set(ks2);
|
|
2514
|
+
cg.local.get(2);
|
|
2515
|
+
cg.local.get(ks0);
|
|
2516
|
+
cg.i32.add();
|
|
2517
|
+
cg.local.set(x0);
|
|
2518
|
+
cg.local.get(3);
|
|
2519
|
+
cg.local.get(ks1);
|
|
2520
|
+
cg.i32.add();
|
|
2521
|
+
cg.local.set(x1);
|
|
2522
|
+
mix(13), mix(15), mix(26), mix(6);
|
|
2523
|
+
keySchedule(ks1, ks2, 1);
|
|
2524
|
+
mix(17), mix(29), mix(16), mix(24);
|
|
2525
|
+
keySchedule(ks2, ks0, 2);
|
|
2526
|
+
mix(13), mix(15), mix(26), mix(6);
|
|
2527
|
+
keySchedule(ks0, ks1, 3);
|
|
2528
|
+
mix(17), mix(29), mix(16), mix(24);
|
|
2529
|
+
keySchedule(ks1, ks2, 4);
|
|
2530
|
+
mix(13), mix(15), mix(26), mix(6);
|
|
2531
|
+
keySchedule(ks2, ks0, 5);
|
|
2532
|
+
cg.local.get(x0);
|
|
2533
|
+
cg.local.get(x1);
|
|
2534
|
+
});
|
|
2535
|
+
}
|
|
2536
|
+
|
|
2537
|
+
//#endregion
|
|
2538
|
+
//#region src/backend/wasm/wasmblr.ts
|
|
2539
|
+
/**
|
|
2540
|
+
* @file Minimalist WebAssembly assembler. This allows you to emit WebAssembly
|
|
2541
|
+
* bytecode directly from the browser.
|
|
2542
|
+
*
|
|
2543
|
+
* Self-contained port of https://github.com/bwasti/wasmblr to TypeScript.
|
|
2544
|
+
* Some operation names in this module are written in `snake_case` to match
|
|
2545
|
+
* their names in the Wasm specification.
|
|
2546
|
+
*
|
|
2547
|
+
* Reference: https://pengowray.github.io/wasm-ops/.
|
|
2548
|
+
*/
|
|
2549
|
+
const magicModuleHeader = [
|
|
2550
|
+
0,
|
|
2551
|
+
97,
|
|
2552
|
+
115,
|
|
2553
|
+
109
|
|
2554
|
+
];
|
|
2555
|
+
const moduleVersion = [
|
|
2556
|
+
1,
|
|
2557
|
+
0,
|
|
2558
|
+
0,
|
|
2559
|
+
0
|
|
2560
|
+
];
|
|
2561
|
+
function assert(condition, message) {
|
|
2562
|
+
if (!condition) throw new Error(message || "Assertion failed");
|
|
2563
|
+
}
|
|
2564
|
+
function encodeSigned(n) {
|
|
2565
|
+
const out = [];
|
|
2566
|
+
let more = true;
|
|
2567
|
+
while (more) {
|
|
2568
|
+
let byte = n & 127;
|
|
2569
|
+
n >>= 7;
|
|
2570
|
+
if (n === 0 && (byte & 64) === 0 || n === -1 && (byte & 64) !== 0) more = false;
|
|
2571
|
+
else byte |= 128;
|
|
2572
|
+
out.push(byte);
|
|
2573
|
+
}
|
|
2574
|
+
return out;
|
|
2575
|
+
}
|
|
2576
|
+
function encodeUnsigned(n) {
|
|
2577
|
+
const out = [];
|
|
2578
|
+
do {
|
|
2579
|
+
let byte = n & 127;
|
|
2580
|
+
n = n >>> 7;
|
|
2581
|
+
if (n !== 0) byte |= 128;
|
|
2582
|
+
out.push(byte);
|
|
2583
|
+
} while (n !== 0);
|
|
2584
|
+
return out;
|
|
2585
|
+
}
|
|
2586
|
+
function encodeString(s) {
|
|
2587
|
+
const bytes = new TextEncoder().encode(s);
|
|
2588
|
+
return [bytes.length, ...bytes];
|
|
2589
|
+
}
|
|
2590
|
+
function encodeBlocktype(type) {
|
|
2591
|
+
assert(type.length > 0, "blocktype must have at least one type");
|
|
2592
|
+
if (type.length === 1) return [type[0].typeId];
|
|
2593
|
+
return [
|
|
2594
|
+
96,
|
|
2595
|
+
...encodeUnsigned(0),
|
|
2596
|
+
...encodeUnsigned(type.length),
|
|
2597
|
+
...type.map((t) => t.typeId)
|
|
2598
|
+
];
|
|
2599
|
+
}
|
|
2600
|
+
function encodeOpcode(opcode) {
|
|
2601
|
+
if (typeof opcode === "number") return [opcode];
|
|
2602
|
+
return [opcode[0], ...encodeUnsigned(opcode[1])];
|
|
2603
|
+
}
|
|
2604
|
+
function concat(out, inp) {
|
|
2605
|
+
out.push(...inp);
|
|
2606
|
+
}
|
|
2607
|
+
var Function_ = class {
|
|
2608
|
+
inputTypes;
|
|
2609
|
+
outputTypes;
|
|
2610
|
+
body;
|
|
2611
|
+
locals = [];
|
|
2612
|
+
constructor(inputTypes, outputTypes, body) {
|
|
2613
|
+
this.inputTypes = inputTypes;
|
|
2614
|
+
this.outputTypes = outputTypes;
|
|
2615
|
+
this.body = body || (() => {});
|
|
2616
|
+
}
|
|
2617
|
+
emit() {
|
|
2618
|
+
this.locals = [];
|
|
2619
|
+
this.body();
|
|
2620
|
+
}
|
|
2621
|
+
};
|
|
2622
|
+
var Memory = class {
|
|
2623
|
+
min = 0;
|
|
2624
|
+
max = 0;
|
|
2625
|
+
isShared = false;
|
|
2626
|
+
aString = "";
|
|
2627
|
+
bString = "";
|
|
2628
|
+
constructor(cg) {
|
|
2629
|
+
this.cg = cg;
|
|
2630
|
+
}
|
|
2631
|
+
/** Declare the size of the memory. Each page is 64 KiB. */
|
|
2632
|
+
pages(min, max = 0) {
|
|
2633
|
+
assert(this.min === 0 && this.max === 0);
|
|
2634
|
+
this.min = min;
|
|
2635
|
+
this.max = max;
|
|
2636
|
+
return this;
|
|
2637
|
+
}
|
|
2638
|
+
export(a) {
|
|
2639
|
+
assert(!this.isImport && !this.isExport, "already set");
|
|
2640
|
+
this.aString = a;
|
|
2641
|
+
return this;
|
|
2642
|
+
}
|
|
2643
|
+
shared(isShared) {
|
|
2644
|
+
this.isShared = isShared;
|
|
2645
|
+
return this;
|
|
2646
|
+
}
|
|
2647
|
+
import(a, b) {
|
|
2648
|
+
assert(!this.isImport && !this.isExport, "already set");
|
|
2649
|
+
this.aString = a;
|
|
2650
|
+
this.bString = b;
|
|
2651
|
+
return this;
|
|
2652
|
+
}
|
|
2653
|
+
size() {
|
|
2654
|
+
this.cg._emit(63);
|
|
2655
|
+
this.cg._emit(0);
|
|
2656
|
+
}
|
|
2657
|
+
grow() {
|
|
2658
|
+
this.cg._emit(64);
|
|
2659
|
+
this.cg._emit(0);
|
|
2660
|
+
}
|
|
2661
|
+
get isImport() {
|
|
2662
|
+
return this.aString.length > 0 && this.bString.length > 0;
|
|
2663
|
+
}
|
|
2664
|
+
get isExport() {
|
|
2665
|
+
return this.aString.length > 0 && this.bString.length === 0;
|
|
2666
|
+
}
|
|
2667
|
+
};
|
|
2668
|
+
/** Public API of WebAssembly assembler. */
|
|
2669
|
+
var CodeGenerator = class {
|
|
2670
|
+
local;
|
|
2671
|
+
i32;
|
|
2672
|
+
f32;
|
|
2673
|
+
v128;
|
|
2674
|
+
i32x4;
|
|
2675
|
+
f32x4;
|
|
2676
|
+
memory;
|
|
2677
|
+
void = {
|
|
2678
|
+
typeId: 64,
|
|
2679
|
+
name: "void"
|
|
2680
|
+
};
|
|
2681
|
+
#functions = [];
|
|
2682
|
+
#importedFunctions = [];
|
|
2683
|
+
#exportedFunctions = /* @__PURE__ */ new Map();
|
|
2684
|
+
#curFunction = null;
|
|
2685
|
+
#curBytes = [];
|
|
2686
|
+
#typeStack = [];
|
|
2687
|
+
#blockFrames = [];
|
|
2688
|
+
constructor() {
|
|
2689
|
+
this.local = new Local(this);
|
|
2690
|
+
this.i32 = new I32(this);
|
|
2691
|
+
this.f32 = new F32(this);
|
|
2692
|
+
this.v128 = new V128(this);
|
|
2693
|
+
this.i32x4 = new I32x4(this);
|
|
2694
|
+
this.f32x4 = new F32x4(this);
|
|
2695
|
+
this.memory = new Memory(this);
|
|
2696
|
+
}
|
|
2697
|
+
unreachable() {
|
|
2698
|
+
this._emit(0);
|
|
2699
|
+
}
|
|
2700
|
+
nop() {
|
|
2701
|
+
this._emit(1);
|
|
2702
|
+
}
|
|
2703
|
+
block(...type) {
|
|
2704
|
+
this.#blockFrames.push({
|
|
2705
|
+
idx: this.#typeStack.length,
|
|
2706
|
+
ty: type
|
|
2707
|
+
});
|
|
2708
|
+
this._emit(2);
|
|
2709
|
+
this._emit(encodeBlocktype(type));
|
|
2710
|
+
}
|
|
2711
|
+
loop(...type) {
|
|
2712
|
+
this.#blockFrames.push({
|
|
2713
|
+
idx: this.#typeStack.length,
|
|
2714
|
+
ty: type
|
|
2715
|
+
});
|
|
2716
|
+
this._emit(3);
|
|
2717
|
+
this._emit(encodeBlocktype(type));
|
|
2718
|
+
}
|
|
2719
|
+
if(...type) {
|
|
2720
|
+
assert(this._pop().typeId === this.i32.typeId, "if_: expected i32");
|
|
2721
|
+
this.#blockFrames.push({
|
|
2722
|
+
idx: this.#typeStack.length,
|
|
2723
|
+
ty: type
|
|
2724
|
+
});
|
|
2725
|
+
this._emit(4);
|
|
2726
|
+
this._emit(encodeBlocktype(type));
|
|
2727
|
+
}
|
|
2728
|
+
else() {
|
|
2729
|
+
assert(this.#blockFrames.length > 0, "else: no block to else");
|
|
2730
|
+
const frame = this.#blockFrames[this.#blockFrames.length - 1];
|
|
2731
|
+
this.#typeStack = this.#typeStack.slice(0, frame.idx);
|
|
2732
|
+
this._emit(5);
|
|
2733
|
+
}
|
|
2734
|
+
/** End a block (`block`, `if`/`else`, `loop`, or function). */
|
|
2735
|
+
end() {
|
|
2736
|
+
const frame = this.#blockFrames.pop();
|
|
2737
|
+
assert(frame !== void 0, "end: no block to end");
|
|
2738
|
+
this.#typeStack = this.#typeStack.slice(0, frame.idx);
|
|
2739
|
+
for (const ty of frame.ty) if (ty.typeId !== this.void.typeId) this._push(ty);
|
|
2740
|
+
this._emit(11);
|
|
2741
|
+
}
|
|
2742
|
+
/** Branch to a block a certain depth outward on the stack. */
|
|
2743
|
+
br(depth) {
|
|
2744
|
+
this._emit(12);
|
|
2745
|
+
this._emit(encodeUnsigned(depth));
|
|
2746
|
+
}
|
|
2747
|
+
/** Conditional branch to a block a certain depth outward on the stack. */
|
|
2748
|
+
br_if(depth) {
|
|
2749
|
+
assert(this._pop().typeId === this.i32.typeId, "br_if: expected i32");
|
|
2750
|
+
this._emit(13);
|
|
2751
|
+
this._emit(encodeUnsigned(depth));
|
|
2752
|
+
}
|
|
2753
|
+
/** Jump table that indexes into a label vector (like switch). */
|
|
2754
|
+
br_table(...depths) {
|
|
2755
|
+
assert(this._pop().typeId === this.i32.typeId, "br_table: expected i32");
|
|
2756
|
+
assert(depths.length > 0, "br_table: expected at least one default depth");
|
|
2757
|
+
this._emit(14);
|
|
2758
|
+
this._emit(encodeUnsigned(depths.length - 1));
|
|
2759
|
+
for (const d of depths) this._emit(encodeUnsigned(d));
|
|
2760
|
+
}
|
|
2761
|
+
/** Return from a function, branching out of the outermost block. */
|
|
2762
|
+
return() {
|
|
2763
|
+
this._emit(15);
|
|
2764
|
+
}
|
|
2765
|
+
/** Call a function with the given ID. */
|
|
2766
|
+
call(fn) {
|
|
2767
|
+
const totalFunctions = this.#importedFunctions.length + this.#functions.length;
|
|
2768
|
+
assert(fn < totalFunctions, "function index does not exist");
|
|
2769
|
+
const func = fn < this.#importedFunctions.length ? this.#importedFunctions[fn] : this.#functions[fn - this.#importedFunctions.length];
|
|
2770
|
+
for (let i = func.inputTypes.length - 1; i >= 0; i--) {
|
|
2771
|
+
const argType = this._pop();
|
|
2772
|
+
assert(argType.typeId === func.inputTypes[i].typeId, `call: argument ${i} type mismatch, expected ${func.inputTypes[i].name} got ${argType.name}`);
|
|
2773
|
+
}
|
|
2774
|
+
for (const outputType of func.outputTypes) this._push(outputType);
|
|
2775
|
+
this._emit(16);
|
|
2776
|
+
this._emit(encodeUnsigned(fn));
|
|
2777
|
+
}
|
|
2778
|
+
/** Throw away an operand on the stack. */
|
|
2779
|
+
drop() {
|
|
2780
|
+
this._pop();
|
|
2781
|
+
this._emit(26);
|
|
2782
|
+
}
|
|
2783
|
+
/** Select one of the first two operands (T, F) based on the third operand (i32)'s value. */
|
|
2784
|
+
select() {
|
|
2785
|
+
assert(this._pop().typeId === this.i32.typeId, "select: expected i32 condition");
|
|
2786
|
+
const [b, a] = [this._pop(), this._pop()];
|
|
2787
|
+
assert(a.typeId === b.typeId, "select: expected same type for both operands");
|
|
2788
|
+
this._push(a);
|
|
2789
|
+
this._emit(27);
|
|
2790
|
+
}
|
|
2791
|
+
/** Import a JavaScript function; returns its index. */
|
|
2792
|
+
importFunction(module, name, inputTypes, outputTypes) {
|
|
2793
|
+
if (this.#functions.length > 0) throw new Error("function imports must precede defining functions");
|
|
2794
|
+
const idx = this.#importedFunctions.length;
|
|
2795
|
+
this.#importedFunctions.push({
|
|
2796
|
+
module,
|
|
2797
|
+
name,
|
|
2798
|
+
inputTypes,
|
|
2799
|
+
outputTypes
|
|
2800
|
+
});
|
|
2801
|
+
return idx;
|
|
2802
|
+
}
|
|
2803
|
+
/** Export a function. */
|
|
2804
|
+
export(fn, name) {
|
|
2805
|
+
this.#exportedFunctions.set(fn, name);
|
|
2806
|
+
}
|
|
2807
|
+
/** Declare a new function; returns its index. */
|
|
2808
|
+
function(inputTypes, outputTypes, body) {
|
|
2809
|
+
const idx = this.#importedFunctions.length + this.#functions.length;
|
|
2810
|
+
this.#functions.push(new Function_(inputTypes, outputTypes, body));
|
|
2811
|
+
return idx;
|
|
2812
|
+
}
|
|
2813
|
+
_declareLocal(type) {
|
|
2814
|
+
assert(this.#curFunction !== null, "No current function");
|
|
2815
|
+
const idx = this.#curFunction.locals.length + this.#curFunction.inputTypes.length;
|
|
2816
|
+
this.#curFunction.locals.push(type);
|
|
2817
|
+
return idx;
|
|
2818
|
+
}
|
|
2819
|
+
_inputTypes() {
|
|
2820
|
+
assert(this.#curFunction !== null, "No current function");
|
|
2821
|
+
return this.#curFunction.inputTypes;
|
|
2822
|
+
}
|
|
2823
|
+
_locals() {
|
|
2824
|
+
assert(this.#curFunction !== null, "No current function");
|
|
2825
|
+
return this.#curFunction.locals;
|
|
2826
|
+
}
|
|
2827
|
+
_push(type) {
|
|
2828
|
+
if (!type) throw new Error(`pushing type ${type}`);
|
|
2829
|
+
this.#typeStack.push(type);
|
|
2830
|
+
}
|
|
2831
|
+
_pop() {
|
|
2832
|
+
assert(this.#typeStack.length > 0, "popping empty stack");
|
|
2833
|
+
return this.#typeStack.pop();
|
|
2834
|
+
}
|
|
2835
|
+
_emit(bytes) {
|
|
2836
|
+
if (typeof bytes === "number") this.#curBytes.push(bytes);
|
|
2837
|
+
else this.#curBytes.push(...bytes);
|
|
2838
|
+
}
|
|
2839
|
+
finish() {
|
|
2840
|
+
this.#curBytes = [];
|
|
2841
|
+
const emittedBytes = [];
|
|
2842
|
+
concat(emittedBytes, magicModuleHeader);
|
|
2843
|
+
concat(emittedBytes, moduleVersion);
|
|
2844
|
+
const typeSectionBytes = [];
|
|
2845
|
+
const totalFunctionTypes = this.#importedFunctions.length + this.#functions.length;
|
|
2846
|
+
concat(typeSectionBytes, encodeUnsigned(totalFunctionTypes));
|
|
2847
|
+
for (const f of [...this.#importedFunctions, ...this.#functions]) {
|
|
2848
|
+
typeSectionBytes.push(96);
|
|
2849
|
+
concat(typeSectionBytes, encodeUnsigned(f.inputTypes.length));
|
|
2850
|
+
for (const t of f.inputTypes) typeSectionBytes.push(t.typeId);
|
|
2851
|
+
concat(typeSectionBytes, encodeUnsigned(f.outputTypes.length));
|
|
2852
|
+
for (const t of f.outputTypes) typeSectionBytes.push(t.typeId);
|
|
2853
|
+
}
|
|
2854
|
+
emittedBytes.push(1);
|
|
2855
|
+
concat(emittedBytes, encodeUnsigned(typeSectionBytes.length));
|
|
2856
|
+
concat(emittedBytes, typeSectionBytes);
|
|
2857
|
+
const importSectionBytes = [];
|
|
2858
|
+
const numImports = this.#importedFunctions.length + (this.memory.isImport ? 1 : 0);
|
|
2859
|
+
if (numImports > 0) {
|
|
2860
|
+
concat(importSectionBytes, encodeUnsigned(numImports));
|
|
2861
|
+
for (let i = 0; i < this.#importedFunctions.length; i++) {
|
|
2862
|
+
const f = this.#importedFunctions[i];
|
|
2863
|
+
concat(importSectionBytes, encodeString(f.module));
|
|
2864
|
+
concat(importSectionBytes, encodeString(f.name));
|
|
2865
|
+
importSectionBytes.push(0);
|
|
2866
|
+
concat(importSectionBytes, encodeUnsigned(i));
|
|
2867
|
+
}
|
|
2868
|
+
if (this.memory.isImport) {
|
|
2869
|
+
concat(importSectionBytes, encodeString(this.memory.aString));
|
|
2870
|
+
concat(importSectionBytes, encodeString(this.memory.bString));
|
|
2871
|
+
importSectionBytes.push(2);
|
|
2872
|
+
if (this.memory.min && this.memory.max) {
|
|
2873
|
+
if (this.memory.isShared) importSectionBytes.push(3);
|
|
2874
|
+
else importSectionBytes.push(1);
|
|
2875
|
+
concat(importSectionBytes, encodeUnsigned(this.memory.min));
|
|
2876
|
+
concat(importSectionBytes, encodeUnsigned(this.memory.max));
|
|
2877
|
+
} else {
|
|
2878
|
+
assert(!this.memory.isShared, "shared memory must have a max size");
|
|
2879
|
+
importSectionBytes.push(0);
|
|
2880
|
+
concat(importSectionBytes, encodeUnsigned(this.memory.min));
|
|
2881
|
+
}
|
|
2882
|
+
}
|
|
2883
|
+
emittedBytes.push(2);
|
|
2884
|
+
concat(emittedBytes, encodeUnsigned(importSectionBytes.length));
|
|
2885
|
+
concat(emittedBytes, importSectionBytes);
|
|
2886
|
+
}
|
|
2887
|
+
const functionSectionBytes = [];
|
|
2888
|
+
concat(functionSectionBytes, encodeUnsigned(this.#functions.length));
|
|
2889
|
+
for (let i = 0; i < this.#functions.length; i++) {
|
|
2890
|
+
const typeIndex = this.#importedFunctions.length + i;
|
|
2891
|
+
concat(functionSectionBytes, encodeUnsigned(typeIndex));
|
|
2892
|
+
}
|
|
2893
|
+
emittedBytes.push(3);
|
|
2894
|
+
concat(emittedBytes, encodeUnsigned(functionSectionBytes.length));
|
|
2895
|
+
concat(emittedBytes, functionSectionBytes);
|
|
2896
|
+
const memorySectionBytes = [];
|
|
2897
|
+
if (!this.memory.isImport && (this.memory.min || this.memory.max)) {
|
|
2898
|
+
memorySectionBytes.push(1);
|
|
2899
|
+
if (this.memory.min && this.memory.max) {
|
|
2900
|
+
if (this.memory.isShared) memorySectionBytes.push(3);
|
|
2901
|
+
else memorySectionBytes.push(1);
|
|
2902
|
+
concat(memorySectionBytes, encodeUnsigned(this.memory.min));
|
|
2903
|
+
concat(memorySectionBytes, encodeUnsigned(this.memory.max));
|
|
2904
|
+
} else {
|
|
2905
|
+
assert(!this.memory.isShared, "shared memory must have a max size");
|
|
2906
|
+
memorySectionBytes.push(0);
|
|
2907
|
+
concat(memorySectionBytes, encodeUnsigned(this.memory.min));
|
|
2908
|
+
}
|
|
2909
|
+
emittedBytes.push(5);
|
|
2910
|
+
concat(emittedBytes, encodeUnsigned(memorySectionBytes.length));
|
|
2911
|
+
concat(emittedBytes, memorySectionBytes);
|
|
2912
|
+
}
|
|
2913
|
+
const exportSectionBytes = [];
|
|
2914
|
+
const numExports = this.#exportedFunctions.size + (this.memory.isExport ? 1 : 0);
|
|
2915
|
+
concat(exportSectionBytes, encodeUnsigned(numExports));
|
|
2916
|
+
if (this.memory.isExport) {
|
|
2917
|
+
concat(exportSectionBytes, encodeString(this.memory.aString));
|
|
2918
|
+
exportSectionBytes.push(2);
|
|
2919
|
+
exportSectionBytes.push(0);
|
|
2920
|
+
}
|
|
2921
|
+
for (const [key, name] of this.#exportedFunctions.entries()) {
|
|
2922
|
+
concat(exportSectionBytes, encodeString(name));
|
|
2923
|
+
exportSectionBytes.push(0);
|
|
2924
|
+
concat(exportSectionBytes, encodeUnsigned(key));
|
|
2925
|
+
}
|
|
2926
|
+
emittedBytes.push(7);
|
|
2927
|
+
concat(emittedBytes, encodeUnsigned(exportSectionBytes.length));
|
|
2928
|
+
concat(emittedBytes, exportSectionBytes);
|
|
2929
|
+
const codeSectionBytes = [];
|
|
2930
|
+
concat(codeSectionBytes, encodeUnsigned(this.#functions.length));
|
|
2931
|
+
for (const f of this.#functions) {
|
|
2932
|
+
this.#typeStack = [];
|
|
2933
|
+
this.#blockFrames = [{
|
|
2934
|
+
idx: 0,
|
|
2935
|
+
ty: f.outputTypes
|
|
2936
|
+
}];
|
|
2937
|
+
this.#curFunction = f;
|
|
2938
|
+
this.#curBytes = [];
|
|
2939
|
+
f.emit();
|
|
2940
|
+
this.end();
|
|
2941
|
+
const bodyBytes = [...this.#curBytes];
|
|
2942
|
+
this.#curBytes = [];
|
|
2943
|
+
concat(this.#curBytes, encodeUnsigned(f.locals.length));
|
|
2944
|
+
for (const l of f.locals) {
|
|
2945
|
+
this._emit(1);
|
|
2946
|
+
this._emit(l.typeId);
|
|
2947
|
+
}
|
|
2948
|
+
const headerBytes = [...this.#curBytes];
|
|
2949
|
+
const fnSize = headerBytes.length + bodyBytes.length;
|
|
2950
|
+
concat(codeSectionBytes, encodeUnsigned(fnSize));
|
|
2951
|
+
concat(codeSectionBytes, headerBytes);
|
|
2952
|
+
concat(codeSectionBytes, bodyBytes);
|
|
2953
|
+
}
|
|
2954
|
+
this.#curFunction = null;
|
|
2955
|
+
emittedBytes.push(10);
|
|
2956
|
+
concat(emittedBytes, encodeUnsigned(codeSectionBytes.length));
|
|
2957
|
+
concat(emittedBytes, codeSectionBytes);
|
|
2958
|
+
return new Uint8Array(emittedBytes);
|
|
2959
|
+
}
|
|
2960
|
+
};
|
|
2961
|
+
var Local = class {
|
|
2962
|
+
constructor(cg) {
|
|
2963
|
+
this.cg = cg;
|
|
2964
|
+
}
|
|
2965
|
+
declare(type) {
|
|
2966
|
+
return this.cg._declareLocal(type);
|
|
2967
|
+
}
|
|
2968
|
+
get(idx) {
|
|
2969
|
+
assert(Number.isInteger(idx), "getting non-integer local");
|
|
2970
|
+
const inputTypes = this.cg._inputTypes();
|
|
2971
|
+
if (idx < inputTypes.length) this.cg._push(inputTypes[idx]);
|
|
2972
|
+
else this.cg._push(this.cg._locals()[idx - inputTypes.length]);
|
|
2973
|
+
this.cg._emit(32);
|
|
2974
|
+
this.cg._emit(encodeUnsigned(idx));
|
|
2975
|
+
}
|
|
2976
|
+
set(idx) {
|
|
2977
|
+
const t = this.cg._pop();
|
|
2978
|
+
const inputTypes = this.cg._inputTypes();
|
|
2979
|
+
const expectedType = idx < inputTypes.length ? inputTypes[idx] : this.cg._locals()[idx - inputTypes.length];
|
|
2980
|
+
assert(expectedType.typeId === t.typeId, "can't set local to this value (wrong type)");
|
|
2981
|
+
this.cg._emit(33);
|
|
2982
|
+
this.cg._emit(encodeUnsigned(idx));
|
|
2983
|
+
}
|
|
2984
|
+
tee(idx) {
|
|
2985
|
+
const t = this.cg._pop();
|
|
2986
|
+
const inputTypes = this.cg._inputTypes();
|
|
2987
|
+
const expectedType = idx < inputTypes.length ? inputTypes[idx] : this.cg._locals()[idx - inputTypes.length];
|
|
2988
|
+
assert(expectedType.typeId === t.typeId, "can't tee local to this value (wrong type)");
|
|
2989
|
+
this.cg._emit(34);
|
|
2990
|
+
this.cg._emit(encodeUnsigned(idx));
|
|
2991
|
+
this.cg._push(expectedType);
|
|
2992
|
+
}
|
|
2993
|
+
};
|
|
2994
|
+
function UNARY_OP(op, opcode, inType, outType) {
|
|
2995
|
+
return function() {
|
|
2996
|
+
const t = this.cg._pop();
|
|
2997
|
+
assert(t.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inType} -> ${outType})`);
|
|
2998
|
+
this.cg._emit(encodeOpcode(opcode));
|
|
2999
|
+
this.cg._push(this.cg[outType]);
|
|
3000
|
+
};
|
|
3001
|
+
}
|
|
3002
|
+
function BINARY_OP(op, opcode, typeA, typeB, outType) {
|
|
3003
|
+
return function() {
|
|
3004
|
+
const b = this.cg._pop();
|
|
3005
|
+
const a = this.cg._pop();
|
|
3006
|
+
assert(a.typeId === this.cg[typeA].typeId && b.typeId === this.cg[typeB].typeId, `invalid type for ${op} (${typeA}, ${typeB} -> ${outType})`);
|
|
3007
|
+
this.cg._emit(encodeOpcode(opcode));
|
|
3008
|
+
this.cg._push(this.cg[outType]);
|
|
3009
|
+
};
|
|
3010
|
+
}
|
|
3011
|
+
function LOAD_OP(op, opcode, outType) {
|
|
3012
|
+
return function(align = 0, offset = 0) {
|
|
3013
|
+
const idxType = this.cg._pop();
|
|
3014
|
+
assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
|
|
3015
|
+
this.cg._emit(encodeOpcode(opcode));
|
|
3016
|
+
this.cg._emit(encodeUnsigned(align));
|
|
3017
|
+
this.cg._emit(encodeUnsigned(offset));
|
|
3018
|
+
this.cg._push(this.cg[outType]);
|
|
3019
|
+
};
|
|
3020
|
+
}
|
|
3021
|
+
function STORE_OP(op, opcode, inType) {
|
|
3022
|
+
return function(align = 0, offset = 0) {
|
|
3023
|
+
const valType = this.cg._pop();
|
|
3024
|
+
const idxType = this.cg._pop();
|
|
3025
|
+
assert(valType.typeId === this.cg[inType].typeId, `invalid value type for ${op} (${inType})`);
|
|
3026
|
+
assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
|
|
3027
|
+
this.cg._emit(encodeOpcode(opcode));
|
|
3028
|
+
this.cg._emit(encodeUnsigned(align));
|
|
3029
|
+
this.cg._emit(encodeUnsigned(offset));
|
|
3030
|
+
};
|
|
3031
|
+
}
|
|
3032
|
+
var I32 = class {
|
|
3033
|
+
constructor(cg) {
|
|
3034
|
+
this.cg = cg;
|
|
3035
|
+
}
|
|
3036
|
+
get typeId() {
|
|
3037
|
+
return 127;
|
|
3038
|
+
}
|
|
3039
|
+
get name() {
|
|
3040
|
+
return "i32";
|
|
3041
|
+
}
|
|
3042
|
+
const(i) {
|
|
3043
|
+
this.cg._emit(65);
|
|
3044
|
+
this.cg._emit(encodeSigned(i));
|
|
3045
|
+
this.cg._push(this);
|
|
3046
|
+
}
|
|
3047
|
+
clz = UNARY_OP("clz", 103, "i32", "i32");
|
|
3048
|
+
ctz = UNARY_OP("ctz", 104, "i32", "i32");
|
|
3049
|
+
popcnt = UNARY_OP("popcnt", 105, "i32", "i32");
|
|
3050
|
+
lt_s = BINARY_OP("lt_s", 72, "i32", "i32", "i32");
|
|
3051
|
+
lt_u = BINARY_OP("lt_u", 73, "i32", "i32", "i32");
|
|
3052
|
+
gt_s = BINARY_OP("gt_s", 74, "i32", "i32", "i32");
|
|
3053
|
+
gt_u = BINARY_OP("gt_u", 75, "i32", "i32", "i32");
|
|
3054
|
+
le_s = BINARY_OP("le_s", 76, "i32", "i32", "i32");
|
|
3055
|
+
le_u = BINARY_OP("le_u", 77, "i32", "i32", "i32");
|
|
3056
|
+
ge_s = BINARY_OP("ge_s", 78, "i32", "i32", "i32");
|
|
3057
|
+
ge_u = BINARY_OP("ge_u", 79, "i32", "i32", "i32");
|
|
3058
|
+
add = BINARY_OP("add", 106, "i32", "i32", "i32");
|
|
3059
|
+
sub = BINARY_OP("sub", 107, "i32", "i32", "i32");
|
|
3060
|
+
mul = BINARY_OP("mul", 108, "i32", "i32", "i32");
|
|
3061
|
+
div_s = BINARY_OP("div_s", 109, "i32", "i32", "i32");
|
|
3062
|
+
div_u = BINARY_OP("div_u", 110, "i32", "i32", "i32");
|
|
3063
|
+
rem_s = BINARY_OP("rem_s", 111, "i32", "i32", "i32");
|
|
3064
|
+
rem_u = BINARY_OP("rem_u", 112, "i32", "i32", "i32");
|
|
3065
|
+
and = BINARY_OP("and", 113, "i32", "i32", "i32");
|
|
3066
|
+
or = BINARY_OP("or", 114, "i32", "i32", "i32");
|
|
3067
|
+
xor = BINARY_OP("xor", 115, "i32", "i32", "i32");
|
|
3068
|
+
shl = BINARY_OP("shl", 116, "i32", "i32", "i32");
|
|
3069
|
+
shr_s = BINARY_OP("shr_s", 117, "i32", "i32", "i32");
|
|
3070
|
+
shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
|
|
3071
|
+
rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
|
|
3072
|
+
rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
|
|
3073
|
+
eqz = BINARY_OP("eqz", 69, "i32", "i32", "i32");
|
|
3074
|
+
eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
|
|
3075
|
+
ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
|
|
3076
|
+
trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
|
|
3077
|
+
trunc_f32_u = UNARY_OP("trunc_f32_u", 169, "f32", "i32");
|
|
3078
|
+
load = LOAD_OP("load", 40, "i32");
|
|
3079
|
+
load8_s = LOAD_OP("load8_s", 44, "i32");
|
|
3080
|
+
load8_u = LOAD_OP("load8_u", 45, "i32");
|
|
3081
|
+
load16_s = LOAD_OP("load16_s", 46, "i32");
|
|
3082
|
+
load16_u = LOAD_OP("load16_u", 47, "i32");
|
|
3083
|
+
store = STORE_OP("store", 54, "i32");
|
|
3084
|
+
store8 = STORE_OP("store8", 58, "i32");
|
|
3085
|
+
store16 = STORE_OP("store16", 59, "i32");
|
|
3086
|
+
reinterpret_f32 = UNARY_OP("reinterpret_f32", 188, "f32", "i32");
|
|
3087
|
+
trunc_sat_f32_s = UNARY_OP("trunc_sat_f32_s", [252, 0], "f32", "i32");
|
|
3088
|
+
trunc_sat_f32_u = UNARY_OP("trunc_sat_f32_u", [252, 1], "f32", "i32");
|
|
3089
|
+
};
|
|
3090
|
+
var F32 = class {
|
|
3091
|
+
constructor(cg) {
|
|
3092
|
+
this.cg = cg;
|
|
3093
|
+
}
|
|
3094
|
+
get typeId() {
|
|
3095
|
+
return 125;
|
|
3096
|
+
}
|
|
3097
|
+
get name() {
|
|
3098
|
+
return "f32";
|
|
3099
|
+
}
|
|
3100
|
+
const(f) {
|
|
3101
|
+
this.cg._emit(67);
|
|
3102
|
+
const buffer = /* @__PURE__ */ new ArrayBuffer(4);
|
|
3103
|
+
new DataView(buffer).setFloat32(0, f, true);
|
|
3104
|
+
const bytes = new Uint8Array(buffer);
|
|
3105
|
+
for (let i = 0; i < 4; i++) this.cg._emit(bytes[i]);
|
|
3106
|
+
this.cg._push(this);
|
|
3107
|
+
}
|
|
3108
|
+
eq = BINARY_OP("eq", 91, "f32", "f32", "i32");
|
|
3109
|
+
ne = BINARY_OP("ne", 92, "f32", "f32", "i32");
|
|
3110
|
+
lt = BINARY_OP("lt", 93, "f32", "f32", "i32");
|
|
3111
|
+
gt = BINARY_OP("gt", 94, "f32", "f32", "i32");
|
|
3112
|
+
le = BINARY_OP("le", 95, "f32", "f32", "i32");
|
|
3113
|
+
ge = BINARY_OP("ge", 96, "f32", "f32", "i32");
|
|
3114
|
+
abs = UNARY_OP("abs", 139, "f32", "f32");
|
|
3115
|
+
neg = UNARY_OP("neg", 140, "f32", "f32");
|
|
3116
|
+
ceil = UNARY_OP("ceil", 141, "f32", "f32");
|
|
3117
|
+
floor = UNARY_OP("floor", 142, "f32", "f32");
|
|
3118
|
+
trunc = UNARY_OP("trunc", 143, "f32", "f32");
|
|
3119
|
+
nearest = UNARY_OP("nearest", 144, "f32", "f32");
|
|
3120
|
+
sqrt = UNARY_OP("sqrt", 145, "f32", "f32");
|
|
3121
|
+
add = BINARY_OP("add", 146, "f32", "f32", "f32");
|
|
3122
|
+
sub = BINARY_OP("sub", 147, "f32", "f32", "f32");
|
|
3123
|
+
mul = BINARY_OP("mul", 148, "f32", "f32", "f32");
|
|
3124
|
+
div = BINARY_OP("div", 149, "f32", "f32", "f32");
|
|
3125
|
+
min = BINARY_OP("min", 150, "f32", "f32", "f32");
|
|
3126
|
+
max = BINARY_OP("max", 151, "f32", "f32", "f32");
|
|
3127
|
+
copysign = BINARY_OP("copysign", 152, "f32", "f32", "f32");
|
|
3128
|
+
convert_i32_s = UNARY_OP("convert_i32_s", 178, "i32", "f32");
|
|
3129
|
+
convert_i32_u = UNARY_OP("convert_i32_u", 179, "i32", "f32");
|
|
3130
|
+
load = LOAD_OP("load", 42, "f32");
|
|
3131
|
+
store = STORE_OP("store", 56, "f32");
|
|
3132
|
+
reinterpret_i32 = UNARY_OP("reinterpret_i32", 190, "i32", "f32");
|
|
3133
|
+
};
|
|
3134
|
+
function VECTOR_OP(op, vopcode, inTypes, outType) {
|
|
3135
|
+
return function() {
|
|
3136
|
+
for (const inType of inTypes.toReversed()) {
|
|
3137
|
+
const actualType = this.cg._pop();
|
|
3138
|
+
assert(actualType.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inTypes.join(", ")} -> ${outType})`);
|
|
3139
|
+
}
|
|
3140
|
+
this.cg._emit(encodeOpcode([253, vopcode]));
|
|
3141
|
+
this.cg._push(this.cg[outType]);
|
|
3142
|
+
};
|
|
3143
|
+
}
|
|
3144
|
+
function VECTOR_OPL(op, vopcode, inTypes, outType) {
|
|
3145
|
+
return function(lane) {
|
|
3146
|
+
for (const inType of inTypes.toReversed()) {
|
|
3147
|
+
const actualType = this.cg._pop();
|
|
3148
|
+
assert(actualType.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inTypes} -> ${outType})`);
|
|
3149
|
+
}
|
|
3150
|
+
this.cg._emit(encodeOpcode([253, vopcode]));
|
|
3151
|
+
this.cg._emit(lane);
|
|
3152
|
+
this.cg._push(this.cg[outType]);
|
|
3153
|
+
};
|
|
3154
|
+
}
|
|
3155
|
+
function VECTOR_LOAD_OP(op, vopcode) {
|
|
3156
|
+
return function(align = 0, offset = 0) {
|
|
3157
|
+
const idxType = this.cg._pop();
|
|
3158
|
+
assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
|
|
3159
|
+
this.cg._emit(encodeOpcode([253, vopcode]));
|
|
3160
|
+
this.cg._emit(encodeUnsigned(align));
|
|
3161
|
+
this.cg._emit(encodeUnsigned(offset));
|
|
3162
|
+
this.cg._push(this.cg.v128);
|
|
3163
|
+
};
|
|
3164
|
+
}
|
|
3165
|
+
var V128 = class {
|
|
3166
|
+
constructor(cg) {
|
|
3167
|
+
this.cg = cg;
|
|
3168
|
+
}
|
|
3169
|
+
get typeId() {
|
|
3170
|
+
return 123;
|
|
3171
|
+
}
|
|
3172
|
+
get name() {
|
|
3173
|
+
return "v128";
|
|
3174
|
+
}
|
|
3175
|
+
load = VECTOR_LOAD_OP("load", 0);
|
|
3176
|
+
load32x2_s = VECTOR_LOAD_OP("load32x2_s", 5);
|
|
3177
|
+
load32x2_u = VECTOR_LOAD_OP("load32x2_u", 6);
|
|
3178
|
+
load32_splat = VECTOR_LOAD_OP("load32_splat", 9);
|
|
3179
|
+
load32_zero = VECTOR_LOAD_OP("load32_zero", 92);
|
|
3180
|
+
store(align = 0, offset = 0) {
|
|
3181
|
+
const valType = this.cg._pop();
|
|
3182
|
+
assert(valType.typeId === this.cg.v128.typeId, `invalid type for store`);
|
|
3183
|
+
const idxType = this.cg._pop();
|
|
3184
|
+
assert(idxType.typeId === this.cg.i32.typeId, `invalid type for store`);
|
|
3185
|
+
this.cg._emit(253);
|
|
3186
|
+
this.cg._emit(encodeUnsigned(11));
|
|
3187
|
+
this.cg._emit(encodeUnsigned(align));
|
|
3188
|
+
this.cg._emit(encodeUnsigned(offset));
|
|
3189
|
+
}
|
|
3190
|
+
not = VECTOR_OP("not", 77, ["v128"], "v128");
|
|
3191
|
+
and = VECTOR_OP("and", 78, ["v128", "v128"], "v128");
|
|
3192
|
+
andnot = VECTOR_OP("andnot", 79, ["v128", "v128"], "v128");
|
|
3193
|
+
or = VECTOR_OP("or", 80, ["v128", "v128"], "v128");
|
|
3194
|
+
xor = VECTOR_OP("xor", 81, ["v128", "v128"], "v128");
|
|
3195
|
+
bitselect = VECTOR_OP("bitselect", 82, [
|
|
3196
|
+
"v128",
|
|
3197
|
+
"v128",
|
|
3198
|
+
"v128"
|
|
3199
|
+
], "v128");
|
|
3200
|
+
any_true = VECTOR_OP("any_true", 83, ["v128"], "i32");
|
|
3201
|
+
};
|
|
3202
|
+
var I32x4 = class extends V128 {
|
|
3203
|
+
splat = VECTOR_OP("splat", 17, ["i32"], "v128");
|
|
3204
|
+
extract_lane = VECTOR_OPL("extract_lane", 27, ["v128"], "i32");
|
|
3205
|
+
replace_lane = VECTOR_OPL("replace_lane", 28, ["v128", "i32"], "v128");
|
|
3206
|
+
eq = VECTOR_OP("eq", 55, ["v128", "v128"], "v128");
|
|
3207
|
+
ne = VECTOR_OP("ne", 56, ["v128", "v128"], "v128");
|
|
3208
|
+
lt_s = VECTOR_OP("lt_s", 57, ["v128", "v128"], "v128");
|
|
3209
|
+
lt_u = VECTOR_OP("lt_u", 58, ["v128", "v128"], "v128");
|
|
3210
|
+
gt_s = VECTOR_OP("gt_s", 59, ["v128", "v128"], "v128");
|
|
3211
|
+
gt_u = VECTOR_OP("gt_u", 60, ["v128", "v128"], "v128");
|
|
3212
|
+
le_s = VECTOR_OP("le_s", 61, ["v128", "v128"], "v128");
|
|
3213
|
+
le_u = VECTOR_OP("le_u", 62, ["v128", "v128"], "v128");
|
|
3214
|
+
ge_s = VECTOR_OP("ge_s", 63, ["v128", "v128"], "v128");
|
|
3215
|
+
ge_u = VECTOR_OP("ge_u", 64, ["v128", "v128"], "v128");
|
|
3216
|
+
abs = VECTOR_OP("abs", 160, ["v128"], "v128");
|
|
3217
|
+
neg = VECTOR_OP("neg", 161, ["v128"], "v128");
|
|
3218
|
+
all_true = VECTOR_OP("all_true", 163, ["v128"], "i32");
|
|
3219
|
+
bitmask = VECTOR_OP("bitmask", 164, ["v128"], "i32");
|
|
3220
|
+
shl = VECTOR_OP("shl", 171, ["v128", "i32"], "v128");
|
|
3221
|
+
shr_s = VECTOR_OP("shr_s", 172, ["v128", "i32"], "v128");
|
|
3222
|
+
shr_u = VECTOR_OP("shr_u", 173, ["v128", "i32"], "v128");
|
|
3223
|
+
add = VECTOR_OP("add", 174, ["v128", "v128"], "v128");
|
|
3224
|
+
sub = VECTOR_OP("sub", 177, ["v128", "v128"], "v128");
|
|
3225
|
+
mul = VECTOR_OP("mul", 181, ["v128", "v128"], "v128");
|
|
3226
|
+
min_s = VECTOR_OP("min_s", 182, ["v128", "v128"], "v128");
|
|
3227
|
+
min_u = VECTOR_OP("min_u", 183, ["v128", "v128"], "v128");
|
|
3228
|
+
max_s = VECTOR_OP("max_s", 184, ["v128", "v128"], "v128");
|
|
3229
|
+
max_u = VECTOR_OP("max_u", 185, ["v128", "v128"], "v128");
|
|
3230
|
+
};
|
|
3231
|
+
var F32x4 = class extends V128 {
|
|
3232
|
+
splat = VECTOR_OP("splat", 19, ["f32"], "v128");
|
|
3233
|
+
extract_lane = VECTOR_OPL("extract_lane", 31, ["v128"], "f32");
|
|
3234
|
+
replace_lane = VECTOR_OPL("replace_lane", 32, ["v128", "f32"], "v128");
|
|
3235
|
+
eq = VECTOR_OP("eq", 65, ["v128", "v128"], "v128");
|
|
3236
|
+
ne = VECTOR_OP("ne", 66, ["v128", "v128"], "v128");
|
|
3237
|
+
lt = VECTOR_OP("lt", 67, ["v128", "v128"], "v128");
|
|
3238
|
+
gt = VECTOR_OP("gt", 68, ["v128", "v128"], "v128");
|
|
3239
|
+
le = VECTOR_OP("le", 69, ["v128", "v128"], "v128");
|
|
3240
|
+
ge = VECTOR_OP("ge", 70, ["v128", "v128"], "v128");
|
|
3241
|
+
ceil = VECTOR_OP("ceil", 103, ["v128"], "v128");
|
|
3242
|
+
floor = VECTOR_OP("floor", 104, ["v128"], "v128");
|
|
3243
|
+
trunc = VECTOR_OP("trunc", 105, ["v128"], "v128");
|
|
3244
|
+
nearest = VECTOR_OP("nearest", 106, ["v128"], "v128");
|
|
3245
|
+
abs = VECTOR_OP("abs", 224, ["v128"], "v128");
|
|
3246
|
+
neg = VECTOR_OP("neg", 225, ["v128"], "v128");
|
|
3247
|
+
sqrt = VECTOR_OP("sqrt", 227, ["v128"], "v128");
|
|
3248
|
+
add = VECTOR_OP("add", 228, ["v128", "v128"], "v128");
|
|
3249
|
+
sub = VECTOR_OP("sub", 229, ["v128", "v128"], "v128");
|
|
3250
|
+
mul = VECTOR_OP("mul", 230, ["v128", "v128"], "v128");
|
|
3251
|
+
div = VECTOR_OP("div", 231, ["v128", "v128"], "v128");
|
|
3252
|
+
min = VECTOR_OP("min", 232, ["v128", "v128"], "v128");
|
|
3253
|
+
max = VECTOR_OP("max", 233, ["v128", "v128"], "v128");
|
|
3254
|
+
pmin = VECTOR_OP("pmin", 234, ["v128", "v128"], "v128");
|
|
3255
|
+
pmax = VECTOR_OP("pmax", 235, ["v128", "v128"], "v128");
|
|
3256
|
+
};
|
|
3257
|
+
|
|
3258
|
+
//#endregion
|
|
3259
|
+
//#region src/backend/wasm.ts
|
|
3260
|
+
/** Backend that compiles into WebAssembly bytecode for immediate execution. */
|
|
3261
|
+
var WasmBackend = class {
|
|
3262
|
+
type = "wasm";
|
|
3263
|
+
maxArgs = 64;
|
|
3264
|
+
#memory;
|
|
3265
|
+
#nextSlot;
|
|
3266
|
+
#allocator;
|
|
3267
|
+
#buffers;
|
|
3268
|
+
constructor() {
|
|
3269
|
+
this.#memory = new WebAssembly.Memory({ initial: 0 });
|
|
3270
|
+
this.#allocator = new WasmAllocator(this.#memory);
|
|
3271
|
+
this.#nextSlot = 1;
|
|
3272
|
+
this.#buffers = /* @__PURE__ */ new Map();
|
|
3273
|
+
}
|
|
3274
|
+
malloc(size, initialData) {
|
|
3275
|
+
const ptr = this.#allocator.malloc(size);
|
|
3276
|
+
if (initialData) {
|
|
3277
|
+
if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
|
|
3278
|
+
new Uint8Array(this.#memory.buffer, ptr, size).set(initialData);
|
|
3279
|
+
}
|
|
3280
|
+
const slot = this.#nextSlot++;
|
|
3281
|
+
this.#buffers.set(slot, {
|
|
3282
|
+
ptr,
|
|
3283
|
+
size,
|
|
3284
|
+
ref: 1
|
|
3285
|
+
});
|
|
3286
|
+
return slot;
|
|
3287
|
+
}
|
|
3288
|
+
incRef(slot) {
|
|
3289
|
+
const buffer = this.#buffers.get(slot);
|
|
3290
|
+
if (!buffer) throw new SlotError(slot);
|
|
3291
|
+
buffer.ref++;
|
|
3292
|
+
}
|
|
3293
|
+
decRef(slot) {
|
|
3294
|
+
const buffer = this.#buffers.get(slot);
|
|
3295
|
+
if (!buffer) throw new SlotError(slot);
|
|
3296
|
+
buffer.ref--;
|
|
3297
|
+
if (buffer.ref === 0) {
|
|
3298
|
+
this.#allocator.free(buffer.ptr);
|
|
3299
|
+
this.#buffers.delete(slot);
|
|
3300
|
+
}
|
|
3301
|
+
}
|
|
3302
|
+
async read(slot, start, count) {
|
|
3303
|
+
return this.readSync(slot, start, count);
|
|
3304
|
+
}
|
|
3305
|
+
readSync(slot, start, count) {
|
|
3306
|
+
const buffer = this.#getBuffer(slot);
|
|
3307
|
+
if (start === void 0) start = 0;
|
|
3308
|
+
if (count === void 0) count = buffer.byteLength - start;
|
|
3309
|
+
return buffer.slice(start, start + count);
|
|
3310
|
+
}
|
|
3311
|
+
async prepare(kernel) {
|
|
3312
|
+
return this.prepareSync(kernel);
|
|
3313
|
+
}
|
|
3314
|
+
prepareSync(kernel) {
|
|
3315
|
+
const bytes = codegenWasm(kernel);
|
|
3316
|
+
const module = new WebAssembly.Module(bytes);
|
|
3317
|
+
return new Executable(kernel, { module });
|
|
3318
|
+
}
|
|
3319
|
+
dispatch(exe, inputs, outputs) {
|
|
3320
|
+
const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
|
|
3321
|
+
const func = instance.exports.kernel;
|
|
3322
|
+
const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
|
|
3323
|
+
func(...ptrs);
|
|
3324
|
+
}
|
|
3325
|
+
#getBuffer(slot) {
|
|
3326
|
+
const buffer = this.#buffers.get(slot);
|
|
3327
|
+
if (!buffer) throw new SlotError(slot);
|
|
3328
|
+
return new Uint8Array(this.#memory.buffer, buffer.ptr, buffer.size);
|
|
3329
|
+
}
|
|
3330
|
+
};
|
|
3331
|
+
function codegenWasm(kernel) {
|
|
3332
|
+
const tune = tuneNullopt(kernel);
|
|
3333
|
+
const re = kernel.reduction;
|
|
3334
|
+
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
3335
|
+
const cg = new CodeGenerator();
|
|
3336
|
+
cg.memory.import("env", "memory");
|
|
3337
|
+
const distinctOps = union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
|
|
3338
|
+
const funcs = {};
|
|
3339
|
+
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3340
|
+
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
3341
|
+
if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
|
|
3342
|
+
if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
|
|
3343
|
+
if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
|
|
3344
|
+
const kernelFunc = cg.function(rep(kernel.nargs + 1, cg.i32), [], () => {
|
|
3345
|
+
const gidx = cg.local.declare(cg.i32);
|
|
3346
|
+
cg.loop(cg.void);
|
|
3347
|
+
cg.block(cg.void);
|
|
3348
|
+
cg.local.get(gidx);
|
|
3349
|
+
cg.i32.const(kernel.size);
|
|
3350
|
+
cg.i32.ge_u();
|
|
3351
|
+
cg.br_if(0);
|
|
3352
|
+
cg.local.get(kernel.nargs);
|
|
3353
|
+
cg.local.get(gidx);
|
|
3354
|
+
cg.i32.const(byteWidth(kernel.dtype));
|
|
3355
|
+
cg.i32.mul();
|
|
3356
|
+
cg.i32.add();
|
|
3357
|
+
if (re) {
|
|
3358
|
+
const acc = cg.local.declare(dty(cg, null, kernel.exp.dtype));
|
|
3359
|
+
dty(cg, null, kernel.exp.dtype).const(re.identity);
|
|
3360
|
+
cg.local.set(acc);
|
|
3361
|
+
const ridx = cg.local.declare(cg.i32);
|
|
3362
|
+
cg.i32.const(0);
|
|
3363
|
+
cg.local.set(ridx);
|
|
3364
|
+
cg.loop(cg.void);
|
|
3365
|
+
cg.block(cg.void);
|
|
3366
|
+
cg.local.get(ridx);
|
|
3367
|
+
cg.i32.const(re.size);
|
|
3368
|
+
cg.i32.ge_u();
|
|
3369
|
+
cg.br_if(0);
|
|
3370
|
+
translateExp(cg, funcs, tune.exp, {
|
|
3371
|
+
gidx,
|
|
3372
|
+
ridx
|
|
3373
|
+
});
|
|
3374
|
+
if (re.op === AluOp.Add) {
|
|
3375
|
+
cg.local.get(acc);
|
|
3376
|
+
if (re.dtype === DType.Bool) cg.i32.or();
|
|
3377
|
+
else dty(cg, re.op, re.dtype).add();
|
|
3378
|
+
} else if (re.op === AluOp.Mul) {
|
|
3379
|
+
cg.local.get(acc);
|
|
3380
|
+
if (re.dtype === DType.Bool) cg.i32.and();
|
|
3381
|
+
else dty(cg, re.op, re.dtype).mul();
|
|
3382
|
+
} else if (re.op === AluOp.Min || re.op === AluOp.Max) if (re.dtype === DType.Float32) {
|
|
3383
|
+
cg.local.get(acc);
|
|
3384
|
+
if (re.op === AluOp.Min) cg.f32.min();
|
|
3385
|
+
else cg.f32.max();
|
|
3386
|
+
} else if ([
|
|
3387
|
+
DType.Int32,
|
|
3388
|
+
DType.Uint32,
|
|
3389
|
+
DType.Bool
|
|
3390
|
+
].includes(re.dtype)) {
|
|
3391
|
+
const local = cg.local.declare(cg.i32);
|
|
3392
|
+
cg.local.tee(local);
|
|
3393
|
+
cg.local.get(acc);
|
|
3394
|
+
cg.local.get(local);
|
|
3395
|
+
cg.local.get(acc);
|
|
3396
|
+
if (re.op === AluOp.Min) if (re.dtype === DType.Int32) cg.i32.lt_s();
|
|
3397
|
+
else cg.i32.lt_u();
|
|
3398
|
+
else if (re.dtype === DType.Int32) cg.i32.gt_s();
|
|
3399
|
+
else cg.i32.gt_u();
|
|
3400
|
+
cg.select();
|
|
3401
|
+
} else throw new Error(`invalid reduction min/max over ${re.dtype}`);
|
|
3402
|
+
else throw new Error(`invalid wasm reduction op: ${re.op}`);
|
|
3403
|
+
cg.local.set(acc);
|
|
3404
|
+
cg.local.get(ridx);
|
|
3405
|
+
cg.i32.const(1);
|
|
3406
|
+
cg.i32.add();
|
|
3407
|
+
cg.local.set(ridx);
|
|
3408
|
+
cg.br(1);
|
|
3409
|
+
cg.end();
|
|
3410
|
+
cg.end();
|
|
3411
|
+
translateExp(cg, funcs, kernel.reduction.epilogue, { acc });
|
|
3412
|
+
} else translateExp(cg, funcs, tune.exp, { gidx });
|
|
3413
|
+
dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
|
|
3414
|
+
cg.local.get(gidx);
|
|
3415
|
+
cg.i32.const(1);
|
|
3416
|
+
cg.i32.add();
|
|
3417
|
+
cg.local.set(gidx);
|
|
3418
|
+
cg.br(1);
|
|
3419
|
+
cg.end();
|
|
3420
|
+
cg.end();
|
|
3421
|
+
});
|
|
3422
|
+
cg.export(kernelFunc, "kernel");
|
|
3423
|
+
return cg.finish();
|
|
3424
|
+
}
|
|
3425
|
+
function translateExp(cg, funcs, exp, ctx) {
|
|
3426
|
+
const references = /* @__PURE__ */ new Map();
|
|
3427
|
+
const seen = /* @__PURE__ */ new Set();
|
|
3428
|
+
const countReferences = (exp$1) => {
|
|
3429
|
+
references.set(exp$1, (references.get(exp$1) ?? 0) + 1);
|
|
3430
|
+
if (!seen.has(exp$1)) {
|
|
3431
|
+
seen.add(exp$1);
|
|
3432
|
+
for (const src of exp$1.src) countReferences(src);
|
|
3433
|
+
}
|
|
3434
|
+
};
|
|
3435
|
+
const expContext = /* @__PURE__ */ new Map();
|
|
3436
|
+
const gen = (exp$1) => {
|
|
3437
|
+
if (expContext.has(exp$1)) return cg.local.get(expContext.get(exp$1));
|
|
3438
|
+
const { op, src, dtype, arg } = exp$1;
|
|
3439
|
+
if (AluGroup.Binary.has(op) || AluGroup.Compare.has(op)) {
|
|
3440
|
+
gen(src[0]);
|
|
3441
|
+
gen(src[1]);
|
|
3442
|
+
if (op === AluOp.Add) if (dtype === DType.Bool) cg.i32.or();
|
|
3443
|
+
else dty(cg, op, dtype).add();
|
|
3444
|
+
else if (op === AluOp.Sub) dty(cg, op, dtype).sub();
|
|
3445
|
+
else if (op === AluOp.Mul) if (dtype === DType.Bool) cg.i32.and();
|
|
3446
|
+
else dty(cg, op, dtype).mul();
|
|
3447
|
+
else if (op === AluOp.Idiv) if (dtype === DType.Float32) cg.f32.div(), cg.f32.trunc();
|
|
3448
|
+
else if (dtype === DType.Uint32) cg.i32.div_u();
|
|
3449
|
+
else if (dtype === DType.Int32) cg.i32.div_s();
|
|
3450
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3451
|
+
else if (op === AluOp.Mod) if (dtype === DType.Float32) {
|
|
3452
|
+
const a = cg.local.declare(cg.f32);
|
|
3453
|
+
const b = cg.local.declare(cg.f32);
|
|
3454
|
+
cg.local.set(b);
|
|
3455
|
+
cg.local.tee(a);
|
|
3456
|
+
cg.local.get(a);
|
|
3457
|
+
cg.local.get(b);
|
|
3458
|
+
cg.f32.div();
|
|
3459
|
+
cg.f32.trunc();
|
|
3460
|
+
cg.local.get(b);
|
|
3461
|
+
cg.f32.mul();
|
|
3462
|
+
cg.f32.sub();
|
|
3463
|
+
} else if (dtype === DType.Uint32) cg.i32.rem_u();
|
|
3464
|
+
else if (dtype === DType.Int32) cg.i32.rem_s();
|
|
3465
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3466
|
+
else if (op === AluOp.Min || op === AluOp.Max) if (dtype === DType.Float32) if (op === AluOp.Min) cg.f32.min();
|
|
3467
|
+
else cg.f32.max();
|
|
3468
|
+
else if (dtype === DType.Int32 || dtype === DType.Uint32) {
|
|
3469
|
+
const a = cg.local.declare(cg.i32);
|
|
3470
|
+
const b = cg.local.declare(cg.i32);
|
|
3471
|
+
cg.local.set(b);
|
|
3472
|
+
cg.local.tee(a);
|
|
3473
|
+
cg.local.get(b);
|
|
3474
|
+
cg.local.get(a);
|
|
3475
|
+
cg.local.get(b);
|
|
3476
|
+
if (dtype === DType.Int32) if (op === AluOp.Min) cg.i32.lt_s();
|
|
3477
|
+
else cg.i32.gt_s();
|
|
3478
|
+
else if (op === AluOp.Min) cg.i32.lt_u();
|
|
3479
|
+
else cg.i32.gt_u();
|
|
3480
|
+
cg.select();
|
|
3481
|
+
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3482
|
+
else if (op === AluOp.Cmplt) {
|
|
3483
|
+
const srcDtype = src[0].dtype;
|
|
3484
|
+
if (srcDtype === DType.Float32) cg.f32.lt();
|
|
3485
|
+
else if (srcDtype === DType.Int32) cg.i32.lt_s();
|
|
3486
|
+
else if (srcDtype === DType.Uint32) cg.i32.lt_u();
|
|
3487
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3488
|
+
} else if (op === AluOp.Cmpne) dty(cg, op, src[0].dtype).ne();
|
|
3489
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3490
|
+
} else if (AluGroup.Unary.has(op)) if (op === AluOp.Sin) gen(src[0]), cg.call(funcs.sin);
|
|
3491
|
+
else if (op === AluOp.Cos) gen(src[0]), cg.call(funcs.cos);
|
|
3492
|
+
else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
|
|
3493
|
+
else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
|
|
3494
|
+
else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
|
|
3495
|
+
else if (op === AluOp.Reciprocal) cg.f32.const(1), gen(src[0]), cg.f32.div();
|
|
3496
|
+
else if (op === AluOp.Cast) {
|
|
3497
|
+
gen(src[0]);
|
|
3498
|
+
const dtype0 = src[0].dtype;
|
|
3499
|
+
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
|
|
3500
|
+
if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
|
|
3501
|
+
else if (i32repr);
|
|
3502
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3503
|
+
else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
|
|
3504
|
+
else if (i32repr);
|
|
3505
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3506
|
+
else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
|
|
3507
|
+
else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
|
|
3508
|
+
else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
|
|
3509
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3510
|
+
else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
|
|
3511
|
+
else if (i32repr) cg.i32.const(0), cg.i32.ne();
|
|
3512
|
+
else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
|
|
3513
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3514
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3515
|
+
} else if (op === AluOp.Bitcast) {
|
|
3516
|
+
gen(src[0]);
|
|
3517
|
+
const dtype0 = src[0].dtype;
|
|
3518
|
+
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
|
|
3519
|
+
if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
|
|
3520
|
+
else if (i32repr);
|
|
3521
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3522
|
+
else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
|
|
3523
|
+
else if (dtype0 === DType.Float32);
|
|
3524
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3525
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3526
|
+
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3527
|
+
else if (op === AluOp.Where) {
|
|
3528
|
+
gen(src[1]);
|
|
3529
|
+
gen(src[2]);
|
|
3530
|
+
gen(src[0]);
|
|
3531
|
+
cg.select();
|
|
3532
|
+
} else if (op === AluOp.Threefry2x32) {
|
|
3533
|
+
for (let i = 0; i < 4; i++) gen(src[i]);
|
|
3534
|
+
cg.call(funcs.threefry2x32);
|
|
3535
|
+
if (arg === "xor") cg.i32.xor();
|
|
3536
|
+
else if (arg === 0) cg.drop();
|
|
3537
|
+
else if (arg === 1) {
|
|
3538
|
+
const local = cg.local.declare(cg.i32);
|
|
3539
|
+
cg.local.set(local);
|
|
3540
|
+
cg.drop();
|
|
3541
|
+
cg.local.get(local);
|
|
3542
|
+
} else throw new UnsupportedOpError(op, dtype, "wasm", arg);
|
|
3543
|
+
} else if (op === AluOp.Const) return dty(cg, op, dtype).const(arg);
|
|
3544
|
+
else if (op === AluOp.Special) return cg.local.get(ctx[arg[0]]);
|
|
3545
|
+
else if (op === AluOp.Variable) return cg.local.get(ctx[arg]);
|
|
3546
|
+
else if (op === AluOp.GlobalIndex) {
|
|
3547
|
+
const [gid, len] = arg;
|
|
3548
|
+
gen(src[0]);
|
|
3549
|
+
const local = cg.local.declare(cg.i32);
|
|
3550
|
+
cg.local.tee(local);
|
|
3551
|
+
cg.i32.const(0);
|
|
3552
|
+
cg.local.get(local), cg.i32.const(len), cg.i32.lt_u();
|
|
3553
|
+
cg.select();
|
|
3554
|
+
cg.i32.const(byteWidth(dtype));
|
|
3555
|
+
cg.i32.mul();
|
|
3556
|
+
cg.local.get(gid);
|
|
3557
|
+
cg.i32.add();
|
|
3558
|
+
dty(cg, op, dtype).load(Math.log2(byteWidth(dtype)));
|
|
3559
|
+
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3560
|
+
if ((references.get(exp$1) ?? 0) > 1) {
|
|
3561
|
+
const local = cg.local.declare(dty(cg, op, dtype));
|
|
3562
|
+
cg.local.tee(local);
|
|
3563
|
+
expContext.set(exp$1, local);
|
|
3564
|
+
}
|
|
3565
|
+
};
|
|
3566
|
+
countReferences(exp);
|
|
3567
|
+
gen(exp);
|
|
3568
|
+
}
|
|
3569
|
+
function dty(cg, op, dtype) {
|
|
3570
|
+
switch (dtype) {
|
|
3571
|
+
case DType.Float32: return cg.f32;
|
|
3572
|
+
case DType.Int32:
|
|
3573
|
+
case DType.Uint32:
|
|
3574
|
+
case DType.Bool: return cg.i32;
|
|
3575
|
+
default: throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3576
|
+
}
|
|
3577
|
+
}
|
|
3578
|
+
|
|
1813
3579
|
//#endregion
|
|
1814
3580
|
//#region src/backend.ts
|
|
1815
|
-
const devices = [
|
|
1816
|
-
|
|
3581
|
+
const devices = [
|
|
3582
|
+
"cpu",
|
|
3583
|
+
"wasm",
|
|
3584
|
+
"webgpu"
|
|
3585
|
+
];
|
|
3586
|
+
let defaultBackend = "wasm";
|
|
1817
3587
|
const initializedBackends = /* @__PURE__ */ new Map();
|
|
1818
|
-
initializedBackends.set("cpu", new
|
|
3588
|
+
initializedBackends.set("cpu", new CpuBackend());
|
|
3589
|
+
initializedBackends.set("wasm", new WasmBackend());
|
|
1819
3590
|
/** Set the default device backend (must be initialized). */
|
|
1820
3591
|
function setDevice(device) {
|
|
1821
3592
|
if (initializedBackends.has(device)) defaultBackend = device;
|
|
@@ -1840,12 +3611,13 @@ async function init(...devicesToInit) {
|
|
|
1840
3611
|
}
|
|
1841
3612
|
/** Create a backend, if available. Internal function called by `init()`. */
|
|
1842
3613
|
async function createBackend(device) {
|
|
1843
|
-
if (device === "cpu") return new
|
|
3614
|
+
if (device === "cpu") return new CpuBackend();
|
|
3615
|
+
else if (device === "wasm") return new WasmBackend();
|
|
1844
3616
|
else if (device === "webgpu") {
|
|
1845
3617
|
if (!navigator.gpu) return null;
|
|
1846
3618
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
1847
3619
|
if (!adapter) return null;
|
|
1848
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
3620
|
+
const { WebGPUBackend } = await import("./webgpu-CNg9JGva.js");
|
|
1849
3621
|
const importantLimits = [
|
|
1850
3622
|
"maxBufferSize",
|
|
1851
3623
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -1858,8 +3630,12 @@ async function createBackend(device) {
|
|
|
1858
3630
|
"maxStorageBuffersPerShaderStage",
|
|
1859
3631
|
"maxStorageTexturesPerShaderStage"
|
|
1860
3632
|
];
|
|
3633
|
+
const requestedFeatures = ["shader-f16", "timestamp-query"];
|
|
1861
3634
|
try {
|
|
1862
|
-
const device$1 = await adapter.requestDevice({
|
|
3635
|
+
const device$1 = await adapter.requestDevice({
|
|
3636
|
+
requiredLimits: Object.fromEntries(importantLimits.map((limit) => [limit, adapter.limits[limit]])),
|
|
3637
|
+
requiredFeatures: requestedFeatures.filter((feature) => adapter.features.has(feature))
|
|
3638
|
+
});
|
|
1863
3639
|
return new WebGPUBackend(device$1);
|
|
1864
3640
|
} catch (error) {
|
|
1865
3641
|
console.error("Unexpected error requesting WebGPU device:", error);
|
|
@@ -1885,6 +3661,13 @@ var SlotError = class extends Error {
|
|
|
1885
3661
|
super(`Used a buffer that is invalid or already freed: ${slot}`);
|
|
1886
3662
|
}
|
|
1887
3663
|
};
|
|
3664
|
+
var UnsupportedOpError = class extends Error {
|
|
3665
|
+
constructor(op, dtype, device, arg) {
|
|
3666
|
+
let msg = `${op || ""}<${dtype}> not supported in ${device} backend`;
|
|
3667
|
+
if (arg !== void 0) msg += ` with arg ${JSON.stringify(arg)}`;
|
|
3668
|
+
super(msg);
|
|
3669
|
+
}
|
|
3670
|
+
};
|
|
1888
3671
|
|
|
1889
3672
|
//#endregion
|
|
1890
|
-
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, findPow2, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache, setDevice, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip };
|
|
3673
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, dtypedJsArray, findPow2, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache, setDevice, strip1, toposort, tuneWebgpu, union, unravelAlu, unzip2, zip, zipn };
|