@jax-js/jax 0.0.4 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -52,6 +52,7 @@ let DEBUG = 0;
52
52
  function setDebug(level) {
53
53
  DEBUG = level;
54
54
  }
55
+ function assertNonNull(value) {}
55
56
  function unzip2(pairs) {
56
57
  const lst1 = [];
57
58
  const lst2 = [];
@@ -97,10 +98,15 @@ function deepEqual(a, b) {
97
98
  for (const key of Object.keys(a)) if (!deepEqual(a[key], b[key])) return false;
98
99
  return true;
99
100
  }
100
- function union(...sets) {
101
- const result = /* @__PURE__ */ new Set();
102
- for (const s of sets) if (s) for (const x of s) result.add(x);
103
- return result;
101
+ /** Produces a union of maps of sets. This mutates `a`. */
102
+ function mapSetUnion(a, b) {
103
+ if (!b) return a;
104
+ for (const [key, setB] of b.entries()) {
105
+ const setA = a.get(key);
106
+ if (setA) for (const val of setB) setA.add(val);
107
+ else a.set(key, setB);
108
+ }
109
+ return a;
104
110
  }
105
111
  /** Splits the list based on a condition, `false` first then `true`. */
106
112
  function partitionList(which, array) {
@@ -207,6 +213,35 @@ function findPow2(hint, max) {
207
213
  while (ret < hint && 2 * ret <= max) ret *= 2;
208
214
  return ret;
209
215
  }
216
+ /**
217
+ * Implements a NumPy-style generalized broadcast rule on two array shapes.
218
+ *
219
+ * "When operating on two arrays, NumPy compares their shapes element-wise. It
220
+ * starts with the trailing (i.e. rightmost) dimension and works its way left.
221
+ * Two dimensions are compatible when:
222
+ * 1. they are equal, or
223
+ * 2. one of them is 1."
224
+ *
225
+ * Throws a TypeError if the broadcast is not possible.
226
+ *
227
+ * <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
228
+ */
229
+ function generalBroadcast(a, b) {
230
+ const out = [];
231
+ let i = a.length - 1;
232
+ let j = b.length - 1;
233
+ for (; i >= 0 && j >= 0; i--, j--) {
234
+ const x = a[i];
235
+ const y = b[j];
236
+ if (x === y) out.push(x);
237
+ else if (x === 1) out.push(y);
238
+ else if (y === 1) out.push(x);
239
+ else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
240
+ }
241
+ for (; i >= 0; i--) out.push(a[i]);
242
+ for (; j >= 0; j--) out.push(b[j]);
243
+ return out.reverse();
244
+ }
210
245
  function recursiveFlatten(ar) {
211
246
  if (!Array.isArray(ar)) return [ar];
212
247
  return ar.flat(Infinity);
@@ -295,12 +330,12 @@ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float
295
330
  * **Type lattice:**
296
331
  * ```text
297
332
  * bool -> uint32 -> int32 -> float16 -> float32
298
- * weak f* --^
333
+ * weakType --^
299
334
  * ```
300
335
  *
301
- * The asterisk f* is a weak type used for JS number constants. When creating
302
- * arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
303
- * any array they are first combined with.
336
+ * `weakType` represents weakly typed arrays. These are created for JS numbers,
337
+ * which default to float32 but "weak" so they cast to the dtype of any array
338
+ * they are first combined with, except `bool`.
304
339
  *
305
340
  * **Examples:**
306
341
  * - `promoteTypes(bool, int32) → int32`
@@ -401,6 +436,12 @@ var AluExp = class AluExp {
401
436
  static log(a) {
402
437
  return new AluExp(AluOp.Log, a.dtype, [a]);
403
438
  }
439
+ static erf(a) {
440
+ return new AluExp(AluOp.Erf, a.dtype, [a]);
441
+ }
442
+ static erfc(a) {
443
+ return new AluExp(AluOp.Erfc, a.dtype, [a]);
444
+ }
404
445
  static sqrt(a) {
405
446
  return new AluExp(AluOp.Sqrt, a.dtype, [a]);
406
447
  }
@@ -571,6 +612,12 @@ var AluExp = class AluExp {
571
612
  case AluOp.Log:
572
613
  ret = [Math.log(src[0].min), Math.log(src[0].max)];
573
614
  break;
615
+ case AluOp.Erf:
616
+ ret = [erf(src[0].min), erf(src[0].max)];
617
+ break;
618
+ case AluOp.Erfc:
619
+ ret = [erfc(src[0].max), erfc(src[0].min)];
620
+ break;
574
621
  case AluOp.Sqrt:
575
622
  ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
576
623
  break;
@@ -892,6 +939,8 @@ var AluExp = class AluExp {
892
939
  case AluOp.Atan: return Math.atan(x);
893
940
  case AluOp.Exp: return Math.exp(x);
894
941
  case AluOp.Log: return Math.log(x);
942
+ case AluOp.Erf: return erf(x);
943
+ case AluOp.Erfc: return erfc(x);
895
944
  case AluOp.Sqrt: return Math.sqrt(x);
896
945
  case AluOp.Reciprocal: return 1 / x;
897
946
  case AluOp.Cast: {
@@ -1040,11 +1089,15 @@ var AluExp = class AluExp {
1040
1089
  });
1041
1090
  return result;
1042
1091
  }
1043
- /** Produce a list of all distinct AluOp in this expression. */
1092
+ /** Produce all distinct AluOp in this expression, with their dtypes. */
1044
1093
  distinctOps() {
1045
- const ops = /* @__PURE__ */ new Set();
1094
+ const ops = /* @__PURE__ */ new Map();
1046
1095
  this.fold((exp) => {
1047
- ops.add(exp.op);
1096
+ const s = ops.get(exp.op) ?? /* @__PURE__ */ new Set();
1097
+ if (!s.has(exp.dtype)) {
1098
+ s.add(exp.dtype);
1099
+ ops.set(exp.op, s);
1100
+ }
1048
1101
  });
1049
1102
  return ops;
1050
1103
  }
@@ -1073,6 +1126,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
1073
1126
  AluOp$1["Atan"] = "Atan";
1074
1127
  AluOp$1["Exp"] = "Exp";
1075
1128
  AluOp$1["Log"] = "Log";
1129
+ AluOp$1["Erf"] = "Erf";
1130
+ AluOp$1["Erfc"] = "Erfc";
1076
1131
  AluOp$1["Sqrt"] = "Sqrt";
1077
1132
  AluOp$1["Reciprocal"] = "Reciprocal";
1078
1133
  AluOp$1["Cast"] = "Cast";
@@ -1105,6 +1160,8 @@ const AluGroup = {
1105
1160
  AluOp.Atan,
1106
1161
  AluOp.Exp,
1107
1162
  AluOp.Log,
1163
+ AluOp.Erf,
1164
+ AluOp.Erfc,
1108
1165
  AluOp.Sqrt,
1109
1166
  AluOp.Reciprocal,
1110
1167
  AluOp.Cast,
@@ -1130,6 +1187,8 @@ const AluGroup = {
1130
1187
  AluOp.Atan,
1131
1188
  AluOp.Exp,
1132
1189
  AluOp.Log,
1190
+ AluOp.Erf,
1191
+ AluOp.Erfc,
1133
1192
  AluOp.Sqrt,
1134
1193
  AluOp.Reciprocal
1135
1194
  ])
@@ -1305,6 +1364,44 @@ function threefry2x32(k0, k1, c0, c1) {
1305
1364
  x1 = x1 + ks0 + 5 >>> 0;
1306
1365
  return [x0, x1];
1307
1366
  }
1367
+ /**
1368
+ * Abramowitz & Stegun’s widely used approximation for erf(x).
1369
+ *
1370
+ * `erf(x) = 1 - P(t) * exp(-x^2)` for `x >= 0`, where `t = 1/(1 + p*x)` and
1371
+ * `P(t) = a1*t + a2*t^2 + a3*t^3 + a4*t^4 + a5*t^5`.
1372
+ *
1373
+ * Coefficients:
1374
+ * - p = 0.3275911
1375
+ * - a1 = 0.254829592
1376
+ * - a2 = -0.284496736
1377
+ * - a3 = 1.421413741
1378
+ * - a4 = -1.453152027
1379
+ * - a5 = 1.061405429
1380
+ *
1381
+ * This function computes just `E = P(t) * exp(-x^2)` for numerical reasons. The
1382
+ * input is assumed to be non-negative.
1383
+ *
1384
+ * Reference: https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions
1385
+ */
1386
+ function _erfapprox$1(x) {
1387
+ const p = .3275911;
1388
+ const a1 = .254829592;
1389
+ const a2 = -.284496736;
1390
+ const a3 = 1.421413741;
1391
+ const a4 = -1.453152027;
1392
+ const a5 = 1.061405429;
1393
+ const t = 1 / (1 + p * x);
1394
+ const P_t = ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t;
1395
+ return P_t * Math.exp(-x * x);
1396
+ }
1397
+ function erf(x) {
1398
+ if (x >= 0) return 1 - _erfapprox$1(x);
1399
+ else return _erfapprox$1(-x) - 1;
1400
+ }
1401
+ function erfc(x) {
1402
+ if (x >= 0) return _erfapprox$1(x);
1403
+ else return 2 - _erfapprox$1(-x);
1404
+ }
1308
1405
 
1309
1406
  //#endregion
1310
1407
  //#region src/shape.ts
@@ -1990,7 +2087,7 @@ function tuneWebgpu(kernel) {
1990
2087
  }
1991
2088
  if (/chrome/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
1992
2089
  const s = dim.st.shape[dim.unroll - 1];
1993
- if (s <= 32) dim.applyUnroll(dim.reduce, s);
2090
+ if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
1994
2091
  else for (const splits of [4]) if (s % splits === 0) {
1995
2092
  dim.applyUnroll(dim.unroll - 1, splits);
1996
2093
  break;
@@ -2209,6 +2306,19 @@ var WasmAllocator = class {
2209
2306
 
2210
2307
  //#endregion
2211
2308
  //#region src/backend/wasm/builtins.ts
2309
+ /** Given a local `x`, evaluate `sum[i](a_i * x^i)` and push to stack. */
2310
+ function _poly(cg, x, as) {
2311
+ if (as.length === 0) throw new Error("_poly needs at least one coefficient");
2312
+ cg.f32.const(as[as.length - 1]);
2313
+ for (let i = as.length - 2; i >= 0; i--) {
2314
+ cg.local.get(x);
2315
+ cg.f32.mul();
2316
+ if (as[i] !== 0) {
2317
+ cg.f32.const(as[i]);
2318
+ cg.f32.add();
2319
+ }
2320
+ }
2321
+ }
2212
2322
  /**
2213
2323
  * Approximate e^x.
2214
2324
  *
@@ -2249,27 +2359,15 @@ function wasm_exp(cg) {
2249
2359
  cg.f32.mul();
2250
2360
  cg.f32.sub();
2251
2361
  cg.local.set(r);
2252
- cg.f32.const(1 / 120);
2253
- cg.local.get(r);
2254
- cg.f32.mul();
2255
- cg.f32.const(1 / 24);
2256
- cg.f32.add();
2257
- cg.local.get(r);
2258
- cg.f32.mul();
2259
- cg.f32.const(1 / 6);
2260
- cg.f32.add();
2261
- cg.local.get(r);
2262
- cg.f32.mul();
2263
- cg.f32.const(1 / 2);
2264
- cg.f32.add();
2265
- cg.local.get(r);
2266
- cg.f32.mul();
2267
- cg.f32.const(1);
2268
- cg.f32.add();
2269
- cg.local.get(r);
2270
- cg.f32.mul();
2271
- cg.f32.const(1);
2272
- cg.f32.add();
2362
+ _poly(cg, r, [
2363
+ 1,
2364
+ 1,
2365
+ 1 / 2,
2366
+ 1 / 6,
2367
+ 1 / 24,
2368
+ 1 / 120,
2369
+ 1 / 720
2370
+ ]);
2273
2371
  cg.local.set(p);
2274
2372
  cg.local.get(k);
2275
2373
  cg.i32.const(127);
@@ -2297,11 +2395,6 @@ function wasm_log(cg) {
2297
2395
  const m = cg.local.declare(cg.f32);
2298
2396
  const t = cg.local.declare(cg.f32);
2299
2397
  const t2 = cg.local.declare(cg.f32);
2300
- const t3 = cg.local.declare(cg.f32);
2301
- const t5 = cg.local.declare(cg.f32);
2302
- const t7 = cg.local.declare(cg.f32);
2303
- const lnm = cg.local.declare(cg.f32);
2304
- const el2 = cg.local.declare(cg.f32);
2305
2398
  cg.local.get(0);
2306
2399
  cg.f32.const(0);
2307
2400
  cg.f32.le();
@@ -2338,41 +2431,18 @@ function wasm_log(cg) {
2338
2431
  cg.local.get(t);
2339
2432
  cg.f32.mul();
2340
2433
  cg.local.set(t2);
2434
+ _poly(cg, t2, [
2435
+ 2,
2436
+ 2 / 3,
2437
+ 2 / 5,
2438
+ 2 / 7
2439
+ ]);
2341
2440
  cg.local.get(t);
2342
- cg.local.get(t2);
2343
- cg.f32.mul();
2344
- cg.local.set(t3);
2345
- cg.local.get(t3);
2346
- cg.local.get(t2);
2347
- cg.f32.mul();
2348
- cg.local.set(t5);
2349
- cg.local.get(t5);
2350
- cg.local.get(t2);
2351
- cg.f32.mul();
2352
- cg.local.set(t7);
2353
- cg.local.get(t7);
2354
- cg.f32.const(1 / 7);
2355
- cg.f32.mul();
2356
- cg.local.get(t5);
2357
- cg.f32.const(1 / 5);
2358
- cg.f32.mul();
2359
- cg.f32.add();
2360
- cg.local.get(t3);
2361
- cg.f32.const(1 / 3);
2362
- cg.f32.mul();
2363
- cg.f32.add();
2364
- cg.local.get(t);
2365
- cg.f32.add();
2366
- cg.f32.const(2);
2367
2441
  cg.f32.mul();
2368
- cg.local.set(lnm);
2369
2442
  cg.local.get(e);
2370
2443
  cg.f32.convert_i32_s();
2371
2444
  cg.f32.const(Math.LN2);
2372
2445
  cg.f32.mul();
2373
- cg.local.set(el2);
2374
- cg.local.get(el2);
2375
- cg.local.get(lnm);
2376
2446
  cg.f32.add();
2377
2447
  });
2378
2448
  }
@@ -2382,7 +2452,7 @@ function wasm_log(cg) {
2382
2452
  * Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
2383
2453
  * z = y - q*(π/2); use one of two polynomials on z:
2384
2454
  * sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
2385
- * cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
2455
+ * cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720) + z^8*(1/40320)
2386
2456
  */
2387
2457
  function _sincos(cg) {
2388
2458
  const y = cg.local.declare(cg.f32);
@@ -2418,35 +2488,22 @@ function _sincos(cg) {
2418
2488
  cg.local.get(z);
2419
2489
  cg.f32.mul();
2420
2490
  cg.local.set(z2);
2421
- cg.f32.const(-1 / 5040);
2422
- cg.local.get(z2);
2423
- cg.f32.mul();
2424
- cg.f32.const(1 / 120);
2425
- cg.f32.add();
2426
- cg.local.get(z2);
2427
- cg.f32.mul();
2428
- cg.f32.const(-1 / 6);
2429
- cg.f32.add();
2430
- cg.local.get(z2);
2431
- cg.f32.mul();
2432
- cg.f32.const(1);
2433
- cg.f32.add();
2491
+ _poly(cg, z2, [
2492
+ 1,
2493
+ -1 / 6,
2494
+ 1 / 120,
2495
+ -1 / 5040
2496
+ ]);
2434
2497
  cg.local.get(z);
2435
2498
  cg.f32.mul();
2436
2499
  cg.local.set(sz);
2437
- cg.f32.const(-1 / 720);
2438
- cg.local.get(z2);
2439
- cg.f32.mul();
2440
- cg.f32.const(1 / 24);
2441
- cg.f32.add();
2442
- cg.local.get(z2);
2443
- cg.f32.mul();
2444
- cg.f32.const(-1 / 2);
2445
- cg.f32.add();
2446
- cg.local.get(z2);
2447
- cg.f32.mul();
2448
- cg.f32.const(1);
2449
- cg.f32.add();
2500
+ _poly(cg, z2, [
2501
+ 1,
2502
+ -1 / 2,
2503
+ 1 / 24,
2504
+ -1 / 720,
2505
+ 1 / 40320
2506
+ ]);
2450
2507
  cg.local.set(cz);
2451
2508
  return {
2452
2509
  q,
@@ -2528,24 +2585,16 @@ function _atan(cg) {
2528
2585
  cg.local.get(z);
2529
2586
  cg.f32.mul();
2530
2587
  cg.local.set(z2);
2531
- cg.f32.const(.0415796528637);
2532
- cg.local.get(z2);
2533
- cg.f32.mul();
2534
- cg.f32.const(.661705427875);
2535
- cg.f32.add();
2536
- cg.local.get(z2);
2537
- cg.f32.mul();
2538
- cg.f32.const(.999998614341);
2539
- cg.f32.add();
2540
- cg.f32.const(.173698870181);
2541
- cg.local.get(z2);
2542
- cg.f32.mul();
2543
- cg.f32.const(.994987933645);
2544
- cg.f32.add();
2545
- cg.local.get(z2);
2546
- cg.f32.mul();
2547
- cg.f32.const(1);
2548
- cg.f32.add();
2588
+ _poly(cg, z2, [
2589
+ .999998614341,
2590
+ .661705427875,
2591
+ .0415796528637
2592
+ ]);
2593
+ _poly(cg, z2, [
2594
+ 1,
2595
+ .994987933645,
2596
+ .173698870181
2597
+ ]);
2549
2598
  cg.f32.div();
2550
2599
  cg.local.get(z);
2551
2600
  cg.f32.mul();
@@ -2599,6 +2648,74 @@ function wasm_asin(cg) {
2599
2648
  });
2600
2649
  }
2601
2650
  /**
2651
+ * Helper function for erf/erfc approximation.
2652
+ *
2653
+ * See `_erfapprox` in alu.ts for details on the algorithm used.
2654
+ */
2655
+ function _erfapprox(cg, exp_func) {
2656
+ const x = cg.local.declare(cg.f32);
2657
+ const t = cg.local.declare(cg.f32);
2658
+ cg.local.set(x);
2659
+ const p = .3275911;
2660
+ const a1 = .254829592;
2661
+ const a2 = -.284496736;
2662
+ const a3 = 1.421413741;
2663
+ const a4 = -1.453152027;
2664
+ const a5 = 1.061405429;
2665
+ cg.f32.const(1);
2666
+ cg.f32.const(1);
2667
+ cg.f32.const(p);
2668
+ cg.local.get(x);
2669
+ cg.f32.mul();
2670
+ cg.f32.add();
2671
+ cg.f32.div();
2672
+ cg.local.set(t);
2673
+ _poly(cg, t, [
2674
+ 0,
2675
+ a1,
2676
+ a2,
2677
+ a3,
2678
+ a4,
2679
+ a5
2680
+ ]);
2681
+ cg.local.get(x);
2682
+ cg.f32.neg();
2683
+ cg.local.get(x);
2684
+ cg.f32.mul();
2685
+ cg.call(exp_func);
2686
+ cg.f32.mul();
2687
+ }
2688
+ /** Approximate erf(x) (error function). */
2689
+ function wasm_erf(cg, exp) {
2690
+ return cg.function([cg.f32], [cg.f32], () => {
2691
+ cg.f32.const(1);
2692
+ cg.local.get(0);
2693
+ cg.f32.abs();
2694
+ _erfapprox(cg, exp);
2695
+ cg.f32.sub();
2696
+ cg.local.get(0);
2697
+ cg.f32.copysign();
2698
+ });
2699
+ }
2700
+ /** Approximate erfc(x) (complementary error function). */
2701
+ function wasm_erfc(cg, exp) {
2702
+ return cg.function([cg.f32], [cg.f32], () => {
2703
+ const e = cg.local.declare(cg.f32);
2704
+ cg.local.get(0);
2705
+ cg.f32.abs();
2706
+ _erfapprox(cg, exp);
2707
+ cg.local.set(e);
2708
+ cg.f32.const(2);
2709
+ cg.local.get(e);
2710
+ cg.f32.sub();
2711
+ cg.local.get(e);
2712
+ cg.local.get(0);
2713
+ cg.f32.const(0);
2714
+ cg.f32.lt();
2715
+ cg.select();
2716
+ });
2717
+ }
2718
+ /**
2602
2719
  * Threefry2x32 pseudorandom number generator.
2603
2720
  *
2604
2721
  * Takes two 32-bit keys and two 32-bit counters as input,
@@ -3473,14 +3590,16 @@ function codegenWasm(kernel) {
3473
3590
  if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
3474
3591
  const cg = new CodeGenerator();
3475
3592
  cg.memory.import("env", "memory");
3476
- const distinctOps = union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
3593
+ const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
3477
3594
  const funcs = {};
3478
3595
  if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
3479
3596
  if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
3480
3597
  if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
3481
3598
  if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
3482
- if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
3599
+ if (distinctOps.has(AluOp.Exp) || distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) funcs.exp = wasm_exp(cg);
3483
3600
  if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
3601
+ if (distinctOps.has(AluOp.Erf)) funcs.erf = wasm_erf(cg, funcs.exp);
3602
+ if (distinctOps.has(AluOp.Erfc)) funcs.erfc = wasm_erfc(cg, funcs.exp);
3484
3603
  if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
3485
3604
  const kernelFunc = cg.function(rep(kernel.nargs + 1, cg.i32), [], () => {
3486
3605
  const gidx = cg.local.declare(cg.i32);
@@ -3634,6 +3753,8 @@ function translateExp(cg, funcs, exp, ctx) {
3634
3753
  else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
3635
3754
  else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
3636
3755
  else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
3756
+ else if (op === AluOp.Erf) gen(src[0]), cg.call(funcs.erf);
3757
+ else if (op === AluOp.Erfc) gen(src[0]), cg.call(funcs.erfc);
3637
3758
  else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
3638
3759
  else if (op === AluOp.Reciprocal) cg.f32.const(1), gen(src[0]), cg.f32.div();
3639
3760
  else if (op === AluOp.Cast) {
@@ -3761,7 +3882,7 @@ async function createBackend(device) {
3761
3882
  if (!navigator.gpu) return null;
3762
3883
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3763
3884
  if (!adapter) return null;
3764
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BVdMaO9T.cjs"));
3885
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BE7zA_01.cjs"));
3765
3886
  const importantLimits = [
3766
3887
  "maxBufferSize",
3767
3888
  "maxComputeInvocationsPerWorkgroup",
@@ -3910,6 +4031,12 @@ Object.defineProperty(exports, 'accessorGlobal', {
3910
4031
  return accessorGlobal;
3911
4032
  }
3912
4033
  });
4034
+ Object.defineProperty(exports, 'assertNonNull', {
4035
+ enumerable: true,
4036
+ get: function () {
4037
+ return assertNonNull;
4038
+ }
4039
+ });
3913
4040
  Object.defineProperty(exports, 'byteWidth', {
3914
4041
  enumerable: true,
3915
4042
  get: function () {
@@ -3958,6 +4085,12 @@ Object.defineProperty(exports, 'findPow2', {
3958
4085
  return findPow2;
3959
4086
  }
3960
4087
  });
4088
+ Object.defineProperty(exports, 'generalBroadcast', {
4089
+ enumerable: true,
4090
+ get: function () {
4091
+ return generalBroadcast;
4092
+ }
4093
+ });
3961
4094
  Object.defineProperty(exports, 'getBackend', {
3962
4095
  enumerable: true,
3963
4096
  get: function () {
@@ -3994,6 +4127,12 @@ Object.defineProperty(exports, 'isPermutation', {
3994
4127
  return isPermutation;
3995
4128
  }
3996
4129
  });
4130
+ Object.defineProperty(exports, 'mapSetUnion', {
4131
+ enumerable: true,
4132
+ get: function () {
4133
+ return mapSetUnion;
4134
+ }
4135
+ });
3997
4136
  Object.defineProperty(exports, 'normalizeAxis', {
3998
4137
  enumerable: true,
3999
4138
  get: function () {
@@ -4066,12 +4205,6 @@ Object.defineProperty(exports, 'tuneWebgpu', {
4066
4205
  return tuneWebgpu;
4067
4206
  }
4068
4207
  });
4069
- Object.defineProperty(exports, 'union', {
4070
- enumerable: true,
4071
- get: function () {
4072
- return union;
4073
- }
4074
- });
4075
4208
  Object.defineProperty(exports, 'unravelAlu', {
4076
4209
  enumerable: true,
4077
4210
  get: function () {
@@ -4095,4 +4228,5 @@ Object.defineProperty(exports, 'zipn', {
4095
4228
  get: function () {
4096
4229
  return zipn;
4097
4230
  }
4098
- });
4231
+ });
4232
+ //# sourceMappingURL=backend-FtkbO6pI.cjs.map