@jax-js/jax 0.0.5 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -51,6 +51,7 @@ let DEBUG = 0;
51
51
  function setDebug(level) {
52
52
  DEBUG = level;
53
53
  }
54
+ function assertNonNull(value) {}
54
55
  function unzip2(pairs) {
55
56
  const lst1 = [];
56
57
  const lst2 = [];
@@ -96,10 +97,15 @@ function deepEqual(a, b) {
96
97
  for (const key of Object.keys(a)) if (!deepEqual(a[key], b[key])) return false;
97
98
  return true;
98
99
  }
99
- function union(...sets) {
100
- const result = /* @__PURE__ */ new Set();
101
- for (const s of sets) if (s) for (const x of s) result.add(x);
102
- return result;
100
+ /** Produces a union of maps of sets. This mutates `a`. */
101
+ function mapSetUnion(a, b) {
102
+ if (!b) return a;
103
+ for (const [key, setB] of b.entries()) {
104
+ const setA = a.get(key);
105
+ if (setA) for (const val of setB) setA.add(val);
106
+ else a.set(key, setB);
107
+ }
108
+ return a;
103
109
  }
104
110
  /** Splits the list based on a condition, `false` first then `true`. */
105
111
  function partitionList(which, array) {
@@ -429,6 +435,12 @@ var AluExp = class AluExp {
429
435
  static log(a) {
430
436
  return new AluExp(AluOp.Log, a.dtype, [a]);
431
437
  }
438
+ static erf(a) {
439
+ return new AluExp(AluOp.Erf, a.dtype, [a]);
440
+ }
441
+ static erfc(a) {
442
+ return new AluExp(AluOp.Erfc, a.dtype, [a]);
443
+ }
432
444
  static sqrt(a) {
433
445
  return new AluExp(AluOp.Sqrt, a.dtype, [a]);
434
446
  }
@@ -599,6 +611,12 @@ var AluExp = class AluExp {
599
611
  case AluOp.Log:
600
612
  ret = [Math.log(src[0].min), Math.log(src[0].max)];
601
613
  break;
614
+ case AluOp.Erf:
615
+ ret = [erf(src[0].min), erf(src[0].max)];
616
+ break;
617
+ case AluOp.Erfc:
618
+ ret = [erfc(src[0].max), erfc(src[0].min)];
619
+ break;
602
620
  case AluOp.Sqrt:
603
621
  ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
604
622
  break;
@@ -920,6 +938,8 @@ var AluExp = class AluExp {
920
938
  case AluOp.Atan: return Math.atan(x);
921
939
  case AluOp.Exp: return Math.exp(x);
922
940
  case AluOp.Log: return Math.log(x);
941
+ case AluOp.Erf: return erf(x);
942
+ case AluOp.Erfc: return erfc(x);
923
943
  case AluOp.Sqrt: return Math.sqrt(x);
924
944
  case AluOp.Reciprocal: return 1 / x;
925
945
  case AluOp.Cast: {
@@ -1068,11 +1088,15 @@ var AluExp = class AluExp {
1068
1088
  });
1069
1089
  return result;
1070
1090
  }
1071
- /** Produce a list of all distinct AluOp in this expression. */
1091
+ /** Produce all distinct AluOp in this expression, with their dtypes. */
1072
1092
  distinctOps() {
1073
- const ops = /* @__PURE__ */ new Set();
1093
+ const ops = /* @__PURE__ */ new Map();
1074
1094
  this.fold((exp) => {
1075
- ops.add(exp.op);
1095
+ const s = ops.get(exp.op) ?? /* @__PURE__ */ new Set();
1096
+ if (!s.has(exp.dtype)) {
1097
+ s.add(exp.dtype);
1098
+ ops.set(exp.op, s);
1099
+ }
1076
1100
  });
1077
1101
  return ops;
1078
1102
  }
@@ -1101,6 +1125,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
1101
1125
  AluOp$1["Atan"] = "Atan";
1102
1126
  AluOp$1["Exp"] = "Exp";
1103
1127
  AluOp$1["Log"] = "Log";
1128
+ AluOp$1["Erf"] = "Erf";
1129
+ AluOp$1["Erfc"] = "Erfc";
1104
1130
  AluOp$1["Sqrt"] = "Sqrt";
1105
1131
  AluOp$1["Reciprocal"] = "Reciprocal";
1106
1132
  AluOp$1["Cast"] = "Cast";
@@ -1133,6 +1159,8 @@ const AluGroup = {
1133
1159
  AluOp.Atan,
1134
1160
  AluOp.Exp,
1135
1161
  AluOp.Log,
1162
+ AluOp.Erf,
1163
+ AluOp.Erfc,
1136
1164
  AluOp.Sqrt,
1137
1165
  AluOp.Reciprocal,
1138
1166
  AluOp.Cast,
@@ -1158,6 +1186,8 @@ const AluGroup = {
1158
1186
  AluOp.Atan,
1159
1187
  AluOp.Exp,
1160
1188
  AluOp.Log,
1189
+ AluOp.Erf,
1190
+ AluOp.Erfc,
1161
1191
  AluOp.Sqrt,
1162
1192
  AluOp.Reciprocal
1163
1193
  ])
@@ -1333,6 +1363,44 @@ function threefry2x32(k0, k1, c0, c1) {
1333
1363
  x1 = x1 + ks0 + 5 >>> 0;
1334
1364
  return [x0, x1];
1335
1365
  }
1366
+ /**
1367
+ * Abramowitz & Stegun’s widely used approximation for erf(x).
1368
+ *
1369
+ * `erf(x) = 1 - P(t) * exp(-x^2)` for `x >= 0`, where `t = 1/(1 + p*x)` and
1370
+ * `P(t) = a1*t + a2*t^2 + a3*t^3 + a4*t^4 + a5*t^5`.
1371
+ *
1372
+ * Coefficients:
1373
+ * - p = 0.3275911
1374
+ * - a1 = 0.254829592
1375
+ * - a2 = -0.284496736
1376
+ * - a3 = 1.421413741
1377
+ * - a4 = -1.453152027
1378
+ * - a5 = 1.061405429
1379
+ *
1380
+ * This function computes just `E = P(t) * exp(-x^2)` for numerical reasons. The
1381
+ * input is assumed to be non-negative.
1382
+ *
1383
+ * Reference: https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions
1384
+ */
1385
+ function _erfapprox$1(x) {
1386
+ const p = .3275911;
1387
+ const a1 = .254829592;
1388
+ const a2 = -.284496736;
1389
+ const a3 = 1.421413741;
1390
+ const a4 = -1.453152027;
1391
+ const a5 = 1.061405429;
1392
+ const t = 1 / (1 + p * x);
1393
+ const P_t = ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t;
1394
+ return P_t * Math.exp(-x * x);
1395
+ }
1396
+ function erf(x) {
1397
+ if (x >= 0) return 1 - _erfapprox$1(x);
1398
+ else return _erfapprox$1(-x) - 1;
1399
+ }
1400
+ function erfc(x) {
1401
+ if (x >= 0) return _erfapprox$1(x);
1402
+ else return 2 - _erfapprox$1(-x);
1403
+ }
1336
1404
 
1337
1405
  //#endregion
1338
1406
  //#region src/shape.ts
@@ -2018,7 +2086,7 @@ function tuneWebgpu(kernel) {
2018
2086
  }
2019
2087
  if (/chrome/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2020
2088
  const s = dim.st.shape[dim.unroll - 1];
2021
- if (s <= 32) dim.applyUnroll(dim.reduce, s);
2089
+ if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
2022
2090
  else for (const splits of [4]) if (s % splits === 0) {
2023
2091
  dim.applyUnroll(dim.unroll - 1, splits);
2024
2092
  break;
@@ -2237,6 +2305,19 @@ var WasmAllocator = class {
2237
2305
 
2238
2306
  //#endregion
2239
2307
  //#region src/backend/wasm/builtins.ts
2308
+ /** Given a local `x`, evaluate `sum[i](a_i * x^i)` and push to stack. */
2309
+ function _poly(cg, x, as) {
2310
+ if (as.length === 0) throw new Error("_poly needs at least one coefficient");
2311
+ cg.f32.const(as[as.length - 1]);
2312
+ for (let i = as.length - 2; i >= 0; i--) {
2313
+ cg.local.get(x);
2314
+ cg.f32.mul();
2315
+ if (as[i] !== 0) {
2316
+ cg.f32.const(as[i]);
2317
+ cg.f32.add();
2318
+ }
2319
+ }
2320
+ }
2240
2321
  /**
2241
2322
  * Approximate e^x.
2242
2323
  *
@@ -2277,27 +2358,15 @@ function wasm_exp(cg) {
2277
2358
  cg.f32.mul();
2278
2359
  cg.f32.sub();
2279
2360
  cg.local.set(r);
2280
- cg.f32.const(1 / 120);
2281
- cg.local.get(r);
2282
- cg.f32.mul();
2283
- cg.f32.const(1 / 24);
2284
- cg.f32.add();
2285
- cg.local.get(r);
2286
- cg.f32.mul();
2287
- cg.f32.const(1 / 6);
2288
- cg.f32.add();
2289
- cg.local.get(r);
2290
- cg.f32.mul();
2291
- cg.f32.const(1 / 2);
2292
- cg.f32.add();
2293
- cg.local.get(r);
2294
- cg.f32.mul();
2295
- cg.f32.const(1);
2296
- cg.f32.add();
2297
- cg.local.get(r);
2298
- cg.f32.mul();
2299
- cg.f32.const(1);
2300
- cg.f32.add();
2361
+ _poly(cg, r, [
2362
+ 1,
2363
+ 1,
2364
+ 1 / 2,
2365
+ 1 / 6,
2366
+ 1 / 24,
2367
+ 1 / 120,
2368
+ 1 / 720
2369
+ ]);
2301
2370
  cg.local.set(p);
2302
2371
  cg.local.get(k);
2303
2372
  cg.i32.const(127);
@@ -2325,11 +2394,6 @@ function wasm_log(cg) {
2325
2394
  const m = cg.local.declare(cg.f32);
2326
2395
  const t = cg.local.declare(cg.f32);
2327
2396
  const t2 = cg.local.declare(cg.f32);
2328
- const t3 = cg.local.declare(cg.f32);
2329
- const t5 = cg.local.declare(cg.f32);
2330
- const t7 = cg.local.declare(cg.f32);
2331
- const lnm = cg.local.declare(cg.f32);
2332
- const el2 = cg.local.declare(cg.f32);
2333
2397
  cg.local.get(0);
2334
2398
  cg.f32.const(0);
2335
2399
  cg.f32.le();
@@ -2366,41 +2430,18 @@ function wasm_log(cg) {
2366
2430
  cg.local.get(t);
2367
2431
  cg.f32.mul();
2368
2432
  cg.local.set(t2);
2433
+ _poly(cg, t2, [
2434
+ 2,
2435
+ 2 / 3,
2436
+ 2 / 5,
2437
+ 2 / 7
2438
+ ]);
2369
2439
  cg.local.get(t);
2370
- cg.local.get(t2);
2371
- cg.f32.mul();
2372
- cg.local.set(t3);
2373
- cg.local.get(t3);
2374
- cg.local.get(t2);
2375
- cg.f32.mul();
2376
- cg.local.set(t5);
2377
- cg.local.get(t5);
2378
- cg.local.get(t2);
2379
- cg.f32.mul();
2380
- cg.local.set(t7);
2381
- cg.local.get(t7);
2382
- cg.f32.const(1 / 7);
2383
- cg.f32.mul();
2384
- cg.local.get(t5);
2385
- cg.f32.const(1 / 5);
2386
- cg.f32.mul();
2387
- cg.f32.add();
2388
- cg.local.get(t3);
2389
- cg.f32.const(1 / 3);
2390
- cg.f32.mul();
2391
- cg.f32.add();
2392
- cg.local.get(t);
2393
- cg.f32.add();
2394
- cg.f32.const(2);
2395
2440
  cg.f32.mul();
2396
- cg.local.set(lnm);
2397
2441
  cg.local.get(e);
2398
2442
  cg.f32.convert_i32_s();
2399
2443
  cg.f32.const(Math.LN2);
2400
2444
  cg.f32.mul();
2401
- cg.local.set(el2);
2402
- cg.local.get(el2);
2403
- cg.local.get(lnm);
2404
2445
  cg.f32.add();
2405
2446
  });
2406
2447
  }
@@ -2410,7 +2451,7 @@ function wasm_log(cg) {
2410
2451
  * Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
2411
2452
  * z = y - q*(π/2); use one of two polynomials on z:
2412
2453
  * sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
2413
- * cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
2454
+ * cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720) + z^8*(1/40320)
2414
2455
  */
2415
2456
  function _sincos(cg) {
2416
2457
  const y = cg.local.declare(cg.f32);
@@ -2446,35 +2487,22 @@ function _sincos(cg) {
2446
2487
  cg.local.get(z);
2447
2488
  cg.f32.mul();
2448
2489
  cg.local.set(z2);
2449
- cg.f32.const(-1 / 5040);
2450
- cg.local.get(z2);
2451
- cg.f32.mul();
2452
- cg.f32.const(1 / 120);
2453
- cg.f32.add();
2454
- cg.local.get(z2);
2455
- cg.f32.mul();
2456
- cg.f32.const(-1 / 6);
2457
- cg.f32.add();
2458
- cg.local.get(z2);
2459
- cg.f32.mul();
2460
- cg.f32.const(1);
2461
- cg.f32.add();
2490
+ _poly(cg, z2, [
2491
+ 1,
2492
+ -1 / 6,
2493
+ 1 / 120,
2494
+ -1 / 5040
2495
+ ]);
2462
2496
  cg.local.get(z);
2463
2497
  cg.f32.mul();
2464
2498
  cg.local.set(sz);
2465
- cg.f32.const(-1 / 720);
2466
- cg.local.get(z2);
2467
- cg.f32.mul();
2468
- cg.f32.const(1 / 24);
2469
- cg.f32.add();
2470
- cg.local.get(z2);
2471
- cg.f32.mul();
2472
- cg.f32.const(-1 / 2);
2473
- cg.f32.add();
2474
- cg.local.get(z2);
2475
- cg.f32.mul();
2476
- cg.f32.const(1);
2477
- cg.f32.add();
2499
+ _poly(cg, z2, [
2500
+ 1,
2501
+ -1 / 2,
2502
+ 1 / 24,
2503
+ -1 / 720,
2504
+ 1 / 40320
2505
+ ]);
2478
2506
  cg.local.set(cz);
2479
2507
  return {
2480
2508
  q,
@@ -2556,24 +2584,16 @@ function _atan(cg) {
2556
2584
  cg.local.get(z);
2557
2585
  cg.f32.mul();
2558
2586
  cg.local.set(z2);
2559
- cg.f32.const(.0415796528637);
2560
- cg.local.get(z2);
2561
- cg.f32.mul();
2562
- cg.f32.const(.661705427875);
2563
- cg.f32.add();
2564
- cg.local.get(z2);
2565
- cg.f32.mul();
2566
- cg.f32.const(.999998614341);
2567
- cg.f32.add();
2568
- cg.f32.const(.173698870181);
2569
- cg.local.get(z2);
2570
- cg.f32.mul();
2571
- cg.f32.const(.994987933645);
2572
- cg.f32.add();
2573
- cg.local.get(z2);
2574
- cg.f32.mul();
2575
- cg.f32.const(1);
2576
- cg.f32.add();
2587
+ _poly(cg, z2, [
2588
+ .999998614341,
2589
+ .661705427875,
2590
+ .0415796528637
2591
+ ]);
2592
+ _poly(cg, z2, [
2593
+ 1,
2594
+ .994987933645,
2595
+ .173698870181
2596
+ ]);
2577
2597
  cg.f32.div();
2578
2598
  cg.local.get(z);
2579
2599
  cg.f32.mul();
@@ -2627,6 +2647,74 @@ function wasm_asin(cg) {
2627
2647
  });
2628
2648
  }
2629
2649
  /**
2650
+ * Helper function for erf/erfc approximation.
2651
+ *
2652
+ * See `_erfapprox` in alu.ts for details on the algorithm used.
2653
+ */
2654
+ function _erfapprox(cg, exp_func) {
2655
+ const x = cg.local.declare(cg.f32);
2656
+ const t = cg.local.declare(cg.f32);
2657
+ cg.local.set(x);
2658
+ const p = .3275911;
2659
+ const a1 = .254829592;
2660
+ const a2 = -.284496736;
2661
+ const a3 = 1.421413741;
2662
+ const a4 = -1.453152027;
2663
+ const a5 = 1.061405429;
2664
+ cg.f32.const(1);
2665
+ cg.f32.const(1);
2666
+ cg.f32.const(p);
2667
+ cg.local.get(x);
2668
+ cg.f32.mul();
2669
+ cg.f32.add();
2670
+ cg.f32.div();
2671
+ cg.local.set(t);
2672
+ _poly(cg, t, [
2673
+ 0,
2674
+ a1,
2675
+ a2,
2676
+ a3,
2677
+ a4,
2678
+ a5
2679
+ ]);
2680
+ cg.local.get(x);
2681
+ cg.f32.neg();
2682
+ cg.local.get(x);
2683
+ cg.f32.mul();
2684
+ cg.call(exp_func);
2685
+ cg.f32.mul();
2686
+ }
2687
+ /** Approximate erf(x) (error function). */
2688
+ function wasm_erf(cg, exp) {
2689
+ return cg.function([cg.f32], [cg.f32], () => {
2690
+ cg.f32.const(1);
2691
+ cg.local.get(0);
2692
+ cg.f32.abs();
2693
+ _erfapprox(cg, exp);
2694
+ cg.f32.sub();
2695
+ cg.local.get(0);
2696
+ cg.f32.copysign();
2697
+ });
2698
+ }
2699
+ /** Approximate erfc(x) (complementary error function). */
2700
+ function wasm_erfc(cg, exp) {
2701
+ return cg.function([cg.f32], [cg.f32], () => {
2702
+ const e = cg.local.declare(cg.f32);
2703
+ cg.local.get(0);
2704
+ cg.f32.abs();
2705
+ _erfapprox(cg, exp);
2706
+ cg.local.set(e);
2707
+ cg.f32.const(2);
2708
+ cg.local.get(e);
2709
+ cg.f32.sub();
2710
+ cg.local.get(e);
2711
+ cg.local.get(0);
2712
+ cg.f32.const(0);
2713
+ cg.f32.lt();
2714
+ cg.select();
2715
+ });
2716
+ }
2717
+ /**
2630
2718
  * Threefry2x32 pseudorandom number generator.
2631
2719
  *
2632
2720
  * Takes two 32-bit keys and two 32-bit counters as input,
@@ -3501,14 +3589,16 @@ function codegenWasm(kernel) {
3501
3589
  if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
3502
3590
  const cg = new CodeGenerator();
3503
3591
  cg.memory.import("env", "memory");
3504
- const distinctOps = union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
3592
+ const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
3505
3593
  const funcs = {};
3506
3594
  if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
3507
3595
  if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
3508
3596
  if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
3509
3597
  if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
3510
- if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
3598
+ if (distinctOps.has(AluOp.Exp) || distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) funcs.exp = wasm_exp(cg);
3511
3599
  if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
3600
+ if (distinctOps.has(AluOp.Erf)) funcs.erf = wasm_erf(cg, funcs.exp);
3601
+ if (distinctOps.has(AluOp.Erfc)) funcs.erfc = wasm_erfc(cg, funcs.exp);
3512
3602
  if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
3513
3603
  const kernelFunc = cg.function(rep(kernel.nargs + 1, cg.i32), [], () => {
3514
3604
  const gidx = cg.local.declare(cg.i32);
@@ -3662,6 +3752,8 @@ function translateExp(cg, funcs, exp, ctx) {
3662
3752
  else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
3663
3753
  else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
3664
3754
  else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
3755
+ else if (op === AluOp.Erf) gen(src[0]), cg.call(funcs.erf);
3756
+ else if (op === AluOp.Erfc) gen(src[0]), cg.call(funcs.erfc);
3665
3757
  else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
3666
3758
  else if (op === AluOp.Reciprocal) cg.f32.const(1), gen(src[0]), cg.f32.div();
3667
3759
  else if (op === AluOp.Cast) {
@@ -3789,7 +3881,7 @@ async function createBackend(device) {
3789
3881
  if (!navigator.gpu) return null;
3790
3882
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3791
3883
  if (!adapter) return null;
3792
- const { WebGPUBackend } = await import("./webgpu-CM-xNYzW.js");
3884
+ const { WebGPUBackend } = await import("./webgpu-LGi2A3mS.js");
3793
3885
  const importantLimits = [
3794
3886
  "maxBufferSize",
3795
3887
  "maxComputeInvocationsPerWorkgroup",
@@ -3842,4 +3934,5 @@ var UnsupportedOpError = class extends Error {
3842
3934
  };
3843
3935
 
3844
3936
  //#endregion
3845
- export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, union, unravelAlu, unzip2, zip, zipn };
3937
+ export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
3938
+ //# sourceMappingURL=backend-DwIAd0AG.js.map