@jax-js/jax 0.0.4 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +296 -78
- package/dist/{backend-EBRGmEYw.js → backend-DwIAd0AG.js} +238 -116
- package/dist/{backend-Ss1Mev_-.cjs → backend-FtkbO6pI.cjs} +256 -122
- package/dist/index.cjs +653 -277
- package/dist/index.d.cts +167 -44
- package/dist/index.d.ts +167 -44
- package/dist/index.js +637 -268
- package/dist/{webgpu-BVdMaO9T.cjs → webgpu-BE7zA_01.cjs} +181 -151
- package/dist/{webgpu-ow0Pn_6q.js → webgpu-LGi2A3mS.js} +181 -151
- package/package.json +7 -5
|
@@ -51,6 +51,7 @@ let DEBUG = 0;
|
|
|
51
51
|
function setDebug(level) {
|
|
52
52
|
DEBUG = level;
|
|
53
53
|
}
|
|
54
|
+
function assertNonNull(value) {}
|
|
54
55
|
function unzip2(pairs) {
|
|
55
56
|
const lst1 = [];
|
|
56
57
|
const lst2 = [];
|
|
@@ -96,10 +97,15 @@ function deepEqual(a, b) {
|
|
|
96
97
|
for (const key of Object.keys(a)) if (!deepEqual(a[key], b[key])) return false;
|
|
97
98
|
return true;
|
|
98
99
|
}
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
100
|
+
/** Produces a union of maps of sets. This mutates `a`. */
|
|
101
|
+
function mapSetUnion(a, b) {
|
|
102
|
+
if (!b) return a;
|
|
103
|
+
for (const [key, setB] of b.entries()) {
|
|
104
|
+
const setA = a.get(key);
|
|
105
|
+
if (setA) for (const val of setB) setA.add(val);
|
|
106
|
+
else a.set(key, setB);
|
|
107
|
+
}
|
|
108
|
+
return a;
|
|
103
109
|
}
|
|
104
110
|
/** Splits the list based on a condition, `false` first then `true`. */
|
|
105
111
|
function partitionList(which, array) {
|
|
@@ -206,6 +212,35 @@ function findPow2(hint, max) {
|
|
|
206
212
|
while (ret < hint && 2 * ret <= max) ret *= 2;
|
|
207
213
|
return ret;
|
|
208
214
|
}
|
|
215
|
+
/**
|
|
216
|
+
* Implements a NumPy-style generalized broadcast rule on two array shapes.
|
|
217
|
+
*
|
|
218
|
+
* "When operating on two arrays, NumPy compares their shapes element-wise. It
|
|
219
|
+
* starts with the trailing (i.e. rightmost) dimension and works its way left.
|
|
220
|
+
* Two dimensions are compatible when:
|
|
221
|
+
* 1. they are equal, or
|
|
222
|
+
* 2. one of them is 1."
|
|
223
|
+
*
|
|
224
|
+
* Throws a TypeError if the broadcast is not possible.
|
|
225
|
+
*
|
|
226
|
+
* <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
|
|
227
|
+
*/
|
|
228
|
+
function generalBroadcast(a, b) {
|
|
229
|
+
const out = [];
|
|
230
|
+
let i = a.length - 1;
|
|
231
|
+
let j = b.length - 1;
|
|
232
|
+
for (; i >= 0 && j >= 0; i--, j--) {
|
|
233
|
+
const x = a[i];
|
|
234
|
+
const y = b[j];
|
|
235
|
+
if (x === y) out.push(x);
|
|
236
|
+
else if (x === 1) out.push(y);
|
|
237
|
+
else if (y === 1) out.push(x);
|
|
238
|
+
else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
|
|
239
|
+
}
|
|
240
|
+
for (; i >= 0; i--) out.push(a[i]);
|
|
241
|
+
for (; j >= 0; j--) out.push(b[j]);
|
|
242
|
+
return out.reverse();
|
|
243
|
+
}
|
|
209
244
|
function recursiveFlatten(ar) {
|
|
210
245
|
if (!Array.isArray(ar)) return [ar];
|
|
211
246
|
return ar.flat(Infinity);
|
|
@@ -294,12 +329,12 @@ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float
|
|
|
294
329
|
* **Type lattice:**
|
|
295
330
|
* ```text
|
|
296
331
|
* bool -> uint32 -> int32 -> float16 -> float32
|
|
297
|
-
*
|
|
332
|
+
* weakType --^
|
|
298
333
|
* ```
|
|
299
334
|
*
|
|
300
|
-
*
|
|
301
|
-
*
|
|
302
|
-
*
|
|
335
|
+
* `weakType` represents weakly typed arrays. These are created for JS numbers,
|
|
336
|
+
* which default to float32 but "weak" so they cast to the dtype of any array
|
|
337
|
+
* they are first combined with, except `bool`.
|
|
303
338
|
*
|
|
304
339
|
* **Examples:**
|
|
305
340
|
* - `promoteTypes(bool, int32) → int32`
|
|
@@ -400,6 +435,12 @@ var AluExp = class AluExp {
|
|
|
400
435
|
static log(a) {
|
|
401
436
|
return new AluExp(AluOp.Log, a.dtype, [a]);
|
|
402
437
|
}
|
|
438
|
+
static erf(a) {
|
|
439
|
+
return new AluExp(AluOp.Erf, a.dtype, [a]);
|
|
440
|
+
}
|
|
441
|
+
static erfc(a) {
|
|
442
|
+
return new AluExp(AluOp.Erfc, a.dtype, [a]);
|
|
443
|
+
}
|
|
403
444
|
static sqrt(a) {
|
|
404
445
|
return new AluExp(AluOp.Sqrt, a.dtype, [a]);
|
|
405
446
|
}
|
|
@@ -570,6 +611,12 @@ var AluExp = class AluExp {
|
|
|
570
611
|
case AluOp.Log:
|
|
571
612
|
ret = [Math.log(src[0].min), Math.log(src[0].max)];
|
|
572
613
|
break;
|
|
614
|
+
case AluOp.Erf:
|
|
615
|
+
ret = [erf(src[0].min), erf(src[0].max)];
|
|
616
|
+
break;
|
|
617
|
+
case AluOp.Erfc:
|
|
618
|
+
ret = [erfc(src[0].max), erfc(src[0].min)];
|
|
619
|
+
break;
|
|
573
620
|
case AluOp.Sqrt:
|
|
574
621
|
ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
|
|
575
622
|
break;
|
|
@@ -891,6 +938,8 @@ var AluExp = class AluExp {
|
|
|
891
938
|
case AluOp.Atan: return Math.atan(x);
|
|
892
939
|
case AluOp.Exp: return Math.exp(x);
|
|
893
940
|
case AluOp.Log: return Math.log(x);
|
|
941
|
+
case AluOp.Erf: return erf(x);
|
|
942
|
+
case AluOp.Erfc: return erfc(x);
|
|
894
943
|
case AluOp.Sqrt: return Math.sqrt(x);
|
|
895
944
|
case AluOp.Reciprocal: return 1 / x;
|
|
896
945
|
case AluOp.Cast: {
|
|
@@ -1039,11 +1088,15 @@ var AluExp = class AluExp {
|
|
|
1039
1088
|
});
|
|
1040
1089
|
return result;
|
|
1041
1090
|
}
|
|
1042
|
-
/** Produce
|
|
1091
|
+
/** Produce all distinct AluOp in this expression, with their dtypes. */
|
|
1043
1092
|
distinctOps() {
|
|
1044
|
-
const ops = /* @__PURE__ */ new
|
|
1093
|
+
const ops = /* @__PURE__ */ new Map();
|
|
1045
1094
|
this.fold((exp) => {
|
|
1046
|
-
ops.
|
|
1095
|
+
const s = ops.get(exp.op) ?? /* @__PURE__ */ new Set();
|
|
1096
|
+
if (!s.has(exp.dtype)) {
|
|
1097
|
+
s.add(exp.dtype);
|
|
1098
|
+
ops.set(exp.op, s);
|
|
1099
|
+
}
|
|
1047
1100
|
});
|
|
1048
1101
|
return ops;
|
|
1049
1102
|
}
|
|
@@ -1072,6 +1125,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
1072
1125
|
AluOp$1["Atan"] = "Atan";
|
|
1073
1126
|
AluOp$1["Exp"] = "Exp";
|
|
1074
1127
|
AluOp$1["Log"] = "Log";
|
|
1128
|
+
AluOp$1["Erf"] = "Erf";
|
|
1129
|
+
AluOp$1["Erfc"] = "Erfc";
|
|
1075
1130
|
AluOp$1["Sqrt"] = "Sqrt";
|
|
1076
1131
|
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
1077
1132
|
AluOp$1["Cast"] = "Cast";
|
|
@@ -1104,6 +1159,8 @@ const AluGroup = {
|
|
|
1104
1159
|
AluOp.Atan,
|
|
1105
1160
|
AluOp.Exp,
|
|
1106
1161
|
AluOp.Log,
|
|
1162
|
+
AluOp.Erf,
|
|
1163
|
+
AluOp.Erfc,
|
|
1107
1164
|
AluOp.Sqrt,
|
|
1108
1165
|
AluOp.Reciprocal,
|
|
1109
1166
|
AluOp.Cast,
|
|
@@ -1129,6 +1186,8 @@ const AluGroup = {
|
|
|
1129
1186
|
AluOp.Atan,
|
|
1130
1187
|
AluOp.Exp,
|
|
1131
1188
|
AluOp.Log,
|
|
1189
|
+
AluOp.Erf,
|
|
1190
|
+
AluOp.Erfc,
|
|
1132
1191
|
AluOp.Sqrt,
|
|
1133
1192
|
AluOp.Reciprocal
|
|
1134
1193
|
])
|
|
@@ -1304,6 +1363,44 @@ function threefry2x32(k0, k1, c0, c1) {
|
|
|
1304
1363
|
x1 = x1 + ks0 + 5 >>> 0;
|
|
1305
1364
|
return [x0, x1];
|
|
1306
1365
|
}
|
|
1366
|
+
/**
|
|
1367
|
+
* Abramowitz & Stegun’s widely used approximation for erf(x).
|
|
1368
|
+
*
|
|
1369
|
+
* `erf(x) = 1 - P(t) * exp(-x^2)` for `x >= 0`, where `t = 1/(1 + p*x)` and
|
|
1370
|
+
* `P(t) = a1*t + a2*t^2 + a3*t^3 + a4*t^4 + a5*t^5`.
|
|
1371
|
+
*
|
|
1372
|
+
* Coefficients:
|
|
1373
|
+
* - p = 0.3275911
|
|
1374
|
+
* - a1 = 0.254829592
|
|
1375
|
+
* - a2 = -0.284496736
|
|
1376
|
+
* - a3 = 1.421413741
|
|
1377
|
+
* - a4 = -1.453152027
|
|
1378
|
+
* - a5 = 1.061405429
|
|
1379
|
+
*
|
|
1380
|
+
* This function computes just `E = P(t) * exp(-x^2)` for numerical reasons. The
|
|
1381
|
+
* input is assumed to be non-negative.
|
|
1382
|
+
*
|
|
1383
|
+
* Reference: https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions
|
|
1384
|
+
*/
|
|
1385
|
+
function _erfapprox$1(x) {
|
|
1386
|
+
const p = .3275911;
|
|
1387
|
+
const a1 = .254829592;
|
|
1388
|
+
const a2 = -.284496736;
|
|
1389
|
+
const a3 = 1.421413741;
|
|
1390
|
+
const a4 = -1.453152027;
|
|
1391
|
+
const a5 = 1.061405429;
|
|
1392
|
+
const t = 1 / (1 + p * x);
|
|
1393
|
+
const P_t = ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t;
|
|
1394
|
+
return P_t * Math.exp(-x * x);
|
|
1395
|
+
}
|
|
1396
|
+
function erf(x) {
|
|
1397
|
+
if (x >= 0) return 1 - _erfapprox$1(x);
|
|
1398
|
+
else return _erfapprox$1(-x) - 1;
|
|
1399
|
+
}
|
|
1400
|
+
function erfc(x) {
|
|
1401
|
+
if (x >= 0) return _erfapprox$1(x);
|
|
1402
|
+
else return 2 - _erfapprox$1(-x);
|
|
1403
|
+
}
|
|
1307
1404
|
|
|
1308
1405
|
//#endregion
|
|
1309
1406
|
//#region src/shape.ts
|
|
@@ -1989,7 +2086,7 @@ function tuneWebgpu(kernel) {
|
|
|
1989
2086
|
}
|
|
1990
2087
|
if (/chrome/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
1991
2088
|
const s = dim.st.shape[dim.unroll - 1];
|
|
1992
|
-
if (s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2089
|
+
if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
1993
2090
|
else for (const splits of [4]) if (s % splits === 0) {
|
|
1994
2091
|
dim.applyUnroll(dim.unroll - 1, splits);
|
|
1995
2092
|
break;
|
|
@@ -2208,6 +2305,19 @@ var WasmAllocator = class {
|
|
|
2208
2305
|
|
|
2209
2306
|
//#endregion
|
|
2210
2307
|
//#region src/backend/wasm/builtins.ts
|
|
2308
|
+
/** Given a local `x`, evaluate `sum[i](a_i * x^i)` and push to stack. */
|
|
2309
|
+
function _poly(cg, x, as) {
|
|
2310
|
+
if (as.length === 0) throw new Error("_poly needs at least one coefficient");
|
|
2311
|
+
cg.f32.const(as[as.length - 1]);
|
|
2312
|
+
for (let i = as.length - 2; i >= 0; i--) {
|
|
2313
|
+
cg.local.get(x);
|
|
2314
|
+
cg.f32.mul();
|
|
2315
|
+
if (as[i] !== 0) {
|
|
2316
|
+
cg.f32.const(as[i]);
|
|
2317
|
+
cg.f32.add();
|
|
2318
|
+
}
|
|
2319
|
+
}
|
|
2320
|
+
}
|
|
2211
2321
|
/**
|
|
2212
2322
|
* Approximate e^x.
|
|
2213
2323
|
*
|
|
@@ -2248,27 +2358,15 @@ function wasm_exp(cg) {
|
|
|
2248
2358
|
cg.f32.mul();
|
|
2249
2359
|
cg.f32.sub();
|
|
2250
2360
|
cg.local.set(r);
|
|
2251
|
-
cg
|
|
2252
|
-
|
|
2253
|
-
|
|
2254
|
-
|
|
2255
|
-
|
|
2256
|
-
|
|
2257
|
-
|
|
2258
|
-
|
|
2259
|
-
|
|
2260
|
-
cg.local.get(r);
|
|
2261
|
-
cg.f32.mul();
|
|
2262
|
-
cg.f32.const(1 / 2);
|
|
2263
|
-
cg.f32.add();
|
|
2264
|
-
cg.local.get(r);
|
|
2265
|
-
cg.f32.mul();
|
|
2266
|
-
cg.f32.const(1);
|
|
2267
|
-
cg.f32.add();
|
|
2268
|
-
cg.local.get(r);
|
|
2269
|
-
cg.f32.mul();
|
|
2270
|
-
cg.f32.const(1);
|
|
2271
|
-
cg.f32.add();
|
|
2361
|
+
_poly(cg, r, [
|
|
2362
|
+
1,
|
|
2363
|
+
1,
|
|
2364
|
+
1 / 2,
|
|
2365
|
+
1 / 6,
|
|
2366
|
+
1 / 24,
|
|
2367
|
+
1 / 120,
|
|
2368
|
+
1 / 720
|
|
2369
|
+
]);
|
|
2272
2370
|
cg.local.set(p);
|
|
2273
2371
|
cg.local.get(k);
|
|
2274
2372
|
cg.i32.const(127);
|
|
@@ -2296,11 +2394,6 @@ function wasm_log(cg) {
|
|
|
2296
2394
|
const m = cg.local.declare(cg.f32);
|
|
2297
2395
|
const t = cg.local.declare(cg.f32);
|
|
2298
2396
|
const t2 = cg.local.declare(cg.f32);
|
|
2299
|
-
const t3 = cg.local.declare(cg.f32);
|
|
2300
|
-
const t5 = cg.local.declare(cg.f32);
|
|
2301
|
-
const t7 = cg.local.declare(cg.f32);
|
|
2302
|
-
const lnm = cg.local.declare(cg.f32);
|
|
2303
|
-
const el2 = cg.local.declare(cg.f32);
|
|
2304
2397
|
cg.local.get(0);
|
|
2305
2398
|
cg.f32.const(0);
|
|
2306
2399
|
cg.f32.le();
|
|
@@ -2337,41 +2430,18 @@ function wasm_log(cg) {
|
|
|
2337
2430
|
cg.local.get(t);
|
|
2338
2431
|
cg.f32.mul();
|
|
2339
2432
|
cg.local.set(t2);
|
|
2433
|
+
_poly(cg, t2, [
|
|
2434
|
+
2,
|
|
2435
|
+
2 / 3,
|
|
2436
|
+
2 / 5,
|
|
2437
|
+
2 / 7
|
|
2438
|
+
]);
|
|
2340
2439
|
cg.local.get(t);
|
|
2341
|
-
cg.local.get(t2);
|
|
2342
|
-
cg.f32.mul();
|
|
2343
|
-
cg.local.set(t3);
|
|
2344
|
-
cg.local.get(t3);
|
|
2345
|
-
cg.local.get(t2);
|
|
2346
|
-
cg.f32.mul();
|
|
2347
|
-
cg.local.set(t5);
|
|
2348
|
-
cg.local.get(t5);
|
|
2349
|
-
cg.local.get(t2);
|
|
2350
|
-
cg.f32.mul();
|
|
2351
|
-
cg.local.set(t7);
|
|
2352
|
-
cg.local.get(t7);
|
|
2353
|
-
cg.f32.const(1 / 7);
|
|
2354
|
-
cg.f32.mul();
|
|
2355
|
-
cg.local.get(t5);
|
|
2356
|
-
cg.f32.const(1 / 5);
|
|
2357
|
-
cg.f32.mul();
|
|
2358
|
-
cg.f32.add();
|
|
2359
|
-
cg.local.get(t3);
|
|
2360
|
-
cg.f32.const(1 / 3);
|
|
2361
|
-
cg.f32.mul();
|
|
2362
|
-
cg.f32.add();
|
|
2363
|
-
cg.local.get(t);
|
|
2364
|
-
cg.f32.add();
|
|
2365
|
-
cg.f32.const(2);
|
|
2366
2440
|
cg.f32.mul();
|
|
2367
|
-
cg.local.set(lnm);
|
|
2368
2441
|
cg.local.get(e);
|
|
2369
2442
|
cg.f32.convert_i32_s();
|
|
2370
2443
|
cg.f32.const(Math.LN2);
|
|
2371
2444
|
cg.f32.mul();
|
|
2372
|
-
cg.local.set(el2);
|
|
2373
|
-
cg.local.get(el2);
|
|
2374
|
-
cg.local.get(lnm);
|
|
2375
2445
|
cg.f32.add();
|
|
2376
2446
|
});
|
|
2377
2447
|
}
|
|
@@ -2381,7 +2451,7 @@ function wasm_log(cg) {
|
|
|
2381
2451
|
* Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
|
|
2382
2452
|
* z = y - q*(π/2); use one of two polynomials on z:
|
|
2383
2453
|
* sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
|
|
2384
|
-
* cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
|
|
2454
|
+
* cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720) + z^8*(1/40320)
|
|
2385
2455
|
*/
|
|
2386
2456
|
function _sincos(cg) {
|
|
2387
2457
|
const y = cg.local.declare(cg.f32);
|
|
@@ -2417,35 +2487,22 @@ function _sincos(cg) {
|
|
|
2417
2487
|
cg.local.get(z);
|
|
2418
2488
|
cg.f32.mul();
|
|
2419
2489
|
cg.local.set(z2);
|
|
2420
|
-
cg
|
|
2421
|
-
|
|
2422
|
-
|
|
2423
|
-
|
|
2424
|
-
|
|
2425
|
-
|
|
2426
|
-
cg.f32.mul();
|
|
2427
|
-
cg.f32.const(-1 / 6);
|
|
2428
|
-
cg.f32.add();
|
|
2429
|
-
cg.local.get(z2);
|
|
2430
|
-
cg.f32.mul();
|
|
2431
|
-
cg.f32.const(1);
|
|
2432
|
-
cg.f32.add();
|
|
2490
|
+
_poly(cg, z2, [
|
|
2491
|
+
1,
|
|
2492
|
+
-1 / 6,
|
|
2493
|
+
1 / 120,
|
|
2494
|
+
-1 / 5040
|
|
2495
|
+
]);
|
|
2433
2496
|
cg.local.get(z);
|
|
2434
2497
|
cg.f32.mul();
|
|
2435
2498
|
cg.local.set(sz);
|
|
2436
|
-
cg
|
|
2437
|
-
|
|
2438
|
-
|
|
2439
|
-
|
|
2440
|
-
|
|
2441
|
-
|
|
2442
|
-
|
|
2443
|
-
cg.f32.const(-1 / 2);
|
|
2444
|
-
cg.f32.add();
|
|
2445
|
-
cg.local.get(z2);
|
|
2446
|
-
cg.f32.mul();
|
|
2447
|
-
cg.f32.const(1);
|
|
2448
|
-
cg.f32.add();
|
|
2499
|
+
_poly(cg, z2, [
|
|
2500
|
+
1,
|
|
2501
|
+
-1 / 2,
|
|
2502
|
+
1 / 24,
|
|
2503
|
+
-1 / 720,
|
|
2504
|
+
1 / 40320
|
|
2505
|
+
]);
|
|
2449
2506
|
cg.local.set(cz);
|
|
2450
2507
|
return {
|
|
2451
2508
|
q,
|
|
@@ -2527,24 +2584,16 @@ function _atan(cg) {
|
|
|
2527
2584
|
cg.local.get(z);
|
|
2528
2585
|
cg.f32.mul();
|
|
2529
2586
|
cg.local.set(z2);
|
|
2530
|
-
cg
|
|
2531
|
-
|
|
2532
|
-
|
|
2533
|
-
|
|
2534
|
-
|
|
2535
|
-
cg
|
|
2536
|
-
|
|
2537
|
-
|
|
2538
|
-
|
|
2539
|
-
|
|
2540
|
-
cg.local.get(z2);
|
|
2541
|
-
cg.f32.mul();
|
|
2542
|
-
cg.f32.const(.994987933645);
|
|
2543
|
-
cg.f32.add();
|
|
2544
|
-
cg.local.get(z2);
|
|
2545
|
-
cg.f32.mul();
|
|
2546
|
-
cg.f32.const(1);
|
|
2547
|
-
cg.f32.add();
|
|
2587
|
+
_poly(cg, z2, [
|
|
2588
|
+
.999998614341,
|
|
2589
|
+
.661705427875,
|
|
2590
|
+
.0415796528637
|
|
2591
|
+
]);
|
|
2592
|
+
_poly(cg, z2, [
|
|
2593
|
+
1,
|
|
2594
|
+
.994987933645,
|
|
2595
|
+
.173698870181
|
|
2596
|
+
]);
|
|
2548
2597
|
cg.f32.div();
|
|
2549
2598
|
cg.local.get(z);
|
|
2550
2599
|
cg.f32.mul();
|
|
@@ -2598,6 +2647,74 @@ function wasm_asin(cg) {
|
|
|
2598
2647
|
});
|
|
2599
2648
|
}
|
|
2600
2649
|
/**
|
|
2650
|
+
* Helper function for erf/erfc approximation.
|
|
2651
|
+
*
|
|
2652
|
+
* See `_erfapprox` in alu.ts for details on the algorithm used.
|
|
2653
|
+
*/
|
|
2654
|
+
function _erfapprox(cg, exp_func) {
|
|
2655
|
+
const x = cg.local.declare(cg.f32);
|
|
2656
|
+
const t = cg.local.declare(cg.f32);
|
|
2657
|
+
cg.local.set(x);
|
|
2658
|
+
const p = .3275911;
|
|
2659
|
+
const a1 = .254829592;
|
|
2660
|
+
const a2 = -.284496736;
|
|
2661
|
+
const a3 = 1.421413741;
|
|
2662
|
+
const a4 = -1.453152027;
|
|
2663
|
+
const a5 = 1.061405429;
|
|
2664
|
+
cg.f32.const(1);
|
|
2665
|
+
cg.f32.const(1);
|
|
2666
|
+
cg.f32.const(p);
|
|
2667
|
+
cg.local.get(x);
|
|
2668
|
+
cg.f32.mul();
|
|
2669
|
+
cg.f32.add();
|
|
2670
|
+
cg.f32.div();
|
|
2671
|
+
cg.local.set(t);
|
|
2672
|
+
_poly(cg, t, [
|
|
2673
|
+
0,
|
|
2674
|
+
a1,
|
|
2675
|
+
a2,
|
|
2676
|
+
a3,
|
|
2677
|
+
a4,
|
|
2678
|
+
a5
|
|
2679
|
+
]);
|
|
2680
|
+
cg.local.get(x);
|
|
2681
|
+
cg.f32.neg();
|
|
2682
|
+
cg.local.get(x);
|
|
2683
|
+
cg.f32.mul();
|
|
2684
|
+
cg.call(exp_func);
|
|
2685
|
+
cg.f32.mul();
|
|
2686
|
+
}
|
|
2687
|
+
/** Approximate erf(x) (error function). */
|
|
2688
|
+
function wasm_erf(cg, exp) {
|
|
2689
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2690
|
+
cg.f32.const(1);
|
|
2691
|
+
cg.local.get(0);
|
|
2692
|
+
cg.f32.abs();
|
|
2693
|
+
_erfapprox(cg, exp);
|
|
2694
|
+
cg.f32.sub();
|
|
2695
|
+
cg.local.get(0);
|
|
2696
|
+
cg.f32.copysign();
|
|
2697
|
+
});
|
|
2698
|
+
}
|
|
2699
|
+
/** Approximate erfc(x) (complementary error function). */
|
|
2700
|
+
function wasm_erfc(cg, exp) {
|
|
2701
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2702
|
+
const e = cg.local.declare(cg.f32);
|
|
2703
|
+
cg.local.get(0);
|
|
2704
|
+
cg.f32.abs();
|
|
2705
|
+
_erfapprox(cg, exp);
|
|
2706
|
+
cg.local.set(e);
|
|
2707
|
+
cg.f32.const(2);
|
|
2708
|
+
cg.local.get(e);
|
|
2709
|
+
cg.f32.sub();
|
|
2710
|
+
cg.local.get(e);
|
|
2711
|
+
cg.local.get(0);
|
|
2712
|
+
cg.f32.const(0);
|
|
2713
|
+
cg.f32.lt();
|
|
2714
|
+
cg.select();
|
|
2715
|
+
});
|
|
2716
|
+
}
|
|
2717
|
+
/**
|
|
2601
2718
|
* Threefry2x32 pseudorandom number generator.
|
|
2602
2719
|
*
|
|
2603
2720
|
* Takes two 32-bit keys and two 32-bit counters as input,
|
|
@@ -3472,14 +3589,16 @@ function codegenWasm(kernel) {
|
|
|
3472
3589
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
3473
3590
|
const cg = new CodeGenerator();
|
|
3474
3591
|
cg.memory.import("env", "memory");
|
|
3475
|
-
const distinctOps =
|
|
3592
|
+
const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
|
|
3476
3593
|
const funcs = {};
|
|
3477
3594
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3478
3595
|
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
3479
3596
|
if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
|
|
3480
3597
|
if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
|
|
3481
|
-
if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
|
|
3598
|
+
if (distinctOps.has(AluOp.Exp) || distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) funcs.exp = wasm_exp(cg);
|
|
3482
3599
|
if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
|
|
3600
|
+
if (distinctOps.has(AluOp.Erf)) funcs.erf = wasm_erf(cg, funcs.exp);
|
|
3601
|
+
if (distinctOps.has(AluOp.Erfc)) funcs.erfc = wasm_erfc(cg, funcs.exp);
|
|
3483
3602
|
if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
|
|
3484
3603
|
const kernelFunc = cg.function(rep(kernel.nargs + 1, cg.i32), [], () => {
|
|
3485
3604
|
const gidx = cg.local.declare(cg.i32);
|
|
@@ -3633,6 +3752,8 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3633
3752
|
else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
|
|
3634
3753
|
else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
|
|
3635
3754
|
else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
|
|
3755
|
+
else if (op === AluOp.Erf) gen(src[0]), cg.call(funcs.erf);
|
|
3756
|
+
else if (op === AluOp.Erfc) gen(src[0]), cg.call(funcs.erfc);
|
|
3636
3757
|
else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
|
|
3637
3758
|
else if (op === AluOp.Reciprocal) cg.f32.const(1), gen(src[0]), cg.f32.div();
|
|
3638
3759
|
else if (op === AluOp.Cast) {
|
|
@@ -3760,7 +3881,7 @@ async function createBackend(device) {
|
|
|
3760
3881
|
if (!navigator.gpu) return null;
|
|
3761
3882
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3762
3883
|
if (!adapter) return null;
|
|
3763
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
3884
|
+
const { WebGPUBackend } = await import("./webgpu-LGi2A3mS.js");
|
|
3764
3885
|
const importantLimits = [
|
|
3765
3886
|
"maxBufferSize",
|
|
3766
3887
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -3813,4 +3934,5 @@ var UnsupportedOpError = class extends Error {
|
|
|
3813
3934
|
};
|
|
3814
3935
|
|
|
3815
3936
|
//#endregion
|
|
3816
|
-
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu,
|
|
3937
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
3938
|
+
//# sourceMappingURL=backend-DwIAd0AG.js.map
|