@jax-js/jax 0.0.2 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -36,7 +36,22 @@ var PPrint = class PPrint {
36
36
  //#endregion
37
37
  //#region src/utils.ts
38
38
  /** @file Generic programming utilities with no dependencies on library code. */
39
- const DEBUG = 3;
39
+ let DEBUG = 0;
40
+ /**
41
+ * Set the debug level for verbose logging.
42
+ *
43
+ * 1. JIT compile logs
44
+ * 2. Shader code
45
+ * 3. Expressions and metadata
46
+ * 4. JIT programs, tuning details
47
+ * 5. Most verbose operation traces
48
+ *
49
+ * This is an experimental API and may change in behavior. Do not rely on this
50
+ * in production.
51
+ */
52
+ function setDebug(level) {
53
+ DEBUG = level;
54
+ }
40
55
  function unzip2(pairs) {
41
56
  const lst1 = [];
42
57
  const lst2 = [];
@@ -49,6 +64,10 @@ function unzip2(pairs) {
49
64
  function zip(xs, ys) {
50
65
  return xs.map((x, i) => [x, ys[i]]);
51
66
  }
67
+ function zipn(...arrays) {
68
+ const minLength = Math.min(...arrays.map((x) => x.length));
69
+ return Array.from({ length: minLength }, (_, i) => arrays.map((arr) => arr[i]));
70
+ }
52
71
  function rep(length, value) {
53
72
  if (value instanceof Function) return new Array(length).fill(0).map((_, i) => value(i));
54
73
  return new Array(length).fill(value);
@@ -56,6 +75,11 @@ function rep(length, value) {
56
75
  function prod(arr) {
57
76
  return arr.reduce((acc, x) => acc * x, 1);
58
77
  }
78
+ function gcd(...values) {
79
+ let a = 0;
80
+ for (let b of values) while (b !== 0) [a, b] = [b, a % b];
81
+ return Math.abs(a);
82
+ }
59
83
  /** Shorthand for integer division, like in Python. */
60
84
  function intdiv(a, b) {
61
85
  return Math.floor(a / b);
@@ -73,6 +97,11 @@ function deepEqual(a, b) {
73
97
  for (const key of Object.keys(a)) if (!deepEqual(a[key], b[key])) return false;
74
98
  return true;
75
99
  }
100
+ function union(...sets) {
101
+ const result = /* @__PURE__ */ new Set();
102
+ for (const s of sets) if (s) for (const x of s) result.add(x);
103
+ return result;
104
+ }
76
105
  /** Splits the list based on a condition, `false` first then `true`. */
77
106
  function partitionList(which, array) {
78
107
  const falseList = [];
@@ -96,9 +125,23 @@ function isNumberPair(x) {
96
125
  }
97
126
  /** Check an axis against number of dimensions, and resolve negative axes. */
98
127
  function checkAxis(axis, ndim) {
99
- if (axis < -ndim || axis >= ndim) throw new Error(`Invalid axis ${axis} for array of ${ndim} dimensions`);
128
+ if (axis < -ndim || axis >= ndim) throw new Error(`Axis ${axis} out of bounds for array of dimension ${ndim}`);
100
129
  return axis < 0 ? axis + ndim : axis;
101
130
  }
131
+ /** Normalize common axis argument for functions, defaulting to all axes. */
132
+ function normalizeAxis(axis, ndim) {
133
+ if (axis === null) return range(ndim);
134
+ else if (typeof axis === "number") return [checkAxis(axis, ndim)];
135
+ else {
136
+ const seen = /* @__PURE__ */ new Set();
137
+ for (const a of axis) {
138
+ const ca = checkAxis(a, ndim);
139
+ if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
140
+ seen.add(ca);
141
+ }
142
+ return [...seen].sort();
143
+ }
144
+ }
102
145
  function range(start, stop, step = 1) {
103
146
  if (stop === void 0) {
104
147
  stop = start;
@@ -173,6 +216,7 @@ function strip1(str) {
173
216
  if (str[0] === "(" && str[str.length - 1] === ")") return str.slice(1, -1);
174
217
  return str;
175
218
  }
219
+ const _stagingbuf = /* @__PURE__ */ new DataView(/* @__PURE__ */ new ArrayBuffer(8));
176
220
  /**
177
221
  * Polynomial hashes modulo p are good at avoiding collisions in expectation.
178
222
  * Probability-wise, it's good enough to be used for something like
@@ -187,22 +231,26 @@ var FpHash = class FpHash {
187
231
  const modulus = 3189051996290219n;
188
232
  this.value = (this.value * base + x) % modulus;
189
233
  }
190
- update(...values) {
191
- for (const x of values) if (typeof x === "string") for (const c of x) this.#update(BigInt(199 + c.charCodeAt(0)));
192
- else if (typeof x === "number") if (Number.isInteger(x)) this.#update(68265653n ^ BigInt(x));
234
+ update(x) {
235
+ if (typeof x === "string") {
236
+ this.#update(BigInt(x.length));
237
+ for (let i = 0; i < x.length; i++) this.#update(BigInt(199 + x.charCodeAt(i)));
238
+ } else if (typeof x === "number") if (Number.isInteger(x)) this.#update(68265653n ^ BigInt(x));
193
239
  else {
194
- const ar = new Float64Array([x]);
195
- this.#update(new DataView(ar.buffer).getBigUint64(0, true));
240
+ _stagingbuf.setFloat64(0, x, true);
241
+ this.#update(_stagingbuf.getBigUint64(0, true));
196
242
  }
197
243
  else if (typeof x === "boolean") this.#update(x ? 69069841n : 63640693n);
198
244
  else if (typeof x === "bigint") this.#update(x ^ 71657401n);
199
245
  else if (x === null) this.#update(37832657n);
200
246
  else if (x === void 0) this.#update(18145117n);
201
- else if (typeof x === "object" && "hash" in x) x.hash(this);
247
+ else x.hash(this);
202
248
  return this;
203
249
  }
204
250
  static hash(...values) {
205
- return new FpHash().update(...values).value;
251
+ const h = new FpHash();
252
+ for (const x of values) h.update(x);
253
+ return h.value;
206
254
  }
207
255
  };
208
256
  /** Run a function while caching it inline inside a `Map`. */
@@ -217,12 +265,13 @@ function runWithCache(cache, key, thunk) {
217
265
 
218
266
  //#endregion
219
267
  //#region src/alu.ts
268
+ /** A numerical data type for array contents. */
220
269
  let DType = /* @__PURE__ */ function(DType$1) {
221
270
  DType$1["Float32"] = "float32";
222
271
  DType$1["Int32"] = "int32";
223
272
  DType$1["Uint32"] = "uint32";
224
273
  DType$1["Bool"] = "bool";
225
- DType$1["Complex64"] = "complex64";
274
+ DType$1["Float16"] = "float16";
226
275
  return DType$1;
227
276
  }({});
228
277
  const byteWidth = (dtype) => {
@@ -231,17 +280,65 @@ const byteWidth = (dtype) => {
231
280
  case DType.Int32:
232
281
  case DType.Uint32:
233
282
  case DType.Bool: return 4;
234
- case DType.Complex64: return 8;
283
+ case DType.Float16: return 2;
235
284
  default: throw new TypeError(`Unknown dtype: ${dtype}`);
236
285
  }
237
286
  };
238
- const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Complex64;
287
+ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
288
+ /**
289
+ * Promote two dtypes to their join according to the type lattice.
290
+ *
291
+ * When performing operations between arrays of different types, we need to
292
+ * promote both operands to a common type that can represent values from both
293
+ * input types. This follows JAX's type promotion rules.
294
+ *
295
+ * **Type lattice:**
296
+ * ```text
297
+ * bool -> uint32 -> int32 -> float16 -> float32
298
+ * weak f* --^
299
+ * ```
300
+ *
301
+ * The asterisk f* is a weak type used for JS number constants. When creating
302
+ * arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
303
+ * any array they are first combined with.
304
+ *
305
+ * **Examples:**
306
+ * - `promoteTypes(bool, int32) → int32`
307
+ * - `promoteTypes(uint32, int32) → int32`
308
+ * - `promoteTypes(int32, float16) → float16`
309
+ * - `promoteTypes(float16, float32) → float32`
310
+ * - `promoteTypes(uint32, float32) → float32`
311
+ */
312
+ function promoteTypes(dtype1, dtype2) {
313
+ if (dtype1 === dtype2) return dtype1;
314
+ const rank = {
315
+ [DType.Bool]: 0,
316
+ [DType.Uint32]: 1,
317
+ [DType.Int32]: 2,
318
+ [DType.Float16]: 3,
319
+ [DType.Float32]: 4
320
+ };
321
+ return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
322
+ }
239
323
  function dtypedArray(dtype, data) {
324
+ const { buffer, byteLength, byteOffset } = data;
325
+ const length = byteLength / byteWidth(dtype);
326
+ switch (dtype) {
327
+ case DType.Float32: return new Float32Array(buffer, byteOffset, length);
328
+ case DType.Int32:
329
+ case DType.Bool: return new Int32Array(buffer, byteOffset, length);
330
+ case DType.Uint32: return new Uint32Array(buffer, byteOffset, length);
331
+ case DType.Float16: return new Float16Array(buffer, byteOffset, length);
332
+ default: throw new Error(`Unimplemented dtype: ${dtype}`);
333
+ }
334
+ }
335
+ function dtypedJsArray(dtype, data) {
240
336
  switch (dtype) {
241
337
  case DType.Float32: return new Float32Array(data);
242
- case DType.Int32: return new Int32Array(data);
243
- case DType.Uint32: return new Uint32Array(data);
338
+ case DType.Int32:
244
339
  case DType.Bool: return new Int32Array(data);
340
+ case DType.Uint32: return new Uint32Array(data);
341
+ case DType.Float16: return new Float16Array(data);
245
342
  default: throw new Error(`Unimplemented dtype: ${dtype}`);
246
343
  }
247
344
  }
@@ -292,12 +389,21 @@ var AluExp = class AluExp {
292
389
  static cos(a) {
293
390
  return new AluExp(AluOp.Cos, a.dtype, [a]);
294
391
  }
392
+ static asin(a) {
393
+ return new AluExp(AluOp.Asin, a.dtype, [a]);
394
+ }
395
+ static atan(a) {
396
+ return new AluExp(AluOp.Atan, a.dtype, [a]);
397
+ }
295
398
  static exp(a) {
296
399
  return new AluExp(AluOp.Exp, a.dtype, [a]);
297
400
  }
298
401
  static log(a) {
299
402
  return new AluExp(AluOp.Log, a.dtype, [a]);
300
403
  }
404
+ static sqrt(a) {
405
+ return new AluExp(AluOp.Sqrt, a.dtype, [a]);
406
+ }
301
407
  static reciprocal(a) {
302
408
  return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
303
409
  }
@@ -343,24 +449,27 @@ var AluExp = class AluExp {
343
449
  static variable(dtype, name) {
344
450
  return new AluExp(AluOp.Variable, dtype, [], name);
345
451
  }
346
- static globalIndex(dtype, gid, bufidx) {
347
- return new AluExp(AluOp.GlobalIndex, dtype, [bufidx], gid);
452
+ static globalIndex(dtype, gid, len, bufidx) {
453
+ return new AluExp(AluOp.GlobalIndex, dtype, [bufidx], [gid, len]);
348
454
  }
349
455
  static globalView(dtype, gid, st, indices) {
350
456
  return new AluExp(AluOp.GlobalView, dtype, indices, [gid, st]);
351
457
  }
458
+ static f32(value) {
459
+ return AluExp.const(DType.Float32, value);
460
+ }
352
461
  static i32(value) {
353
462
  return AluExp.const(DType.Int32, value);
354
463
  }
355
464
  static u32(value) {
356
465
  return AluExp.const(DType.Uint32, value);
357
466
  }
358
- static f32(value) {
359
- return AluExp.const(DType.Float32, value);
360
- }
361
467
  static bool(value) {
362
468
  return AluExp.const(DType.Bool, Number(value));
363
469
  }
470
+ static f16(value) {
471
+ return AluExp.const(DType.Float16, value);
472
+ }
364
473
  not() {
365
474
  if (this.dtype !== DType.Bool) throw new Error("not() can only be called on boolean expressions");
366
475
  return AluExp.cmpne(this, AluExp.const(DType.Bool, true));
@@ -369,8 +478,11 @@ var AluExp = class AluExp {
369
478
  getHash() {
370
479
  if (this.#hash !== void 0) return this.#hash;
371
480
  const hasher = new FpHash();
372
- hasher.update(this.op, this.dtype, JSON.stringify(this.arg));
373
- hasher.update(this.src.length, ...this.src);
481
+ hasher.update(this.op);
482
+ hasher.update(this.dtype);
483
+ hasher.update(JSON.stringify(this.arg));
484
+ hasher.update(this.src.length);
485
+ for (const s of this.src) hasher.update(s);
374
486
  this.#hash = hasher.value;
375
487
  return this.#hash;
376
488
  }
@@ -390,9 +502,9 @@ var AluExp = class AluExp {
390
502
  reindexGids(gidMap) {
391
503
  return this.rewrite((exp) => {
392
504
  if (exp.op === AluOp.GlobalIndex) {
393
- const gid = exp.arg;
505
+ const [gid, len] = exp.arg;
394
506
  const newGid = gidMap.get(gid);
395
- if (newGid !== void 0 && newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, exp.src[0]);
507
+ if (newGid !== void 0 && newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
396
508
  } else if (exp.op === AluOp.GlobalView) {
397
509
  const gid = exp.arg[0];
398
510
  const newGid = gidMap.get(gid);
@@ -421,17 +533,16 @@ var AluExp = class AluExp {
421
533
  case AluOp.Sub:
422
534
  ret = [src[0].min - src[1].max, src[0].max - src[1].min];
423
535
  break;
424
- case AluOp.Mul: {
536
+ case AluOp.Mul:
425
537
  ret = minMax4((a, b) => a * b);
426
538
  break;
427
- }
428
- case AluOp.Idiv: {
429
- ret = minMax4((a, b) => Math.floor(a / b));
539
+ case AluOp.Idiv:
540
+ ret = minMax4((a, b) => Math.trunc(a / b));
430
541
  break;
431
- }
432
542
  case AluOp.Mod: {
433
543
  let divisorRange = src[1].#computeRange();
434
544
  if (divisorRange[0] <= 0 && divisorRange[1] >= 0) divisorRange = [0, Math.max(-divisorRange[0], divisorRange[1])];
545
+ if (divisorRange[1] < 0) divisorRange = [-divisorRange[1], -divisorRange[0]];
435
546
  const maxDivisor = isFloatDtype(this.dtype) ? divisorRange[1] : divisorRange[1] - 1;
436
547
  ret = [clamp(src[0].min, -maxDivisor, 0), clamp(src[0].max, 0, maxDivisor)];
437
548
  break;
@@ -443,10 +554,16 @@ var AluExp = class AluExp {
443
554
  ret = [Math.max(src[0].min, src[1].min), Math.max(src[0].max, src[1].max)];
444
555
  break;
445
556
  case AluOp.Sin:
446
- ret = [Math.sin(src[0].min), Math.sin(src[0].max)];
557
+ ret = [-1, 1];
447
558
  break;
448
559
  case AluOp.Cos:
449
- ret = [Math.cos(src[0].min), Math.cos(src[0].max)];
560
+ ret = [-1, 1];
561
+ break;
562
+ case AluOp.Asin:
563
+ ret = [-Math.PI / 2, Math.PI / 2];
564
+ break;
565
+ case AluOp.Atan:
566
+ ret = [-Math.PI / 2, Math.PI / 2];
450
567
  break;
451
568
  case AluOp.Exp:
452
569
  ret = [Math.exp(src[0].min), Math.exp(src[0].max)];
@@ -454,23 +571,31 @@ var AluExp = class AluExp {
454
571
  case AluOp.Log:
455
572
  ret = [Math.log(src[0].min), Math.log(src[0].max)];
456
573
  break;
574
+ case AluOp.Sqrt:
575
+ ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
576
+ break;
457
577
  case AluOp.Reciprocal:
458
578
  if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
459
579
  ret = [1 / src[0].max, 1 / src[0].min];
460
580
  break;
461
- case AluOp.Cast:
581
+ case AluOp.Cast: {
582
+ const wasFloat = isFloatDtype(src[0].dtype);
583
+ const bounded = Number.isFinite(src[0].min) && Number.isFinite(src[0].max);
462
584
  if (this.dtype === DType.Bool) {
463
585
  const canBeZero = src[0].min <= 0 && src[0].max >= 0;
464
586
  const mustBeZero = src[0].min === 0 && src[0].max === 0;
465
587
  ret = mustBeZero ? [0, 0] : canBeZero ? [0, 1] : [1, 1];
466
- } else if (this.dtype === DType.Int32) ret = [Math.trunc(src[0].min), Math.trunc(src[0].max)];
467
- else if (this.dtype === DType.Uint32) {
468
- const a = Math.trunc(src[0].min);
469
- const b = Math.trunc(src[0].max);
470
- if (Math.floor(a / 2 ** 32) !== Math.floor(b / 2 ** 32)) ret = [0, -1 >>> 0];
471
- else ret = [a % 2 ** 32, b % 2 ** 32];
588
+ } else if (this.dtype === DType.Int32) {
589
+ const a = wasFloat ? clamp(src[0].min, -2147483648, 2147483647) | 0 : src[0].min | 0;
590
+ const b = wasFloat ? clamp(src[0].max, -2147483648, 2147483647) | 0 : src[0].max | 0;
591
+ ret = bounded && a <= b ? [a, b] : [-Infinity, Infinity];
592
+ } else if (this.dtype === DType.Uint32) {
593
+ const a = wasFloat ? clamp(src[0].min, 0, 4294967295) >>> 0 : src[0].min >>> 0;
594
+ const b = wasFloat ? clamp(src[0].max, 0, 4294967295) >>> 0 : src[0].max >>> 0;
595
+ ret = bounded && a <= b ? [a, b] : [0, Infinity];
472
596
  } else ret = [src[0].min, src[0].max];
473
597
  break;
598
+ }
474
599
  case AluOp.Cmplt:
475
600
  ret = [0, 1];
476
601
  break;
@@ -493,6 +618,7 @@ var AluExp = class AluExp {
493
618
  ret[0] = clamp(ret[0], 0, 1);
494
619
  ret[1] = clamp(ret[1], 0, 1);
495
620
  }
621
+ if (this.dtype === DType.Uint32) ret[0] = Math.max(0, ret[0]);
496
622
  this.#range = ret;
497
623
  return ret;
498
624
  }
@@ -502,21 +628,63 @@ var AluExp = class AluExp {
502
628
  get max() {
503
629
  return this.#computeRange()[1];
504
630
  }
631
+ /** Largest known integer that divides self. */
632
+ constFactor() {
633
+ if (this.op === AluOp.Const) return Math.abs(this.arg);
634
+ if (this.op === AluOp.Add) return gcd(this.src[0].constFactor(), this.src[1].constFactor());
635
+ if (this.op === AluOp.Mul) {
636
+ if (this.src[0].op === AluOp.Const) return Math.abs(this.src[0].arg);
637
+ if (this.src[1].op === AluOp.Const) return Math.abs(this.src[1].arg);
638
+ }
639
+ return 1;
640
+ }
641
+ /**
642
+ * Checks if divisible by an integer v and returns the quotient if it is, or
643
+ * `null` if it's not divisible.
644
+ */
645
+ divides(v) {
646
+ if (v === 1) return this;
647
+ if (this.op === AluOp.Const && this.arg % v === 0) return AluExp.const(this.dtype, this.arg / v);
648
+ if (this.op === AluOp.Add) {
649
+ const a = this.src[0].divides(v);
650
+ if (a !== null) {
651
+ const b = this.src[1].divides(v);
652
+ if (b !== null) return AluExp.add(a, b);
653
+ }
654
+ }
655
+ if (this.op === AluOp.Mul) {
656
+ const a = this.src[0].divides(v);
657
+ if (a !== null) return AluExp.mul(a, this.src[1]);
658
+ const b = this.src[1].divides(v);
659
+ if (b !== null) return AluExp.mul(this.src[0], b);
660
+ }
661
+ return null;
662
+ }
505
663
  #isConstInt() {
506
664
  return this.op === AluOp.Const && (this.dtype === DType.Int32 || this.dtype === DType.Uint32);
507
665
  }
508
666
  /**
667
+ * Get all expressions by deeply matching an operation.
668
+ *
669
+ * For example: `((2+(3*5))+4).splitOp(+) -> [2,(3*5),4]`.
670
+ */
671
+ *splitOp(sep) {
672
+ if (this.op === sep) for (const src of this.src) yield* src.splitOp(sep);
673
+ else yield this;
674
+ }
675
+ /**
509
676
  * Simplify the expression by replacing any known patterns and deduping
510
677
  * identical subexpressions.
511
678
  */
512
679
  simplify(cache = /* @__PURE__ */ new Map()) {
513
680
  if (this.#simplified !== void 0) return this.#simplified;
514
681
  const hash = this.getHash();
515
- if (cache.has(hash)) return this.#simplified = cache.get(hash);
682
+ const prevCachedValue = cache.get(hash);
683
+ if (prevCachedValue !== void 0) return this.#simplified = prevCachedValue;
516
684
  const simplified = this.#simplifyInner(cache);
517
685
  const simplifiedHash = simplified.getHash();
518
- if (cache.has(simplifiedHash)) {
519
- const prevSimplified = cache.get(simplifiedHash);
686
+ const prevSimplified = cache.get(simplifiedHash);
687
+ if (prevSimplified !== void 0) {
520
688
  cache.set(hash, prevSimplified);
521
689
  this.#simplified = prevSimplified;
522
690
  return prevSimplified;
@@ -550,7 +718,24 @@ var AluExp = class AluExp {
550
718
  if (a.op === AluOp.Const && a.arg === -1) return new AluExp(opNeg, this.dtype, [src[0], b]);
551
719
  else if (b.op === AluOp.Const && b.arg === -1) return new AluExp(opNeg, this.dtype, [src[0], a]);
552
720
  }
721
+ if (op === AluOp.Where && src.slice(1).every((s, i) => s.op === AluOp.Const && s.arg === 1 - i)) return AluExp.cast(this.dtype, src[0]);
722
+ if (op === AluOp.Cmplt) {
723
+ if (src[0].min >= src[1].max) return AluExp.const(DType.Bool, false);
724
+ if (src[0].max < src[1].min) return AluExp.const(DType.Bool, true);
725
+ }
726
+ if (op === AluOp.Cmpne) {
727
+ if (src[0].max < src[1].min || src[0].min > src[1].max) return AluExp.const(DType.Bool, true);
728
+ }
729
+ if (op === AluOp.Where) {
730
+ if (src[0].max === 0) return src[2];
731
+ if (src[0].min === 1) return src[1];
732
+ }
553
733
  if (op === AluOp.Mod && src[1].op === AluOp.Const && src[0].min >= 0 && src[0].max < src[1].arg) return src[0];
734
+ if (op === AluOp.Mod && src[0].op === AluOp.Mod && src[1].#isConstInt() && src[0].src[1].#isConstInt()) {
735
+ const A = src[0].src[1].arg;
736
+ const B = src[1].arg;
737
+ if (A > 0 && B > 0 && (A % B === 0 || B % A === 0)) return AluExp.mod(src[0].src[0], AluExp.const(this.dtype, Math.min(A, B))).simplify();
738
+ }
554
739
  if (op === AluOp.Add && src[0].op === AluOp.Mul && src[0].src[1].#isConstInt() && src[1].op === AluOp.Mod && src[1].src[1].#isConstInt() && src[0].src[1].arg === src[1].src[1].arg) {
555
740
  const [mul, mod] = src;
556
741
  const check = (exp) => {
@@ -570,7 +755,7 @@ var AluExp = class AluExp {
570
755
  const A = numer.src[i].arg;
571
756
  if (A % B === 0) {
572
757
  let ret = numer.src[1 - i];
573
- if (A / B !== 1) ret = AluExp.mul(ret, AluExp.i32(A / B));
758
+ if (A / B !== 1) ret = AluExp.mul(ret, AluExp.const(ret.dtype, A / B));
574
759
  return ret.simplify(cache);
575
760
  }
576
761
  }
@@ -578,8 +763,8 @@ var AluExp = class AluExp {
578
763
  const A = numer.src[j].src[i].arg;
579
764
  if (A % B === 0) {
580
765
  let ret = numer.src[j].src[1 - i];
581
- if (A / B !== 1) ret = AluExp.mul(ret, AluExp.i32(A / B));
582
- ret = AluExp.add(ret, AluExp.idiv(numer.src[1 - j], B));
766
+ if (A / B !== 1) ret = AluExp.mul(ret, AluExp.const(ret.dtype, A / B));
767
+ ret = AluExp.add(ret, AluExp.idiv(numer.src[1 - j], AluExp.const(ret.dtype, B)));
583
768
  return ret.simplify(cache);
584
769
  }
585
770
  }
@@ -588,23 +773,81 @@ var AluExp = class AluExp {
588
773
  if (op === AluOp.Mod && src[1].#isConstInt() && src[1].arg > 0 && src[0].min >= 0) {
589
774
  const [numer, denom] = src;
590
775
  const B = denom.arg;
591
- for (let i = 0; i < 2; i++) if (numer.op === AluOp.Add && numer.src[i].#isConstInt()) {
592
- const A = numer.src[i].arg;
593
- let ret = numer.src[1 - i];
594
- if (A % B !== 0) ret = AluExp.add(ret, AluExp.i32(A % B));
595
- return ret.simplify(cache);
776
+ for (let i = 0; i < 2; i++) if (numer.op === AluOp.Add) {
777
+ if (numer.src[i].#isConstInt()) {
778
+ const A = numer.src[i].arg;
779
+ const x = numer.src[1 - i];
780
+ if (A % B === 0 && x.min >= 0) return AluExp.mod(x, denom).simplify(cache);
781
+ }
782
+ for (let j = 0; j < 2; j++) if (numer.src[i].op === AluOp.Mul && numer.src[i].src[j].#isConstInt()) {
783
+ const A = numer.src[i].src[j].arg;
784
+ const x = numer.src[1 - i];
785
+ if (A % B === 0 && x.min >= 0) return AluExp.mod(x, denom).simplify(cache);
786
+ }
787
+ } else if (numer.op === AluOp.Mul) {
788
+ if (numer.src[i].#isConstInt()) {
789
+ const A = numer.src[i].arg;
790
+ if (A % B === 0) return AluExp.const(this.dtype, 0);
791
+ if (A % B === 1) return AluExp.mod(numer.src[1 - i], denom).simplify(cache);
792
+ }
596
793
  }
597
794
  }
598
- if (op === AluOp.Cmplt) {
599
- if (src[0].min >= src[1].max) return AluExp.const(DType.Bool, false);
600
- if (src[0].max < src[1].min) return AluExp.const(DType.Bool, true);
601
- }
602
- if (op === AluOp.Cmpne) {
603
- if (src[0].max < src[1].min || src[0].min > src[1].max) return AluExp.const(DType.Bool, true);
795
+ const commOps = [
796
+ AluOp.Add,
797
+ AluOp.Mul,
798
+ AluOp.Max,
799
+ AluOp.Min
800
+ ];
801
+ if (commOps.includes(op)) {
802
+ const p = (a, b) => new AluExp(op, this.dtype, [a, b]);
803
+ if (src[0].op === AluOp.Const) return p(src[1], src[0]).simplify(cache);
804
+ if (src[0].op === op && src[0].src[1].op === AluOp.Const) if (src[1].op === AluOp.Const) return p(src[0].src[0], p(src[0].src[1], src[1])).simplify(cache);
805
+ else return p(p(src[0].src[0], src[1]), src[0].src[1]).simplify(cache);
806
+ if (src[1].op === op && src[1].src[1].op === AluOp.Const) return p(p(src[0], src[1].src[0]), src[1].src[1]).simplify(cache);
604
807
  }
605
- if (op === AluOp.Where) {
606
- if (src[0].max === 0) return src[2];
607
- if (src[0].min === 1) return src[1];
808
+ if (op === AluOp.Mod || op === AluOp.Idiv && src[1].#isConstInt()) {
809
+ const [x, y] = src;
810
+ {
811
+ const factors = [];
812
+ const terms = [];
813
+ for (const u of x.splitOp(AluOp.Add)) {
814
+ const factor = u.constFactor();
815
+ factors.push(factor);
816
+ terms.push(u.divides(factor));
817
+ }
818
+ const g = gcd(y.arg, ...factors);
819
+ if (g !== 1) {
820
+ let ret = new AluExp(op, this.dtype, [factors.map((f, i) => AluExp.mul(AluExp.const(terms[i].dtype, f / g), terms[i])).reduceRight((a, x$1) => AluExp.add(x$1, a)), AluExp.const(y.dtype, y.arg / g)]);
821
+ if (op === AluOp.Mod) ret = AluExp.mul(ret, AluExp.const(this.dtype, g));
822
+ return ret.simplify(cache);
823
+ }
824
+ }
825
+ if (y.arg > 0) {
826
+ let [xNoConst, constVal] = [x, 0];
827
+ if (x.op === AluOp.Add && x.src[1].op === AluOp.Const) [xNoConst, constVal] = [x.src[0], x.src[1].arg];
828
+ const terms = [];
829
+ const factors = [];
830
+ for (const u of xNoConst.splitOp(AluOp.Add)) {
831
+ const f = u.constFactor();
832
+ const divided = u.divides(f);
833
+ terms.push(divided ?? u);
834
+ factors.push(divided ? f : 1);
835
+ }
836
+ const quotients = factors.map((f) => Math.floor(f / y.arg));
837
+ const remainders = factors.map((f) => f % y.arg);
838
+ const gcdVal = remainders.reduce((g, r) => gcd(g, r), y.arg);
839
+ if (constVal % y.arg !== constVal || gcdVal !== 1 || remainders.some((r, i) => r === 0 || r !== factors[i] && op === AluOp.Mod)) {
840
+ let quo = AluExp.const(x.dtype, Math.floor(constVal / y.arg));
841
+ let rem = AluExp.const(x.dtype, Math.floor(constVal % y.arg / gcdVal));
842
+ for (let i = 0; i < terms.length; i++) if (op === AluOp.Idiv && remainders[i] !== 0) rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(factors[i] / gcdVal)), terms[i]));
843
+ else {
844
+ rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(remainders[i] / gcdVal)), terms[i]));
845
+ quo = AluExp.add(quo, AluExp.mul(AluExp.const(x.dtype, quotients[i]), terms[i]));
846
+ }
847
+ if (!((x.min < 0 || rem.min < 0) && remainders.some((r) => r !== 0))) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
848
+ else return AluExp.add(AluExp.idiv(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal))), quo).simplify(cache);
849
+ }
850
+ }
608
851
  }
609
852
  const newExp = src.every((s, i) => s === this.src[i]) ? this : new AluExp(op, this.dtype, src, this.arg);
610
853
  return newExp;
@@ -645,14 +888,20 @@ var AluExp = class AluExp {
645
888
  switch (this.op) {
646
889
  case AluOp.Sin: return Math.sin(x);
647
890
  case AluOp.Cos: return Math.cos(x);
891
+ case AluOp.Asin: return Math.asin(x);
892
+ case AluOp.Atan: return Math.atan(x);
648
893
  case AluOp.Exp: return Math.exp(x);
649
894
  case AluOp.Log: return Math.log(x);
895
+ case AluOp.Sqrt: return Math.sqrt(x);
650
896
  case AluOp.Reciprocal: return 1 / x;
651
- case AluOp.Cast: if (this.dtype === DType.Int32) return Math.trunc(x) | 0;
652
- else if (this.dtype === DType.Uint32) return Math.trunc(x) >>> 0;
653
- else if (this.dtype === DType.Float32) return x;
654
- else if (this.dtype === DType.Bool) return Number(Boolean(x));
655
- else throw new Error(`Unsupported cast to ${this.dtype}`);
897
+ case AluOp.Cast: {
898
+ const wasFloat = isFloatDtype(this.src[0].dtype);
899
+ if (this.dtype === DType.Int32) return (wasFloat ? clamp(x, -2147483648, 2147483647) : x) | 0;
900
+ else if (this.dtype === DType.Uint32) return (wasFloat ? clamp(x, 0, 4294967295) : x) >>> 0;
901
+ else if (isFloatDtype(this.dtype)) return x;
902
+ else if (this.dtype === DType.Bool) return Number(Boolean(x));
903
+ else throw new Error(`Unsupported cast to ${this.dtype}`);
904
+ }
656
905
  case AluOp.Bitcast: {
657
906
  const buf = new ArrayBuffer(byteWidth(this.dtype));
658
907
  const view = new DataView(buf);
@@ -660,10 +909,12 @@ var AluExp = class AluExp {
660
909
  if (fromType === DType.Float32) view.setFloat32(0, x, true);
661
910
  else if (fromType === DType.Int32) view.setInt32(0, x, true);
662
911
  else if (fromType === DType.Uint32) view.setUint32(0, x, true);
912
+ else if (fromType === DType.Float16) view.setFloat16(0, x, true);
663
913
  else throw new Error(`Unsupported bitcast from ${fromType}`);
664
914
  if (this.dtype === DType.Float32) return view.getFloat32(0, true);
665
915
  else if (this.dtype === DType.Int32) return view.getInt32(0, true);
666
916
  else if (this.dtype === DType.Uint32) return view.getUint32(0, true);
917
+ else if (this.dtype === DType.Float16) return view.getFloat16(0, true);
667
918
  else throw new Error(`Unsupported bitcast to ${this.dtype}`);
668
919
  }
669
920
  default: throw new Error(`Missing implemementation for ${this.op}`);
@@ -692,7 +943,7 @@ var AluExp = class AluExp {
692
943
  }
693
944
  case AluOp.GlobalIndex: {
694
945
  if (!globals) throw new Error("Missing globals function");
695
- const gid = this.arg;
946
+ const gid = this.arg[0];
696
947
  const bufidx = this.src[0].evaluate(context, globals);
697
948
  return globals(gid, bufidx);
698
949
  }
@@ -722,13 +973,7 @@ var AluExp = class AluExp {
722
973
  [AluOp.Cmplt]: "<",
723
974
  [AluOp.Cmpne]: "!="
724
975
  };
725
- const UNARY_SYM = {
726
- [AluOp.Sin]: "sin",
727
- [AluOp.Cos]: "cos",
728
- [AluOp.Exp]: "exp",
729
- [AluOp.Log]: "log",
730
- [AluOp.Reciprocal]: "1/"
731
- };
976
+ const UNARY_SYM = { [AluOp.Reciprocal]: "1/" };
732
977
  return this.fold((node, parts) => {
733
978
  switch (node.op) {
734
979
  case AluOp.Const: return "" + (node.dtype === DType.Bool ? Boolean(node.arg) : node.arg);
@@ -737,7 +982,7 @@ var AluExp = class AluExp {
737
982
  const [name, n] = node.arg;
738
983
  return `#${name}{${n}}`;
739
984
  }
740
- case AluOp.GlobalIndex: return `G_${node.arg}<${node.dtype}>[${strip1(parts[0])}]`;
985
+ case AluOp.GlobalIndex: return `G_${node.arg[0]}<${node.dtype}>[${strip1(parts[0])}]`;
741
986
  case AluOp.GlobalView: {
742
987
  const [gid, st] = node.arg;
743
988
  const shape = st.shape.join(",");
@@ -766,6 +1011,17 @@ var AluExp = class AluExp {
766
1011
  };
767
1012
  return recurse(this);
768
1013
  }
1014
+ /** Check if any expression in the tree satisfies a predicate. */
1015
+ some(predicate) {
1016
+ const visited = /* @__PURE__ */ new Set();
1017
+ const recurse = (exp) => {
1018
+ if (visited.has(exp)) return false;
1019
+ if (predicate(exp)) return true;
1020
+ visited.add(exp);
1021
+ return exp.src.some(recurse);
1022
+ };
1023
+ return recurse(this);
1024
+ }
769
1025
  /** Rewrite the expression recursively using a visitor. */
770
1026
  rewrite(visitor) {
771
1027
  return this.fold((exp, newSrc) => {
@@ -784,6 +1040,23 @@ var AluExp = class AluExp {
784
1040
  });
785
1041
  return result;
786
1042
  }
1043
+ /** Produce a list of all distinct AluOp in this expression. */
1044
+ distinctOps() {
1045
+ const ops = /* @__PURE__ */ new Set();
1046
+ this.fold((exp) => {
1047
+ ops.add(exp.op);
1048
+ });
1049
+ return ops;
1050
+ }
1051
+ /** Rewrite GlobalView operations to GlobalIndex operations. */
1052
+ rewriteGlobalViews() {
1053
+ return this.rewrite((exp) => {
1054
+ if (exp.op === AluOp.GlobalView) {
1055
+ const [gid, st] = exp.arg;
1056
+ return accessorGlobal(exp.dtype, gid, st, exp.src);
1057
+ }
1058
+ });
1059
+ }
787
1060
  };
788
1061
  /** Symbolic form for each mathematical operation. */
789
1062
  let AluOp = /* @__PURE__ */ function(AluOp$1) {
@@ -796,8 +1069,11 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
796
1069
  AluOp$1["Max"] = "Max";
797
1070
  AluOp$1["Sin"] = "Sin";
798
1071
  AluOp$1["Cos"] = "Cos";
1072
+ AluOp$1["Asin"] = "Asin";
1073
+ AluOp$1["Atan"] = "Atan";
799
1074
  AluOp$1["Exp"] = "Exp";
800
1075
  AluOp$1["Log"] = "Log";
1076
+ AluOp$1["Sqrt"] = "Sqrt";
801
1077
  AluOp$1["Reciprocal"] = "Reciprocal";
802
1078
  AluOp$1["Cast"] = "Cast";
803
1079
  AluOp$1["Bitcast"] = "Bitcast";
@@ -825,8 +1101,11 @@ const AluGroup = {
825
1101
  Unary: new Set([
826
1102
  AluOp.Sin,
827
1103
  AluOp.Cos,
1104
+ AluOp.Asin,
1105
+ AluOp.Atan,
828
1106
  AluOp.Exp,
829
1107
  AluOp.Log,
1108
+ AluOp.Sqrt,
830
1109
  AluOp.Reciprocal,
831
1110
  AluOp.Cast,
832
1111
  AluOp.Bitcast
@@ -847,8 +1126,11 @@ const AluGroup = {
847
1126
  RequiredFloat: new Set([
848
1127
  AluOp.Sin,
849
1128
  AluOp.Cos,
1129
+ AluOp.Asin,
1130
+ AluOp.Atan,
850
1131
  AluOp.Exp,
851
1132
  AluOp.Log,
1133
+ AluOp.Sqrt,
852
1134
  AluOp.Reciprocal
853
1135
  ])
854
1136
  };
@@ -877,7 +1159,7 @@ var Kernel = class {
877
1159
  this.exp = exp.simplify();
878
1160
  }
879
1161
  hash(state) {
880
- state.update(this.nargs, this.size, this.exp, this.reduction);
1162
+ state.update(this.nargs).update(this.size).update(this.exp).update(this.reduction);
881
1163
  }
882
1164
  pprint() {
883
1165
  let details = PPrint.pp(`exp = ${this.exp}`);
@@ -890,7 +1172,7 @@ var Kernel = class {
890
1172
  }
891
1173
  /** The dtype of the values output by this kernel. */
892
1174
  get dtype() {
893
- if (this.reduction) return this.reduction.fusion.dtype;
1175
+ if (this.reduction) return this.reduction.epilogue.dtype;
894
1176
  else return this.exp.dtype;
895
1177
  }
896
1178
  /** The number of bytes in the output array when evaluating this kernel. */
@@ -914,22 +1196,23 @@ var Kernel = class {
914
1196
  * at this level since they depend on GPU, versus CPU or Wasm.
915
1197
  */
916
1198
  var Reduction = class {
917
- constructor(dtype, op, size, fusion = AluVar.acc(dtype)) {
1199
+ constructor(dtype, op, size, epilogue = AluVar.acc(dtype)) {
918
1200
  this.dtype = dtype;
919
1201
  this.op = op;
920
1202
  this.size = size;
921
- this.fusion = fusion;
1203
+ this.epilogue = epilogue;
922
1204
  if (!AluGroup.Reduce.has(op)) throw new TypeError(`Unsupported reduction: ${op}`);
1205
+ this.epilogue = epilogue.simplify();
923
1206
  }
924
1207
  hash(state) {
925
- state.update(this.dtype, this.op, this.size, this.fusion);
1208
+ state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
926
1209
  }
927
1210
  toString() {
928
- return `${this.op}{${this.size}} -> ${this.fusion}`;
1211
+ return `${this.op}{${this.size}} -> ${this.epilogue}`;
929
1212
  }
930
1213
  /** Get the identity for this reduction operation. */
931
1214
  get identity() {
932
- if (this.dtype === DType.Bool) return this.op === AluOp.Add || this.op === AluOp.Max ? false : true;
1215
+ if (this.dtype === DType.Bool) return this.op === AluOp.Add || this.op === AluOp.Max ? 0 : 1;
933
1216
  else if (this.dtype === DType.Int32) {
934
1217
  if (this.op === AluOp.Add) return 0;
935
1218
  else if (this.op === AluOp.Mul) return 1;
@@ -940,7 +1223,7 @@ var Reduction = class {
940
1223
  else if (this.op === AluOp.Mul) return 1;
941
1224
  else if (this.op === AluOp.Min) return -1 >>> 0;
942
1225
  else if (this.op === AluOp.Max) return 0;
943
- } else if (this.dtype === DType.Float32) {
1226
+ } else if (isFloatDtype(this.dtype)) {
944
1227
  if (this.op === AluOp.Add) return 0;
945
1228
  else if (this.op === AluOp.Mul) return 1;
946
1229
  else if (this.op === AluOp.Min) return Infinity;
@@ -963,7 +1246,7 @@ var Reduction = class {
963
1246
  else if (this.op === AluOp.Mul) return values.reduce((a, b) => a * b >>> 0, 1);
964
1247
  else if (this.op === AluOp.Min) return values.reduce((a, b) => Math.min(a, b), -1 >>> 0);
965
1248
  else if (this.op === AluOp.Max) return values.reduce((a, b) => Math.max(a, b), 0);
966
- } else if (this.dtype === DType.Float32) {
1249
+ } else if (isFloatDtype(this.dtype)) {
967
1250
  if (this.op === AluOp.Add) return values.reduce((a, b) => a + b, 0);
968
1251
  else if (this.op === AluOp.Mul) return values.reduce((a, b) => a * b, 1);
969
1252
  else if (this.op === AluOp.Min) return values.reduce((a, b) => Math.min(a, b), Infinity);
@@ -975,12 +1258,13 @@ var Reduction = class {
975
1258
  /** Expression for accessing `indices` in input array with the given shape. */
976
1259
  function accessorGlobal(dtype, gid, st, indices) {
977
1260
  const [index, valid] = st.toAluExp(indices);
978
- return AluExp.where(valid, AluExp.globalIndex(dtype, gid, index), AluExp.const(dtype, 0));
1261
+ const [, len] = st.views[0].dataRange();
1262
+ return AluExp.where(valid, AluExp.globalIndex(dtype, gid, len, index), AluExp.const(dtype, 0));
979
1263
  }
980
1264
  /** Expression for accessing `indices` in an array recipe with variable "idx". */
981
- function accessorAluExp(dtype, exp, st, indices) {
1265
+ function accessorAluExp(exp, st, indices) {
982
1266
  const [index, valid] = st.toAluExp(indices);
983
- return AluExp.where(valid, exp.substitute({ idx: index }), AluExp.const(dtype, 0));
1267
+ return AluExp.where(valid, exp.substitute({ idx: index }), AluExp.const(exp.dtype, 0));
984
1268
  }
985
1269
  function threefry2x32(k0, k1, c0, c1) {
986
1270
  const rotl32 = (x, r) => (x << r | x >>> 32 - r) >>> 0;
@@ -1164,6 +1448,25 @@ var View = class View {
1164
1448
  if (this.#contiguous === void 0) this.#contiguous = this.size === 0 || this.offset === 0 && this.mask === null && deepEqual(this.strides, defaultStrides(this.shape));
1165
1449
  return this.#contiguous;
1166
1450
  }
1451
+ /** Return the range of data being indexed in this view, or [0, 0] if none. */
1452
+ dataRange() {
1453
+ if (this.size === 0 || this.mask && this.mask[0][0] === this.mask[0][1]) return [0, 0];
1454
+ let min = this.offset;
1455
+ let max = this.offset;
1456
+ for (let i = 0; i < this.ndim; i++) {
1457
+ let [lo, hi] = this.mask ? this.mask[i] : [0, this.shape[i]];
1458
+ --hi;
1459
+ const s = this.strides[i];
1460
+ if (s > 0) {
1461
+ min += s * lo;
1462
+ max += s * hi;
1463
+ } else if (s < 0) {
1464
+ min += s * hi;
1465
+ max += s * lo;
1466
+ }
1467
+ }
1468
+ return [min, max + 1];
1469
+ }
1167
1470
  /** Produce an AluExp for evaluating this view at an index. */
1168
1471
  toAluExp(idxs) {
1169
1472
  let iexpr = AluExp.i32(this.offset);
@@ -1478,6 +1781,39 @@ var ShapeTracker = class ShapeTracker {
1478
1781
  }
1479
1782
  return st.expand(newShape);
1480
1783
  }
1784
+ /**
1785
+ * Repeat data in each axis by a positive number of repetitions.
1786
+ *
1787
+ * - If `tile` is true (default): [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
1788
+ * - If `tile` is false: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
1789
+ */
1790
+ repeat(reps, tile = true) {
1791
+ if (reps.length > this.shape.length) throw new Error(`Too many repeats ${jstr(reps)} for shape ${jstr(this.shape)}`);
1792
+ if (reps.some((c) => c <= 0)) throw new Error(`Invalid repeats ${jstr(reps)}`);
1793
+ if (reps.length === 0) return this;
1794
+ const noop = this.shape.slice(0, -reps.length);
1795
+ const shape = this.shape.slice(-reps.length);
1796
+ return this.broadcast([...noop, ...shape.flatMap((s, i) => tile ? [reps[i], s] : [s, reps[i]])], shape.map((_, i) => noop.length + 2 * i + (tile ? 0 : 1))).reshape([...noop, ...shape.map((s, i) => s * reps[i])]);
1797
+ }
1798
+ /** Move axis i to axis j. */
1799
+ moveaxis(i, j) {
1800
+ const perm = range(this.shape.length);
1801
+ perm.splice(i, 1);
1802
+ perm.splice(j, 0, i);
1803
+ return this.permute(perm);
1804
+ }
1805
+ /** Like pad(), but allows for negative values. */
1806
+ padOrShrink(arg) {
1807
+ const padArg = [];
1808
+ const shrinkArg = [];
1809
+ for (let i = 0; i < arg.length; i++) {
1810
+ const [b, e] = arg[i];
1811
+ if (b < -this.shape[i] || e < -this.shape[i] || b + e < -this.shape[i]) throw new Error(`Invalid padOrShrink ${jstr(arg)} for ${jstr(this.shape)}`);
1812
+ padArg.push([Math.max(0, b), Math.max(0, e)]);
1813
+ shrinkArg.push([Math.max(0, -b), this.shape[i] - Math.max(0, -e)]);
1814
+ }
1815
+ return this.shrink(shrinkArg).pad(padArg);
1816
+ }
1481
1817
  };
1482
1818
  function applyLast(ar, f) {
1483
1819
  return ar.toSpliced(ar.length - 1, 1, f(ar[ar.length - 1]));
@@ -1598,13 +1934,7 @@ function tuneNullopt(kernel) {
1598
1934
  vars.gidx = AluExp.special(DType.Int32, "gidx", kernel.size);
1599
1935
  if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
1600
1936
  return {
1601
- exp: kernel.exp.rewrite((exp) => {
1602
- if (exp.op === AluOp.GlobalView) {
1603
- const gid = exp.arg[0];
1604
- const st = exp.arg[1];
1605
- return accessorGlobal(exp.dtype, gid, st, exp.src);
1606
- }
1607
- }).substitute(vars).simplify(),
1937
+ exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
1608
1938
  outputIdxExp: AluExp.special(DType.Int32, "gidx", kernel.size),
1609
1939
  threadCount: kernel.size,
1610
1940
  size: { reduce: kernel.reduction ? kernel.reduction.size : 0 }
@@ -1699,13 +2029,19 @@ function tuneWebgpu(kernel) {
1699
2029
  const s = dim.st.shape.slice(dim.upcast);
1700
2030
  addIndices(s, AluVar.upcast);
1701
2031
  }
1702
- const newExp = exp.rewrite((exp$1) => {
2032
+ let newExp = exp.rewrite((exp$1) => {
1703
2033
  if (exp$1.op === AluOp.GlobalView) {
1704
2034
  const gid = exp$1.arg[0];
1705
2035
  const st = exp$1.arg[1];
1706
2036
  return accessorGlobal(exp$1.dtype, gid, st.compose(dim.st), indices);
1707
2037
  }
1708
2038
  });
2039
+ const [iexpr, vexpr] = dim.st.toAluExp(indices);
2040
+ if (vexpr.min !== 1) throw new Error("Invariant violation: vexpr !== true");
2041
+ newExp = newExp.substitute({
2042
+ gidx: AluExp.idiv(iexpr, AluExp.i32(reduction.size)).simplify(),
2043
+ ridx: AluExp.mod(iexpr, AluExp.i32(reduction.size)).simplify()
2044
+ });
1709
2045
  const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
1710
2046
  const outputUpcast = dim.outputSt.shape.slice(dim.groups);
1711
2047
  const [outputIdxExp, _] = dim.outputSt.toAluExp([...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)]);
@@ -1727,7 +2063,7 @@ function tuneWebgpu(kernel) {
1727
2063
  //#endregion
1728
2064
  //#region src/backend/cpu.ts
1729
2065
  /** Most basic implementation of `Backend` for testing. */
1730
- var CPUBackend = class {
2066
+ var CpuBackend = class {
1731
2067
  type = "cpu";
1732
2068
  maxArgs = Infinity;
1733
2069
  #buffers;
@@ -1737,10 +2073,10 @@ var CPUBackend = class {
1737
2073
  this.#nextSlot = 1;
1738
2074
  }
1739
2075
  malloc(size, initialData) {
1740
- const buffer = new ArrayBuffer(size);
2076
+ const buffer = new Uint8Array(size);
1741
2077
  if (initialData) {
1742
2078
  if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
1743
- new Uint8Array(buffer).set(new Uint8Array(initialData));
2079
+ buffer.set(initialData);
1744
2080
  }
1745
2081
  const slot = this.#nextSlot++;
1746
2082
  this.#buffers.set(slot, {
@@ -1779,7 +2115,7 @@ var CPUBackend = class {
1779
2115
  const { exp } = tuneNullopt(kernel);
1780
2116
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
1781
2117
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
1782
- const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg, exp$1.dtype]));
2118
+ const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
1783
2119
  const inputArrays = inputBuffers.map((buf, i) => {
1784
2120
  const dtype = usedArgs.get(i);
1785
2121
  if (!dtype) return null;
@@ -1801,7 +2137,7 @@ var CPUBackend = class {
1801
2137
  }, globals);
1802
2138
  acc = kernel.reduction.evaluate(acc, item);
1803
2139
  }
1804
- outputArray[i] = kernel.reduction.fusion.evaluate({ acc });
2140
+ outputArray[i] = kernel.reduction.epilogue.evaluate({ acc });
1805
2141
  }
1806
2142
  }
1807
2143
  #getBuffer(slot) {
@@ -1811,16 +2147,1594 @@ var CPUBackend = class {
1811
2147
  }
1812
2148
  };
1813
2149
 
2150
+ //#endregion
2151
+ //#region src/backend/wasm/allocator.ts
2152
+ /** Simple tensor memory allocator for WebAssembly linear memory. */
2153
+ var WasmAllocator = class {
2154
+ #memory;
2155
+ #headPtr;
2156
+ #freeLists;
2157
+ #allocatedBuffers;
2158
+ constructor(memory) {
2159
+ this.#memory = memory;
2160
+ this.#headPtr = 64;
2161
+ this.#freeLists = /* @__PURE__ */ new Map();
2162
+ this.#allocatedBuffers = /* @__PURE__ */ new Map();
2163
+ }
2164
+ malloc(size) {
2165
+ if (size === 0) return 0;
2166
+ const sizeClass = this.#findSizeClass(size);
2167
+ const freeList = this.#freeLists.get(sizeClass);
2168
+ let ptr;
2169
+ if (freeList && freeList.length > 0) ptr = freeList.pop();
2170
+ else ptr = this.#bumpAlloc(sizeClass);
2171
+ this.#allocatedBuffers.set(ptr, sizeClass);
2172
+ return ptr;
2173
+ }
2174
+ free(ptr) {
2175
+ if (ptr === 0) return;
2176
+ const sizeClass = this.#allocatedBuffers.get(ptr);
2177
+ if (sizeClass === void 0) throw new Error(`Attempting to free unallocated pointer: ${ptr}`);
2178
+ const freeList = this.#freeLists.get(sizeClass);
2179
+ if (freeList) freeList.push(ptr);
2180
+ else this.#freeLists.set(sizeClass, [ptr]);
2181
+ this.#allocatedBuffers.delete(ptr);
2182
+ }
2183
+ #bumpAlloc(size) {
2184
+ const ptr = this.#headPtr;
2185
+ size = size + 63 & -64;
2186
+ this.#headPtr += size;
2187
+ if (ptr + size > this.#memory.buffer.byteLength) this.#memory.grow((ptr + size + 65535 >> 16) - (this.#memory.buffer.byteLength >> 16));
2188
+ return ptr;
2189
+ }
2190
+ #findSizeClass(size) {
2191
+ if (size <= 512) return size + 63 & -64;
2192
+ if (size <= 2048) return size + 511 & -512;
2193
+ if (size <= 65536) {
2194
+ let sizeClass = 4096;
2195
+ while (sizeClass < size) sizeClass *= 2;
2196
+ return sizeClass;
2197
+ }
2198
+ return size + 65535 & -65536;
2199
+ }
2200
+ getStats() {
2201
+ const freeListSizes = /* @__PURE__ */ new Map();
2202
+ for (const [sizeClass, freeList] of this.#freeLists) if (freeList.length > 0) freeListSizes.set(sizeClass, freeList.length);
2203
+ return {
2204
+ totalAllocated: this.#headPtr,
2205
+ freeListSizes
2206
+ };
2207
+ }
2208
+ };
2209
+
2210
+ //#endregion
2211
+ //#region src/backend/wasm/builtins.ts
2212
+ /**
2213
+ * Approximate e^x.
2214
+ *
2215
+ * Method: range-reduce x = k*ln2 + r with k = round(x/ln2), |r|<=~0.3466
2216
+ * then e^x = 2^k * P(r), where P is 5th-order poly (Taylor).
2217
+ */
2218
+ function wasm_exp(cg) {
2219
+ return cg.function([cg.f32], [cg.f32], () => {
2220
+ const k_f = cg.local.declare(cg.f32);
2221
+ const k = cg.local.declare(cg.i32);
2222
+ const r = cg.local.declare(cg.f32);
2223
+ const p = cg.local.declare(cg.f32);
2224
+ const scale = cg.local.declare(cg.f32);
2225
+ cg.local.get(0);
2226
+ cg.f32.const(1 / Math.LN2);
2227
+ cg.f32.mul();
2228
+ cg.f32.nearest();
2229
+ cg.local.tee(k_f);
2230
+ cg.i32.trunc_sat_f32_s();
2231
+ cg.local.set(k);
2232
+ cg.local.get(k);
2233
+ cg.i32.const(127);
2234
+ cg.i32.gt_s();
2235
+ cg.if(cg.void);
2236
+ cg.f32.const(Infinity);
2237
+ cg.return();
2238
+ cg.end();
2239
+ cg.local.get(k);
2240
+ cg.i32.const(-126);
2241
+ cg.i32.lt_s();
2242
+ cg.if(cg.void);
2243
+ cg.f32.const(0);
2244
+ cg.return();
2245
+ cg.end();
2246
+ cg.local.get(0);
2247
+ cg.local.get(k_f);
2248
+ cg.f32.const(Math.LN2);
2249
+ cg.f32.mul();
2250
+ cg.f32.sub();
2251
+ cg.local.set(r);
2252
+ cg.f32.const(1 / 120);
2253
+ cg.local.get(r);
2254
+ cg.f32.mul();
2255
+ cg.f32.const(1 / 24);
2256
+ cg.f32.add();
2257
+ cg.local.get(r);
2258
+ cg.f32.mul();
2259
+ cg.f32.const(1 / 6);
2260
+ cg.f32.add();
2261
+ cg.local.get(r);
2262
+ cg.f32.mul();
2263
+ cg.f32.const(1 / 2);
2264
+ cg.f32.add();
2265
+ cg.local.get(r);
2266
+ cg.f32.mul();
2267
+ cg.f32.const(1);
2268
+ cg.f32.add();
2269
+ cg.local.get(r);
2270
+ cg.f32.mul();
2271
+ cg.f32.const(1);
2272
+ cg.f32.add();
2273
+ cg.local.set(p);
2274
+ cg.local.get(k);
2275
+ cg.i32.const(127);
2276
+ cg.i32.add();
2277
+ cg.i32.const(23);
2278
+ cg.i32.shl();
2279
+ cg.f32.reinterpret_i32();
2280
+ cg.local.set(scale);
2281
+ cg.local.get(p);
2282
+ cg.local.get(scale);
2283
+ cg.f32.mul();
2284
+ });
2285
+ }
2286
+ /**
2287
+ * Approximate ln(x), x > 0.
2288
+ *
2289
+ * Method: decompose x = m * 2^e with m in [1,2), e integer (via bit ops)
2290
+ * ln(x) = e*ln2 + ln(m); use atanh-style series with t=(m-1)/(m+1)
2291
+ * ln(m) ≈ 2*(t + t^3/3 + t^5/5 + t^7/7)
2292
+ */
2293
+ function wasm_log(cg) {
2294
+ return cg.function([cg.f32], [cg.f32], () => {
2295
+ const bits = cg.local.declare(cg.i32);
2296
+ const e = cg.local.declare(cg.i32);
2297
+ const m = cg.local.declare(cg.f32);
2298
+ const t = cg.local.declare(cg.f32);
2299
+ const t2 = cg.local.declare(cg.f32);
2300
+ const t3 = cg.local.declare(cg.f32);
2301
+ const t5 = cg.local.declare(cg.f32);
2302
+ const t7 = cg.local.declare(cg.f32);
2303
+ const lnm = cg.local.declare(cg.f32);
2304
+ const el2 = cg.local.declare(cg.f32);
2305
+ cg.local.get(0);
2306
+ cg.f32.const(0);
2307
+ cg.f32.le();
2308
+ cg.if(cg.void);
2309
+ cg.f32.const(NaN);
2310
+ cg.return();
2311
+ cg.end();
2312
+ cg.local.get(0);
2313
+ cg.i32.reinterpret_f32();
2314
+ cg.local.tee(bits);
2315
+ cg.i32.const(23);
2316
+ cg.i32.shr_u();
2317
+ cg.i32.const(255);
2318
+ cg.i32.and();
2319
+ cg.i32.const(127);
2320
+ cg.i32.sub();
2321
+ cg.local.set(e);
2322
+ cg.local.get(bits);
2323
+ cg.i32.const(8388607);
2324
+ cg.i32.and();
2325
+ cg.i32.const(1065353216);
2326
+ cg.i32.or();
2327
+ cg.f32.reinterpret_i32();
2328
+ cg.local.set(m);
2329
+ cg.local.get(m);
2330
+ cg.f32.const(1);
2331
+ cg.f32.sub();
2332
+ cg.local.get(m);
2333
+ cg.f32.const(1);
2334
+ cg.f32.add();
2335
+ cg.f32.div();
2336
+ cg.local.set(t);
2337
+ cg.local.get(t);
2338
+ cg.local.get(t);
2339
+ cg.f32.mul();
2340
+ cg.local.set(t2);
2341
+ cg.local.get(t);
2342
+ cg.local.get(t2);
2343
+ cg.f32.mul();
2344
+ cg.local.set(t3);
2345
+ cg.local.get(t3);
2346
+ cg.local.get(t2);
2347
+ cg.f32.mul();
2348
+ cg.local.set(t5);
2349
+ cg.local.get(t5);
2350
+ cg.local.get(t2);
2351
+ cg.f32.mul();
2352
+ cg.local.set(t7);
2353
+ cg.local.get(t7);
2354
+ cg.f32.const(1 / 7);
2355
+ cg.f32.mul();
2356
+ cg.local.get(t5);
2357
+ cg.f32.const(1 / 5);
2358
+ cg.f32.mul();
2359
+ cg.f32.add();
2360
+ cg.local.get(t3);
2361
+ cg.f32.const(1 / 3);
2362
+ cg.f32.mul();
2363
+ cg.f32.add();
2364
+ cg.local.get(t);
2365
+ cg.f32.add();
2366
+ cg.f32.const(2);
2367
+ cg.f32.mul();
2368
+ cg.local.set(lnm);
2369
+ cg.local.get(e);
2370
+ cg.f32.convert_i32_s();
2371
+ cg.f32.const(Math.LN2);
2372
+ cg.f32.mul();
2373
+ cg.local.set(el2);
2374
+ cg.local.get(el2);
2375
+ cg.local.get(lnm);
2376
+ cg.f32.add();
2377
+ });
2378
+ }
2379
+ /**
2380
+ * Common helper to approximate sin(x) and cos(x).
2381
+ *
2382
+ * Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
2383
+ * z = y - q*(π/2); use one of two polynomials on z:
2384
+ * sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
2385
+ * cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
2386
+ */
2387
+ function _sincos(cg) {
2388
+ const y = cg.local.declare(cg.f32);
2389
+ const qf = cg.local.declare(cg.f32);
2390
+ const q = cg.local.declare(cg.i32);
2391
+ const z = cg.local.declare(cg.f32);
2392
+ const z2 = cg.local.declare(cg.f32);
2393
+ const sz = cg.local.declare(cg.f32);
2394
+ const cz = cg.local.declare(cg.f32);
2395
+ cg.local.get(0);
2396
+ cg.local.get(0);
2397
+ cg.f32.const(1 / (2 * Math.PI));
2398
+ cg.f32.mul();
2399
+ cg.f32.nearest();
2400
+ cg.local.tee(qf);
2401
+ cg.f32.const(2 * Math.PI);
2402
+ cg.f32.mul();
2403
+ cg.f32.sub();
2404
+ cg.local.set(y);
2405
+ cg.local.get(y);
2406
+ cg.f32.const(2 / Math.PI);
2407
+ cg.f32.mul();
2408
+ cg.f32.nearest();
2409
+ cg.local.tee(qf);
2410
+ cg.i32.trunc_f32_s();
2411
+ cg.local.set(q);
2412
+ cg.local.get(y);
2413
+ cg.local.get(qf);
2414
+ cg.f32.const(Math.PI / 2);
2415
+ cg.f32.mul();
2416
+ cg.f32.sub();
2417
+ cg.local.tee(z);
2418
+ cg.local.get(z);
2419
+ cg.f32.mul();
2420
+ cg.local.set(z2);
2421
+ cg.f32.const(-1 / 5040);
2422
+ cg.local.get(z2);
2423
+ cg.f32.mul();
2424
+ cg.f32.const(1 / 120);
2425
+ cg.f32.add();
2426
+ cg.local.get(z2);
2427
+ cg.f32.mul();
2428
+ cg.f32.const(-1 / 6);
2429
+ cg.f32.add();
2430
+ cg.local.get(z2);
2431
+ cg.f32.mul();
2432
+ cg.f32.const(1);
2433
+ cg.f32.add();
2434
+ cg.local.get(z);
2435
+ cg.f32.mul();
2436
+ cg.local.set(sz);
2437
+ cg.f32.const(-1 / 720);
2438
+ cg.local.get(z2);
2439
+ cg.f32.mul();
2440
+ cg.f32.const(1 / 24);
2441
+ cg.f32.add();
2442
+ cg.local.get(z2);
2443
+ cg.f32.mul();
2444
+ cg.f32.const(-1 / 2);
2445
+ cg.f32.add();
2446
+ cg.local.get(z2);
2447
+ cg.f32.mul();
2448
+ cg.f32.const(1);
2449
+ cg.f32.add();
2450
+ cg.local.set(cz);
2451
+ return {
2452
+ q,
2453
+ sz,
2454
+ cz
2455
+ };
2456
+ }
2457
+ /**
2458
+ * Approximate sin(x).
2459
+ *
2460
+ * Quadrant mapping: k=q mod 4: 0: +sz, 1: +cz, 2: -sz, 3: -cz
2461
+ */
2462
+ function wasm_sin(cg) {
2463
+ return cg.function([cg.f32], [cg.f32], () => {
2464
+ const { q, sz, cz } = _sincos(cg);
2465
+ const mag = cg.local.declare(cg.f32);
2466
+ cg.local.get(cz);
2467
+ cg.local.get(sz);
2468
+ cg.local.get(q);
2469
+ cg.i32.const(1);
2470
+ cg.i32.and();
2471
+ cg.select();
2472
+ cg.local.tee(mag);
2473
+ cg.f32.neg();
2474
+ cg.local.get(mag);
2475
+ cg.local.get(q);
2476
+ cg.i32.const(2);
2477
+ cg.i32.and();
2478
+ cg.select();
2479
+ });
2480
+ }
2481
+ /**
2482
+ * Approximate cos(x).
2483
+ *
2484
+ * Quadrant mapping: k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
2485
+ */
2486
+ function wasm_cos(cg) {
2487
+ return cg.function([cg.f32], [cg.f32], () => {
2488
+ const { q, sz, cz } = _sincos(cg);
2489
+ const mag = cg.local.declare(cg.f32);
2490
+ cg.local.get(sz);
2491
+ cg.local.get(cz);
2492
+ cg.local.get(q);
2493
+ cg.i32.const(1);
2494
+ cg.i32.and();
2495
+ cg.select();
2496
+ cg.local.tee(mag);
2497
+ cg.f32.neg();
2498
+ cg.local.get(mag);
2499
+ cg.local.get(q);
2500
+ cg.i32.const(1);
2501
+ cg.i32.add();
2502
+ cg.i32.const(2);
2503
+ cg.i32.and();
2504
+ cg.select();
2505
+ });
2506
+ }
2507
+ /** Helper function for approximating arctan(x). */
2508
+ function _atan(cg) {
2509
+ const x = cg.local.declare(cg.f32);
2510
+ const abs_x = cg.local.declare(cg.f32);
2511
+ const z = cg.local.declare(cg.f32);
2512
+ const z2 = cg.local.declare(cg.f32);
2513
+ const p = cg.local.declare(cg.f32);
2514
+ cg.local.set(x);
2515
+ cg.local.get(x);
2516
+ cg.f32.abs();
2517
+ cg.local.set(abs_x);
2518
+ cg.f32.const(1);
2519
+ cg.local.get(abs_x);
2520
+ cg.f32.div();
2521
+ cg.local.get(abs_x);
2522
+ cg.local.get(abs_x);
2523
+ cg.f32.const(1);
2524
+ cg.f32.ge();
2525
+ cg.select();
2526
+ cg.local.set(z);
2527
+ cg.local.get(z);
2528
+ cg.local.get(z);
2529
+ cg.f32.mul();
2530
+ cg.local.set(z2);
2531
+ cg.f32.const(.0415796528637);
2532
+ cg.local.get(z2);
2533
+ cg.f32.mul();
2534
+ cg.f32.const(.661705427875);
2535
+ cg.f32.add();
2536
+ cg.local.get(z2);
2537
+ cg.f32.mul();
2538
+ cg.f32.const(.999998614341);
2539
+ cg.f32.add();
2540
+ cg.f32.const(.173698870181);
2541
+ cg.local.get(z2);
2542
+ cg.f32.mul();
2543
+ cg.f32.const(.994987933645);
2544
+ cg.f32.add();
2545
+ cg.local.get(z2);
2546
+ cg.f32.mul();
2547
+ cg.f32.const(1);
2548
+ cg.f32.add();
2549
+ cg.f32.div();
2550
+ cg.local.get(z);
2551
+ cg.f32.mul();
2552
+ cg.local.set(p);
2553
+ cg.f32.const(Math.PI / 2);
2554
+ cg.local.get(p);
2555
+ cg.f32.sub();
2556
+ cg.local.get(p);
2557
+ cg.local.get(abs_x);
2558
+ cg.f32.const(1);
2559
+ cg.f32.ge();
2560
+ cg.select();
2561
+ cg.local.get(x);
2562
+ cg.f32.copysign();
2563
+ }
2564
+ /**
2565
+ * Approximate atan(x).
2566
+ *
2567
+ * Method: if |x| < 1, use rational approximation: atan(x) ≈ x * P(x^2) / Q(x^2)
2568
+ * where P(u) = A0 + A1*u + A2*u^2 (degree 2)
2569
+ * Q(u) = 1 + B1*u + B2*u^2 (degree 2)
2570
+ * if |x| >= 1, use: atan(x) = sign(x)*π/2 - atan(1/x)
2571
+ * (fitted coefficients, max error ~5e-7 on [0,1])
2572
+ */
2573
+ function wasm_atan(cg) {
2574
+ return cg.function([cg.f32], [cg.f32], () => {
2575
+ cg.local.get(0);
2576
+ _atan(cg);
2577
+ });
2578
+ }
2579
+ /**
2580
+ * Approximate asin(x).
2581
+ *
2582
+ * Method: asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
2583
+ */
2584
+ function wasm_asin(cg) {
2585
+ return cg.function([cg.f32], [cg.f32], () => {
2586
+ cg.local.get(0);
2587
+ cg.f32.const(1);
2588
+ cg.local.get(0);
2589
+ cg.local.get(0);
2590
+ cg.f32.mul();
2591
+ cg.f32.sub();
2592
+ cg.f32.sqrt();
2593
+ cg.f32.const(1);
2594
+ cg.f32.add();
2595
+ cg.f32.div();
2596
+ _atan(cg);
2597
+ cg.f32.const(2);
2598
+ cg.f32.mul();
2599
+ });
2600
+ }
2601
+ /**
2602
+ * Threefry2x32 pseudorandom number generator.
2603
+ *
2604
+ * Takes two 32-bit keys and two 32-bit counters as input,
2605
+ * returns two 32-bit pseudorandom values.
2606
+ */
2607
+ function wasm_threefry2x32(cg) {
2608
+ return cg.function([
2609
+ cg.i32,
2610
+ cg.i32,
2611
+ cg.i32,
2612
+ cg.i32
2613
+ ], [cg.i32, cg.i32], () => {
2614
+ const ks0 = cg.local.declare(cg.i32);
2615
+ const ks1 = cg.local.declare(cg.i32);
2616
+ const ks2 = cg.local.declare(cg.i32);
2617
+ const x0 = cg.local.declare(cg.i32);
2618
+ const x1 = cg.local.declare(cg.i32);
2619
+ const mix = (rot) => {
2620
+ cg.local.get(x0);
2621
+ cg.local.get(x1);
2622
+ cg.i32.add();
2623
+ cg.local.set(x0);
2624
+ cg.local.get(x1);
2625
+ cg.i32.const(rot);
2626
+ cg.i32.rotl();
2627
+ cg.local.get(x0);
2628
+ cg.i32.xor();
2629
+ cg.local.set(x1);
2630
+ };
2631
+ const keySchedule = (k0, k1, round) => {
2632
+ cg.local.get(x0);
2633
+ cg.local.get(k0);
2634
+ cg.i32.add();
2635
+ cg.local.set(x0);
2636
+ cg.local.get(x1);
2637
+ cg.local.get(k1);
2638
+ cg.i32.add();
2639
+ cg.i32.const(round);
2640
+ cg.i32.add();
2641
+ cg.local.set(x1);
2642
+ };
2643
+ cg.local.get(0);
2644
+ cg.local.set(ks0);
2645
+ cg.local.get(1);
2646
+ cg.local.set(ks1);
2647
+ cg.local.get(0);
2648
+ cg.local.get(1);
2649
+ cg.i32.xor();
2650
+ cg.i32.const(466688986);
2651
+ cg.i32.xor();
2652
+ cg.local.set(ks2);
2653
+ cg.local.get(2);
2654
+ cg.local.get(ks0);
2655
+ cg.i32.add();
2656
+ cg.local.set(x0);
2657
+ cg.local.get(3);
2658
+ cg.local.get(ks1);
2659
+ cg.i32.add();
2660
+ cg.local.set(x1);
2661
+ mix(13), mix(15), mix(26), mix(6);
2662
+ keySchedule(ks1, ks2, 1);
2663
+ mix(17), mix(29), mix(16), mix(24);
2664
+ keySchedule(ks2, ks0, 2);
2665
+ mix(13), mix(15), mix(26), mix(6);
2666
+ keySchedule(ks0, ks1, 3);
2667
+ mix(17), mix(29), mix(16), mix(24);
2668
+ keySchedule(ks1, ks2, 4);
2669
+ mix(13), mix(15), mix(26), mix(6);
2670
+ keySchedule(ks2, ks0, 5);
2671
+ cg.local.get(x0);
2672
+ cg.local.get(x1);
2673
+ });
2674
+ }
2675
+
2676
+ //#endregion
2677
+ //#region src/backend/wasm/wasmblr.ts
2678
+ /**
2679
+ * @file Minimalist WebAssembly assembler. This allows you to emit WebAssembly
2680
+ * bytecode directly from the browser.
2681
+ *
2682
+ * Self-contained port of https://github.com/bwasti/wasmblr to TypeScript.
2683
+ * Some operation names in this module are written in `snake_case` to match
2684
+ * their names in the Wasm specification.
2685
+ *
2686
+ * Reference: https://pengowray.github.io/wasm-ops/.
2687
+ */
2688
+ const magicModuleHeader = [
2689
+ 0,
2690
+ 97,
2691
+ 115,
2692
+ 109
2693
+ ];
2694
+ const moduleVersion = [
2695
+ 1,
2696
+ 0,
2697
+ 0,
2698
+ 0
2699
+ ];
2700
+ function assert(condition, message) {
2701
+ if (!condition) throw new Error(message || "Assertion failed");
2702
+ }
2703
+ function encodeSigned(n) {
2704
+ const out = [];
2705
+ let more = true;
2706
+ while (more) {
2707
+ let byte = n & 127;
2708
+ n >>= 7;
2709
+ if (n === 0 && (byte & 64) === 0 || n === -1 && (byte & 64) !== 0) more = false;
2710
+ else byte |= 128;
2711
+ out.push(byte);
2712
+ }
2713
+ return out;
2714
+ }
2715
+ function encodeUnsigned(n) {
2716
+ const out = [];
2717
+ do {
2718
+ let byte = n & 127;
2719
+ n = n >>> 7;
2720
+ if (n !== 0) byte |= 128;
2721
+ out.push(byte);
2722
+ } while (n !== 0);
2723
+ return out;
2724
+ }
2725
+ function encodeString(s) {
2726
+ const bytes = new TextEncoder().encode(s);
2727
+ return [bytes.length, ...bytes];
2728
+ }
2729
+ function encodeBlocktype(type) {
2730
+ assert(type.length > 0, "blocktype must have at least one type");
2731
+ if (type.length === 1) return [type[0].typeId];
2732
+ return [
2733
+ 96,
2734
+ ...encodeUnsigned(0),
2735
+ ...encodeUnsigned(type.length),
2736
+ ...type.map((t) => t.typeId)
2737
+ ];
2738
+ }
2739
+ function encodeOpcode(opcode) {
2740
+ if (typeof opcode === "number") return [opcode];
2741
+ return [opcode[0], ...encodeUnsigned(opcode[1])];
2742
+ }
2743
+ function concat(out, inp) {
2744
+ out.push(...inp);
2745
+ }
2746
+ var Function_ = class {
2747
+ inputTypes;
2748
+ outputTypes;
2749
+ body;
2750
+ locals = [];
2751
+ constructor(inputTypes, outputTypes, body) {
2752
+ this.inputTypes = inputTypes;
2753
+ this.outputTypes = outputTypes;
2754
+ this.body = body || (() => {});
2755
+ }
2756
+ emit() {
2757
+ this.locals = [];
2758
+ this.body();
2759
+ }
2760
+ };
2761
+ var Memory = class {
2762
+ min = 0;
2763
+ max = 0;
2764
+ isShared = false;
2765
+ aString = "";
2766
+ bString = "";
2767
+ constructor(cg) {
2768
+ this.cg = cg;
2769
+ }
2770
+ /** Declare the size of the memory. Each page is 64 KiB. */
2771
+ pages(min, max = 0) {
2772
+ assert(this.min === 0 && this.max === 0);
2773
+ this.min = min;
2774
+ this.max = max;
2775
+ return this;
2776
+ }
2777
+ export(a) {
2778
+ assert(!this.isImport && !this.isExport, "already set");
2779
+ this.aString = a;
2780
+ return this;
2781
+ }
2782
+ shared(isShared) {
2783
+ this.isShared = isShared;
2784
+ return this;
2785
+ }
2786
+ import(a, b) {
2787
+ assert(!this.isImport && !this.isExport, "already set");
2788
+ this.aString = a;
2789
+ this.bString = b;
2790
+ return this;
2791
+ }
2792
+ size() {
2793
+ this.cg._emit(63);
2794
+ this.cg._emit(0);
2795
+ }
2796
+ grow() {
2797
+ this.cg._emit(64);
2798
+ this.cg._emit(0);
2799
+ }
2800
+ get isImport() {
2801
+ return this.aString.length > 0 && this.bString.length > 0;
2802
+ }
2803
+ get isExport() {
2804
+ return this.aString.length > 0 && this.bString.length === 0;
2805
+ }
2806
+ };
2807
+ /** Public API of WebAssembly assembler. */
2808
+ var CodeGenerator = class {
2809
+ local;
2810
+ i32;
2811
+ f32;
2812
+ v128;
2813
+ i32x4;
2814
+ f32x4;
2815
+ memory;
2816
+ void = {
2817
+ typeId: 64,
2818
+ name: "void"
2819
+ };
2820
+ #functions = [];
2821
+ #importedFunctions = [];
2822
+ #exportedFunctions = /* @__PURE__ */ new Map();
2823
+ #curFunction = null;
2824
+ #curBytes = [];
2825
+ #typeStack = [];
2826
+ #blockFrames = [];
2827
+ constructor() {
2828
+ this.local = new Local(this);
2829
+ this.i32 = new I32(this);
2830
+ this.f32 = new F32(this);
2831
+ this.v128 = new V128(this);
2832
+ this.i32x4 = new I32x4(this);
2833
+ this.f32x4 = new F32x4(this);
2834
+ this.memory = new Memory(this);
2835
+ }
2836
+ unreachable() {
2837
+ this._emit(0);
2838
+ }
2839
+ nop() {
2840
+ this._emit(1);
2841
+ }
2842
+ block(...type) {
2843
+ this.#blockFrames.push({
2844
+ idx: this.#typeStack.length,
2845
+ ty: type
2846
+ });
2847
+ this._emit(2);
2848
+ this._emit(encodeBlocktype(type));
2849
+ }
2850
+ loop(...type) {
2851
+ this.#blockFrames.push({
2852
+ idx: this.#typeStack.length,
2853
+ ty: type
2854
+ });
2855
+ this._emit(3);
2856
+ this._emit(encodeBlocktype(type));
2857
+ }
2858
+ if(...type) {
2859
+ assert(this._pop().typeId === this.i32.typeId, "if_: expected i32");
2860
+ this.#blockFrames.push({
2861
+ idx: this.#typeStack.length,
2862
+ ty: type
2863
+ });
2864
+ this._emit(4);
2865
+ this._emit(encodeBlocktype(type));
2866
+ }
2867
+ else() {
2868
+ assert(this.#blockFrames.length > 0, "else: no block to else");
2869
+ const frame = this.#blockFrames[this.#blockFrames.length - 1];
2870
+ this.#typeStack = this.#typeStack.slice(0, frame.idx);
2871
+ this._emit(5);
2872
+ }
2873
+ /** End a block (`block`, `if`/`else`, `loop`, or function). */
2874
+ end() {
2875
+ const frame = this.#blockFrames.pop();
2876
+ assert(frame !== void 0, "end: no block to end");
2877
+ this.#typeStack = this.#typeStack.slice(0, frame.idx);
2878
+ for (const ty of frame.ty) if (ty.typeId !== this.void.typeId) this._push(ty);
2879
+ this._emit(11);
2880
+ }
2881
+ /** Branch to a block a certain depth outward on the stack. */
2882
+ br(depth) {
2883
+ this._emit(12);
2884
+ this._emit(encodeUnsigned(depth));
2885
+ }
2886
+ /** Conditional branch to a block a certain depth outward on the stack. */
2887
+ br_if(depth) {
2888
+ assert(this._pop().typeId === this.i32.typeId, "br_if: expected i32");
2889
+ this._emit(13);
2890
+ this._emit(encodeUnsigned(depth));
2891
+ }
2892
+ /** Jump table that indexes into a label vector (like switch). */
2893
+ br_table(...depths) {
2894
+ assert(this._pop().typeId === this.i32.typeId, "br_table: expected i32");
2895
+ assert(depths.length > 0, "br_table: expected at least one default depth");
2896
+ this._emit(14);
2897
+ this._emit(encodeUnsigned(depths.length - 1));
2898
+ for (const d of depths) this._emit(encodeUnsigned(d));
2899
+ }
2900
+ /** Return from a function, branching out of the outermost block. */
2901
+ return() {
2902
+ this._emit(15);
2903
+ }
2904
+ /** Call a function with the given ID. */
2905
+ call(fn) {
2906
+ const totalFunctions = this.#importedFunctions.length + this.#functions.length;
2907
+ assert(fn < totalFunctions, "function index does not exist");
2908
+ const func = fn < this.#importedFunctions.length ? this.#importedFunctions[fn] : this.#functions[fn - this.#importedFunctions.length];
2909
+ for (let i = func.inputTypes.length - 1; i >= 0; i--) {
2910
+ const argType = this._pop();
2911
+ assert(argType.typeId === func.inputTypes[i].typeId, `call: argument ${i} type mismatch, expected ${func.inputTypes[i].name} got ${argType.name}`);
2912
+ }
2913
+ for (const outputType of func.outputTypes) this._push(outputType);
2914
+ this._emit(16);
2915
+ this._emit(encodeUnsigned(fn));
2916
+ }
2917
+ /** Throw away an operand on the stack. */
2918
+ drop() {
2919
+ this._pop();
2920
+ this._emit(26);
2921
+ }
2922
+ /** Select one of the first two operands (T, F) based on the third operand (i32)'s value. */
2923
+ select() {
2924
+ assert(this._pop().typeId === this.i32.typeId, "select: expected i32 condition");
2925
+ const [b, a] = [this._pop(), this._pop()];
2926
+ assert(a.typeId === b.typeId, "select: expected same type for both operands");
2927
+ this._push(a);
2928
+ this._emit(27);
2929
+ }
2930
+ /** Import a JavaScript function; returns its index. */
2931
+ importFunction(module$1, name, inputTypes, outputTypes) {
2932
+ if (this.#functions.length > 0) throw new Error("function imports must precede defining functions");
2933
+ const idx = this.#importedFunctions.length;
2934
+ this.#importedFunctions.push({
2935
+ module: module$1,
2936
+ name,
2937
+ inputTypes,
2938
+ outputTypes
2939
+ });
2940
+ return idx;
2941
+ }
2942
+ /** Export a function. */
2943
+ export(fn, name) {
2944
+ this.#exportedFunctions.set(fn, name);
2945
+ }
2946
+ /** Declare a new function; returns its index. */
2947
+ function(inputTypes, outputTypes, body) {
2948
+ const idx = this.#importedFunctions.length + this.#functions.length;
2949
+ this.#functions.push(new Function_(inputTypes, outputTypes, body));
2950
+ return idx;
2951
+ }
2952
+ _declareLocal(type) {
2953
+ assert(this.#curFunction !== null, "No current function");
2954
+ const idx = this.#curFunction.locals.length + this.#curFunction.inputTypes.length;
2955
+ this.#curFunction.locals.push(type);
2956
+ return idx;
2957
+ }
2958
+ _inputTypes() {
2959
+ assert(this.#curFunction !== null, "No current function");
2960
+ return this.#curFunction.inputTypes;
2961
+ }
2962
+ _locals() {
2963
+ assert(this.#curFunction !== null, "No current function");
2964
+ return this.#curFunction.locals;
2965
+ }
2966
+ _push(type) {
2967
+ if (!type) throw new Error(`pushing type ${type}`);
2968
+ this.#typeStack.push(type);
2969
+ }
2970
+ _pop() {
2971
+ assert(this.#typeStack.length > 0, "popping empty stack");
2972
+ return this.#typeStack.pop();
2973
+ }
2974
+ _emit(bytes) {
2975
+ if (typeof bytes === "number") this.#curBytes.push(bytes);
2976
+ else this.#curBytes.push(...bytes);
2977
+ }
2978
+ finish() {
2979
+ this.#curBytes = [];
2980
+ const emittedBytes = [];
2981
+ concat(emittedBytes, magicModuleHeader);
2982
+ concat(emittedBytes, moduleVersion);
2983
+ const typeSectionBytes = [];
2984
+ const totalFunctionTypes = this.#importedFunctions.length + this.#functions.length;
2985
+ concat(typeSectionBytes, encodeUnsigned(totalFunctionTypes));
2986
+ for (const f of [...this.#importedFunctions, ...this.#functions]) {
2987
+ typeSectionBytes.push(96);
2988
+ concat(typeSectionBytes, encodeUnsigned(f.inputTypes.length));
2989
+ for (const t of f.inputTypes) typeSectionBytes.push(t.typeId);
2990
+ concat(typeSectionBytes, encodeUnsigned(f.outputTypes.length));
2991
+ for (const t of f.outputTypes) typeSectionBytes.push(t.typeId);
2992
+ }
2993
+ emittedBytes.push(1);
2994
+ concat(emittedBytes, encodeUnsigned(typeSectionBytes.length));
2995
+ concat(emittedBytes, typeSectionBytes);
2996
+ const importSectionBytes = [];
2997
+ const numImports = this.#importedFunctions.length + (this.memory.isImport ? 1 : 0);
2998
+ if (numImports > 0) {
2999
+ concat(importSectionBytes, encodeUnsigned(numImports));
3000
+ for (let i = 0; i < this.#importedFunctions.length; i++) {
3001
+ const f = this.#importedFunctions[i];
3002
+ concat(importSectionBytes, encodeString(f.module));
3003
+ concat(importSectionBytes, encodeString(f.name));
3004
+ importSectionBytes.push(0);
3005
+ concat(importSectionBytes, encodeUnsigned(i));
3006
+ }
3007
+ if (this.memory.isImport) {
3008
+ concat(importSectionBytes, encodeString(this.memory.aString));
3009
+ concat(importSectionBytes, encodeString(this.memory.bString));
3010
+ importSectionBytes.push(2);
3011
+ if (this.memory.min && this.memory.max) {
3012
+ if (this.memory.isShared) importSectionBytes.push(3);
3013
+ else importSectionBytes.push(1);
3014
+ concat(importSectionBytes, encodeUnsigned(this.memory.min));
3015
+ concat(importSectionBytes, encodeUnsigned(this.memory.max));
3016
+ } else {
3017
+ assert(!this.memory.isShared, "shared memory must have a max size");
3018
+ importSectionBytes.push(0);
3019
+ concat(importSectionBytes, encodeUnsigned(this.memory.min));
3020
+ }
3021
+ }
3022
+ emittedBytes.push(2);
3023
+ concat(emittedBytes, encodeUnsigned(importSectionBytes.length));
3024
+ concat(emittedBytes, importSectionBytes);
3025
+ }
3026
+ const functionSectionBytes = [];
3027
+ concat(functionSectionBytes, encodeUnsigned(this.#functions.length));
3028
+ for (let i = 0; i < this.#functions.length; i++) {
3029
+ const typeIndex = this.#importedFunctions.length + i;
3030
+ concat(functionSectionBytes, encodeUnsigned(typeIndex));
3031
+ }
3032
+ emittedBytes.push(3);
3033
+ concat(emittedBytes, encodeUnsigned(functionSectionBytes.length));
3034
+ concat(emittedBytes, functionSectionBytes);
3035
+ const memorySectionBytes = [];
3036
+ if (!this.memory.isImport && (this.memory.min || this.memory.max)) {
3037
+ memorySectionBytes.push(1);
3038
+ if (this.memory.min && this.memory.max) {
3039
+ if (this.memory.isShared) memorySectionBytes.push(3);
3040
+ else memorySectionBytes.push(1);
3041
+ concat(memorySectionBytes, encodeUnsigned(this.memory.min));
3042
+ concat(memorySectionBytes, encodeUnsigned(this.memory.max));
3043
+ } else {
3044
+ assert(!this.memory.isShared, "shared memory must have a max size");
3045
+ memorySectionBytes.push(0);
3046
+ concat(memorySectionBytes, encodeUnsigned(this.memory.min));
3047
+ }
3048
+ emittedBytes.push(5);
3049
+ concat(emittedBytes, encodeUnsigned(memorySectionBytes.length));
3050
+ concat(emittedBytes, memorySectionBytes);
3051
+ }
3052
+ const exportSectionBytes = [];
3053
+ const numExports = this.#exportedFunctions.size + (this.memory.isExport ? 1 : 0);
3054
+ concat(exportSectionBytes, encodeUnsigned(numExports));
3055
+ if (this.memory.isExport) {
3056
+ concat(exportSectionBytes, encodeString(this.memory.aString));
3057
+ exportSectionBytes.push(2);
3058
+ exportSectionBytes.push(0);
3059
+ }
3060
+ for (const [key, name] of this.#exportedFunctions.entries()) {
3061
+ concat(exportSectionBytes, encodeString(name));
3062
+ exportSectionBytes.push(0);
3063
+ concat(exportSectionBytes, encodeUnsigned(key));
3064
+ }
3065
+ emittedBytes.push(7);
3066
+ concat(emittedBytes, encodeUnsigned(exportSectionBytes.length));
3067
+ concat(emittedBytes, exportSectionBytes);
3068
+ const codeSectionBytes = [];
3069
+ concat(codeSectionBytes, encodeUnsigned(this.#functions.length));
3070
+ for (const f of this.#functions) {
3071
+ this.#typeStack = [];
3072
+ this.#blockFrames = [{
3073
+ idx: 0,
3074
+ ty: f.outputTypes
3075
+ }];
3076
+ this.#curFunction = f;
3077
+ this.#curBytes = [];
3078
+ f.emit();
3079
+ this.end();
3080
+ const bodyBytes = [...this.#curBytes];
3081
+ this.#curBytes = [];
3082
+ concat(this.#curBytes, encodeUnsigned(f.locals.length));
3083
+ for (const l of f.locals) {
3084
+ this._emit(1);
3085
+ this._emit(l.typeId);
3086
+ }
3087
+ const headerBytes = [...this.#curBytes];
3088
+ const fnSize = headerBytes.length + bodyBytes.length;
3089
+ concat(codeSectionBytes, encodeUnsigned(fnSize));
3090
+ concat(codeSectionBytes, headerBytes);
3091
+ concat(codeSectionBytes, bodyBytes);
3092
+ }
3093
+ this.#curFunction = null;
3094
+ emittedBytes.push(10);
3095
+ concat(emittedBytes, encodeUnsigned(codeSectionBytes.length));
3096
+ concat(emittedBytes, codeSectionBytes);
3097
+ return new Uint8Array(emittedBytes);
3098
+ }
3099
+ };
3100
+ var Local = class {
3101
+ constructor(cg) {
3102
+ this.cg = cg;
3103
+ }
3104
+ declare(type) {
3105
+ return this.cg._declareLocal(type);
3106
+ }
3107
+ get(idx) {
3108
+ assert(Number.isInteger(idx), "getting non-integer local");
3109
+ const inputTypes = this.cg._inputTypes();
3110
+ if (idx < inputTypes.length) this.cg._push(inputTypes[idx]);
3111
+ else this.cg._push(this.cg._locals()[idx - inputTypes.length]);
3112
+ this.cg._emit(32);
3113
+ this.cg._emit(encodeUnsigned(idx));
3114
+ }
3115
+ set(idx) {
3116
+ const t = this.cg._pop();
3117
+ const inputTypes = this.cg._inputTypes();
3118
+ const expectedType = idx < inputTypes.length ? inputTypes[idx] : this.cg._locals()[idx - inputTypes.length];
3119
+ assert(expectedType.typeId === t.typeId, "can't set local to this value (wrong type)");
3120
+ this.cg._emit(33);
3121
+ this.cg._emit(encodeUnsigned(idx));
3122
+ }
3123
+ tee(idx) {
3124
+ const t = this.cg._pop();
3125
+ const inputTypes = this.cg._inputTypes();
3126
+ const expectedType = idx < inputTypes.length ? inputTypes[idx] : this.cg._locals()[idx - inputTypes.length];
3127
+ assert(expectedType.typeId === t.typeId, "can't tee local to this value (wrong type)");
3128
+ this.cg._emit(34);
3129
+ this.cg._emit(encodeUnsigned(idx));
3130
+ this.cg._push(expectedType);
3131
+ }
3132
+ };
3133
+ function UNARY_OP(op, opcode, inType, outType) {
3134
+ return function() {
3135
+ const t = this.cg._pop();
3136
+ assert(t.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inType} -> ${outType})`);
3137
+ this.cg._emit(encodeOpcode(opcode));
3138
+ this.cg._push(this.cg[outType]);
3139
+ };
3140
+ }
3141
+ function BINARY_OP(op, opcode, typeA, typeB, outType) {
3142
+ return function() {
3143
+ const b = this.cg._pop();
3144
+ const a = this.cg._pop();
3145
+ assert(a.typeId === this.cg[typeA].typeId && b.typeId === this.cg[typeB].typeId, `invalid type for ${op} (${typeA}, ${typeB} -> ${outType})`);
3146
+ this.cg._emit(encodeOpcode(opcode));
3147
+ this.cg._push(this.cg[outType]);
3148
+ };
3149
+ }
3150
+ function LOAD_OP(op, opcode, outType) {
3151
+ return function(align = 0, offset = 0) {
3152
+ const idxType = this.cg._pop();
3153
+ assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
3154
+ this.cg._emit(encodeOpcode(opcode));
3155
+ this.cg._emit(encodeUnsigned(align));
3156
+ this.cg._emit(encodeUnsigned(offset));
3157
+ this.cg._push(this.cg[outType]);
3158
+ };
3159
+ }
3160
+ function STORE_OP(op, opcode, inType) {
3161
+ return function(align = 0, offset = 0) {
3162
+ const valType = this.cg._pop();
3163
+ const idxType = this.cg._pop();
3164
+ assert(valType.typeId === this.cg[inType].typeId, `invalid value type for ${op} (${inType})`);
3165
+ assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
3166
+ this.cg._emit(encodeOpcode(opcode));
3167
+ this.cg._emit(encodeUnsigned(align));
3168
+ this.cg._emit(encodeUnsigned(offset));
3169
+ };
3170
+ }
3171
+ var I32 = class {
3172
+ constructor(cg) {
3173
+ this.cg = cg;
3174
+ }
3175
+ get typeId() {
3176
+ return 127;
3177
+ }
3178
+ get name() {
3179
+ return "i32";
3180
+ }
3181
+ const(i) {
3182
+ this.cg._emit(65);
3183
+ this.cg._emit(encodeSigned(i));
3184
+ this.cg._push(this);
3185
+ }
3186
+ clz = UNARY_OP("clz", 103, "i32", "i32");
3187
+ ctz = UNARY_OP("ctz", 104, "i32", "i32");
3188
+ popcnt = UNARY_OP("popcnt", 105, "i32", "i32");
3189
+ lt_s = BINARY_OP("lt_s", 72, "i32", "i32", "i32");
3190
+ lt_u = BINARY_OP("lt_u", 73, "i32", "i32", "i32");
3191
+ gt_s = BINARY_OP("gt_s", 74, "i32", "i32", "i32");
3192
+ gt_u = BINARY_OP("gt_u", 75, "i32", "i32", "i32");
3193
+ le_s = BINARY_OP("le_s", 76, "i32", "i32", "i32");
3194
+ le_u = BINARY_OP("le_u", 77, "i32", "i32", "i32");
3195
+ ge_s = BINARY_OP("ge_s", 78, "i32", "i32", "i32");
3196
+ ge_u = BINARY_OP("ge_u", 79, "i32", "i32", "i32");
3197
+ add = BINARY_OP("add", 106, "i32", "i32", "i32");
3198
+ sub = BINARY_OP("sub", 107, "i32", "i32", "i32");
3199
+ mul = BINARY_OP("mul", 108, "i32", "i32", "i32");
3200
+ div_s = BINARY_OP("div_s", 109, "i32", "i32", "i32");
3201
+ div_u = BINARY_OP("div_u", 110, "i32", "i32", "i32");
3202
+ rem_s = BINARY_OP("rem_s", 111, "i32", "i32", "i32");
3203
+ rem_u = BINARY_OP("rem_u", 112, "i32", "i32", "i32");
3204
+ and = BINARY_OP("and", 113, "i32", "i32", "i32");
3205
+ or = BINARY_OP("or", 114, "i32", "i32", "i32");
3206
+ xor = BINARY_OP("xor", 115, "i32", "i32", "i32");
3207
+ shl = BINARY_OP("shl", 116, "i32", "i32", "i32");
3208
+ shr_s = BINARY_OP("shr_s", 117, "i32", "i32", "i32");
3209
+ shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
3210
+ rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
3211
+ rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
3212
+ eqz = BINARY_OP("eqz", 69, "i32", "i32", "i32");
3213
+ eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
3214
+ ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
3215
+ trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
3216
+ trunc_f32_u = UNARY_OP("trunc_f32_u", 169, "f32", "i32");
3217
+ load = LOAD_OP("load", 40, "i32");
3218
+ load8_s = LOAD_OP("load8_s", 44, "i32");
3219
+ load8_u = LOAD_OP("load8_u", 45, "i32");
3220
+ load16_s = LOAD_OP("load16_s", 46, "i32");
3221
+ load16_u = LOAD_OP("load16_u", 47, "i32");
3222
+ store = STORE_OP("store", 54, "i32");
3223
+ store8 = STORE_OP("store8", 58, "i32");
3224
+ store16 = STORE_OP("store16", 59, "i32");
3225
+ reinterpret_f32 = UNARY_OP("reinterpret_f32", 188, "f32", "i32");
3226
+ trunc_sat_f32_s = UNARY_OP("trunc_sat_f32_s", [252, 0], "f32", "i32");
3227
+ trunc_sat_f32_u = UNARY_OP("trunc_sat_f32_u", [252, 1], "f32", "i32");
3228
+ };
3229
+ var F32 = class {
3230
+ constructor(cg) {
3231
+ this.cg = cg;
3232
+ }
3233
+ get typeId() {
3234
+ return 125;
3235
+ }
3236
+ get name() {
3237
+ return "f32";
3238
+ }
3239
+ const(f) {
3240
+ this.cg._emit(67);
3241
+ const buffer = /* @__PURE__ */ new ArrayBuffer(4);
3242
+ new DataView(buffer).setFloat32(0, f, true);
3243
+ const bytes = new Uint8Array(buffer);
3244
+ for (let i = 0; i < 4; i++) this.cg._emit(bytes[i]);
3245
+ this.cg._push(this);
3246
+ }
3247
+ eq = BINARY_OP("eq", 91, "f32", "f32", "i32");
3248
+ ne = BINARY_OP("ne", 92, "f32", "f32", "i32");
3249
+ lt = BINARY_OP("lt", 93, "f32", "f32", "i32");
3250
+ gt = BINARY_OP("gt", 94, "f32", "f32", "i32");
3251
+ le = BINARY_OP("le", 95, "f32", "f32", "i32");
3252
+ ge = BINARY_OP("ge", 96, "f32", "f32", "i32");
3253
+ abs = UNARY_OP("abs", 139, "f32", "f32");
3254
+ neg = UNARY_OP("neg", 140, "f32", "f32");
3255
+ ceil = UNARY_OP("ceil", 141, "f32", "f32");
3256
+ floor = UNARY_OP("floor", 142, "f32", "f32");
3257
+ trunc = UNARY_OP("trunc", 143, "f32", "f32");
3258
+ nearest = UNARY_OP("nearest", 144, "f32", "f32");
3259
+ sqrt = UNARY_OP("sqrt", 145, "f32", "f32");
3260
+ add = BINARY_OP("add", 146, "f32", "f32", "f32");
3261
+ sub = BINARY_OP("sub", 147, "f32", "f32", "f32");
3262
+ mul = BINARY_OP("mul", 148, "f32", "f32", "f32");
3263
+ div = BINARY_OP("div", 149, "f32", "f32", "f32");
3264
+ min = BINARY_OP("min", 150, "f32", "f32", "f32");
3265
+ max = BINARY_OP("max", 151, "f32", "f32", "f32");
3266
+ copysign = BINARY_OP("copysign", 152, "f32", "f32", "f32");
3267
+ convert_i32_s = UNARY_OP("convert_i32_s", 178, "i32", "f32");
3268
+ convert_i32_u = UNARY_OP("convert_i32_u", 179, "i32", "f32");
3269
+ load = LOAD_OP("load", 42, "f32");
3270
+ store = STORE_OP("store", 56, "f32");
3271
+ reinterpret_i32 = UNARY_OP("reinterpret_i32", 190, "i32", "f32");
3272
+ };
3273
+ function VECTOR_OP(op, vopcode, inTypes, outType) {
3274
+ return function() {
3275
+ for (const inType of inTypes.toReversed()) {
3276
+ const actualType = this.cg._pop();
3277
+ assert(actualType.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inTypes.join(", ")} -> ${outType})`);
3278
+ }
3279
+ this.cg._emit(encodeOpcode([253, vopcode]));
3280
+ this.cg._push(this.cg[outType]);
3281
+ };
3282
+ }
3283
+ function VECTOR_OPL(op, vopcode, inTypes, outType) {
3284
+ return function(lane) {
3285
+ for (const inType of inTypes.toReversed()) {
3286
+ const actualType = this.cg._pop();
3287
+ assert(actualType.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inTypes} -> ${outType})`);
3288
+ }
3289
+ this.cg._emit(encodeOpcode([253, vopcode]));
3290
+ this.cg._emit(lane);
3291
+ this.cg._push(this.cg[outType]);
3292
+ };
3293
+ }
3294
+ function VECTOR_LOAD_OP(op, vopcode) {
3295
+ return function(align = 0, offset = 0) {
3296
+ const idxType = this.cg._pop();
3297
+ assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
3298
+ this.cg._emit(encodeOpcode([253, vopcode]));
3299
+ this.cg._emit(encodeUnsigned(align));
3300
+ this.cg._emit(encodeUnsigned(offset));
3301
+ this.cg._push(this.cg.v128);
3302
+ };
3303
+ }
3304
+ var V128 = class {
3305
+ constructor(cg) {
3306
+ this.cg = cg;
3307
+ }
3308
+ get typeId() {
3309
+ return 123;
3310
+ }
3311
+ get name() {
3312
+ return "v128";
3313
+ }
3314
+ load = VECTOR_LOAD_OP("load", 0);
3315
+ load32x2_s = VECTOR_LOAD_OP("load32x2_s", 5);
3316
+ load32x2_u = VECTOR_LOAD_OP("load32x2_u", 6);
3317
+ load32_splat = VECTOR_LOAD_OP("load32_splat", 9);
3318
+ load32_zero = VECTOR_LOAD_OP("load32_zero", 92);
3319
+ store(align = 0, offset = 0) {
3320
+ const valType = this.cg._pop();
3321
+ assert(valType.typeId === this.cg.v128.typeId, `invalid type for store`);
3322
+ const idxType = this.cg._pop();
3323
+ assert(idxType.typeId === this.cg.i32.typeId, `invalid type for store`);
3324
+ this.cg._emit(253);
3325
+ this.cg._emit(encodeUnsigned(11));
3326
+ this.cg._emit(encodeUnsigned(align));
3327
+ this.cg._emit(encodeUnsigned(offset));
3328
+ }
3329
+ not = VECTOR_OP("not", 77, ["v128"], "v128");
3330
+ and = VECTOR_OP("and", 78, ["v128", "v128"], "v128");
3331
+ andnot = VECTOR_OP("andnot", 79, ["v128", "v128"], "v128");
3332
+ or = VECTOR_OP("or", 80, ["v128", "v128"], "v128");
3333
+ xor = VECTOR_OP("xor", 81, ["v128", "v128"], "v128");
3334
+ bitselect = VECTOR_OP("bitselect", 82, [
3335
+ "v128",
3336
+ "v128",
3337
+ "v128"
3338
+ ], "v128");
3339
+ any_true = VECTOR_OP("any_true", 83, ["v128"], "i32");
3340
+ };
3341
+ var I32x4 = class extends V128 {
3342
+ splat = VECTOR_OP("splat", 17, ["i32"], "v128");
3343
+ extract_lane = VECTOR_OPL("extract_lane", 27, ["v128"], "i32");
3344
+ replace_lane = VECTOR_OPL("replace_lane", 28, ["v128", "i32"], "v128");
3345
+ eq = VECTOR_OP("eq", 55, ["v128", "v128"], "v128");
3346
+ ne = VECTOR_OP("ne", 56, ["v128", "v128"], "v128");
3347
+ lt_s = VECTOR_OP("lt_s", 57, ["v128", "v128"], "v128");
3348
+ lt_u = VECTOR_OP("lt_u", 58, ["v128", "v128"], "v128");
3349
+ gt_s = VECTOR_OP("gt_s", 59, ["v128", "v128"], "v128");
3350
+ gt_u = VECTOR_OP("gt_u", 60, ["v128", "v128"], "v128");
3351
+ le_s = VECTOR_OP("le_s", 61, ["v128", "v128"], "v128");
3352
+ le_u = VECTOR_OP("le_u", 62, ["v128", "v128"], "v128");
3353
+ ge_s = VECTOR_OP("ge_s", 63, ["v128", "v128"], "v128");
3354
+ ge_u = VECTOR_OP("ge_u", 64, ["v128", "v128"], "v128");
3355
+ abs = VECTOR_OP("abs", 160, ["v128"], "v128");
3356
+ neg = VECTOR_OP("neg", 161, ["v128"], "v128");
3357
+ all_true = VECTOR_OP("all_true", 163, ["v128"], "i32");
3358
+ bitmask = VECTOR_OP("bitmask", 164, ["v128"], "i32");
3359
+ shl = VECTOR_OP("shl", 171, ["v128", "i32"], "v128");
3360
+ shr_s = VECTOR_OP("shr_s", 172, ["v128", "i32"], "v128");
3361
+ shr_u = VECTOR_OP("shr_u", 173, ["v128", "i32"], "v128");
3362
+ add = VECTOR_OP("add", 174, ["v128", "v128"], "v128");
3363
+ sub = VECTOR_OP("sub", 177, ["v128", "v128"], "v128");
3364
+ mul = VECTOR_OP("mul", 181, ["v128", "v128"], "v128");
3365
+ min_s = VECTOR_OP("min_s", 182, ["v128", "v128"], "v128");
3366
+ min_u = VECTOR_OP("min_u", 183, ["v128", "v128"], "v128");
3367
+ max_s = VECTOR_OP("max_s", 184, ["v128", "v128"], "v128");
3368
+ max_u = VECTOR_OP("max_u", 185, ["v128", "v128"], "v128");
3369
+ };
3370
+ var F32x4 = class extends V128 {
3371
+ splat = VECTOR_OP("splat", 19, ["f32"], "v128");
3372
+ extract_lane = VECTOR_OPL("extract_lane", 31, ["v128"], "f32");
3373
+ replace_lane = VECTOR_OPL("replace_lane", 32, ["v128", "f32"], "v128");
3374
+ eq = VECTOR_OP("eq", 65, ["v128", "v128"], "v128");
3375
+ ne = VECTOR_OP("ne", 66, ["v128", "v128"], "v128");
3376
+ lt = VECTOR_OP("lt", 67, ["v128", "v128"], "v128");
3377
+ gt = VECTOR_OP("gt", 68, ["v128", "v128"], "v128");
3378
+ le = VECTOR_OP("le", 69, ["v128", "v128"], "v128");
3379
+ ge = VECTOR_OP("ge", 70, ["v128", "v128"], "v128");
3380
+ ceil = VECTOR_OP("ceil", 103, ["v128"], "v128");
3381
+ floor = VECTOR_OP("floor", 104, ["v128"], "v128");
3382
+ trunc = VECTOR_OP("trunc", 105, ["v128"], "v128");
3383
+ nearest = VECTOR_OP("nearest", 106, ["v128"], "v128");
3384
+ abs = VECTOR_OP("abs", 224, ["v128"], "v128");
3385
+ neg = VECTOR_OP("neg", 225, ["v128"], "v128");
3386
+ sqrt = VECTOR_OP("sqrt", 227, ["v128"], "v128");
3387
+ add = VECTOR_OP("add", 228, ["v128", "v128"], "v128");
3388
+ sub = VECTOR_OP("sub", 229, ["v128", "v128"], "v128");
3389
+ mul = VECTOR_OP("mul", 230, ["v128", "v128"], "v128");
3390
+ div = VECTOR_OP("div", 231, ["v128", "v128"], "v128");
3391
+ min = VECTOR_OP("min", 232, ["v128", "v128"], "v128");
3392
+ max = VECTOR_OP("max", 233, ["v128", "v128"], "v128");
3393
+ pmin = VECTOR_OP("pmin", 234, ["v128", "v128"], "v128");
3394
+ pmax = VECTOR_OP("pmax", 235, ["v128", "v128"], "v128");
3395
+ };
3396
+
3397
+ //#endregion
3398
+ //#region src/backend/wasm.ts
3399
+ /** Backend that compiles into WebAssembly bytecode for immediate execution. */
3400
+ var WasmBackend = class {
3401
+ type = "wasm";
3402
+ maxArgs = 64;
3403
+ #memory;
3404
+ #nextSlot;
3405
+ #allocator;
3406
+ #buffers;
3407
+ constructor() {
3408
+ this.#memory = new WebAssembly.Memory({ initial: 0 });
3409
+ this.#allocator = new WasmAllocator(this.#memory);
3410
+ this.#nextSlot = 1;
3411
+ this.#buffers = /* @__PURE__ */ new Map();
3412
+ }
3413
+ malloc(size, initialData) {
3414
+ const ptr = this.#allocator.malloc(size);
3415
+ if (initialData) {
3416
+ if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
3417
+ new Uint8Array(this.#memory.buffer, ptr, size).set(initialData);
3418
+ }
3419
+ const slot = this.#nextSlot++;
3420
+ this.#buffers.set(slot, {
3421
+ ptr,
3422
+ size,
3423
+ ref: 1
3424
+ });
3425
+ return slot;
3426
+ }
3427
+ incRef(slot) {
3428
+ const buffer = this.#buffers.get(slot);
3429
+ if (!buffer) throw new SlotError(slot);
3430
+ buffer.ref++;
3431
+ }
3432
+ decRef(slot) {
3433
+ const buffer = this.#buffers.get(slot);
3434
+ if (!buffer) throw new SlotError(slot);
3435
+ buffer.ref--;
3436
+ if (buffer.ref === 0) {
3437
+ this.#allocator.free(buffer.ptr);
3438
+ this.#buffers.delete(slot);
3439
+ }
3440
+ }
3441
+ async read(slot, start, count) {
3442
+ return this.readSync(slot, start, count);
3443
+ }
3444
+ readSync(slot, start, count) {
3445
+ const buffer = this.#getBuffer(slot);
3446
+ if (start === void 0) start = 0;
3447
+ if (count === void 0) count = buffer.byteLength - start;
3448
+ return buffer.slice(start, start + count);
3449
+ }
3450
+ async prepare(kernel) {
3451
+ return this.prepareSync(kernel);
3452
+ }
3453
+ prepareSync(kernel) {
3454
+ const bytes = codegenWasm(kernel);
3455
+ const module$1 = new WebAssembly.Module(bytes);
3456
+ return new Executable(kernel, { module: module$1 });
3457
+ }
3458
+ dispatch(exe, inputs, outputs) {
3459
+ const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
3460
+ const func = instance.exports.kernel;
3461
+ const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
3462
+ func(...ptrs);
3463
+ }
3464
+ #getBuffer(slot) {
3465
+ const buffer = this.#buffers.get(slot);
3466
+ if (!buffer) throw new SlotError(slot);
3467
+ return new Uint8Array(this.#memory.buffer, buffer.ptr, buffer.size);
3468
+ }
3469
+ };
3470
+ function codegenWasm(kernel) {
3471
+ const tune = tuneNullopt(kernel);
3472
+ const re = kernel.reduction;
3473
+ if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
3474
+ const cg = new CodeGenerator();
3475
+ cg.memory.import("env", "memory");
3476
+ const distinctOps = union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
3477
+ const funcs = {};
3478
+ if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
3479
+ if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
3480
+ if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
3481
+ if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
3482
+ if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
3483
+ if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
3484
+ if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
3485
+ const kernelFunc = cg.function(rep(kernel.nargs + 1, cg.i32), [], () => {
3486
+ const gidx = cg.local.declare(cg.i32);
3487
+ cg.loop(cg.void);
3488
+ cg.block(cg.void);
3489
+ cg.local.get(gidx);
3490
+ cg.i32.const(kernel.size);
3491
+ cg.i32.ge_u();
3492
+ cg.br_if(0);
3493
+ cg.local.get(kernel.nargs);
3494
+ cg.local.get(gidx);
3495
+ cg.i32.const(byteWidth(kernel.dtype));
3496
+ cg.i32.mul();
3497
+ cg.i32.add();
3498
+ if (re) {
3499
+ const acc = cg.local.declare(dty(cg, null, kernel.exp.dtype));
3500
+ dty(cg, null, kernel.exp.dtype).const(re.identity);
3501
+ cg.local.set(acc);
3502
+ const ridx = cg.local.declare(cg.i32);
3503
+ cg.i32.const(0);
3504
+ cg.local.set(ridx);
3505
+ cg.loop(cg.void);
3506
+ cg.block(cg.void);
3507
+ cg.local.get(ridx);
3508
+ cg.i32.const(re.size);
3509
+ cg.i32.ge_u();
3510
+ cg.br_if(0);
3511
+ translateExp(cg, funcs, tune.exp, {
3512
+ gidx,
3513
+ ridx
3514
+ });
3515
+ if (re.op === AluOp.Add) {
3516
+ cg.local.get(acc);
3517
+ if (re.dtype === DType.Bool) cg.i32.or();
3518
+ else dty(cg, re.op, re.dtype).add();
3519
+ } else if (re.op === AluOp.Mul) {
3520
+ cg.local.get(acc);
3521
+ if (re.dtype === DType.Bool) cg.i32.and();
3522
+ else dty(cg, re.op, re.dtype).mul();
3523
+ } else if (re.op === AluOp.Min || re.op === AluOp.Max) if (re.dtype === DType.Float32) {
3524
+ cg.local.get(acc);
3525
+ if (re.op === AluOp.Min) cg.f32.min();
3526
+ else cg.f32.max();
3527
+ } else if ([
3528
+ DType.Int32,
3529
+ DType.Uint32,
3530
+ DType.Bool
3531
+ ].includes(re.dtype)) {
3532
+ const local = cg.local.declare(cg.i32);
3533
+ cg.local.tee(local);
3534
+ cg.local.get(acc);
3535
+ cg.local.get(local);
3536
+ cg.local.get(acc);
3537
+ if (re.op === AluOp.Min) if (re.dtype === DType.Int32) cg.i32.lt_s();
3538
+ else cg.i32.lt_u();
3539
+ else if (re.dtype === DType.Int32) cg.i32.gt_s();
3540
+ else cg.i32.gt_u();
3541
+ cg.select();
3542
+ } else throw new Error(`invalid reduction min/max over ${re.dtype}`);
3543
+ else throw new Error(`invalid wasm reduction op: ${re.op}`);
3544
+ cg.local.set(acc);
3545
+ cg.local.get(ridx);
3546
+ cg.i32.const(1);
3547
+ cg.i32.add();
3548
+ cg.local.set(ridx);
3549
+ cg.br(1);
3550
+ cg.end();
3551
+ cg.end();
3552
+ translateExp(cg, funcs, kernel.reduction.epilogue, { acc });
3553
+ } else translateExp(cg, funcs, tune.exp, { gidx });
3554
+ dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
3555
+ cg.local.get(gidx);
3556
+ cg.i32.const(1);
3557
+ cg.i32.add();
3558
+ cg.local.set(gidx);
3559
+ cg.br(1);
3560
+ cg.end();
3561
+ cg.end();
3562
+ });
3563
+ cg.export(kernelFunc, "kernel");
3564
+ return cg.finish();
3565
+ }
3566
+ function translateExp(cg, funcs, exp, ctx) {
3567
+ const references = /* @__PURE__ */ new Map();
3568
+ const seen = /* @__PURE__ */ new Set();
3569
+ const countReferences = (exp$1) => {
3570
+ references.set(exp$1, (references.get(exp$1) ?? 0) + 1);
3571
+ if (!seen.has(exp$1)) {
3572
+ seen.add(exp$1);
3573
+ for (const src of exp$1.src) countReferences(src);
3574
+ }
3575
+ };
3576
+ const expContext = /* @__PURE__ */ new Map();
3577
+ const gen = (exp$1) => {
3578
+ if (expContext.has(exp$1)) return cg.local.get(expContext.get(exp$1));
3579
+ const { op, src, dtype, arg } = exp$1;
3580
+ if (AluGroup.Binary.has(op) || AluGroup.Compare.has(op)) {
3581
+ gen(src[0]);
3582
+ gen(src[1]);
3583
+ if (op === AluOp.Add) if (dtype === DType.Bool) cg.i32.or();
3584
+ else dty(cg, op, dtype).add();
3585
+ else if (op === AluOp.Sub) dty(cg, op, dtype).sub();
3586
+ else if (op === AluOp.Mul) if (dtype === DType.Bool) cg.i32.and();
3587
+ else dty(cg, op, dtype).mul();
3588
+ else if (op === AluOp.Idiv) if (dtype === DType.Float32) cg.f32.div(), cg.f32.trunc();
3589
+ else if (dtype === DType.Uint32) cg.i32.div_u();
3590
+ else if (dtype === DType.Int32) cg.i32.div_s();
3591
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3592
+ else if (op === AluOp.Mod) if (dtype === DType.Float32) {
3593
+ const a = cg.local.declare(cg.f32);
3594
+ const b = cg.local.declare(cg.f32);
3595
+ cg.local.set(b);
3596
+ cg.local.tee(a);
3597
+ cg.local.get(a);
3598
+ cg.local.get(b);
3599
+ cg.f32.div();
3600
+ cg.f32.trunc();
3601
+ cg.local.get(b);
3602
+ cg.f32.mul();
3603
+ cg.f32.sub();
3604
+ } else if (dtype === DType.Uint32) cg.i32.rem_u();
3605
+ else if (dtype === DType.Int32) cg.i32.rem_s();
3606
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3607
+ else if (op === AluOp.Min || op === AluOp.Max) if (dtype === DType.Float32) if (op === AluOp.Min) cg.f32.min();
3608
+ else cg.f32.max();
3609
+ else if (dtype === DType.Int32 || dtype === DType.Uint32) {
3610
+ const a = cg.local.declare(cg.i32);
3611
+ const b = cg.local.declare(cg.i32);
3612
+ cg.local.set(b);
3613
+ cg.local.tee(a);
3614
+ cg.local.get(b);
3615
+ cg.local.get(a);
3616
+ cg.local.get(b);
3617
+ if (dtype === DType.Int32) if (op === AluOp.Min) cg.i32.lt_s();
3618
+ else cg.i32.gt_s();
3619
+ else if (op === AluOp.Min) cg.i32.lt_u();
3620
+ else cg.i32.gt_u();
3621
+ cg.select();
3622
+ } else throw new UnsupportedOpError(op, dtype, "wasm");
3623
+ else if (op === AluOp.Cmplt) {
3624
+ const srcDtype = src[0].dtype;
3625
+ if (srcDtype === DType.Float32) cg.f32.lt();
3626
+ else if (srcDtype === DType.Int32) cg.i32.lt_s();
3627
+ else if (srcDtype === DType.Uint32) cg.i32.lt_u();
3628
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3629
+ } else if (op === AluOp.Cmpne) dty(cg, op, src[0].dtype).ne();
3630
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3631
+ } else if (AluGroup.Unary.has(op)) if (op === AluOp.Sin) gen(src[0]), cg.call(funcs.sin);
3632
+ else if (op === AluOp.Cos) gen(src[0]), cg.call(funcs.cos);
3633
+ else if (op === AluOp.Asin) gen(src[0]), cg.call(funcs.asin);
3634
+ else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
3635
+ else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
3636
+ else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
3637
+ else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
3638
+ else if (op === AluOp.Reciprocal) cg.f32.const(1), gen(src[0]), cg.f32.div();
3639
+ else if (op === AluOp.Cast) {
3640
+ gen(src[0]);
3641
+ const dtype0 = src[0].dtype;
3642
+ const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
3643
+ if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
3644
+ else if (i32repr);
3645
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3646
+ else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
3647
+ else if (i32repr);
3648
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3649
+ else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
3650
+ else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
3651
+ else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
3652
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3653
+ else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
3654
+ else if (i32repr) cg.i32.const(0), cg.i32.ne();
3655
+ else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
3656
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3657
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3658
+ } else if (op === AluOp.Bitcast) {
3659
+ gen(src[0]);
3660
+ const dtype0 = src[0].dtype;
3661
+ const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
3662
+ if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
3663
+ else if (i32repr);
3664
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3665
+ else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
3666
+ else if (dtype0 === DType.Float32);
3667
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3668
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3669
+ } else throw new UnsupportedOpError(op, dtype, "wasm");
3670
+ else if (op === AluOp.Where) {
3671
+ gen(src[1]);
3672
+ gen(src[2]);
3673
+ gen(src[0]);
3674
+ cg.select();
3675
+ } else if (op === AluOp.Threefry2x32) {
3676
+ for (let i = 0; i < 4; i++) gen(src[i]);
3677
+ cg.call(funcs.threefry2x32);
3678
+ if (arg === "xor") cg.i32.xor();
3679
+ else if (arg === 0) cg.drop();
3680
+ else if (arg === 1) {
3681
+ const local = cg.local.declare(cg.i32);
3682
+ cg.local.set(local);
3683
+ cg.drop();
3684
+ cg.local.get(local);
3685
+ } else throw new UnsupportedOpError(op, dtype, "wasm", arg);
3686
+ } else if (op === AluOp.Const) return dty(cg, op, dtype).const(arg);
3687
+ else if (op === AluOp.Special) return cg.local.get(ctx[arg[0]]);
3688
+ else if (op === AluOp.Variable) return cg.local.get(ctx[arg]);
3689
+ else if (op === AluOp.GlobalIndex) {
3690
+ const [gid, len] = arg;
3691
+ gen(src[0]);
3692
+ const local = cg.local.declare(cg.i32);
3693
+ cg.local.tee(local);
3694
+ cg.i32.const(0);
3695
+ cg.local.get(local), cg.i32.const(len), cg.i32.lt_u();
3696
+ cg.select();
3697
+ cg.i32.const(byteWidth(dtype));
3698
+ cg.i32.mul();
3699
+ cg.local.get(gid);
3700
+ cg.i32.add();
3701
+ dty(cg, op, dtype).load(Math.log2(byteWidth(dtype)));
3702
+ } else throw new UnsupportedOpError(op, dtype, "wasm");
3703
+ if ((references.get(exp$1) ?? 0) > 1) {
3704
+ const local = cg.local.declare(dty(cg, op, dtype));
3705
+ cg.local.tee(local);
3706
+ expContext.set(exp$1, local);
3707
+ }
3708
+ };
3709
+ countReferences(exp);
3710
+ gen(exp);
3711
+ }
3712
+ function dty(cg, op, dtype) {
3713
+ switch (dtype) {
3714
+ case DType.Float32: return cg.f32;
3715
+ case DType.Int32:
3716
+ case DType.Uint32:
3717
+ case DType.Bool: return cg.i32;
3718
+ default: throw new UnsupportedOpError(op, dtype, "wasm");
3719
+ }
3720
+ }
3721
+
1814
3722
  //#endregion
1815
3723
  //#region src/backend.ts
1816
- const devices = ["cpu", "webgpu"];
1817
- let defaultBackend = "cpu";
3724
+ const devices = [
3725
+ "cpu",
3726
+ "wasm",
3727
+ "webgpu"
3728
+ ];
3729
+ let defaultBackend = "wasm";
1818
3730
  const initializedBackends = /* @__PURE__ */ new Map();
1819
- initializedBackends.set("cpu", new CPUBackend());
1820
- /** Set the default device backend (must be initialized). */
1821
- function setDevice(device) {
1822
- if (initializedBackends.has(device)) defaultBackend = device;
3731
+ initializedBackends.set("cpu", new CpuBackend());
3732
+ initializedBackends.set("wasm", new WasmBackend());
3733
+ /** Configure the default device for arrays. */
3734
+ function defaultDevice(device) {
3735
+ if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
1823
3736
  else throw new Error(`Backend not initialized: ${device}`);
3737
+ return defaultBackend;
1824
3738
  }
1825
3739
  /**
1826
3740
  * Initialize `jax-js` library backends.
@@ -1841,12 +3755,13 @@ async function init(...devicesToInit) {
1841
3755
  }
1842
3756
  /** Create a backend, if available. Internal function called by `init()`. */
1843
3757
  async function createBackend(device) {
1844
- if (device === "cpu") return new CPUBackend();
3758
+ if (device === "cpu") return new CpuBackend();
3759
+ else if (device === "wasm") return new WasmBackend();
1845
3760
  else if (device === "webgpu") {
1846
3761
  if (!navigator.gpu) return null;
1847
3762
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
1848
3763
  if (!adapter) return null;
1849
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-c5Fe8nx8.cjs"));
3764
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BVdMaO9T.cjs"));
1850
3765
  const importantLimits = [
1851
3766
  "maxBufferSize",
1852
3767
  "maxComputeInvocationsPerWorkgroup",
@@ -1859,8 +3774,12 @@ async function createBackend(device) {
1859
3774
  "maxStorageBuffersPerShaderStage",
1860
3775
  "maxStorageTexturesPerShaderStage"
1861
3776
  ];
3777
+ const requestedFeatures = ["shader-f16", "timestamp-query"];
1862
3778
  try {
1863
- const device$1 = await adapter.requestDevice({ requiredLimits: Object.fromEntries(importantLimits.map((feature) => [feature, adapter.limits[feature]])) });
3779
+ const device$1 = await adapter.requestDevice({
3780
+ requiredLimits: Object.fromEntries(importantLimits.map((limit) => [limit, adapter.limits[limit]])),
3781
+ requiredFeatures: requestedFeatures.filter((feature) => adapter.features.has(feature))
3782
+ });
1864
3783
  return new WebGPUBackend(device$1);
1865
3784
  } catch (error) {
1866
3785
  console.error("Unexpected error requesting WebGPU device:", error);
@@ -1886,6 +3805,13 @@ var SlotError = class extends Error {
1886
3805
  super(`Used a buffer that is invalid or already freed: ${slot}`);
1887
3806
  }
1888
3807
  };
3808
+ var UnsupportedOpError = class extends Error {
3809
+ constructor(op, dtype, device, arg) {
3810
+ let msg = `${op || ""}<${dtype}> not supported in ${device} backend`;
3811
+ if (arg !== void 0) msg += ` with arg ${JSON.stringify(arg)}`;
3812
+ super(msg);
3813
+ }
3814
+ };
1889
3815
 
1890
3816
  //#endregion
1891
3817
  Object.defineProperty(exports, 'AluExp', {
@@ -1966,6 +3892,12 @@ Object.defineProperty(exports, 'SlotError', {
1966
3892
  return SlotError;
1967
3893
  }
1968
3894
  });
3895
+ Object.defineProperty(exports, 'UnsupportedOpError', {
3896
+ enumerable: true,
3897
+ get: function () {
3898
+ return UnsupportedOpError;
3899
+ }
3900
+ });
1969
3901
  Object.defineProperty(exports, 'accessorAluExp', {
1970
3902
  enumerable: true,
1971
3903
  get: function () {
@@ -1996,6 +3928,12 @@ Object.defineProperty(exports, 'deepEqual', {
1996
3928
  return deepEqual;
1997
3929
  }
1998
3930
  });
3931
+ Object.defineProperty(exports, 'defaultDevice', {
3932
+ enumerable: true,
3933
+ get: function () {
3934
+ return defaultDevice;
3935
+ }
3936
+ });
1999
3937
  Object.defineProperty(exports, 'devices', {
2000
3938
  enumerable: true,
2001
3939
  get: function () {
@@ -2008,6 +3946,12 @@ Object.defineProperty(exports, 'dtypedArray', {
2008
3946
  return dtypedArray;
2009
3947
  }
2010
3948
  });
3949
+ Object.defineProperty(exports, 'dtypedJsArray', {
3950
+ enumerable: true,
3951
+ get: function () {
3952
+ return dtypedJsArray;
3953
+ }
3954
+ });
2011
3955
  Object.defineProperty(exports, 'findPow2', {
2012
3956
  enumerable: true,
2013
3957
  get: function () {
@@ -2050,6 +3994,12 @@ Object.defineProperty(exports, 'isPermutation', {
2050
3994
  return isPermutation;
2051
3995
  }
2052
3996
  });
3997
+ Object.defineProperty(exports, 'normalizeAxis', {
3998
+ enumerable: true,
3999
+ get: function () {
4000
+ return normalizeAxis;
4001
+ }
4002
+ });
2053
4003
  Object.defineProperty(exports, 'partitionList', {
2054
4004
  enumerable: true,
2055
4005
  get: function () {
@@ -2062,6 +4012,12 @@ Object.defineProperty(exports, 'prod', {
2062
4012
  return prod;
2063
4013
  }
2064
4014
  });
4015
+ Object.defineProperty(exports, 'promoteTypes', {
4016
+ enumerable: true,
4017
+ get: function () {
4018
+ return promoteTypes;
4019
+ }
4020
+ });
2065
4021
  Object.defineProperty(exports, 'range', {
2066
4022
  enumerable: true,
2067
4023
  get: function () {
@@ -2086,10 +4042,10 @@ Object.defineProperty(exports, 'runWithCache', {
2086
4042
  return runWithCache;
2087
4043
  }
2088
4044
  });
2089
- Object.defineProperty(exports, 'setDevice', {
4045
+ Object.defineProperty(exports, 'setDebug', {
2090
4046
  enumerable: true,
2091
4047
  get: function () {
2092
- return setDevice;
4048
+ return setDebug;
2093
4049
  }
2094
4050
  });
2095
4051
  Object.defineProperty(exports, 'strip1', {
@@ -2110,6 +4066,12 @@ Object.defineProperty(exports, 'tuneWebgpu', {
2110
4066
  return tuneWebgpu;
2111
4067
  }
2112
4068
  });
4069
+ Object.defineProperty(exports, 'union', {
4070
+ enumerable: true,
4071
+ get: function () {
4072
+ return union;
4073
+ }
4074
+ });
2113
4075
  Object.defineProperty(exports, 'unravelAlu', {
2114
4076
  enumerable: true,
2115
4077
  get: function () {
@@ -2127,4 +4089,10 @@ Object.defineProperty(exports, 'zip', {
2127
4089
  get: function () {
2128
4090
  return zip;
2129
4091
  }
4092
+ });
4093
+ Object.defineProperty(exports, 'zipn', {
4094
+ enumerable: true,
4095
+ get: function () {
4096
+ return zipn;
4097
+ }
2130
4098
  });