@jax-js/jax 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/{backend-FtkbO6pI.cjs → backend-BbrKEB18.cjs} +165 -69
- package/dist/{backend-DwIAd0AG.js → backend-CoVtc9dx.js} +165 -69
- package/dist/index.cjs +42 -8
- package/dist/index.d.cts +20 -6
- package/dist/index.d.ts +20 -6
- package/dist/index.js +42 -8
- package/dist/{webgpu-LGi2A3mS.js → webgpu-B3UVme6n.js} +9 -4
- package/dist/{webgpu-BE7zA_01.cjs → webgpu-DGYNVHma.cjs} +9 -4
- package/package.json +21 -13
|
@@ -307,6 +307,7 @@ let DType = /* @__PURE__ */ function(DType$1) {
|
|
|
307
307
|
DType$1["Uint32"] = "uint32";
|
|
308
308
|
DType$1["Bool"] = "bool";
|
|
309
309
|
DType$1["Float16"] = "float16";
|
|
310
|
+
DType$1["Float64"] = "float64";
|
|
310
311
|
return DType$1;
|
|
311
312
|
}({});
|
|
312
313
|
const byteWidth = (dtype) => {
|
|
@@ -316,10 +317,11 @@ const byteWidth = (dtype) => {
|
|
|
316
317
|
case DType.Uint32:
|
|
317
318
|
case DType.Bool: return 4;
|
|
318
319
|
case DType.Float16: return 2;
|
|
320
|
+
case DType.Float64: return 8;
|
|
319
321
|
default: throw new TypeError(`Unknown dtype: ${dtype}`);
|
|
320
322
|
}
|
|
321
323
|
};
|
|
322
|
-
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
|
|
324
|
+
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16 || dtype === DType.Float64;
|
|
323
325
|
/**
|
|
324
326
|
* Promote two dtypes to their join according to the type lattice.
|
|
325
327
|
*
|
|
@@ -329,7 +331,7 @@ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float
|
|
|
329
331
|
*
|
|
330
332
|
* **Type lattice:**
|
|
331
333
|
* ```text
|
|
332
|
-
* bool -> uint32 -> int32 -> float16 -> float32
|
|
334
|
+
* bool -> uint32 -> int32 -> float16 -> float32 -> float64
|
|
333
335
|
* weakType --^
|
|
334
336
|
* ```
|
|
335
337
|
*
|
|
@@ -351,7 +353,8 @@ function promoteTypes(dtype1, dtype2) {
|
|
|
351
353
|
[DType.Uint32]: 1,
|
|
352
354
|
[DType.Int32]: 2,
|
|
353
355
|
[DType.Float16]: 3,
|
|
354
|
-
[DType.Float32]: 4
|
|
356
|
+
[DType.Float32]: 4,
|
|
357
|
+
[DType.Float64]: 5
|
|
355
358
|
};
|
|
356
359
|
return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
|
|
357
360
|
}
|
|
@@ -364,6 +367,7 @@ function dtypedArray(dtype, data) {
|
|
|
364
367
|
case DType.Bool: return new Int32Array(buffer, byteOffset, length);
|
|
365
368
|
case DType.Uint32: return new Uint32Array(buffer, byteOffset, length);
|
|
366
369
|
case DType.Float16: return new Float16Array(buffer, byteOffset, length);
|
|
370
|
+
case DType.Float64: return new Float64Array(buffer, byteOffset, length);
|
|
367
371
|
default: throw new Error(`Unimplemented dtype: ${dtype}`);
|
|
368
372
|
}
|
|
369
373
|
}
|
|
@@ -374,6 +378,7 @@ function dtypedJsArray(dtype, data) {
|
|
|
374
378
|
case DType.Bool: return new Int32Array(data);
|
|
375
379
|
case DType.Uint32: return new Uint32Array(data);
|
|
376
380
|
case DType.Float16: return new Float16Array(data);
|
|
381
|
+
case DType.Float64: return new Float64Array(data);
|
|
377
382
|
default: throw new Error(`Unimplemented dtype: ${dtype}`);
|
|
378
383
|
}
|
|
379
384
|
}
|
|
@@ -511,6 +516,9 @@ var AluExp = class AluExp {
|
|
|
511
516
|
static f16(value) {
|
|
512
517
|
return AluExp.const(DType.Float16, value);
|
|
513
518
|
}
|
|
519
|
+
static f64(value) {
|
|
520
|
+
return AluExp.const(DType.Float64, value);
|
|
521
|
+
}
|
|
514
522
|
not() {
|
|
515
523
|
if (this.dtype !== DType.Bool) throw new Error("not() can only be called on boolean expressions");
|
|
516
524
|
return AluExp.cmpne(this, AluExp.const(DType.Bool, true));
|
|
@@ -521,7 +529,8 @@ var AluExp = class AluExp {
|
|
|
521
529
|
const hasher = new FpHash();
|
|
522
530
|
hasher.update(this.op);
|
|
523
531
|
hasher.update(this.dtype);
|
|
524
|
-
hasher.update(
|
|
532
|
+
if (this.op === AluOp.Const) hasher.update(this.arg);
|
|
533
|
+
else hasher.update(JSON.stringify(this.arg));
|
|
525
534
|
hasher.update(this.src.length);
|
|
526
535
|
for (const s of this.src) hasher.update(s);
|
|
527
536
|
this.#hash = hasher.value;
|
|
@@ -758,6 +767,7 @@ var AluExp = class AluExp {
|
|
|
758
767
|
if (op === AluOp.Mul && x === 1) return src[1 - i];
|
|
759
768
|
if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
|
|
760
769
|
if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
|
|
770
|
+
if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
|
|
761
771
|
}
|
|
762
772
|
if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
|
|
763
773
|
const [a, b] = src[1].src;
|
|
@@ -959,11 +969,13 @@ var AluExp = class AluExp {
|
|
|
959
969
|
else if (fromType === DType.Int32) view.setInt32(0, x, true);
|
|
960
970
|
else if (fromType === DType.Uint32) view.setUint32(0, x, true);
|
|
961
971
|
else if (fromType === DType.Float16) view.setFloat16(0, x, true);
|
|
972
|
+
else if (fromType === DType.Float64) view.setFloat64(0, x, true);
|
|
962
973
|
else throw new Error(`Unsupported bitcast from ${fromType}`);
|
|
963
974
|
if (this.dtype === DType.Float32) return view.getFloat32(0, true);
|
|
964
975
|
else if (this.dtype === DType.Int32) return view.getInt32(0, true);
|
|
965
976
|
else if (this.dtype === DType.Uint32) return view.getUint32(0, true);
|
|
966
977
|
else if (this.dtype === DType.Float16) return view.getFloat16(0, true);
|
|
978
|
+
else if (this.dtype === DType.Float64) return view.getFloat64(0, true);
|
|
967
979
|
else throw new Error(`Unsupported bitcast to ${this.dtype}`);
|
|
968
980
|
}
|
|
969
981
|
default: throw new Error(`Missing implemementation for ${this.op}`);
|
|
@@ -2926,6 +2938,7 @@ var CodeGenerator = class {
|
|
|
2926
2938
|
local;
|
|
2927
2939
|
i32;
|
|
2928
2940
|
f32;
|
|
2941
|
+
f64;
|
|
2929
2942
|
v128;
|
|
2930
2943
|
i32x4;
|
|
2931
2944
|
f32x4;
|
|
@@ -2945,6 +2958,7 @@ var CodeGenerator = class {
|
|
|
2945
2958
|
this.local = new Local(this);
|
|
2946
2959
|
this.i32 = new I32(this);
|
|
2947
2960
|
this.f32 = new F32(this);
|
|
2961
|
+
this.f64 = new F64(this);
|
|
2948
2962
|
this.v128 = new V128(this);
|
|
2949
2963
|
this.i32x4 = new I32x4(this);
|
|
2950
2964
|
this.f32x4 = new F32x4(this);
|
|
@@ -3331,6 +3345,8 @@ var I32 = class {
|
|
|
3331
3345
|
ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
|
|
3332
3346
|
trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
|
|
3333
3347
|
trunc_f32_u = UNARY_OP("trunc_f32_u", 169, "f32", "i32");
|
|
3348
|
+
trunc_f64_s = UNARY_OP("trunc_f64_s", 170, "f64", "i32");
|
|
3349
|
+
trunc_f64_u = UNARY_OP("trunc_f64_u", 171, "f64", "i32");
|
|
3334
3350
|
load = LOAD_OP("load", 40, "i32");
|
|
3335
3351
|
load8_s = LOAD_OP("load8_s", 44, "i32");
|
|
3336
3352
|
load8_u = LOAD_OP("load8_u", 45, "i32");
|
|
@@ -3342,6 +3358,8 @@ var I32 = class {
|
|
|
3342
3358
|
reinterpret_f32 = UNARY_OP("reinterpret_f32", 188, "f32", "i32");
|
|
3343
3359
|
trunc_sat_f32_s = UNARY_OP("trunc_sat_f32_s", [252, 0], "f32", "i32");
|
|
3344
3360
|
trunc_sat_f32_u = UNARY_OP("trunc_sat_f32_u", [252, 1], "f32", "i32");
|
|
3361
|
+
trunc_sat_f64_s = UNARY_OP("trunc_sat_f64_s", [252, 2], "f64", "i32");
|
|
3362
|
+
trunc_sat_f64_u = UNARY_OP("trunc_sat_f64_u", [252, 3], "f64", "i32");
|
|
3345
3363
|
};
|
|
3346
3364
|
var F32 = class {
|
|
3347
3365
|
constructor(cg) {
|
|
@@ -3361,6 +3379,8 @@ var F32 = class {
|
|
|
3361
3379
|
for (let i = 0; i < 4; i++) this.cg._emit(bytes[i]);
|
|
3362
3380
|
this.cg._push(this);
|
|
3363
3381
|
}
|
|
3382
|
+
load = LOAD_OP("load", 42, "f32");
|
|
3383
|
+
store = STORE_OP("store", 56, "f32");
|
|
3364
3384
|
eq = BINARY_OP("eq", 91, "f32", "f32", "i32");
|
|
3365
3385
|
ne = BINARY_OP("ne", 92, "f32", "f32", "i32");
|
|
3366
3386
|
lt = BINARY_OP("lt", 93, "f32", "f32", "i32");
|
|
@@ -3383,10 +3403,53 @@ var F32 = class {
|
|
|
3383
3403
|
copysign = BINARY_OP("copysign", 152, "f32", "f32", "f32");
|
|
3384
3404
|
convert_i32_s = UNARY_OP("convert_i32_s", 178, "i32", "f32");
|
|
3385
3405
|
convert_i32_u = UNARY_OP("convert_i32_u", 179, "i32", "f32");
|
|
3386
|
-
|
|
3387
|
-
store = STORE_OP("store", 56, "f32");
|
|
3406
|
+
demote_f64 = UNARY_OP("demote_f64", 182, "f64", "f32");
|
|
3388
3407
|
reinterpret_i32 = UNARY_OP("reinterpret_i32", 190, "i32", "f32");
|
|
3389
3408
|
};
|
|
3409
|
+
var F64 = class {
|
|
3410
|
+
constructor(cg) {
|
|
3411
|
+
this.cg = cg;
|
|
3412
|
+
}
|
|
3413
|
+
get typeId() {
|
|
3414
|
+
return 124;
|
|
3415
|
+
}
|
|
3416
|
+
get name() {
|
|
3417
|
+
return "f64";
|
|
3418
|
+
}
|
|
3419
|
+
const(f) {
|
|
3420
|
+
this.cg._emit(68);
|
|
3421
|
+
const buffer = /* @__PURE__ */ new ArrayBuffer(8);
|
|
3422
|
+
new DataView(buffer).setFloat64(0, f, true);
|
|
3423
|
+
const bytes = new Uint8Array(buffer);
|
|
3424
|
+
for (let i = 0; i < 8; i++) this.cg._emit(bytes[i]);
|
|
3425
|
+
this.cg._push(this);
|
|
3426
|
+
}
|
|
3427
|
+
load = LOAD_OP("load", 43, "f64");
|
|
3428
|
+
store = STORE_OP("store", 57, "f64");
|
|
3429
|
+
eq = BINARY_OP("eq", 97, "f64", "f64", "i32");
|
|
3430
|
+
ne = BINARY_OP("ne", 98, "f64", "f64", "i32");
|
|
3431
|
+
lt = BINARY_OP("lt", 99, "f64", "f64", "i32");
|
|
3432
|
+
gt = BINARY_OP("gt", 100, "f64", "f64", "i32");
|
|
3433
|
+
le = BINARY_OP("le", 101, "f64", "f64", "i32");
|
|
3434
|
+
ge = BINARY_OP("ge", 102, "f64", "f64", "i32");
|
|
3435
|
+
abs = UNARY_OP("abs", 153, "f64", "f64");
|
|
3436
|
+
neg = UNARY_OP("neg", 154, "f64", "f64");
|
|
3437
|
+
ceil = UNARY_OP("ceil", 155, "f64", "f64");
|
|
3438
|
+
floor = UNARY_OP("floor", 156, "f64", "f64");
|
|
3439
|
+
trunc = UNARY_OP("trunc", 157, "f64", "f64");
|
|
3440
|
+
nearest = UNARY_OP("nearest", 158, "f64", "f64");
|
|
3441
|
+
sqrt = UNARY_OP("sqrt", 159, "f64", "f64");
|
|
3442
|
+
add = BINARY_OP("add", 160, "f64", "f64", "f64");
|
|
3443
|
+
sub = BINARY_OP("sub", 161, "f64", "f64", "f64");
|
|
3444
|
+
mul = BINARY_OP("mul", 162, "f64", "f64", "f64");
|
|
3445
|
+
div = BINARY_OP("div", 163, "f64", "f64", "f64");
|
|
3446
|
+
min = BINARY_OP("min", 164, "f64", "f64", "f64");
|
|
3447
|
+
max = BINARY_OP("max", 165, "f64", "f64", "f64");
|
|
3448
|
+
copysign = BINARY_OP("copysign", 166, "f64", "f64", "f64");
|
|
3449
|
+
convert_i32_s = UNARY_OP("convert_i32_s", 183, "i32", "f64");
|
|
3450
|
+
convert_i32_u = UNARY_OP("convert_i32_u", 184, "i32", "f64");
|
|
3451
|
+
promote_f32 = UNARY_OP("promote_f32", 187, "f32", "f64");
|
|
3452
|
+
};
|
|
3390
3453
|
function VECTOR_OP(op, vopcode, inTypes, outType) {
|
|
3391
3454
|
return function() {
|
|
3392
3455
|
for (const inType of inTypes.toReversed()) {
|
|
@@ -3639,10 +3702,10 @@ function codegenWasm(kernel) {
|
|
|
3639
3702
|
cg.local.get(acc);
|
|
3640
3703
|
if (re.dtype === DType.Bool) cg.i32.and();
|
|
3641
3704
|
else dty(cg, re.op, re.dtype).mul();
|
|
3642
|
-
} else if (re.op === AluOp.Min || re.op === AluOp.Max) if (re.dtype
|
|
3705
|
+
} else if (re.op === AluOp.Min || re.op === AluOp.Max) if (isFloatDtype(re.dtype)) {
|
|
3643
3706
|
cg.local.get(acc);
|
|
3644
|
-
if (re.op === AluOp.Min) cg.
|
|
3645
|
-
else cg.
|
|
3707
|
+
if (re.op === AluOp.Min) dtyF(cg, re.op, re.dtype).min();
|
|
3708
|
+
else dtyF(cg, re.op, re.dtype).max();
|
|
3646
3709
|
} else if ([
|
|
3647
3710
|
DType.Int32,
|
|
3648
3711
|
DType.Uint32,
|
|
@@ -3704,27 +3767,30 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3704
3767
|
else if (op === AluOp.Sub) dty(cg, op, dtype).sub();
|
|
3705
3768
|
else if (op === AluOp.Mul) if (dtype === DType.Bool) cg.i32.and();
|
|
3706
3769
|
else dty(cg, op, dtype).mul();
|
|
3707
|
-
else if (op === AluOp.Idiv) if (dtype
|
|
3708
|
-
|
|
3770
|
+
else if (op === AluOp.Idiv) if (isFloatDtype(dtype)) {
|
|
3771
|
+
dtyF(cg, op, dtype).div();
|
|
3772
|
+
dtyF(cg, op, dtype).trunc();
|
|
3773
|
+
} else if (dtype === DType.Uint32) cg.i32.div_u();
|
|
3709
3774
|
else if (dtype === DType.Int32) cg.i32.div_s();
|
|
3710
3775
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3711
|
-
else if (op === AluOp.Mod) if (dtype
|
|
3712
|
-
const
|
|
3713
|
-
const
|
|
3776
|
+
else if (op === AluOp.Mod) if (isFloatDtype(dtype)) {
|
|
3777
|
+
const dt = dtyF(cg, op, dtype);
|
|
3778
|
+
const a = cg.local.declare(dt);
|
|
3779
|
+
const b = cg.local.declare(dt);
|
|
3714
3780
|
cg.local.set(b);
|
|
3715
3781
|
cg.local.tee(a);
|
|
3716
3782
|
cg.local.get(a);
|
|
3717
3783
|
cg.local.get(b);
|
|
3718
|
-
|
|
3719
|
-
|
|
3784
|
+
dt.div();
|
|
3785
|
+
dt.trunc();
|
|
3720
3786
|
cg.local.get(b);
|
|
3721
|
-
|
|
3722
|
-
|
|
3787
|
+
dt.mul();
|
|
3788
|
+
dt.sub();
|
|
3723
3789
|
} else if (dtype === DType.Uint32) cg.i32.rem_u();
|
|
3724
3790
|
else if (dtype === DType.Int32) cg.i32.rem_s();
|
|
3725
3791
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3726
|
-
else if (op === AluOp.Min || op === AluOp.Max) if (dtype
|
|
3727
|
-
else cg.
|
|
3792
|
+
else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
|
|
3793
|
+
else dtyF(cg, op, dtype).max();
|
|
3728
3794
|
else if (dtype === DType.Int32 || dtype === DType.Uint32) {
|
|
3729
3795
|
const a = cg.local.declare(cg.i32);
|
|
3730
3796
|
const b = cg.local.declare(cg.i32);
|
|
@@ -3741,54 +3807,74 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3741
3807
|
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3742
3808
|
else if (op === AluOp.Cmplt) {
|
|
3743
3809
|
const srcDtype = src[0].dtype;
|
|
3744
|
-
if (srcDtype
|
|
3810
|
+
if (isFloatDtype(srcDtype)) dtyF(cg, op, srcDtype).lt();
|
|
3745
3811
|
else if (srcDtype === DType.Int32) cg.i32.lt_s();
|
|
3746
3812
|
else if (srcDtype === DType.Uint32) cg.i32.lt_u();
|
|
3747
3813
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3748
3814
|
} else if (op === AluOp.Cmpne) dty(cg, op, src[0].dtype).ne();
|
|
3749
3815
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3750
|
-
} else if (AluGroup.Unary.has(op))
|
|
3751
|
-
|
|
3752
|
-
|
|
3753
|
-
|
|
3754
|
-
|
|
3755
|
-
|
|
3756
|
-
|
|
3757
|
-
|
|
3758
|
-
|
|
3759
|
-
|
|
3760
|
-
|
|
3761
|
-
gen(src[0]);
|
|
3762
|
-
|
|
3763
|
-
|
|
3764
|
-
if (
|
|
3765
|
-
else if (
|
|
3766
|
-
else
|
|
3767
|
-
|
|
3768
|
-
|
|
3769
|
-
else
|
|
3770
|
-
|
|
3771
|
-
|
|
3772
|
-
|
|
3773
|
-
|
|
3774
|
-
|
|
3775
|
-
|
|
3776
|
-
|
|
3777
|
-
|
|
3778
|
-
|
|
3779
|
-
|
|
3780
|
-
|
|
3781
|
-
|
|
3782
|
-
|
|
3783
|
-
|
|
3784
|
-
|
|
3785
|
-
|
|
3786
|
-
|
|
3787
|
-
|
|
3788
|
-
|
|
3789
|
-
|
|
3790
|
-
|
|
3791
|
-
|
|
3816
|
+
} else if (AluGroup.Unary.has(op)) {
|
|
3817
|
+
const callFuncF32 = (func) => {
|
|
3818
|
+
if (dtype !== DType.Float32) if (dtype === DType.Float64) cg.f32.demote_f64();
|
|
3819
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3820
|
+
cg.call(func);
|
|
3821
|
+
if (dtype === DType.Float64) cg.f64.promote_f32();
|
|
3822
|
+
};
|
|
3823
|
+
if (op === AluOp.Sin) gen(src[0]), callFuncF32(funcs.sin);
|
|
3824
|
+
else if (op === AluOp.Cos) gen(src[0]), callFuncF32(funcs.cos);
|
|
3825
|
+
else if (op === AluOp.Asin) gen(src[0]), callFuncF32(funcs.asin);
|
|
3826
|
+
else if (op === AluOp.Atan) gen(src[0]), callFuncF32(funcs.atan);
|
|
3827
|
+
else if (op === AluOp.Exp) gen(src[0]), callFuncF32(funcs.exp);
|
|
3828
|
+
else if (op === AluOp.Log) gen(src[0]), callFuncF32(funcs.log);
|
|
3829
|
+
else if (op === AluOp.Erf) gen(src[0]), callFuncF32(funcs.erf);
|
|
3830
|
+
else if (op === AluOp.Erfc) gen(src[0]), callFuncF32(funcs.erfc);
|
|
3831
|
+
else if (op === AluOp.Sqrt) gen(src[0]), dtyF(cg, op, dtype).sqrt();
|
|
3832
|
+
else if (op === AluOp.Reciprocal) {
|
|
3833
|
+
const dt = dtyF(cg, op, dtype);
|
|
3834
|
+
dt.const(1), gen(src[0]), dt.div();
|
|
3835
|
+
} else if (op === AluOp.Cast) {
|
|
3836
|
+
gen(src[0]);
|
|
3837
|
+
const dtype0 = src[0].dtype;
|
|
3838
|
+
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
|
|
3839
|
+
if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
|
|
3840
|
+
else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_s();
|
|
3841
|
+
else if (i32repr);
|
|
3842
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3843
|
+
else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
|
|
3844
|
+
else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_u();
|
|
3845
|
+
else if (i32repr);
|
|
3846
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3847
|
+
else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
|
|
3848
|
+
else if (dtype0 === DType.Float64) cg.f32.demote_f64();
|
|
3849
|
+
else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
|
|
3850
|
+
else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
|
|
3851
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3852
|
+
else if (dtype === DType.Float64) if (dtype0 === DType.Float32) cg.f64.promote_f32();
|
|
3853
|
+
else if (dtype0 === DType.Float64);
|
|
3854
|
+
else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f64.convert_i32_s();
|
|
3855
|
+
else if (dtype0 === DType.Uint32) cg.f64.convert_i32_u();
|
|
3856
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3857
|
+
else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
|
|
3858
|
+
else if (i32repr) cg.i32.const(0), cg.i32.ne();
|
|
3859
|
+
else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
|
|
3860
|
+
else if (dtype0 === DType.Float64) cg.f64.const(0), cg.f64.ne();
|
|
3861
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3862
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3863
|
+
} else if (op === AluOp.Bitcast) {
|
|
3864
|
+
gen(src[0]);
|
|
3865
|
+
const dtype0 = src[0].dtype;
|
|
3866
|
+
if (dtype !== dtype0) {
|
|
3867
|
+
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
|
|
3868
|
+
if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
|
|
3869
|
+
else if (i32repr);
|
|
3870
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3871
|
+
else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
|
|
3872
|
+
else if (dtype0 === DType.Float32);
|
|
3873
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3874
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3875
|
+
}
|
|
3876
|
+
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3877
|
+
} else if (op === AluOp.Where) {
|
|
3792
3878
|
gen(src[1]);
|
|
3793
3879
|
gen(src[2]);
|
|
3794
3880
|
gen(src[0]);
|
|
@@ -3833,12 +3919,20 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3833
3919
|
function dty(cg, op, dtype) {
|
|
3834
3920
|
switch (dtype) {
|
|
3835
3921
|
case DType.Float32: return cg.f32;
|
|
3922
|
+
case DType.Float64: return cg.f64;
|
|
3836
3923
|
case DType.Int32:
|
|
3837
3924
|
case DType.Uint32:
|
|
3838
3925
|
case DType.Bool: return cg.i32;
|
|
3839
3926
|
default: throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3840
3927
|
}
|
|
3841
3928
|
}
|
|
3929
|
+
function dtyF(cg, op, dtype) {
|
|
3930
|
+
switch (dtype) {
|
|
3931
|
+
case DType.Float32: return cg.f32;
|
|
3932
|
+
case DType.Float64: return cg.f64;
|
|
3933
|
+
default: throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3934
|
+
}
|
|
3935
|
+
}
|
|
3842
3936
|
|
|
3843
3937
|
//#endregion
|
|
3844
3938
|
//#region src/backend.ts
|
|
@@ -3847,10 +3941,10 @@ const devices = [
|
|
|
3847
3941
|
"wasm",
|
|
3848
3942
|
"webgpu"
|
|
3849
3943
|
];
|
|
3850
|
-
let defaultBackend = "wasm";
|
|
3851
3944
|
const initializedBackends = /* @__PURE__ */ new Map();
|
|
3852
3945
|
initializedBackends.set("cpu", new CpuBackend());
|
|
3853
|
-
initializedBackends.set("wasm", new WasmBackend());
|
|
3946
|
+
if (typeof WebAssembly !== "undefined") initializedBackends.set("wasm", new WasmBackend());
|
|
3947
|
+
let defaultBackend = initializedBackends.has("wasm") ? "wasm" : "cpu";
|
|
3854
3948
|
/** Configure the default device for arrays. */
|
|
3855
3949
|
function defaultDevice(device) {
|
|
3856
3950
|
if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
|
|
@@ -3877,12 +3971,14 @@ async function init(...devicesToInit) {
|
|
|
3877
3971
|
/** Create a backend, if available. Internal function called by `init()`. */
|
|
3878
3972
|
async function createBackend(device) {
|
|
3879
3973
|
if (device === "cpu") return new CpuBackend();
|
|
3880
|
-
else if (device === "wasm")
|
|
3881
|
-
|
|
3974
|
+
else if (device === "wasm") {
|
|
3975
|
+
if (typeof WebAssembly === "undefined") return null;
|
|
3976
|
+
return new WasmBackend();
|
|
3977
|
+
} else if (device === "webgpu") {
|
|
3882
3978
|
if (!navigator.gpu) return null;
|
|
3883
3979
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3884
3980
|
if (!adapter) return null;
|
|
3885
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
3981
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-DGYNVHma.cjs"));
|
|
3886
3982
|
const importantLimits = [
|
|
3887
3983
|
"maxBufferSize",
|
|
3888
3984
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4229,4 +4325,4 @@ Object.defineProperty(exports, 'zipn', {
|
|
|
4229
4325
|
return zipn;
|
|
4230
4326
|
}
|
|
4231
4327
|
});
|
|
4232
|
-
//# sourceMappingURL=backend-
|
|
4328
|
+
//# sourceMappingURL=backend-BbrKEB18.cjs.map
|