@jax-js/jax 0.0.2 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3816 @@
1
+ //#region src/pprint.ts
2
+ /** General class for pretty-printing expressions with indentation. */
3
+ var PPrint = class PPrint {
4
+ constructor(indents, lines) {
5
+ this.indents = indents;
6
+ this.lines = lines;
7
+ }
8
+ /** Add a fixed amount of indentation to each line. */
9
+ indent(spaces) {
10
+ return new PPrint(this.indents.map((i) => i + spaces), this.lines);
11
+ }
12
+ /** Concatenate pretty-printed expressions with newlines. */
13
+ concat(...items) {
14
+ return new PPrint((this.indents ?? []).concat(...items.map((i) => i.indents)), (this.lines ?? []).concat(...items.map((i) => i.lines)));
15
+ }
16
+ /** Stack one block to the right of another one, sharing 1 common line. */
17
+ stack(other) {
18
+ if (!other.lines.length) return this;
19
+ if (!this.lines.length) return other;
20
+ const indent = this.indents[this.indents.length - 1];
21
+ const s = this.lines[this.lines.length - 1];
22
+ const indentedBlock = other.indent(indent + s.length);
23
+ return new PPrint(this.indents.concat(indentedBlock.indents.slice(1)), this.lines.slice(0, -1).concat(s + " ".repeat(other.indents[0]) + other.lines[0], ...indentedBlock.lines.slice(1)));
24
+ }
25
+ /** Combine this block of lines into a formatted string. */
26
+ toString() {
27
+ return this.lines.map((line, i) => " ".repeat(this.indents[i]) + line).join("\n");
28
+ }
29
+ static pp(s) {
30
+ const lines = s.toString().split("\n");
31
+ return new PPrint(Array(lines.length).fill(0), lines);
32
+ }
33
+ };
34
+
35
+ //#endregion
36
+ //#region src/utils.ts
37
+ /** @file Generic programming utilities with no dependencies on library code. */
38
+ let DEBUG = 0;
39
+ /**
40
+ * Set the debug level for verbose logging.
41
+ *
42
+ * 1. JIT compile logs
43
+ * 2. Shader code
44
+ * 3. Expressions and metadata
45
+ * 4. JIT programs, tuning details
46
+ * 5. Most verbose operation traces
47
+ *
48
+ * This is an experimental API and may change in behavior. Do not rely on this
49
+ * in production.
50
+ */
51
+ function setDebug(level) {
52
+ DEBUG = level;
53
+ }
54
+ function unzip2(pairs) {
55
+ const lst1 = [];
56
+ const lst2 = [];
57
+ for (const [x, y] of pairs) {
58
+ lst1.push(x);
59
+ lst2.push(y);
60
+ }
61
+ return [lst1, lst2];
62
+ }
63
+ function zip(xs, ys) {
64
+ return xs.map((x, i) => [x, ys[i]]);
65
+ }
66
+ function zipn(...arrays) {
67
+ const minLength = Math.min(...arrays.map((x) => x.length));
68
+ return Array.from({ length: minLength }, (_, i) => arrays.map((arr) => arr[i]));
69
+ }
70
+ function rep(length, value) {
71
+ if (value instanceof Function) return new Array(length).fill(0).map((_, i) => value(i));
72
+ return new Array(length).fill(value);
73
+ }
74
+ function prod(arr) {
75
+ return arr.reduce((acc, x) => acc * x, 1);
76
+ }
77
+ function gcd(...values) {
78
+ let a = 0;
79
+ for (let b of values) while (b !== 0) [a, b] = [b, a % b];
80
+ return Math.abs(a);
81
+ }
82
+ /** Shorthand for integer division, like in Python. */
83
+ function intdiv(a, b) {
84
+ return Math.floor(a / b);
85
+ }
86
+ /** Clamp `x` to the range `[min, max]`. */
87
+ function clamp(x, min, max) {
88
+ return Math.max(min, Math.min(max, x));
89
+ }
90
+ /** Check if two objects are deep equal. */
91
+ function deepEqual(a, b) {
92
+ if (a === b) return true;
93
+ if (typeof a !== "object" || typeof b !== "object") return false;
94
+ if (a === null || b === null) return false;
95
+ if (Object.keys(a).length !== Object.keys(b).length) return false;
96
+ for (const key of Object.keys(a)) if (!deepEqual(a[key], b[key])) return false;
97
+ return true;
98
+ }
99
+ function union(...sets) {
100
+ const result = /* @__PURE__ */ new Set();
101
+ for (const s of sets) if (s) for (const x of s) result.add(x);
102
+ return result;
103
+ }
104
+ /** Splits the list based on a condition, `false` first then `true`. */
105
+ function partitionList(which, array) {
106
+ const falseList = [];
107
+ const trueList = [];
108
+ for (let i = 0; i < which.length; i++) if (which[i]) trueList.push(array[i]);
109
+ else falseList.push(array[i]);
110
+ return [falseList, trueList];
111
+ }
112
+ /** Compare two arrays of numbers lexicographically. */
113
+ function lexCompare(a, b) {
114
+ const minLength = Math.min(a.length, b.length);
115
+ for (let i = 0; i < minLength; i++) {
116
+ if (a[i] < b[i]) return -1;
117
+ if (a[i] > b[i]) return 1;
118
+ }
119
+ return a.length - b.length;
120
+ }
121
+ /** Check if an object is a number pair, i.e., a tuple of two numbers. */
122
+ function isNumberPair(x) {
123
+ return Array.isArray(x) && x.length === 2 && typeof x[0] === "number" && typeof x[1] === "number";
124
+ }
125
+ /** Check an axis against number of dimensions, and resolve negative axes. */
126
+ function checkAxis(axis, ndim) {
127
+ if (axis < -ndim || axis >= ndim) throw new Error(`Axis ${axis} out of bounds for array of dimension ${ndim}`);
128
+ return axis < 0 ? axis + ndim : axis;
129
+ }
130
+ /** Normalize common axis argument for functions, defaulting to all axes. */
131
+ function normalizeAxis(axis, ndim) {
132
+ if (axis === null) return range(ndim);
133
+ else if (typeof axis === "number") return [checkAxis(axis, ndim)];
134
+ else {
135
+ const seen = /* @__PURE__ */ new Set();
136
+ for (const a of axis) {
137
+ const ca = checkAxis(a, ndim);
138
+ if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
139
+ seen.add(ca);
140
+ }
141
+ return [...seen].sort();
142
+ }
143
+ }
144
+ function range(start, stop, step = 1) {
145
+ if (stop === void 0) {
146
+ stop = start;
147
+ start = 0;
148
+ }
149
+ const result = [];
150
+ for (let i = start; i < stop; i += step) result.push(i);
151
+ return result;
152
+ }
153
+ function isPermutation(axis, n) {
154
+ if (axis.length !== n) return false;
155
+ const seen = /* @__PURE__ */ new Set();
156
+ for (const x of axis) {
157
+ if (x < 0 || x >= n) return false;
158
+ seen.add(x);
159
+ }
160
+ return seen.size === n;
161
+ }
162
+ function invertPermutation(axis) {
163
+ const n = axis.length;
164
+ if (!isPermutation(axis, n)) throw new Error("invertPermutation: axis is not a permutation");
165
+ const result = new Array(n);
166
+ for (let i = 0; i < n; i++) result[axis[i]] = i;
167
+ return result;
168
+ }
169
+ /** Topologically sort a DAG, given terminal nodes and an ancestor function. */
170
+ function toposort(terminals, parents) {
171
+ const childCounts = /* @__PURE__ */ new Map();
172
+ const stack = [...new Set(terminals)];
173
+ while (true) {
174
+ const node = stack.pop();
175
+ if (!node) break;
176
+ for (const parent of parents(node)) if (childCounts.has(parent)) childCounts.set(parent, childCounts.get(parent) + 1);
177
+ else {
178
+ childCounts.set(parent, 1);
179
+ stack.push(parent);
180
+ }
181
+ }
182
+ for (const node of terminals) childCounts.set(node, childCounts.get(node) - 1);
183
+ const order = [];
184
+ const frontier = terminals.filter((n) => !childCounts.get(n));
185
+ while (true) {
186
+ const node = frontier.pop();
187
+ if (!node) break;
188
+ order.push(node);
189
+ for (const parent of parents(node)) {
190
+ const c = childCounts.get(parent) - 1;
191
+ childCounts.set(parent, c);
192
+ if (c == 0) frontier.push(parent);
193
+ }
194
+ }
195
+ return order.reverse();
196
+ }
197
+ /**
198
+ * Returns the largest power of 2 less than or equal to `max`.
199
+ *
200
+ * If `hint` is nonzero, it will not return a number greater than the first
201
+ * power of 2 that is greater than or equal to `hint`.
202
+ */
203
+ function findPow2(hint, max) {
204
+ if (max < 1) throw new Error("max must be a positive integer");
205
+ let ret = 1;
206
+ while (ret < hint && 2 * ret <= max) ret *= 2;
207
+ return ret;
208
+ }
209
+ function recursiveFlatten(ar) {
210
+ if (!Array.isArray(ar)) return [ar];
211
+ return ar.flat(Infinity);
212
+ }
213
+ /** Strip an outermost pair of nested parentheses from an expression, if any. */
214
+ function strip1(str) {
215
+ if (str[0] === "(" && str[str.length - 1] === ")") return str.slice(1, -1);
216
+ return str;
217
+ }
218
+ const _stagingbuf = /* @__PURE__ */ new DataView(/* @__PURE__ */ new ArrayBuffer(8));
219
+ /**
220
+ * Polynomial hashes modulo p are good at avoiding collisions in expectation.
221
+ * Probability-wise, it's good enough to be used for something like
222
+ * deduplicating seen compiler expressions, although it's not adversarial.
223
+ *
224
+ * See https://en.wikipedia.org/wiki/Lagrange%27s_theorem_(number_theory)
225
+ */
226
+ var FpHash = class FpHash {
227
+ value = 8773157n;
228
+ #update(x) {
229
+ const base = 873192869n;
230
+ const modulus = 3189051996290219n;
231
+ this.value = (this.value * base + x) % modulus;
232
+ }
233
+ update(x) {
234
+ if (typeof x === "string") {
235
+ this.#update(BigInt(x.length));
236
+ for (let i = 0; i < x.length; i++) this.#update(BigInt(199 + x.charCodeAt(i)));
237
+ } else if (typeof x === "number") if (Number.isInteger(x)) this.#update(68265653n ^ BigInt(x));
238
+ else {
239
+ _stagingbuf.setFloat64(0, x, true);
240
+ this.#update(_stagingbuf.getBigUint64(0, true));
241
+ }
242
+ else if (typeof x === "boolean") this.#update(x ? 69069841n : 63640693n);
243
+ else if (typeof x === "bigint") this.#update(x ^ 71657401n);
244
+ else if (x === null) this.#update(37832657n);
245
+ else if (x === void 0) this.#update(18145117n);
246
+ else x.hash(this);
247
+ return this;
248
+ }
249
+ static hash(...values) {
250
+ const h = new FpHash();
251
+ for (const x of values) h.update(x);
252
+ return h.value;
253
+ }
254
+ };
255
+ /** Run a function while caching it inline inside a `Map`. */
256
+ function runWithCache(cache, key, thunk) {
257
+ if (cache.has(key)) return cache.get(key);
258
+ else {
259
+ const value = thunk();
260
+ cache.set(key, value);
261
+ return value;
262
+ }
263
+ }
264
+
265
+ //#endregion
266
+ //#region src/alu.ts
267
+ /** A numerical data type for array contents. */
268
+ let DType = /* @__PURE__ */ function(DType$1) {
269
+ DType$1["Float32"] = "float32";
270
+ DType$1["Int32"] = "int32";
271
+ DType$1["Uint32"] = "uint32";
272
+ DType$1["Bool"] = "bool";
273
+ DType$1["Float16"] = "float16";
274
+ return DType$1;
275
+ }({});
276
+ const byteWidth = (dtype) => {
277
+ switch (dtype) {
278
+ case DType.Float32:
279
+ case DType.Int32:
280
+ case DType.Uint32:
281
+ case DType.Bool: return 4;
282
+ case DType.Float16: return 2;
283
+ default: throw new TypeError(`Unknown dtype: ${dtype}`);
284
+ }
285
+ };
286
+ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
287
+ /**
288
+ * Promote two dtypes to their join according to the type lattice.
289
+ *
290
+ * When performing operations between arrays of different types, we need to
291
+ * promote both operands to a common type that can represent values from both
292
+ * input types. This follows JAX's type promotion rules.
293
+ *
294
+ * **Type lattice:**
295
+ * ```text
296
+ * bool -> uint32 -> int32 -> float16 -> float32
297
+ * weak f* --^
298
+ * ```
299
+ *
300
+ * The asterisk f* is a weak type used for JS number constants. When creating
301
+ * arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
302
+ * any array they are first combined with.
303
+ *
304
+ * **Examples:**
305
+ * - `promoteTypes(bool, int32) → int32`
306
+ * - `promoteTypes(uint32, int32) → int32`
307
+ * - `promoteTypes(int32, float16) → float16`
308
+ * - `promoteTypes(float16, float32) → float32`
309
+ * - `promoteTypes(uint32, float32) → float32`
310
+ */
311
+ function promoteTypes(dtype1, dtype2) {
312
+ if (dtype1 === dtype2) return dtype1;
313
+ const rank = {
314
+ [DType.Bool]: 0,
315
+ [DType.Uint32]: 1,
316
+ [DType.Int32]: 2,
317
+ [DType.Float16]: 3,
318
+ [DType.Float32]: 4
319
+ };
320
+ return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
321
+ }
322
+ function dtypedArray(dtype, data) {
323
+ const { buffer, byteLength, byteOffset } = data;
324
+ const length = byteLength / byteWidth(dtype);
325
+ switch (dtype) {
326
+ case DType.Float32: return new Float32Array(buffer, byteOffset, length);
327
+ case DType.Int32:
328
+ case DType.Bool: return new Int32Array(buffer, byteOffset, length);
329
+ case DType.Uint32: return new Uint32Array(buffer, byteOffset, length);
330
+ case DType.Float16: return new Float16Array(buffer, byteOffset, length);
331
+ default: throw new Error(`Unimplemented dtype: ${dtype}`);
332
+ }
333
+ }
334
+ function dtypedJsArray(dtype, data) {
335
+ switch (dtype) {
336
+ case DType.Float32: return new Float32Array(data);
337
+ case DType.Int32:
338
+ case DType.Bool: return new Int32Array(data);
339
+ case DType.Uint32: return new Uint32Array(data);
340
+ case DType.Float16: return new Float16Array(data);
341
+ default: throw new Error(`Unimplemented dtype: ${dtype}`);
342
+ }
343
+ }
344
+ /**
345
+ * Mathematical expression on scalar values.
346
+ *
347
+ * This is similiar to and based on tinygrad's UOp class, but it's more specific
348
+ * to just math on scalars. We're doing this to avoid the complexity of a full
349
+ * graph rewrite engine.
350
+ */
351
+ var AluExp = class AluExp {
352
+ #hash;
353
+ #simplified;
354
+ #range;
355
+ constructor(op, dtype, src, arg = void 0) {
356
+ this.op = op;
357
+ this.dtype = dtype;
358
+ this.src = src;
359
+ this.arg = arg;
360
+ if (AluGroup.RequiredFloat.has(op) && !isFloatDtype(dtype)) throw new TypeError(`Unsupported dtype for ${op}: ${dtype}`);
361
+ if (op === AluOp.Bitcast && (dtype === DType.Bool || src[0].dtype === DType.Bool || byteWidth(dtype) !== byteWidth(src[0].dtype))) throw new TypeError(`Bitcast from ${src[0].dtype} -> ${dtype}`);
362
+ if (op === AluOp.Threefry2x32 && (dtype !== DType.Uint32 || src.some((x) => x.dtype !== DType.Uint32))) throw new TypeError("Threefry2x32 requires uint32 types");
363
+ }
364
+ static add(a, b) {
365
+ return new AluExp(AluOp.Add, a.dtype, [a, b]);
366
+ }
367
+ static sub(a, b) {
368
+ return new AluExp(AluOp.Sub, a.dtype, [a, b]);
369
+ }
370
+ static mul(a, b) {
371
+ return new AluExp(AluOp.Mul, a.dtype, [a, b]);
372
+ }
373
+ static idiv(a, b) {
374
+ return new AluExp(AluOp.Idiv, a.dtype, [a, b]);
375
+ }
376
+ static mod(a, b) {
377
+ return new AluExp(AluOp.Mod, a.dtype, [a, b]);
378
+ }
379
+ static min(a, b) {
380
+ return new AluExp(AluOp.Min, a.dtype, [a, b]);
381
+ }
382
+ static max(a, b) {
383
+ return new AluExp(AluOp.Max, a.dtype, [a, b]);
384
+ }
385
+ static sin(a) {
386
+ return new AluExp(AluOp.Sin, a.dtype, [a]);
387
+ }
388
+ static cos(a) {
389
+ return new AluExp(AluOp.Cos, a.dtype, [a]);
390
+ }
391
+ static asin(a) {
392
+ return new AluExp(AluOp.Asin, a.dtype, [a]);
393
+ }
394
+ static atan(a) {
395
+ return new AluExp(AluOp.Atan, a.dtype, [a]);
396
+ }
397
+ static exp(a) {
398
+ return new AluExp(AluOp.Exp, a.dtype, [a]);
399
+ }
400
+ static log(a) {
401
+ return new AluExp(AluOp.Log, a.dtype, [a]);
402
+ }
403
+ static sqrt(a) {
404
+ return new AluExp(AluOp.Sqrt, a.dtype, [a]);
405
+ }
406
+ static reciprocal(a) {
407
+ return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
408
+ }
409
+ static cast(dtype, a) {
410
+ if (a.dtype === dtype) return a;
411
+ return new AluExp(AluOp.Cast, dtype, [a]);
412
+ }
413
+ static bitcast(dtype, a) {
414
+ if (a.dtype === dtype) return a;
415
+ return new AluExp(AluOp.Bitcast, dtype, [a]);
416
+ }
417
+ static threefry2x32(k0, k1, c0, c1, mode = "xor") {
418
+ return new AluExp(AluOp.Threefry2x32, DType.Uint32, [
419
+ k0,
420
+ k1,
421
+ c0,
422
+ c1
423
+ ], mode);
424
+ }
425
+ static cmplt(a, b) {
426
+ return new AluExp(AluOp.Cmplt, DType.Bool, [a, b]);
427
+ }
428
+ static cmpne(a, b) {
429
+ return new AluExp(AluOp.Cmpne, DType.Bool, [a, b]);
430
+ }
431
+ static where(cond, a, b) {
432
+ return new AluExp(AluOp.Where, a.dtype, [
433
+ cond,
434
+ a,
435
+ b
436
+ ]);
437
+ }
438
+ static const(dtype, value) {
439
+ if (dtype === DType.Bool) value = Number(Boolean(value));
440
+ else if (dtype === DType.Int32) value = Math.trunc(value) | 0;
441
+ else if (dtype === DType.Uint32) value = Math.trunc(value) >>> 0;
442
+ if (typeof value !== "number") throw new TypeError(`Expected a number for constant, got ${typeof value}: ${value}`);
443
+ return new AluExp(AluOp.Const, dtype, [], value);
444
+ }
445
+ static special(dtype, name, n) {
446
+ return new AluExp(AluOp.Special, dtype, [], [name, n]);
447
+ }
448
+ static variable(dtype, name) {
449
+ return new AluExp(AluOp.Variable, dtype, [], name);
450
+ }
451
+ static globalIndex(dtype, gid, len, bufidx) {
452
+ return new AluExp(AluOp.GlobalIndex, dtype, [bufidx], [gid, len]);
453
+ }
454
+ static globalView(dtype, gid, st, indices) {
455
+ return new AluExp(AluOp.GlobalView, dtype, indices, [gid, st]);
456
+ }
457
+ static f32(value) {
458
+ return AluExp.const(DType.Float32, value);
459
+ }
460
+ static i32(value) {
461
+ return AluExp.const(DType.Int32, value);
462
+ }
463
+ static u32(value) {
464
+ return AluExp.const(DType.Uint32, value);
465
+ }
466
+ static bool(value) {
467
+ return AluExp.const(DType.Bool, Number(value));
468
+ }
469
+ static f16(value) {
470
+ return AluExp.const(DType.Float16, value);
471
+ }
472
+ not() {
473
+ if (this.dtype !== DType.Bool) throw new Error("not() can only be called on boolean expressions");
474
+ return AluExp.cmpne(this, AluExp.const(DType.Bool, true));
475
+ }
476
+ /** Compute a reasonable expression hash with low collision rate. */
477
+ getHash() {
478
+ if (this.#hash !== void 0) return this.#hash;
479
+ const hasher = new FpHash();
480
+ hasher.update(this.op);
481
+ hasher.update(this.dtype);
482
+ hasher.update(JSON.stringify(this.arg));
483
+ hasher.update(this.src.length);
484
+ for (const s of this.src) hasher.update(s);
485
+ this.#hash = hasher.value;
486
+ return this.#hash;
487
+ }
488
+ hash(state) {
489
+ state.update(this.getHash());
490
+ }
491
+ /** Substitute variables in this AluExp to values. */
492
+ substitute(variables) {
493
+ return this.rewrite((exp) => {
494
+ if (exp.op === AluOp.Variable && Object.hasOwn(variables, exp.arg)) {
495
+ if (exp.dtype !== variables[exp.arg].dtype) throw new Error(`Type mismatch: ${exp.dtype} vs ${variables[exp.arg].dtype}`);
496
+ return variables[exp.arg];
497
+ }
498
+ });
499
+ }
500
+ /** Reindex gid values in this expression as needed. */
501
+ reindexGids(gidMap) {
502
+ return this.rewrite((exp) => {
503
+ if (exp.op === AluOp.GlobalIndex) {
504
+ const [gid, len] = exp.arg;
505
+ const newGid = gidMap.get(gid);
506
+ if (newGid !== void 0 && newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
507
+ } else if (exp.op === AluOp.GlobalView) {
508
+ const gid = exp.arg[0];
509
+ const newGid = gidMap.get(gid);
510
+ if (newGid !== void 0 && newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
511
+ }
512
+ });
513
+ }
514
+ #computeRange() {
515
+ if (this.#range !== void 0) return this.#range;
516
+ const src = this.src;
517
+ const minMax4 = (f) => {
518
+ const [r1, r2] = [src[0].#computeRange(), src[1].#computeRange()];
519
+ const values = [
520
+ f(r1[0], r2[0]),
521
+ f(r1[0], r2[1]),
522
+ f(r1[1], r2[0]),
523
+ f(r1[1], r2[1])
524
+ ];
525
+ return [Math.min(...values), Math.max(...values)];
526
+ };
527
+ let ret;
528
+ switch (this.op) {
529
+ case AluOp.Add:
530
+ ret = [src[0].min + src[1].min, src[0].max + src[1].max];
531
+ break;
532
+ case AluOp.Sub:
533
+ ret = [src[0].min - src[1].max, src[0].max - src[1].min];
534
+ break;
535
+ case AluOp.Mul:
536
+ ret = minMax4((a, b) => a * b);
537
+ break;
538
+ case AluOp.Idiv:
539
+ ret = minMax4((a, b) => Math.trunc(a / b));
540
+ break;
541
+ case AluOp.Mod: {
542
+ let divisorRange = src[1].#computeRange();
543
+ if (divisorRange[0] <= 0 && divisorRange[1] >= 0) divisorRange = [0, Math.max(-divisorRange[0], divisorRange[1])];
544
+ if (divisorRange[1] < 0) divisorRange = [-divisorRange[1], -divisorRange[0]];
545
+ const maxDivisor = isFloatDtype(this.dtype) ? divisorRange[1] : divisorRange[1] - 1;
546
+ ret = [clamp(src[0].min, -maxDivisor, 0), clamp(src[0].max, 0, maxDivisor)];
547
+ break;
548
+ }
549
+ case AluOp.Min:
550
+ ret = [Math.min(src[0].min, src[1].min), Math.min(src[0].max, src[1].max)];
551
+ break;
552
+ case AluOp.Max:
553
+ ret = [Math.max(src[0].min, src[1].min), Math.max(src[0].max, src[1].max)];
554
+ break;
555
+ case AluOp.Sin:
556
+ ret = [-1, 1];
557
+ break;
558
+ case AluOp.Cos:
559
+ ret = [-1, 1];
560
+ break;
561
+ case AluOp.Asin:
562
+ ret = [-Math.PI / 2, Math.PI / 2];
563
+ break;
564
+ case AluOp.Atan:
565
+ ret = [-Math.PI / 2, Math.PI / 2];
566
+ break;
567
+ case AluOp.Exp:
568
+ ret = [Math.exp(src[0].min), Math.exp(src[0].max)];
569
+ break;
570
+ case AluOp.Log:
571
+ ret = [Math.log(src[0].min), Math.log(src[0].max)];
572
+ break;
573
+ case AluOp.Sqrt:
574
+ ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
575
+ break;
576
+ case AluOp.Reciprocal:
577
+ if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
578
+ ret = [1 / src[0].max, 1 / src[0].min];
579
+ break;
580
+ case AluOp.Cast: {
581
+ const wasFloat = isFloatDtype(src[0].dtype);
582
+ const bounded = Number.isFinite(src[0].min) && Number.isFinite(src[0].max);
583
+ if (this.dtype === DType.Bool) {
584
+ const canBeZero = src[0].min <= 0 && src[0].max >= 0;
585
+ const mustBeZero = src[0].min === 0 && src[0].max === 0;
586
+ ret = mustBeZero ? [0, 0] : canBeZero ? [0, 1] : [1, 1];
587
+ } else if (this.dtype === DType.Int32) {
588
+ const a = wasFloat ? clamp(src[0].min, -2147483648, 2147483647) | 0 : src[0].min | 0;
589
+ const b = wasFloat ? clamp(src[0].max, -2147483648, 2147483647) | 0 : src[0].max | 0;
590
+ ret = bounded && a <= b ? [a, b] : [-Infinity, Infinity];
591
+ } else if (this.dtype === DType.Uint32) {
592
+ const a = wasFloat ? clamp(src[0].min, 0, 4294967295) >>> 0 : src[0].min >>> 0;
593
+ const b = wasFloat ? clamp(src[0].max, 0, 4294967295) >>> 0 : src[0].max >>> 0;
594
+ ret = bounded && a <= b ? [a, b] : [0, Infinity];
595
+ } else ret = [src[0].min, src[0].max];
596
+ break;
597
+ }
598
+ case AluOp.Cmplt:
599
+ ret = [0, 1];
600
+ break;
601
+ case AluOp.Cmpne:
602
+ ret = [0, 1];
603
+ break;
604
+ case AluOp.Where:
605
+ ret = [Math.min(src[1].min, src[2].min), Math.max(src[1].max, src[2].max)];
606
+ break;
607
+ case AluOp.Const:
608
+ ret = [this.arg, this.arg];
609
+ break;
610
+ case AluOp.Special:
611
+ ret = [0, this.arg[1] - 1];
612
+ break;
613
+ default: ret = [-Infinity, Infinity];
614
+ }
615
+ if (isNaN(ret[0]) || isNaN(ret[1])) ret = [-Infinity, Infinity];
616
+ if (this.dtype === DType.Bool) {
617
+ ret[0] = clamp(ret[0], 0, 1);
618
+ ret[1] = clamp(ret[1], 0, 1);
619
+ }
620
+ if (this.dtype === DType.Uint32) ret[0] = Math.max(0, ret[0]);
621
+ this.#range = ret;
622
+ return ret;
623
+ }
624
+ get min() {
625
+ return this.#computeRange()[0];
626
+ }
627
+ get max() {
628
+ return this.#computeRange()[1];
629
+ }
630
+ /** Largest known integer that divides self. */
631
+ constFactor() {
632
+ if (this.op === AluOp.Const) return Math.abs(this.arg);
633
+ if (this.op === AluOp.Add) return gcd(this.src[0].constFactor(), this.src[1].constFactor());
634
+ if (this.op === AluOp.Mul) {
635
+ if (this.src[0].op === AluOp.Const) return Math.abs(this.src[0].arg);
636
+ if (this.src[1].op === AluOp.Const) return Math.abs(this.src[1].arg);
637
+ }
638
+ return 1;
639
+ }
640
+ /**
641
+ * Checks if divisible by an integer v and returns the quotient if it is, or
642
+ * `null` if it's not divisible.
643
+ */
644
+ divides(v) {
645
+ if (v === 1) return this;
646
+ if (this.op === AluOp.Const && this.arg % v === 0) return AluExp.const(this.dtype, this.arg / v);
647
+ if (this.op === AluOp.Add) {
648
+ const a = this.src[0].divides(v);
649
+ if (a !== null) {
650
+ const b = this.src[1].divides(v);
651
+ if (b !== null) return AluExp.add(a, b);
652
+ }
653
+ }
654
+ if (this.op === AluOp.Mul) {
655
+ const a = this.src[0].divides(v);
656
+ if (a !== null) return AluExp.mul(a, this.src[1]);
657
+ const b = this.src[1].divides(v);
658
+ if (b !== null) return AluExp.mul(this.src[0], b);
659
+ }
660
+ return null;
661
+ }
662
+ #isConstInt() {
663
+ return this.op === AluOp.Const && (this.dtype === DType.Int32 || this.dtype === DType.Uint32);
664
+ }
665
+ /**
666
+ * Get all expressions by deeply matching an operation.
667
+ *
668
+ * For example: `((2+(3*5))+4).splitOp(+) -> [2,(3*5),4]`.
669
+ */
670
+ *splitOp(sep) {
671
+ if (this.op === sep) for (const src of this.src) yield* src.splitOp(sep);
672
+ else yield this;
673
+ }
674
+ /**
675
+ * Simplify the expression by replacing any known patterns and deduping
676
+ * identical subexpressions.
677
+ */
678
+ simplify(cache = /* @__PURE__ */ new Map()) {
679
+ if (this.#simplified !== void 0) return this.#simplified;
680
+ const hash = this.getHash();
681
+ const prevCachedValue = cache.get(hash);
682
+ if (prevCachedValue !== void 0) return this.#simplified = prevCachedValue;
683
+ const simplified = this.#simplifyInner(cache);
684
+ const simplifiedHash = simplified.getHash();
685
+ const prevSimplified = cache.get(simplifiedHash);
686
+ if (prevSimplified !== void 0) {
687
+ cache.set(hash, prevSimplified);
688
+ this.#simplified = prevSimplified;
689
+ return prevSimplified;
690
+ } else {
691
+ cache.set(hash, simplified);
692
+ cache.set(simplifiedHash, simplified);
693
+ this.#simplified = simplified;
694
+ return simplified;
695
+ }
696
+ }
697
+ #simplifyInner(cache) {
698
+ const src = this.src.map((x) => x.simplify(cache));
699
+ const { op } = this;
700
+ if (src.every((x) => x.op === AluOp.Const) && !AluGroup.Variable.has(op)) {
701
+ const newExp$1 = new AluExp(op, this.dtype, src, this.arg);
702
+ return AluExp.const(this.dtype, newExp$1.evaluate({}));
703
+ }
704
+ if (op !== AluOp.Const && this.min === this.max) return AluExp.const(this.dtype, this.min);
705
+ if (AluGroup.Binary.has(op)) for (let i = 0; i < 2; i++) {
706
+ if (src[i].op !== AluOp.Const) continue;
707
+ const x = src[i].arg;
708
+ if (op === AluOp.Add && x === 0) return src[1 - i];
709
+ if (op === AluOp.Sub && i === 1 && x === 0) return src[1 - i];
710
+ if (op === AluOp.Mul && x === 1) return src[1 - i];
711
+ if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
712
+ if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
713
+ }
714
+ if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
715
+ const [a, b] = src[1].src;
716
+ const opNeg = op === AluOp.Add ? AluOp.Sub : AluOp.Add;
717
+ if (a.op === AluOp.Const && a.arg === -1) return new AluExp(opNeg, this.dtype, [src[0], b]);
718
+ else if (b.op === AluOp.Const && b.arg === -1) return new AluExp(opNeg, this.dtype, [src[0], a]);
719
+ }
720
+ if (op === AluOp.Where && src.slice(1).every((s, i) => s.op === AluOp.Const && s.arg === 1 - i)) return AluExp.cast(this.dtype, src[0]);
721
+ if (op === AluOp.Cmplt) {
722
+ if (src[0].min >= src[1].max) return AluExp.const(DType.Bool, false);
723
+ if (src[0].max < src[1].min) return AluExp.const(DType.Bool, true);
724
+ }
725
+ if (op === AluOp.Cmpne) {
726
+ if (src[0].max < src[1].min || src[0].min > src[1].max) return AluExp.const(DType.Bool, true);
727
+ }
728
+ if (op === AluOp.Where) {
729
+ if (src[0].max === 0) return src[2];
730
+ if (src[0].min === 1) return src[1];
731
+ }
732
+ if (op === AluOp.Mod && src[1].op === AluOp.Const && src[0].min >= 0 && src[0].max < src[1].arg) return src[0];
733
+ if (op === AluOp.Mod && src[0].op === AluOp.Mod && src[1].#isConstInt() && src[0].src[1].#isConstInt()) {
734
+ const A = src[0].src[1].arg;
735
+ const B = src[1].arg;
736
+ if (A > 0 && B > 0 && (A % B === 0 || B % A === 0)) return AluExp.mod(src[0].src[0], AluExp.const(this.dtype, Math.min(A, B))).simplify();
737
+ }
738
+ if (op === AluOp.Add && src[0].op === AluOp.Mul && src[0].src[1].#isConstInt() && src[1].op === AluOp.Mod && src[1].src[1].#isConstInt() && src[0].src[1].arg === src[1].src[1].arg) {
739
+ const [mul, mod] = src;
740
+ const check = (exp) => {
741
+ return exp.op === AluOp.Idiv && exp.src[1].#isConstInt() && exp.src[1].arg === mod.src[1].arg && exp.src[0] === mod.src[0];
742
+ };
743
+ if (check(mul.src[0])) return mod.src[0];
744
+ if (mul.src[0].op === AluOp.Mod) {
745
+ const [x, y] = mul.src[0].src;
746
+ if (check(x)) return AluExp.mod(mod.src[0], AluExp.mul(mod.src[1], y)).simplify(cache);
747
+ }
748
+ }
749
+ if (op === AluOp.Idiv && src[1].#isConstInt()) {
750
+ const [numer, denom] = src;
751
+ const B = denom.arg;
752
+ for (let i = 0; i < 2; i++) {
753
+ if (numer.op === AluOp.Mul && numer.src[i].#isConstInt()) {
754
+ const A = numer.src[i].arg;
755
+ if (A % B === 0) {
756
+ let ret = numer.src[1 - i];
757
+ if (A / B !== 1) ret = AluExp.mul(ret, AluExp.const(ret.dtype, A / B));
758
+ return ret.simplify(cache);
759
+ }
760
+ }
761
+ for (let j = 0; j < 2; j++) if (numer.op === AluOp.Add && numer.src[j].op === AluOp.Mul && numer.src[j].src[i].#isConstInt()) {
762
+ const A = numer.src[j].src[i].arg;
763
+ if (A % B === 0) {
764
+ let ret = numer.src[j].src[1 - i];
765
+ if (A / B !== 1) ret = AluExp.mul(ret, AluExp.const(ret.dtype, A / B));
766
+ ret = AluExp.add(ret, AluExp.idiv(numer.src[1 - j], AluExp.const(ret.dtype, B)));
767
+ return ret.simplify(cache);
768
+ }
769
+ }
770
+ }
771
+ }
772
+ if (op === AluOp.Mod && src[1].#isConstInt() && src[1].arg > 0 && src[0].min >= 0) {
773
+ const [numer, denom] = src;
774
+ const B = denom.arg;
775
+ for (let i = 0; i < 2; i++) if (numer.op === AluOp.Add) {
776
+ if (numer.src[i].#isConstInt()) {
777
+ const A = numer.src[i].arg;
778
+ const x = numer.src[1 - i];
779
+ if (A % B === 0 && x.min >= 0) return AluExp.mod(x, denom).simplify(cache);
780
+ }
781
+ for (let j = 0; j < 2; j++) if (numer.src[i].op === AluOp.Mul && numer.src[i].src[j].#isConstInt()) {
782
+ const A = numer.src[i].src[j].arg;
783
+ const x = numer.src[1 - i];
784
+ if (A % B === 0 && x.min >= 0) return AluExp.mod(x, denom).simplify(cache);
785
+ }
786
+ } else if (numer.op === AluOp.Mul) {
787
+ if (numer.src[i].#isConstInt()) {
788
+ const A = numer.src[i].arg;
789
+ if (A % B === 0) return AluExp.const(this.dtype, 0);
790
+ if (A % B === 1) return AluExp.mod(numer.src[1 - i], denom).simplify(cache);
791
+ }
792
+ }
793
+ }
794
+ const commOps = [
795
+ AluOp.Add,
796
+ AluOp.Mul,
797
+ AluOp.Max,
798
+ AluOp.Min
799
+ ];
800
+ if (commOps.includes(op)) {
801
+ const p = (a, b) => new AluExp(op, this.dtype, [a, b]);
802
+ if (src[0].op === AluOp.Const) return p(src[1], src[0]).simplify(cache);
803
+ if (src[0].op === op && src[0].src[1].op === AluOp.Const) if (src[1].op === AluOp.Const) return p(src[0].src[0], p(src[0].src[1], src[1])).simplify(cache);
804
+ else return p(p(src[0].src[0], src[1]), src[0].src[1]).simplify(cache);
805
+ if (src[1].op === op && src[1].src[1].op === AluOp.Const) return p(p(src[0], src[1].src[0]), src[1].src[1]).simplify(cache);
806
+ }
807
+ if (op === AluOp.Mod || op === AluOp.Idiv && src[1].#isConstInt()) {
808
+ const [x, y] = src;
809
+ {
810
+ const factors = [];
811
+ const terms = [];
812
+ for (const u of x.splitOp(AluOp.Add)) {
813
+ const factor = u.constFactor();
814
+ factors.push(factor);
815
+ terms.push(u.divides(factor));
816
+ }
817
+ const g = gcd(y.arg, ...factors);
818
+ if (g !== 1) {
819
+ let ret = new AluExp(op, this.dtype, [factors.map((f, i) => AluExp.mul(AluExp.const(terms[i].dtype, f / g), terms[i])).reduceRight((a, x$1) => AluExp.add(x$1, a)), AluExp.const(y.dtype, y.arg / g)]);
820
+ if (op === AluOp.Mod) ret = AluExp.mul(ret, AluExp.const(this.dtype, g));
821
+ return ret.simplify(cache);
822
+ }
823
+ }
824
+ if (y.arg > 0) {
825
+ let [xNoConst, constVal] = [x, 0];
826
+ if (x.op === AluOp.Add && x.src[1].op === AluOp.Const) [xNoConst, constVal] = [x.src[0], x.src[1].arg];
827
+ const terms = [];
828
+ const factors = [];
829
+ for (const u of xNoConst.splitOp(AluOp.Add)) {
830
+ const f = u.constFactor();
831
+ const divided = u.divides(f);
832
+ terms.push(divided ?? u);
833
+ factors.push(divided ? f : 1);
834
+ }
835
+ const quotients = factors.map((f) => Math.floor(f / y.arg));
836
+ const remainders = factors.map((f) => f % y.arg);
837
+ const gcdVal = remainders.reduce((g, r) => gcd(g, r), y.arg);
838
+ if (constVal % y.arg !== constVal || gcdVal !== 1 || remainders.some((r, i) => r === 0 || r !== factors[i] && op === AluOp.Mod)) {
839
+ let quo = AluExp.const(x.dtype, Math.floor(constVal / y.arg));
840
+ let rem = AluExp.const(x.dtype, Math.floor(constVal % y.arg / gcdVal));
841
+ for (let i = 0; i < terms.length; i++) if (op === AluOp.Idiv && remainders[i] !== 0) rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(factors[i] / gcdVal)), terms[i]));
842
+ else {
843
+ rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(remainders[i] / gcdVal)), terms[i]));
844
+ quo = AluExp.add(quo, AluExp.mul(AluExp.const(x.dtype, quotients[i]), terms[i]));
845
+ }
846
+ if (!((x.min < 0 || rem.min < 0) && remainders.some((r) => r !== 0))) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
847
+ else return AluExp.add(AluExp.idiv(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal))), quo).simplify(cache);
848
+ }
849
+ }
850
+ }
851
+ const newExp = src.every((s, i) => s === this.src[i]) ? this : new AluExp(op, this.dtype, src, this.arg);
852
+ return newExp;
853
+ }
854
+ /** Resolve this to a value, or `undefined` if not possible. */
855
+ resolve() {
856
+ const x = this.simplify();
857
+ if (x.op === AluOp.Const) return x.arg;
858
+ return void 0;
859
+ }
860
+ /**
861
+ * Evaluate the expression on CPU, returning the result.
862
+ *
863
+ * Typically you would compile the AluExp as a representation to a lower-level
864
+ * language. This is just to define the semantics and help debug.
865
+ *
866
+ * Note that the representation of Bool is as a number (0 or 1) here.
867
+ */
868
+ evaluate(context, globals) {
869
+ if (AluGroup.Binary.has(this.op) || AluGroup.Compare.has(this.op)) {
870
+ const x = this.src[0].evaluate(context, globals);
871
+ const y = this.src[1].evaluate(context, globals);
872
+ switch (this.op) {
873
+ case AluOp.Add: return this.dtype === DType.Bool ? Number(x || y) : x + y;
874
+ case AluOp.Sub: return x - y;
875
+ case AluOp.Mul: return this.dtype === DType.Bool ? Number(x && y) : x * y;
876
+ case AluOp.Idiv: return Math.trunc(x / y);
877
+ case AluOp.Mod: return x % y;
878
+ case AluOp.Min: return Math.min(x, y);
879
+ case AluOp.Max: return Math.max(x, y);
880
+ case AluOp.Cmplt: return Number(x < y);
881
+ case AluOp.Cmpne: return Number(x != y);
882
+ default: throw new Error(`Missing implemementation for ${this.op}`);
883
+ }
884
+ }
885
+ if (AluGroup.Unary.has(this.op)) {
886
+ const x = this.src[0].evaluate(context, globals);
887
+ switch (this.op) {
888
+ case AluOp.Sin: return Math.sin(x);
889
+ case AluOp.Cos: return Math.cos(x);
890
+ case AluOp.Asin: return Math.asin(x);
891
+ case AluOp.Atan: return Math.atan(x);
892
+ case AluOp.Exp: return Math.exp(x);
893
+ case AluOp.Log: return Math.log(x);
894
+ case AluOp.Sqrt: return Math.sqrt(x);
895
+ case AluOp.Reciprocal: return 1 / x;
896
+ case AluOp.Cast: {
897
+ const wasFloat = isFloatDtype(this.src[0].dtype);
898
+ if (this.dtype === DType.Int32) return (wasFloat ? clamp(x, -2147483648, 2147483647) : x) | 0;
899
+ else if (this.dtype === DType.Uint32) return (wasFloat ? clamp(x, 0, 4294967295) : x) >>> 0;
900
+ else if (isFloatDtype(this.dtype)) return x;
901
+ else if (this.dtype === DType.Bool) return Number(Boolean(x));
902
+ else throw new Error(`Unsupported cast to ${this.dtype}`);
903
+ }
904
+ case AluOp.Bitcast: {
905
+ const buf = new ArrayBuffer(byteWidth(this.dtype));
906
+ const view = new DataView(buf);
907
+ const fromType = this.src[0].dtype;
908
+ if (fromType === DType.Float32) view.setFloat32(0, x, true);
909
+ else if (fromType === DType.Int32) view.setInt32(0, x, true);
910
+ else if (fromType === DType.Uint32) view.setUint32(0, x, true);
911
+ else if (fromType === DType.Float16) view.setFloat16(0, x, true);
912
+ else throw new Error(`Unsupported bitcast from ${fromType}`);
913
+ if (this.dtype === DType.Float32) return view.getFloat32(0, true);
914
+ else if (this.dtype === DType.Int32) return view.getInt32(0, true);
915
+ else if (this.dtype === DType.Uint32) return view.getUint32(0, true);
916
+ else if (this.dtype === DType.Float16) return view.getFloat16(0, true);
917
+ else throw new Error(`Unsupported bitcast to ${this.dtype}`);
918
+ }
919
+ default: throw new Error(`Missing implemementation for ${this.op}`);
920
+ }
921
+ }
922
+ switch (this.op) {
923
+ case AluOp.Where: return this.src[0].evaluate(context, globals) ? this.src[1].evaluate(context, globals) : this.src[2].evaluate(context, globals);
924
+ case AluOp.Threefry2x32: {
925
+ const [k0, k1, c0, c1] = this.src.map((x) => x.evaluate(context, globals));
926
+ const [x0, x1] = threefry2x32(k0, k1, c0, c1);
927
+ if (this.arg === "xor") return (x0 ^ x1) >>> 0;
928
+ else if (this.arg === 0) return x0;
929
+ else if (this.arg === 1) return x1;
930
+ else throw new Error(`Invalid Threefry2x32 mode: ${this.arg}`);
931
+ }
932
+ case AluOp.Const: return this.arg;
933
+ case AluOp.Special: {
934
+ const x = context[this.arg[0]];
935
+ if (x === void 0) throw new Error(`Missing special: ${this.arg[0]}`);
936
+ return x;
937
+ }
938
+ case AluOp.Variable: {
939
+ const x = context[this.arg];
940
+ if (x === void 0) throw new Error(`Missing variable: ${this.arg}`);
941
+ return x;
942
+ }
943
+ case AluOp.GlobalIndex: {
944
+ if (!globals) throw new Error("Missing globals function");
945
+ const gid = this.arg[0];
946
+ const bufidx = this.src[0].evaluate(context, globals);
947
+ return globals(gid, bufidx);
948
+ }
949
+ case AluOp.GlobalView: {
950
+ if (!globals) throw new Error("Missing globals function");
951
+ const gid = this.arg[0];
952
+ const st = this.arg[1];
953
+ const [iexpr, vexpr] = st.toAluExp(this.src);
954
+ if (vexpr.evaluate(context, globals)) {
955
+ const bufidx = iexpr.evaluate(context, globals);
956
+ return globals(gid, bufidx);
957
+ } else return 0;
958
+ }
959
+ default: throw new Error(`Missing implemementation for ${this.op}`);
960
+ }
961
+ }
962
+ /** Get this expression in debug format as a string. */
963
+ toString() {
964
+ const BIN_SYM = {
965
+ [AluOp.Add]: "+",
966
+ [AluOp.Sub]: "-",
967
+ [AluOp.Mul]: "*",
968
+ [AluOp.Idiv]: "/",
969
+ [AluOp.Mod]: "%"
970
+ };
971
+ const CMP_SYM = {
972
+ [AluOp.Cmplt]: "<",
973
+ [AluOp.Cmpne]: "!="
974
+ };
975
+ const UNARY_SYM = { [AluOp.Reciprocal]: "1/" };
976
+ return this.fold((node, parts) => {
977
+ switch (node.op) {
978
+ case AluOp.Const: return "" + (node.dtype === DType.Bool ? Boolean(node.arg) : node.arg);
979
+ case AluOp.Variable: return `$${node.arg}:${node.dtype}`;
980
+ case AluOp.Special: {
981
+ const [name, n] = node.arg;
982
+ return `#${name}{${n}}`;
983
+ }
984
+ case AluOp.GlobalIndex: return `G_${node.arg[0]}<${node.dtype}>[${strip1(parts[0])}]`;
985
+ case AluOp.GlobalView: {
986
+ const [gid, st] = node.arg;
987
+ const shape = st.shape.join(",");
988
+ const lastStrides = st.lastStrides.join(",");
989
+ const cont = st.contiguous ? "c" : "nc";
990
+ return `GV_${gid}<${node.dtype}>{${shape}:${lastStrides}:${cont}}[${parts.map(strip1).join(", ")}]`;
991
+ }
992
+ }
993
+ if (BIN_SYM[node.op]) return `(${parts[0]} ${BIN_SYM[node.op]} ${parts[1]})`;
994
+ if (CMP_SYM[node.op]) return `(${parts[0]} ${CMP_SYM[node.op]} ${parts[1]})`;
995
+ if (UNARY_SYM[node.op]) return `${UNARY_SYM[node.op]}${parts[0]}`;
996
+ if (node.op === AluOp.Cast) return `Cast<${node.dtype}>(${strip1(parts[0])})`;
997
+ if (node.op === AluOp.Bitcast) return `Bitcast<${node.dtype}>(${strip1(parts[0])})`;
998
+ return `${node.op}(${parts.map(strip1).join(", ")})`;
999
+ });
1000
+ }
1001
+ /** Generic fold() operation with a reducer over the expression tree. */
1002
+ fold(reducer) {
1003
+ const visited = /* @__PURE__ */ new Map();
1004
+ const recurse = (exp) => {
1005
+ if (visited.has(exp)) return visited.get(exp);
1006
+ const mappedSrc = exp.src.map((s) => recurse(s));
1007
+ const result = reducer(exp, mappedSrc);
1008
+ visited.set(exp, result);
1009
+ return result;
1010
+ };
1011
+ return recurse(this);
1012
+ }
1013
+ /** Check if any expression in the tree satisfies a predicate. */
1014
+ some(predicate) {
1015
+ const visited = /* @__PURE__ */ new Set();
1016
+ const recurse = (exp) => {
1017
+ if (visited.has(exp)) return false;
1018
+ if (predicate(exp)) return true;
1019
+ visited.add(exp);
1020
+ return exp.src.some(recurse);
1021
+ };
1022
+ return recurse(this);
1023
+ }
1024
+ /** Rewrite the expression recursively using a visitor. */
1025
+ rewrite(visitor) {
1026
+ return this.fold((exp, newSrc) => {
1027
+ if (newSrc.length === exp.src.length && newSrc.every((s, i) => s === exp.src[i])) return visitor(exp) ?? exp;
1028
+ else {
1029
+ const newExp = new AluExp(exp.op, exp.dtype, newSrc, exp.arg);
1030
+ return visitor(newExp) ?? newExp;
1031
+ }
1032
+ });
1033
+ }
1034
+ /** Collect all nodes that satisfy a predicate. */
1035
+ collect(predicate) {
1036
+ const result = [];
1037
+ this.fold((exp) => {
1038
+ if (predicate(exp)) result.push(exp);
1039
+ });
1040
+ return result;
1041
+ }
1042
+ /** Produce a list of all distinct AluOp in this expression. */
1043
+ distinctOps() {
1044
+ const ops = /* @__PURE__ */ new Set();
1045
+ this.fold((exp) => {
1046
+ ops.add(exp.op);
1047
+ });
1048
+ return ops;
1049
+ }
1050
+ /** Rewrite GlobalView operations to GlobalIndex operations. */
1051
+ rewriteGlobalViews() {
1052
+ return this.rewrite((exp) => {
1053
+ if (exp.op === AluOp.GlobalView) {
1054
+ const [gid, st] = exp.arg;
1055
+ return accessorGlobal(exp.dtype, gid, st, exp.src);
1056
+ }
1057
+ });
1058
+ }
1059
+ };
1060
+ /** Symbolic form for each mathematical operation. */
1061
+ let AluOp = /* @__PURE__ */ function(AluOp$1) {
1062
+ AluOp$1["Add"] = "Add";
1063
+ AluOp$1["Sub"] = "Sub";
1064
+ AluOp$1["Mul"] = "Mul";
1065
+ AluOp$1["Idiv"] = "Idiv";
1066
+ AluOp$1["Mod"] = "Mod";
1067
+ AluOp$1["Min"] = "Min";
1068
+ AluOp$1["Max"] = "Max";
1069
+ AluOp$1["Sin"] = "Sin";
1070
+ AluOp$1["Cos"] = "Cos";
1071
+ AluOp$1["Asin"] = "Asin";
1072
+ AluOp$1["Atan"] = "Atan";
1073
+ AluOp$1["Exp"] = "Exp";
1074
+ AluOp$1["Log"] = "Log";
1075
+ AluOp$1["Sqrt"] = "Sqrt";
1076
+ AluOp$1["Reciprocal"] = "Reciprocal";
1077
+ AluOp$1["Cast"] = "Cast";
1078
+ AluOp$1["Bitcast"] = "Bitcast";
1079
+ AluOp$1["Cmplt"] = "Cmplt";
1080
+ AluOp$1["Cmpne"] = "Cmpne";
1081
+ AluOp$1["Where"] = "Where";
1082
+ AluOp$1["Threefry2x32"] = "Threefry2x32";
1083
+ AluOp$1["Const"] = "Const";
1084
+ AluOp$1["Special"] = "Special";
1085
+ AluOp$1["Variable"] = "Variable";
1086
+ AluOp$1["GlobalIndex"] = "GlobalIndex";
1087
+ AluOp$1["GlobalView"] = "GlobalView";
1088
+ return AluOp$1;
1089
+ }({});
1090
+ const AluGroup = {
1091
+ Binary: new Set([
1092
+ AluOp.Add,
1093
+ AluOp.Sub,
1094
+ AluOp.Mul,
1095
+ AluOp.Idiv,
1096
+ AluOp.Mod,
1097
+ AluOp.Min,
1098
+ AluOp.Max
1099
+ ]),
1100
+ Unary: new Set([
1101
+ AluOp.Sin,
1102
+ AluOp.Cos,
1103
+ AluOp.Asin,
1104
+ AluOp.Atan,
1105
+ AluOp.Exp,
1106
+ AluOp.Log,
1107
+ AluOp.Sqrt,
1108
+ AluOp.Reciprocal,
1109
+ AluOp.Cast,
1110
+ AluOp.Bitcast
1111
+ ]),
1112
+ Compare: new Set([AluOp.Cmplt, AluOp.Cmpne]),
1113
+ Variable: new Set([
1114
+ AluOp.Special,
1115
+ AluOp.Variable,
1116
+ AluOp.GlobalIndex,
1117
+ AluOp.GlobalView
1118
+ ]),
1119
+ Reduce: new Set([
1120
+ AluOp.Add,
1121
+ AluOp.Mul,
1122
+ AluOp.Min,
1123
+ AluOp.Max
1124
+ ]),
1125
+ RequiredFloat: new Set([
1126
+ AluOp.Sin,
1127
+ AluOp.Cos,
1128
+ AluOp.Asin,
1129
+ AluOp.Atan,
1130
+ AluOp.Exp,
1131
+ AluOp.Log,
1132
+ AluOp.Sqrt,
1133
+ AluOp.Reciprocal
1134
+ ])
1135
+ };
1136
+ /** Common variables that can be substituted in expressions. */
1137
+ const AluVar = {
1138
+ gidx: AluExp.variable(DType.Int32, "gidx"),
1139
+ ridx: AluExp.variable(DType.Int32, "ridx"),
1140
+ acc: (dtype) => AluExp.variable(dtype, "acc"),
1141
+ idx: AluExp.variable(DType.Int32, "idx"),
1142
+ unroll: AluExp.variable(DType.Int32, "unroll"),
1143
+ upcast: AluExp.variable(DType.Int32, "upcast")
1144
+ };
1145
+ /**
1146
+ * Description of a kernel to be compiled.
1147
+ *
1148
+ * Each of these can be processed by a backend into some lower-level
1149
+ * representation. It consists of one or more fused operations, optionally
1150
+ * indexing into a buffer.
1151
+ */
1152
+ var Kernel = class {
1153
+ constructor(nargs, size, exp, reduction) {
1154
+ this.nargs = nargs;
1155
+ this.size = size;
1156
+ this.exp = exp;
1157
+ this.reduction = reduction;
1158
+ this.exp = exp.simplify();
1159
+ }
1160
+ hash(state) {
1161
+ state.update(this.nargs).update(this.size).update(this.exp).update(this.reduction);
1162
+ }
1163
+ pprint() {
1164
+ let details = PPrint.pp(`exp = ${this.exp}`);
1165
+ details = details.concat(PPrint.pp(`size = ${this.size}`));
1166
+ if (this.reduction) details = details.concat(PPrint.pp(`reduction = ${this.reduction}`));
1167
+ return PPrint.pp("{ ").stack(details).stack(PPrint.pp(" }"));
1168
+ }
1169
+ toString() {
1170
+ return this.pprint().toString();
1171
+ }
1172
+ /** The dtype of the values output by this kernel. */
1173
+ get dtype() {
1174
+ if (this.reduction) return this.reduction.epilogue.dtype;
1175
+ else return this.exp.dtype;
1176
+ }
1177
+ /** The number of bytes in the output array when evaluating this kernel. */
1178
+ get bytes() {
1179
+ return this.size * byteWidth(this.dtype);
1180
+ }
1181
+ };
1182
+ /**
1183
+ * Description of a reduction.
1184
+ *
1185
+ * The strategy of jax-js backends is to either handle a standard operation that
1186
+ * is dispatched in a vectorized way over an array, or to reduce over one axis
1187
+ * of some computation. This is a description of the reduction.
1188
+ *
1189
+ * Reduction only supports a few operations, and only over one axis. Users can
1190
+ * always `flatten()` the array before reducing if needed.
1191
+ *
1192
+ * The backend is responsible for implementing the reduction in a way that
1193
+ * minimizes the number of global memory loads, for efficiency. This involves
1194
+ * passing through some optimization strategy. But optimizations are not coded
1195
+ * at this level since they depend on GPU, versus CPU or Wasm.
1196
+ */
1197
+ var Reduction = class {
1198
+ constructor(dtype, op, size, epilogue = AluVar.acc(dtype)) {
1199
+ this.dtype = dtype;
1200
+ this.op = op;
1201
+ this.size = size;
1202
+ this.epilogue = epilogue;
1203
+ if (!AluGroup.Reduce.has(op)) throw new TypeError(`Unsupported reduction: ${op}`);
1204
+ this.epilogue = epilogue.simplify();
1205
+ }
1206
+ hash(state) {
1207
+ state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
1208
+ }
1209
+ toString() {
1210
+ return `${this.op}{${this.size}} -> ${this.epilogue}`;
1211
+ }
1212
+ /** Get the identity for this reduction operation. */
1213
+ get identity() {
1214
+ if (this.dtype === DType.Bool) return this.op === AluOp.Add || this.op === AluOp.Max ? 0 : 1;
1215
+ else if (this.dtype === DType.Int32) {
1216
+ if (this.op === AluOp.Add) return 0;
1217
+ else if (this.op === AluOp.Mul) return 1;
1218
+ else if (this.op === AluOp.Min) return -1 >>> 1;
1219
+ else if (this.op === AluOp.Max) return 1 << 31;
1220
+ } else if (this.dtype === DType.Uint32) {
1221
+ if (this.op === AluOp.Add) return 0;
1222
+ else if (this.op === AluOp.Mul) return 1;
1223
+ else if (this.op === AluOp.Min) return -1 >>> 0;
1224
+ else if (this.op === AluOp.Max) return 0;
1225
+ } else if (isFloatDtype(this.dtype)) {
1226
+ if (this.op === AluOp.Add) return 0;
1227
+ else if (this.op === AluOp.Mul) return 1;
1228
+ else if (this.op === AluOp.Min) return Infinity;
1229
+ else if (this.op === AluOp.Max) return -Infinity;
1230
+ }
1231
+ throw new TypeError(`Unsupported reduction: ${this.op} ${this.dtype}`);
1232
+ }
1233
+ /** Evaluate this operation on CPU. */
1234
+ evaluate(...values) {
1235
+ if (this.dtype === DType.Bool) {
1236
+ if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b, true);
1237
+ else if (this.op === AluOp.Mul || this.op === AluOp.Min) return values.reduce((a, b) => a && b, true);
1238
+ } else if (this.dtype === DType.Int32) {
1239
+ if (this.op === AluOp.Add) return values.reduce((a, b) => a + b | 0, 0);
1240
+ else if (this.op === AluOp.Mul) return values.reduce((a, b) => a * b | 0, 1);
1241
+ else if (this.op === AluOp.Min) return values.reduce((a, b) => Math.min(a, b), -1 >>> 1);
1242
+ else if (this.op === AluOp.Max) return values.reduce((a, b) => Math.max(a, b), 1 << 31);
1243
+ } else if (this.dtype === DType.Uint32) {
1244
+ if (this.op === AluOp.Add) return values.reduce((a, b) => a + b >>> 0, 0);
1245
+ else if (this.op === AluOp.Mul) return values.reduce((a, b) => a * b >>> 0, 1);
1246
+ else if (this.op === AluOp.Min) return values.reduce((a, b) => Math.min(a, b), -1 >>> 0);
1247
+ else if (this.op === AluOp.Max) return values.reduce((a, b) => Math.max(a, b), 0);
1248
+ } else if (isFloatDtype(this.dtype)) {
1249
+ if (this.op === AluOp.Add) return values.reduce((a, b) => a + b, 0);
1250
+ else if (this.op === AluOp.Mul) return values.reduce((a, b) => a * b, 1);
1251
+ else if (this.op === AluOp.Min) return values.reduce((a, b) => Math.min(a, b), Infinity);
1252
+ else if (this.op === AluOp.Max) return values.reduce((a, b) => Math.max(a, b), -Infinity);
1253
+ }
1254
+ throw new TypeError(`Unsupported reduction: ${this.op} ${this.dtype}`);
1255
+ }
1256
+ };
1257
+ /** Expression for accessing `indices` in input array with the given shape. */
1258
+ function accessorGlobal(dtype, gid, st, indices) {
1259
+ const [index, valid] = st.toAluExp(indices);
1260
+ const [, len] = st.views[0].dataRange();
1261
+ return AluExp.where(valid, AluExp.globalIndex(dtype, gid, len, index), AluExp.const(dtype, 0));
1262
+ }
1263
+ /** Expression for accessing `indices` in an array recipe with variable "idx". */
1264
+ function accessorAluExp(exp, st, indices) {
1265
+ const [index, valid] = st.toAluExp(indices);
1266
+ return AluExp.where(valid, exp.substitute({ idx: index }), AluExp.const(exp.dtype, 0));
1267
+ }
1268
+ function threefry2x32(k0, k1, c0, c1) {
1269
+ const rotl32 = (x, r) => (x << r | x >>> 32 - r) >>> 0;
1270
+ const ks0 = k0 >>> 0;
1271
+ const ks1 = k1 >>> 0;
1272
+ const ks2 = (ks0 ^ ks1 ^ 466688986) >>> 0;
1273
+ let x0 = c0 + ks0 >>> 0;
1274
+ let x1 = c1 + ks1 >>> 0;
1275
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 13) ^ x0;
1276
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 15) ^ x0;
1277
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 26) ^ x0;
1278
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 6) ^ x0;
1279
+ x0 = x0 + ks1 >>> 0;
1280
+ x1 = x1 + ks2 + 1 >>> 0;
1281
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 17) ^ x0;
1282
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 29) ^ x0;
1283
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 16) ^ x0;
1284
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 24) ^ x0;
1285
+ x0 = x0 + ks2 >>> 0;
1286
+ x1 = x1 + ks0 + 2 >>> 0;
1287
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 13) ^ x0;
1288
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 15) ^ x0;
1289
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 26) ^ x0;
1290
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 6) ^ x0;
1291
+ x0 = x0 + ks0 >>> 0;
1292
+ x1 = x1 + ks1 + 3 >>> 0;
1293
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 17) ^ x0;
1294
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 29) ^ x0;
1295
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 16) ^ x0;
1296
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 24) ^ x0;
1297
+ x0 = x0 + ks1 >>> 0;
1298
+ x1 = x1 + ks2 + 4 >>> 0;
1299
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 13) ^ x0;
1300
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 15) ^ x0;
1301
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 26) ^ x0;
1302
+ x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 6) ^ x0;
1303
+ x0 = x0 + ks2 >>> 0;
1304
+ x1 = x1 + ks0 + 5 >>> 0;
1305
+ return [x0, x1];
1306
+ }
1307
+
1308
+ //#endregion
1309
+ //#region src/shape.ts
1310
+ const jstr = JSON.stringify;
1311
+ /** Remove "1" dimensions from the strides list. */
1312
+ function canonicalizeStrides(shape, strides) {
1313
+ const newStrides = [];
1314
+ for (let i = 0; i < shape.length; i++) if (shape[i] === 1) newStrides.push(0);
1315
+ else newStrides.push(strides[i]);
1316
+ return newStrides;
1317
+ }
1318
+ /** Get the strides for a shape in default row-major order. */
1319
+ function defaultStrides(shape) {
1320
+ if (shape.length === 0) return [];
1321
+ const strides = rep(shape.length, 1);
1322
+ for (let i = shape.length - 1; i > 0; i--) strides[i - 1] = shape[i] * strides[i];
1323
+ return canonicalizeStrides(shape, strides);
1324
+ }
1325
+ /** Merge contiguous subparts or zero-strided dimensions in a view. */
1326
+ function mergeDims(shape, strides, mask) {
1327
+ if (shape.length === 0) return [];
1328
+ if (shape.length !== strides.length || mask && shape.length !== mask.length) throw new Error("internal: invalid args to mergeDims");
1329
+ const ret = [[
1330
+ shape[0],
1331
+ strides[0],
1332
+ strides[0] !== 0 ? shape[0] : 0
1333
+ ]];
1334
+ let merging = mask ? mask[0][1] - mask[0][0] === 1 : shape[0] === 1;
1335
+ for (let i = 1; i < shape.length; i++) {
1336
+ const [s, st] = [shape[i], strides[i]];
1337
+ if (s === 1) continue;
1338
+ const [lastS, lastSt, lastPreExpandS] = ret[ret.length - 1];
1339
+ if (merging || lastSt === s * st) ret[ret.length - 1] = [
1340
+ lastS * s,
1341
+ st,
1342
+ merging ? s : lastPreExpandS * s
1343
+ ];
1344
+ else ret.push([
1345
+ s,
1346
+ st,
1347
+ s
1348
+ ]);
1349
+ merging = mask ? mask[i][1] - mask[i][0] === 1 : false;
1350
+ }
1351
+ return ret;
1352
+ }
1353
+ /** Return the new mask if a reshape if possible, otherwise `null`. */
1354
+ function reshapeMask(maskInput, oldShape, newShape) {
1355
+ const newMask = [];
1356
+ let rMasksI = maskInput.length;
1357
+ let rShapeI = oldShape.length;
1358
+ let rNewShapeI = newShape.length;
1359
+ const rMasks = () => rMasksI ? maskInput[--rMasksI] : [0, 1];
1360
+ const rShape = () => rShapeI ? oldShape[--rShapeI] : 1;
1361
+ const rNewShape = () => rNewShapeI ? newShape[--rNewShapeI] : 1;
1362
+ let currStride = 1;
1363
+ let [oldDim, newDim, mask] = [
1364
+ rShape(),
1365
+ rNewShape(),
1366
+ rMasks()
1367
+ ];
1368
+ while (newMask.length < newShape.length) {
1369
+ const [l, r] = mask;
1370
+ const nextStride = newDim * currStride;
1371
+ if (oldDim === nextStride) {
1372
+ newMask.push([intdiv(l, currStride), intdiv(r - 1, currStride) + 1]);
1373
+ currStride = 1;
1374
+ [oldDim, newDim, mask] = [
1375
+ rShape(),
1376
+ rNewShape(),
1377
+ rMasks()
1378
+ ];
1379
+ } else if (oldDim > nextStride) {
1380
+ if (oldDim % nextStride !== 0) return null;
1381
+ if ((l % nextStride !== 0 || r % nextStride !== 0) && intdiv(l, nextStride) !== intdiv(r - 1, nextStride)) return null;
1382
+ newMask.push([intdiv(l % nextStride, currStride), intdiv((r - 1) % nextStride, currStride) + 1]);
1383
+ [currStride, newDim] = [nextStride, rNewShape()];
1384
+ } else {
1385
+ const nextMask = rMasks();
1386
+ if (!deepEqual(mask, [0, oldDim]) && l !== r && nextMask[1] - nextMask[0] !== 1) return null;
1387
+ mask = [nextMask[0] * oldDim + l, (nextMask[1] - 1) * oldDim + r];
1388
+ oldDim *= rShape();
1389
+ }
1390
+ }
1391
+ return newMask.reverse();
1392
+ }
1393
+ /**
1394
+ * A multidimensional view into memory. An array can be thought of as the
1395
+ * combination of a linear buffer of memory, along with a `View`.
1396
+ *
1397
+ * Formula for getting a data point is basically:
1398
+ * 1. Check if ∀i. 0 <= dim[i] < shape[i], otherwise out of bounds.
1399
+ * 2. If mask exists, and ∃i. dim[i] ∉ mask[i], return 0.
1400
+ * 2. Otherwise, look at this memory address: offset + ∑(strides[i] * dim[i]).
1401
+ */
1402
+ var View = class View {
1403
+ #size;
1404
+ #contiguous;
1405
+ constructor(shape, strides, offset, mask) {
1406
+ this.shape = shape;
1407
+ this.strides = strides;
1408
+ this.offset = offset;
1409
+ this.mask = mask;
1410
+ }
1411
+ static create(shape, strides, offset = 0, mask = null) {
1412
+ if (shape.some((s) => s < 0)) throw new Error("View shape must be non-negative");
1413
+ strides = strides ? canonicalizeStrides(shape, strides) : defaultStrides(shape);
1414
+ if (shape.includes(0)) return new View(shape, rep(shape.length, 0), 0, null);
1415
+ if (mask !== null && mask.every(([b, e], i) => b === 0 && e === shape[i])) mask = null;
1416
+ if (mask !== null) {
1417
+ const elimDims = [];
1418
+ let hasNoData = false;
1419
+ for (let i = 0; i < shape.length; i++) {
1420
+ const [b, e] = mask[i];
1421
+ if (b + 1 >= e) elimDims.push(i);
1422
+ if (b >= e) hasNoData = true;
1423
+ }
1424
+ if (elimDims.length) {
1425
+ if (hasNoData) {
1426
+ strides = rep(shape.length, 0);
1427
+ offset = 0;
1428
+ mask = rep(shape.length, () => [0, 0]);
1429
+ }
1430
+ for (const i of elimDims) {
1431
+ offset += strides[i] * mask[i][0];
1432
+ strides[i] = 0;
1433
+ }
1434
+ }
1435
+ }
1436
+ return new View(shape, strides, offset, mask);
1437
+ }
1438
+ get ndim() {
1439
+ return this.shape.length;
1440
+ }
1441
+ get size() {
1442
+ if (this.#size === void 0) this.#size = prod(this.shape);
1443
+ return this.#size;
1444
+ }
1445
+ /** Whether this is a default, contiguous, unaltered view of the data (identity). */
1446
+ get contiguous() {
1447
+ if (this.#contiguous === void 0) this.#contiguous = this.size === 0 || this.offset === 0 && this.mask === null && deepEqual(this.strides, defaultStrides(this.shape));
1448
+ return this.#contiguous;
1449
+ }
1450
+ /** Return the range of data being indexed in this view, or [0, 0] if none. */
1451
+ dataRange() {
1452
+ if (this.size === 0 || this.mask && this.mask[0][0] === this.mask[0][1]) return [0, 0];
1453
+ let min = this.offset;
1454
+ let max = this.offset;
1455
+ for (let i = 0; i < this.ndim; i++) {
1456
+ let [lo, hi] = this.mask ? this.mask[i] : [0, this.shape[i]];
1457
+ --hi;
1458
+ const s = this.strides[i];
1459
+ if (s > 0) {
1460
+ min += s * lo;
1461
+ max += s * hi;
1462
+ } else if (s < 0) {
1463
+ min += s * hi;
1464
+ max += s * lo;
1465
+ }
1466
+ }
1467
+ return [min, max + 1];
1468
+ }
1469
+ /** Produce an AluExp for evaluating this view at an index. */
1470
+ toAluExp(idxs) {
1471
+ let iexpr = AluExp.i32(this.offset);
1472
+ let vexpr = AluExp.bool(true);
1473
+ for (let i = this.ndim - 1; i >= 0; i--) {
1474
+ const idx = idxs[i];
1475
+ if (this.shape[i] !== 1 && this.strides[i] !== 0) iexpr = AluExp.add(AluExp.mul(idx, AluExp.i32(this.strides[i])), iexpr);
1476
+ if (this.mask) {
1477
+ if (this.mask[i][0] !== 0) vexpr = AluExp.mul(AluExp.cmplt(idx, AluExp.i32(this.mask[i][0])).not(), vexpr);
1478
+ if (this.mask[i][1] !== this.shape[i]) vexpr = AluExp.mul(AluExp.cmplt(idx, AluExp.i32(this.mask[i][1])), vexpr);
1479
+ }
1480
+ }
1481
+ return [iexpr, vexpr];
1482
+ }
1483
+ /**
1484
+ * Try to compose this view with another one. `this` view is applied first,
1485
+ * followed by the argument. If this is not possible for the specific views,
1486
+ * return `null` instead.
1487
+ *
1488
+ * If composable, return a combined view with the same shape as `v1`.
1489
+ *
1490
+ * This is very tricky. The shapes of v1 and v2 may be different, and in that
1491
+ * case, we do some math to figure out whether they're compatible.
1492
+ */
1493
+ compose(v1) {
1494
+ const v2 = this;
1495
+ if (v2.contiguous) return v1;
1496
+ if (v1.contiguous) {
1497
+ if (deepEqual(v1.shape, v2.shape)) return v2;
1498
+ if (v1.size === v2.size) {
1499
+ const ret = v2.reshape(v1.shape);
1500
+ if (ret !== null) return ret;
1501
+ }
1502
+ }
1503
+ if (v1.mask !== null) {
1504
+ const newV1 = v1.shrink(v1.mask);
1505
+ const merged = v2.compose(newV1);
1506
+ return merged ? merged.pad(zip(v1.mask, v1.shape).map(([m, s]) => [m[0], s - m[1]])) : null;
1507
+ }
1508
+ const origin = unravel(v2.shape, v1.offset);
1509
+ const terms = rep(v2.ndim, () => []);
1510
+ const strides = rep(v1.ndim, 0);
1511
+ for (let d1 = 0; d1 < v1.strides.length; d1++) {
1512
+ const st = v1.strides[d1];
1513
+ if (st === 0) continue;
1514
+ const unravelOffset = unravel(v2.shape, v1.offset + st);
1515
+ for (let d2 = 0; d2 < v2.ndim; d2++) {
1516
+ const o = origin[d2];
1517
+ const diff = unravelOffset[d2] - o;
1518
+ if (diff === 0) continue;
1519
+ terms[d2].push([d1, diff]);
1520
+ strides[d1] += diff * v2.strides[d2];
1521
+ }
1522
+ }
1523
+ let [mergedSize, mergedTermMin, mergedTermMax] = [
1524
+ 1,
1525
+ 0,
1526
+ 0
1527
+ ];
1528
+ const extents = [];
1529
+ for (let i = v2.ndim - 1; i >= 0; i--) {
1530
+ const term = terms[i];
1531
+ const s = v2.shape[i];
1532
+ let [tmin, tmax] = [origin[i], origin[i]];
1533
+ for (const [d1, s1] of term) if (s1 > 0) tmax += (v1.shape[d1] - 1) * s1;
1534
+ else if (s1 < 0) tmin += (v1.shape[d1] - 1) * s1;
1535
+ mergedTermMin += tmin * mergedSize;
1536
+ mergedTermMax += tmax * mergedSize;
1537
+ mergedSize *= s;
1538
+ if (mergedTermMin >= 0 && mergedTermMax < mergedSize) {
1539
+ extents.push([
1540
+ mergedSize,
1541
+ mergedTermMin,
1542
+ mergedTermMax
1543
+ ]);
1544
+ [mergedSize, mergedTermMin, mergedTermMax] = [
1545
+ 1,
1546
+ 0,
1547
+ 0
1548
+ ];
1549
+ }
1550
+ }
1551
+ if (mergedTermMin !== 0 || mergedTermMax !== 0) return null;
1552
+ extents.reverse();
1553
+ const v2Shape = extents.map(([s]) => s);
1554
+ if (!deepEqual(v2Shape, v2.shape)) {
1555
+ const reshapedV2 = v2.reshape(v2Shape);
1556
+ if (reshapedV2 === null) return null;
1557
+ if (!deepEqual(reshapedV2.shape, v2.shape)) return reshapedV2.compose(v1);
1558
+ }
1559
+ if (v2.mask !== null) {
1560
+ const newB = rep(v1.ndim, 0);
1561
+ const newE = v1.shape.slice();
1562
+ let bad = false;
1563
+ for (let d2 = 0; d2 < v2.ndim; d2++) {
1564
+ const [b, e] = v2.mask[d2];
1565
+ const o = origin[d2];
1566
+ const term = terms[d2];
1567
+ const [_, tmin, tmax] = extents[d2];
1568
+ if (b <= tmin && tmax < e) continue;
1569
+ if (term.length !== 1) if (term.length === 0 && newE.length) newE[0] = 0;
1570
+ else bad = true;
1571
+ else {
1572
+ const [d1, s1] = term[0];
1573
+ newB[d1] = Math.max(newB[d1], Math.ceil((s1 > 0 ? b - o : e - o - 1) / s1));
1574
+ newE[d1] = Math.min(newE[d1], Math.floor((s1 < 0 ? b - o : e - o - 1) / s1) + 1);
1575
+ }
1576
+ }
1577
+ for (let d1 = 0; d1 < v1.ndim; d1++) if (newB[d1] !== 0 || newE[d1] !== v1.shape[d1]) return v2.compose(View.create(v1.shape, v1.strides, v1.offset, zip(newB, newE)));
1578
+ if (bad) return null;
1579
+ }
1580
+ let finalOffset = v2.offset;
1581
+ for (let d2 = 0; d2 < v2.ndim; d2++) finalOffset += origin[d2] * v2.strides[d2];
1582
+ return View.create(v1.shape, strides, finalOffset, null);
1583
+ }
1584
+ /** Attempt to simplify this view into a smaller reshaped form. */
1585
+ minify() {
1586
+ const minShape = mergeDims(this.shape, this.strides, this.mask).map((x) => x[0]);
1587
+ const nv = this.reshape(minShape);
1588
+ return nv ? nv : this;
1589
+ }
1590
+ /** Pad the view with zeros on each dimension. */
1591
+ pad(arg) {
1592
+ if (arg.length !== this.ndim || !arg.every(([b, e]) => b >= 0 && e >= 0)) throw new Error(`invalid pad ${jstr(arg)} for ${jstr(this.shape)}`);
1593
+ if (arg.every(([b, e]) => b === 0 && e === 0)) return this;
1594
+ const zvarg = arg.map(([b, e], i) => [-b, this.shape[i] + e]);
1595
+ const mask = arg.map(([b, _e], i) => [b, this.shape[i] + b]);
1596
+ return this.#unsafeResize(zvarg, mask);
1597
+ }
1598
+ /** Shrink the view by taking a subarray. */
1599
+ shrink(arg) {
1600
+ if (arg.length !== this.ndim || !arg.every(([b, e], i) => 0 <= b && b <= e && e <= this.shape[i])) throw new Error(`invalid shrink ${jstr(arg)} for ${jstr(this.shape)}`);
1601
+ return this.#unsafeResize(arg);
1602
+ }
1603
+ #unsafeResize(arg, mask) {
1604
+ const offset = this.strides.map((s, i) => s * arg[i][0]).reduce((a, b) => a + b, 0);
1605
+ if (this.mask) {
1606
+ const nmask = this.mask.map(([mx, my], i) => [Math.max(0, Math.min(mx - arg[i][0], arg[i][1] - arg[i][0])), Math.max(0, Math.min(my - arg[i][0], arg[i][1] - arg[i][0]))]);
1607
+ mask = mask ? mask.map(([mx, my], i) => [Math.max(mx, nmask[i][0]), Math.min(my, nmask[i][1])]) : nmask;
1608
+ }
1609
+ return View.create(arg.map(([b, e]) => e - b), this.strides, this.offset + offset, mask);
1610
+ }
1611
+ /** Expand one or more axes with length "1" by repeating the data. */
1612
+ expand(newShape) {
1613
+ if (newShape.length !== this.ndim) throw new Error(`Can't expand ${jstr(this.shape)} into ${jstr(newShape)}`);
1614
+ for (let i = 0; i < this.ndim; i++) if (newShape[i] !== this.shape[i] && this.shape[i] !== 1) throw new Error(`Can't expand ${jstr(this.shape)} into ${jstr(newShape)}`);
1615
+ if (this.size === 0) return View.create(newShape);
1616
+ const mask = this.mask ? this.mask.map((m, i) => this.shape[i] === newShape[i] ? m : m[0] === 0 && m[1] === 1 ? [0, newShape[i]] : [0, 0]) : null;
1617
+ return View.create(newShape, this.strides, this.offset, mask);
1618
+ }
1619
+ /** Permute the axes of an array. */
1620
+ permute(axis) {
1621
+ if (!isPermutation(axis, this.ndim)) throw new Error(`Invalid permutation ${jstr(axis)} of len ${this.ndim}`);
1622
+ const newShape = axis.map((a) => this.shape[a]);
1623
+ const newStrides = axis.map((a) => this.strides[a]);
1624
+ const newMask = this.mask ? axis.map((a) => this.mask[a]) : null;
1625
+ return View.create(newShape, newStrides, this.offset, newMask);
1626
+ }
1627
+ /** Flip (reverse) one or more axes of the view. */
1628
+ flip(arg) {
1629
+ if (arg.length !== this.ndim) throw new Error(`Invalid flip ${jstr(arg)} for ${jstr(this.shape)}`);
1630
+ const strides = this.strides.slice();
1631
+ let offset = this.offset;
1632
+ const mask = this.mask ? this.mask.slice() : null;
1633
+ for (let i = 0; i < this.ndim; i++) {
1634
+ const s = this.shape[i];
1635
+ if (arg[i]) {
1636
+ strides[i] = -strides[i];
1637
+ offset += (s - 1) * this.strides[i];
1638
+ if (mask) mask[i] = [s - mask[i][1], s - mask[i][0]];
1639
+ }
1640
+ }
1641
+ return View.create(this.shape, strides, offset, mask);
1642
+ }
1643
+ /** Reshape the view into a new shape. */
1644
+ reshape(newShape) {
1645
+ if (deepEqual(this.shape, newShape)) return this;
1646
+ if (newShape.some((s) => s < 0)) throw new Error(`Reshape cannot have negative numbers ${jstr(newShape)}`);
1647
+ if (this.size !== prod(newShape)) throw new Error(`Reshape size ${jstr(this.shape)} -> ${jstr(newShape)}`);
1648
+ if (this.size === 0) return View.create(newShape);
1649
+ if (newShape.length === 0 && this.mask?.some(([b, e]) => b === e)) return null;
1650
+ if (this.contiguous) return View.create(newShape);
1651
+ const rStrides = [];
1652
+ const merge = mergeDims(this.shape, this.strides, this.mask);
1653
+ let rShapeIdx = newShape.length;
1654
+ for (let i = merge.length - 1; i >= 0; i--) {
1655
+ let [mergedSize, newStride, realSize] = merge[i];
1656
+ let acc = 1;
1657
+ while (acc < mergedSize && rShapeIdx > 0) {
1658
+ const newDim = newShape[--rShapeIdx];
1659
+ rStrides.push(newStride * acc);
1660
+ acc *= newDim;
1661
+ if (acc >= realSize) newStride = 0;
1662
+ }
1663
+ if (acc !== mergedSize) return null;
1664
+ }
1665
+ const newStrides = rep(newShape.length - rStrides.length, 0).concat(rStrides.reverse());
1666
+ if (!this.mask) return View.create(newShape, newStrides, this.offset);
1667
+ const newMask = reshapeMask(this.mask, this.shape, newShape);
1668
+ if (!newMask) return null;
1669
+ let newOffset = this.offset;
1670
+ for (let i = 0; i < this.ndim; i++) newOffset += this.strides[i] * this.mask[i][0];
1671
+ for (let i = 0; i < newShape.length; i++) newOffset -= newStrides[i] * newMask[i][0];
1672
+ return View.create(newShape, newStrides, newOffset, newMask);
1673
+ }
1674
+ };
1675
+ /**
1676
+ * Find position of `offset` in each dimension within an existing shape. Like
1677
+ * `numpy.unravel_index` in behavior.
1678
+ */
1679
+ function unravel(shape, offset) {
1680
+ let acc = 1;
1681
+ const idxs = [];
1682
+ for (let i = shape.length - 1; i >= 0; i--) {
1683
+ const d = shape[i];
1684
+ idxs.push(Math.floor(offset / acc) % d);
1685
+ acc *= d;
1686
+ }
1687
+ return idxs.reverse();
1688
+ }
1689
+ /** Generate a list of AluExp for computing unravel(). */
1690
+ function unravelAlu(shape, offset) {
1691
+ let acc = 1;
1692
+ const idxs = [];
1693
+ for (let i = shape.length - 1; i >= 0; i--) {
1694
+ const d = shape[i];
1695
+ idxs.push(AluExp.mod(AluExp.idiv(offset, AluExp.i32(acc)), AluExp.i32(d)));
1696
+ acc *= d;
1697
+ }
1698
+ return idxs.reverse();
1699
+ }
1700
+ /**
1701
+ * Array shape after applying movement operations, as a series of views.
1702
+ *
1703
+ * Each view is applied, then treated as if it were a contiguous array of its
1704
+ * shape, then used as the virtual buffer for the next view.
1705
+ */
1706
+ var ShapeTracker = class ShapeTracker {
1707
+ constructor(views) {
1708
+ this.views = views;
1709
+ }
1710
+ /** Compose this shape tracker with another, applying it after this one. */
1711
+ compose(other) {
1712
+ if (this.contiguous) return other;
1713
+ let ret = this;
1714
+ for (const v of other.views) ret = new ShapeTracker(ret.views.concat(v)).simplify();
1715
+ return ret;
1716
+ }
1717
+ static fromShape(shape) {
1718
+ return new ShapeTracker([View.create(shape)]);
1719
+ }
1720
+ get contiguous() {
1721
+ return this.views.length === 1 && this.views[0].contiguous;
1722
+ }
1723
+ get consecutive() {
1724
+ return this.views.length === 1 && this.views[0].mask === null && deepEqual(this.views[0].strides, defaultStrides(this.views[0].shape));
1725
+ }
1726
+ get lastStrides() {
1727
+ return this.views[this.views.length - 1].strides;
1728
+ }
1729
+ get shape() {
1730
+ return this.views[this.views.length - 1].shape;
1731
+ }
1732
+ get size() {
1733
+ return this.views[this.views.length - 1].size;
1734
+ }
1735
+ toAluExp(idxs) {
1736
+ let [iexpr, vexpr] = this.views[this.views.length - 1].toAluExp(idxs);
1737
+ for (let i = this.views.length - 2; i >= 0; i--) {
1738
+ const view = this.views[i].minify();
1739
+ const exprs = view.toAluExp(unravelAlu(view.shape, iexpr));
1740
+ iexpr = exprs[0];
1741
+ vexpr = AluExp.mul(vexpr, exprs[1]);
1742
+ }
1743
+ return [iexpr.simplify(), vexpr.simplify()];
1744
+ }
1745
+ simplify() {
1746
+ const views = this.views.slice();
1747
+ while (views.length >= 2) {
1748
+ const newView = views[views.length - 2].compose(views[views.length - 1]);
1749
+ if (newView === null) break;
1750
+ views.splice(views.length - 2, 2, newView);
1751
+ }
1752
+ return new ShapeTracker(views);
1753
+ }
1754
+ pad(arg) {
1755
+ return new ShapeTracker(applyLast(this.views, (x) => x.pad(arg)));
1756
+ }
1757
+ shrink(arg) {
1758
+ return new ShapeTracker(applyLast(this.views, (x) => x.shrink(arg)));
1759
+ }
1760
+ expand(newShape) {
1761
+ return new ShapeTracker(applyLast(this.views, (x) => x.expand(newShape)));
1762
+ }
1763
+ permute(axis) {
1764
+ return new ShapeTracker(applyLast(this.views, (x) => x.permute(axis)));
1765
+ }
1766
+ flip(arg) {
1767
+ return new ShapeTracker(applyLast(this.views, (x) => x.flip(arg)));
1768
+ }
1769
+ reshape(newShape) {
1770
+ const newView = this.views[this.views.length - 1].reshape(newShape);
1771
+ return new ShapeTracker(newView === null ? this.views.concat(View.create(newShape)) : this.views.toSpliced(this.views.length - 1, 1, newView));
1772
+ }
1773
+ /** Broadcast along the given new axes, then expand the shape. */
1774
+ broadcast(newShape, axis) {
1775
+ let st = this;
1776
+ if (axis.length > 0) {
1777
+ const unsqueezed = [...st.shape];
1778
+ for (const i of axis.toSorted()) unsqueezed.splice(i, 0, 1);
1779
+ st = st.reshape(unsqueezed);
1780
+ }
1781
+ return st.expand(newShape);
1782
+ }
1783
+ /**
1784
+ * Repeat data in each axis by a positive number of repetitions.
1785
+ *
1786
+ * - If `tile` is true (default): [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
1787
+ * - If `tile` is false: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
1788
+ */
1789
+ repeat(reps, tile = true) {
1790
+ if (reps.length > this.shape.length) throw new Error(`Too many repeats ${jstr(reps)} for shape ${jstr(this.shape)}`);
1791
+ if (reps.some((c) => c <= 0)) throw new Error(`Invalid repeats ${jstr(reps)}`);
1792
+ if (reps.length === 0) return this;
1793
+ const noop = this.shape.slice(0, -reps.length);
1794
+ const shape = this.shape.slice(-reps.length);
1795
+ return this.broadcast([...noop, ...shape.flatMap((s, i) => tile ? [reps[i], s] : [s, reps[i]])], shape.map((_, i) => noop.length + 2 * i + (tile ? 0 : 1))).reshape([...noop, ...shape.map((s, i) => s * reps[i])]);
1796
+ }
1797
+ /** Move axis i to axis j. */
1798
+ moveaxis(i, j) {
1799
+ const perm = range(this.shape.length);
1800
+ perm.splice(i, 1);
1801
+ perm.splice(j, 0, i);
1802
+ return this.permute(perm);
1803
+ }
1804
+ /** Like pad(), but allows for negative values. */
1805
+ padOrShrink(arg) {
1806
+ const padArg = [];
1807
+ const shrinkArg = [];
1808
+ for (let i = 0; i < arg.length; i++) {
1809
+ const [b, e] = arg[i];
1810
+ if (b < -this.shape[i] || e < -this.shape[i] || b + e < -this.shape[i]) throw new Error(`Invalid padOrShrink ${jstr(arg)} for ${jstr(this.shape)}`);
1811
+ padArg.push([Math.max(0, b), Math.max(0, e)]);
1812
+ shrinkArg.push([Math.max(0, -b), this.shape[i] - Math.max(0, -e)]);
1813
+ }
1814
+ return this.shrink(shrinkArg).pad(padArg);
1815
+ }
1816
+ };
1817
+ function applyLast(ar, f) {
1818
+ return ar.toSpliced(ar.length - 1, 1, f(ar[ar.length - 1]));
1819
+ }
1820
+
1821
+ //#endregion
1822
+ //#region src/tuner.ts
1823
+ /** Stores dimensions of the kernel's applied shape. Globals start at 0. */
1824
+ var TuneDims = class {
1825
+ st;
1826
+ outputSt;
1827
+ groups;
1828
+ reduce;
1829
+ unroll;
1830
+ upcast;
1831
+ get end() {
1832
+ return this.st.shape.length;
1833
+ }
1834
+ constructor(shape) {
1835
+ this.st = ShapeTracker.fromShape(shape);
1836
+ this.outputSt = ShapeTracker.fromShape(shape.slice(0, -1));
1837
+ this.groups = this.st.shape.length - 1;
1838
+ this.reduce = this.st.shape.length - 1;
1839
+ this.unroll = this.st.shape.length;
1840
+ this.upcast = this.st.shape.length;
1841
+ }
1842
+ applyLocal(axis, amount) {
1843
+ if (axis >= this.groups) throw new Error("Cannot localize reduction axis");
1844
+ const length = this.st.shape[axis];
1845
+ if (length % amount !== 0) throw new Error(`Localize by ${amount} on axis length ${length}`);
1846
+ if (length !== amount) {
1847
+ this.groups++, this.reduce++, this.unroll++, this.upcast++;
1848
+ this.st = this.st.reshape([
1849
+ ...this.st.shape.slice(0, axis),
1850
+ length / amount,
1851
+ amount,
1852
+ ...this.st.shape.slice(axis + 1)
1853
+ ]);
1854
+ this.outputSt = this.outputSt.reshape([
1855
+ ...this.outputSt.shape.slice(0, axis),
1856
+ length / amount,
1857
+ amount,
1858
+ ...this.outputSt.shape.slice(axis + 1)
1859
+ ]);
1860
+ axis++;
1861
+ }
1862
+ this.st = this.st.permute([
1863
+ ...range(axis),
1864
+ ...range(axis + 1, this.groups),
1865
+ axis,
1866
+ ...range(this.groups, this.st.shape.length)
1867
+ ]);
1868
+ this.outputSt = this.outputSt.permute([
1869
+ ...range(axis),
1870
+ ...range(axis + 1, this.groups),
1871
+ axis,
1872
+ ...range(this.groups, this.outputSt.shape.length)
1873
+ ]);
1874
+ }
1875
+ applyUpcast(axis, amount) {
1876
+ if (axis >= this.groups) throw new Error("Cannot upcast along reduction axis");
1877
+ const length = this.st.shape[axis];
1878
+ if (length % amount !== 0) throw new Error(`Upcast by ${amount} on axis length ${length}`);
1879
+ this.st = this.st.reshape([
1880
+ ...this.st.shape.slice(0, axis),
1881
+ length / amount,
1882
+ amount,
1883
+ ...this.st.shape.slice(axis + 1)
1884
+ ]).permute([
1885
+ ...range(axis + 1),
1886
+ ...range(axis + 2, this.st.shape.length + 1),
1887
+ axis + 1
1888
+ ]);
1889
+ this.outputSt = this.outputSt.reshape([
1890
+ ...this.outputSt.shape.slice(0, axis),
1891
+ length / amount,
1892
+ amount,
1893
+ ...this.outputSt.shape.slice(axis + 1)
1894
+ ]).permute([
1895
+ ...range(axis + 1),
1896
+ ...range(axis + 2, this.outputSt.shape.length + 1),
1897
+ axis + 1
1898
+ ]);
1899
+ }
1900
+ applyUnroll(axis, amount) {
1901
+ if (axis < this.groups) throw new Error("Cannot unroll non-reduce axis");
1902
+ if (axis >= this.unroll) throw new Error("Axis already unrolled");
1903
+ const length = this.st.shape[axis];
1904
+ if (length % amount !== 0) throw new Error(`Unroll by ${amount} on axis length ${length}`);
1905
+ if (length === amount) {
1906
+ this.st = this.st.permute([
1907
+ ...range(axis),
1908
+ ...range(axis + 1, this.upcast),
1909
+ axis,
1910
+ ...range(this.upcast, this.st.shape.length)
1911
+ ]);
1912
+ if (axis < this.reduce) this.reduce--;
1913
+ this.unroll--;
1914
+ } else {
1915
+ this.st = this.st.reshape([
1916
+ ...this.st.shape.slice(0, axis),
1917
+ length / amount,
1918
+ amount,
1919
+ ...this.st.shape.slice(axis + 1)
1920
+ ]).permute([
1921
+ ...range(axis + 1),
1922
+ ...range(axis + 2, this.upcast + 1),
1923
+ axis + 1,
1924
+ ...range(this.upcast + 1, this.st.shape.length + 1)
1925
+ ]);
1926
+ this.upcast++;
1927
+ }
1928
+ }
1929
+ };
1930
+ /** Tuning step that does not apply any optimization. */
1931
+ function tuneNullopt(kernel) {
1932
+ const vars = {};
1933
+ vars.gidx = AluExp.special(DType.Int32, "gidx", kernel.size);
1934
+ if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
1935
+ return {
1936
+ exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
1937
+ outputIdxExp: AluExp.special(DType.Int32, "gidx", kernel.size),
1938
+ threadCount: kernel.size,
1939
+ size: { reduce: kernel.reduction ? kernel.reduction.size : 0 }
1940
+ };
1941
+ }
1942
+ /** Tuning for WebGPU kernels. */
1943
+ function tuneWebgpu(kernel) {
1944
+ const { exp, reduction } = kernel;
1945
+ if (!reduction) return tuneNullopt(kernel);
1946
+ const globalIndexes = exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex);
1947
+ if (globalIndexes.length > 0) {
1948
+ if (DEBUG >= 4) console.info("Tuning: Found GlobalIndex ops, skipping opt.");
1949
+ return tuneNullopt(kernel);
1950
+ }
1951
+ const globalViews = exp.collect((exp$1) => exp$1.op === AluOp.GlobalView);
1952
+ if (globalViews.length === 0) {
1953
+ if (DEBUG >= 4) console.info("Tuning: No GlobalView ops found in kernel.");
1954
+ return tuneNullopt(kernel);
1955
+ }
1956
+ const shape = globalViews[0].arg[1].shape;
1957
+ const expectedSrc = [...unravelAlu(shape.slice(0, -1), AluVar.gidx), AluVar.ridx].map((e) => e.simplify());
1958
+ for (const gv of globalViews) if (!gv.src.length || !deepEqual(gv.src, expectedSrc)) {
1959
+ if (DEBUG >= 4) console.info("Tuning: GlobalView src[] not consistent with reduction.");
1960
+ return tuneNullopt(kernel);
1961
+ }
1962
+ if (shape[shape.length - 1] !== reduction.size) throw new Error("Invariant violation: shape doesn't match reduction size.");
1963
+ const sts = globalViews.map((gv) => gv.arg[1]);
1964
+ for (const st of sts) if (!deepEqual(st.shape, shape)) throw new Error("Invariant violation: GlobalView shape mismatch");
1965
+ const dim = new TuneDims(shape);
1966
+ const upcastedAxis = /* @__PURE__ */ new Set();
1967
+ while (prod(dim.st.shape.slice(0, dim.groups)) >= 1024) {
1968
+ const choices = [];
1969
+ const composedSts = sts.map((st) => st.compose(dim.st));
1970
+ for (let axis = 0; axis < dim.groups; axis++) for (const amount of [3, 4]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
1971
+ let nonzeroStrides = 0;
1972
+ let totalStrides = 0;
1973
+ for (const st of composedSts) {
1974
+ nonzeroStrides += st.lastStrides[axis] > 0 ? 1 : 0;
1975
+ totalStrides += st.lastStrides[axis];
1976
+ }
1977
+ choices.push([
1978
+ nonzeroStrides,
1979
+ totalStrides,
1980
+ axis,
1981
+ amount
1982
+ ]);
1983
+ }
1984
+ if (choices.length > 0) {
1985
+ choices.sort(lexCompare);
1986
+ dim.applyUpcast(choices[0][2], choices[0][3]);
1987
+ upcastedAxis.add(choices[0][2]);
1988
+ } else break;
1989
+ }
1990
+ if (/chrome/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
1991
+ const s = dim.st.shape[dim.unroll - 1];
1992
+ if (s <= 32) dim.applyUnroll(dim.reduce, s);
1993
+ else for (const splits of [4]) if (s % splits === 0) {
1994
+ dim.applyUnroll(dim.unroll - 1, splits);
1995
+ break;
1996
+ }
1997
+ }
1998
+ for (const ax of Array.from(upcastedAxis).sort()) {
1999
+ const s = dim.st.shape[ax];
2000
+ for (const amount of [8, 4]) if (s % amount === 0) {
2001
+ dim.applyLocal(ax, amount);
2002
+ break;
2003
+ }
2004
+ }
2005
+ const indices = [];
2006
+ const addIndices = (s, exp$1) => {
2007
+ if (s.length === 0) return;
2008
+ else if (s.length === 1) indices.push(exp$1);
2009
+ else indices.push(...unravelAlu(s, exp$1));
2010
+ };
2011
+ if (0 < dim.groups) {
2012
+ const s = dim.st.shape.slice(0, dim.groups);
2013
+ addIndices(s, AluExp.special(DType.Int32, "gidx", prod(s)));
2014
+ }
2015
+ if (dim.groups < dim.reduce) {
2016
+ const s = dim.st.shape.slice(dim.groups, dim.reduce);
2017
+ addIndices(s, AluExp.special(DType.Int32, "group", prod(s)));
2018
+ }
2019
+ if (dim.reduce <= dim.unroll) {
2020
+ const s = dim.st.shape.slice(dim.reduce, dim.unroll);
2021
+ addIndices(s, AluExp.special(DType.Int32, "ridx", prod(s)));
2022
+ }
2023
+ if (dim.unroll < dim.upcast) {
2024
+ const s = dim.st.shape.slice(dim.unroll, dim.upcast);
2025
+ addIndices(s, AluVar.unroll);
2026
+ }
2027
+ if (dim.upcast < dim.end) {
2028
+ const s = dim.st.shape.slice(dim.upcast);
2029
+ addIndices(s, AluVar.upcast);
2030
+ }
2031
+ let newExp = exp.rewrite((exp$1) => {
2032
+ if (exp$1.op === AluOp.GlobalView) {
2033
+ const gid = exp$1.arg[0];
2034
+ const st = exp$1.arg[1];
2035
+ return accessorGlobal(exp$1.dtype, gid, st.compose(dim.st), indices);
2036
+ }
2037
+ });
2038
+ const [iexpr, vexpr] = dim.st.toAluExp(indices);
2039
+ if (vexpr.min !== 1) throw new Error("Invariant violation: vexpr !== true");
2040
+ newExp = newExp.substitute({
2041
+ gidx: AluExp.idiv(iexpr, AluExp.i32(reduction.size)).simplify(),
2042
+ ridx: AluExp.mod(iexpr, AluExp.i32(reduction.size)).simplify()
2043
+ });
2044
+ const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
2045
+ const outputUpcast = dim.outputSt.shape.slice(dim.groups);
2046
+ const [outputIdxExp, _] = dim.outputSt.toAluExp([...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)]);
2047
+ if (prod(dim.st.shape.slice(dim.groups, dim.upcast)) !== reduction.size) throw new Error(`Invariant violation: reduction size ${reduction.size} does not match tuned dims ${JSON.stringify(dim.st.shape.slice(dim.groups, dim.upcast))}`);
2048
+ const size = {
2049
+ groups: prod(dim.st.shape.slice(dim.groups, dim.reduce)),
2050
+ reduce: prod(dim.st.shape.slice(dim.reduce, dim.unroll)),
2051
+ unroll: prod(dim.st.shape.slice(dim.unroll, dim.upcast)),
2052
+ upcast: prod(dim.st.shape.slice(dim.upcast))
2053
+ };
2054
+ return {
2055
+ exp: newExp.simplify(),
2056
+ outputIdxExp: outputIdxExp.simplify(),
2057
+ threadCount: kernel.size / size.upcast * size.groups,
2058
+ size
2059
+ };
2060
+ }
2061
+
2062
+ //#endregion
2063
+ //#region src/backend/cpu.ts
2064
+ /** Most basic implementation of `Backend` for testing. */
2065
+ var CpuBackend = class {
2066
+ type = "cpu";
2067
+ maxArgs = Infinity;
2068
+ #buffers;
2069
+ #nextSlot;
2070
+ constructor() {
2071
+ this.#buffers = /* @__PURE__ */ new Map();
2072
+ this.#nextSlot = 1;
2073
+ }
2074
+ malloc(size, initialData) {
2075
+ const buffer = new Uint8Array(size);
2076
+ if (initialData) {
2077
+ if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
2078
+ buffer.set(initialData);
2079
+ }
2080
+ const slot = this.#nextSlot++;
2081
+ this.#buffers.set(slot, {
2082
+ buffer,
2083
+ ref: 1
2084
+ });
2085
+ return slot;
2086
+ }
2087
+ incRef(slot) {
2088
+ const buffer = this.#buffers.get(slot);
2089
+ if (!buffer) throw new SlotError(slot);
2090
+ buffer.ref++;
2091
+ }
2092
+ decRef(slot) {
2093
+ const buffer = this.#buffers.get(slot);
2094
+ if (!buffer) throw new SlotError(slot);
2095
+ buffer.ref--;
2096
+ if (buffer.ref === 0) this.#buffers.delete(slot);
2097
+ }
2098
+ async read(slot, start, count) {
2099
+ return this.readSync(slot, start, count);
2100
+ }
2101
+ readSync(slot, start, count) {
2102
+ const buffer = this.#getBuffer(slot);
2103
+ if (start === void 0) start = 0;
2104
+ if (count === void 0) count = buffer.byteLength - start;
2105
+ return buffer.slice(start, start + count);
2106
+ }
2107
+ async prepare(kernel) {
2108
+ return this.prepareSync(kernel);
2109
+ }
2110
+ prepareSync(kernel) {
2111
+ return new Executable(kernel, void 0);
2112
+ }
2113
+ dispatch({ kernel }, inputs, outputs) {
2114
+ const { exp } = tuneNullopt(kernel);
2115
+ const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
2116
+ const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
2117
+ const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
2118
+ const inputArrays = inputBuffers.map((buf, i) => {
2119
+ const dtype = usedArgs.get(i);
2120
+ if (!dtype) return null;
2121
+ return dtypedArray(dtype, buf);
2122
+ });
2123
+ const outputArray = dtypedArray(kernel.dtype, outputBuffers[0]);
2124
+ const globals = (gid, bufidx) => {
2125
+ if (gid < 0 || gid >= inputArrays.length) throw new Error("gid out of bounds: " + gid);
2126
+ if (bufidx < 0 || bufidx >= inputArrays[gid].length) throw new Error("bufidx out of bounds: " + bufidx);
2127
+ return inputArrays[gid][bufidx];
2128
+ };
2129
+ if (!kernel.reduction) for (let i = 0; i < kernel.size; i++) outputArray[i] = exp.evaluate({ gidx: i }, globals);
2130
+ else for (let i = 0; i < kernel.size; i++) {
2131
+ let acc = kernel.reduction.identity;
2132
+ for (let j = 0; j < kernel.reduction.size; j++) {
2133
+ const item = exp.evaluate({
2134
+ gidx: i,
2135
+ ridx: j
2136
+ }, globals);
2137
+ acc = kernel.reduction.evaluate(acc, item);
2138
+ }
2139
+ outputArray[i] = kernel.reduction.epilogue.evaluate({ acc });
2140
+ }
2141
+ }
2142
+ #getBuffer(slot) {
2143
+ const buffer = this.#buffers.get(slot);
2144
+ if (!buffer) throw new SlotError(slot);
2145
+ return buffer.buffer;
2146
+ }
2147
+ };
2148
+
2149
+ //#endregion
2150
+ //#region src/backend/wasm/allocator.ts
2151
+ /** Simple tensor memory allocator for WebAssembly linear memory. */
2152
+ var WasmAllocator = class {
2153
+ #memory;
2154
+ #headPtr;
2155
+ #freeLists;
2156
+ #allocatedBuffers;
2157
+ constructor(memory) {
2158
+ this.#memory = memory;
2159
+ this.#headPtr = 64;
2160
+ this.#freeLists = /* @__PURE__ */ new Map();
2161
+ this.#allocatedBuffers = /* @__PURE__ */ new Map();
2162
+ }
2163
+ malloc(size) {
2164
+ if (size === 0) return 0;
2165
+ const sizeClass = this.#findSizeClass(size);
2166
+ const freeList = this.#freeLists.get(sizeClass);
2167
+ let ptr;
2168
+ if (freeList && freeList.length > 0) ptr = freeList.pop();
2169
+ else ptr = this.#bumpAlloc(sizeClass);
2170
+ this.#allocatedBuffers.set(ptr, sizeClass);
2171
+ return ptr;
2172
+ }
2173
+ free(ptr) {
2174
+ if (ptr === 0) return;
2175
+ const sizeClass = this.#allocatedBuffers.get(ptr);
2176
+ if (sizeClass === void 0) throw new Error(`Attempting to free unallocated pointer: ${ptr}`);
2177
+ const freeList = this.#freeLists.get(sizeClass);
2178
+ if (freeList) freeList.push(ptr);
2179
+ else this.#freeLists.set(sizeClass, [ptr]);
2180
+ this.#allocatedBuffers.delete(ptr);
2181
+ }
2182
+ #bumpAlloc(size) {
2183
+ const ptr = this.#headPtr;
2184
+ size = size + 63 & -64;
2185
+ this.#headPtr += size;
2186
+ if (ptr + size > this.#memory.buffer.byteLength) this.#memory.grow((ptr + size + 65535 >> 16) - (this.#memory.buffer.byteLength >> 16));
2187
+ return ptr;
2188
+ }
2189
+ #findSizeClass(size) {
2190
+ if (size <= 512) return size + 63 & -64;
2191
+ if (size <= 2048) return size + 511 & -512;
2192
+ if (size <= 65536) {
2193
+ let sizeClass = 4096;
2194
+ while (sizeClass < size) sizeClass *= 2;
2195
+ return sizeClass;
2196
+ }
2197
+ return size + 65535 & -65536;
2198
+ }
2199
+ getStats() {
2200
+ const freeListSizes = /* @__PURE__ */ new Map();
2201
+ for (const [sizeClass, freeList] of this.#freeLists) if (freeList.length > 0) freeListSizes.set(sizeClass, freeList.length);
2202
+ return {
2203
+ totalAllocated: this.#headPtr,
2204
+ freeListSizes
2205
+ };
2206
+ }
2207
+ };
2208
+
2209
+ //#endregion
2210
+ //#region src/backend/wasm/builtins.ts
2211
+ /**
2212
+ * Approximate e^x.
2213
+ *
2214
+ * Method: range-reduce x = k*ln2 + r with k = round(x/ln2), |r|<=~0.3466
2215
+ * then e^x = 2^k * P(r), where P is 5th-order poly (Taylor).
2216
+ */
2217
+ function wasm_exp(cg) {
2218
+ return cg.function([cg.f32], [cg.f32], () => {
2219
+ const k_f = cg.local.declare(cg.f32);
2220
+ const k = cg.local.declare(cg.i32);
2221
+ const r = cg.local.declare(cg.f32);
2222
+ const p = cg.local.declare(cg.f32);
2223
+ const scale = cg.local.declare(cg.f32);
2224
+ cg.local.get(0);
2225
+ cg.f32.const(1 / Math.LN2);
2226
+ cg.f32.mul();
2227
+ cg.f32.nearest();
2228
+ cg.local.tee(k_f);
2229
+ cg.i32.trunc_sat_f32_s();
2230
+ cg.local.set(k);
2231
+ cg.local.get(k);
2232
+ cg.i32.const(127);
2233
+ cg.i32.gt_s();
2234
+ cg.if(cg.void);
2235
+ cg.f32.const(Infinity);
2236
+ cg.return();
2237
+ cg.end();
2238
+ cg.local.get(k);
2239
+ cg.i32.const(-126);
2240
+ cg.i32.lt_s();
2241
+ cg.if(cg.void);
2242
+ cg.f32.const(0);
2243
+ cg.return();
2244
+ cg.end();
2245
+ cg.local.get(0);
2246
+ cg.local.get(k_f);
2247
+ cg.f32.const(Math.LN2);
2248
+ cg.f32.mul();
2249
+ cg.f32.sub();
2250
+ cg.local.set(r);
2251
+ cg.f32.const(1 / 120);
2252
+ cg.local.get(r);
2253
+ cg.f32.mul();
2254
+ cg.f32.const(1 / 24);
2255
+ cg.f32.add();
2256
+ cg.local.get(r);
2257
+ cg.f32.mul();
2258
+ cg.f32.const(1 / 6);
2259
+ cg.f32.add();
2260
+ cg.local.get(r);
2261
+ cg.f32.mul();
2262
+ cg.f32.const(1 / 2);
2263
+ cg.f32.add();
2264
+ cg.local.get(r);
2265
+ cg.f32.mul();
2266
+ cg.f32.const(1);
2267
+ cg.f32.add();
2268
+ cg.local.get(r);
2269
+ cg.f32.mul();
2270
+ cg.f32.const(1);
2271
+ cg.f32.add();
2272
+ cg.local.set(p);
2273
+ cg.local.get(k);
2274
+ cg.i32.const(127);
2275
+ cg.i32.add();
2276
+ cg.i32.const(23);
2277
+ cg.i32.shl();
2278
+ cg.f32.reinterpret_i32();
2279
+ cg.local.set(scale);
2280
+ cg.local.get(p);
2281
+ cg.local.get(scale);
2282
+ cg.f32.mul();
2283
+ });
2284
+ }
2285
+ /**
2286
+ * Approximate ln(x), x > 0.
2287
+ *
2288
+ * Method: decompose x = m * 2^e with m in [1,2), e integer (via bit ops)
2289
+ * ln(x) = e*ln2 + ln(m); use atanh-style series with t=(m-1)/(m+1)
2290
+ * ln(m) ≈ 2*(t + t^3/3 + t^5/5 + t^7/7)
2291
+ */
2292
+ function wasm_log(cg) {
2293
+ return cg.function([cg.f32], [cg.f32], () => {
2294
+ const bits = cg.local.declare(cg.i32);
2295
+ const e = cg.local.declare(cg.i32);
2296
+ const m = cg.local.declare(cg.f32);
2297
+ const t = cg.local.declare(cg.f32);
2298
+ const t2 = cg.local.declare(cg.f32);
2299
+ const t3 = cg.local.declare(cg.f32);
2300
+ const t5 = cg.local.declare(cg.f32);
2301
+ const t7 = cg.local.declare(cg.f32);
2302
+ const lnm = cg.local.declare(cg.f32);
2303
+ const el2 = cg.local.declare(cg.f32);
2304
+ cg.local.get(0);
2305
+ cg.f32.const(0);
2306
+ cg.f32.le();
2307
+ cg.if(cg.void);
2308
+ cg.f32.const(NaN);
2309
+ cg.return();
2310
+ cg.end();
2311
+ cg.local.get(0);
2312
+ cg.i32.reinterpret_f32();
2313
+ cg.local.tee(bits);
2314
+ cg.i32.const(23);
2315
+ cg.i32.shr_u();
2316
+ cg.i32.const(255);
2317
+ cg.i32.and();
2318
+ cg.i32.const(127);
2319
+ cg.i32.sub();
2320
+ cg.local.set(e);
2321
+ cg.local.get(bits);
2322
+ cg.i32.const(8388607);
2323
+ cg.i32.and();
2324
+ cg.i32.const(1065353216);
2325
+ cg.i32.or();
2326
+ cg.f32.reinterpret_i32();
2327
+ cg.local.set(m);
2328
+ cg.local.get(m);
2329
+ cg.f32.const(1);
2330
+ cg.f32.sub();
2331
+ cg.local.get(m);
2332
+ cg.f32.const(1);
2333
+ cg.f32.add();
2334
+ cg.f32.div();
2335
+ cg.local.set(t);
2336
+ cg.local.get(t);
2337
+ cg.local.get(t);
2338
+ cg.f32.mul();
2339
+ cg.local.set(t2);
2340
+ cg.local.get(t);
2341
+ cg.local.get(t2);
2342
+ cg.f32.mul();
2343
+ cg.local.set(t3);
2344
+ cg.local.get(t3);
2345
+ cg.local.get(t2);
2346
+ cg.f32.mul();
2347
+ cg.local.set(t5);
2348
+ cg.local.get(t5);
2349
+ cg.local.get(t2);
2350
+ cg.f32.mul();
2351
+ cg.local.set(t7);
2352
+ cg.local.get(t7);
2353
+ cg.f32.const(1 / 7);
2354
+ cg.f32.mul();
2355
+ cg.local.get(t5);
2356
+ cg.f32.const(1 / 5);
2357
+ cg.f32.mul();
2358
+ cg.f32.add();
2359
+ cg.local.get(t3);
2360
+ cg.f32.const(1 / 3);
2361
+ cg.f32.mul();
2362
+ cg.f32.add();
2363
+ cg.local.get(t);
2364
+ cg.f32.add();
2365
+ cg.f32.const(2);
2366
+ cg.f32.mul();
2367
+ cg.local.set(lnm);
2368
+ cg.local.get(e);
2369
+ cg.f32.convert_i32_s();
2370
+ cg.f32.const(Math.LN2);
2371
+ cg.f32.mul();
2372
+ cg.local.set(el2);
2373
+ cg.local.get(el2);
2374
+ cg.local.get(lnm);
2375
+ cg.f32.add();
2376
+ });
2377
+ }
2378
+ /**
2379
+ * Common helper to approximate sin(x) and cos(x).
2380
+ *
2381
+ * Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
2382
+ * z = y - q*(π/2); use one of two polynomials on z:
2383
+ * sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
2384
+ * cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
2385
+ */
2386
+ function _sincos(cg) {
2387
+ const y = cg.local.declare(cg.f32);
2388
+ const qf = cg.local.declare(cg.f32);
2389
+ const q = cg.local.declare(cg.i32);
2390
+ const z = cg.local.declare(cg.f32);
2391
+ const z2 = cg.local.declare(cg.f32);
2392
+ const sz = cg.local.declare(cg.f32);
2393
+ const cz = cg.local.declare(cg.f32);
2394
+ cg.local.get(0);
2395
+ cg.local.get(0);
2396
+ cg.f32.const(1 / (2 * Math.PI));
2397
+ cg.f32.mul();
2398
+ cg.f32.nearest();
2399
+ cg.local.tee(qf);
2400
+ cg.f32.const(2 * Math.PI);
2401
+ cg.f32.mul();
2402
+ cg.f32.sub();
2403
+ cg.local.set(y);
2404
+ cg.local.get(y);
2405
+ cg.f32.const(2 / Math.PI);
2406
+ cg.f32.mul();
2407
+ cg.f32.nearest();
2408
+ cg.local.tee(qf);
2409
+ cg.i32.trunc_f32_s();
2410
+ cg.local.set(q);
2411
+ cg.local.get(y);
2412
+ cg.local.get(qf);
2413
+ cg.f32.const(Math.PI / 2);
2414
+ cg.f32.mul();
2415
+ cg.f32.sub();
2416
+ cg.local.tee(z);
2417
+ cg.local.get(z);
2418
+ cg.f32.mul();
2419
+ cg.local.set(z2);
2420
+ cg.f32.const(-1 / 5040);
2421
+ cg.local.get(z2);
2422
+ cg.f32.mul();
2423
+ cg.f32.const(1 / 120);
2424
+ cg.f32.add();
2425
+ cg.local.get(z2);
2426
+ cg.f32.mul();
2427
+ cg.f32.const(-1 / 6);
2428
+ cg.f32.add();
2429
+ cg.local.get(z2);
2430
+ cg.f32.mul();
2431
+ cg.f32.const(1);
2432
+ cg.f32.add();
2433
+ cg.local.get(z);
2434
+ cg.f32.mul();
2435
+ cg.local.set(sz);
2436
+ cg.f32.const(-1 / 720);
2437
+ cg.local.get(z2);
2438
+ cg.f32.mul();
2439
+ cg.f32.const(1 / 24);
2440
+ cg.f32.add();
2441
+ cg.local.get(z2);
2442
+ cg.f32.mul();
2443
+ cg.f32.const(-1 / 2);
2444
+ cg.f32.add();
2445
+ cg.local.get(z2);
2446
+ cg.f32.mul();
2447
+ cg.f32.const(1);
2448
+ cg.f32.add();
2449
+ cg.local.set(cz);
2450
+ return {
2451
+ q,
2452
+ sz,
2453
+ cz
2454
+ };
2455
+ }
2456
+ /**
2457
+ * Approximate sin(x).
2458
+ *
2459
+ * Quadrant mapping: k=q mod 4: 0: +sz, 1: +cz, 2: -sz, 3: -cz
2460
+ */
2461
+ function wasm_sin(cg) {
2462
+ return cg.function([cg.f32], [cg.f32], () => {
2463
+ const { q, sz, cz } = _sincos(cg);
2464
+ const mag = cg.local.declare(cg.f32);
2465
+ cg.local.get(cz);
2466
+ cg.local.get(sz);
2467
+ cg.local.get(q);
2468
+ cg.i32.const(1);
2469
+ cg.i32.and();
2470
+ cg.select();
2471
+ cg.local.tee(mag);
2472
+ cg.f32.neg();
2473
+ cg.local.get(mag);
2474
+ cg.local.get(q);
2475
+ cg.i32.const(2);
2476
+ cg.i32.and();
2477
+ cg.select();
2478
+ });
2479
+ }
2480
+ /**
2481
+ * Approximate cos(x).
2482
+ *
2483
+ * Quadrant mapping: k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
2484
+ */
2485
+ function wasm_cos(cg) {
2486
+ return cg.function([cg.f32], [cg.f32], () => {
2487
+ const { q, sz, cz } = _sincos(cg);
2488
+ const mag = cg.local.declare(cg.f32);
2489
+ cg.local.get(sz);
2490
+ cg.local.get(cz);
2491
+ cg.local.get(q);
2492
+ cg.i32.const(1);
2493
+ cg.i32.and();
2494
+ cg.select();
2495
+ cg.local.tee(mag);
2496
+ cg.f32.neg();
2497
+ cg.local.get(mag);
2498
+ cg.local.get(q);
2499
+ cg.i32.const(1);
2500
+ cg.i32.add();
2501
+ cg.i32.const(2);
2502
+ cg.i32.and();
2503
+ cg.select();
2504
+ });
2505
+ }
2506
+ /** Helper function for approximating arctan(x). */
2507
+ function _atan(cg) {
2508
+ const x = cg.local.declare(cg.f32);
2509
+ const abs_x = cg.local.declare(cg.f32);
2510
+ const z = cg.local.declare(cg.f32);
2511
+ const z2 = cg.local.declare(cg.f32);
2512
+ const p = cg.local.declare(cg.f32);
2513
+ cg.local.set(x);
2514
+ cg.local.get(x);
2515
+ cg.f32.abs();
2516
+ cg.local.set(abs_x);
2517
+ cg.f32.const(1);
2518
+ cg.local.get(abs_x);
2519
+ cg.f32.div();
2520
+ cg.local.get(abs_x);
2521
+ cg.local.get(abs_x);
2522
+ cg.f32.const(1);
2523
+ cg.f32.ge();
2524
+ cg.select();
2525
+ cg.local.set(z);
2526
+ cg.local.get(z);
2527
+ cg.local.get(z);
2528
+ cg.f32.mul();
2529
+ cg.local.set(z2);
2530
+ cg.f32.const(.0415796528637);
2531
+ cg.local.get(z2);
2532
+ cg.f32.mul();
2533
+ cg.f32.const(.661705427875);
2534
+ cg.f32.add();
2535
+ cg.local.get(z2);
2536
+ cg.f32.mul();
2537
+ cg.f32.const(.999998614341);
2538
+ cg.f32.add();
2539
+ cg.f32.const(.173698870181);
2540
+ cg.local.get(z2);
2541
+ cg.f32.mul();
2542
+ cg.f32.const(.994987933645);
2543
+ cg.f32.add();
2544
+ cg.local.get(z2);
2545
+ cg.f32.mul();
2546
+ cg.f32.const(1);
2547
+ cg.f32.add();
2548
+ cg.f32.div();
2549
+ cg.local.get(z);
2550
+ cg.f32.mul();
2551
+ cg.local.set(p);
2552
+ cg.f32.const(Math.PI / 2);
2553
+ cg.local.get(p);
2554
+ cg.f32.sub();
2555
+ cg.local.get(p);
2556
+ cg.local.get(abs_x);
2557
+ cg.f32.const(1);
2558
+ cg.f32.ge();
2559
+ cg.select();
2560
+ cg.local.get(x);
2561
+ cg.f32.copysign();
2562
+ }
2563
+ /**
2564
+ * Approximate atan(x).
2565
+ *
2566
+ * Method: if |x| < 1, use rational approximation: atan(x) ≈ x * P(x^2) / Q(x^2)
2567
+ * where P(u) = A0 + A1*u + A2*u^2 (degree 2)
2568
+ * Q(u) = 1 + B1*u + B2*u^2 (degree 2)
2569
+ * if |x| >= 1, use: atan(x) = sign(x)*π/2 - atan(1/x)
2570
+ * (fitted coefficients, max error ~5e-7 on [0,1])
2571
+ */
2572
+ function wasm_atan(cg) {
2573
+ return cg.function([cg.f32], [cg.f32], () => {
2574
+ cg.local.get(0);
2575
+ _atan(cg);
2576
+ });
2577
+ }
2578
+ /**
2579
+ * Approximate asin(x).
2580
+ *
2581
+ * Method: asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
2582
+ */
2583
+ function wasm_asin(cg) {
2584
+ return cg.function([cg.f32], [cg.f32], () => {
2585
+ cg.local.get(0);
2586
+ cg.f32.const(1);
2587
+ cg.local.get(0);
2588
+ cg.local.get(0);
2589
+ cg.f32.mul();
2590
+ cg.f32.sub();
2591
+ cg.f32.sqrt();
2592
+ cg.f32.const(1);
2593
+ cg.f32.add();
2594
+ cg.f32.div();
2595
+ _atan(cg);
2596
+ cg.f32.const(2);
2597
+ cg.f32.mul();
2598
+ });
2599
+ }
2600
+ /**
2601
+ * Threefry2x32 pseudorandom number generator.
2602
+ *
2603
+ * Takes two 32-bit keys and two 32-bit counters as input,
2604
+ * returns two 32-bit pseudorandom values.
2605
+ */
2606
+ function wasm_threefry2x32(cg) {
2607
+ return cg.function([
2608
+ cg.i32,
2609
+ cg.i32,
2610
+ cg.i32,
2611
+ cg.i32
2612
+ ], [cg.i32, cg.i32], () => {
2613
+ const ks0 = cg.local.declare(cg.i32);
2614
+ const ks1 = cg.local.declare(cg.i32);
2615
+ const ks2 = cg.local.declare(cg.i32);
2616
+ const x0 = cg.local.declare(cg.i32);
2617
+ const x1 = cg.local.declare(cg.i32);
2618
+ const mix = (rot) => {
2619
+ cg.local.get(x0);
2620
+ cg.local.get(x1);
2621
+ cg.i32.add();
2622
+ cg.local.set(x0);
2623
+ cg.local.get(x1);
2624
+ cg.i32.const(rot);
2625
+ cg.i32.rotl();
2626
+ cg.local.get(x0);
2627
+ cg.i32.xor();
2628
+ cg.local.set(x1);
2629
+ };
2630
+ const keySchedule = (k0, k1, round) => {
2631
+ cg.local.get(x0);
2632
+ cg.local.get(k0);
2633
+ cg.i32.add();
2634
+ cg.local.set(x0);
2635
+ cg.local.get(x1);
2636
+ cg.local.get(k1);
2637
+ cg.i32.add();
2638
+ cg.i32.const(round);
2639
+ cg.i32.add();
2640
+ cg.local.set(x1);
2641
+ };
2642
+ cg.local.get(0);
2643
+ cg.local.set(ks0);
2644
+ cg.local.get(1);
2645
+ cg.local.set(ks1);
2646
+ cg.local.get(0);
2647
+ cg.local.get(1);
2648
+ cg.i32.xor();
2649
+ cg.i32.const(466688986);
2650
+ cg.i32.xor();
2651
+ cg.local.set(ks2);
2652
+ cg.local.get(2);
2653
+ cg.local.get(ks0);
2654
+ cg.i32.add();
2655
+ cg.local.set(x0);
2656
+ cg.local.get(3);
2657
+ cg.local.get(ks1);
2658
+ cg.i32.add();
2659
+ cg.local.set(x1);
2660
+ mix(13), mix(15), mix(26), mix(6);
2661
+ keySchedule(ks1, ks2, 1);
2662
+ mix(17), mix(29), mix(16), mix(24);
2663
+ keySchedule(ks2, ks0, 2);
2664
+ mix(13), mix(15), mix(26), mix(6);
2665
+ keySchedule(ks0, ks1, 3);
2666
+ mix(17), mix(29), mix(16), mix(24);
2667
+ keySchedule(ks1, ks2, 4);
2668
+ mix(13), mix(15), mix(26), mix(6);
2669
+ keySchedule(ks2, ks0, 5);
2670
+ cg.local.get(x0);
2671
+ cg.local.get(x1);
2672
+ });
2673
+ }
2674
+
2675
+ //#endregion
2676
+ //#region src/backend/wasm/wasmblr.ts
2677
+ /**
2678
+ * @file Minimalist WebAssembly assembler. This allows you to emit WebAssembly
2679
+ * bytecode directly from the browser.
2680
+ *
2681
+ * Self-contained port of https://github.com/bwasti/wasmblr to TypeScript.
2682
+ * Some operation names in this module are written in `snake_case` to match
2683
+ * their names in the Wasm specification.
2684
+ *
2685
+ * Reference: https://pengowray.github.io/wasm-ops/.
2686
+ */
2687
+ const magicModuleHeader = [
2688
+ 0,
2689
+ 97,
2690
+ 115,
2691
+ 109
2692
+ ];
2693
+ const moduleVersion = [
2694
+ 1,
2695
+ 0,
2696
+ 0,
2697
+ 0
2698
+ ];
2699
+ function assert(condition, message) {
2700
+ if (!condition) throw new Error(message || "Assertion failed");
2701
+ }
2702
+ function encodeSigned(n) {
2703
+ const out = [];
2704
+ let more = true;
2705
+ while (more) {
2706
+ let byte = n & 127;
2707
+ n >>= 7;
2708
+ if (n === 0 && (byte & 64) === 0 || n === -1 && (byte & 64) !== 0) more = false;
2709
+ else byte |= 128;
2710
+ out.push(byte);
2711
+ }
2712
+ return out;
2713
+ }
2714
+ function encodeUnsigned(n) {
2715
+ const out = [];
2716
+ do {
2717
+ let byte = n & 127;
2718
+ n = n >>> 7;
2719
+ if (n !== 0) byte |= 128;
2720
+ out.push(byte);
2721
+ } while (n !== 0);
2722
+ return out;
2723
+ }
2724
+ function encodeString(s) {
2725
+ const bytes = new TextEncoder().encode(s);
2726
+ return [bytes.length, ...bytes];
2727
+ }
2728
+ function encodeBlocktype(type) {
2729
+ assert(type.length > 0, "blocktype must have at least one type");
2730
+ if (type.length === 1) return [type[0].typeId];
2731
+ return [
2732
+ 96,
2733
+ ...encodeUnsigned(0),
2734
+ ...encodeUnsigned(type.length),
2735
+ ...type.map((t) => t.typeId)
2736
+ ];
2737
+ }
2738
+ function encodeOpcode(opcode) {
2739
+ if (typeof opcode === "number") return [opcode];
2740
+ return [opcode[0], ...encodeUnsigned(opcode[1])];
2741
+ }
2742
+ function concat(out, inp) {
2743
+ out.push(...inp);
2744
+ }
2745
+ var Function_ = class {
2746
+ inputTypes;
2747
+ outputTypes;
2748
+ body;
2749
+ locals = [];
2750
+ constructor(inputTypes, outputTypes, body) {
2751
+ this.inputTypes = inputTypes;
2752
+ this.outputTypes = outputTypes;
2753
+ this.body = body || (() => {});
2754
+ }
2755
+ emit() {
2756
+ this.locals = [];
2757
+ this.body();
2758
+ }
2759
+ };
2760
+ var Memory = class {
2761
+ min = 0;
2762
+ max = 0;
2763
+ isShared = false;
2764
+ aString = "";
2765
+ bString = "";
2766
+ constructor(cg) {
2767
+ this.cg = cg;
2768
+ }
2769
+ /** Declare the size of the memory. Each page is 64 KiB. */
2770
+ pages(min, max = 0) {
2771
+ assert(this.min === 0 && this.max === 0);
2772
+ this.min = min;
2773
+ this.max = max;
2774
+ return this;
2775
+ }
2776
+ export(a) {
2777
+ assert(!this.isImport && !this.isExport, "already set");
2778
+ this.aString = a;
2779
+ return this;
2780
+ }
2781
+ shared(isShared) {
2782
+ this.isShared = isShared;
2783
+ return this;
2784
+ }
2785
+ import(a, b) {
2786
+ assert(!this.isImport && !this.isExport, "already set");
2787
+ this.aString = a;
2788
+ this.bString = b;
2789
+ return this;
2790
+ }
2791
+ size() {
2792
+ this.cg._emit(63);
2793
+ this.cg._emit(0);
2794
+ }
2795
+ grow() {
2796
+ this.cg._emit(64);
2797
+ this.cg._emit(0);
2798
+ }
2799
+ get isImport() {
2800
+ return this.aString.length > 0 && this.bString.length > 0;
2801
+ }
2802
+ get isExport() {
2803
+ return this.aString.length > 0 && this.bString.length === 0;
2804
+ }
2805
+ };
2806
+ /** Public API of WebAssembly assembler. */
2807
+ var CodeGenerator = class {
2808
+ local;
2809
+ i32;
2810
+ f32;
2811
+ v128;
2812
+ i32x4;
2813
+ f32x4;
2814
+ memory;
2815
+ void = {
2816
+ typeId: 64,
2817
+ name: "void"
2818
+ };
2819
+ #functions = [];
2820
+ #importedFunctions = [];
2821
+ #exportedFunctions = /* @__PURE__ */ new Map();
2822
+ #curFunction = null;
2823
+ #curBytes = [];
2824
+ #typeStack = [];
2825
+ #blockFrames = [];
2826
+ constructor() {
2827
+ this.local = new Local(this);
2828
+ this.i32 = new I32(this);
2829
+ this.f32 = new F32(this);
2830
+ this.v128 = new V128(this);
2831
+ this.i32x4 = new I32x4(this);
2832
+ this.f32x4 = new F32x4(this);
2833
+ this.memory = new Memory(this);
2834
+ }
2835
+ unreachable() {
2836
+ this._emit(0);
2837
+ }
2838
+ nop() {
2839
+ this._emit(1);
2840
+ }
2841
+ block(...type) {
2842
+ this.#blockFrames.push({
2843
+ idx: this.#typeStack.length,
2844
+ ty: type
2845
+ });
2846
+ this._emit(2);
2847
+ this._emit(encodeBlocktype(type));
2848
+ }
2849
+ loop(...type) {
2850
+ this.#blockFrames.push({
2851
+ idx: this.#typeStack.length,
2852
+ ty: type
2853
+ });
2854
+ this._emit(3);
2855
+ this._emit(encodeBlocktype(type));
2856
+ }
2857
+ if(...type) {
2858
+ assert(this._pop().typeId === this.i32.typeId, "if_: expected i32");
2859
+ this.#blockFrames.push({
2860
+ idx: this.#typeStack.length,
2861
+ ty: type
2862
+ });
2863
+ this._emit(4);
2864
+ this._emit(encodeBlocktype(type));
2865
+ }
2866
+ else() {
2867
+ assert(this.#blockFrames.length > 0, "else: no block to else");
2868
+ const frame = this.#blockFrames[this.#blockFrames.length - 1];
2869
+ this.#typeStack = this.#typeStack.slice(0, frame.idx);
2870
+ this._emit(5);
2871
+ }
2872
+ /** End a block (`block`, `if`/`else`, `loop`, or function). */
2873
+ end() {
2874
+ const frame = this.#blockFrames.pop();
2875
+ assert(frame !== void 0, "end: no block to end");
2876
+ this.#typeStack = this.#typeStack.slice(0, frame.idx);
2877
+ for (const ty of frame.ty) if (ty.typeId !== this.void.typeId) this._push(ty);
2878
+ this._emit(11);
2879
+ }
2880
+ /** Branch to a block a certain depth outward on the stack. */
2881
+ br(depth) {
2882
+ this._emit(12);
2883
+ this._emit(encodeUnsigned(depth));
2884
+ }
2885
+ /** Conditional branch to a block a certain depth outward on the stack. */
2886
+ br_if(depth) {
2887
+ assert(this._pop().typeId === this.i32.typeId, "br_if: expected i32");
2888
+ this._emit(13);
2889
+ this._emit(encodeUnsigned(depth));
2890
+ }
2891
+ /** Jump table that indexes into a label vector (like switch). */
2892
+ br_table(...depths) {
2893
+ assert(this._pop().typeId === this.i32.typeId, "br_table: expected i32");
2894
+ assert(depths.length > 0, "br_table: expected at least one default depth");
2895
+ this._emit(14);
2896
+ this._emit(encodeUnsigned(depths.length - 1));
2897
+ for (const d of depths) this._emit(encodeUnsigned(d));
2898
+ }
2899
+ /** Return from a function, branching out of the outermost block. */
2900
+ return() {
2901
+ this._emit(15);
2902
+ }
2903
+ /** Call a function with the given ID. */
2904
+ call(fn) {
2905
+ const totalFunctions = this.#importedFunctions.length + this.#functions.length;
2906
+ assert(fn < totalFunctions, "function index does not exist");
2907
+ const func = fn < this.#importedFunctions.length ? this.#importedFunctions[fn] : this.#functions[fn - this.#importedFunctions.length];
2908
+ for (let i = func.inputTypes.length - 1; i >= 0; i--) {
2909
+ const argType = this._pop();
2910
+ assert(argType.typeId === func.inputTypes[i].typeId, `call: argument ${i} type mismatch, expected ${func.inputTypes[i].name} got ${argType.name}`);
2911
+ }
2912
+ for (const outputType of func.outputTypes) this._push(outputType);
2913
+ this._emit(16);
2914
+ this._emit(encodeUnsigned(fn));
2915
+ }
2916
+ /** Throw away an operand on the stack. */
2917
+ drop() {
2918
+ this._pop();
2919
+ this._emit(26);
2920
+ }
2921
+ /** Select one of the first two operands (T, F) based on the third operand (i32)'s value. */
2922
+ select() {
2923
+ assert(this._pop().typeId === this.i32.typeId, "select: expected i32 condition");
2924
+ const [b, a] = [this._pop(), this._pop()];
2925
+ assert(a.typeId === b.typeId, "select: expected same type for both operands");
2926
+ this._push(a);
2927
+ this._emit(27);
2928
+ }
2929
+ /** Import a JavaScript function; returns its index. */
2930
+ importFunction(module, name, inputTypes, outputTypes) {
2931
+ if (this.#functions.length > 0) throw new Error("function imports must precede defining functions");
2932
+ const idx = this.#importedFunctions.length;
2933
+ this.#importedFunctions.push({
2934
+ module,
2935
+ name,
2936
+ inputTypes,
2937
+ outputTypes
2938
+ });
2939
+ return idx;
2940
+ }
2941
+ /** Export a function. */
2942
+ export(fn, name) {
2943
+ this.#exportedFunctions.set(fn, name);
2944
+ }
2945
+ /** Declare a new function; returns its index. */
2946
+ function(inputTypes, outputTypes, body) {
2947
+ const idx = this.#importedFunctions.length + this.#functions.length;
2948
+ this.#functions.push(new Function_(inputTypes, outputTypes, body));
2949
+ return idx;
2950
+ }
2951
+ _declareLocal(type) {
2952
+ assert(this.#curFunction !== null, "No current function");
2953
+ const idx = this.#curFunction.locals.length + this.#curFunction.inputTypes.length;
2954
+ this.#curFunction.locals.push(type);
2955
+ return idx;
2956
+ }
2957
+ _inputTypes() {
2958
+ assert(this.#curFunction !== null, "No current function");
2959
+ return this.#curFunction.inputTypes;
2960
+ }
2961
+ _locals() {
2962
+ assert(this.#curFunction !== null, "No current function");
2963
+ return this.#curFunction.locals;
2964
+ }
2965
+ _push(type) {
2966
+ if (!type) throw new Error(`pushing type ${type}`);
2967
+ this.#typeStack.push(type);
2968
+ }
2969
+ _pop() {
2970
+ assert(this.#typeStack.length > 0, "popping empty stack");
2971
+ return this.#typeStack.pop();
2972
+ }
2973
+ _emit(bytes) {
2974
+ if (typeof bytes === "number") this.#curBytes.push(bytes);
2975
+ else this.#curBytes.push(...bytes);
2976
+ }
2977
+ finish() {
2978
+ this.#curBytes = [];
2979
+ const emittedBytes = [];
2980
+ concat(emittedBytes, magicModuleHeader);
2981
+ concat(emittedBytes, moduleVersion);
2982
+ const typeSectionBytes = [];
2983
+ const totalFunctionTypes = this.#importedFunctions.length + this.#functions.length;
2984
+ concat(typeSectionBytes, encodeUnsigned(totalFunctionTypes));
2985
+ for (const f of [...this.#importedFunctions, ...this.#functions]) {
2986
+ typeSectionBytes.push(96);
2987
+ concat(typeSectionBytes, encodeUnsigned(f.inputTypes.length));
2988
+ for (const t of f.inputTypes) typeSectionBytes.push(t.typeId);
2989
+ concat(typeSectionBytes, encodeUnsigned(f.outputTypes.length));
2990
+ for (const t of f.outputTypes) typeSectionBytes.push(t.typeId);
2991
+ }
2992
+ emittedBytes.push(1);
2993
+ concat(emittedBytes, encodeUnsigned(typeSectionBytes.length));
2994
+ concat(emittedBytes, typeSectionBytes);
2995
+ const importSectionBytes = [];
2996
+ const numImports = this.#importedFunctions.length + (this.memory.isImport ? 1 : 0);
2997
+ if (numImports > 0) {
2998
+ concat(importSectionBytes, encodeUnsigned(numImports));
2999
+ for (let i = 0; i < this.#importedFunctions.length; i++) {
3000
+ const f = this.#importedFunctions[i];
3001
+ concat(importSectionBytes, encodeString(f.module));
3002
+ concat(importSectionBytes, encodeString(f.name));
3003
+ importSectionBytes.push(0);
3004
+ concat(importSectionBytes, encodeUnsigned(i));
3005
+ }
3006
+ if (this.memory.isImport) {
3007
+ concat(importSectionBytes, encodeString(this.memory.aString));
3008
+ concat(importSectionBytes, encodeString(this.memory.bString));
3009
+ importSectionBytes.push(2);
3010
+ if (this.memory.min && this.memory.max) {
3011
+ if (this.memory.isShared) importSectionBytes.push(3);
3012
+ else importSectionBytes.push(1);
3013
+ concat(importSectionBytes, encodeUnsigned(this.memory.min));
3014
+ concat(importSectionBytes, encodeUnsigned(this.memory.max));
3015
+ } else {
3016
+ assert(!this.memory.isShared, "shared memory must have a max size");
3017
+ importSectionBytes.push(0);
3018
+ concat(importSectionBytes, encodeUnsigned(this.memory.min));
3019
+ }
3020
+ }
3021
+ emittedBytes.push(2);
3022
+ concat(emittedBytes, encodeUnsigned(importSectionBytes.length));
3023
+ concat(emittedBytes, importSectionBytes);
3024
+ }
3025
+ const functionSectionBytes = [];
3026
+ concat(functionSectionBytes, encodeUnsigned(this.#functions.length));
3027
+ for (let i = 0; i < this.#functions.length; i++) {
3028
+ const typeIndex = this.#importedFunctions.length + i;
3029
+ concat(functionSectionBytes, encodeUnsigned(typeIndex));
3030
+ }
3031
+ emittedBytes.push(3);
3032
+ concat(emittedBytes, encodeUnsigned(functionSectionBytes.length));
3033
+ concat(emittedBytes, functionSectionBytes);
3034
+ const memorySectionBytes = [];
3035
+ if (!this.memory.isImport && (this.memory.min || this.memory.max)) {
3036
+ memorySectionBytes.push(1);
3037
+ if (this.memory.min && this.memory.max) {
3038
+ if (this.memory.isShared) memorySectionBytes.push(3);
3039
+ else memorySectionBytes.push(1);
3040
+ concat(memorySectionBytes, encodeUnsigned(this.memory.min));
3041
+ concat(memorySectionBytes, encodeUnsigned(this.memory.max));
3042
+ } else {
3043
+ assert(!this.memory.isShared, "shared memory must have a max size");
3044
+ memorySectionBytes.push(0);
3045
+ concat(memorySectionBytes, encodeUnsigned(this.memory.min));
3046
+ }
3047
+ emittedBytes.push(5);
3048
+ concat(emittedBytes, encodeUnsigned(memorySectionBytes.length));
3049
+ concat(emittedBytes, memorySectionBytes);
3050
+ }
3051
+ const exportSectionBytes = [];
3052
+ const numExports = this.#exportedFunctions.size + (this.memory.isExport ? 1 : 0);
3053
+ concat(exportSectionBytes, encodeUnsigned(numExports));
3054
+ if (this.memory.isExport) {
3055
+ concat(exportSectionBytes, encodeString(this.memory.aString));
3056
+ exportSectionBytes.push(2);
3057
+ exportSectionBytes.push(0);
3058
+ }
3059
+ for (const [key, name] of this.#exportedFunctions.entries()) {
3060
+ concat(exportSectionBytes, encodeString(name));
3061
+ exportSectionBytes.push(0);
3062
+ concat(exportSectionBytes, encodeUnsigned(key));
3063
+ }
3064
+ emittedBytes.push(7);
3065
+ concat(emittedBytes, encodeUnsigned(exportSectionBytes.length));
3066
+ concat(emittedBytes, exportSectionBytes);
3067
+ const codeSectionBytes = [];
3068
+ concat(codeSectionBytes, encodeUnsigned(this.#functions.length));
3069
+ for (const f of this.#functions) {
3070
+ this.#typeStack = [];
3071
+ this.#blockFrames = [{
3072
+ idx: 0,
3073
+ ty: f.outputTypes
3074
+ }];
3075
+ this.#curFunction = f;
3076
+ this.#curBytes = [];
3077
+ f.emit();
3078
+ this.end();
3079
+ const bodyBytes = [...this.#curBytes];
3080
+ this.#curBytes = [];
3081
+ concat(this.#curBytes, encodeUnsigned(f.locals.length));
3082
+ for (const l of f.locals) {
3083
+ this._emit(1);
3084
+ this._emit(l.typeId);
3085
+ }
3086
+ const headerBytes = [...this.#curBytes];
3087
+ const fnSize = headerBytes.length + bodyBytes.length;
3088
+ concat(codeSectionBytes, encodeUnsigned(fnSize));
3089
+ concat(codeSectionBytes, headerBytes);
3090
+ concat(codeSectionBytes, bodyBytes);
3091
+ }
3092
+ this.#curFunction = null;
3093
+ emittedBytes.push(10);
3094
+ concat(emittedBytes, encodeUnsigned(codeSectionBytes.length));
3095
+ concat(emittedBytes, codeSectionBytes);
3096
+ return new Uint8Array(emittedBytes);
3097
+ }
3098
+ };
3099
+ var Local = class {
3100
+ constructor(cg) {
3101
+ this.cg = cg;
3102
+ }
3103
+ declare(type) {
3104
+ return this.cg._declareLocal(type);
3105
+ }
3106
+ get(idx) {
3107
+ assert(Number.isInteger(idx), "getting non-integer local");
3108
+ const inputTypes = this.cg._inputTypes();
3109
+ if (idx < inputTypes.length) this.cg._push(inputTypes[idx]);
3110
+ else this.cg._push(this.cg._locals()[idx - inputTypes.length]);
3111
+ this.cg._emit(32);
3112
+ this.cg._emit(encodeUnsigned(idx));
3113
+ }
3114
+ set(idx) {
3115
+ const t = this.cg._pop();
3116
+ const inputTypes = this.cg._inputTypes();
3117
+ const expectedType = idx < inputTypes.length ? inputTypes[idx] : this.cg._locals()[idx - inputTypes.length];
3118
+ assert(expectedType.typeId === t.typeId, "can't set local to this value (wrong type)");
3119
+ this.cg._emit(33);
3120
+ this.cg._emit(encodeUnsigned(idx));
3121
+ }
3122
+ tee(idx) {
3123
+ const t = this.cg._pop();
3124
+ const inputTypes = this.cg._inputTypes();
3125
+ const expectedType = idx < inputTypes.length ? inputTypes[idx] : this.cg._locals()[idx - inputTypes.length];
3126
+ assert(expectedType.typeId === t.typeId, "can't tee local to this value (wrong type)");
3127
+ this.cg._emit(34);
3128
+ this.cg._emit(encodeUnsigned(idx));
3129
+ this.cg._push(expectedType);
3130
+ }
3131
+ };
3132
+ function UNARY_OP(op, opcode, inType, outType) {
3133
+ return function() {
3134
+ const t = this.cg._pop();
3135
+ assert(t.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inType} -> ${outType})`);
3136
+ this.cg._emit(encodeOpcode(opcode));
3137
+ this.cg._push(this.cg[outType]);
3138
+ };
3139
+ }
3140
+ function BINARY_OP(op, opcode, typeA, typeB, outType) {
3141
+ return function() {
3142
+ const b = this.cg._pop();
3143
+ const a = this.cg._pop();
3144
+ assert(a.typeId === this.cg[typeA].typeId && b.typeId === this.cg[typeB].typeId, `invalid type for ${op} (${typeA}, ${typeB} -> ${outType})`);
3145
+ this.cg._emit(encodeOpcode(opcode));
3146
+ this.cg._push(this.cg[outType]);
3147
+ };
3148
+ }
3149
+ function LOAD_OP(op, opcode, outType) {
3150
+ return function(align = 0, offset = 0) {
3151
+ const idxType = this.cg._pop();
3152
+ assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
3153
+ this.cg._emit(encodeOpcode(opcode));
3154
+ this.cg._emit(encodeUnsigned(align));
3155
+ this.cg._emit(encodeUnsigned(offset));
3156
+ this.cg._push(this.cg[outType]);
3157
+ };
3158
+ }
3159
+ function STORE_OP(op, opcode, inType) {
3160
+ return function(align = 0, offset = 0) {
3161
+ const valType = this.cg._pop();
3162
+ const idxType = this.cg._pop();
3163
+ assert(valType.typeId === this.cg[inType].typeId, `invalid value type for ${op} (${inType})`);
3164
+ assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
3165
+ this.cg._emit(encodeOpcode(opcode));
3166
+ this.cg._emit(encodeUnsigned(align));
3167
+ this.cg._emit(encodeUnsigned(offset));
3168
+ };
3169
+ }
3170
+ var I32 = class {
3171
+ constructor(cg) {
3172
+ this.cg = cg;
3173
+ }
3174
+ get typeId() {
3175
+ return 127;
3176
+ }
3177
+ get name() {
3178
+ return "i32";
3179
+ }
3180
+ const(i) {
3181
+ this.cg._emit(65);
3182
+ this.cg._emit(encodeSigned(i));
3183
+ this.cg._push(this);
3184
+ }
3185
+ clz = UNARY_OP("clz", 103, "i32", "i32");
3186
+ ctz = UNARY_OP("ctz", 104, "i32", "i32");
3187
+ popcnt = UNARY_OP("popcnt", 105, "i32", "i32");
3188
+ lt_s = BINARY_OP("lt_s", 72, "i32", "i32", "i32");
3189
+ lt_u = BINARY_OP("lt_u", 73, "i32", "i32", "i32");
3190
+ gt_s = BINARY_OP("gt_s", 74, "i32", "i32", "i32");
3191
+ gt_u = BINARY_OP("gt_u", 75, "i32", "i32", "i32");
3192
+ le_s = BINARY_OP("le_s", 76, "i32", "i32", "i32");
3193
+ le_u = BINARY_OP("le_u", 77, "i32", "i32", "i32");
3194
+ ge_s = BINARY_OP("ge_s", 78, "i32", "i32", "i32");
3195
+ ge_u = BINARY_OP("ge_u", 79, "i32", "i32", "i32");
3196
+ add = BINARY_OP("add", 106, "i32", "i32", "i32");
3197
+ sub = BINARY_OP("sub", 107, "i32", "i32", "i32");
3198
+ mul = BINARY_OP("mul", 108, "i32", "i32", "i32");
3199
+ div_s = BINARY_OP("div_s", 109, "i32", "i32", "i32");
3200
+ div_u = BINARY_OP("div_u", 110, "i32", "i32", "i32");
3201
+ rem_s = BINARY_OP("rem_s", 111, "i32", "i32", "i32");
3202
+ rem_u = BINARY_OP("rem_u", 112, "i32", "i32", "i32");
3203
+ and = BINARY_OP("and", 113, "i32", "i32", "i32");
3204
+ or = BINARY_OP("or", 114, "i32", "i32", "i32");
3205
+ xor = BINARY_OP("xor", 115, "i32", "i32", "i32");
3206
+ shl = BINARY_OP("shl", 116, "i32", "i32", "i32");
3207
+ shr_s = BINARY_OP("shr_s", 117, "i32", "i32", "i32");
3208
+ shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
3209
+ rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
3210
+ rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
3211
+ eqz = BINARY_OP("eqz", 69, "i32", "i32", "i32");
3212
+ eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
3213
+ ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
3214
+ trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
3215
+ trunc_f32_u = UNARY_OP("trunc_f32_u", 169, "f32", "i32");
3216
+ load = LOAD_OP("load", 40, "i32");
3217
+ load8_s = LOAD_OP("load8_s", 44, "i32");
3218
+ load8_u = LOAD_OP("load8_u", 45, "i32");
3219
+ load16_s = LOAD_OP("load16_s", 46, "i32");
3220
+ load16_u = LOAD_OP("load16_u", 47, "i32");
3221
+ store = STORE_OP("store", 54, "i32");
3222
+ store8 = STORE_OP("store8", 58, "i32");
3223
+ store16 = STORE_OP("store16", 59, "i32");
3224
+ reinterpret_f32 = UNARY_OP("reinterpret_f32", 188, "f32", "i32");
3225
+ trunc_sat_f32_s = UNARY_OP("trunc_sat_f32_s", [252, 0], "f32", "i32");
3226
+ trunc_sat_f32_u = UNARY_OP("trunc_sat_f32_u", [252, 1], "f32", "i32");
3227
+ };
3228
+ var F32 = class {
3229
+ constructor(cg) {
3230
+ this.cg = cg;
3231
+ }
3232
+ get typeId() {
3233
+ return 125;
3234
+ }
3235
+ get name() {
3236
+ return "f32";
3237
+ }
3238
+ const(f) {
3239
+ this.cg._emit(67);
3240
+ const buffer = /* @__PURE__ */ new ArrayBuffer(4);
3241
+ new DataView(buffer).setFloat32(0, f, true);
3242
+ const bytes = new Uint8Array(buffer);
3243
+ for (let i = 0; i < 4; i++) this.cg._emit(bytes[i]);
3244
+ this.cg._push(this);
3245
+ }
3246
+ eq = BINARY_OP("eq", 91, "f32", "f32", "i32");
3247
+ ne = BINARY_OP("ne", 92, "f32", "f32", "i32");
3248
+ lt = BINARY_OP("lt", 93, "f32", "f32", "i32");
3249
+ gt = BINARY_OP("gt", 94, "f32", "f32", "i32");
3250
+ le = BINARY_OP("le", 95, "f32", "f32", "i32");
3251
+ ge = BINARY_OP("ge", 96, "f32", "f32", "i32");
3252
+ abs = UNARY_OP("abs", 139, "f32", "f32");
3253
+ neg = UNARY_OP("neg", 140, "f32", "f32");
3254
+ ceil = UNARY_OP("ceil", 141, "f32", "f32");
3255
+ floor = UNARY_OP("floor", 142, "f32", "f32");
3256
+ trunc = UNARY_OP("trunc", 143, "f32", "f32");
3257
+ nearest = UNARY_OP("nearest", 144, "f32", "f32");
3258
+ sqrt = UNARY_OP("sqrt", 145, "f32", "f32");
3259
+ add = BINARY_OP("add", 146, "f32", "f32", "f32");
3260
+ sub = BINARY_OP("sub", 147, "f32", "f32", "f32");
3261
+ mul = BINARY_OP("mul", 148, "f32", "f32", "f32");
3262
+ div = BINARY_OP("div", 149, "f32", "f32", "f32");
3263
+ min = BINARY_OP("min", 150, "f32", "f32", "f32");
3264
+ max = BINARY_OP("max", 151, "f32", "f32", "f32");
3265
+ copysign = BINARY_OP("copysign", 152, "f32", "f32", "f32");
3266
+ convert_i32_s = UNARY_OP("convert_i32_s", 178, "i32", "f32");
3267
+ convert_i32_u = UNARY_OP("convert_i32_u", 179, "i32", "f32");
3268
+ load = LOAD_OP("load", 42, "f32");
3269
+ store = STORE_OP("store", 56, "f32");
3270
+ reinterpret_i32 = UNARY_OP("reinterpret_i32", 190, "i32", "f32");
3271
+ };
3272
+ function VECTOR_OP(op, vopcode, inTypes, outType) {
3273
+ return function() {
3274
+ for (const inType of inTypes.toReversed()) {
3275
+ const actualType = this.cg._pop();
3276
+ assert(actualType.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inTypes.join(", ")} -> ${outType})`);
3277
+ }
3278
+ this.cg._emit(encodeOpcode([253, vopcode]));
3279
+ this.cg._push(this.cg[outType]);
3280
+ };
3281
+ }
3282
+ function VECTOR_OPL(op, vopcode, inTypes, outType) {
3283
+ return function(lane) {
3284
+ for (const inType of inTypes.toReversed()) {
3285
+ const actualType = this.cg._pop();
3286
+ assert(actualType.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inTypes} -> ${outType})`);
3287
+ }
3288
+ this.cg._emit(encodeOpcode([253, vopcode]));
3289
+ this.cg._emit(lane);
3290
+ this.cg._push(this.cg[outType]);
3291
+ };
3292
+ }
3293
+ function VECTOR_LOAD_OP(op, vopcode) {
3294
+ return function(align = 0, offset = 0) {
3295
+ const idxType = this.cg._pop();
3296
+ assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
3297
+ this.cg._emit(encodeOpcode([253, vopcode]));
3298
+ this.cg._emit(encodeUnsigned(align));
3299
+ this.cg._emit(encodeUnsigned(offset));
3300
+ this.cg._push(this.cg.v128);
3301
+ };
3302
+ }
3303
+ var V128 = class {
3304
+ constructor(cg) {
3305
+ this.cg = cg;
3306
+ }
3307
+ get typeId() {
3308
+ return 123;
3309
+ }
3310
+ get name() {
3311
+ return "v128";
3312
+ }
3313
+ load = VECTOR_LOAD_OP("load", 0);
3314
+ load32x2_s = VECTOR_LOAD_OP("load32x2_s", 5);
3315
+ load32x2_u = VECTOR_LOAD_OP("load32x2_u", 6);
3316
+ load32_splat = VECTOR_LOAD_OP("load32_splat", 9);
3317
+ load32_zero = VECTOR_LOAD_OP("load32_zero", 92);
3318
+ store(align = 0, offset = 0) {
3319
+ const valType = this.cg._pop();
3320
+ assert(valType.typeId === this.cg.v128.typeId, `invalid type for store`);
3321
+ const idxType = this.cg._pop();
3322
+ assert(idxType.typeId === this.cg.i32.typeId, `invalid type for store`);
3323
+ this.cg._emit(253);
3324
+ this.cg._emit(encodeUnsigned(11));
3325
+ this.cg._emit(encodeUnsigned(align));
3326
+ this.cg._emit(encodeUnsigned(offset));
3327
+ }
3328
+ not = VECTOR_OP("not", 77, ["v128"], "v128");
3329
+ and = VECTOR_OP("and", 78, ["v128", "v128"], "v128");
3330
+ andnot = VECTOR_OP("andnot", 79, ["v128", "v128"], "v128");
3331
+ or = VECTOR_OP("or", 80, ["v128", "v128"], "v128");
3332
+ xor = VECTOR_OP("xor", 81, ["v128", "v128"], "v128");
3333
+ bitselect = VECTOR_OP("bitselect", 82, [
3334
+ "v128",
3335
+ "v128",
3336
+ "v128"
3337
+ ], "v128");
3338
+ any_true = VECTOR_OP("any_true", 83, ["v128"], "i32");
3339
+ };
3340
+ var I32x4 = class extends V128 {
3341
+ splat = VECTOR_OP("splat", 17, ["i32"], "v128");
3342
+ extract_lane = VECTOR_OPL("extract_lane", 27, ["v128"], "i32");
3343
+ replace_lane = VECTOR_OPL("replace_lane", 28, ["v128", "i32"], "v128");
3344
+ eq = VECTOR_OP("eq", 55, ["v128", "v128"], "v128");
3345
+ ne = VECTOR_OP("ne", 56, ["v128", "v128"], "v128");
3346
+ lt_s = VECTOR_OP("lt_s", 57, ["v128", "v128"], "v128");
3347
+ lt_u = VECTOR_OP("lt_u", 58, ["v128", "v128"], "v128");
3348
+ gt_s = VECTOR_OP("gt_s", 59, ["v128", "v128"], "v128");
3349
+ gt_u = VECTOR_OP("gt_u", 60, ["v128", "v128"], "v128");
3350
+ le_s = VECTOR_OP("le_s", 61, ["v128", "v128"], "v128");
3351
+ le_u = VECTOR_OP("le_u", 62, ["v128", "v128"], "v128");
3352
+ ge_s = VECTOR_OP("ge_s", 63, ["v128", "v128"], "v128");
3353
+ ge_u = VECTOR_OP("ge_u", 64, ["v128", "v128"], "v128");
3354
+ abs = VECTOR_OP("abs", 160, ["v128"], "v128");
3355
+ neg = VECTOR_OP("neg", 161, ["v128"], "v128");
3356
+ all_true = VECTOR_OP("all_true", 163, ["v128"], "i32");
3357
+ bitmask = VECTOR_OP("bitmask", 164, ["v128"], "i32");
3358
+ shl = VECTOR_OP("shl", 171, ["v128", "i32"], "v128");
3359
+ shr_s = VECTOR_OP("shr_s", 172, ["v128", "i32"], "v128");
3360
+ shr_u = VECTOR_OP("shr_u", 173, ["v128", "i32"], "v128");
3361
+ add = VECTOR_OP("add", 174, ["v128", "v128"], "v128");
3362
+ sub = VECTOR_OP("sub", 177, ["v128", "v128"], "v128");
3363
+ mul = VECTOR_OP("mul", 181, ["v128", "v128"], "v128");
3364
+ min_s = VECTOR_OP("min_s", 182, ["v128", "v128"], "v128");
3365
+ min_u = VECTOR_OP("min_u", 183, ["v128", "v128"], "v128");
3366
+ max_s = VECTOR_OP("max_s", 184, ["v128", "v128"], "v128");
3367
+ max_u = VECTOR_OP("max_u", 185, ["v128", "v128"], "v128");
3368
+ };
3369
+ var F32x4 = class extends V128 {
3370
+ splat = VECTOR_OP("splat", 19, ["f32"], "v128");
3371
+ extract_lane = VECTOR_OPL("extract_lane", 31, ["v128"], "f32");
3372
+ replace_lane = VECTOR_OPL("replace_lane", 32, ["v128", "f32"], "v128");
3373
+ eq = VECTOR_OP("eq", 65, ["v128", "v128"], "v128");
3374
+ ne = VECTOR_OP("ne", 66, ["v128", "v128"], "v128");
3375
+ lt = VECTOR_OP("lt", 67, ["v128", "v128"], "v128");
3376
+ gt = VECTOR_OP("gt", 68, ["v128", "v128"], "v128");
3377
+ le = VECTOR_OP("le", 69, ["v128", "v128"], "v128");
3378
+ ge = VECTOR_OP("ge", 70, ["v128", "v128"], "v128");
3379
+ ceil = VECTOR_OP("ceil", 103, ["v128"], "v128");
3380
+ floor = VECTOR_OP("floor", 104, ["v128"], "v128");
3381
+ trunc = VECTOR_OP("trunc", 105, ["v128"], "v128");
3382
+ nearest = VECTOR_OP("nearest", 106, ["v128"], "v128");
3383
+ abs = VECTOR_OP("abs", 224, ["v128"], "v128");
3384
+ neg = VECTOR_OP("neg", 225, ["v128"], "v128");
3385
+ sqrt = VECTOR_OP("sqrt", 227, ["v128"], "v128");
3386
+ add = VECTOR_OP("add", 228, ["v128", "v128"], "v128");
3387
+ sub = VECTOR_OP("sub", 229, ["v128", "v128"], "v128");
3388
+ mul = VECTOR_OP("mul", 230, ["v128", "v128"], "v128");
3389
+ div = VECTOR_OP("div", 231, ["v128", "v128"], "v128");
3390
+ min = VECTOR_OP("min", 232, ["v128", "v128"], "v128");
3391
+ max = VECTOR_OP("max", 233, ["v128", "v128"], "v128");
3392
+ pmin = VECTOR_OP("pmin", 234, ["v128", "v128"], "v128");
3393
+ pmax = VECTOR_OP("pmax", 235, ["v128", "v128"], "v128");
3394
+ };
3395
+
3396
+ //#endregion
3397
+ //#region src/backend/wasm.ts
3398
+ /** Backend that compiles into WebAssembly bytecode for immediate execution. */
3399
+ var WasmBackend = class {
3400
+ type = "wasm";
3401
+ maxArgs = 64;
3402
+ #memory;
3403
+ #nextSlot;
3404
+ #allocator;
3405
+ #buffers;
3406
+ constructor() {
3407
+ this.#memory = new WebAssembly.Memory({ initial: 0 });
3408
+ this.#allocator = new WasmAllocator(this.#memory);
3409
+ this.#nextSlot = 1;
3410
+ this.#buffers = /* @__PURE__ */ new Map();
3411
+ }
3412
+ malloc(size, initialData) {
3413
+ const ptr = this.#allocator.malloc(size);
3414
+ if (initialData) {
3415
+ if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
3416
+ new Uint8Array(this.#memory.buffer, ptr, size).set(initialData);
3417
+ }
3418
+ const slot = this.#nextSlot++;
3419
+ this.#buffers.set(slot, {
3420
+ ptr,
3421
+ size,
3422
+ ref: 1
3423
+ });
3424
+ return slot;
3425
+ }
3426
+ incRef(slot) {
3427
+ const buffer = this.#buffers.get(slot);
3428
+ if (!buffer) throw new SlotError(slot);
3429
+ buffer.ref++;
3430
+ }
3431
+ decRef(slot) {
3432
+ const buffer = this.#buffers.get(slot);
3433
+ if (!buffer) throw new SlotError(slot);
3434
+ buffer.ref--;
3435
+ if (buffer.ref === 0) {
3436
+ this.#allocator.free(buffer.ptr);
3437
+ this.#buffers.delete(slot);
3438
+ }
3439
+ }
3440
+ async read(slot, start, count) {
3441
+ return this.readSync(slot, start, count);
3442
+ }
3443
+ readSync(slot, start, count) {
3444
+ const buffer = this.#getBuffer(slot);
3445
+ if (start === void 0) start = 0;
3446
+ if (count === void 0) count = buffer.byteLength - start;
3447
+ return buffer.slice(start, start + count);
3448
+ }
3449
+ async prepare(kernel) {
3450
+ return this.prepareSync(kernel);
3451
+ }
3452
+ prepareSync(kernel) {
3453
+ const bytes = codegenWasm(kernel);
3454
+ const module = new WebAssembly.Module(bytes);
3455
+ return new Executable(kernel, { module });
3456
+ }
3457
+ dispatch(exe, inputs, outputs) {
3458
+ const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
3459
+ const func = instance.exports.kernel;
3460
+ const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
3461
+ func(...ptrs);
3462
+ }
3463
+ #getBuffer(slot) {
3464
+ const buffer = this.#buffers.get(slot);
3465
+ if (!buffer) throw new SlotError(slot);
3466
+ return new Uint8Array(this.#memory.buffer, buffer.ptr, buffer.size);
3467
+ }
3468
+ };
3469
+ function codegenWasm(kernel) {
3470
+ const tune = tuneNullopt(kernel);
3471
+ const re = kernel.reduction;
3472
+ if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
3473
+ const cg = new CodeGenerator();
3474
+ cg.memory.import("env", "memory");
3475
+ const distinctOps = union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
3476
+ const funcs = {};
3477
+ if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
3478
+ if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
3479
+ if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
3480
+ if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
3481
+ if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
3482
+ if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
3483
+ if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
3484
+ const kernelFunc = cg.function(rep(kernel.nargs + 1, cg.i32), [], () => {
3485
+ const gidx = cg.local.declare(cg.i32);
3486
+ cg.loop(cg.void);
3487
+ cg.block(cg.void);
3488
+ cg.local.get(gidx);
3489
+ cg.i32.const(kernel.size);
3490
+ cg.i32.ge_u();
3491
+ cg.br_if(0);
3492
+ cg.local.get(kernel.nargs);
3493
+ cg.local.get(gidx);
3494
+ cg.i32.const(byteWidth(kernel.dtype));
3495
+ cg.i32.mul();
3496
+ cg.i32.add();
3497
+ if (re) {
3498
+ const acc = cg.local.declare(dty(cg, null, kernel.exp.dtype));
3499
+ dty(cg, null, kernel.exp.dtype).const(re.identity);
3500
+ cg.local.set(acc);
3501
+ const ridx = cg.local.declare(cg.i32);
3502
+ cg.i32.const(0);
3503
+ cg.local.set(ridx);
3504
+ cg.loop(cg.void);
3505
+ cg.block(cg.void);
3506
+ cg.local.get(ridx);
3507
+ cg.i32.const(re.size);
3508
+ cg.i32.ge_u();
3509
+ cg.br_if(0);
3510
+ translateExp(cg, funcs, tune.exp, {
3511
+ gidx,
3512
+ ridx
3513
+ });
3514
+ if (re.op === AluOp.Add) {
3515
+ cg.local.get(acc);
3516
+ if (re.dtype === DType.Bool) cg.i32.or();
3517
+ else dty(cg, re.op, re.dtype).add();
3518
+ } else if (re.op === AluOp.Mul) {
3519
+ cg.local.get(acc);
3520
+ if (re.dtype === DType.Bool) cg.i32.and();
3521
+ else dty(cg, re.op, re.dtype).mul();
3522
+ } else if (re.op === AluOp.Min || re.op === AluOp.Max) if (re.dtype === DType.Float32) {
3523
+ cg.local.get(acc);
3524
+ if (re.op === AluOp.Min) cg.f32.min();
3525
+ else cg.f32.max();
3526
+ } else if ([
3527
+ DType.Int32,
3528
+ DType.Uint32,
3529
+ DType.Bool
3530
+ ].includes(re.dtype)) {
3531
+ const local = cg.local.declare(cg.i32);
3532
+ cg.local.tee(local);
3533
+ cg.local.get(acc);
3534
+ cg.local.get(local);
3535
+ cg.local.get(acc);
3536
+ if (re.op === AluOp.Min) if (re.dtype === DType.Int32) cg.i32.lt_s();
3537
+ else cg.i32.lt_u();
3538
+ else if (re.dtype === DType.Int32) cg.i32.gt_s();
3539
+ else cg.i32.gt_u();
3540
+ cg.select();
3541
+ } else throw new Error(`invalid reduction min/max over ${re.dtype}`);
3542
+ else throw new Error(`invalid wasm reduction op: ${re.op}`);
3543
+ cg.local.set(acc);
3544
+ cg.local.get(ridx);
3545
+ cg.i32.const(1);
3546
+ cg.i32.add();
3547
+ cg.local.set(ridx);
3548
+ cg.br(1);
3549
+ cg.end();
3550
+ cg.end();
3551
+ translateExp(cg, funcs, kernel.reduction.epilogue, { acc });
3552
+ } else translateExp(cg, funcs, tune.exp, { gidx });
3553
+ dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
3554
+ cg.local.get(gidx);
3555
+ cg.i32.const(1);
3556
+ cg.i32.add();
3557
+ cg.local.set(gidx);
3558
+ cg.br(1);
3559
+ cg.end();
3560
+ cg.end();
3561
+ });
3562
+ cg.export(kernelFunc, "kernel");
3563
+ return cg.finish();
3564
+ }
3565
+ function translateExp(cg, funcs, exp, ctx) {
3566
+ const references = /* @__PURE__ */ new Map();
3567
+ const seen = /* @__PURE__ */ new Set();
3568
+ const countReferences = (exp$1) => {
3569
+ references.set(exp$1, (references.get(exp$1) ?? 0) + 1);
3570
+ if (!seen.has(exp$1)) {
3571
+ seen.add(exp$1);
3572
+ for (const src of exp$1.src) countReferences(src);
3573
+ }
3574
+ };
3575
+ const expContext = /* @__PURE__ */ new Map();
3576
+ const gen = (exp$1) => {
3577
+ if (expContext.has(exp$1)) return cg.local.get(expContext.get(exp$1));
3578
+ const { op, src, dtype, arg } = exp$1;
3579
+ if (AluGroup.Binary.has(op) || AluGroup.Compare.has(op)) {
3580
+ gen(src[0]);
3581
+ gen(src[1]);
3582
+ if (op === AluOp.Add) if (dtype === DType.Bool) cg.i32.or();
3583
+ else dty(cg, op, dtype).add();
3584
+ else if (op === AluOp.Sub) dty(cg, op, dtype).sub();
3585
+ else if (op === AluOp.Mul) if (dtype === DType.Bool) cg.i32.and();
3586
+ else dty(cg, op, dtype).mul();
3587
+ else if (op === AluOp.Idiv) if (dtype === DType.Float32) cg.f32.div(), cg.f32.trunc();
3588
+ else if (dtype === DType.Uint32) cg.i32.div_u();
3589
+ else if (dtype === DType.Int32) cg.i32.div_s();
3590
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3591
+ else if (op === AluOp.Mod) if (dtype === DType.Float32) {
3592
+ const a = cg.local.declare(cg.f32);
3593
+ const b = cg.local.declare(cg.f32);
3594
+ cg.local.set(b);
3595
+ cg.local.tee(a);
3596
+ cg.local.get(a);
3597
+ cg.local.get(b);
3598
+ cg.f32.div();
3599
+ cg.f32.trunc();
3600
+ cg.local.get(b);
3601
+ cg.f32.mul();
3602
+ cg.f32.sub();
3603
+ } else if (dtype === DType.Uint32) cg.i32.rem_u();
3604
+ else if (dtype === DType.Int32) cg.i32.rem_s();
3605
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3606
+ else if (op === AluOp.Min || op === AluOp.Max) if (dtype === DType.Float32) if (op === AluOp.Min) cg.f32.min();
3607
+ else cg.f32.max();
3608
+ else if (dtype === DType.Int32 || dtype === DType.Uint32) {
3609
+ const a = cg.local.declare(cg.i32);
3610
+ const b = cg.local.declare(cg.i32);
3611
+ cg.local.set(b);
3612
+ cg.local.tee(a);
3613
+ cg.local.get(b);
3614
+ cg.local.get(a);
3615
+ cg.local.get(b);
3616
+ if (dtype === DType.Int32) if (op === AluOp.Min) cg.i32.lt_s();
3617
+ else cg.i32.gt_s();
3618
+ else if (op === AluOp.Min) cg.i32.lt_u();
3619
+ else cg.i32.gt_u();
3620
+ cg.select();
3621
+ } else throw new UnsupportedOpError(op, dtype, "wasm");
3622
+ else if (op === AluOp.Cmplt) {
3623
+ const srcDtype = src[0].dtype;
3624
+ if (srcDtype === DType.Float32) cg.f32.lt();
3625
+ else if (srcDtype === DType.Int32) cg.i32.lt_s();
3626
+ else if (srcDtype === DType.Uint32) cg.i32.lt_u();
3627
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3628
+ } else if (op === AluOp.Cmpne) dty(cg, op, src[0].dtype).ne();
3629
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3630
+ } else if (AluGroup.Unary.has(op)) if (op === AluOp.Sin) gen(src[0]), cg.call(funcs.sin);
3631
+ else if (op === AluOp.Cos) gen(src[0]), cg.call(funcs.cos);
3632
+ else if (op === AluOp.Asin) gen(src[0]), cg.call(funcs.asin);
3633
+ else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
3634
+ else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
3635
+ else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
3636
+ else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
3637
+ else if (op === AluOp.Reciprocal) cg.f32.const(1), gen(src[0]), cg.f32.div();
3638
+ else if (op === AluOp.Cast) {
3639
+ gen(src[0]);
3640
+ const dtype0 = src[0].dtype;
3641
+ const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
3642
+ if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
3643
+ else if (i32repr);
3644
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3645
+ else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
3646
+ else if (i32repr);
3647
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3648
+ else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
3649
+ else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
3650
+ else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
3651
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3652
+ else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
3653
+ else if (i32repr) cg.i32.const(0), cg.i32.ne();
3654
+ else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
3655
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3656
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3657
+ } else if (op === AluOp.Bitcast) {
3658
+ gen(src[0]);
3659
+ const dtype0 = src[0].dtype;
3660
+ const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
3661
+ if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
3662
+ else if (i32repr);
3663
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3664
+ else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
3665
+ else if (dtype0 === DType.Float32);
3666
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3667
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3668
+ } else throw new UnsupportedOpError(op, dtype, "wasm");
3669
+ else if (op === AluOp.Where) {
3670
+ gen(src[1]);
3671
+ gen(src[2]);
3672
+ gen(src[0]);
3673
+ cg.select();
3674
+ } else if (op === AluOp.Threefry2x32) {
3675
+ for (let i = 0; i < 4; i++) gen(src[i]);
3676
+ cg.call(funcs.threefry2x32);
3677
+ if (arg === "xor") cg.i32.xor();
3678
+ else if (arg === 0) cg.drop();
3679
+ else if (arg === 1) {
3680
+ const local = cg.local.declare(cg.i32);
3681
+ cg.local.set(local);
3682
+ cg.drop();
3683
+ cg.local.get(local);
3684
+ } else throw new UnsupportedOpError(op, dtype, "wasm", arg);
3685
+ } else if (op === AluOp.Const) return dty(cg, op, dtype).const(arg);
3686
+ else if (op === AluOp.Special) return cg.local.get(ctx[arg[0]]);
3687
+ else if (op === AluOp.Variable) return cg.local.get(ctx[arg]);
3688
+ else if (op === AluOp.GlobalIndex) {
3689
+ const [gid, len] = arg;
3690
+ gen(src[0]);
3691
+ const local = cg.local.declare(cg.i32);
3692
+ cg.local.tee(local);
3693
+ cg.i32.const(0);
3694
+ cg.local.get(local), cg.i32.const(len), cg.i32.lt_u();
3695
+ cg.select();
3696
+ cg.i32.const(byteWidth(dtype));
3697
+ cg.i32.mul();
3698
+ cg.local.get(gid);
3699
+ cg.i32.add();
3700
+ dty(cg, op, dtype).load(Math.log2(byteWidth(dtype)));
3701
+ } else throw new UnsupportedOpError(op, dtype, "wasm");
3702
+ if ((references.get(exp$1) ?? 0) > 1) {
3703
+ const local = cg.local.declare(dty(cg, op, dtype));
3704
+ cg.local.tee(local);
3705
+ expContext.set(exp$1, local);
3706
+ }
3707
+ };
3708
+ countReferences(exp);
3709
+ gen(exp);
3710
+ }
3711
+ function dty(cg, op, dtype) {
3712
+ switch (dtype) {
3713
+ case DType.Float32: return cg.f32;
3714
+ case DType.Int32:
3715
+ case DType.Uint32:
3716
+ case DType.Bool: return cg.i32;
3717
+ default: throw new UnsupportedOpError(op, dtype, "wasm");
3718
+ }
3719
+ }
3720
+
3721
+ //#endregion
3722
+ //#region src/backend.ts
3723
+ const devices = [
3724
+ "cpu",
3725
+ "wasm",
3726
+ "webgpu"
3727
+ ];
3728
+ let defaultBackend = "wasm";
3729
+ const initializedBackends = /* @__PURE__ */ new Map();
3730
+ initializedBackends.set("cpu", new CpuBackend());
3731
+ initializedBackends.set("wasm", new WasmBackend());
3732
+ /** Configure the default device for arrays. */
3733
+ function defaultDevice(device) {
3734
+ if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
3735
+ else throw new Error(`Backend not initialized: ${device}`);
3736
+ return defaultBackend;
3737
+ }
3738
+ /**
3739
+ * Initialize `jax-js` library backends.
3740
+ *
3741
+ * By default, this will initialize all available backends. If one or more
3742
+ * backends is provided, only attempt to initialize those. Returns a list of
3743
+ * available backends.
3744
+ */
3745
+ async function init(...devicesToInit) {
3746
+ if (devicesToInit.length === 0) devicesToInit = devices;
3747
+ const promises = [];
3748
+ for (const device of new Set(devicesToInit)) if (!initializedBackends.has(device)) promises.push((async () => {
3749
+ const backend = await createBackend(device);
3750
+ if (backend) initializedBackends.set(device, backend);
3751
+ })());
3752
+ await Promise.all(promises);
3753
+ return Array.from(initializedBackends.keys());
3754
+ }
3755
+ /** Create a backend, if available. Internal function called by `init()`. */
3756
+ async function createBackend(device) {
3757
+ if (device === "cpu") return new CpuBackend();
3758
+ else if (device === "wasm") return new WasmBackend();
3759
+ else if (device === "webgpu") {
3760
+ if (!navigator.gpu) return null;
3761
+ const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3762
+ if (!adapter) return null;
3763
+ const { WebGPUBackend } = await import("./webgpu-ow0Pn_6q.js");
3764
+ const importantLimits = [
3765
+ "maxBufferSize",
3766
+ "maxComputeInvocationsPerWorkgroup",
3767
+ "maxComputeWorkgroupSizeX",
3768
+ "maxComputeWorkgroupSizeY",
3769
+ "maxComputeWorkgroupSizeZ",
3770
+ "maxComputeWorkgroupStorageSize",
3771
+ "maxComputeWorkgroupsPerDimension",
3772
+ "maxStorageBufferBindingSize",
3773
+ "maxStorageBuffersPerShaderStage",
3774
+ "maxStorageTexturesPerShaderStage"
3775
+ ];
3776
+ const requestedFeatures = ["shader-f16", "timestamp-query"];
3777
+ try {
3778
+ const device$1 = await adapter.requestDevice({
3779
+ requiredLimits: Object.fromEntries(importantLimits.map((limit) => [limit, adapter.limits[limit]])),
3780
+ requiredFeatures: requestedFeatures.filter((feature) => adapter.features.has(feature))
3781
+ });
3782
+ return new WebGPUBackend(device$1);
3783
+ } catch (error) {
3784
+ console.error("Unexpected error requesting WebGPU device:", error);
3785
+ return null;
3786
+ }
3787
+ } else throw new Error(`Backend not found: ${device}`);
3788
+ }
3789
+ /** Retrieve a backend that has been initialized. */
3790
+ function getBackend(device) {
3791
+ device = device ?? defaultBackend;
3792
+ const backend = initializedBackends.get(device);
3793
+ if (!backend) throw new Error(`${device} backend not ready, call init() first`);
3794
+ return backend;
3795
+ }
3796
+ var Executable = class {
3797
+ constructor(kernel, data) {
3798
+ this.kernel = kernel;
3799
+ this.data = data;
3800
+ }
3801
+ };
3802
+ var SlotError = class extends Error {
3803
+ constructor(slot) {
3804
+ super(`Used a buffer that is invalid or already freed: ${slot}`);
3805
+ }
3806
+ };
3807
+ var UnsupportedOpError = class extends Error {
3808
+ constructor(op, dtype, device, arg) {
3809
+ let msg = `${op || ""}<${dtype}> not supported in ${device} backend`;
3810
+ if (arg !== void 0) msg += ` with arg ${JSON.stringify(arg)}`;
3811
+ super(msg);
3812
+ }
3813
+ };
3814
+
3815
+ //#endregion
3816
+ export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, union, unravelAlu, unzip2, zip, zipn };