@jax-js/jax 0.1.0 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/{backend-DwIAd0AG.js → backend-BqymqzuU.js} +194 -73
- package/dist/{backend-FtkbO6pI.cjs → backend-DeVfWEFS.cjs} +194 -73
- package/dist/index.cjs +2725 -2206
- package/dist/index.d.cts +964 -844
- package/dist/index.d.ts +964 -844
- package/dist/index.js +2698 -2179
- package/dist/{webgpu-LGi2A3mS.js → webgpu-BGuG58KZ.js} +20 -13
- package/dist/{webgpu-BE7zA_01.cjs → webgpu-CcGP160M.cjs} +20 -13
- package/package.json +1 -1
|
@@ -290,10 +290,11 @@ var FpHash = class FpHash {
|
|
|
290
290
|
};
|
|
291
291
|
/** Run a function while caching it inline inside a `Map`. */
|
|
292
292
|
function runWithCache(cache, key, thunk) {
|
|
293
|
-
|
|
293
|
+
const keyStr = JSON.stringify(key);
|
|
294
|
+
if (cache.has(keyStr)) return cache.get(keyStr);
|
|
294
295
|
else {
|
|
295
296
|
const value = thunk();
|
|
296
|
-
cache.set(
|
|
297
|
+
cache.set(keyStr, value);
|
|
297
298
|
return value;
|
|
298
299
|
}
|
|
299
300
|
}
|
|
@@ -307,6 +308,7 @@ let DType = /* @__PURE__ */ function(DType$1) {
|
|
|
307
308
|
DType$1["Uint32"] = "uint32";
|
|
308
309
|
DType$1["Bool"] = "bool";
|
|
309
310
|
DType$1["Float16"] = "float16";
|
|
311
|
+
DType$1["Float64"] = "float64";
|
|
310
312
|
return DType$1;
|
|
311
313
|
}({});
|
|
312
314
|
const byteWidth = (dtype) => {
|
|
@@ -316,10 +318,11 @@ const byteWidth = (dtype) => {
|
|
|
316
318
|
case DType.Uint32:
|
|
317
319
|
case DType.Bool: return 4;
|
|
318
320
|
case DType.Float16: return 2;
|
|
321
|
+
case DType.Float64: return 8;
|
|
319
322
|
default: throw new TypeError(`Unknown dtype: ${dtype}`);
|
|
320
323
|
}
|
|
321
324
|
};
|
|
322
|
-
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
|
|
325
|
+
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16 || dtype === DType.Float64;
|
|
323
326
|
/**
|
|
324
327
|
* Promote two dtypes to their join according to the type lattice.
|
|
325
328
|
*
|
|
@@ -329,7 +332,7 @@ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float
|
|
|
329
332
|
*
|
|
330
333
|
* **Type lattice:**
|
|
331
334
|
* ```text
|
|
332
|
-
* bool -> uint32 -> int32 -> float16 -> float32
|
|
335
|
+
* bool -> uint32 -> int32 -> float16 -> float32 -> float64
|
|
333
336
|
* weakType --^
|
|
334
337
|
* ```
|
|
335
338
|
*
|
|
@@ -351,7 +354,8 @@ function promoteTypes(dtype1, dtype2) {
|
|
|
351
354
|
[DType.Uint32]: 1,
|
|
352
355
|
[DType.Int32]: 2,
|
|
353
356
|
[DType.Float16]: 3,
|
|
354
|
-
[DType.Float32]: 4
|
|
357
|
+
[DType.Float32]: 4,
|
|
358
|
+
[DType.Float64]: 5
|
|
355
359
|
};
|
|
356
360
|
return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
|
|
357
361
|
}
|
|
@@ -364,6 +368,7 @@ function dtypedArray(dtype, data) {
|
|
|
364
368
|
case DType.Bool: return new Int32Array(buffer, byteOffset, length);
|
|
365
369
|
case DType.Uint32: return new Uint32Array(buffer, byteOffset, length);
|
|
366
370
|
case DType.Float16: return new Float16Array(buffer, byteOffset, length);
|
|
371
|
+
case DType.Float64: return new Float64Array(buffer, byteOffset, length);
|
|
367
372
|
default: throw new Error(`Unimplemented dtype: ${dtype}`);
|
|
368
373
|
}
|
|
369
374
|
}
|
|
@@ -374,6 +379,7 @@ function dtypedJsArray(dtype, data) {
|
|
|
374
379
|
case DType.Bool: return new Int32Array(data);
|
|
375
380
|
case DType.Uint32: return new Uint32Array(data);
|
|
376
381
|
case DType.Float16: return new Float16Array(data);
|
|
382
|
+
case DType.Float64: return new Float64Array(data);
|
|
377
383
|
default: throw new Error(`Unimplemented dtype: ${dtype}`);
|
|
378
384
|
}
|
|
379
385
|
}
|
|
@@ -445,6 +451,14 @@ var AluExp = class AluExp {
|
|
|
445
451
|
static sqrt(a) {
|
|
446
452
|
return new AluExp(AluOp.Sqrt, a.dtype, [a]);
|
|
447
453
|
}
|
|
454
|
+
static floor(a) {
|
|
455
|
+
if (!isFloatDtype(a.dtype)) return a;
|
|
456
|
+
return new AluExp(AluOp.Floor, a.dtype, [a]);
|
|
457
|
+
}
|
|
458
|
+
static ceil(a) {
|
|
459
|
+
if (!isFloatDtype(a.dtype)) return a;
|
|
460
|
+
return new AluExp(AluOp.Ceil, a.dtype, [a]);
|
|
461
|
+
}
|
|
448
462
|
static reciprocal(a) {
|
|
449
463
|
return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
|
|
450
464
|
}
|
|
@@ -511,6 +525,9 @@ var AluExp = class AluExp {
|
|
|
511
525
|
static f16(value) {
|
|
512
526
|
return AluExp.const(DType.Float16, value);
|
|
513
527
|
}
|
|
528
|
+
static f64(value) {
|
|
529
|
+
return AluExp.const(DType.Float64, value);
|
|
530
|
+
}
|
|
514
531
|
not() {
|
|
515
532
|
if (this.dtype !== DType.Bool) throw new Error("not() can only be called on boolean expressions");
|
|
516
533
|
return AluExp.cmpne(this, AluExp.const(DType.Bool, true));
|
|
@@ -521,7 +538,8 @@ var AluExp = class AluExp {
|
|
|
521
538
|
const hasher = new FpHash();
|
|
522
539
|
hasher.update(this.op);
|
|
523
540
|
hasher.update(this.dtype);
|
|
524
|
-
hasher.update(
|
|
541
|
+
if (this.op === AluOp.Const) hasher.update(this.arg);
|
|
542
|
+
else hasher.update(JSON.stringify(this.arg));
|
|
525
543
|
hasher.update(this.src.length);
|
|
526
544
|
for (const s of this.src) hasher.update(s);
|
|
527
545
|
this.#hash = hasher.value;
|
|
@@ -621,6 +639,12 @@ var AluExp = class AluExp {
|
|
|
621
639
|
case AluOp.Sqrt:
|
|
622
640
|
ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
|
|
623
641
|
break;
|
|
642
|
+
case AluOp.Floor:
|
|
643
|
+
ret = [Math.floor(src[0].min), Math.floor(src[0].max)];
|
|
644
|
+
break;
|
|
645
|
+
case AluOp.Ceil:
|
|
646
|
+
ret = [Math.ceil(src[0].min), Math.ceil(src[0].max)];
|
|
647
|
+
break;
|
|
624
648
|
case AluOp.Reciprocal:
|
|
625
649
|
if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
|
|
626
650
|
ret = [1 / src[0].max, 1 / src[0].min];
|
|
@@ -758,6 +782,7 @@ var AluExp = class AluExp {
|
|
|
758
782
|
if (op === AluOp.Mul && x === 1) return src[1 - i];
|
|
759
783
|
if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
|
|
760
784
|
if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
|
|
785
|
+
if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
|
|
761
786
|
}
|
|
762
787
|
if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
|
|
763
788
|
const [a, b] = src[1].src;
|
|
@@ -852,7 +877,7 @@ var AluExp = class AluExp {
|
|
|
852
877
|
else return p(p(src[0].src[0], src[1]), src[0].src[1]).simplify(cache);
|
|
853
878
|
if (src[1].op === op && src[1].src[1].op === AluOp.Const) return p(p(src[0], src[1].src[0]), src[1].src[1]).simplify(cache);
|
|
854
879
|
}
|
|
855
|
-
if (op === AluOp.Mod || op === AluOp.Idiv && src[1].#isConstInt()) {
|
|
880
|
+
if ((op === AluOp.Mod || op === AluOp.Idiv) && src[1].#isConstInt()) {
|
|
856
881
|
const [x, y] = src;
|
|
857
882
|
{
|
|
858
883
|
const factors = [];
|
|
@@ -942,6 +967,8 @@ var AluExp = class AluExp {
|
|
|
942
967
|
case AluOp.Erf: return erf(x);
|
|
943
968
|
case AluOp.Erfc: return erfc(x);
|
|
944
969
|
case AluOp.Sqrt: return Math.sqrt(x);
|
|
970
|
+
case AluOp.Floor: return Math.floor(x);
|
|
971
|
+
case AluOp.Ceil: return Math.ceil(x);
|
|
945
972
|
case AluOp.Reciprocal: return 1 / x;
|
|
946
973
|
case AluOp.Cast: {
|
|
947
974
|
const wasFloat = isFloatDtype(this.src[0].dtype);
|
|
@@ -959,11 +986,13 @@ var AluExp = class AluExp {
|
|
|
959
986
|
else if (fromType === DType.Int32) view.setInt32(0, x, true);
|
|
960
987
|
else if (fromType === DType.Uint32) view.setUint32(0, x, true);
|
|
961
988
|
else if (fromType === DType.Float16) view.setFloat16(0, x, true);
|
|
989
|
+
else if (fromType === DType.Float64) view.setFloat64(0, x, true);
|
|
962
990
|
else throw new Error(`Unsupported bitcast from ${fromType}`);
|
|
963
991
|
if (this.dtype === DType.Float32) return view.getFloat32(0, true);
|
|
964
992
|
else if (this.dtype === DType.Int32) return view.getInt32(0, true);
|
|
965
993
|
else if (this.dtype === DType.Uint32) return view.getUint32(0, true);
|
|
966
994
|
else if (this.dtype === DType.Float16) return view.getFloat16(0, true);
|
|
995
|
+
else if (this.dtype === DType.Float64) return view.getFloat64(0, true);
|
|
967
996
|
else throw new Error(`Unsupported bitcast to ${this.dtype}`);
|
|
968
997
|
}
|
|
969
998
|
default: throw new Error(`Missing implemementation for ${this.op}`);
|
|
@@ -1129,6 +1158,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
1129
1158
|
AluOp$1["Erf"] = "Erf";
|
|
1130
1159
|
AluOp$1["Erfc"] = "Erfc";
|
|
1131
1160
|
AluOp$1["Sqrt"] = "Sqrt";
|
|
1161
|
+
AluOp$1["Floor"] = "Floor";
|
|
1162
|
+
AluOp$1["Ceil"] = "Ceil";
|
|
1132
1163
|
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
1133
1164
|
AluOp$1["Cast"] = "Cast";
|
|
1134
1165
|
AluOp$1["Bitcast"] = "Bitcast";
|
|
@@ -1163,6 +1194,8 @@ const AluGroup = {
|
|
|
1163
1194
|
AluOp.Erf,
|
|
1164
1195
|
AluOp.Erfc,
|
|
1165
1196
|
AluOp.Sqrt,
|
|
1197
|
+
AluOp.Floor,
|
|
1198
|
+
AluOp.Ceil,
|
|
1166
1199
|
AluOp.Reciprocal,
|
|
1167
1200
|
AluOp.Cast,
|
|
1168
1201
|
AluOp.Bitcast
|
|
@@ -1190,7 +1223,9 @@ const AluGroup = {
|
|
|
1190
1223
|
AluOp.Erf,
|
|
1191
1224
|
AluOp.Erfc,
|
|
1192
1225
|
AluOp.Sqrt,
|
|
1193
|
-
AluOp.Reciprocal
|
|
1226
|
+
AluOp.Reciprocal,
|
|
1227
|
+
AluOp.Floor,
|
|
1228
|
+
AluOp.Ceil
|
|
1194
1229
|
])
|
|
1195
1230
|
};
|
|
1196
1231
|
/** Common variables that can be substituted in expressions. */
|
|
@@ -2926,6 +2961,7 @@ var CodeGenerator = class {
|
|
|
2926
2961
|
local;
|
|
2927
2962
|
i32;
|
|
2928
2963
|
f32;
|
|
2964
|
+
f64;
|
|
2929
2965
|
v128;
|
|
2930
2966
|
i32x4;
|
|
2931
2967
|
f32x4;
|
|
@@ -2945,6 +2981,7 @@ var CodeGenerator = class {
|
|
|
2945
2981
|
this.local = new Local(this);
|
|
2946
2982
|
this.i32 = new I32(this);
|
|
2947
2983
|
this.f32 = new F32(this);
|
|
2984
|
+
this.f64 = new F64(this);
|
|
2948
2985
|
this.v128 = new V128(this);
|
|
2949
2986
|
this.i32x4 = new I32x4(this);
|
|
2950
2987
|
this.f32x4 = new F32x4(this);
|
|
@@ -3331,6 +3368,8 @@ var I32 = class {
|
|
|
3331
3368
|
ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
|
|
3332
3369
|
trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
|
|
3333
3370
|
trunc_f32_u = UNARY_OP("trunc_f32_u", 169, "f32", "i32");
|
|
3371
|
+
trunc_f64_s = UNARY_OP("trunc_f64_s", 170, "f64", "i32");
|
|
3372
|
+
trunc_f64_u = UNARY_OP("trunc_f64_u", 171, "f64", "i32");
|
|
3334
3373
|
load = LOAD_OP("load", 40, "i32");
|
|
3335
3374
|
load8_s = LOAD_OP("load8_s", 44, "i32");
|
|
3336
3375
|
load8_u = LOAD_OP("load8_u", 45, "i32");
|
|
@@ -3342,6 +3381,8 @@ var I32 = class {
|
|
|
3342
3381
|
reinterpret_f32 = UNARY_OP("reinterpret_f32", 188, "f32", "i32");
|
|
3343
3382
|
trunc_sat_f32_s = UNARY_OP("trunc_sat_f32_s", [252, 0], "f32", "i32");
|
|
3344
3383
|
trunc_sat_f32_u = UNARY_OP("trunc_sat_f32_u", [252, 1], "f32", "i32");
|
|
3384
|
+
trunc_sat_f64_s = UNARY_OP("trunc_sat_f64_s", [252, 2], "f64", "i32");
|
|
3385
|
+
trunc_sat_f64_u = UNARY_OP("trunc_sat_f64_u", [252, 3], "f64", "i32");
|
|
3345
3386
|
};
|
|
3346
3387
|
var F32 = class {
|
|
3347
3388
|
constructor(cg) {
|
|
@@ -3361,6 +3402,8 @@ var F32 = class {
|
|
|
3361
3402
|
for (let i = 0; i < 4; i++) this.cg._emit(bytes[i]);
|
|
3362
3403
|
this.cg._push(this);
|
|
3363
3404
|
}
|
|
3405
|
+
load = LOAD_OP("load", 42, "f32");
|
|
3406
|
+
store = STORE_OP("store", 56, "f32");
|
|
3364
3407
|
eq = BINARY_OP("eq", 91, "f32", "f32", "i32");
|
|
3365
3408
|
ne = BINARY_OP("ne", 92, "f32", "f32", "i32");
|
|
3366
3409
|
lt = BINARY_OP("lt", 93, "f32", "f32", "i32");
|
|
@@ -3383,10 +3426,53 @@ var F32 = class {
|
|
|
3383
3426
|
copysign = BINARY_OP("copysign", 152, "f32", "f32", "f32");
|
|
3384
3427
|
convert_i32_s = UNARY_OP("convert_i32_s", 178, "i32", "f32");
|
|
3385
3428
|
convert_i32_u = UNARY_OP("convert_i32_u", 179, "i32", "f32");
|
|
3386
|
-
|
|
3387
|
-
store = STORE_OP("store", 56, "f32");
|
|
3429
|
+
demote_f64 = UNARY_OP("demote_f64", 182, "f64", "f32");
|
|
3388
3430
|
reinterpret_i32 = UNARY_OP("reinterpret_i32", 190, "i32", "f32");
|
|
3389
3431
|
};
|
|
3432
|
+
var F64 = class {
|
|
3433
|
+
constructor(cg) {
|
|
3434
|
+
this.cg = cg;
|
|
3435
|
+
}
|
|
3436
|
+
get typeId() {
|
|
3437
|
+
return 124;
|
|
3438
|
+
}
|
|
3439
|
+
get name() {
|
|
3440
|
+
return "f64";
|
|
3441
|
+
}
|
|
3442
|
+
const(f) {
|
|
3443
|
+
this.cg._emit(68);
|
|
3444
|
+
const buffer = /* @__PURE__ */ new ArrayBuffer(8);
|
|
3445
|
+
new DataView(buffer).setFloat64(0, f, true);
|
|
3446
|
+
const bytes = new Uint8Array(buffer);
|
|
3447
|
+
for (let i = 0; i < 8; i++) this.cg._emit(bytes[i]);
|
|
3448
|
+
this.cg._push(this);
|
|
3449
|
+
}
|
|
3450
|
+
load = LOAD_OP("load", 43, "f64");
|
|
3451
|
+
store = STORE_OP("store", 57, "f64");
|
|
3452
|
+
eq = BINARY_OP("eq", 97, "f64", "f64", "i32");
|
|
3453
|
+
ne = BINARY_OP("ne", 98, "f64", "f64", "i32");
|
|
3454
|
+
lt = BINARY_OP("lt", 99, "f64", "f64", "i32");
|
|
3455
|
+
gt = BINARY_OP("gt", 100, "f64", "f64", "i32");
|
|
3456
|
+
le = BINARY_OP("le", 101, "f64", "f64", "i32");
|
|
3457
|
+
ge = BINARY_OP("ge", 102, "f64", "f64", "i32");
|
|
3458
|
+
abs = UNARY_OP("abs", 153, "f64", "f64");
|
|
3459
|
+
neg = UNARY_OP("neg", 154, "f64", "f64");
|
|
3460
|
+
ceil = UNARY_OP("ceil", 155, "f64", "f64");
|
|
3461
|
+
floor = UNARY_OP("floor", 156, "f64", "f64");
|
|
3462
|
+
trunc = UNARY_OP("trunc", 157, "f64", "f64");
|
|
3463
|
+
nearest = UNARY_OP("nearest", 158, "f64", "f64");
|
|
3464
|
+
sqrt = UNARY_OP("sqrt", 159, "f64", "f64");
|
|
3465
|
+
add = BINARY_OP("add", 160, "f64", "f64", "f64");
|
|
3466
|
+
sub = BINARY_OP("sub", 161, "f64", "f64", "f64");
|
|
3467
|
+
mul = BINARY_OP("mul", 162, "f64", "f64", "f64");
|
|
3468
|
+
div = BINARY_OP("div", 163, "f64", "f64", "f64");
|
|
3469
|
+
min = BINARY_OP("min", 164, "f64", "f64", "f64");
|
|
3470
|
+
max = BINARY_OP("max", 165, "f64", "f64", "f64");
|
|
3471
|
+
copysign = BINARY_OP("copysign", 166, "f64", "f64", "f64");
|
|
3472
|
+
convert_i32_s = UNARY_OP("convert_i32_s", 183, "i32", "f64");
|
|
3473
|
+
convert_i32_u = UNARY_OP("convert_i32_u", 184, "i32", "f64");
|
|
3474
|
+
promote_f32 = UNARY_OP("promote_f32", 187, "f32", "f64");
|
|
3475
|
+
};
|
|
3390
3476
|
function VECTOR_OP(op, vopcode, inTypes, outType) {
|
|
3391
3477
|
return function() {
|
|
3392
3478
|
for (const inType of inTypes.toReversed()) {
|
|
@@ -3639,10 +3725,10 @@ function codegenWasm(kernel) {
|
|
|
3639
3725
|
cg.local.get(acc);
|
|
3640
3726
|
if (re.dtype === DType.Bool) cg.i32.and();
|
|
3641
3727
|
else dty(cg, re.op, re.dtype).mul();
|
|
3642
|
-
} else if (re.op === AluOp.Min || re.op === AluOp.Max) if (re.dtype
|
|
3728
|
+
} else if (re.op === AluOp.Min || re.op === AluOp.Max) if (isFloatDtype(re.dtype)) {
|
|
3643
3729
|
cg.local.get(acc);
|
|
3644
|
-
if (re.op === AluOp.Min) cg.
|
|
3645
|
-
else cg.
|
|
3730
|
+
if (re.op === AluOp.Min) dtyF(cg, re.op, re.dtype).min();
|
|
3731
|
+
else dtyF(cg, re.op, re.dtype).max();
|
|
3646
3732
|
} else if ([
|
|
3647
3733
|
DType.Int32,
|
|
3648
3734
|
DType.Uint32,
|
|
@@ -3704,27 +3790,30 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3704
3790
|
else if (op === AluOp.Sub) dty(cg, op, dtype).sub();
|
|
3705
3791
|
else if (op === AluOp.Mul) if (dtype === DType.Bool) cg.i32.and();
|
|
3706
3792
|
else dty(cg, op, dtype).mul();
|
|
3707
|
-
else if (op === AluOp.Idiv) if (dtype
|
|
3708
|
-
|
|
3793
|
+
else if (op === AluOp.Idiv) if (isFloatDtype(dtype)) {
|
|
3794
|
+
dtyF(cg, op, dtype).div();
|
|
3795
|
+
dtyF(cg, op, dtype).trunc();
|
|
3796
|
+
} else if (dtype === DType.Uint32) cg.i32.div_u();
|
|
3709
3797
|
else if (dtype === DType.Int32) cg.i32.div_s();
|
|
3710
3798
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3711
|
-
else if (op === AluOp.Mod) if (dtype
|
|
3712
|
-
const
|
|
3713
|
-
const
|
|
3799
|
+
else if (op === AluOp.Mod) if (isFloatDtype(dtype)) {
|
|
3800
|
+
const dt = dtyF(cg, op, dtype);
|
|
3801
|
+
const a = cg.local.declare(dt);
|
|
3802
|
+
const b = cg.local.declare(dt);
|
|
3714
3803
|
cg.local.set(b);
|
|
3715
3804
|
cg.local.tee(a);
|
|
3716
3805
|
cg.local.get(a);
|
|
3717
3806
|
cg.local.get(b);
|
|
3718
|
-
|
|
3719
|
-
|
|
3807
|
+
dt.div();
|
|
3808
|
+
dt.trunc();
|
|
3720
3809
|
cg.local.get(b);
|
|
3721
|
-
|
|
3722
|
-
|
|
3810
|
+
dt.mul();
|
|
3811
|
+
dt.sub();
|
|
3723
3812
|
} else if (dtype === DType.Uint32) cg.i32.rem_u();
|
|
3724
3813
|
else if (dtype === DType.Int32) cg.i32.rem_s();
|
|
3725
3814
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3726
|
-
else if (op === AluOp.Min || op === AluOp.Max) if (dtype
|
|
3727
|
-
else cg.
|
|
3815
|
+
else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
|
|
3816
|
+
else dtyF(cg, op, dtype).max();
|
|
3728
3817
|
else if (dtype === DType.Int32 || dtype === DType.Uint32) {
|
|
3729
3818
|
const a = cg.local.declare(cg.i32);
|
|
3730
3819
|
const b = cg.local.declare(cg.i32);
|
|
@@ -3741,54 +3830,76 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3741
3830
|
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3742
3831
|
else if (op === AluOp.Cmplt) {
|
|
3743
3832
|
const srcDtype = src[0].dtype;
|
|
3744
|
-
if (srcDtype
|
|
3833
|
+
if (isFloatDtype(srcDtype)) dtyF(cg, op, srcDtype).lt();
|
|
3745
3834
|
else if (srcDtype === DType.Int32) cg.i32.lt_s();
|
|
3746
3835
|
else if (srcDtype === DType.Uint32) cg.i32.lt_u();
|
|
3747
3836
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3748
3837
|
} else if (op === AluOp.Cmpne) dty(cg, op, src[0].dtype).ne();
|
|
3749
3838
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3750
|
-
} else if (AluGroup.Unary.has(op))
|
|
3751
|
-
|
|
3752
|
-
|
|
3753
|
-
|
|
3754
|
-
|
|
3755
|
-
|
|
3756
|
-
|
|
3757
|
-
|
|
3758
|
-
|
|
3759
|
-
|
|
3760
|
-
|
|
3761
|
-
gen(src[0]);
|
|
3762
|
-
|
|
3763
|
-
|
|
3764
|
-
if (
|
|
3765
|
-
else if (
|
|
3766
|
-
else
|
|
3767
|
-
|
|
3768
|
-
|
|
3769
|
-
else
|
|
3770
|
-
else if (
|
|
3771
|
-
else if (
|
|
3772
|
-
|
|
3773
|
-
|
|
3774
|
-
|
|
3775
|
-
|
|
3776
|
-
|
|
3777
|
-
|
|
3778
|
-
|
|
3779
|
-
|
|
3780
|
-
|
|
3781
|
-
|
|
3782
|
-
|
|
3783
|
-
|
|
3784
|
-
|
|
3785
|
-
|
|
3786
|
-
|
|
3787
|
-
|
|
3788
|
-
|
|
3789
|
-
|
|
3790
|
-
|
|
3791
|
-
|
|
3839
|
+
} else if (AluGroup.Unary.has(op)) {
|
|
3840
|
+
const callFuncF32 = (func) => {
|
|
3841
|
+
if (dtype !== DType.Float32) if (dtype === DType.Float64) cg.f32.demote_f64();
|
|
3842
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3843
|
+
cg.call(func);
|
|
3844
|
+
if (dtype === DType.Float64) cg.f64.promote_f32();
|
|
3845
|
+
};
|
|
3846
|
+
if (op === AluOp.Sin) gen(src[0]), callFuncF32(funcs.sin);
|
|
3847
|
+
else if (op === AluOp.Cos) gen(src[0]), callFuncF32(funcs.cos);
|
|
3848
|
+
else if (op === AluOp.Asin) gen(src[0]), callFuncF32(funcs.asin);
|
|
3849
|
+
else if (op === AluOp.Atan) gen(src[0]), callFuncF32(funcs.atan);
|
|
3850
|
+
else if (op === AluOp.Exp) gen(src[0]), callFuncF32(funcs.exp);
|
|
3851
|
+
else if (op === AluOp.Log) gen(src[0]), callFuncF32(funcs.log);
|
|
3852
|
+
else if (op === AluOp.Erf) gen(src[0]), callFuncF32(funcs.erf);
|
|
3853
|
+
else if (op === AluOp.Erfc) gen(src[0]), callFuncF32(funcs.erfc);
|
|
3854
|
+
else if (op === AluOp.Sqrt) gen(src[0]), dtyF(cg, op, dtype).sqrt();
|
|
3855
|
+
else if (op === AluOp.Reciprocal) {
|
|
3856
|
+
const dt = dtyF(cg, op, dtype);
|
|
3857
|
+
dt.const(1), gen(src[0]), dt.div();
|
|
3858
|
+
} else if (op === AluOp.Floor) gen(src[0]), dtyF(cg, op, dtype).floor();
|
|
3859
|
+
else if (op === AluOp.Ceil) gen(src[0]), dtyF(cg, op, dtype).ceil();
|
|
3860
|
+
else if (op === AluOp.Cast) {
|
|
3861
|
+
gen(src[0]);
|
|
3862
|
+
const dtype0 = src[0].dtype;
|
|
3863
|
+
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
|
|
3864
|
+
if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
|
|
3865
|
+
else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_s();
|
|
3866
|
+
else if (i32repr);
|
|
3867
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3868
|
+
else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
|
|
3869
|
+
else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_u();
|
|
3870
|
+
else if (i32repr);
|
|
3871
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3872
|
+
else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
|
|
3873
|
+
else if (dtype0 === DType.Float64) cg.f32.demote_f64();
|
|
3874
|
+
else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
|
|
3875
|
+
else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
|
|
3876
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3877
|
+
else if (dtype === DType.Float64) if (dtype0 === DType.Float32) cg.f64.promote_f32();
|
|
3878
|
+
else if (dtype0 === DType.Float64);
|
|
3879
|
+
else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f64.convert_i32_s();
|
|
3880
|
+
else if (dtype0 === DType.Uint32) cg.f64.convert_i32_u();
|
|
3881
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3882
|
+
else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
|
|
3883
|
+
else if (i32repr) cg.i32.const(0), cg.i32.ne();
|
|
3884
|
+
else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
|
|
3885
|
+
else if (dtype0 === DType.Float64) cg.f64.const(0), cg.f64.ne();
|
|
3886
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3887
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3888
|
+
} else if (op === AluOp.Bitcast) {
|
|
3889
|
+
gen(src[0]);
|
|
3890
|
+
const dtype0 = src[0].dtype;
|
|
3891
|
+
if (dtype !== dtype0) {
|
|
3892
|
+
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
|
|
3893
|
+
if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
|
|
3894
|
+
else if (i32repr);
|
|
3895
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3896
|
+
else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
|
|
3897
|
+
else if (dtype0 === DType.Float32);
|
|
3898
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3899
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3900
|
+
}
|
|
3901
|
+
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3902
|
+
} else if (op === AluOp.Where) {
|
|
3792
3903
|
gen(src[1]);
|
|
3793
3904
|
gen(src[2]);
|
|
3794
3905
|
gen(src[0]);
|
|
@@ -3833,12 +3944,20 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3833
3944
|
function dty(cg, op, dtype) {
|
|
3834
3945
|
switch (dtype) {
|
|
3835
3946
|
case DType.Float32: return cg.f32;
|
|
3947
|
+
case DType.Float64: return cg.f64;
|
|
3836
3948
|
case DType.Int32:
|
|
3837
3949
|
case DType.Uint32:
|
|
3838
3950
|
case DType.Bool: return cg.i32;
|
|
3839
3951
|
default: throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3840
3952
|
}
|
|
3841
3953
|
}
|
|
3954
|
+
function dtyF(cg, op, dtype) {
|
|
3955
|
+
switch (dtype) {
|
|
3956
|
+
case DType.Float32: return cg.f32;
|
|
3957
|
+
case DType.Float64: return cg.f64;
|
|
3958
|
+
default: throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3959
|
+
}
|
|
3960
|
+
}
|
|
3842
3961
|
|
|
3843
3962
|
//#endregion
|
|
3844
3963
|
//#region src/backend.ts
|
|
@@ -3847,10 +3966,10 @@ const devices = [
|
|
|
3847
3966
|
"wasm",
|
|
3848
3967
|
"webgpu"
|
|
3849
3968
|
];
|
|
3850
|
-
let defaultBackend = "wasm";
|
|
3851
3969
|
const initializedBackends = /* @__PURE__ */ new Map();
|
|
3852
3970
|
initializedBackends.set("cpu", new CpuBackend());
|
|
3853
|
-
initializedBackends.set("wasm", new WasmBackend());
|
|
3971
|
+
if (typeof WebAssembly !== "undefined") initializedBackends.set("wasm", new WasmBackend());
|
|
3972
|
+
let defaultBackend = initializedBackends.has("wasm") ? "wasm" : "cpu";
|
|
3854
3973
|
/** Configure the default device for arrays. */
|
|
3855
3974
|
function defaultDevice(device) {
|
|
3856
3975
|
if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
|
|
@@ -3877,12 +3996,14 @@ async function init(...devicesToInit) {
|
|
|
3877
3996
|
/** Create a backend, if available. Internal function called by `init()`. */
|
|
3878
3997
|
async function createBackend(device) {
|
|
3879
3998
|
if (device === "cpu") return new CpuBackend();
|
|
3880
|
-
else if (device === "wasm")
|
|
3881
|
-
|
|
3999
|
+
else if (device === "wasm") {
|
|
4000
|
+
if (typeof WebAssembly === "undefined") return null;
|
|
4001
|
+
return new WasmBackend();
|
|
4002
|
+
} else if (device === "webgpu") {
|
|
3882
4003
|
if (!navigator.gpu) return null;
|
|
3883
4004
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3884
4005
|
if (!adapter) return null;
|
|
3885
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
4006
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-CcGP160M.cjs"));
|
|
3886
4007
|
const importantLimits = [
|
|
3887
4008
|
"maxBufferSize",
|
|
3888
4009
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4229,4 +4350,4 @@ Object.defineProperty(exports, 'zipn', {
|
|
|
4229
4350
|
return zipn;
|
|
4230
4351
|
}
|
|
4231
4352
|
});
|
|
4232
|
-
//# sourceMappingURL=backend-
|
|
4353
|
+
//# sourceMappingURL=backend-DeVfWEFS.cjs.map
|