@jax-js/jax 0.0.4 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -51,6 +51,7 @@ let DEBUG = 0;
51
51
  function setDebug(level) {
52
52
  DEBUG = level;
53
53
  }
54
+ function assertNonNull(value) {}
54
55
  function unzip2(pairs) {
55
56
  const lst1 = [];
56
57
  const lst2 = [];
@@ -96,10 +97,15 @@ function deepEqual(a, b) {
96
97
  for (const key of Object.keys(a)) if (!deepEqual(a[key], b[key])) return false;
97
98
  return true;
98
99
  }
99
- function union(...sets) {
100
- const result = /* @__PURE__ */ new Set();
101
- for (const s of sets) if (s) for (const x of s) result.add(x);
102
- return result;
100
+ /** Produces a union of maps of sets. This mutates `a`. */
101
+ function mapSetUnion(a, b) {
102
+ if (!b) return a;
103
+ for (const [key, setB] of b.entries()) {
104
+ const setA = a.get(key);
105
+ if (setA) for (const val of setB) setA.add(val);
106
+ else a.set(key, setB);
107
+ }
108
+ return a;
103
109
  }
104
110
  /** Splits the list based on a condition, `false` first then `true`. */
105
111
  function partitionList(which, array) {
@@ -206,6 +212,35 @@ function findPow2(hint, max) {
206
212
  while (ret < hint && 2 * ret <= max) ret *= 2;
207
213
  return ret;
208
214
  }
215
+ /**
216
+ * Implements a NumPy-style generalized broadcast rule on two array shapes.
217
+ *
218
+ * "When operating on two arrays, NumPy compares their shapes element-wise. It
219
+ * starts with the trailing (i.e. rightmost) dimension and works its way left.
220
+ * Two dimensions are compatible when:
221
+ * 1. they are equal, or
222
+ * 2. one of them is 1."
223
+ *
224
+ * Throws a TypeError if the broadcast is not possible.
225
+ *
226
+ * <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
227
+ */
228
+ function generalBroadcast(a, b) {
229
+ const out = [];
230
+ let i = a.length - 1;
231
+ let j = b.length - 1;
232
+ for (; i >= 0 && j >= 0; i--, j--) {
233
+ const x = a[i];
234
+ const y = b[j];
235
+ if (x === y) out.push(x);
236
+ else if (x === 1) out.push(y);
237
+ else if (y === 1) out.push(x);
238
+ else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
239
+ }
240
+ for (; i >= 0; i--) out.push(a[i]);
241
+ for (; j >= 0; j--) out.push(b[j]);
242
+ return out.reverse();
243
+ }
209
244
  function recursiveFlatten(ar) {
210
245
  if (!Array.isArray(ar)) return [ar];
211
246
  return ar.flat(Infinity);
@@ -294,12 +329,12 @@ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float
294
329
  * **Type lattice:**
295
330
  * ```text
296
331
  * bool -> uint32 -> int32 -> float16 -> float32
297
- * weak f* --^
332
+ * weakType --^
298
333
  * ```
299
334
  *
300
- * The asterisk f* is a weak type used for JS number constants. When creating
301
- * arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
302
- * any array they are first combined with.
335
+ * `weakType` represents weakly typed arrays. These are created for JS numbers,
336
+ * which default to float32 but "weak" so they cast to the dtype of any array
337
+ * they are first combined with, except `bool`.
303
338
  *
304
339
  * **Examples:**
305
340
  * - `promoteTypes(bool, int32) → int32`
@@ -400,6 +435,12 @@ var AluExp = class AluExp {
400
435
  static log(a) {
401
436
  return new AluExp(AluOp.Log, a.dtype, [a]);
402
437
  }
438
+ static erf(a) {
439
+ return new AluExp(AluOp.Erf, a.dtype, [a]);
440
+ }
441
+ static erfc(a) {
442
+ return new AluExp(AluOp.Erfc, a.dtype, [a]);
443
+ }
403
444
  static sqrt(a) {
404
445
  return new AluExp(AluOp.Sqrt, a.dtype, [a]);
405
446
  }
@@ -570,6 +611,12 @@ var AluExp = class AluExp {
570
611
  case AluOp.Log:
571
612
  ret = [Math.log(src[0].min), Math.log(src[0].max)];
572
613
  break;
614
+ case AluOp.Erf:
615
+ ret = [erf(src[0].min), erf(src[0].max)];
616
+ break;
617
+ case AluOp.Erfc:
618
+ ret = [erfc(src[0].max), erfc(src[0].min)];
619
+ break;
573
620
  case AluOp.Sqrt:
574
621
  ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
575
622
  break;
@@ -891,6 +938,8 @@ var AluExp = class AluExp {
891
938
  case AluOp.Atan: return Math.atan(x);
892
939
  case AluOp.Exp: return Math.exp(x);
893
940
  case AluOp.Log: return Math.log(x);
941
+ case AluOp.Erf: return erf(x);
942
+ case AluOp.Erfc: return erfc(x);
894
943
  case AluOp.Sqrt: return Math.sqrt(x);
895
944
  case AluOp.Reciprocal: return 1 / x;
896
945
  case AluOp.Cast: {
@@ -1039,11 +1088,15 @@ var AluExp = class AluExp {
1039
1088
  });
1040
1089
  return result;
1041
1090
  }
1042
- /** Produce a list of all distinct AluOp in this expression. */
1091
+ /** Produce all distinct AluOp in this expression, with their dtypes. */
1043
1092
  distinctOps() {
1044
- const ops = /* @__PURE__ */ new Set();
1093
+ const ops = /* @__PURE__ */ new Map();
1045
1094
  this.fold((exp) => {
1046
- ops.add(exp.op);
1095
+ const s = ops.get(exp.op) ?? /* @__PURE__ */ new Set();
1096
+ if (!s.has(exp.dtype)) {
1097
+ s.add(exp.dtype);
1098
+ ops.set(exp.op, s);
1099
+ }
1047
1100
  });
1048
1101
  return ops;
1049
1102
  }
@@ -1072,6 +1125,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
1072
1125
  AluOp$1["Atan"] = "Atan";
1073
1126
  AluOp$1["Exp"] = "Exp";
1074
1127
  AluOp$1["Log"] = "Log";
1128
+ AluOp$1["Erf"] = "Erf";
1129
+ AluOp$1["Erfc"] = "Erfc";
1075
1130
  AluOp$1["Sqrt"] = "Sqrt";
1076
1131
  AluOp$1["Reciprocal"] = "Reciprocal";
1077
1132
  AluOp$1["Cast"] = "Cast";
@@ -1104,6 +1159,8 @@ const AluGroup = {
1104
1159
  AluOp.Atan,
1105
1160
  AluOp.Exp,
1106
1161
  AluOp.Log,
1162
+ AluOp.Erf,
1163
+ AluOp.Erfc,
1107
1164
  AluOp.Sqrt,
1108
1165
  AluOp.Reciprocal,
1109
1166
  AluOp.Cast,
@@ -1129,6 +1186,8 @@ const AluGroup = {
1129
1186
  AluOp.Atan,
1130
1187
  AluOp.Exp,
1131
1188
  AluOp.Log,
1189
+ AluOp.Erf,
1190
+ AluOp.Erfc,
1132
1191
  AluOp.Sqrt,
1133
1192
  AluOp.Reciprocal
1134
1193
  ])
@@ -1304,6 +1363,44 @@ function threefry2x32(k0, k1, c0, c1) {
1304
1363
  x1 = x1 + ks0 + 5 >>> 0;
1305
1364
  return [x0, x1];
1306
1365
  }
1366
+ /**
1367
+ * Abramowitz & Stegun’s widely used approximation for erf(x).
1368
+ *
1369
+ * `erf(x) = 1 - P(t) * exp(-x^2)` for `x >= 0`, where `t = 1/(1 + p*x)` and
1370
+ * `P(t) = a1*t + a2*t^2 + a3*t^3 + a4*t^4 + a5*t^5`.
1371
+ *
1372
+ * Coefficients:
1373
+ * - p = 0.3275911
1374
+ * - a1 = 0.254829592
1375
+ * - a2 = -0.284496736
1376
+ * - a3 = 1.421413741
1377
+ * - a4 = -1.453152027
1378
+ * - a5 = 1.061405429
1379
+ *
1380
+ * This function computes just `E = P(t) * exp(-x^2)` for numerical reasons. The
1381
+ * input is assumed to be non-negative.
1382
+ *
1383
+ * Reference: https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions
1384
+ */
1385
+ function _erfapprox$1(x) {
1386
+ const p = .3275911;
1387
+ const a1 = .254829592;
1388
+ const a2 = -.284496736;
1389
+ const a3 = 1.421413741;
1390
+ const a4 = -1.453152027;
1391
+ const a5 = 1.061405429;
1392
+ const t = 1 / (1 + p * x);
1393
+ const P_t = ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t;
1394
+ return P_t * Math.exp(-x * x);
1395
+ }
1396
+ function erf(x) {
1397
+ if (x >= 0) return 1 - _erfapprox$1(x);
1398
+ else return _erfapprox$1(-x) - 1;
1399
+ }
1400
+ function erfc(x) {
1401
+ if (x >= 0) return _erfapprox$1(x);
1402
+ else return 2 - _erfapprox$1(-x);
1403
+ }
1307
1404
 
1308
1405
  //#endregion
1309
1406
  //#region src/shape.ts
@@ -1989,7 +2086,7 @@ function tuneWebgpu(kernel) {
1989
2086
  }
1990
2087
  if (/chrome/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
1991
2088
  const s = dim.st.shape[dim.unroll - 1];
1992
- if (s <= 32) dim.applyUnroll(dim.reduce, s);
2089
+ if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
1993
2090
  else for (const splits of [4]) if (s % splits === 0) {
1994
2091
  dim.applyUnroll(dim.unroll - 1, splits);
1995
2092
  break;
@@ -2208,6 +2305,19 @@ var WasmAllocator = class {
2208
2305
 
2209
2306
  //#endregion
2210
2307
  //#region src/backend/wasm/builtins.ts
2308
+ /** Given a local `x`, evaluate `sum[i](a_i * x^i)` and push to stack. */
2309
+ function _poly(cg, x, as) {
2310
+ if (as.length === 0) throw new Error("_poly needs at least one coefficient");
2311
+ cg.f32.const(as[as.length - 1]);
2312
+ for (let i = as.length - 2; i >= 0; i--) {
2313
+ cg.local.get(x);
2314
+ cg.f32.mul();
2315
+ if (as[i] !== 0) {
2316
+ cg.f32.const(as[i]);
2317
+ cg.f32.add();
2318
+ }
2319
+ }
2320
+ }
2211
2321
  /**
2212
2322
  * Approximate e^x.
2213
2323
  *
@@ -2248,27 +2358,15 @@ function wasm_exp(cg) {
2248
2358
  cg.f32.mul();
2249
2359
  cg.f32.sub();
2250
2360
  cg.local.set(r);
2251
- cg.f32.const(1 / 120);
2252
- cg.local.get(r);
2253
- cg.f32.mul();
2254
- cg.f32.const(1 / 24);
2255
- cg.f32.add();
2256
- cg.local.get(r);
2257
- cg.f32.mul();
2258
- cg.f32.const(1 / 6);
2259
- cg.f32.add();
2260
- cg.local.get(r);
2261
- cg.f32.mul();
2262
- cg.f32.const(1 / 2);
2263
- cg.f32.add();
2264
- cg.local.get(r);
2265
- cg.f32.mul();
2266
- cg.f32.const(1);
2267
- cg.f32.add();
2268
- cg.local.get(r);
2269
- cg.f32.mul();
2270
- cg.f32.const(1);
2271
- cg.f32.add();
2361
+ _poly(cg, r, [
2362
+ 1,
2363
+ 1,
2364
+ 1 / 2,
2365
+ 1 / 6,
2366
+ 1 / 24,
2367
+ 1 / 120,
2368
+ 1 / 720
2369
+ ]);
2272
2370
  cg.local.set(p);
2273
2371
  cg.local.get(k);
2274
2372
  cg.i32.const(127);
@@ -2296,11 +2394,6 @@ function wasm_log(cg) {
2296
2394
  const m = cg.local.declare(cg.f32);
2297
2395
  const t = cg.local.declare(cg.f32);
2298
2396
  const t2 = cg.local.declare(cg.f32);
2299
- const t3 = cg.local.declare(cg.f32);
2300
- const t5 = cg.local.declare(cg.f32);
2301
- const t7 = cg.local.declare(cg.f32);
2302
- const lnm = cg.local.declare(cg.f32);
2303
- const el2 = cg.local.declare(cg.f32);
2304
2397
  cg.local.get(0);
2305
2398
  cg.f32.const(0);
2306
2399
  cg.f32.le();
@@ -2337,41 +2430,18 @@ function wasm_log(cg) {
2337
2430
  cg.local.get(t);
2338
2431
  cg.f32.mul();
2339
2432
  cg.local.set(t2);
2433
+ _poly(cg, t2, [
2434
+ 2,
2435
+ 2 / 3,
2436
+ 2 / 5,
2437
+ 2 / 7
2438
+ ]);
2340
2439
  cg.local.get(t);
2341
- cg.local.get(t2);
2342
- cg.f32.mul();
2343
- cg.local.set(t3);
2344
- cg.local.get(t3);
2345
- cg.local.get(t2);
2346
- cg.f32.mul();
2347
- cg.local.set(t5);
2348
- cg.local.get(t5);
2349
- cg.local.get(t2);
2350
- cg.f32.mul();
2351
- cg.local.set(t7);
2352
- cg.local.get(t7);
2353
- cg.f32.const(1 / 7);
2354
- cg.f32.mul();
2355
- cg.local.get(t5);
2356
- cg.f32.const(1 / 5);
2357
- cg.f32.mul();
2358
- cg.f32.add();
2359
- cg.local.get(t3);
2360
- cg.f32.const(1 / 3);
2361
- cg.f32.mul();
2362
- cg.f32.add();
2363
- cg.local.get(t);
2364
- cg.f32.add();
2365
- cg.f32.const(2);
2366
2440
  cg.f32.mul();
2367
- cg.local.set(lnm);
2368
2441
  cg.local.get(e);
2369
2442
  cg.f32.convert_i32_s();
2370
2443
  cg.f32.const(Math.LN2);
2371
2444
  cg.f32.mul();
2372
- cg.local.set(el2);
2373
- cg.local.get(el2);
2374
- cg.local.get(lnm);
2375
2445
  cg.f32.add();
2376
2446
  });
2377
2447
  }
@@ -2381,7 +2451,7 @@ function wasm_log(cg) {
2381
2451
  * Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
2382
2452
  * z = y - q*(π/2); use one of two polynomials on z:
2383
2453
  * sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
2384
- * cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
2454
+ * cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720) + z^8*(1/40320)
2385
2455
  */
2386
2456
  function _sincos(cg) {
2387
2457
  const y = cg.local.declare(cg.f32);
@@ -2417,35 +2487,22 @@ function _sincos(cg) {
2417
2487
  cg.local.get(z);
2418
2488
  cg.f32.mul();
2419
2489
  cg.local.set(z2);
2420
- cg.f32.const(-1 / 5040);
2421
- cg.local.get(z2);
2422
- cg.f32.mul();
2423
- cg.f32.const(1 / 120);
2424
- cg.f32.add();
2425
- cg.local.get(z2);
2426
- cg.f32.mul();
2427
- cg.f32.const(-1 / 6);
2428
- cg.f32.add();
2429
- cg.local.get(z2);
2430
- cg.f32.mul();
2431
- cg.f32.const(1);
2432
- cg.f32.add();
2490
+ _poly(cg, z2, [
2491
+ 1,
2492
+ -1 / 6,
2493
+ 1 / 120,
2494
+ -1 / 5040
2495
+ ]);
2433
2496
  cg.local.get(z);
2434
2497
  cg.f32.mul();
2435
2498
  cg.local.set(sz);
2436
- cg.f32.const(-1 / 720);
2437
- cg.local.get(z2);
2438
- cg.f32.mul();
2439
- cg.f32.const(1 / 24);
2440
- cg.f32.add();
2441
- cg.local.get(z2);
2442
- cg.f32.mul();
2443
- cg.f32.const(-1 / 2);
2444
- cg.f32.add();
2445
- cg.local.get(z2);
2446
- cg.f32.mul();
2447
- cg.f32.const(1);
2448
- cg.f32.add();
2499
+ _poly(cg, z2, [
2500
+ 1,
2501
+ -1 / 2,
2502
+ 1 / 24,
2503
+ -1 / 720,
2504
+ 1 / 40320
2505
+ ]);
2449
2506
  cg.local.set(cz);
2450
2507
  return {
2451
2508
  q,
@@ -2527,24 +2584,16 @@ function _atan(cg) {
2527
2584
  cg.local.get(z);
2528
2585
  cg.f32.mul();
2529
2586
  cg.local.set(z2);
2530
- cg.f32.const(.0415796528637);
2531
- cg.local.get(z2);
2532
- cg.f32.mul();
2533
- cg.f32.const(.661705427875);
2534
- cg.f32.add();
2535
- cg.local.get(z2);
2536
- cg.f32.mul();
2537
- cg.f32.const(.999998614341);
2538
- cg.f32.add();
2539
- cg.f32.const(.173698870181);
2540
- cg.local.get(z2);
2541
- cg.f32.mul();
2542
- cg.f32.const(.994987933645);
2543
- cg.f32.add();
2544
- cg.local.get(z2);
2545
- cg.f32.mul();
2546
- cg.f32.const(1);
2547
- cg.f32.add();
2587
+ _poly(cg, z2, [
2588
+ .999998614341,
2589
+ .661705427875,
2590
+ .0415796528637
2591
+ ]);
2592
+ _poly(cg, z2, [
2593
+ 1,
2594
+ .994987933645,
2595
+ .173698870181
2596
+ ]);
2548
2597
  cg.f32.div();
2549
2598
  cg.local.get(z);
2550
2599
  cg.f32.mul();
@@ -2598,6 +2647,74 @@ function wasm_asin(cg) {
2598
2647
  });
2599
2648
  }
2600
2649
  /**
2650
+ * Helper function for erf/erfc approximation.
2651
+ *
2652
+ * See `_erfapprox` in alu.ts for details on the algorithm used.
2653
+ */
2654
+ function _erfapprox(cg, exp_func) {
2655
+ const x = cg.local.declare(cg.f32);
2656
+ const t = cg.local.declare(cg.f32);
2657
+ cg.local.set(x);
2658
+ const p = .3275911;
2659
+ const a1 = .254829592;
2660
+ const a2 = -.284496736;
2661
+ const a3 = 1.421413741;
2662
+ const a4 = -1.453152027;
2663
+ const a5 = 1.061405429;
2664
+ cg.f32.const(1);
2665
+ cg.f32.const(1);
2666
+ cg.f32.const(p);
2667
+ cg.local.get(x);
2668
+ cg.f32.mul();
2669
+ cg.f32.add();
2670
+ cg.f32.div();
2671
+ cg.local.set(t);
2672
+ _poly(cg, t, [
2673
+ 0,
2674
+ a1,
2675
+ a2,
2676
+ a3,
2677
+ a4,
2678
+ a5
2679
+ ]);
2680
+ cg.local.get(x);
2681
+ cg.f32.neg();
2682
+ cg.local.get(x);
2683
+ cg.f32.mul();
2684
+ cg.call(exp_func);
2685
+ cg.f32.mul();
2686
+ }
2687
+ /** Approximate erf(x) (error function). */
2688
+ function wasm_erf(cg, exp) {
2689
+ return cg.function([cg.f32], [cg.f32], () => {
2690
+ cg.f32.const(1);
2691
+ cg.local.get(0);
2692
+ cg.f32.abs();
2693
+ _erfapprox(cg, exp);
2694
+ cg.f32.sub();
2695
+ cg.local.get(0);
2696
+ cg.f32.copysign();
2697
+ });
2698
+ }
2699
+ /** Approximate erfc(x) (complementary error function). */
2700
+ function wasm_erfc(cg, exp) {
2701
+ return cg.function([cg.f32], [cg.f32], () => {
2702
+ const e = cg.local.declare(cg.f32);
2703
+ cg.local.get(0);
2704
+ cg.f32.abs();
2705
+ _erfapprox(cg, exp);
2706
+ cg.local.set(e);
2707
+ cg.f32.const(2);
2708
+ cg.local.get(e);
2709
+ cg.f32.sub();
2710
+ cg.local.get(e);
2711
+ cg.local.get(0);
2712
+ cg.f32.const(0);
2713
+ cg.f32.lt();
2714
+ cg.select();
2715
+ });
2716
+ }
2717
+ /**
2601
2718
  * Threefry2x32 pseudorandom number generator.
2602
2719
  *
2603
2720
  * Takes two 32-bit keys and two 32-bit counters as input,
@@ -3472,14 +3589,16 @@ function codegenWasm(kernel) {
3472
3589
  if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
3473
3590
  const cg = new CodeGenerator();
3474
3591
  cg.memory.import("env", "memory");
3475
- const distinctOps = union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
3592
+ const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
3476
3593
  const funcs = {};
3477
3594
  if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
3478
3595
  if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
3479
3596
  if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
3480
3597
  if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
3481
- if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
3598
+ if (distinctOps.has(AluOp.Exp) || distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) funcs.exp = wasm_exp(cg);
3482
3599
  if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
3600
+ if (distinctOps.has(AluOp.Erf)) funcs.erf = wasm_erf(cg, funcs.exp);
3601
+ if (distinctOps.has(AluOp.Erfc)) funcs.erfc = wasm_erfc(cg, funcs.exp);
3483
3602
  if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
3484
3603
  const kernelFunc = cg.function(rep(kernel.nargs + 1, cg.i32), [], () => {
3485
3604
  const gidx = cg.local.declare(cg.i32);
@@ -3633,6 +3752,8 @@ function translateExp(cg, funcs, exp, ctx) {
3633
3752
  else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
3634
3753
  else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
3635
3754
  else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
3755
+ else if (op === AluOp.Erf) gen(src[0]), cg.call(funcs.erf);
3756
+ else if (op === AluOp.Erfc) gen(src[0]), cg.call(funcs.erfc);
3636
3757
  else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
3637
3758
  else if (op === AluOp.Reciprocal) cg.f32.const(1), gen(src[0]), cg.f32.div();
3638
3759
  else if (op === AluOp.Cast) {
@@ -3760,7 +3881,7 @@ async function createBackend(device) {
3760
3881
  if (!navigator.gpu) return null;
3761
3882
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3762
3883
  if (!adapter) return null;
3763
- const { WebGPUBackend } = await import("./webgpu-ow0Pn_6q.js");
3884
+ const { WebGPUBackend } = await import("./webgpu-LGi2A3mS.js");
3764
3885
  const importantLimits = [
3765
3886
  "maxBufferSize",
3766
3887
  "maxComputeInvocationsPerWorkgroup",
@@ -3813,4 +3934,5 @@ var UnsupportedOpError = class extends Error {
3813
3934
  };
3814
3935
 
3815
3936
  //#endregion
3816
- export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, union, unravelAlu, unzip2, zip, zipn };
3937
+ export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
3938
+ //# sourceMappingURL=backend-DwIAd0AG.js.map