@jax-js/jax 0.0.4 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +296 -78
- package/dist/{backend-EBRGmEYw.js → backend-DwIAd0AG.js} +238 -116
- package/dist/{backend-Ss1Mev_-.cjs → backend-FtkbO6pI.cjs} +256 -122
- package/dist/index.cjs +653 -277
- package/dist/index.d.cts +167 -44
- package/dist/index.d.ts +167 -44
- package/dist/index.js +637 -268
- package/dist/{webgpu-BVdMaO9T.cjs → webgpu-BE7zA_01.cjs} +181 -151
- package/dist/{webgpu-ow0Pn_6q.js → webgpu-LGi2A3mS.js} +181 -151
- package/package.json +7 -5
|
@@ -52,6 +52,7 @@ let DEBUG = 0;
|
|
|
52
52
|
function setDebug(level) {
|
|
53
53
|
DEBUG = level;
|
|
54
54
|
}
|
|
55
|
+
function assertNonNull(value) {}
|
|
55
56
|
function unzip2(pairs) {
|
|
56
57
|
const lst1 = [];
|
|
57
58
|
const lst2 = [];
|
|
@@ -97,10 +98,15 @@ function deepEqual(a, b) {
|
|
|
97
98
|
for (const key of Object.keys(a)) if (!deepEqual(a[key], b[key])) return false;
|
|
98
99
|
return true;
|
|
99
100
|
}
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
101
|
+
/** Produces a union of maps of sets. This mutates `a`. */
|
|
102
|
+
function mapSetUnion(a, b) {
|
|
103
|
+
if (!b) return a;
|
|
104
|
+
for (const [key, setB] of b.entries()) {
|
|
105
|
+
const setA = a.get(key);
|
|
106
|
+
if (setA) for (const val of setB) setA.add(val);
|
|
107
|
+
else a.set(key, setB);
|
|
108
|
+
}
|
|
109
|
+
return a;
|
|
104
110
|
}
|
|
105
111
|
/** Splits the list based on a condition, `false` first then `true`. */
|
|
106
112
|
function partitionList(which, array) {
|
|
@@ -207,6 +213,35 @@ function findPow2(hint, max) {
|
|
|
207
213
|
while (ret < hint && 2 * ret <= max) ret *= 2;
|
|
208
214
|
return ret;
|
|
209
215
|
}
|
|
216
|
+
/**
|
|
217
|
+
* Implements a NumPy-style generalized broadcast rule on two array shapes.
|
|
218
|
+
*
|
|
219
|
+
* "When operating on two arrays, NumPy compares their shapes element-wise. It
|
|
220
|
+
* starts with the trailing (i.e. rightmost) dimension and works its way left.
|
|
221
|
+
* Two dimensions are compatible when:
|
|
222
|
+
* 1. they are equal, or
|
|
223
|
+
* 2. one of them is 1."
|
|
224
|
+
*
|
|
225
|
+
* Throws a TypeError if the broadcast is not possible.
|
|
226
|
+
*
|
|
227
|
+
* <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
|
|
228
|
+
*/
|
|
229
|
+
function generalBroadcast(a, b) {
|
|
230
|
+
const out = [];
|
|
231
|
+
let i = a.length - 1;
|
|
232
|
+
let j = b.length - 1;
|
|
233
|
+
for (; i >= 0 && j >= 0; i--, j--) {
|
|
234
|
+
const x = a[i];
|
|
235
|
+
const y = b[j];
|
|
236
|
+
if (x === y) out.push(x);
|
|
237
|
+
else if (x === 1) out.push(y);
|
|
238
|
+
else if (y === 1) out.push(x);
|
|
239
|
+
else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
|
|
240
|
+
}
|
|
241
|
+
for (; i >= 0; i--) out.push(a[i]);
|
|
242
|
+
for (; j >= 0; j--) out.push(b[j]);
|
|
243
|
+
return out.reverse();
|
|
244
|
+
}
|
|
210
245
|
function recursiveFlatten(ar) {
|
|
211
246
|
if (!Array.isArray(ar)) return [ar];
|
|
212
247
|
return ar.flat(Infinity);
|
|
@@ -295,12 +330,12 @@ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float
|
|
|
295
330
|
* **Type lattice:**
|
|
296
331
|
* ```text
|
|
297
332
|
* bool -> uint32 -> int32 -> float16 -> float32
|
|
298
|
-
*
|
|
333
|
+
* weakType --^
|
|
299
334
|
* ```
|
|
300
335
|
*
|
|
301
|
-
*
|
|
302
|
-
*
|
|
303
|
-
*
|
|
336
|
+
* `weakType` represents weakly typed arrays. These are created for JS numbers,
|
|
337
|
+
* which default to float32 but "weak" so they cast to the dtype of any array
|
|
338
|
+
* they are first combined with, except `bool`.
|
|
304
339
|
*
|
|
305
340
|
* **Examples:**
|
|
306
341
|
* - `promoteTypes(bool, int32) → int32`
|
|
@@ -401,6 +436,12 @@ var AluExp = class AluExp {
|
|
|
401
436
|
static log(a) {
|
|
402
437
|
return new AluExp(AluOp.Log, a.dtype, [a]);
|
|
403
438
|
}
|
|
439
|
+
static erf(a) {
|
|
440
|
+
return new AluExp(AluOp.Erf, a.dtype, [a]);
|
|
441
|
+
}
|
|
442
|
+
static erfc(a) {
|
|
443
|
+
return new AluExp(AluOp.Erfc, a.dtype, [a]);
|
|
444
|
+
}
|
|
404
445
|
static sqrt(a) {
|
|
405
446
|
return new AluExp(AluOp.Sqrt, a.dtype, [a]);
|
|
406
447
|
}
|
|
@@ -571,6 +612,12 @@ var AluExp = class AluExp {
|
|
|
571
612
|
case AluOp.Log:
|
|
572
613
|
ret = [Math.log(src[0].min), Math.log(src[0].max)];
|
|
573
614
|
break;
|
|
615
|
+
case AluOp.Erf:
|
|
616
|
+
ret = [erf(src[0].min), erf(src[0].max)];
|
|
617
|
+
break;
|
|
618
|
+
case AluOp.Erfc:
|
|
619
|
+
ret = [erfc(src[0].max), erfc(src[0].min)];
|
|
620
|
+
break;
|
|
574
621
|
case AluOp.Sqrt:
|
|
575
622
|
ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
|
|
576
623
|
break;
|
|
@@ -892,6 +939,8 @@ var AluExp = class AluExp {
|
|
|
892
939
|
case AluOp.Atan: return Math.atan(x);
|
|
893
940
|
case AluOp.Exp: return Math.exp(x);
|
|
894
941
|
case AluOp.Log: return Math.log(x);
|
|
942
|
+
case AluOp.Erf: return erf(x);
|
|
943
|
+
case AluOp.Erfc: return erfc(x);
|
|
895
944
|
case AluOp.Sqrt: return Math.sqrt(x);
|
|
896
945
|
case AluOp.Reciprocal: return 1 / x;
|
|
897
946
|
case AluOp.Cast: {
|
|
@@ -1040,11 +1089,15 @@ var AluExp = class AluExp {
|
|
|
1040
1089
|
});
|
|
1041
1090
|
return result;
|
|
1042
1091
|
}
|
|
1043
|
-
/** Produce
|
|
1092
|
+
/** Produce all distinct AluOp in this expression, with their dtypes. */
|
|
1044
1093
|
distinctOps() {
|
|
1045
|
-
const ops = /* @__PURE__ */ new
|
|
1094
|
+
const ops = /* @__PURE__ */ new Map();
|
|
1046
1095
|
this.fold((exp) => {
|
|
1047
|
-
ops.
|
|
1096
|
+
const s = ops.get(exp.op) ?? /* @__PURE__ */ new Set();
|
|
1097
|
+
if (!s.has(exp.dtype)) {
|
|
1098
|
+
s.add(exp.dtype);
|
|
1099
|
+
ops.set(exp.op, s);
|
|
1100
|
+
}
|
|
1048
1101
|
});
|
|
1049
1102
|
return ops;
|
|
1050
1103
|
}
|
|
@@ -1073,6 +1126,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
1073
1126
|
AluOp$1["Atan"] = "Atan";
|
|
1074
1127
|
AluOp$1["Exp"] = "Exp";
|
|
1075
1128
|
AluOp$1["Log"] = "Log";
|
|
1129
|
+
AluOp$1["Erf"] = "Erf";
|
|
1130
|
+
AluOp$1["Erfc"] = "Erfc";
|
|
1076
1131
|
AluOp$1["Sqrt"] = "Sqrt";
|
|
1077
1132
|
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
1078
1133
|
AluOp$1["Cast"] = "Cast";
|
|
@@ -1105,6 +1160,8 @@ const AluGroup = {
|
|
|
1105
1160
|
AluOp.Atan,
|
|
1106
1161
|
AluOp.Exp,
|
|
1107
1162
|
AluOp.Log,
|
|
1163
|
+
AluOp.Erf,
|
|
1164
|
+
AluOp.Erfc,
|
|
1108
1165
|
AluOp.Sqrt,
|
|
1109
1166
|
AluOp.Reciprocal,
|
|
1110
1167
|
AluOp.Cast,
|
|
@@ -1130,6 +1187,8 @@ const AluGroup = {
|
|
|
1130
1187
|
AluOp.Atan,
|
|
1131
1188
|
AluOp.Exp,
|
|
1132
1189
|
AluOp.Log,
|
|
1190
|
+
AluOp.Erf,
|
|
1191
|
+
AluOp.Erfc,
|
|
1133
1192
|
AluOp.Sqrt,
|
|
1134
1193
|
AluOp.Reciprocal
|
|
1135
1194
|
])
|
|
@@ -1305,6 +1364,44 @@ function threefry2x32(k0, k1, c0, c1) {
|
|
|
1305
1364
|
x1 = x1 + ks0 + 5 >>> 0;
|
|
1306
1365
|
return [x0, x1];
|
|
1307
1366
|
}
|
|
1367
|
+
/**
|
|
1368
|
+
* Abramowitz & Stegun’s widely used approximation for erf(x).
|
|
1369
|
+
*
|
|
1370
|
+
* `erf(x) = 1 - P(t) * exp(-x^2)` for `x >= 0`, where `t = 1/(1 + p*x)` and
|
|
1371
|
+
* `P(t) = a1*t + a2*t^2 + a3*t^3 + a4*t^4 + a5*t^5`.
|
|
1372
|
+
*
|
|
1373
|
+
* Coefficients:
|
|
1374
|
+
* - p = 0.3275911
|
|
1375
|
+
* - a1 = 0.254829592
|
|
1376
|
+
* - a2 = -0.284496736
|
|
1377
|
+
* - a3 = 1.421413741
|
|
1378
|
+
* - a4 = -1.453152027
|
|
1379
|
+
* - a5 = 1.061405429
|
|
1380
|
+
*
|
|
1381
|
+
* This function computes just `E = P(t) * exp(-x^2)` for numerical reasons. The
|
|
1382
|
+
* input is assumed to be non-negative.
|
|
1383
|
+
*
|
|
1384
|
+
* Reference: https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions
|
|
1385
|
+
*/
|
|
1386
|
+
function _erfapprox$1(x) {
|
|
1387
|
+
const p = .3275911;
|
|
1388
|
+
const a1 = .254829592;
|
|
1389
|
+
const a2 = -.284496736;
|
|
1390
|
+
const a3 = 1.421413741;
|
|
1391
|
+
const a4 = -1.453152027;
|
|
1392
|
+
const a5 = 1.061405429;
|
|
1393
|
+
const t = 1 / (1 + p * x);
|
|
1394
|
+
const P_t = ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t;
|
|
1395
|
+
return P_t * Math.exp(-x * x);
|
|
1396
|
+
}
|
|
1397
|
+
function erf(x) {
|
|
1398
|
+
if (x >= 0) return 1 - _erfapprox$1(x);
|
|
1399
|
+
else return _erfapprox$1(-x) - 1;
|
|
1400
|
+
}
|
|
1401
|
+
function erfc(x) {
|
|
1402
|
+
if (x >= 0) return _erfapprox$1(x);
|
|
1403
|
+
else return 2 - _erfapprox$1(-x);
|
|
1404
|
+
}
|
|
1308
1405
|
|
|
1309
1406
|
//#endregion
|
|
1310
1407
|
//#region src/shape.ts
|
|
@@ -1990,7 +2087,7 @@ function tuneWebgpu(kernel) {
|
|
|
1990
2087
|
}
|
|
1991
2088
|
if (/chrome/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
1992
2089
|
const s = dim.st.shape[dim.unroll - 1];
|
|
1993
|
-
if (s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2090
|
+
if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
1994
2091
|
else for (const splits of [4]) if (s % splits === 0) {
|
|
1995
2092
|
dim.applyUnroll(dim.unroll - 1, splits);
|
|
1996
2093
|
break;
|
|
@@ -2209,6 +2306,19 @@ var WasmAllocator = class {
|
|
|
2209
2306
|
|
|
2210
2307
|
//#endregion
|
|
2211
2308
|
//#region src/backend/wasm/builtins.ts
|
|
2309
|
+
/** Given a local `x`, evaluate `sum[i](a_i * x^i)` and push to stack. */
|
|
2310
|
+
function _poly(cg, x, as) {
|
|
2311
|
+
if (as.length === 0) throw new Error("_poly needs at least one coefficient");
|
|
2312
|
+
cg.f32.const(as[as.length - 1]);
|
|
2313
|
+
for (let i = as.length - 2; i >= 0; i--) {
|
|
2314
|
+
cg.local.get(x);
|
|
2315
|
+
cg.f32.mul();
|
|
2316
|
+
if (as[i] !== 0) {
|
|
2317
|
+
cg.f32.const(as[i]);
|
|
2318
|
+
cg.f32.add();
|
|
2319
|
+
}
|
|
2320
|
+
}
|
|
2321
|
+
}
|
|
2212
2322
|
/**
|
|
2213
2323
|
* Approximate e^x.
|
|
2214
2324
|
*
|
|
@@ -2249,27 +2359,15 @@ function wasm_exp(cg) {
|
|
|
2249
2359
|
cg.f32.mul();
|
|
2250
2360
|
cg.f32.sub();
|
|
2251
2361
|
cg.local.set(r);
|
|
2252
|
-
cg
|
|
2253
|
-
|
|
2254
|
-
|
|
2255
|
-
|
|
2256
|
-
|
|
2257
|
-
|
|
2258
|
-
|
|
2259
|
-
|
|
2260
|
-
|
|
2261
|
-
cg.local.get(r);
|
|
2262
|
-
cg.f32.mul();
|
|
2263
|
-
cg.f32.const(1 / 2);
|
|
2264
|
-
cg.f32.add();
|
|
2265
|
-
cg.local.get(r);
|
|
2266
|
-
cg.f32.mul();
|
|
2267
|
-
cg.f32.const(1);
|
|
2268
|
-
cg.f32.add();
|
|
2269
|
-
cg.local.get(r);
|
|
2270
|
-
cg.f32.mul();
|
|
2271
|
-
cg.f32.const(1);
|
|
2272
|
-
cg.f32.add();
|
|
2362
|
+
_poly(cg, r, [
|
|
2363
|
+
1,
|
|
2364
|
+
1,
|
|
2365
|
+
1 / 2,
|
|
2366
|
+
1 / 6,
|
|
2367
|
+
1 / 24,
|
|
2368
|
+
1 / 120,
|
|
2369
|
+
1 / 720
|
|
2370
|
+
]);
|
|
2273
2371
|
cg.local.set(p);
|
|
2274
2372
|
cg.local.get(k);
|
|
2275
2373
|
cg.i32.const(127);
|
|
@@ -2297,11 +2395,6 @@ function wasm_log(cg) {
|
|
|
2297
2395
|
const m = cg.local.declare(cg.f32);
|
|
2298
2396
|
const t = cg.local.declare(cg.f32);
|
|
2299
2397
|
const t2 = cg.local.declare(cg.f32);
|
|
2300
|
-
const t3 = cg.local.declare(cg.f32);
|
|
2301
|
-
const t5 = cg.local.declare(cg.f32);
|
|
2302
|
-
const t7 = cg.local.declare(cg.f32);
|
|
2303
|
-
const lnm = cg.local.declare(cg.f32);
|
|
2304
|
-
const el2 = cg.local.declare(cg.f32);
|
|
2305
2398
|
cg.local.get(0);
|
|
2306
2399
|
cg.f32.const(0);
|
|
2307
2400
|
cg.f32.le();
|
|
@@ -2338,41 +2431,18 @@ function wasm_log(cg) {
|
|
|
2338
2431
|
cg.local.get(t);
|
|
2339
2432
|
cg.f32.mul();
|
|
2340
2433
|
cg.local.set(t2);
|
|
2434
|
+
_poly(cg, t2, [
|
|
2435
|
+
2,
|
|
2436
|
+
2 / 3,
|
|
2437
|
+
2 / 5,
|
|
2438
|
+
2 / 7
|
|
2439
|
+
]);
|
|
2341
2440
|
cg.local.get(t);
|
|
2342
|
-
cg.local.get(t2);
|
|
2343
|
-
cg.f32.mul();
|
|
2344
|
-
cg.local.set(t3);
|
|
2345
|
-
cg.local.get(t3);
|
|
2346
|
-
cg.local.get(t2);
|
|
2347
|
-
cg.f32.mul();
|
|
2348
|
-
cg.local.set(t5);
|
|
2349
|
-
cg.local.get(t5);
|
|
2350
|
-
cg.local.get(t2);
|
|
2351
|
-
cg.f32.mul();
|
|
2352
|
-
cg.local.set(t7);
|
|
2353
|
-
cg.local.get(t7);
|
|
2354
|
-
cg.f32.const(1 / 7);
|
|
2355
|
-
cg.f32.mul();
|
|
2356
|
-
cg.local.get(t5);
|
|
2357
|
-
cg.f32.const(1 / 5);
|
|
2358
|
-
cg.f32.mul();
|
|
2359
|
-
cg.f32.add();
|
|
2360
|
-
cg.local.get(t3);
|
|
2361
|
-
cg.f32.const(1 / 3);
|
|
2362
|
-
cg.f32.mul();
|
|
2363
|
-
cg.f32.add();
|
|
2364
|
-
cg.local.get(t);
|
|
2365
|
-
cg.f32.add();
|
|
2366
|
-
cg.f32.const(2);
|
|
2367
2441
|
cg.f32.mul();
|
|
2368
|
-
cg.local.set(lnm);
|
|
2369
2442
|
cg.local.get(e);
|
|
2370
2443
|
cg.f32.convert_i32_s();
|
|
2371
2444
|
cg.f32.const(Math.LN2);
|
|
2372
2445
|
cg.f32.mul();
|
|
2373
|
-
cg.local.set(el2);
|
|
2374
|
-
cg.local.get(el2);
|
|
2375
|
-
cg.local.get(lnm);
|
|
2376
2446
|
cg.f32.add();
|
|
2377
2447
|
});
|
|
2378
2448
|
}
|
|
@@ -2382,7 +2452,7 @@ function wasm_log(cg) {
|
|
|
2382
2452
|
* Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
|
|
2383
2453
|
* z = y - q*(π/2); use one of two polynomials on z:
|
|
2384
2454
|
* sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
|
|
2385
|
-
* cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
|
|
2455
|
+
* cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720) + z^8*(1/40320)
|
|
2386
2456
|
*/
|
|
2387
2457
|
function _sincos(cg) {
|
|
2388
2458
|
const y = cg.local.declare(cg.f32);
|
|
@@ -2418,35 +2488,22 @@ function _sincos(cg) {
|
|
|
2418
2488
|
cg.local.get(z);
|
|
2419
2489
|
cg.f32.mul();
|
|
2420
2490
|
cg.local.set(z2);
|
|
2421
|
-
cg
|
|
2422
|
-
|
|
2423
|
-
|
|
2424
|
-
|
|
2425
|
-
|
|
2426
|
-
|
|
2427
|
-
cg.f32.mul();
|
|
2428
|
-
cg.f32.const(-1 / 6);
|
|
2429
|
-
cg.f32.add();
|
|
2430
|
-
cg.local.get(z2);
|
|
2431
|
-
cg.f32.mul();
|
|
2432
|
-
cg.f32.const(1);
|
|
2433
|
-
cg.f32.add();
|
|
2491
|
+
_poly(cg, z2, [
|
|
2492
|
+
1,
|
|
2493
|
+
-1 / 6,
|
|
2494
|
+
1 / 120,
|
|
2495
|
+
-1 / 5040
|
|
2496
|
+
]);
|
|
2434
2497
|
cg.local.get(z);
|
|
2435
2498
|
cg.f32.mul();
|
|
2436
2499
|
cg.local.set(sz);
|
|
2437
|
-
cg
|
|
2438
|
-
|
|
2439
|
-
|
|
2440
|
-
|
|
2441
|
-
|
|
2442
|
-
|
|
2443
|
-
|
|
2444
|
-
cg.f32.const(-1 / 2);
|
|
2445
|
-
cg.f32.add();
|
|
2446
|
-
cg.local.get(z2);
|
|
2447
|
-
cg.f32.mul();
|
|
2448
|
-
cg.f32.const(1);
|
|
2449
|
-
cg.f32.add();
|
|
2500
|
+
_poly(cg, z2, [
|
|
2501
|
+
1,
|
|
2502
|
+
-1 / 2,
|
|
2503
|
+
1 / 24,
|
|
2504
|
+
-1 / 720,
|
|
2505
|
+
1 / 40320
|
|
2506
|
+
]);
|
|
2450
2507
|
cg.local.set(cz);
|
|
2451
2508
|
return {
|
|
2452
2509
|
q,
|
|
@@ -2528,24 +2585,16 @@ function _atan(cg) {
|
|
|
2528
2585
|
cg.local.get(z);
|
|
2529
2586
|
cg.f32.mul();
|
|
2530
2587
|
cg.local.set(z2);
|
|
2531
|
-
cg
|
|
2532
|
-
|
|
2533
|
-
|
|
2534
|
-
|
|
2535
|
-
|
|
2536
|
-
cg
|
|
2537
|
-
|
|
2538
|
-
|
|
2539
|
-
|
|
2540
|
-
|
|
2541
|
-
cg.local.get(z2);
|
|
2542
|
-
cg.f32.mul();
|
|
2543
|
-
cg.f32.const(.994987933645);
|
|
2544
|
-
cg.f32.add();
|
|
2545
|
-
cg.local.get(z2);
|
|
2546
|
-
cg.f32.mul();
|
|
2547
|
-
cg.f32.const(1);
|
|
2548
|
-
cg.f32.add();
|
|
2588
|
+
_poly(cg, z2, [
|
|
2589
|
+
.999998614341,
|
|
2590
|
+
.661705427875,
|
|
2591
|
+
.0415796528637
|
|
2592
|
+
]);
|
|
2593
|
+
_poly(cg, z2, [
|
|
2594
|
+
1,
|
|
2595
|
+
.994987933645,
|
|
2596
|
+
.173698870181
|
|
2597
|
+
]);
|
|
2549
2598
|
cg.f32.div();
|
|
2550
2599
|
cg.local.get(z);
|
|
2551
2600
|
cg.f32.mul();
|
|
@@ -2599,6 +2648,74 @@ function wasm_asin(cg) {
|
|
|
2599
2648
|
});
|
|
2600
2649
|
}
|
|
2601
2650
|
/**
|
|
2651
|
+
* Helper function for erf/erfc approximation.
|
|
2652
|
+
*
|
|
2653
|
+
* See `_erfapprox` in alu.ts for details on the algorithm used.
|
|
2654
|
+
*/
|
|
2655
|
+
function _erfapprox(cg, exp_func) {
|
|
2656
|
+
const x = cg.local.declare(cg.f32);
|
|
2657
|
+
const t = cg.local.declare(cg.f32);
|
|
2658
|
+
cg.local.set(x);
|
|
2659
|
+
const p = .3275911;
|
|
2660
|
+
const a1 = .254829592;
|
|
2661
|
+
const a2 = -.284496736;
|
|
2662
|
+
const a3 = 1.421413741;
|
|
2663
|
+
const a4 = -1.453152027;
|
|
2664
|
+
const a5 = 1.061405429;
|
|
2665
|
+
cg.f32.const(1);
|
|
2666
|
+
cg.f32.const(1);
|
|
2667
|
+
cg.f32.const(p);
|
|
2668
|
+
cg.local.get(x);
|
|
2669
|
+
cg.f32.mul();
|
|
2670
|
+
cg.f32.add();
|
|
2671
|
+
cg.f32.div();
|
|
2672
|
+
cg.local.set(t);
|
|
2673
|
+
_poly(cg, t, [
|
|
2674
|
+
0,
|
|
2675
|
+
a1,
|
|
2676
|
+
a2,
|
|
2677
|
+
a3,
|
|
2678
|
+
a4,
|
|
2679
|
+
a5
|
|
2680
|
+
]);
|
|
2681
|
+
cg.local.get(x);
|
|
2682
|
+
cg.f32.neg();
|
|
2683
|
+
cg.local.get(x);
|
|
2684
|
+
cg.f32.mul();
|
|
2685
|
+
cg.call(exp_func);
|
|
2686
|
+
cg.f32.mul();
|
|
2687
|
+
}
|
|
2688
|
+
/** Approximate erf(x) (error function). */
|
|
2689
|
+
function wasm_erf(cg, exp) {
|
|
2690
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2691
|
+
cg.f32.const(1);
|
|
2692
|
+
cg.local.get(0);
|
|
2693
|
+
cg.f32.abs();
|
|
2694
|
+
_erfapprox(cg, exp);
|
|
2695
|
+
cg.f32.sub();
|
|
2696
|
+
cg.local.get(0);
|
|
2697
|
+
cg.f32.copysign();
|
|
2698
|
+
});
|
|
2699
|
+
}
|
|
2700
|
+
/** Approximate erfc(x) (complementary error function). */
|
|
2701
|
+
function wasm_erfc(cg, exp) {
|
|
2702
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2703
|
+
const e = cg.local.declare(cg.f32);
|
|
2704
|
+
cg.local.get(0);
|
|
2705
|
+
cg.f32.abs();
|
|
2706
|
+
_erfapprox(cg, exp);
|
|
2707
|
+
cg.local.set(e);
|
|
2708
|
+
cg.f32.const(2);
|
|
2709
|
+
cg.local.get(e);
|
|
2710
|
+
cg.f32.sub();
|
|
2711
|
+
cg.local.get(e);
|
|
2712
|
+
cg.local.get(0);
|
|
2713
|
+
cg.f32.const(0);
|
|
2714
|
+
cg.f32.lt();
|
|
2715
|
+
cg.select();
|
|
2716
|
+
});
|
|
2717
|
+
}
|
|
2718
|
+
/**
|
|
2602
2719
|
* Threefry2x32 pseudorandom number generator.
|
|
2603
2720
|
*
|
|
2604
2721
|
* Takes two 32-bit keys and two 32-bit counters as input,
|
|
@@ -3473,14 +3590,16 @@ function codegenWasm(kernel) {
|
|
|
3473
3590
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
3474
3591
|
const cg = new CodeGenerator();
|
|
3475
3592
|
cg.memory.import("env", "memory");
|
|
3476
|
-
const distinctOps =
|
|
3593
|
+
const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
|
|
3477
3594
|
const funcs = {};
|
|
3478
3595
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3479
3596
|
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
3480
3597
|
if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
|
|
3481
3598
|
if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
|
|
3482
|
-
if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
|
|
3599
|
+
if (distinctOps.has(AluOp.Exp) || distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) funcs.exp = wasm_exp(cg);
|
|
3483
3600
|
if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
|
|
3601
|
+
if (distinctOps.has(AluOp.Erf)) funcs.erf = wasm_erf(cg, funcs.exp);
|
|
3602
|
+
if (distinctOps.has(AluOp.Erfc)) funcs.erfc = wasm_erfc(cg, funcs.exp);
|
|
3484
3603
|
if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
|
|
3485
3604
|
const kernelFunc = cg.function(rep(kernel.nargs + 1, cg.i32), [], () => {
|
|
3486
3605
|
const gidx = cg.local.declare(cg.i32);
|
|
@@ -3634,6 +3753,8 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3634
3753
|
else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
|
|
3635
3754
|
else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
|
|
3636
3755
|
else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
|
|
3756
|
+
else if (op === AluOp.Erf) gen(src[0]), cg.call(funcs.erf);
|
|
3757
|
+
else if (op === AluOp.Erfc) gen(src[0]), cg.call(funcs.erfc);
|
|
3637
3758
|
else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
|
|
3638
3759
|
else if (op === AluOp.Reciprocal) cg.f32.const(1), gen(src[0]), cg.f32.div();
|
|
3639
3760
|
else if (op === AluOp.Cast) {
|
|
@@ -3761,7 +3882,7 @@ async function createBackend(device) {
|
|
|
3761
3882
|
if (!navigator.gpu) return null;
|
|
3762
3883
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3763
3884
|
if (!adapter) return null;
|
|
3764
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
3885
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BE7zA_01.cjs"));
|
|
3765
3886
|
const importantLimits = [
|
|
3766
3887
|
"maxBufferSize",
|
|
3767
3888
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -3910,6 +4031,12 @@ Object.defineProperty(exports, 'accessorGlobal', {
|
|
|
3910
4031
|
return accessorGlobal;
|
|
3911
4032
|
}
|
|
3912
4033
|
});
|
|
4034
|
+
Object.defineProperty(exports, 'assertNonNull', {
|
|
4035
|
+
enumerable: true,
|
|
4036
|
+
get: function () {
|
|
4037
|
+
return assertNonNull;
|
|
4038
|
+
}
|
|
4039
|
+
});
|
|
3913
4040
|
Object.defineProperty(exports, 'byteWidth', {
|
|
3914
4041
|
enumerable: true,
|
|
3915
4042
|
get: function () {
|
|
@@ -3958,6 +4085,12 @@ Object.defineProperty(exports, 'findPow2', {
|
|
|
3958
4085
|
return findPow2;
|
|
3959
4086
|
}
|
|
3960
4087
|
});
|
|
4088
|
+
Object.defineProperty(exports, 'generalBroadcast', {
|
|
4089
|
+
enumerable: true,
|
|
4090
|
+
get: function () {
|
|
4091
|
+
return generalBroadcast;
|
|
4092
|
+
}
|
|
4093
|
+
});
|
|
3961
4094
|
Object.defineProperty(exports, 'getBackend', {
|
|
3962
4095
|
enumerable: true,
|
|
3963
4096
|
get: function () {
|
|
@@ -3994,6 +4127,12 @@ Object.defineProperty(exports, 'isPermutation', {
|
|
|
3994
4127
|
return isPermutation;
|
|
3995
4128
|
}
|
|
3996
4129
|
});
|
|
4130
|
+
Object.defineProperty(exports, 'mapSetUnion', {
|
|
4131
|
+
enumerable: true,
|
|
4132
|
+
get: function () {
|
|
4133
|
+
return mapSetUnion;
|
|
4134
|
+
}
|
|
4135
|
+
});
|
|
3997
4136
|
Object.defineProperty(exports, 'normalizeAxis', {
|
|
3998
4137
|
enumerable: true,
|
|
3999
4138
|
get: function () {
|
|
@@ -4066,12 +4205,6 @@ Object.defineProperty(exports, 'tuneWebgpu', {
|
|
|
4066
4205
|
return tuneWebgpu;
|
|
4067
4206
|
}
|
|
4068
4207
|
});
|
|
4069
|
-
Object.defineProperty(exports, 'union', {
|
|
4070
|
-
enumerable: true,
|
|
4071
|
-
get: function () {
|
|
4072
|
-
return union;
|
|
4073
|
-
}
|
|
4074
|
-
});
|
|
4075
4208
|
Object.defineProperty(exports, 'unravelAlu', {
|
|
4076
4209
|
enumerable: true,
|
|
4077
4210
|
get: function () {
|
|
@@ -4095,4 +4228,5 @@ Object.defineProperty(exports, 'zipn', {
|
|
|
4095
4228
|
get: function () {
|
|
4096
4229
|
return zipn;
|
|
4097
4230
|
}
|
|
4098
|
-
});
|
|
4231
|
+
});
|
|
4232
|
+
//# sourceMappingURL=backend-FtkbO6pI.cjs.map
|