@jax-js/jax 0.0.5 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +267 -92
- package/dist/{backend-yEU0L_ig.cjs → backend-BbrKEB18.cjs} +378 -183
- package/dist/{backend-CdcTZEOF.js → backend-CoVtc9dx.js} +366 -177
- package/dist/index.cjs +385 -74
- package/dist/index.d.cts +115 -23
- package/dist/index.d.ts +115 -23
- package/dist/index.js +378 -74
- package/dist/{webgpu-CM-xNYzW.js → webgpu-B3UVme6n.js} +188 -153
- package/dist/{webgpu-CNOpiO5T.cjs → webgpu-DGYNVHma.cjs} +188 -153
- package/package.json +25 -15
|
@@ -51,6 +51,7 @@ let DEBUG = 0;
|
|
|
51
51
|
function setDebug(level) {
|
|
52
52
|
DEBUG = level;
|
|
53
53
|
}
|
|
54
|
+
function assertNonNull(value) {}
|
|
54
55
|
function unzip2(pairs) {
|
|
55
56
|
const lst1 = [];
|
|
56
57
|
const lst2 = [];
|
|
@@ -96,10 +97,15 @@ function deepEqual(a, b) {
|
|
|
96
97
|
for (const key of Object.keys(a)) if (!deepEqual(a[key], b[key])) return false;
|
|
97
98
|
return true;
|
|
98
99
|
}
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
100
|
+
/** Produces a union of maps of sets. This mutates `a`. */
|
|
101
|
+
function mapSetUnion(a, b) {
|
|
102
|
+
if (!b) return a;
|
|
103
|
+
for (const [key, setB] of b.entries()) {
|
|
104
|
+
const setA = a.get(key);
|
|
105
|
+
if (setA) for (const val of setB) setA.add(val);
|
|
106
|
+
else a.set(key, setB);
|
|
107
|
+
}
|
|
108
|
+
return a;
|
|
103
109
|
}
|
|
104
110
|
/** Splits the list based on a condition, `false` first then `true`. */
|
|
105
111
|
function partitionList(which, array) {
|
|
@@ -300,6 +306,7 @@ let DType = /* @__PURE__ */ function(DType$1) {
|
|
|
300
306
|
DType$1["Uint32"] = "uint32";
|
|
301
307
|
DType$1["Bool"] = "bool";
|
|
302
308
|
DType$1["Float16"] = "float16";
|
|
309
|
+
DType$1["Float64"] = "float64";
|
|
303
310
|
return DType$1;
|
|
304
311
|
}({});
|
|
305
312
|
const byteWidth = (dtype) => {
|
|
@@ -309,10 +316,11 @@ const byteWidth = (dtype) => {
|
|
|
309
316
|
case DType.Uint32:
|
|
310
317
|
case DType.Bool: return 4;
|
|
311
318
|
case DType.Float16: return 2;
|
|
319
|
+
case DType.Float64: return 8;
|
|
312
320
|
default: throw new TypeError(`Unknown dtype: ${dtype}`);
|
|
313
321
|
}
|
|
314
322
|
};
|
|
315
|
-
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
|
|
323
|
+
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16 || dtype === DType.Float64;
|
|
316
324
|
/**
|
|
317
325
|
* Promote two dtypes to their join according to the type lattice.
|
|
318
326
|
*
|
|
@@ -322,7 +330,7 @@ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float
|
|
|
322
330
|
*
|
|
323
331
|
* **Type lattice:**
|
|
324
332
|
* ```text
|
|
325
|
-
* bool -> uint32 -> int32 -> float16 -> float32
|
|
333
|
+
* bool -> uint32 -> int32 -> float16 -> float32 -> float64
|
|
326
334
|
* weakType --^
|
|
327
335
|
* ```
|
|
328
336
|
*
|
|
@@ -344,7 +352,8 @@ function promoteTypes(dtype1, dtype2) {
|
|
|
344
352
|
[DType.Uint32]: 1,
|
|
345
353
|
[DType.Int32]: 2,
|
|
346
354
|
[DType.Float16]: 3,
|
|
347
|
-
[DType.Float32]: 4
|
|
355
|
+
[DType.Float32]: 4,
|
|
356
|
+
[DType.Float64]: 5
|
|
348
357
|
};
|
|
349
358
|
return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
|
|
350
359
|
}
|
|
@@ -357,6 +366,7 @@ function dtypedArray(dtype, data) {
|
|
|
357
366
|
case DType.Bool: return new Int32Array(buffer, byteOffset, length);
|
|
358
367
|
case DType.Uint32: return new Uint32Array(buffer, byteOffset, length);
|
|
359
368
|
case DType.Float16: return new Float16Array(buffer, byteOffset, length);
|
|
369
|
+
case DType.Float64: return new Float64Array(buffer, byteOffset, length);
|
|
360
370
|
default: throw new Error(`Unimplemented dtype: ${dtype}`);
|
|
361
371
|
}
|
|
362
372
|
}
|
|
@@ -367,6 +377,7 @@ function dtypedJsArray(dtype, data) {
|
|
|
367
377
|
case DType.Bool: return new Int32Array(data);
|
|
368
378
|
case DType.Uint32: return new Uint32Array(data);
|
|
369
379
|
case DType.Float16: return new Float16Array(data);
|
|
380
|
+
case DType.Float64: return new Float64Array(data);
|
|
370
381
|
default: throw new Error(`Unimplemented dtype: ${dtype}`);
|
|
371
382
|
}
|
|
372
383
|
}
|
|
@@ -429,6 +440,12 @@ var AluExp = class AluExp {
|
|
|
429
440
|
static log(a) {
|
|
430
441
|
return new AluExp(AluOp.Log, a.dtype, [a]);
|
|
431
442
|
}
|
|
443
|
+
static erf(a) {
|
|
444
|
+
return new AluExp(AluOp.Erf, a.dtype, [a]);
|
|
445
|
+
}
|
|
446
|
+
static erfc(a) {
|
|
447
|
+
return new AluExp(AluOp.Erfc, a.dtype, [a]);
|
|
448
|
+
}
|
|
432
449
|
static sqrt(a) {
|
|
433
450
|
return new AluExp(AluOp.Sqrt, a.dtype, [a]);
|
|
434
451
|
}
|
|
@@ -498,6 +515,9 @@ var AluExp = class AluExp {
|
|
|
498
515
|
static f16(value) {
|
|
499
516
|
return AluExp.const(DType.Float16, value);
|
|
500
517
|
}
|
|
518
|
+
static f64(value) {
|
|
519
|
+
return AluExp.const(DType.Float64, value);
|
|
520
|
+
}
|
|
501
521
|
not() {
|
|
502
522
|
if (this.dtype !== DType.Bool) throw new Error("not() can only be called on boolean expressions");
|
|
503
523
|
return AluExp.cmpne(this, AluExp.const(DType.Bool, true));
|
|
@@ -508,7 +528,8 @@ var AluExp = class AluExp {
|
|
|
508
528
|
const hasher = new FpHash();
|
|
509
529
|
hasher.update(this.op);
|
|
510
530
|
hasher.update(this.dtype);
|
|
511
|
-
hasher.update(
|
|
531
|
+
if (this.op === AluOp.Const) hasher.update(this.arg);
|
|
532
|
+
else hasher.update(JSON.stringify(this.arg));
|
|
512
533
|
hasher.update(this.src.length);
|
|
513
534
|
for (const s of this.src) hasher.update(s);
|
|
514
535
|
this.#hash = hasher.value;
|
|
@@ -599,6 +620,12 @@ var AluExp = class AluExp {
|
|
|
599
620
|
case AluOp.Log:
|
|
600
621
|
ret = [Math.log(src[0].min), Math.log(src[0].max)];
|
|
601
622
|
break;
|
|
623
|
+
case AluOp.Erf:
|
|
624
|
+
ret = [erf(src[0].min), erf(src[0].max)];
|
|
625
|
+
break;
|
|
626
|
+
case AluOp.Erfc:
|
|
627
|
+
ret = [erfc(src[0].max), erfc(src[0].min)];
|
|
628
|
+
break;
|
|
602
629
|
case AluOp.Sqrt:
|
|
603
630
|
ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
|
|
604
631
|
break;
|
|
@@ -739,6 +766,7 @@ var AluExp = class AluExp {
|
|
|
739
766
|
if (op === AluOp.Mul && x === 1) return src[1 - i];
|
|
740
767
|
if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
|
|
741
768
|
if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
|
|
769
|
+
if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
|
|
742
770
|
}
|
|
743
771
|
if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
|
|
744
772
|
const [a, b] = src[1].src;
|
|
@@ -920,6 +948,8 @@ var AluExp = class AluExp {
|
|
|
920
948
|
case AluOp.Atan: return Math.atan(x);
|
|
921
949
|
case AluOp.Exp: return Math.exp(x);
|
|
922
950
|
case AluOp.Log: return Math.log(x);
|
|
951
|
+
case AluOp.Erf: return erf(x);
|
|
952
|
+
case AluOp.Erfc: return erfc(x);
|
|
923
953
|
case AluOp.Sqrt: return Math.sqrt(x);
|
|
924
954
|
case AluOp.Reciprocal: return 1 / x;
|
|
925
955
|
case AluOp.Cast: {
|
|
@@ -938,11 +968,13 @@ var AluExp = class AluExp {
|
|
|
938
968
|
else if (fromType === DType.Int32) view.setInt32(0, x, true);
|
|
939
969
|
else if (fromType === DType.Uint32) view.setUint32(0, x, true);
|
|
940
970
|
else if (fromType === DType.Float16) view.setFloat16(0, x, true);
|
|
971
|
+
else if (fromType === DType.Float64) view.setFloat64(0, x, true);
|
|
941
972
|
else throw new Error(`Unsupported bitcast from ${fromType}`);
|
|
942
973
|
if (this.dtype === DType.Float32) return view.getFloat32(0, true);
|
|
943
974
|
else if (this.dtype === DType.Int32) return view.getInt32(0, true);
|
|
944
975
|
else if (this.dtype === DType.Uint32) return view.getUint32(0, true);
|
|
945
976
|
else if (this.dtype === DType.Float16) return view.getFloat16(0, true);
|
|
977
|
+
else if (this.dtype === DType.Float64) return view.getFloat64(0, true);
|
|
946
978
|
else throw new Error(`Unsupported bitcast to ${this.dtype}`);
|
|
947
979
|
}
|
|
948
980
|
default: throw new Error(`Missing implemementation for ${this.op}`);
|
|
@@ -1068,11 +1100,15 @@ var AluExp = class AluExp {
|
|
|
1068
1100
|
});
|
|
1069
1101
|
return result;
|
|
1070
1102
|
}
|
|
1071
|
-
/** Produce
|
|
1103
|
+
/** Produce all distinct AluOp in this expression, with their dtypes. */
|
|
1072
1104
|
distinctOps() {
|
|
1073
|
-
const ops = /* @__PURE__ */ new
|
|
1105
|
+
const ops = /* @__PURE__ */ new Map();
|
|
1074
1106
|
this.fold((exp) => {
|
|
1075
|
-
ops.
|
|
1107
|
+
const s = ops.get(exp.op) ?? /* @__PURE__ */ new Set();
|
|
1108
|
+
if (!s.has(exp.dtype)) {
|
|
1109
|
+
s.add(exp.dtype);
|
|
1110
|
+
ops.set(exp.op, s);
|
|
1111
|
+
}
|
|
1076
1112
|
});
|
|
1077
1113
|
return ops;
|
|
1078
1114
|
}
|
|
@@ -1101,6 +1137,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
1101
1137
|
AluOp$1["Atan"] = "Atan";
|
|
1102
1138
|
AluOp$1["Exp"] = "Exp";
|
|
1103
1139
|
AluOp$1["Log"] = "Log";
|
|
1140
|
+
AluOp$1["Erf"] = "Erf";
|
|
1141
|
+
AluOp$1["Erfc"] = "Erfc";
|
|
1104
1142
|
AluOp$1["Sqrt"] = "Sqrt";
|
|
1105
1143
|
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
1106
1144
|
AluOp$1["Cast"] = "Cast";
|
|
@@ -1133,6 +1171,8 @@ const AluGroup = {
|
|
|
1133
1171
|
AluOp.Atan,
|
|
1134
1172
|
AluOp.Exp,
|
|
1135
1173
|
AluOp.Log,
|
|
1174
|
+
AluOp.Erf,
|
|
1175
|
+
AluOp.Erfc,
|
|
1136
1176
|
AluOp.Sqrt,
|
|
1137
1177
|
AluOp.Reciprocal,
|
|
1138
1178
|
AluOp.Cast,
|
|
@@ -1158,6 +1198,8 @@ const AluGroup = {
|
|
|
1158
1198
|
AluOp.Atan,
|
|
1159
1199
|
AluOp.Exp,
|
|
1160
1200
|
AluOp.Log,
|
|
1201
|
+
AluOp.Erf,
|
|
1202
|
+
AluOp.Erfc,
|
|
1161
1203
|
AluOp.Sqrt,
|
|
1162
1204
|
AluOp.Reciprocal
|
|
1163
1205
|
])
|
|
@@ -1333,6 +1375,44 @@ function threefry2x32(k0, k1, c0, c1) {
|
|
|
1333
1375
|
x1 = x1 + ks0 + 5 >>> 0;
|
|
1334
1376
|
return [x0, x1];
|
|
1335
1377
|
}
|
|
1378
|
+
/**
|
|
1379
|
+
* Abramowitz & Stegun’s widely used approximation for erf(x).
|
|
1380
|
+
*
|
|
1381
|
+
* `erf(x) = 1 - P(t) * exp(-x^2)` for `x >= 0`, where `t = 1/(1 + p*x)` and
|
|
1382
|
+
* `P(t) = a1*t + a2*t^2 + a3*t^3 + a4*t^4 + a5*t^5`.
|
|
1383
|
+
*
|
|
1384
|
+
* Coefficients:
|
|
1385
|
+
* - p = 0.3275911
|
|
1386
|
+
* - a1 = 0.254829592
|
|
1387
|
+
* - a2 = -0.284496736
|
|
1388
|
+
* - a3 = 1.421413741
|
|
1389
|
+
* - a4 = -1.453152027
|
|
1390
|
+
* - a5 = 1.061405429
|
|
1391
|
+
*
|
|
1392
|
+
* This function computes just `E = P(t) * exp(-x^2)` for numerical reasons. The
|
|
1393
|
+
* input is assumed to be non-negative.
|
|
1394
|
+
*
|
|
1395
|
+
* Reference: https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions
|
|
1396
|
+
*/
|
|
1397
|
+
function _erfapprox$1(x) {
|
|
1398
|
+
const p = .3275911;
|
|
1399
|
+
const a1 = .254829592;
|
|
1400
|
+
const a2 = -.284496736;
|
|
1401
|
+
const a3 = 1.421413741;
|
|
1402
|
+
const a4 = -1.453152027;
|
|
1403
|
+
const a5 = 1.061405429;
|
|
1404
|
+
const t = 1 / (1 + p * x);
|
|
1405
|
+
const P_t = ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t;
|
|
1406
|
+
return P_t * Math.exp(-x * x);
|
|
1407
|
+
}
|
|
1408
|
+
function erf(x) {
|
|
1409
|
+
if (x >= 0) return 1 - _erfapprox$1(x);
|
|
1410
|
+
else return _erfapprox$1(-x) - 1;
|
|
1411
|
+
}
|
|
1412
|
+
function erfc(x) {
|
|
1413
|
+
if (x >= 0) return _erfapprox$1(x);
|
|
1414
|
+
else return 2 - _erfapprox$1(-x);
|
|
1415
|
+
}
|
|
1336
1416
|
|
|
1337
1417
|
//#endregion
|
|
1338
1418
|
//#region src/shape.ts
|
|
@@ -2018,7 +2098,7 @@ function tuneWebgpu(kernel) {
|
|
|
2018
2098
|
}
|
|
2019
2099
|
if (/chrome/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
2020
2100
|
const s = dim.st.shape[dim.unroll - 1];
|
|
2021
|
-
if (s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2101
|
+
if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2022
2102
|
else for (const splits of [4]) if (s % splits === 0) {
|
|
2023
2103
|
dim.applyUnroll(dim.unroll - 1, splits);
|
|
2024
2104
|
break;
|
|
@@ -2237,6 +2317,19 @@ var WasmAllocator = class {
|
|
|
2237
2317
|
|
|
2238
2318
|
//#endregion
|
|
2239
2319
|
//#region src/backend/wasm/builtins.ts
|
|
2320
|
+
/** Given a local `x`, evaluate `sum[i](a_i * x^i)` and push to stack. */
|
|
2321
|
+
function _poly(cg, x, as) {
|
|
2322
|
+
if (as.length === 0) throw new Error("_poly needs at least one coefficient");
|
|
2323
|
+
cg.f32.const(as[as.length - 1]);
|
|
2324
|
+
for (let i = as.length - 2; i >= 0; i--) {
|
|
2325
|
+
cg.local.get(x);
|
|
2326
|
+
cg.f32.mul();
|
|
2327
|
+
if (as[i] !== 0) {
|
|
2328
|
+
cg.f32.const(as[i]);
|
|
2329
|
+
cg.f32.add();
|
|
2330
|
+
}
|
|
2331
|
+
}
|
|
2332
|
+
}
|
|
2240
2333
|
/**
|
|
2241
2334
|
* Approximate e^x.
|
|
2242
2335
|
*
|
|
@@ -2277,27 +2370,15 @@ function wasm_exp(cg) {
|
|
|
2277
2370
|
cg.f32.mul();
|
|
2278
2371
|
cg.f32.sub();
|
|
2279
2372
|
cg.local.set(r);
|
|
2280
|
-
cg
|
|
2281
|
-
|
|
2282
|
-
|
|
2283
|
-
|
|
2284
|
-
|
|
2285
|
-
|
|
2286
|
-
|
|
2287
|
-
|
|
2288
|
-
|
|
2289
|
-
cg.local.get(r);
|
|
2290
|
-
cg.f32.mul();
|
|
2291
|
-
cg.f32.const(1 / 2);
|
|
2292
|
-
cg.f32.add();
|
|
2293
|
-
cg.local.get(r);
|
|
2294
|
-
cg.f32.mul();
|
|
2295
|
-
cg.f32.const(1);
|
|
2296
|
-
cg.f32.add();
|
|
2297
|
-
cg.local.get(r);
|
|
2298
|
-
cg.f32.mul();
|
|
2299
|
-
cg.f32.const(1);
|
|
2300
|
-
cg.f32.add();
|
|
2373
|
+
_poly(cg, r, [
|
|
2374
|
+
1,
|
|
2375
|
+
1,
|
|
2376
|
+
1 / 2,
|
|
2377
|
+
1 / 6,
|
|
2378
|
+
1 / 24,
|
|
2379
|
+
1 / 120,
|
|
2380
|
+
1 / 720
|
|
2381
|
+
]);
|
|
2301
2382
|
cg.local.set(p);
|
|
2302
2383
|
cg.local.get(k);
|
|
2303
2384
|
cg.i32.const(127);
|
|
@@ -2325,11 +2406,6 @@ function wasm_log(cg) {
|
|
|
2325
2406
|
const m = cg.local.declare(cg.f32);
|
|
2326
2407
|
const t = cg.local.declare(cg.f32);
|
|
2327
2408
|
const t2 = cg.local.declare(cg.f32);
|
|
2328
|
-
const t3 = cg.local.declare(cg.f32);
|
|
2329
|
-
const t5 = cg.local.declare(cg.f32);
|
|
2330
|
-
const t7 = cg.local.declare(cg.f32);
|
|
2331
|
-
const lnm = cg.local.declare(cg.f32);
|
|
2332
|
-
const el2 = cg.local.declare(cg.f32);
|
|
2333
2409
|
cg.local.get(0);
|
|
2334
2410
|
cg.f32.const(0);
|
|
2335
2411
|
cg.f32.le();
|
|
@@ -2366,41 +2442,18 @@ function wasm_log(cg) {
|
|
|
2366
2442
|
cg.local.get(t);
|
|
2367
2443
|
cg.f32.mul();
|
|
2368
2444
|
cg.local.set(t2);
|
|
2445
|
+
_poly(cg, t2, [
|
|
2446
|
+
2,
|
|
2447
|
+
2 / 3,
|
|
2448
|
+
2 / 5,
|
|
2449
|
+
2 / 7
|
|
2450
|
+
]);
|
|
2369
2451
|
cg.local.get(t);
|
|
2370
|
-
cg.local.get(t2);
|
|
2371
|
-
cg.f32.mul();
|
|
2372
|
-
cg.local.set(t3);
|
|
2373
|
-
cg.local.get(t3);
|
|
2374
|
-
cg.local.get(t2);
|
|
2375
|
-
cg.f32.mul();
|
|
2376
|
-
cg.local.set(t5);
|
|
2377
|
-
cg.local.get(t5);
|
|
2378
|
-
cg.local.get(t2);
|
|
2379
|
-
cg.f32.mul();
|
|
2380
|
-
cg.local.set(t7);
|
|
2381
|
-
cg.local.get(t7);
|
|
2382
|
-
cg.f32.const(1 / 7);
|
|
2383
|
-
cg.f32.mul();
|
|
2384
|
-
cg.local.get(t5);
|
|
2385
|
-
cg.f32.const(1 / 5);
|
|
2386
|
-
cg.f32.mul();
|
|
2387
|
-
cg.f32.add();
|
|
2388
|
-
cg.local.get(t3);
|
|
2389
|
-
cg.f32.const(1 / 3);
|
|
2390
|
-
cg.f32.mul();
|
|
2391
|
-
cg.f32.add();
|
|
2392
|
-
cg.local.get(t);
|
|
2393
|
-
cg.f32.add();
|
|
2394
|
-
cg.f32.const(2);
|
|
2395
2452
|
cg.f32.mul();
|
|
2396
|
-
cg.local.set(lnm);
|
|
2397
2453
|
cg.local.get(e);
|
|
2398
2454
|
cg.f32.convert_i32_s();
|
|
2399
2455
|
cg.f32.const(Math.LN2);
|
|
2400
2456
|
cg.f32.mul();
|
|
2401
|
-
cg.local.set(el2);
|
|
2402
|
-
cg.local.get(el2);
|
|
2403
|
-
cg.local.get(lnm);
|
|
2404
2457
|
cg.f32.add();
|
|
2405
2458
|
});
|
|
2406
2459
|
}
|
|
@@ -2410,7 +2463,7 @@ function wasm_log(cg) {
|
|
|
2410
2463
|
* Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
|
|
2411
2464
|
* z = y - q*(π/2); use one of two polynomials on z:
|
|
2412
2465
|
* sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
|
|
2413
|
-
* cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
|
|
2466
|
+
* cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720) + z^8*(1/40320)
|
|
2414
2467
|
*/
|
|
2415
2468
|
function _sincos(cg) {
|
|
2416
2469
|
const y = cg.local.declare(cg.f32);
|
|
@@ -2446,35 +2499,22 @@ function _sincos(cg) {
|
|
|
2446
2499
|
cg.local.get(z);
|
|
2447
2500
|
cg.f32.mul();
|
|
2448
2501
|
cg.local.set(z2);
|
|
2449
|
-
cg
|
|
2450
|
-
|
|
2451
|
-
|
|
2452
|
-
|
|
2453
|
-
|
|
2454
|
-
|
|
2455
|
-
cg.f32.mul();
|
|
2456
|
-
cg.f32.const(-1 / 6);
|
|
2457
|
-
cg.f32.add();
|
|
2458
|
-
cg.local.get(z2);
|
|
2459
|
-
cg.f32.mul();
|
|
2460
|
-
cg.f32.const(1);
|
|
2461
|
-
cg.f32.add();
|
|
2502
|
+
_poly(cg, z2, [
|
|
2503
|
+
1,
|
|
2504
|
+
-1 / 6,
|
|
2505
|
+
1 / 120,
|
|
2506
|
+
-1 / 5040
|
|
2507
|
+
]);
|
|
2462
2508
|
cg.local.get(z);
|
|
2463
2509
|
cg.f32.mul();
|
|
2464
2510
|
cg.local.set(sz);
|
|
2465
|
-
cg
|
|
2466
|
-
|
|
2467
|
-
|
|
2468
|
-
|
|
2469
|
-
|
|
2470
|
-
|
|
2471
|
-
|
|
2472
|
-
cg.f32.const(-1 / 2);
|
|
2473
|
-
cg.f32.add();
|
|
2474
|
-
cg.local.get(z2);
|
|
2475
|
-
cg.f32.mul();
|
|
2476
|
-
cg.f32.const(1);
|
|
2477
|
-
cg.f32.add();
|
|
2511
|
+
_poly(cg, z2, [
|
|
2512
|
+
1,
|
|
2513
|
+
-1 / 2,
|
|
2514
|
+
1 / 24,
|
|
2515
|
+
-1 / 720,
|
|
2516
|
+
1 / 40320
|
|
2517
|
+
]);
|
|
2478
2518
|
cg.local.set(cz);
|
|
2479
2519
|
return {
|
|
2480
2520
|
q,
|
|
@@ -2556,24 +2596,16 @@ function _atan(cg) {
|
|
|
2556
2596
|
cg.local.get(z);
|
|
2557
2597
|
cg.f32.mul();
|
|
2558
2598
|
cg.local.set(z2);
|
|
2559
|
-
cg
|
|
2560
|
-
|
|
2561
|
-
|
|
2562
|
-
|
|
2563
|
-
|
|
2564
|
-
cg
|
|
2565
|
-
|
|
2566
|
-
|
|
2567
|
-
|
|
2568
|
-
|
|
2569
|
-
cg.local.get(z2);
|
|
2570
|
-
cg.f32.mul();
|
|
2571
|
-
cg.f32.const(.994987933645);
|
|
2572
|
-
cg.f32.add();
|
|
2573
|
-
cg.local.get(z2);
|
|
2574
|
-
cg.f32.mul();
|
|
2575
|
-
cg.f32.const(1);
|
|
2576
|
-
cg.f32.add();
|
|
2599
|
+
_poly(cg, z2, [
|
|
2600
|
+
.999998614341,
|
|
2601
|
+
.661705427875,
|
|
2602
|
+
.0415796528637
|
|
2603
|
+
]);
|
|
2604
|
+
_poly(cg, z2, [
|
|
2605
|
+
1,
|
|
2606
|
+
.994987933645,
|
|
2607
|
+
.173698870181
|
|
2608
|
+
]);
|
|
2577
2609
|
cg.f32.div();
|
|
2578
2610
|
cg.local.get(z);
|
|
2579
2611
|
cg.f32.mul();
|
|
@@ -2627,6 +2659,74 @@ function wasm_asin(cg) {
|
|
|
2627
2659
|
});
|
|
2628
2660
|
}
|
|
2629
2661
|
/**
|
|
2662
|
+
* Helper function for erf/erfc approximation.
|
|
2663
|
+
*
|
|
2664
|
+
* See `_erfapprox` in alu.ts for details on the algorithm used.
|
|
2665
|
+
*/
|
|
2666
|
+
function _erfapprox(cg, exp_func) {
|
|
2667
|
+
const x = cg.local.declare(cg.f32);
|
|
2668
|
+
const t = cg.local.declare(cg.f32);
|
|
2669
|
+
cg.local.set(x);
|
|
2670
|
+
const p = .3275911;
|
|
2671
|
+
const a1 = .254829592;
|
|
2672
|
+
const a2 = -.284496736;
|
|
2673
|
+
const a3 = 1.421413741;
|
|
2674
|
+
const a4 = -1.453152027;
|
|
2675
|
+
const a5 = 1.061405429;
|
|
2676
|
+
cg.f32.const(1);
|
|
2677
|
+
cg.f32.const(1);
|
|
2678
|
+
cg.f32.const(p);
|
|
2679
|
+
cg.local.get(x);
|
|
2680
|
+
cg.f32.mul();
|
|
2681
|
+
cg.f32.add();
|
|
2682
|
+
cg.f32.div();
|
|
2683
|
+
cg.local.set(t);
|
|
2684
|
+
_poly(cg, t, [
|
|
2685
|
+
0,
|
|
2686
|
+
a1,
|
|
2687
|
+
a2,
|
|
2688
|
+
a3,
|
|
2689
|
+
a4,
|
|
2690
|
+
a5
|
|
2691
|
+
]);
|
|
2692
|
+
cg.local.get(x);
|
|
2693
|
+
cg.f32.neg();
|
|
2694
|
+
cg.local.get(x);
|
|
2695
|
+
cg.f32.mul();
|
|
2696
|
+
cg.call(exp_func);
|
|
2697
|
+
cg.f32.mul();
|
|
2698
|
+
}
|
|
2699
|
+
/** Approximate erf(x) (error function). */
|
|
2700
|
+
function wasm_erf(cg, exp) {
|
|
2701
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2702
|
+
cg.f32.const(1);
|
|
2703
|
+
cg.local.get(0);
|
|
2704
|
+
cg.f32.abs();
|
|
2705
|
+
_erfapprox(cg, exp);
|
|
2706
|
+
cg.f32.sub();
|
|
2707
|
+
cg.local.get(0);
|
|
2708
|
+
cg.f32.copysign();
|
|
2709
|
+
});
|
|
2710
|
+
}
|
|
2711
|
+
/** Approximate erfc(x) (complementary error function). */
|
|
2712
|
+
function wasm_erfc(cg, exp) {
|
|
2713
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2714
|
+
const e = cg.local.declare(cg.f32);
|
|
2715
|
+
cg.local.get(0);
|
|
2716
|
+
cg.f32.abs();
|
|
2717
|
+
_erfapprox(cg, exp);
|
|
2718
|
+
cg.local.set(e);
|
|
2719
|
+
cg.f32.const(2);
|
|
2720
|
+
cg.local.get(e);
|
|
2721
|
+
cg.f32.sub();
|
|
2722
|
+
cg.local.get(e);
|
|
2723
|
+
cg.local.get(0);
|
|
2724
|
+
cg.f32.const(0);
|
|
2725
|
+
cg.f32.lt();
|
|
2726
|
+
cg.select();
|
|
2727
|
+
});
|
|
2728
|
+
}
|
|
2729
|
+
/**
|
|
2630
2730
|
* Threefry2x32 pseudorandom number generator.
|
|
2631
2731
|
*
|
|
2632
2732
|
* Takes two 32-bit keys and two 32-bit counters as input,
|
|
@@ -2837,6 +2937,7 @@ var CodeGenerator = class {
|
|
|
2837
2937
|
local;
|
|
2838
2938
|
i32;
|
|
2839
2939
|
f32;
|
|
2940
|
+
f64;
|
|
2840
2941
|
v128;
|
|
2841
2942
|
i32x4;
|
|
2842
2943
|
f32x4;
|
|
@@ -2856,6 +2957,7 @@ var CodeGenerator = class {
|
|
|
2856
2957
|
this.local = new Local(this);
|
|
2857
2958
|
this.i32 = new I32(this);
|
|
2858
2959
|
this.f32 = new F32(this);
|
|
2960
|
+
this.f64 = new F64(this);
|
|
2859
2961
|
this.v128 = new V128(this);
|
|
2860
2962
|
this.i32x4 = new I32x4(this);
|
|
2861
2963
|
this.f32x4 = new F32x4(this);
|
|
@@ -3242,6 +3344,8 @@ var I32 = class {
|
|
|
3242
3344
|
ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
|
|
3243
3345
|
trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
|
|
3244
3346
|
trunc_f32_u = UNARY_OP("trunc_f32_u", 169, "f32", "i32");
|
|
3347
|
+
trunc_f64_s = UNARY_OP("trunc_f64_s", 170, "f64", "i32");
|
|
3348
|
+
trunc_f64_u = UNARY_OP("trunc_f64_u", 171, "f64", "i32");
|
|
3245
3349
|
load = LOAD_OP("load", 40, "i32");
|
|
3246
3350
|
load8_s = LOAD_OP("load8_s", 44, "i32");
|
|
3247
3351
|
load8_u = LOAD_OP("load8_u", 45, "i32");
|
|
@@ -3253,6 +3357,8 @@ var I32 = class {
|
|
|
3253
3357
|
reinterpret_f32 = UNARY_OP("reinterpret_f32", 188, "f32", "i32");
|
|
3254
3358
|
trunc_sat_f32_s = UNARY_OP("trunc_sat_f32_s", [252, 0], "f32", "i32");
|
|
3255
3359
|
trunc_sat_f32_u = UNARY_OP("trunc_sat_f32_u", [252, 1], "f32", "i32");
|
|
3360
|
+
trunc_sat_f64_s = UNARY_OP("trunc_sat_f64_s", [252, 2], "f64", "i32");
|
|
3361
|
+
trunc_sat_f64_u = UNARY_OP("trunc_sat_f64_u", [252, 3], "f64", "i32");
|
|
3256
3362
|
};
|
|
3257
3363
|
var F32 = class {
|
|
3258
3364
|
constructor(cg) {
|
|
@@ -3272,6 +3378,8 @@ var F32 = class {
|
|
|
3272
3378
|
for (let i = 0; i < 4; i++) this.cg._emit(bytes[i]);
|
|
3273
3379
|
this.cg._push(this);
|
|
3274
3380
|
}
|
|
3381
|
+
load = LOAD_OP("load", 42, "f32");
|
|
3382
|
+
store = STORE_OP("store", 56, "f32");
|
|
3275
3383
|
eq = BINARY_OP("eq", 91, "f32", "f32", "i32");
|
|
3276
3384
|
ne = BINARY_OP("ne", 92, "f32", "f32", "i32");
|
|
3277
3385
|
lt = BINARY_OP("lt", 93, "f32", "f32", "i32");
|
|
@@ -3294,10 +3402,53 @@ var F32 = class {
|
|
|
3294
3402
|
copysign = BINARY_OP("copysign", 152, "f32", "f32", "f32");
|
|
3295
3403
|
convert_i32_s = UNARY_OP("convert_i32_s", 178, "i32", "f32");
|
|
3296
3404
|
convert_i32_u = UNARY_OP("convert_i32_u", 179, "i32", "f32");
|
|
3297
|
-
|
|
3298
|
-
store = STORE_OP("store", 56, "f32");
|
|
3405
|
+
demote_f64 = UNARY_OP("demote_f64", 182, "f64", "f32");
|
|
3299
3406
|
reinterpret_i32 = UNARY_OP("reinterpret_i32", 190, "i32", "f32");
|
|
3300
3407
|
};
|
|
3408
|
+
var F64 = class {
|
|
3409
|
+
constructor(cg) {
|
|
3410
|
+
this.cg = cg;
|
|
3411
|
+
}
|
|
3412
|
+
get typeId() {
|
|
3413
|
+
return 124;
|
|
3414
|
+
}
|
|
3415
|
+
get name() {
|
|
3416
|
+
return "f64";
|
|
3417
|
+
}
|
|
3418
|
+
const(f) {
|
|
3419
|
+
this.cg._emit(68);
|
|
3420
|
+
const buffer = /* @__PURE__ */ new ArrayBuffer(8);
|
|
3421
|
+
new DataView(buffer).setFloat64(0, f, true);
|
|
3422
|
+
const bytes = new Uint8Array(buffer);
|
|
3423
|
+
for (let i = 0; i < 8; i++) this.cg._emit(bytes[i]);
|
|
3424
|
+
this.cg._push(this);
|
|
3425
|
+
}
|
|
3426
|
+
load = LOAD_OP("load", 43, "f64");
|
|
3427
|
+
store = STORE_OP("store", 57, "f64");
|
|
3428
|
+
eq = BINARY_OP("eq", 97, "f64", "f64", "i32");
|
|
3429
|
+
ne = BINARY_OP("ne", 98, "f64", "f64", "i32");
|
|
3430
|
+
lt = BINARY_OP("lt", 99, "f64", "f64", "i32");
|
|
3431
|
+
gt = BINARY_OP("gt", 100, "f64", "f64", "i32");
|
|
3432
|
+
le = BINARY_OP("le", 101, "f64", "f64", "i32");
|
|
3433
|
+
ge = BINARY_OP("ge", 102, "f64", "f64", "i32");
|
|
3434
|
+
abs = UNARY_OP("abs", 153, "f64", "f64");
|
|
3435
|
+
neg = UNARY_OP("neg", 154, "f64", "f64");
|
|
3436
|
+
ceil = UNARY_OP("ceil", 155, "f64", "f64");
|
|
3437
|
+
floor = UNARY_OP("floor", 156, "f64", "f64");
|
|
3438
|
+
trunc = UNARY_OP("trunc", 157, "f64", "f64");
|
|
3439
|
+
nearest = UNARY_OP("nearest", 158, "f64", "f64");
|
|
3440
|
+
sqrt = UNARY_OP("sqrt", 159, "f64", "f64");
|
|
3441
|
+
add = BINARY_OP("add", 160, "f64", "f64", "f64");
|
|
3442
|
+
sub = BINARY_OP("sub", 161, "f64", "f64", "f64");
|
|
3443
|
+
mul = BINARY_OP("mul", 162, "f64", "f64", "f64");
|
|
3444
|
+
div = BINARY_OP("div", 163, "f64", "f64", "f64");
|
|
3445
|
+
min = BINARY_OP("min", 164, "f64", "f64", "f64");
|
|
3446
|
+
max = BINARY_OP("max", 165, "f64", "f64", "f64");
|
|
3447
|
+
copysign = BINARY_OP("copysign", 166, "f64", "f64", "f64");
|
|
3448
|
+
convert_i32_s = UNARY_OP("convert_i32_s", 183, "i32", "f64");
|
|
3449
|
+
convert_i32_u = UNARY_OP("convert_i32_u", 184, "i32", "f64");
|
|
3450
|
+
promote_f32 = UNARY_OP("promote_f32", 187, "f32", "f64");
|
|
3451
|
+
};
|
|
3301
3452
|
function VECTOR_OP(op, vopcode, inTypes, outType) {
|
|
3302
3453
|
return function() {
|
|
3303
3454
|
for (const inType of inTypes.toReversed()) {
|
|
@@ -3501,14 +3652,16 @@ function codegenWasm(kernel) {
|
|
|
3501
3652
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
3502
3653
|
const cg = new CodeGenerator();
|
|
3503
3654
|
cg.memory.import("env", "memory");
|
|
3504
|
-
const distinctOps =
|
|
3655
|
+
const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
|
|
3505
3656
|
const funcs = {};
|
|
3506
3657
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3507
3658
|
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
3508
3659
|
if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
|
|
3509
3660
|
if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
|
|
3510
|
-
if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
|
|
3661
|
+
if (distinctOps.has(AluOp.Exp) || distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) funcs.exp = wasm_exp(cg);
|
|
3511
3662
|
if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
|
|
3663
|
+
if (distinctOps.has(AluOp.Erf)) funcs.erf = wasm_erf(cg, funcs.exp);
|
|
3664
|
+
if (distinctOps.has(AluOp.Erfc)) funcs.erfc = wasm_erfc(cg, funcs.exp);
|
|
3512
3665
|
if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
|
|
3513
3666
|
const kernelFunc = cg.function(rep(kernel.nargs + 1, cg.i32), [], () => {
|
|
3514
3667
|
const gidx = cg.local.declare(cg.i32);
|
|
@@ -3548,10 +3701,10 @@ function codegenWasm(kernel) {
|
|
|
3548
3701
|
cg.local.get(acc);
|
|
3549
3702
|
if (re.dtype === DType.Bool) cg.i32.and();
|
|
3550
3703
|
else dty(cg, re.op, re.dtype).mul();
|
|
3551
|
-
} else if (re.op === AluOp.Min || re.op === AluOp.Max) if (re.dtype
|
|
3704
|
+
} else if (re.op === AluOp.Min || re.op === AluOp.Max) if (isFloatDtype(re.dtype)) {
|
|
3552
3705
|
cg.local.get(acc);
|
|
3553
|
-
if (re.op === AluOp.Min) cg.
|
|
3554
|
-
else cg.
|
|
3706
|
+
if (re.op === AluOp.Min) dtyF(cg, re.op, re.dtype).min();
|
|
3707
|
+
else dtyF(cg, re.op, re.dtype).max();
|
|
3555
3708
|
} else if ([
|
|
3556
3709
|
DType.Int32,
|
|
3557
3710
|
DType.Uint32,
|
|
@@ -3613,27 +3766,30 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3613
3766
|
else if (op === AluOp.Sub) dty(cg, op, dtype).sub();
|
|
3614
3767
|
else if (op === AluOp.Mul) if (dtype === DType.Bool) cg.i32.and();
|
|
3615
3768
|
else dty(cg, op, dtype).mul();
|
|
3616
|
-
else if (op === AluOp.Idiv) if (dtype
|
|
3617
|
-
|
|
3769
|
+
else if (op === AluOp.Idiv) if (isFloatDtype(dtype)) {
|
|
3770
|
+
dtyF(cg, op, dtype).div();
|
|
3771
|
+
dtyF(cg, op, dtype).trunc();
|
|
3772
|
+
} else if (dtype === DType.Uint32) cg.i32.div_u();
|
|
3618
3773
|
else if (dtype === DType.Int32) cg.i32.div_s();
|
|
3619
3774
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3620
|
-
else if (op === AluOp.Mod) if (dtype
|
|
3621
|
-
const
|
|
3622
|
-
const
|
|
3775
|
+
else if (op === AluOp.Mod) if (isFloatDtype(dtype)) {
|
|
3776
|
+
const dt = dtyF(cg, op, dtype);
|
|
3777
|
+
const a = cg.local.declare(dt);
|
|
3778
|
+
const b = cg.local.declare(dt);
|
|
3623
3779
|
cg.local.set(b);
|
|
3624
3780
|
cg.local.tee(a);
|
|
3625
3781
|
cg.local.get(a);
|
|
3626
3782
|
cg.local.get(b);
|
|
3627
|
-
|
|
3628
|
-
|
|
3783
|
+
dt.div();
|
|
3784
|
+
dt.trunc();
|
|
3629
3785
|
cg.local.get(b);
|
|
3630
|
-
|
|
3631
|
-
|
|
3786
|
+
dt.mul();
|
|
3787
|
+
dt.sub();
|
|
3632
3788
|
} else if (dtype === DType.Uint32) cg.i32.rem_u();
|
|
3633
3789
|
else if (dtype === DType.Int32) cg.i32.rem_s();
|
|
3634
3790
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3635
|
-
else if (op === AluOp.Min || op === AluOp.Max) if (dtype
|
|
3636
|
-
else cg.
|
|
3791
|
+
else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
|
|
3792
|
+
else dtyF(cg, op, dtype).max();
|
|
3637
3793
|
else if (dtype === DType.Int32 || dtype === DType.Uint32) {
|
|
3638
3794
|
const a = cg.local.declare(cg.i32);
|
|
3639
3795
|
const b = cg.local.declare(cg.i32);
|
|
@@ -3650,52 +3806,74 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3650
3806
|
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3651
3807
|
else if (op === AluOp.Cmplt) {
|
|
3652
3808
|
const srcDtype = src[0].dtype;
|
|
3653
|
-
if (srcDtype
|
|
3809
|
+
if (isFloatDtype(srcDtype)) dtyF(cg, op, srcDtype).lt();
|
|
3654
3810
|
else if (srcDtype === DType.Int32) cg.i32.lt_s();
|
|
3655
3811
|
else if (srcDtype === DType.Uint32) cg.i32.lt_u();
|
|
3656
3812
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3657
3813
|
} else if (op === AluOp.Cmpne) dty(cg, op, src[0].dtype).ne();
|
|
3658
3814
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3659
|
-
} else if (AluGroup.Unary.has(op))
|
|
3660
|
-
|
|
3661
|
-
|
|
3662
|
-
|
|
3663
|
-
|
|
3664
|
-
|
|
3665
|
-
|
|
3666
|
-
|
|
3667
|
-
|
|
3668
|
-
gen(src[0]);
|
|
3669
|
-
|
|
3670
|
-
|
|
3671
|
-
if (
|
|
3672
|
-
else if (
|
|
3673
|
-
else
|
|
3674
|
-
else if (
|
|
3675
|
-
else if (
|
|
3676
|
-
|
|
3677
|
-
|
|
3678
|
-
else if (
|
|
3679
|
-
|
|
3680
|
-
|
|
3681
|
-
|
|
3682
|
-
|
|
3683
|
-
|
|
3684
|
-
|
|
3685
|
-
|
|
3686
|
-
|
|
3687
|
-
|
|
3688
|
-
|
|
3689
|
-
|
|
3690
|
-
|
|
3691
|
-
|
|
3692
|
-
|
|
3693
|
-
|
|
3694
|
-
|
|
3695
|
-
|
|
3696
|
-
|
|
3697
|
-
|
|
3698
|
-
|
|
3815
|
+
} else if (AluGroup.Unary.has(op)) {
|
|
3816
|
+
const callFuncF32 = (func) => {
|
|
3817
|
+
if (dtype !== DType.Float32) if (dtype === DType.Float64) cg.f32.demote_f64();
|
|
3818
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3819
|
+
cg.call(func);
|
|
3820
|
+
if (dtype === DType.Float64) cg.f64.promote_f32();
|
|
3821
|
+
};
|
|
3822
|
+
if (op === AluOp.Sin) gen(src[0]), callFuncF32(funcs.sin);
|
|
3823
|
+
else if (op === AluOp.Cos) gen(src[0]), callFuncF32(funcs.cos);
|
|
3824
|
+
else if (op === AluOp.Asin) gen(src[0]), callFuncF32(funcs.asin);
|
|
3825
|
+
else if (op === AluOp.Atan) gen(src[0]), callFuncF32(funcs.atan);
|
|
3826
|
+
else if (op === AluOp.Exp) gen(src[0]), callFuncF32(funcs.exp);
|
|
3827
|
+
else if (op === AluOp.Log) gen(src[0]), callFuncF32(funcs.log);
|
|
3828
|
+
else if (op === AluOp.Erf) gen(src[0]), callFuncF32(funcs.erf);
|
|
3829
|
+
else if (op === AluOp.Erfc) gen(src[0]), callFuncF32(funcs.erfc);
|
|
3830
|
+
else if (op === AluOp.Sqrt) gen(src[0]), dtyF(cg, op, dtype).sqrt();
|
|
3831
|
+
else if (op === AluOp.Reciprocal) {
|
|
3832
|
+
const dt = dtyF(cg, op, dtype);
|
|
3833
|
+
dt.const(1), gen(src[0]), dt.div();
|
|
3834
|
+
} else if (op === AluOp.Cast) {
|
|
3835
|
+
gen(src[0]);
|
|
3836
|
+
const dtype0 = src[0].dtype;
|
|
3837
|
+
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
|
|
3838
|
+
if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
|
|
3839
|
+
else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_s();
|
|
3840
|
+
else if (i32repr);
|
|
3841
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3842
|
+
else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
|
|
3843
|
+
else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_u();
|
|
3844
|
+
else if (i32repr);
|
|
3845
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3846
|
+
else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
|
|
3847
|
+
else if (dtype0 === DType.Float64) cg.f32.demote_f64();
|
|
3848
|
+
else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
|
|
3849
|
+
else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
|
|
3850
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3851
|
+
else if (dtype === DType.Float64) if (dtype0 === DType.Float32) cg.f64.promote_f32();
|
|
3852
|
+
else if (dtype0 === DType.Float64);
|
|
3853
|
+
else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f64.convert_i32_s();
|
|
3854
|
+
else if (dtype0 === DType.Uint32) cg.f64.convert_i32_u();
|
|
3855
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3856
|
+
else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
|
|
3857
|
+
else if (i32repr) cg.i32.const(0), cg.i32.ne();
|
|
3858
|
+
else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
|
|
3859
|
+
else if (dtype0 === DType.Float64) cg.f64.const(0), cg.f64.ne();
|
|
3860
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3861
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3862
|
+
} else if (op === AluOp.Bitcast) {
|
|
3863
|
+
gen(src[0]);
|
|
3864
|
+
const dtype0 = src[0].dtype;
|
|
3865
|
+
if (dtype !== dtype0) {
|
|
3866
|
+
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
|
|
3867
|
+
if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
|
|
3868
|
+
else if (i32repr);
|
|
3869
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3870
|
+
else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
|
|
3871
|
+
else if (dtype0 === DType.Float32);
|
|
3872
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3873
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3874
|
+
}
|
|
3875
|
+
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3876
|
+
} else if (op === AluOp.Where) {
|
|
3699
3877
|
gen(src[1]);
|
|
3700
3878
|
gen(src[2]);
|
|
3701
3879
|
gen(src[0]);
|
|
@@ -3740,12 +3918,20 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3740
3918
|
function dty(cg, op, dtype) {
|
|
3741
3919
|
switch (dtype) {
|
|
3742
3920
|
case DType.Float32: return cg.f32;
|
|
3921
|
+
case DType.Float64: return cg.f64;
|
|
3743
3922
|
case DType.Int32:
|
|
3744
3923
|
case DType.Uint32:
|
|
3745
3924
|
case DType.Bool: return cg.i32;
|
|
3746
3925
|
default: throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3747
3926
|
}
|
|
3748
3927
|
}
|
|
3928
|
+
function dtyF(cg, op, dtype) {
|
|
3929
|
+
switch (dtype) {
|
|
3930
|
+
case DType.Float32: return cg.f32;
|
|
3931
|
+
case DType.Float64: return cg.f64;
|
|
3932
|
+
default: throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3933
|
+
}
|
|
3934
|
+
}
|
|
3749
3935
|
|
|
3750
3936
|
//#endregion
|
|
3751
3937
|
//#region src/backend.ts
|
|
@@ -3754,10 +3940,10 @@ const devices = [
|
|
|
3754
3940
|
"wasm",
|
|
3755
3941
|
"webgpu"
|
|
3756
3942
|
];
|
|
3757
|
-
let defaultBackend = "wasm";
|
|
3758
3943
|
const initializedBackends = /* @__PURE__ */ new Map();
|
|
3759
3944
|
initializedBackends.set("cpu", new CpuBackend());
|
|
3760
|
-
initializedBackends.set("wasm", new WasmBackend());
|
|
3945
|
+
if (typeof WebAssembly !== "undefined") initializedBackends.set("wasm", new WasmBackend());
|
|
3946
|
+
let defaultBackend = initializedBackends.has("wasm") ? "wasm" : "cpu";
|
|
3761
3947
|
/** Configure the default device for arrays. */
|
|
3762
3948
|
function defaultDevice(device) {
|
|
3763
3949
|
if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
|
|
@@ -3784,12 +3970,14 @@ async function init(...devicesToInit) {
|
|
|
3784
3970
|
/** Create a backend, if available. Internal function called by `init()`. */
|
|
3785
3971
|
async function createBackend(device) {
|
|
3786
3972
|
if (device === "cpu") return new CpuBackend();
|
|
3787
|
-
else if (device === "wasm")
|
|
3788
|
-
|
|
3973
|
+
else if (device === "wasm") {
|
|
3974
|
+
if (typeof WebAssembly === "undefined") return null;
|
|
3975
|
+
return new WasmBackend();
|
|
3976
|
+
} else if (device === "webgpu") {
|
|
3789
3977
|
if (!navigator.gpu) return null;
|
|
3790
3978
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3791
3979
|
if (!adapter) return null;
|
|
3792
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
3980
|
+
const { WebGPUBackend } = await import("./webgpu-B3UVme6n.js");
|
|
3793
3981
|
const importantLimits = [
|
|
3794
3982
|
"maxBufferSize",
|
|
3795
3983
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -3842,4 +4030,5 @@ var UnsupportedOpError = class extends Error {
|
|
|
3842
4030
|
};
|
|
3843
4031
|
|
|
3844
4032
|
//#endregion
|
|
3845
|
-
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu,
|
|
4033
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
4034
|
+
//# sourceMappingURL=backend-CoVtc9dx.js.map
|