@jax-js/jax 0.0.5 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +267 -92
- package/dist/{backend-CdcTZEOF.js → backend-DwIAd0AG.js} +205 -112
- package/dist/{backend-yEU0L_ig.cjs → backend-FtkbO6pI.cjs} +217 -118
- package/dist/index.cjs +344 -67
- package/dist/index.d.cts +96 -18
- package/dist/index.d.ts +96 -18
- package/dist/index.js +337 -67
- package/dist/{webgpu-CNOpiO5T.cjs → webgpu-BE7zA_01.cjs} +181 -151
- package/dist/{webgpu-CM-xNYzW.js → webgpu-LGi2A3mS.js} +181 -151
- package/package.json +7 -5
|
@@ -51,6 +51,7 @@ let DEBUG = 0;
|
|
|
51
51
|
function setDebug(level) {
|
|
52
52
|
DEBUG = level;
|
|
53
53
|
}
|
|
54
|
+
function assertNonNull(value) {}
|
|
54
55
|
function unzip2(pairs) {
|
|
55
56
|
const lst1 = [];
|
|
56
57
|
const lst2 = [];
|
|
@@ -96,10 +97,15 @@ function deepEqual(a, b) {
|
|
|
96
97
|
for (const key of Object.keys(a)) if (!deepEqual(a[key], b[key])) return false;
|
|
97
98
|
return true;
|
|
98
99
|
}
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
100
|
+
/** Produces a union of maps of sets. This mutates `a`. */
|
|
101
|
+
function mapSetUnion(a, b) {
|
|
102
|
+
if (!b) return a;
|
|
103
|
+
for (const [key, setB] of b.entries()) {
|
|
104
|
+
const setA = a.get(key);
|
|
105
|
+
if (setA) for (const val of setB) setA.add(val);
|
|
106
|
+
else a.set(key, setB);
|
|
107
|
+
}
|
|
108
|
+
return a;
|
|
103
109
|
}
|
|
104
110
|
/** Splits the list based on a condition, `false` first then `true`. */
|
|
105
111
|
function partitionList(which, array) {
|
|
@@ -429,6 +435,12 @@ var AluExp = class AluExp {
|
|
|
429
435
|
static log(a) {
|
|
430
436
|
return new AluExp(AluOp.Log, a.dtype, [a]);
|
|
431
437
|
}
|
|
438
|
+
static erf(a) {
|
|
439
|
+
return new AluExp(AluOp.Erf, a.dtype, [a]);
|
|
440
|
+
}
|
|
441
|
+
static erfc(a) {
|
|
442
|
+
return new AluExp(AluOp.Erfc, a.dtype, [a]);
|
|
443
|
+
}
|
|
432
444
|
static sqrt(a) {
|
|
433
445
|
return new AluExp(AluOp.Sqrt, a.dtype, [a]);
|
|
434
446
|
}
|
|
@@ -599,6 +611,12 @@ var AluExp = class AluExp {
|
|
|
599
611
|
case AluOp.Log:
|
|
600
612
|
ret = [Math.log(src[0].min), Math.log(src[0].max)];
|
|
601
613
|
break;
|
|
614
|
+
case AluOp.Erf:
|
|
615
|
+
ret = [erf(src[0].min), erf(src[0].max)];
|
|
616
|
+
break;
|
|
617
|
+
case AluOp.Erfc:
|
|
618
|
+
ret = [erfc(src[0].max), erfc(src[0].min)];
|
|
619
|
+
break;
|
|
602
620
|
case AluOp.Sqrt:
|
|
603
621
|
ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
|
|
604
622
|
break;
|
|
@@ -920,6 +938,8 @@ var AluExp = class AluExp {
|
|
|
920
938
|
case AluOp.Atan: return Math.atan(x);
|
|
921
939
|
case AluOp.Exp: return Math.exp(x);
|
|
922
940
|
case AluOp.Log: return Math.log(x);
|
|
941
|
+
case AluOp.Erf: return erf(x);
|
|
942
|
+
case AluOp.Erfc: return erfc(x);
|
|
923
943
|
case AluOp.Sqrt: return Math.sqrt(x);
|
|
924
944
|
case AluOp.Reciprocal: return 1 / x;
|
|
925
945
|
case AluOp.Cast: {
|
|
@@ -1068,11 +1088,15 @@ var AluExp = class AluExp {
|
|
|
1068
1088
|
});
|
|
1069
1089
|
return result;
|
|
1070
1090
|
}
|
|
1071
|
-
/** Produce
|
|
1091
|
+
/** Produce all distinct AluOp in this expression, with their dtypes. */
|
|
1072
1092
|
distinctOps() {
|
|
1073
|
-
const ops = /* @__PURE__ */ new
|
|
1093
|
+
const ops = /* @__PURE__ */ new Map();
|
|
1074
1094
|
this.fold((exp) => {
|
|
1075
|
-
ops.
|
|
1095
|
+
const s = ops.get(exp.op) ?? /* @__PURE__ */ new Set();
|
|
1096
|
+
if (!s.has(exp.dtype)) {
|
|
1097
|
+
s.add(exp.dtype);
|
|
1098
|
+
ops.set(exp.op, s);
|
|
1099
|
+
}
|
|
1076
1100
|
});
|
|
1077
1101
|
return ops;
|
|
1078
1102
|
}
|
|
@@ -1101,6 +1125,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
1101
1125
|
AluOp$1["Atan"] = "Atan";
|
|
1102
1126
|
AluOp$1["Exp"] = "Exp";
|
|
1103
1127
|
AluOp$1["Log"] = "Log";
|
|
1128
|
+
AluOp$1["Erf"] = "Erf";
|
|
1129
|
+
AluOp$1["Erfc"] = "Erfc";
|
|
1104
1130
|
AluOp$1["Sqrt"] = "Sqrt";
|
|
1105
1131
|
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
1106
1132
|
AluOp$1["Cast"] = "Cast";
|
|
@@ -1133,6 +1159,8 @@ const AluGroup = {
|
|
|
1133
1159
|
AluOp.Atan,
|
|
1134
1160
|
AluOp.Exp,
|
|
1135
1161
|
AluOp.Log,
|
|
1162
|
+
AluOp.Erf,
|
|
1163
|
+
AluOp.Erfc,
|
|
1136
1164
|
AluOp.Sqrt,
|
|
1137
1165
|
AluOp.Reciprocal,
|
|
1138
1166
|
AluOp.Cast,
|
|
@@ -1158,6 +1186,8 @@ const AluGroup = {
|
|
|
1158
1186
|
AluOp.Atan,
|
|
1159
1187
|
AluOp.Exp,
|
|
1160
1188
|
AluOp.Log,
|
|
1189
|
+
AluOp.Erf,
|
|
1190
|
+
AluOp.Erfc,
|
|
1161
1191
|
AluOp.Sqrt,
|
|
1162
1192
|
AluOp.Reciprocal
|
|
1163
1193
|
])
|
|
@@ -1333,6 +1363,44 @@ function threefry2x32(k0, k1, c0, c1) {
|
|
|
1333
1363
|
x1 = x1 + ks0 + 5 >>> 0;
|
|
1334
1364
|
return [x0, x1];
|
|
1335
1365
|
}
|
|
1366
|
+
/**
|
|
1367
|
+
* Abramowitz & Stegun’s widely used approximation for erf(x).
|
|
1368
|
+
*
|
|
1369
|
+
* `erf(x) = 1 - P(t) * exp(-x^2)` for `x >= 0`, where `t = 1/(1 + p*x)` and
|
|
1370
|
+
* `P(t) = a1*t + a2*t^2 + a3*t^3 + a4*t^4 + a5*t^5`.
|
|
1371
|
+
*
|
|
1372
|
+
* Coefficients:
|
|
1373
|
+
* - p = 0.3275911
|
|
1374
|
+
* - a1 = 0.254829592
|
|
1375
|
+
* - a2 = -0.284496736
|
|
1376
|
+
* - a3 = 1.421413741
|
|
1377
|
+
* - a4 = -1.453152027
|
|
1378
|
+
* - a5 = 1.061405429
|
|
1379
|
+
*
|
|
1380
|
+
* This function computes just `E = P(t) * exp(-x^2)` for numerical reasons. The
|
|
1381
|
+
* input is assumed to be non-negative.
|
|
1382
|
+
*
|
|
1383
|
+
* Reference: https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions
|
|
1384
|
+
*/
|
|
1385
|
+
function _erfapprox$1(x) {
|
|
1386
|
+
const p = .3275911;
|
|
1387
|
+
const a1 = .254829592;
|
|
1388
|
+
const a2 = -.284496736;
|
|
1389
|
+
const a3 = 1.421413741;
|
|
1390
|
+
const a4 = -1.453152027;
|
|
1391
|
+
const a5 = 1.061405429;
|
|
1392
|
+
const t = 1 / (1 + p * x);
|
|
1393
|
+
const P_t = ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t;
|
|
1394
|
+
return P_t * Math.exp(-x * x);
|
|
1395
|
+
}
|
|
1396
|
+
function erf(x) {
|
|
1397
|
+
if (x >= 0) return 1 - _erfapprox$1(x);
|
|
1398
|
+
else return _erfapprox$1(-x) - 1;
|
|
1399
|
+
}
|
|
1400
|
+
function erfc(x) {
|
|
1401
|
+
if (x >= 0) return _erfapprox$1(x);
|
|
1402
|
+
else return 2 - _erfapprox$1(-x);
|
|
1403
|
+
}
|
|
1336
1404
|
|
|
1337
1405
|
//#endregion
|
|
1338
1406
|
//#region src/shape.ts
|
|
@@ -2018,7 +2086,7 @@ function tuneWebgpu(kernel) {
|
|
|
2018
2086
|
}
|
|
2019
2087
|
if (/chrome/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
2020
2088
|
const s = dim.st.shape[dim.unroll - 1];
|
|
2021
|
-
if (s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2089
|
+
if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2022
2090
|
else for (const splits of [4]) if (s % splits === 0) {
|
|
2023
2091
|
dim.applyUnroll(dim.unroll - 1, splits);
|
|
2024
2092
|
break;
|
|
@@ -2237,6 +2305,19 @@ var WasmAllocator = class {
|
|
|
2237
2305
|
|
|
2238
2306
|
//#endregion
|
|
2239
2307
|
//#region src/backend/wasm/builtins.ts
|
|
2308
|
+
/** Given a local `x`, evaluate `sum[i](a_i * x^i)` and push to stack. */
|
|
2309
|
+
function _poly(cg, x, as) {
|
|
2310
|
+
if (as.length === 0) throw new Error("_poly needs at least one coefficient");
|
|
2311
|
+
cg.f32.const(as[as.length - 1]);
|
|
2312
|
+
for (let i = as.length - 2; i >= 0; i--) {
|
|
2313
|
+
cg.local.get(x);
|
|
2314
|
+
cg.f32.mul();
|
|
2315
|
+
if (as[i] !== 0) {
|
|
2316
|
+
cg.f32.const(as[i]);
|
|
2317
|
+
cg.f32.add();
|
|
2318
|
+
}
|
|
2319
|
+
}
|
|
2320
|
+
}
|
|
2240
2321
|
/**
|
|
2241
2322
|
* Approximate e^x.
|
|
2242
2323
|
*
|
|
@@ -2277,27 +2358,15 @@ function wasm_exp(cg) {
|
|
|
2277
2358
|
cg.f32.mul();
|
|
2278
2359
|
cg.f32.sub();
|
|
2279
2360
|
cg.local.set(r);
|
|
2280
|
-
cg
|
|
2281
|
-
|
|
2282
|
-
|
|
2283
|
-
|
|
2284
|
-
|
|
2285
|
-
|
|
2286
|
-
|
|
2287
|
-
|
|
2288
|
-
|
|
2289
|
-
cg.local.get(r);
|
|
2290
|
-
cg.f32.mul();
|
|
2291
|
-
cg.f32.const(1 / 2);
|
|
2292
|
-
cg.f32.add();
|
|
2293
|
-
cg.local.get(r);
|
|
2294
|
-
cg.f32.mul();
|
|
2295
|
-
cg.f32.const(1);
|
|
2296
|
-
cg.f32.add();
|
|
2297
|
-
cg.local.get(r);
|
|
2298
|
-
cg.f32.mul();
|
|
2299
|
-
cg.f32.const(1);
|
|
2300
|
-
cg.f32.add();
|
|
2361
|
+
_poly(cg, r, [
|
|
2362
|
+
1,
|
|
2363
|
+
1,
|
|
2364
|
+
1 / 2,
|
|
2365
|
+
1 / 6,
|
|
2366
|
+
1 / 24,
|
|
2367
|
+
1 / 120,
|
|
2368
|
+
1 / 720
|
|
2369
|
+
]);
|
|
2301
2370
|
cg.local.set(p);
|
|
2302
2371
|
cg.local.get(k);
|
|
2303
2372
|
cg.i32.const(127);
|
|
@@ -2325,11 +2394,6 @@ function wasm_log(cg) {
|
|
|
2325
2394
|
const m = cg.local.declare(cg.f32);
|
|
2326
2395
|
const t = cg.local.declare(cg.f32);
|
|
2327
2396
|
const t2 = cg.local.declare(cg.f32);
|
|
2328
|
-
const t3 = cg.local.declare(cg.f32);
|
|
2329
|
-
const t5 = cg.local.declare(cg.f32);
|
|
2330
|
-
const t7 = cg.local.declare(cg.f32);
|
|
2331
|
-
const lnm = cg.local.declare(cg.f32);
|
|
2332
|
-
const el2 = cg.local.declare(cg.f32);
|
|
2333
2397
|
cg.local.get(0);
|
|
2334
2398
|
cg.f32.const(0);
|
|
2335
2399
|
cg.f32.le();
|
|
@@ -2366,41 +2430,18 @@ function wasm_log(cg) {
|
|
|
2366
2430
|
cg.local.get(t);
|
|
2367
2431
|
cg.f32.mul();
|
|
2368
2432
|
cg.local.set(t2);
|
|
2433
|
+
_poly(cg, t2, [
|
|
2434
|
+
2,
|
|
2435
|
+
2 / 3,
|
|
2436
|
+
2 / 5,
|
|
2437
|
+
2 / 7
|
|
2438
|
+
]);
|
|
2369
2439
|
cg.local.get(t);
|
|
2370
|
-
cg.local.get(t2);
|
|
2371
|
-
cg.f32.mul();
|
|
2372
|
-
cg.local.set(t3);
|
|
2373
|
-
cg.local.get(t3);
|
|
2374
|
-
cg.local.get(t2);
|
|
2375
|
-
cg.f32.mul();
|
|
2376
|
-
cg.local.set(t5);
|
|
2377
|
-
cg.local.get(t5);
|
|
2378
|
-
cg.local.get(t2);
|
|
2379
|
-
cg.f32.mul();
|
|
2380
|
-
cg.local.set(t7);
|
|
2381
|
-
cg.local.get(t7);
|
|
2382
|
-
cg.f32.const(1 / 7);
|
|
2383
|
-
cg.f32.mul();
|
|
2384
|
-
cg.local.get(t5);
|
|
2385
|
-
cg.f32.const(1 / 5);
|
|
2386
|
-
cg.f32.mul();
|
|
2387
|
-
cg.f32.add();
|
|
2388
|
-
cg.local.get(t3);
|
|
2389
|
-
cg.f32.const(1 / 3);
|
|
2390
|
-
cg.f32.mul();
|
|
2391
|
-
cg.f32.add();
|
|
2392
|
-
cg.local.get(t);
|
|
2393
|
-
cg.f32.add();
|
|
2394
|
-
cg.f32.const(2);
|
|
2395
2440
|
cg.f32.mul();
|
|
2396
|
-
cg.local.set(lnm);
|
|
2397
2441
|
cg.local.get(e);
|
|
2398
2442
|
cg.f32.convert_i32_s();
|
|
2399
2443
|
cg.f32.const(Math.LN2);
|
|
2400
2444
|
cg.f32.mul();
|
|
2401
|
-
cg.local.set(el2);
|
|
2402
|
-
cg.local.get(el2);
|
|
2403
|
-
cg.local.get(lnm);
|
|
2404
2445
|
cg.f32.add();
|
|
2405
2446
|
});
|
|
2406
2447
|
}
|
|
@@ -2410,7 +2451,7 @@ function wasm_log(cg) {
|
|
|
2410
2451
|
* Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
|
|
2411
2452
|
* z = y - q*(π/2); use one of two polynomials on z:
|
|
2412
2453
|
* sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
|
|
2413
|
-
* cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
|
|
2454
|
+
* cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720) + z^8*(1/40320)
|
|
2414
2455
|
*/
|
|
2415
2456
|
function _sincos(cg) {
|
|
2416
2457
|
const y = cg.local.declare(cg.f32);
|
|
@@ -2446,35 +2487,22 @@ function _sincos(cg) {
|
|
|
2446
2487
|
cg.local.get(z);
|
|
2447
2488
|
cg.f32.mul();
|
|
2448
2489
|
cg.local.set(z2);
|
|
2449
|
-
cg
|
|
2450
|
-
|
|
2451
|
-
|
|
2452
|
-
|
|
2453
|
-
|
|
2454
|
-
|
|
2455
|
-
cg.f32.mul();
|
|
2456
|
-
cg.f32.const(-1 / 6);
|
|
2457
|
-
cg.f32.add();
|
|
2458
|
-
cg.local.get(z2);
|
|
2459
|
-
cg.f32.mul();
|
|
2460
|
-
cg.f32.const(1);
|
|
2461
|
-
cg.f32.add();
|
|
2490
|
+
_poly(cg, z2, [
|
|
2491
|
+
1,
|
|
2492
|
+
-1 / 6,
|
|
2493
|
+
1 / 120,
|
|
2494
|
+
-1 / 5040
|
|
2495
|
+
]);
|
|
2462
2496
|
cg.local.get(z);
|
|
2463
2497
|
cg.f32.mul();
|
|
2464
2498
|
cg.local.set(sz);
|
|
2465
|
-
cg
|
|
2466
|
-
|
|
2467
|
-
|
|
2468
|
-
|
|
2469
|
-
|
|
2470
|
-
|
|
2471
|
-
|
|
2472
|
-
cg.f32.const(-1 / 2);
|
|
2473
|
-
cg.f32.add();
|
|
2474
|
-
cg.local.get(z2);
|
|
2475
|
-
cg.f32.mul();
|
|
2476
|
-
cg.f32.const(1);
|
|
2477
|
-
cg.f32.add();
|
|
2499
|
+
_poly(cg, z2, [
|
|
2500
|
+
1,
|
|
2501
|
+
-1 / 2,
|
|
2502
|
+
1 / 24,
|
|
2503
|
+
-1 / 720,
|
|
2504
|
+
1 / 40320
|
|
2505
|
+
]);
|
|
2478
2506
|
cg.local.set(cz);
|
|
2479
2507
|
return {
|
|
2480
2508
|
q,
|
|
@@ -2556,24 +2584,16 @@ function _atan(cg) {
|
|
|
2556
2584
|
cg.local.get(z);
|
|
2557
2585
|
cg.f32.mul();
|
|
2558
2586
|
cg.local.set(z2);
|
|
2559
|
-
cg
|
|
2560
|
-
|
|
2561
|
-
|
|
2562
|
-
|
|
2563
|
-
|
|
2564
|
-
cg
|
|
2565
|
-
|
|
2566
|
-
|
|
2567
|
-
|
|
2568
|
-
|
|
2569
|
-
cg.local.get(z2);
|
|
2570
|
-
cg.f32.mul();
|
|
2571
|
-
cg.f32.const(.994987933645);
|
|
2572
|
-
cg.f32.add();
|
|
2573
|
-
cg.local.get(z2);
|
|
2574
|
-
cg.f32.mul();
|
|
2575
|
-
cg.f32.const(1);
|
|
2576
|
-
cg.f32.add();
|
|
2587
|
+
_poly(cg, z2, [
|
|
2588
|
+
.999998614341,
|
|
2589
|
+
.661705427875,
|
|
2590
|
+
.0415796528637
|
|
2591
|
+
]);
|
|
2592
|
+
_poly(cg, z2, [
|
|
2593
|
+
1,
|
|
2594
|
+
.994987933645,
|
|
2595
|
+
.173698870181
|
|
2596
|
+
]);
|
|
2577
2597
|
cg.f32.div();
|
|
2578
2598
|
cg.local.get(z);
|
|
2579
2599
|
cg.f32.mul();
|
|
@@ -2627,6 +2647,74 @@ function wasm_asin(cg) {
|
|
|
2627
2647
|
});
|
|
2628
2648
|
}
|
|
2629
2649
|
/**
|
|
2650
|
+
* Helper function for erf/erfc approximation.
|
|
2651
|
+
*
|
|
2652
|
+
* See `_erfapprox` in alu.ts for details on the algorithm used.
|
|
2653
|
+
*/
|
|
2654
|
+
function _erfapprox(cg, exp_func) {
|
|
2655
|
+
const x = cg.local.declare(cg.f32);
|
|
2656
|
+
const t = cg.local.declare(cg.f32);
|
|
2657
|
+
cg.local.set(x);
|
|
2658
|
+
const p = .3275911;
|
|
2659
|
+
const a1 = .254829592;
|
|
2660
|
+
const a2 = -.284496736;
|
|
2661
|
+
const a3 = 1.421413741;
|
|
2662
|
+
const a4 = -1.453152027;
|
|
2663
|
+
const a5 = 1.061405429;
|
|
2664
|
+
cg.f32.const(1);
|
|
2665
|
+
cg.f32.const(1);
|
|
2666
|
+
cg.f32.const(p);
|
|
2667
|
+
cg.local.get(x);
|
|
2668
|
+
cg.f32.mul();
|
|
2669
|
+
cg.f32.add();
|
|
2670
|
+
cg.f32.div();
|
|
2671
|
+
cg.local.set(t);
|
|
2672
|
+
_poly(cg, t, [
|
|
2673
|
+
0,
|
|
2674
|
+
a1,
|
|
2675
|
+
a2,
|
|
2676
|
+
a3,
|
|
2677
|
+
a4,
|
|
2678
|
+
a5
|
|
2679
|
+
]);
|
|
2680
|
+
cg.local.get(x);
|
|
2681
|
+
cg.f32.neg();
|
|
2682
|
+
cg.local.get(x);
|
|
2683
|
+
cg.f32.mul();
|
|
2684
|
+
cg.call(exp_func);
|
|
2685
|
+
cg.f32.mul();
|
|
2686
|
+
}
|
|
2687
|
+
/** Approximate erf(x) (error function). */
|
|
2688
|
+
function wasm_erf(cg, exp) {
|
|
2689
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2690
|
+
cg.f32.const(1);
|
|
2691
|
+
cg.local.get(0);
|
|
2692
|
+
cg.f32.abs();
|
|
2693
|
+
_erfapprox(cg, exp);
|
|
2694
|
+
cg.f32.sub();
|
|
2695
|
+
cg.local.get(0);
|
|
2696
|
+
cg.f32.copysign();
|
|
2697
|
+
});
|
|
2698
|
+
}
|
|
2699
|
+
/** Approximate erfc(x) (complementary error function). */
|
|
2700
|
+
function wasm_erfc(cg, exp) {
|
|
2701
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2702
|
+
const e = cg.local.declare(cg.f32);
|
|
2703
|
+
cg.local.get(0);
|
|
2704
|
+
cg.f32.abs();
|
|
2705
|
+
_erfapprox(cg, exp);
|
|
2706
|
+
cg.local.set(e);
|
|
2707
|
+
cg.f32.const(2);
|
|
2708
|
+
cg.local.get(e);
|
|
2709
|
+
cg.f32.sub();
|
|
2710
|
+
cg.local.get(e);
|
|
2711
|
+
cg.local.get(0);
|
|
2712
|
+
cg.f32.const(0);
|
|
2713
|
+
cg.f32.lt();
|
|
2714
|
+
cg.select();
|
|
2715
|
+
});
|
|
2716
|
+
}
|
|
2717
|
+
/**
|
|
2630
2718
|
* Threefry2x32 pseudorandom number generator.
|
|
2631
2719
|
*
|
|
2632
2720
|
* Takes two 32-bit keys and two 32-bit counters as input,
|
|
@@ -3501,14 +3589,16 @@ function codegenWasm(kernel) {
|
|
|
3501
3589
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
3502
3590
|
const cg = new CodeGenerator();
|
|
3503
3591
|
cg.memory.import("env", "memory");
|
|
3504
|
-
const distinctOps =
|
|
3592
|
+
const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
|
|
3505
3593
|
const funcs = {};
|
|
3506
3594
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3507
3595
|
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
3508
3596
|
if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
|
|
3509
3597
|
if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
|
|
3510
|
-
if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
|
|
3598
|
+
if (distinctOps.has(AluOp.Exp) || distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) funcs.exp = wasm_exp(cg);
|
|
3511
3599
|
if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
|
|
3600
|
+
if (distinctOps.has(AluOp.Erf)) funcs.erf = wasm_erf(cg, funcs.exp);
|
|
3601
|
+
if (distinctOps.has(AluOp.Erfc)) funcs.erfc = wasm_erfc(cg, funcs.exp);
|
|
3512
3602
|
if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
|
|
3513
3603
|
const kernelFunc = cg.function(rep(kernel.nargs + 1, cg.i32), [], () => {
|
|
3514
3604
|
const gidx = cg.local.declare(cg.i32);
|
|
@@ -3662,6 +3752,8 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3662
3752
|
else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
|
|
3663
3753
|
else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
|
|
3664
3754
|
else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
|
|
3755
|
+
else if (op === AluOp.Erf) gen(src[0]), cg.call(funcs.erf);
|
|
3756
|
+
else if (op === AluOp.Erfc) gen(src[0]), cg.call(funcs.erfc);
|
|
3665
3757
|
else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
|
|
3666
3758
|
else if (op === AluOp.Reciprocal) cg.f32.const(1), gen(src[0]), cg.f32.div();
|
|
3667
3759
|
else if (op === AluOp.Cast) {
|
|
@@ -3789,7 +3881,7 @@ async function createBackend(device) {
|
|
|
3789
3881
|
if (!navigator.gpu) return null;
|
|
3790
3882
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3791
3883
|
if (!adapter) return null;
|
|
3792
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
3884
|
+
const { WebGPUBackend } = await import("./webgpu-LGi2A3mS.js");
|
|
3793
3885
|
const importantLimits = [
|
|
3794
3886
|
"maxBufferSize",
|
|
3795
3887
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -3842,4 +3934,5 @@ var UnsupportedOpError = class extends Error {
|
|
|
3842
3934
|
};
|
|
3843
3935
|
|
|
3844
3936
|
//#endregion
|
|
3845
|
-
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu,
|
|
3937
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
3938
|
+
//# sourceMappingURL=backend-DwIAd0AG.js.map
|