@jax-js/jax 0.0.5 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +267 -92
- package/dist/{backend-yEU0L_ig.cjs → backend-BbrKEB18.cjs} +378 -183
- package/dist/{backend-CdcTZEOF.js → backend-CoVtc9dx.js} +366 -177
- package/dist/index.cjs +385 -74
- package/dist/index.d.cts +115 -23
- package/dist/index.d.ts +115 -23
- package/dist/index.js +378 -74
- package/dist/{webgpu-CM-xNYzW.js → webgpu-B3UVme6n.js} +188 -153
- package/dist/{webgpu-CNOpiO5T.cjs → webgpu-DGYNVHma.cjs} +188 -153
- package/package.json +25 -15
|
@@ -52,6 +52,7 @@ let DEBUG = 0;
|
|
|
52
52
|
function setDebug(level) {
|
|
53
53
|
DEBUG = level;
|
|
54
54
|
}
|
|
55
|
+
function assertNonNull(value) {}
|
|
55
56
|
function unzip2(pairs) {
|
|
56
57
|
const lst1 = [];
|
|
57
58
|
const lst2 = [];
|
|
@@ -97,10 +98,15 @@ function deepEqual(a, b) {
|
|
|
97
98
|
for (const key of Object.keys(a)) if (!deepEqual(a[key], b[key])) return false;
|
|
98
99
|
return true;
|
|
99
100
|
}
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
101
|
+
/** Produces a union of maps of sets. This mutates `a`. */
|
|
102
|
+
function mapSetUnion(a, b) {
|
|
103
|
+
if (!b) return a;
|
|
104
|
+
for (const [key, setB] of b.entries()) {
|
|
105
|
+
const setA = a.get(key);
|
|
106
|
+
if (setA) for (const val of setB) setA.add(val);
|
|
107
|
+
else a.set(key, setB);
|
|
108
|
+
}
|
|
109
|
+
return a;
|
|
104
110
|
}
|
|
105
111
|
/** Splits the list based on a condition, `false` first then `true`. */
|
|
106
112
|
function partitionList(which, array) {
|
|
@@ -301,6 +307,7 @@ let DType = /* @__PURE__ */ function(DType$1) {
|
|
|
301
307
|
DType$1["Uint32"] = "uint32";
|
|
302
308
|
DType$1["Bool"] = "bool";
|
|
303
309
|
DType$1["Float16"] = "float16";
|
|
310
|
+
DType$1["Float64"] = "float64";
|
|
304
311
|
return DType$1;
|
|
305
312
|
}({});
|
|
306
313
|
const byteWidth = (dtype) => {
|
|
@@ -310,10 +317,11 @@ const byteWidth = (dtype) => {
|
|
|
310
317
|
case DType.Uint32:
|
|
311
318
|
case DType.Bool: return 4;
|
|
312
319
|
case DType.Float16: return 2;
|
|
320
|
+
case DType.Float64: return 8;
|
|
313
321
|
default: throw new TypeError(`Unknown dtype: ${dtype}`);
|
|
314
322
|
}
|
|
315
323
|
};
|
|
316
|
-
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
|
|
324
|
+
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16 || dtype === DType.Float64;
|
|
317
325
|
/**
|
|
318
326
|
* Promote two dtypes to their join according to the type lattice.
|
|
319
327
|
*
|
|
@@ -323,7 +331,7 @@ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float
|
|
|
323
331
|
*
|
|
324
332
|
* **Type lattice:**
|
|
325
333
|
* ```text
|
|
326
|
-
* bool -> uint32 -> int32 -> float16 -> float32
|
|
334
|
+
* bool -> uint32 -> int32 -> float16 -> float32 -> float64
|
|
327
335
|
* weakType --^
|
|
328
336
|
* ```
|
|
329
337
|
*
|
|
@@ -345,7 +353,8 @@ function promoteTypes(dtype1, dtype2) {
|
|
|
345
353
|
[DType.Uint32]: 1,
|
|
346
354
|
[DType.Int32]: 2,
|
|
347
355
|
[DType.Float16]: 3,
|
|
348
|
-
[DType.Float32]: 4
|
|
356
|
+
[DType.Float32]: 4,
|
|
357
|
+
[DType.Float64]: 5
|
|
349
358
|
};
|
|
350
359
|
return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
|
|
351
360
|
}
|
|
@@ -358,6 +367,7 @@ function dtypedArray(dtype, data) {
|
|
|
358
367
|
case DType.Bool: return new Int32Array(buffer, byteOffset, length);
|
|
359
368
|
case DType.Uint32: return new Uint32Array(buffer, byteOffset, length);
|
|
360
369
|
case DType.Float16: return new Float16Array(buffer, byteOffset, length);
|
|
370
|
+
case DType.Float64: return new Float64Array(buffer, byteOffset, length);
|
|
361
371
|
default: throw new Error(`Unimplemented dtype: ${dtype}`);
|
|
362
372
|
}
|
|
363
373
|
}
|
|
@@ -368,6 +378,7 @@ function dtypedJsArray(dtype, data) {
|
|
|
368
378
|
case DType.Bool: return new Int32Array(data);
|
|
369
379
|
case DType.Uint32: return new Uint32Array(data);
|
|
370
380
|
case DType.Float16: return new Float16Array(data);
|
|
381
|
+
case DType.Float64: return new Float64Array(data);
|
|
371
382
|
default: throw new Error(`Unimplemented dtype: ${dtype}`);
|
|
372
383
|
}
|
|
373
384
|
}
|
|
@@ -430,6 +441,12 @@ var AluExp = class AluExp {
|
|
|
430
441
|
static log(a) {
|
|
431
442
|
return new AluExp(AluOp.Log, a.dtype, [a]);
|
|
432
443
|
}
|
|
444
|
+
static erf(a) {
|
|
445
|
+
return new AluExp(AluOp.Erf, a.dtype, [a]);
|
|
446
|
+
}
|
|
447
|
+
static erfc(a) {
|
|
448
|
+
return new AluExp(AluOp.Erfc, a.dtype, [a]);
|
|
449
|
+
}
|
|
433
450
|
static sqrt(a) {
|
|
434
451
|
return new AluExp(AluOp.Sqrt, a.dtype, [a]);
|
|
435
452
|
}
|
|
@@ -499,6 +516,9 @@ var AluExp = class AluExp {
|
|
|
499
516
|
static f16(value) {
|
|
500
517
|
return AluExp.const(DType.Float16, value);
|
|
501
518
|
}
|
|
519
|
+
static f64(value) {
|
|
520
|
+
return AluExp.const(DType.Float64, value);
|
|
521
|
+
}
|
|
502
522
|
not() {
|
|
503
523
|
if (this.dtype !== DType.Bool) throw new Error("not() can only be called on boolean expressions");
|
|
504
524
|
return AluExp.cmpne(this, AluExp.const(DType.Bool, true));
|
|
@@ -509,7 +529,8 @@ var AluExp = class AluExp {
|
|
|
509
529
|
const hasher = new FpHash();
|
|
510
530
|
hasher.update(this.op);
|
|
511
531
|
hasher.update(this.dtype);
|
|
512
|
-
hasher.update(
|
|
532
|
+
if (this.op === AluOp.Const) hasher.update(this.arg);
|
|
533
|
+
else hasher.update(JSON.stringify(this.arg));
|
|
513
534
|
hasher.update(this.src.length);
|
|
514
535
|
for (const s of this.src) hasher.update(s);
|
|
515
536
|
this.#hash = hasher.value;
|
|
@@ -600,6 +621,12 @@ var AluExp = class AluExp {
|
|
|
600
621
|
case AluOp.Log:
|
|
601
622
|
ret = [Math.log(src[0].min), Math.log(src[0].max)];
|
|
602
623
|
break;
|
|
624
|
+
case AluOp.Erf:
|
|
625
|
+
ret = [erf(src[0].min), erf(src[0].max)];
|
|
626
|
+
break;
|
|
627
|
+
case AluOp.Erfc:
|
|
628
|
+
ret = [erfc(src[0].max), erfc(src[0].min)];
|
|
629
|
+
break;
|
|
603
630
|
case AluOp.Sqrt:
|
|
604
631
|
ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
|
|
605
632
|
break;
|
|
@@ -740,6 +767,7 @@ var AluExp = class AluExp {
|
|
|
740
767
|
if (op === AluOp.Mul && x === 1) return src[1 - i];
|
|
741
768
|
if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
|
|
742
769
|
if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
|
|
770
|
+
if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
|
|
743
771
|
}
|
|
744
772
|
if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
|
|
745
773
|
const [a, b] = src[1].src;
|
|
@@ -921,6 +949,8 @@ var AluExp = class AluExp {
|
|
|
921
949
|
case AluOp.Atan: return Math.atan(x);
|
|
922
950
|
case AluOp.Exp: return Math.exp(x);
|
|
923
951
|
case AluOp.Log: return Math.log(x);
|
|
952
|
+
case AluOp.Erf: return erf(x);
|
|
953
|
+
case AluOp.Erfc: return erfc(x);
|
|
924
954
|
case AluOp.Sqrt: return Math.sqrt(x);
|
|
925
955
|
case AluOp.Reciprocal: return 1 / x;
|
|
926
956
|
case AluOp.Cast: {
|
|
@@ -939,11 +969,13 @@ var AluExp = class AluExp {
|
|
|
939
969
|
else if (fromType === DType.Int32) view.setInt32(0, x, true);
|
|
940
970
|
else if (fromType === DType.Uint32) view.setUint32(0, x, true);
|
|
941
971
|
else if (fromType === DType.Float16) view.setFloat16(0, x, true);
|
|
972
|
+
else if (fromType === DType.Float64) view.setFloat64(0, x, true);
|
|
942
973
|
else throw new Error(`Unsupported bitcast from ${fromType}`);
|
|
943
974
|
if (this.dtype === DType.Float32) return view.getFloat32(0, true);
|
|
944
975
|
else if (this.dtype === DType.Int32) return view.getInt32(0, true);
|
|
945
976
|
else if (this.dtype === DType.Uint32) return view.getUint32(0, true);
|
|
946
977
|
else if (this.dtype === DType.Float16) return view.getFloat16(0, true);
|
|
978
|
+
else if (this.dtype === DType.Float64) return view.getFloat64(0, true);
|
|
947
979
|
else throw new Error(`Unsupported bitcast to ${this.dtype}`);
|
|
948
980
|
}
|
|
949
981
|
default: throw new Error(`Missing implemementation for ${this.op}`);
|
|
@@ -1069,11 +1101,15 @@ var AluExp = class AluExp {
|
|
|
1069
1101
|
});
|
|
1070
1102
|
return result;
|
|
1071
1103
|
}
|
|
1072
|
-
/** Produce
|
|
1104
|
+
/** Produce all distinct AluOp in this expression, with their dtypes. */
|
|
1073
1105
|
distinctOps() {
|
|
1074
|
-
const ops = /* @__PURE__ */ new
|
|
1106
|
+
const ops = /* @__PURE__ */ new Map();
|
|
1075
1107
|
this.fold((exp) => {
|
|
1076
|
-
ops.
|
|
1108
|
+
const s = ops.get(exp.op) ?? /* @__PURE__ */ new Set();
|
|
1109
|
+
if (!s.has(exp.dtype)) {
|
|
1110
|
+
s.add(exp.dtype);
|
|
1111
|
+
ops.set(exp.op, s);
|
|
1112
|
+
}
|
|
1077
1113
|
});
|
|
1078
1114
|
return ops;
|
|
1079
1115
|
}
|
|
@@ -1102,6 +1138,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
1102
1138
|
AluOp$1["Atan"] = "Atan";
|
|
1103
1139
|
AluOp$1["Exp"] = "Exp";
|
|
1104
1140
|
AluOp$1["Log"] = "Log";
|
|
1141
|
+
AluOp$1["Erf"] = "Erf";
|
|
1142
|
+
AluOp$1["Erfc"] = "Erfc";
|
|
1105
1143
|
AluOp$1["Sqrt"] = "Sqrt";
|
|
1106
1144
|
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
1107
1145
|
AluOp$1["Cast"] = "Cast";
|
|
@@ -1134,6 +1172,8 @@ const AluGroup = {
|
|
|
1134
1172
|
AluOp.Atan,
|
|
1135
1173
|
AluOp.Exp,
|
|
1136
1174
|
AluOp.Log,
|
|
1175
|
+
AluOp.Erf,
|
|
1176
|
+
AluOp.Erfc,
|
|
1137
1177
|
AluOp.Sqrt,
|
|
1138
1178
|
AluOp.Reciprocal,
|
|
1139
1179
|
AluOp.Cast,
|
|
@@ -1159,6 +1199,8 @@ const AluGroup = {
|
|
|
1159
1199
|
AluOp.Atan,
|
|
1160
1200
|
AluOp.Exp,
|
|
1161
1201
|
AluOp.Log,
|
|
1202
|
+
AluOp.Erf,
|
|
1203
|
+
AluOp.Erfc,
|
|
1162
1204
|
AluOp.Sqrt,
|
|
1163
1205
|
AluOp.Reciprocal
|
|
1164
1206
|
])
|
|
@@ -1334,6 +1376,44 @@ function threefry2x32(k0, k1, c0, c1) {
|
|
|
1334
1376
|
x1 = x1 + ks0 + 5 >>> 0;
|
|
1335
1377
|
return [x0, x1];
|
|
1336
1378
|
}
|
|
1379
|
+
/**
|
|
1380
|
+
* Abramowitz & Stegun’s widely used approximation for erf(x).
|
|
1381
|
+
*
|
|
1382
|
+
* `erf(x) = 1 - P(t) * exp(-x^2)` for `x >= 0`, where `t = 1/(1 + p*x)` and
|
|
1383
|
+
* `P(t) = a1*t + a2*t^2 + a3*t^3 + a4*t^4 + a5*t^5`.
|
|
1384
|
+
*
|
|
1385
|
+
* Coefficients:
|
|
1386
|
+
* - p = 0.3275911
|
|
1387
|
+
* - a1 = 0.254829592
|
|
1388
|
+
* - a2 = -0.284496736
|
|
1389
|
+
* - a3 = 1.421413741
|
|
1390
|
+
* - a4 = -1.453152027
|
|
1391
|
+
* - a5 = 1.061405429
|
|
1392
|
+
*
|
|
1393
|
+
* This function computes just `E = P(t) * exp(-x^2)` for numerical reasons. The
|
|
1394
|
+
* input is assumed to be non-negative.
|
|
1395
|
+
*
|
|
1396
|
+
* Reference: https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions
|
|
1397
|
+
*/
|
|
1398
|
+
function _erfapprox$1(x) {
|
|
1399
|
+
const p = .3275911;
|
|
1400
|
+
const a1 = .254829592;
|
|
1401
|
+
const a2 = -.284496736;
|
|
1402
|
+
const a3 = 1.421413741;
|
|
1403
|
+
const a4 = -1.453152027;
|
|
1404
|
+
const a5 = 1.061405429;
|
|
1405
|
+
const t = 1 / (1 + p * x);
|
|
1406
|
+
const P_t = ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t;
|
|
1407
|
+
return P_t * Math.exp(-x * x);
|
|
1408
|
+
}
|
|
1409
|
+
function erf(x) {
|
|
1410
|
+
if (x >= 0) return 1 - _erfapprox$1(x);
|
|
1411
|
+
else return _erfapprox$1(-x) - 1;
|
|
1412
|
+
}
|
|
1413
|
+
function erfc(x) {
|
|
1414
|
+
if (x >= 0) return _erfapprox$1(x);
|
|
1415
|
+
else return 2 - _erfapprox$1(-x);
|
|
1416
|
+
}
|
|
1337
1417
|
|
|
1338
1418
|
//#endregion
|
|
1339
1419
|
//#region src/shape.ts
|
|
@@ -2019,7 +2099,7 @@ function tuneWebgpu(kernel) {
|
|
|
2019
2099
|
}
|
|
2020
2100
|
if (/chrome/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
2021
2101
|
const s = dim.st.shape[dim.unroll - 1];
|
|
2022
|
-
if (s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2102
|
+
if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2023
2103
|
else for (const splits of [4]) if (s % splits === 0) {
|
|
2024
2104
|
dim.applyUnroll(dim.unroll - 1, splits);
|
|
2025
2105
|
break;
|
|
@@ -2238,6 +2318,19 @@ var WasmAllocator = class {
|
|
|
2238
2318
|
|
|
2239
2319
|
//#endregion
|
|
2240
2320
|
//#region src/backend/wasm/builtins.ts
|
|
2321
|
+
/** Given a local `x`, evaluate `sum[i](a_i * x^i)` and push to stack. */
|
|
2322
|
+
function _poly(cg, x, as) {
|
|
2323
|
+
if (as.length === 0) throw new Error("_poly needs at least one coefficient");
|
|
2324
|
+
cg.f32.const(as[as.length - 1]);
|
|
2325
|
+
for (let i = as.length - 2; i >= 0; i--) {
|
|
2326
|
+
cg.local.get(x);
|
|
2327
|
+
cg.f32.mul();
|
|
2328
|
+
if (as[i] !== 0) {
|
|
2329
|
+
cg.f32.const(as[i]);
|
|
2330
|
+
cg.f32.add();
|
|
2331
|
+
}
|
|
2332
|
+
}
|
|
2333
|
+
}
|
|
2241
2334
|
/**
|
|
2242
2335
|
* Approximate e^x.
|
|
2243
2336
|
*
|
|
@@ -2278,27 +2371,15 @@ function wasm_exp(cg) {
|
|
|
2278
2371
|
cg.f32.mul();
|
|
2279
2372
|
cg.f32.sub();
|
|
2280
2373
|
cg.local.set(r);
|
|
2281
|
-
cg
|
|
2282
|
-
|
|
2283
|
-
|
|
2284
|
-
|
|
2285
|
-
|
|
2286
|
-
|
|
2287
|
-
|
|
2288
|
-
|
|
2289
|
-
|
|
2290
|
-
cg.local.get(r);
|
|
2291
|
-
cg.f32.mul();
|
|
2292
|
-
cg.f32.const(1 / 2);
|
|
2293
|
-
cg.f32.add();
|
|
2294
|
-
cg.local.get(r);
|
|
2295
|
-
cg.f32.mul();
|
|
2296
|
-
cg.f32.const(1);
|
|
2297
|
-
cg.f32.add();
|
|
2298
|
-
cg.local.get(r);
|
|
2299
|
-
cg.f32.mul();
|
|
2300
|
-
cg.f32.const(1);
|
|
2301
|
-
cg.f32.add();
|
|
2374
|
+
_poly(cg, r, [
|
|
2375
|
+
1,
|
|
2376
|
+
1,
|
|
2377
|
+
1 / 2,
|
|
2378
|
+
1 / 6,
|
|
2379
|
+
1 / 24,
|
|
2380
|
+
1 / 120,
|
|
2381
|
+
1 / 720
|
|
2382
|
+
]);
|
|
2302
2383
|
cg.local.set(p);
|
|
2303
2384
|
cg.local.get(k);
|
|
2304
2385
|
cg.i32.const(127);
|
|
@@ -2326,11 +2407,6 @@ function wasm_log(cg) {
|
|
|
2326
2407
|
const m = cg.local.declare(cg.f32);
|
|
2327
2408
|
const t = cg.local.declare(cg.f32);
|
|
2328
2409
|
const t2 = cg.local.declare(cg.f32);
|
|
2329
|
-
const t3 = cg.local.declare(cg.f32);
|
|
2330
|
-
const t5 = cg.local.declare(cg.f32);
|
|
2331
|
-
const t7 = cg.local.declare(cg.f32);
|
|
2332
|
-
const lnm = cg.local.declare(cg.f32);
|
|
2333
|
-
const el2 = cg.local.declare(cg.f32);
|
|
2334
2410
|
cg.local.get(0);
|
|
2335
2411
|
cg.f32.const(0);
|
|
2336
2412
|
cg.f32.le();
|
|
@@ -2367,41 +2443,18 @@ function wasm_log(cg) {
|
|
|
2367
2443
|
cg.local.get(t);
|
|
2368
2444
|
cg.f32.mul();
|
|
2369
2445
|
cg.local.set(t2);
|
|
2446
|
+
_poly(cg, t2, [
|
|
2447
|
+
2,
|
|
2448
|
+
2 / 3,
|
|
2449
|
+
2 / 5,
|
|
2450
|
+
2 / 7
|
|
2451
|
+
]);
|
|
2370
2452
|
cg.local.get(t);
|
|
2371
|
-
cg.local.get(t2);
|
|
2372
|
-
cg.f32.mul();
|
|
2373
|
-
cg.local.set(t3);
|
|
2374
|
-
cg.local.get(t3);
|
|
2375
|
-
cg.local.get(t2);
|
|
2376
|
-
cg.f32.mul();
|
|
2377
|
-
cg.local.set(t5);
|
|
2378
|
-
cg.local.get(t5);
|
|
2379
|
-
cg.local.get(t2);
|
|
2380
|
-
cg.f32.mul();
|
|
2381
|
-
cg.local.set(t7);
|
|
2382
|
-
cg.local.get(t7);
|
|
2383
|
-
cg.f32.const(1 / 7);
|
|
2384
|
-
cg.f32.mul();
|
|
2385
|
-
cg.local.get(t5);
|
|
2386
|
-
cg.f32.const(1 / 5);
|
|
2387
|
-
cg.f32.mul();
|
|
2388
|
-
cg.f32.add();
|
|
2389
|
-
cg.local.get(t3);
|
|
2390
|
-
cg.f32.const(1 / 3);
|
|
2391
|
-
cg.f32.mul();
|
|
2392
|
-
cg.f32.add();
|
|
2393
|
-
cg.local.get(t);
|
|
2394
|
-
cg.f32.add();
|
|
2395
|
-
cg.f32.const(2);
|
|
2396
2453
|
cg.f32.mul();
|
|
2397
|
-
cg.local.set(lnm);
|
|
2398
2454
|
cg.local.get(e);
|
|
2399
2455
|
cg.f32.convert_i32_s();
|
|
2400
2456
|
cg.f32.const(Math.LN2);
|
|
2401
2457
|
cg.f32.mul();
|
|
2402
|
-
cg.local.set(el2);
|
|
2403
|
-
cg.local.get(el2);
|
|
2404
|
-
cg.local.get(lnm);
|
|
2405
2458
|
cg.f32.add();
|
|
2406
2459
|
});
|
|
2407
2460
|
}
|
|
@@ -2411,7 +2464,7 @@ function wasm_log(cg) {
|
|
|
2411
2464
|
* Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
|
|
2412
2465
|
* z = y - q*(π/2); use one of two polynomials on z:
|
|
2413
2466
|
* sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
|
|
2414
|
-
* cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
|
|
2467
|
+
* cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720) + z^8*(1/40320)
|
|
2415
2468
|
*/
|
|
2416
2469
|
function _sincos(cg) {
|
|
2417
2470
|
const y = cg.local.declare(cg.f32);
|
|
@@ -2447,35 +2500,22 @@ function _sincos(cg) {
|
|
|
2447
2500
|
cg.local.get(z);
|
|
2448
2501
|
cg.f32.mul();
|
|
2449
2502
|
cg.local.set(z2);
|
|
2450
|
-
cg
|
|
2451
|
-
|
|
2452
|
-
|
|
2453
|
-
|
|
2454
|
-
|
|
2455
|
-
|
|
2456
|
-
cg.f32.mul();
|
|
2457
|
-
cg.f32.const(-1 / 6);
|
|
2458
|
-
cg.f32.add();
|
|
2459
|
-
cg.local.get(z2);
|
|
2460
|
-
cg.f32.mul();
|
|
2461
|
-
cg.f32.const(1);
|
|
2462
|
-
cg.f32.add();
|
|
2503
|
+
_poly(cg, z2, [
|
|
2504
|
+
1,
|
|
2505
|
+
-1 / 6,
|
|
2506
|
+
1 / 120,
|
|
2507
|
+
-1 / 5040
|
|
2508
|
+
]);
|
|
2463
2509
|
cg.local.get(z);
|
|
2464
2510
|
cg.f32.mul();
|
|
2465
2511
|
cg.local.set(sz);
|
|
2466
|
-
cg
|
|
2467
|
-
|
|
2468
|
-
|
|
2469
|
-
|
|
2470
|
-
|
|
2471
|
-
|
|
2472
|
-
|
|
2473
|
-
cg.f32.const(-1 / 2);
|
|
2474
|
-
cg.f32.add();
|
|
2475
|
-
cg.local.get(z2);
|
|
2476
|
-
cg.f32.mul();
|
|
2477
|
-
cg.f32.const(1);
|
|
2478
|
-
cg.f32.add();
|
|
2512
|
+
_poly(cg, z2, [
|
|
2513
|
+
1,
|
|
2514
|
+
-1 / 2,
|
|
2515
|
+
1 / 24,
|
|
2516
|
+
-1 / 720,
|
|
2517
|
+
1 / 40320
|
|
2518
|
+
]);
|
|
2479
2519
|
cg.local.set(cz);
|
|
2480
2520
|
return {
|
|
2481
2521
|
q,
|
|
@@ -2557,24 +2597,16 @@ function _atan(cg) {
|
|
|
2557
2597
|
cg.local.get(z);
|
|
2558
2598
|
cg.f32.mul();
|
|
2559
2599
|
cg.local.set(z2);
|
|
2560
|
-
cg
|
|
2561
|
-
|
|
2562
|
-
|
|
2563
|
-
|
|
2564
|
-
|
|
2565
|
-
cg
|
|
2566
|
-
|
|
2567
|
-
|
|
2568
|
-
|
|
2569
|
-
|
|
2570
|
-
cg.local.get(z2);
|
|
2571
|
-
cg.f32.mul();
|
|
2572
|
-
cg.f32.const(.994987933645);
|
|
2573
|
-
cg.f32.add();
|
|
2574
|
-
cg.local.get(z2);
|
|
2575
|
-
cg.f32.mul();
|
|
2576
|
-
cg.f32.const(1);
|
|
2577
|
-
cg.f32.add();
|
|
2600
|
+
_poly(cg, z2, [
|
|
2601
|
+
.999998614341,
|
|
2602
|
+
.661705427875,
|
|
2603
|
+
.0415796528637
|
|
2604
|
+
]);
|
|
2605
|
+
_poly(cg, z2, [
|
|
2606
|
+
1,
|
|
2607
|
+
.994987933645,
|
|
2608
|
+
.173698870181
|
|
2609
|
+
]);
|
|
2578
2610
|
cg.f32.div();
|
|
2579
2611
|
cg.local.get(z);
|
|
2580
2612
|
cg.f32.mul();
|
|
@@ -2628,6 +2660,74 @@ function wasm_asin(cg) {
|
|
|
2628
2660
|
});
|
|
2629
2661
|
}
|
|
2630
2662
|
/**
|
|
2663
|
+
* Helper function for erf/erfc approximation.
|
|
2664
|
+
*
|
|
2665
|
+
* See `_erfapprox` in alu.ts for details on the algorithm used.
|
|
2666
|
+
*/
|
|
2667
|
+
function _erfapprox(cg, exp_func) {
|
|
2668
|
+
const x = cg.local.declare(cg.f32);
|
|
2669
|
+
const t = cg.local.declare(cg.f32);
|
|
2670
|
+
cg.local.set(x);
|
|
2671
|
+
const p = .3275911;
|
|
2672
|
+
const a1 = .254829592;
|
|
2673
|
+
const a2 = -.284496736;
|
|
2674
|
+
const a3 = 1.421413741;
|
|
2675
|
+
const a4 = -1.453152027;
|
|
2676
|
+
const a5 = 1.061405429;
|
|
2677
|
+
cg.f32.const(1);
|
|
2678
|
+
cg.f32.const(1);
|
|
2679
|
+
cg.f32.const(p);
|
|
2680
|
+
cg.local.get(x);
|
|
2681
|
+
cg.f32.mul();
|
|
2682
|
+
cg.f32.add();
|
|
2683
|
+
cg.f32.div();
|
|
2684
|
+
cg.local.set(t);
|
|
2685
|
+
_poly(cg, t, [
|
|
2686
|
+
0,
|
|
2687
|
+
a1,
|
|
2688
|
+
a2,
|
|
2689
|
+
a3,
|
|
2690
|
+
a4,
|
|
2691
|
+
a5
|
|
2692
|
+
]);
|
|
2693
|
+
cg.local.get(x);
|
|
2694
|
+
cg.f32.neg();
|
|
2695
|
+
cg.local.get(x);
|
|
2696
|
+
cg.f32.mul();
|
|
2697
|
+
cg.call(exp_func);
|
|
2698
|
+
cg.f32.mul();
|
|
2699
|
+
}
|
|
2700
|
+
/** Approximate erf(x) (error function). */
|
|
2701
|
+
function wasm_erf(cg, exp) {
|
|
2702
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2703
|
+
cg.f32.const(1);
|
|
2704
|
+
cg.local.get(0);
|
|
2705
|
+
cg.f32.abs();
|
|
2706
|
+
_erfapprox(cg, exp);
|
|
2707
|
+
cg.f32.sub();
|
|
2708
|
+
cg.local.get(0);
|
|
2709
|
+
cg.f32.copysign();
|
|
2710
|
+
});
|
|
2711
|
+
}
|
|
2712
|
+
/** Approximate erfc(x) (complementary error function). */
|
|
2713
|
+
function wasm_erfc(cg, exp) {
|
|
2714
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2715
|
+
const e = cg.local.declare(cg.f32);
|
|
2716
|
+
cg.local.get(0);
|
|
2717
|
+
cg.f32.abs();
|
|
2718
|
+
_erfapprox(cg, exp);
|
|
2719
|
+
cg.local.set(e);
|
|
2720
|
+
cg.f32.const(2);
|
|
2721
|
+
cg.local.get(e);
|
|
2722
|
+
cg.f32.sub();
|
|
2723
|
+
cg.local.get(e);
|
|
2724
|
+
cg.local.get(0);
|
|
2725
|
+
cg.f32.const(0);
|
|
2726
|
+
cg.f32.lt();
|
|
2727
|
+
cg.select();
|
|
2728
|
+
});
|
|
2729
|
+
}
|
|
2730
|
+
/**
|
|
2631
2731
|
* Threefry2x32 pseudorandom number generator.
|
|
2632
2732
|
*
|
|
2633
2733
|
* Takes two 32-bit keys and two 32-bit counters as input,
|
|
@@ -2838,6 +2938,7 @@ var CodeGenerator = class {
|
|
|
2838
2938
|
local;
|
|
2839
2939
|
i32;
|
|
2840
2940
|
f32;
|
|
2941
|
+
f64;
|
|
2841
2942
|
v128;
|
|
2842
2943
|
i32x4;
|
|
2843
2944
|
f32x4;
|
|
@@ -2857,6 +2958,7 @@ var CodeGenerator = class {
|
|
|
2857
2958
|
this.local = new Local(this);
|
|
2858
2959
|
this.i32 = new I32(this);
|
|
2859
2960
|
this.f32 = new F32(this);
|
|
2961
|
+
this.f64 = new F64(this);
|
|
2860
2962
|
this.v128 = new V128(this);
|
|
2861
2963
|
this.i32x4 = new I32x4(this);
|
|
2862
2964
|
this.f32x4 = new F32x4(this);
|
|
@@ -3243,6 +3345,8 @@ var I32 = class {
|
|
|
3243
3345
|
ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
|
|
3244
3346
|
trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
|
|
3245
3347
|
trunc_f32_u = UNARY_OP("trunc_f32_u", 169, "f32", "i32");
|
|
3348
|
+
trunc_f64_s = UNARY_OP("trunc_f64_s", 170, "f64", "i32");
|
|
3349
|
+
trunc_f64_u = UNARY_OP("trunc_f64_u", 171, "f64", "i32");
|
|
3246
3350
|
load = LOAD_OP("load", 40, "i32");
|
|
3247
3351
|
load8_s = LOAD_OP("load8_s", 44, "i32");
|
|
3248
3352
|
load8_u = LOAD_OP("load8_u", 45, "i32");
|
|
@@ -3254,6 +3358,8 @@ var I32 = class {
|
|
|
3254
3358
|
reinterpret_f32 = UNARY_OP("reinterpret_f32", 188, "f32", "i32");
|
|
3255
3359
|
trunc_sat_f32_s = UNARY_OP("trunc_sat_f32_s", [252, 0], "f32", "i32");
|
|
3256
3360
|
trunc_sat_f32_u = UNARY_OP("trunc_sat_f32_u", [252, 1], "f32", "i32");
|
|
3361
|
+
trunc_sat_f64_s = UNARY_OP("trunc_sat_f64_s", [252, 2], "f64", "i32");
|
|
3362
|
+
trunc_sat_f64_u = UNARY_OP("trunc_sat_f64_u", [252, 3], "f64", "i32");
|
|
3257
3363
|
};
|
|
3258
3364
|
var F32 = class {
|
|
3259
3365
|
constructor(cg) {
|
|
@@ -3273,6 +3379,8 @@ var F32 = class {
|
|
|
3273
3379
|
for (let i = 0; i < 4; i++) this.cg._emit(bytes[i]);
|
|
3274
3380
|
this.cg._push(this);
|
|
3275
3381
|
}
|
|
3382
|
+
load = LOAD_OP("load", 42, "f32");
|
|
3383
|
+
store = STORE_OP("store", 56, "f32");
|
|
3276
3384
|
eq = BINARY_OP("eq", 91, "f32", "f32", "i32");
|
|
3277
3385
|
ne = BINARY_OP("ne", 92, "f32", "f32", "i32");
|
|
3278
3386
|
lt = BINARY_OP("lt", 93, "f32", "f32", "i32");
|
|
@@ -3295,10 +3403,53 @@ var F32 = class {
|
|
|
3295
3403
|
copysign = BINARY_OP("copysign", 152, "f32", "f32", "f32");
|
|
3296
3404
|
convert_i32_s = UNARY_OP("convert_i32_s", 178, "i32", "f32");
|
|
3297
3405
|
convert_i32_u = UNARY_OP("convert_i32_u", 179, "i32", "f32");
|
|
3298
|
-
|
|
3299
|
-
store = STORE_OP("store", 56, "f32");
|
|
3406
|
+
demote_f64 = UNARY_OP("demote_f64", 182, "f64", "f32");
|
|
3300
3407
|
reinterpret_i32 = UNARY_OP("reinterpret_i32", 190, "i32", "f32");
|
|
3301
3408
|
};
|
|
3409
|
+
var F64 = class {
|
|
3410
|
+
constructor(cg) {
|
|
3411
|
+
this.cg = cg;
|
|
3412
|
+
}
|
|
3413
|
+
get typeId() {
|
|
3414
|
+
return 124;
|
|
3415
|
+
}
|
|
3416
|
+
get name() {
|
|
3417
|
+
return "f64";
|
|
3418
|
+
}
|
|
3419
|
+
const(f) {
|
|
3420
|
+
this.cg._emit(68);
|
|
3421
|
+
const buffer = /* @__PURE__ */ new ArrayBuffer(8);
|
|
3422
|
+
new DataView(buffer).setFloat64(0, f, true);
|
|
3423
|
+
const bytes = new Uint8Array(buffer);
|
|
3424
|
+
for (let i = 0; i < 8; i++) this.cg._emit(bytes[i]);
|
|
3425
|
+
this.cg._push(this);
|
|
3426
|
+
}
|
|
3427
|
+
load = LOAD_OP("load", 43, "f64");
|
|
3428
|
+
store = STORE_OP("store", 57, "f64");
|
|
3429
|
+
eq = BINARY_OP("eq", 97, "f64", "f64", "i32");
|
|
3430
|
+
ne = BINARY_OP("ne", 98, "f64", "f64", "i32");
|
|
3431
|
+
lt = BINARY_OP("lt", 99, "f64", "f64", "i32");
|
|
3432
|
+
gt = BINARY_OP("gt", 100, "f64", "f64", "i32");
|
|
3433
|
+
le = BINARY_OP("le", 101, "f64", "f64", "i32");
|
|
3434
|
+
ge = BINARY_OP("ge", 102, "f64", "f64", "i32");
|
|
3435
|
+
abs = UNARY_OP("abs", 153, "f64", "f64");
|
|
3436
|
+
neg = UNARY_OP("neg", 154, "f64", "f64");
|
|
3437
|
+
ceil = UNARY_OP("ceil", 155, "f64", "f64");
|
|
3438
|
+
floor = UNARY_OP("floor", 156, "f64", "f64");
|
|
3439
|
+
trunc = UNARY_OP("trunc", 157, "f64", "f64");
|
|
3440
|
+
nearest = UNARY_OP("nearest", 158, "f64", "f64");
|
|
3441
|
+
sqrt = UNARY_OP("sqrt", 159, "f64", "f64");
|
|
3442
|
+
add = BINARY_OP("add", 160, "f64", "f64", "f64");
|
|
3443
|
+
sub = BINARY_OP("sub", 161, "f64", "f64", "f64");
|
|
3444
|
+
mul = BINARY_OP("mul", 162, "f64", "f64", "f64");
|
|
3445
|
+
div = BINARY_OP("div", 163, "f64", "f64", "f64");
|
|
3446
|
+
min = BINARY_OP("min", 164, "f64", "f64", "f64");
|
|
3447
|
+
max = BINARY_OP("max", 165, "f64", "f64", "f64");
|
|
3448
|
+
copysign = BINARY_OP("copysign", 166, "f64", "f64", "f64");
|
|
3449
|
+
convert_i32_s = UNARY_OP("convert_i32_s", 183, "i32", "f64");
|
|
3450
|
+
convert_i32_u = UNARY_OP("convert_i32_u", 184, "i32", "f64");
|
|
3451
|
+
promote_f32 = UNARY_OP("promote_f32", 187, "f32", "f64");
|
|
3452
|
+
};
|
|
3302
3453
|
function VECTOR_OP(op, vopcode, inTypes, outType) {
|
|
3303
3454
|
return function() {
|
|
3304
3455
|
for (const inType of inTypes.toReversed()) {
|
|
@@ -3502,14 +3653,16 @@ function codegenWasm(kernel) {
|
|
|
3502
3653
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
3503
3654
|
const cg = new CodeGenerator();
|
|
3504
3655
|
cg.memory.import("env", "memory");
|
|
3505
|
-
const distinctOps =
|
|
3656
|
+
const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
|
|
3506
3657
|
const funcs = {};
|
|
3507
3658
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3508
3659
|
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
3509
3660
|
if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
|
|
3510
3661
|
if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
|
|
3511
|
-
if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
|
|
3662
|
+
if (distinctOps.has(AluOp.Exp) || distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) funcs.exp = wasm_exp(cg);
|
|
3512
3663
|
if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
|
|
3664
|
+
if (distinctOps.has(AluOp.Erf)) funcs.erf = wasm_erf(cg, funcs.exp);
|
|
3665
|
+
if (distinctOps.has(AluOp.Erfc)) funcs.erfc = wasm_erfc(cg, funcs.exp);
|
|
3513
3666
|
if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
|
|
3514
3667
|
const kernelFunc = cg.function(rep(kernel.nargs + 1, cg.i32), [], () => {
|
|
3515
3668
|
const gidx = cg.local.declare(cg.i32);
|
|
@@ -3549,10 +3702,10 @@ function codegenWasm(kernel) {
|
|
|
3549
3702
|
cg.local.get(acc);
|
|
3550
3703
|
if (re.dtype === DType.Bool) cg.i32.and();
|
|
3551
3704
|
else dty(cg, re.op, re.dtype).mul();
|
|
3552
|
-
} else if (re.op === AluOp.Min || re.op === AluOp.Max) if (re.dtype
|
|
3705
|
+
} else if (re.op === AluOp.Min || re.op === AluOp.Max) if (isFloatDtype(re.dtype)) {
|
|
3553
3706
|
cg.local.get(acc);
|
|
3554
|
-
if (re.op === AluOp.Min) cg.
|
|
3555
|
-
else cg.
|
|
3707
|
+
if (re.op === AluOp.Min) dtyF(cg, re.op, re.dtype).min();
|
|
3708
|
+
else dtyF(cg, re.op, re.dtype).max();
|
|
3556
3709
|
} else if ([
|
|
3557
3710
|
DType.Int32,
|
|
3558
3711
|
DType.Uint32,
|
|
@@ -3614,27 +3767,30 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3614
3767
|
else if (op === AluOp.Sub) dty(cg, op, dtype).sub();
|
|
3615
3768
|
else if (op === AluOp.Mul) if (dtype === DType.Bool) cg.i32.and();
|
|
3616
3769
|
else dty(cg, op, dtype).mul();
|
|
3617
|
-
else if (op === AluOp.Idiv) if (dtype
|
|
3618
|
-
|
|
3770
|
+
else if (op === AluOp.Idiv) if (isFloatDtype(dtype)) {
|
|
3771
|
+
dtyF(cg, op, dtype).div();
|
|
3772
|
+
dtyF(cg, op, dtype).trunc();
|
|
3773
|
+
} else if (dtype === DType.Uint32) cg.i32.div_u();
|
|
3619
3774
|
else if (dtype === DType.Int32) cg.i32.div_s();
|
|
3620
3775
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3621
|
-
else if (op === AluOp.Mod) if (dtype
|
|
3622
|
-
const
|
|
3623
|
-
const
|
|
3776
|
+
else if (op === AluOp.Mod) if (isFloatDtype(dtype)) {
|
|
3777
|
+
const dt = dtyF(cg, op, dtype);
|
|
3778
|
+
const a = cg.local.declare(dt);
|
|
3779
|
+
const b = cg.local.declare(dt);
|
|
3624
3780
|
cg.local.set(b);
|
|
3625
3781
|
cg.local.tee(a);
|
|
3626
3782
|
cg.local.get(a);
|
|
3627
3783
|
cg.local.get(b);
|
|
3628
|
-
|
|
3629
|
-
|
|
3784
|
+
dt.div();
|
|
3785
|
+
dt.trunc();
|
|
3630
3786
|
cg.local.get(b);
|
|
3631
|
-
|
|
3632
|
-
|
|
3787
|
+
dt.mul();
|
|
3788
|
+
dt.sub();
|
|
3633
3789
|
} else if (dtype === DType.Uint32) cg.i32.rem_u();
|
|
3634
3790
|
else if (dtype === DType.Int32) cg.i32.rem_s();
|
|
3635
3791
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3636
|
-
else if (op === AluOp.Min || op === AluOp.Max) if (dtype
|
|
3637
|
-
else cg.
|
|
3792
|
+
else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
|
|
3793
|
+
else dtyF(cg, op, dtype).max();
|
|
3638
3794
|
else if (dtype === DType.Int32 || dtype === DType.Uint32) {
|
|
3639
3795
|
const a = cg.local.declare(cg.i32);
|
|
3640
3796
|
const b = cg.local.declare(cg.i32);
|
|
@@ -3651,52 +3807,74 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3651
3807
|
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3652
3808
|
else if (op === AluOp.Cmplt) {
|
|
3653
3809
|
const srcDtype = src[0].dtype;
|
|
3654
|
-
if (srcDtype
|
|
3810
|
+
if (isFloatDtype(srcDtype)) dtyF(cg, op, srcDtype).lt();
|
|
3655
3811
|
else if (srcDtype === DType.Int32) cg.i32.lt_s();
|
|
3656
3812
|
else if (srcDtype === DType.Uint32) cg.i32.lt_u();
|
|
3657
3813
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3658
3814
|
} else if (op === AluOp.Cmpne) dty(cg, op, src[0].dtype).ne();
|
|
3659
3815
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3660
|
-
} else if (AluGroup.Unary.has(op))
|
|
3661
|
-
|
|
3662
|
-
|
|
3663
|
-
|
|
3664
|
-
|
|
3665
|
-
|
|
3666
|
-
|
|
3667
|
-
|
|
3668
|
-
|
|
3669
|
-
gen(src[0]);
|
|
3670
|
-
|
|
3671
|
-
|
|
3672
|
-
if (
|
|
3673
|
-
else if (
|
|
3674
|
-
else
|
|
3675
|
-
else if (
|
|
3676
|
-
else if (
|
|
3677
|
-
|
|
3678
|
-
|
|
3679
|
-
else if (
|
|
3680
|
-
|
|
3681
|
-
|
|
3682
|
-
|
|
3683
|
-
|
|
3684
|
-
|
|
3685
|
-
|
|
3686
|
-
|
|
3687
|
-
|
|
3688
|
-
|
|
3689
|
-
|
|
3690
|
-
|
|
3691
|
-
|
|
3692
|
-
|
|
3693
|
-
|
|
3694
|
-
|
|
3695
|
-
|
|
3696
|
-
|
|
3697
|
-
|
|
3698
|
-
|
|
3699
|
-
|
|
3816
|
+
} else if (AluGroup.Unary.has(op)) {
|
|
3817
|
+
const callFuncF32 = (func) => {
|
|
3818
|
+
if (dtype !== DType.Float32) if (dtype === DType.Float64) cg.f32.demote_f64();
|
|
3819
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3820
|
+
cg.call(func);
|
|
3821
|
+
if (dtype === DType.Float64) cg.f64.promote_f32();
|
|
3822
|
+
};
|
|
3823
|
+
if (op === AluOp.Sin) gen(src[0]), callFuncF32(funcs.sin);
|
|
3824
|
+
else if (op === AluOp.Cos) gen(src[0]), callFuncF32(funcs.cos);
|
|
3825
|
+
else if (op === AluOp.Asin) gen(src[0]), callFuncF32(funcs.asin);
|
|
3826
|
+
else if (op === AluOp.Atan) gen(src[0]), callFuncF32(funcs.atan);
|
|
3827
|
+
else if (op === AluOp.Exp) gen(src[0]), callFuncF32(funcs.exp);
|
|
3828
|
+
else if (op === AluOp.Log) gen(src[0]), callFuncF32(funcs.log);
|
|
3829
|
+
else if (op === AluOp.Erf) gen(src[0]), callFuncF32(funcs.erf);
|
|
3830
|
+
else if (op === AluOp.Erfc) gen(src[0]), callFuncF32(funcs.erfc);
|
|
3831
|
+
else if (op === AluOp.Sqrt) gen(src[0]), dtyF(cg, op, dtype).sqrt();
|
|
3832
|
+
else if (op === AluOp.Reciprocal) {
|
|
3833
|
+
const dt = dtyF(cg, op, dtype);
|
|
3834
|
+
dt.const(1), gen(src[0]), dt.div();
|
|
3835
|
+
} else if (op === AluOp.Cast) {
|
|
3836
|
+
gen(src[0]);
|
|
3837
|
+
const dtype0 = src[0].dtype;
|
|
3838
|
+
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
|
|
3839
|
+
if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
|
|
3840
|
+
else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_s();
|
|
3841
|
+
else if (i32repr);
|
|
3842
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3843
|
+
else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
|
|
3844
|
+
else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_u();
|
|
3845
|
+
else if (i32repr);
|
|
3846
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3847
|
+
else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
|
|
3848
|
+
else if (dtype0 === DType.Float64) cg.f32.demote_f64();
|
|
3849
|
+
else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
|
|
3850
|
+
else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
|
|
3851
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3852
|
+
else if (dtype === DType.Float64) if (dtype0 === DType.Float32) cg.f64.promote_f32();
|
|
3853
|
+
else if (dtype0 === DType.Float64);
|
|
3854
|
+
else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f64.convert_i32_s();
|
|
3855
|
+
else if (dtype0 === DType.Uint32) cg.f64.convert_i32_u();
|
|
3856
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3857
|
+
else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
|
|
3858
|
+
else if (i32repr) cg.i32.const(0), cg.i32.ne();
|
|
3859
|
+
else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
|
|
3860
|
+
else if (dtype0 === DType.Float64) cg.f64.const(0), cg.f64.ne();
|
|
3861
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3862
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3863
|
+
} else if (op === AluOp.Bitcast) {
|
|
3864
|
+
gen(src[0]);
|
|
3865
|
+
const dtype0 = src[0].dtype;
|
|
3866
|
+
if (dtype !== dtype0) {
|
|
3867
|
+
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
|
|
3868
|
+
if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
|
|
3869
|
+
else if (i32repr);
|
|
3870
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3871
|
+
else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
|
|
3872
|
+
else if (dtype0 === DType.Float32);
|
|
3873
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3874
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3875
|
+
}
|
|
3876
|
+
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3877
|
+
} else if (op === AluOp.Where) {
|
|
3700
3878
|
gen(src[1]);
|
|
3701
3879
|
gen(src[2]);
|
|
3702
3880
|
gen(src[0]);
|
|
@@ -3741,12 +3919,20 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3741
3919
|
function dty(cg, op, dtype) {
|
|
3742
3920
|
switch (dtype) {
|
|
3743
3921
|
case DType.Float32: return cg.f32;
|
|
3922
|
+
case DType.Float64: return cg.f64;
|
|
3744
3923
|
case DType.Int32:
|
|
3745
3924
|
case DType.Uint32:
|
|
3746
3925
|
case DType.Bool: return cg.i32;
|
|
3747
3926
|
default: throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3748
3927
|
}
|
|
3749
3928
|
}
|
|
3929
|
+
function dtyF(cg, op, dtype) {
|
|
3930
|
+
switch (dtype) {
|
|
3931
|
+
case DType.Float32: return cg.f32;
|
|
3932
|
+
case DType.Float64: return cg.f64;
|
|
3933
|
+
default: throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3934
|
+
}
|
|
3935
|
+
}
|
|
3750
3936
|
|
|
3751
3937
|
//#endregion
|
|
3752
3938
|
//#region src/backend.ts
|
|
@@ -3755,10 +3941,10 @@ const devices = [
|
|
|
3755
3941
|
"wasm",
|
|
3756
3942
|
"webgpu"
|
|
3757
3943
|
];
|
|
3758
|
-
let defaultBackend = "wasm";
|
|
3759
3944
|
const initializedBackends = /* @__PURE__ */ new Map();
|
|
3760
3945
|
initializedBackends.set("cpu", new CpuBackend());
|
|
3761
|
-
initializedBackends.set("wasm", new WasmBackend());
|
|
3946
|
+
if (typeof WebAssembly !== "undefined") initializedBackends.set("wasm", new WasmBackend());
|
|
3947
|
+
let defaultBackend = initializedBackends.has("wasm") ? "wasm" : "cpu";
|
|
3762
3948
|
/** Configure the default device for arrays. */
|
|
3763
3949
|
function defaultDevice(device) {
|
|
3764
3950
|
if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
|
|
@@ -3785,12 +3971,14 @@ async function init(...devicesToInit) {
|
|
|
3785
3971
|
/** Create a backend, if available. Internal function called by `init()`. */
|
|
3786
3972
|
async function createBackend(device) {
|
|
3787
3973
|
if (device === "cpu") return new CpuBackend();
|
|
3788
|
-
else if (device === "wasm")
|
|
3789
|
-
|
|
3974
|
+
else if (device === "wasm") {
|
|
3975
|
+
if (typeof WebAssembly === "undefined") return null;
|
|
3976
|
+
return new WasmBackend();
|
|
3977
|
+
} else if (device === "webgpu") {
|
|
3790
3978
|
if (!navigator.gpu) return null;
|
|
3791
3979
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3792
3980
|
if (!adapter) return null;
|
|
3793
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
3981
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-DGYNVHma.cjs"));
|
|
3794
3982
|
const importantLimits = [
|
|
3795
3983
|
"maxBufferSize",
|
|
3796
3984
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -3939,6 +4127,12 @@ Object.defineProperty(exports, 'accessorGlobal', {
|
|
|
3939
4127
|
return accessorGlobal;
|
|
3940
4128
|
}
|
|
3941
4129
|
});
|
|
4130
|
+
Object.defineProperty(exports, 'assertNonNull', {
|
|
4131
|
+
enumerable: true,
|
|
4132
|
+
get: function () {
|
|
4133
|
+
return assertNonNull;
|
|
4134
|
+
}
|
|
4135
|
+
});
|
|
3942
4136
|
Object.defineProperty(exports, 'byteWidth', {
|
|
3943
4137
|
enumerable: true,
|
|
3944
4138
|
get: function () {
|
|
@@ -4029,6 +4223,12 @@ Object.defineProperty(exports, 'isPermutation', {
|
|
|
4029
4223
|
return isPermutation;
|
|
4030
4224
|
}
|
|
4031
4225
|
});
|
|
4226
|
+
Object.defineProperty(exports, 'mapSetUnion', {
|
|
4227
|
+
enumerable: true,
|
|
4228
|
+
get: function () {
|
|
4229
|
+
return mapSetUnion;
|
|
4230
|
+
}
|
|
4231
|
+
});
|
|
4032
4232
|
Object.defineProperty(exports, 'normalizeAxis', {
|
|
4033
4233
|
enumerable: true,
|
|
4034
4234
|
get: function () {
|
|
@@ -4101,12 +4301,6 @@ Object.defineProperty(exports, 'tuneWebgpu', {
|
|
|
4101
4301
|
return tuneWebgpu;
|
|
4102
4302
|
}
|
|
4103
4303
|
});
|
|
4104
|
-
Object.defineProperty(exports, 'union', {
|
|
4105
|
-
enumerable: true,
|
|
4106
|
-
get: function () {
|
|
4107
|
-
return union;
|
|
4108
|
-
}
|
|
4109
|
-
});
|
|
4110
4304
|
Object.defineProperty(exports, 'unravelAlu', {
|
|
4111
4305
|
enumerable: true,
|
|
4112
4306
|
get: function () {
|
|
@@ -4130,4 +4324,5 @@ Object.defineProperty(exports, 'zipn', {
|
|
|
4130
4324
|
get: function () {
|
|
4131
4325
|
return zipn;
|
|
4132
4326
|
}
|
|
4133
|
-
});
|
|
4327
|
+
});
|
|
4328
|
+
//# sourceMappingURL=backend-BbrKEB18.cjs.map
|