@jax-js/jax 0.0.5 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +267 -92
- package/dist/{backend-CdcTZEOF.js → backend-DwIAd0AG.js} +205 -112
- package/dist/{backend-yEU0L_ig.cjs → backend-FtkbO6pI.cjs} +217 -118
- package/dist/index.cjs +344 -67
- package/dist/index.d.cts +96 -18
- package/dist/index.d.ts +96 -18
- package/dist/index.js +337 -67
- package/dist/{webgpu-CNOpiO5T.cjs → webgpu-BE7zA_01.cjs} +181 -151
- package/dist/{webgpu-CM-xNYzW.js → webgpu-LGi2A3mS.js} +181 -151
- package/package.json +7 -5
|
@@ -52,6 +52,7 @@ let DEBUG = 0;
|
|
|
52
52
|
function setDebug(level) {
|
|
53
53
|
DEBUG = level;
|
|
54
54
|
}
|
|
55
|
+
function assertNonNull(value) {}
|
|
55
56
|
function unzip2(pairs) {
|
|
56
57
|
const lst1 = [];
|
|
57
58
|
const lst2 = [];
|
|
@@ -97,10 +98,15 @@ function deepEqual(a, b) {
|
|
|
97
98
|
for (const key of Object.keys(a)) if (!deepEqual(a[key], b[key])) return false;
|
|
98
99
|
return true;
|
|
99
100
|
}
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
101
|
+
/** Produces a union of maps of sets. This mutates `a`. */
|
|
102
|
+
function mapSetUnion(a, b) {
|
|
103
|
+
if (!b) return a;
|
|
104
|
+
for (const [key, setB] of b.entries()) {
|
|
105
|
+
const setA = a.get(key);
|
|
106
|
+
if (setA) for (const val of setB) setA.add(val);
|
|
107
|
+
else a.set(key, setB);
|
|
108
|
+
}
|
|
109
|
+
return a;
|
|
104
110
|
}
|
|
105
111
|
/** Splits the list based on a condition, `false` first then `true`. */
|
|
106
112
|
function partitionList(which, array) {
|
|
@@ -430,6 +436,12 @@ var AluExp = class AluExp {
|
|
|
430
436
|
static log(a) {
|
|
431
437
|
return new AluExp(AluOp.Log, a.dtype, [a]);
|
|
432
438
|
}
|
|
439
|
+
static erf(a) {
|
|
440
|
+
return new AluExp(AluOp.Erf, a.dtype, [a]);
|
|
441
|
+
}
|
|
442
|
+
static erfc(a) {
|
|
443
|
+
return new AluExp(AluOp.Erfc, a.dtype, [a]);
|
|
444
|
+
}
|
|
433
445
|
static sqrt(a) {
|
|
434
446
|
return new AluExp(AluOp.Sqrt, a.dtype, [a]);
|
|
435
447
|
}
|
|
@@ -600,6 +612,12 @@ var AluExp = class AluExp {
|
|
|
600
612
|
case AluOp.Log:
|
|
601
613
|
ret = [Math.log(src[0].min), Math.log(src[0].max)];
|
|
602
614
|
break;
|
|
615
|
+
case AluOp.Erf:
|
|
616
|
+
ret = [erf(src[0].min), erf(src[0].max)];
|
|
617
|
+
break;
|
|
618
|
+
case AluOp.Erfc:
|
|
619
|
+
ret = [erfc(src[0].max), erfc(src[0].min)];
|
|
620
|
+
break;
|
|
603
621
|
case AluOp.Sqrt:
|
|
604
622
|
ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
|
|
605
623
|
break;
|
|
@@ -921,6 +939,8 @@ var AluExp = class AluExp {
|
|
|
921
939
|
case AluOp.Atan: return Math.atan(x);
|
|
922
940
|
case AluOp.Exp: return Math.exp(x);
|
|
923
941
|
case AluOp.Log: return Math.log(x);
|
|
942
|
+
case AluOp.Erf: return erf(x);
|
|
943
|
+
case AluOp.Erfc: return erfc(x);
|
|
924
944
|
case AluOp.Sqrt: return Math.sqrt(x);
|
|
925
945
|
case AluOp.Reciprocal: return 1 / x;
|
|
926
946
|
case AluOp.Cast: {
|
|
@@ -1069,11 +1089,15 @@ var AluExp = class AluExp {
|
|
|
1069
1089
|
});
|
|
1070
1090
|
return result;
|
|
1071
1091
|
}
|
|
1072
|
-
/** Produce
|
|
1092
|
+
/** Produce all distinct AluOp in this expression, with their dtypes. */
|
|
1073
1093
|
distinctOps() {
|
|
1074
|
-
const ops = /* @__PURE__ */ new
|
|
1094
|
+
const ops = /* @__PURE__ */ new Map();
|
|
1075
1095
|
this.fold((exp) => {
|
|
1076
|
-
ops.
|
|
1096
|
+
const s = ops.get(exp.op) ?? /* @__PURE__ */ new Set();
|
|
1097
|
+
if (!s.has(exp.dtype)) {
|
|
1098
|
+
s.add(exp.dtype);
|
|
1099
|
+
ops.set(exp.op, s);
|
|
1100
|
+
}
|
|
1077
1101
|
});
|
|
1078
1102
|
return ops;
|
|
1079
1103
|
}
|
|
@@ -1102,6 +1126,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
1102
1126
|
AluOp$1["Atan"] = "Atan";
|
|
1103
1127
|
AluOp$1["Exp"] = "Exp";
|
|
1104
1128
|
AluOp$1["Log"] = "Log";
|
|
1129
|
+
AluOp$1["Erf"] = "Erf";
|
|
1130
|
+
AluOp$1["Erfc"] = "Erfc";
|
|
1105
1131
|
AluOp$1["Sqrt"] = "Sqrt";
|
|
1106
1132
|
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
1107
1133
|
AluOp$1["Cast"] = "Cast";
|
|
@@ -1134,6 +1160,8 @@ const AluGroup = {
|
|
|
1134
1160
|
AluOp.Atan,
|
|
1135
1161
|
AluOp.Exp,
|
|
1136
1162
|
AluOp.Log,
|
|
1163
|
+
AluOp.Erf,
|
|
1164
|
+
AluOp.Erfc,
|
|
1137
1165
|
AluOp.Sqrt,
|
|
1138
1166
|
AluOp.Reciprocal,
|
|
1139
1167
|
AluOp.Cast,
|
|
@@ -1159,6 +1187,8 @@ const AluGroup = {
|
|
|
1159
1187
|
AluOp.Atan,
|
|
1160
1188
|
AluOp.Exp,
|
|
1161
1189
|
AluOp.Log,
|
|
1190
|
+
AluOp.Erf,
|
|
1191
|
+
AluOp.Erfc,
|
|
1162
1192
|
AluOp.Sqrt,
|
|
1163
1193
|
AluOp.Reciprocal
|
|
1164
1194
|
])
|
|
@@ -1334,6 +1364,44 @@ function threefry2x32(k0, k1, c0, c1) {
|
|
|
1334
1364
|
x1 = x1 + ks0 + 5 >>> 0;
|
|
1335
1365
|
return [x0, x1];
|
|
1336
1366
|
}
|
|
1367
|
+
/**
|
|
1368
|
+
* Abramowitz & Stegun’s widely used approximation for erf(x).
|
|
1369
|
+
*
|
|
1370
|
+
* `erf(x) = 1 - P(t) * exp(-x^2)` for `x >= 0`, where `t = 1/(1 + p*x)` and
|
|
1371
|
+
* `P(t) = a1*t + a2*t^2 + a3*t^3 + a4*t^4 + a5*t^5`.
|
|
1372
|
+
*
|
|
1373
|
+
* Coefficients:
|
|
1374
|
+
* - p = 0.3275911
|
|
1375
|
+
* - a1 = 0.254829592
|
|
1376
|
+
* - a2 = -0.284496736
|
|
1377
|
+
* - a3 = 1.421413741
|
|
1378
|
+
* - a4 = -1.453152027
|
|
1379
|
+
* - a5 = 1.061405429
|
|
1380
|
+
*
|
|
1381
|
+
* This function computes just `E = P(t) * exp(-x^2)` for numerical reasons. The
|
|
1382
|
+
* input is assumed to be non-negative.
|
|
1383
|
+
*
|
|
1384
|
+
* Reference: https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions
|
|
1385
|
+
*/
|
|
1386
|
+
function _erfapprox$1(x) {
|
|
1387
|
+
const p = .3275911;
|
|
1388
|
+
const a1 = .254829592;
|
|
1389
|
+
const a2 = -.284496736;
|
|
1390
|
+
const a3 = 1.421413741;
|
|
1391
|
+
const a4 = -1.453152027;
|
|
1392
|
+
const a5 = 1.061405429;
|
|
1393
|
+
const t = 1 / (1 + p * x);
|
|
1394
|
+
const P_t = ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t;
|
|
1395
|
+
return P_t * Math.exp(-x * x);
|
|
1396
|
+
}
|
|
1397
|
+
function erf(x) {
|
|
1398
|
+
if (x >= 0) return 1 - _erfapprox$1(x);
|
|
1399
|
+
else return _erfapprox$1(-x) - 1;
|
|
1400
|
+
}
|
|
1401
|
+
function erfc(x) {
|
|
1402
|
+
if (x >= 0) return _erfapprox$1(x);
|
|
1403
|
+
else return 2 - _erfapprox$1(-x);
|
|
1404
|
+
}
|
|
1337
1405
|
|
|
1338
1406
|
//#endregion
|
|
1339
1407
|
//#region src/shape.ts
|
|
@@ -2019,7 +2087,7 @@ function tuneWebgpu(kernel) {
|
|
|
2019
2087
|
}
|
|
2020
2088
|
if (/chrome/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
2021
2089
|
const s = dim.st.shape[dim.unroll - 1];
|
|
2022
|
-
if (s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2090
|
+
if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2023
2091
|
else for (const splits of [4]) if (s % splits === 0) {
|
|
2024
2092
|
dim.applyUnroll(dim.unroll - 1, splits);
|
|
2025
2093
|
break;
|
|
@@ -2238,6 +2306,19 @@ var WasmAllocator = class {
|
|
|
2238
2306
|
|
|
2239
2307
|
//#endregion
|
|
2240
2308
|
//#region src/backend/wasm/builtins.ts
|
|
2309
|
+
/** Given a local `x`, evaluate `sum[i](a_i * x^i)` and push to stack. */
|
|
2310
|
+
function _poly(cg, x, as) {
|
|
2311
|
+
if (as.length === 0) throw new Error("_poly needs at least one coefficient");
|
|
2312
|
+
cg.f32.const(as[as.length - 1]);
|
|
2313
|
+
for (let i = as.length - 2; i >= 0; i--) {
|
|
2314
|
+
cg.local.get(x);
|
|
2315
|
+
cg.f32.mul();
|
|
2316
|
+
if (as[i] !== 0) {
|
|
2317
|
+
cg.f32.const(as[i]);
|
|
2318
|
+
cg.f32.add();
|
|
2319
|
+
}
|
|
2320
|
+
}
|
|
2321
|
+
}
|
|
2241
2322
|
/**
|
|
2242
2323
|
* Approximate e^x.
|
|
2243
2324
|
*
|
|
@@ -2278,27 +2359,15 @@ function wasm_exp(cg) {
|
|
|
2278
2359
|
cg.f32.mul();
|
|
2279
2360
|
cg.f32.sub();
|
|
2280
2361
|
cg.local.set(r);
|
|
2281
|
-
cg
|
|
2282
|
-
|
|
2283
|
-
|
|
2284
|
-
|
|
2285
|
-
|
|
2286
|
-
|
|
2287
|
-
|
|
2288
|
-
|
|
2289
|
-
|
|
2290
|
-
cg.local.get(r);
|
|
2291
|
-
cg.f32.mul();
|
|
2292
|
-
cg.f32.const(1 / 2);
|
|
2293
|
-
cg.f32.add();
|
|
2294
|
-
cg.local.get(r);
|
|
2295
|
-
cg.f32.mul();
|
|
2296
|
-
cg.f32.const(1);
|
|
2297
|
-
cg.f32.add();
|
|
2298
|
-
cg.local.get(r);
|
|
2299
|
-
cg.f32.mul();
|
|
2300
|
-
cg.f32.const(1);
|
|
2301
|
-
cg.f32.add();
|
|
2362
|
+
_poly(cg, r, [
|
|
2363
|
+
1,
|
|
2364
|
+
1,
|
|
2365
|
+
1 / 2,
|
|
2366
|
+
1 / 6,
|
|
2367
|
+
1 / 24,
|
|
2368
|
+
1 / 120,
|
|
2369
|
+
1 / 720
|
|
2370
|
+
]);
|
|
2302
2371
|
cg.local.set(p);
|
|
2303
2372
|
cg.local.get(k);
|
|
2304
2373
|
cg.i32.const(127);
|
|
@@ -2326,11 +2395,6 @@ function wasm_log(cg) {
|
|
|
2326
2395
|
const m = cg.local.declare(cg.f32);
|
|
2327
2396
|
const t = cg.local.declare(cg.f32);
|
|
2328
2397
|
const t2 = cg.local.declare(cg.f32);
|
|
2329
|
-
const t3 = cg.local.declare(cg.f32);
|
|
2330
|
-
const t5 = cg.local.declare(cg.f32);
|
|
2331
|
-
const t7 = cg.local.declare(cg.f32);
|
|
2332
|
-
const lnm = cg.local.declare(cg.f32);
|
|
2333
|
-
const el2 = cg.local.declare(cg.f32);
|
|
2334
2398
|
cg.local.get(0);
|
|
2335
2399
|
cg.f32.const(0);
|
|
2336
2400
|
cg.f32.le();
|
|
@@ -2367,41 +2431,18 @@ function wasm_log(cg) {
|
|
|
2367
2431
|
cg.local.get(t);
|
|
2368
2432
|
cg.f32.mul();
|
|
2369
2433
|
cg.local.set(t2);
|
|
2434
|
+
_poly(cg, t2, [
|
|
2435
|
+
2,
|
|
2436
|
+
2 / 3,
|
|
2437
|
+
2 / 5,
|
|
2438
|
+
2 / 7
|
|
2439
|
+
]);
|
|
2370
2440
|
cg.local.get(t);
|
|
2371
|
-
cg.local.get(t2);
|
|
2372
|
-
cg.f32.mul();
|
|
2373
|
-
cg.local.set(t3);
|
|
2374
|
-
cg.local.get(t3);
|
|
2375
|
-
cg.local.get(t2);
|
|
2376
|
-
cg.f32.mul();
|
|
2377
|
-
cg.local.set(t5);
|
|
2378
|
-
cg.local.get(t5);
|
|
2379
|
-
cg.local.get(t2);
|
|
2380
|
-
cg.f32.mul();
|
|
2381
|
-
cg.local.set(t7);
|
|
2382
|
-
cg.local.get(t7);
|
|
2383
|
-
cg.f32.const(1 / 7);
|
|
2384
2441
|
cg.f32.mul();
|
|
2385
|
-
cg.local.get(t5);
|
|
2386
|
-
cg.f32.const(1 / 5);
|
|
2387
|
-
cg.f32.mul();
|
|
2388
|
-
cg.f32.add();
|
|
2389
|
-
cg.local.get(t3);
|
|
2390
|
-
cg.f32.const(1 / 3);
|
|
2391
|
-
cg.f32.mul();
|
|
2392
|
-
cg.f32.add();
|
|
2393
|
-
cg.local.get(t);
|
|
2394
|
-
cg.f32.add();
|
|
2395
|
-
cg.f32.const(2);
|
|
2396
|
-
cg.f32.mul();
|
|
2397
|
-
cg.local.set(lnm);
|
|
2398
2442
|
cg.local.get(e);
|
|
2399
2443
|
cg.f32.convert_i32_s();
|
|
2400
2444
|
cg.f32.const(Math.LN2);
|
|
2401
2445
|
cg.f32.mul();
|
|
2402
|
-
cg.local.set(el2);
|
|
2403
|
-
cg.local.get(el2);
|
|
2404
|
-
cg.local.get(lnm);
|
|
2405
2446
|
cg.f32.add();
|
|
2406
2447
|
});
|
|
2407
2448
|
}
|
|
@@ -2411,7 +2452,7 @@ function wasm_log(cg) {
|
|
|
2411
2452
|
* Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
|
|
2412
2453
|
* z = y - q*(π/2); use one of two polynomials on z:
|
|
2413
2454
|
* sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
|
|
2414
|
-
* cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
|
|
2455
|
+
* cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720) + z^8*(1/40320)
|
|
2415
2456
|
*/
|
|
2416
2457
|
function _sincos(cg) {
|
|
2417
2458
|
const y = cg.local.declare(cg.f32);
|
|
@@ -2447,35 +2488,22 @@ function _sincos(cg) {
|
|
|
2447
2488
|
cg.local.get(z);
|
|
2448
2489
|
cg.f32.mul();
|
|
2449
2490
|
cg.local.set(z2);
|
|
2450
|
-
cg
|
|
2451
|
-
|
|
2452
|
-
|
|
2453
|
-
|
|
2454
|
-
|
|
2455
|
-
|
|
2456
|
-
cg.f32.mul();
|
|
2457
|
-
cg.f32.const(-1 / 6);
|
|
2458
|
-
cg.f32.add();
|
|
2459
|
-
cg.local.get(z2);
|
|
2460
|
-
cg.f32.mul();
|
|
2461
|
-
cg.f32.const(1);
|
|
2462
|
-
cg.f32.add();
|
|
2491
|
+
_poly(cg, z2, [
|
|
2492
|
+
1,
|
|
2493
|
+
-1 / 6,
|
|
2494
|
+
1 / 120,
|
|
2495
|
+
-1 / 5040
|
|
2496
|
+
]);
|
|
2463
2497
|
cg.local.get(z);
|
|
2464
2498
|
cg.f32.mul();
|
|
2465
2499
|
cg.local.set(sz);
|
|
2466
|
-
cg
|
|
2467
|
-
|
|
2468
|
-
|
|
2469
|
-
|
|
2470
|
-
|
|
2471
|
-
|
|
2472
|
-
|
|
2473
|
-
cg.f32.const(-1 / 2);
|
|
2474
|
-
cg.f32.add();
|
|
2475
|
-
cg.local.get(z2);
|
|
2476
|
-
cg.f32.mul();
|
|
2477
|
-
cg.f32.const(1);
|
|
2478
|
-
cg.f32.add();
|
|
2500
|
+
_poly(cg, z2, [
|
|
2501
|
+
1,
|
|
2502
|
+
-1 / 2,
|
|
2503
|
+
1 / 24,
|
|
2504
|
+
-1 / 720,
|
|
2505
|
+
1 / 40320
|
|
2506
|
+
]);
|
|
2479
2507
|
cg.local.set(cz);
|
|
2480
2508
|
return {
|
|
2481
2509
|
q,
|
|
@@ -2557,24 +2585,16 @@ function _atan(cg) {
|
|
|
2557
2585
|
cg.local.get(z);
|
|
2558
2586
|
cg.f32.mul();
|
|
2559
2587
|
cg.local.set(z2);
|
|
2560
|
-
cg
|
|
2561
|
-
|
|
2562
|
-
|
|
2563
|
-
|
|
2564
|
-
|
|
2565
|
-
cg
|
|
2566
|
-
|
|
2567
|
-
|
|
2568
|
-
|
|
2569
|
-
|
|
2570
|
-
cg.local.get(z2);
|
|
2571
|
-
cg.f32.mul();
|
|
2572
|
-
cg.f32.const(.994987933645);
|
|
2573
|
-
cg.f32.add();
|
|
2574
|
-
cg.local.get(z2);
|
|
2575
|
-
cg.f32.mul();
|
|
2576
|
-
cg.f32.const(1);
|
|
2577
|
-
cg.f32.add();
|
|
2588
|
+
_poly(cg, z2, [
|
|
2589
|
+
.999998614341,
|
|
2590
|
+
.661705427875,
|
|
2591
|
+
.0415796528637
|
|
2592
|
+
]);
|
|
2593
|
+
_poly(cg, z2, [
|
|
2594
|
+
1,
|
|
2595
|
+
.994987933645,
|
|
2596
|
+
.173698870181
|
|
2597
|
+
]);
|
|
2578
2598
|
cg.f32.div();
|
|
2579
2599
|
cg.local.get(z);
|
|
2580
2600
|
cg.f32.mul();
|
|
@@ -2628,6 +2648,74 @@ function wasm_asin(cg) {
|
|
|
2628
2648
|
});
|
|
2629
2649
|
}
|
|
2630
2650
|
/**
|
|
2651
|
+
* Helper function for erf/erfc approximation.
|
|
2652
|
+
*
|
|
2653
|
+
* See `_erfapprox` in alu.ts for details on the algorithm used.
|
|
2654
|
+
*/
|
|
2655
|
+
function _erfapprox(cg, exp_func) {
|
|
2656
|
+
const x = cg.local.declare(cg.f32);
|
|
2657
|
+
const t = cg.local.declare(cg.f32);
|
|
2658
|
+
cg.local.set(x);
|
|
2659
|
+
const p = .3275911;
|
|
2660
|
+
const a1 = .254829592;
|
|
2661
|
+
const a2 = -.284496736;
|
|
2662
|
+
const a3 = 1.421413741;
|
|
2663
|
+
const a4 = -1.453152027;
|
|
2664
|
+
const a5 = 1.061405429;
|
|
2665
|
+
cg.f32.const(1);
|
|
2666
|
+
cg.f32.const(1);
|
|
2667
|
+
cg.f32.const(p);
|
|
2668
|
+
cg.local.get(x);
|
|
2669
|
+
cg.f32.mul();
|
|
2670
|
+
cg.f32.add();
|
|
2671
|
+
cg.f32.div();
|
|
2672
|
+
cg.local.set(t);
|
|
2673
|
+
_poly(cg, t, [
|
|
2674
|
+
0,
|
|
2675
|
+
a1,
|
|
2676
|
+
a2,
|
|
2677
|
+
a3,
|
|
2678
|
+
a4,
|
|
2679
|
+
a5
|
|
2680
|
+
]);
|
|
2681
|
+
cg.local.get(x);
|
|
2682
|
+
cg.f32.neg();
|
|
2683
|
+
cg.local.get(x);
|
|
2684
|
+
cg.f32.mul();
|
|
2685
|
+
cg.call(exp_func);
|
|
2686
|
+
cg.f32.mul();
|
|
2687
|
+
}
|
|
2688
|
+
/** Approximate erf(x) (error function). */
|
|
2689
|
+
function wasm_erf(cg, exp) {
|
|
2690
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2691
|
+
cg.f32.const(1);
|
|
2692
|
+
cg.local.get(0);
|
|
2693
|
+
cg.f32.abs();
|
|
2694
|
+
_erfapprox(cg, exp);
|
|
2695
|
+
cg.f32.sub();
|
|
2696
|
+
cg.local.get(0);
|
|
2697
|
+
cg.f32.copysign();
|
|
2698
|
+
});
|
|
2699
|
+
}
|
|
2700
|
+
/** Approximate erfc(x) (complementary error function). */
|
|
2701
|
+
function wasm_erfc(cg, exp) {
|
|
2702
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2703
|
+
const e = cg.local.declare(cg.f32);
|
|
2704
|
+
cg.local.get(0);
|
|
2705
|
+
cg.f32.abs();
|
|
2706
|
+
_erfapprox(cg, exp);
|
|
2707
|
+
cg.local.set(e);
|
|
2708
|
+
cg.f32.const(2);
|
|
2709
|
+
cg.local.get(e);
|
|
2710
|
+
cg.f32.sub();
|
|
2711
|
+
cg.local.get(e);
|
|
2712
|
+
cg.local.get(0);
|
|
2713
|
+
cg.f32.const(0);
|
|
2714
|
+
cg.f32.lt();
|
|
2715
|
+
cg.select();
|
|
2716
|
+
});
|
|
2717
|
+
}
|
|
2718
|
+
/**
|
|
2631
2719
|
* Threefry2x32 pseudorandom number generator.
|
|
2632
2720
|
*
|
|
2633
2721
|
* Takes two 32-bit keys and two 32-bit counters as input,
|
|
@@ -3502,14 +3590,16 @@ function codegenWasm(kernel) {
|
|
|
3502
3590
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
3503
3591
|
const cg = new CodeGenerator();
|
|
3504
3592
|
cg.memory.import("env", "memory");
|
|
3505
|
-
const distinctOps =
|
|
3593
|
+
const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
|
|
3506
3594
|
const funcs = {};
|
|
3507
3595
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3508
3596
|
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
3509
3597
|
if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
|
|
3510
3598
|
if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
|
|
3511
|
-
if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
|
|
3599
|
+
if (distinctOps.has(AluOp.Exp) || distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) funcs.exp = wasm_exp(cg);
|
|
3512
3600
|
if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
|
|
3601
|
+
if (distinctOps.has(AluOp.Erf)) funcs.erf = wasm_erf(cg, funcs.exp);
|
|
3602
|
+
if (distinctOps.has(AluOp.Erfc)) funcs.erfc = wasm_erfc(cg, funcs.exp);
|
|
3513
3603
|
if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
|
|
3514
3604
|
const kernelFunc = cg.function(rep(kernel.nargs + 1, cg.i32), [], () => {
|
|
3515
3605
|
const gidx = cg.local.declare(cg.i32);
|
|
@@ -3663,6 +3753,8 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3663
3753
|
else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
|
|
3664
3754
|
else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
|
|
3665
3755
|
else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
|
|
3756
|
+
else if (op === AluOp.Erf) gen(src[0]), cg.call(funcs.erf);
|
|
3757
|
+
else if (op === AluOp.Erfc) gen(src[0]), cg.call(funcs.erfc);
|
|
3666
3758
|
else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
|
|
3667
3759
|
else if (op === AluOp.Reciprocal) cg.f32.const(1), gen(src[0]), cg.f32.div();
|
|
3668
3760
|
else if (op === AluOp.Cast) {
|
|
@@ -3790,7 +3882,7 @@ async function createBackend(device) {
|
|
|
3790
3882
|
if (!navigator.gpu) return null;
|
|
3791
3883
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3792
3884
|
if (!adapter) return null;
|
|
3793
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
3885
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BE7zA_01.cjs"));
|
|
3794
3886
|
const importantLimits = [
|
|
3795
3887
|
"maxBufferSize",
|
|
3796
3888
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -3939,6 +4031,12 @@ Object.defineProperty(exports, 'accessorGlobal', {
|
|
|
3939
4031
|
return accessorGlobal;
|
|
3940
4032
|
}
|
|
3941
4033
|
});
|
|
4034
|
+
Object.defineProperty(exports, 'assertNonNull', {
|
|
4035
|
+
enumerable: true,
|
|
4036
|
+
get: function () {
|
|
4037
|
+
return assertNonNull;
|
|
4038
|
+
}
|
|
4039
|
+
});
|
|
3942
4040
|
Object.defineProperty(exports, 'byteWidth', {
|
|
3943
4041
|
enumerable: true,
|
|
3944
4042
|
get: function () {
|
|
@@ -4029,6 +4127,12 @@ Object.defineProperty(exports, 'isPermutation', {
|
|
|
4029
4127
|
return isPermutation;
|
|
4030
4128
|
}
|
|
4031
4129
|
});
|
|
4130
|
+
Object.defineProperty(exports, 'mapSetUnion', {
|
|
4131
|
+
enumerable: true,
|
|
4132
|
+
get: function () {
|
|
4133
|
+
return mapSetUnion;
|
|
4134
|
+
}
|
|
4135
|
+
});
|
|
4032
4136
|
Object.defineProperty(exports, 'normalizeAxis', {
|
|
4033
4137
|
enumerable: true,
|
|
4034
4138
|
get: function () {
|
|
@@ -4101,12 +4205,6 @@ Object.defineProperty(exports, 'tuneWebgpu', {
|
|
|
4101
4205
|
return tuneWebgpu;
|
|
4102
4206
|
}
|
|
4103
4207
|
});
|
|
4104
|
-
Object.defineProperty(exports, 'union', {
|
|
4105
|
-
enumerable: true,
|
|
4106
|
-
get: function () {
|
|
4107
|
-
return union;
|
|
4108
|
-
}
|
|
4109
|
-
});
|
|
4110
4208
|
Object.defineProperty(exports, 'unravelAlu', {
|
|
4111
4209
|
enumerable: true,
|
|
4112
4210
|
get: function () {
|
|
@@ -4130,4 +4228,5 @@ Object.defineProperty(exports, 'zipn', {
|
|
|
4130
4228
|
get: function () {
|
|
4131
4229
|
return zipn;
|
|
4132
4230
|
}
|
|
4133
|
-
});
|
|
4231
|
+
});
|
|
4232
|
+
//# sourceMappingURL=backend-FtkbO6pI.cjs.map
|