@jax-js/jax 0.1.0 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/{backend-DwIAd0AG.js → backend-BqymqzuU.js} +194 -73
- package/dist/{backend-FtkbO6pI.cjs → backend-DeVfWEFS.cjs} +194 -73
- package/dist/index.cjs +2725 -2206
- package/dist/index.d.cts +964 -844
- package/dist/index.d.ts +964 -844
- package/dist/index.js +2698 -2179
- package/dist/{webgpu-LGi2A3mS.js → webgpu-BGuG58KZ.js} +20 -13
- package/dist/{webgpu-BE7zA_01.cjs → webgpu-CcGP160M.cjs} +20 -13
- package/package.json +1 -1
|
@@ -289,10 +289,11 @@ var FpHash = class FpHash {
|
|
|
289
289
|
};
|
|
290
290
|
/** Run a function while caching it inline inside a `Map`. */
|
|
291
291
|
function runWithCache(cache, key, thunk) {
|
|
292
|
-
|
|
292
|
+
const keyStr = JSON.stringify(key);
|
|
293
|
+
if (cache.has(keyStr)) return cache.get(keyStr);
|
|
293
294
|
else {
|
|
294
295
|
const value = thunk();
|
|
295
|
-
cache.set(
|
|
296
|
+
cache.set(keyStr, value);
|
|
296
297
|
return value;
|
|
297
298
|
}
|
|
298
299
|
}
|
|
@@ -306,6 +307,7 @@ let DType = /* @__PURE__ */ function(DType$1) {
|
|
|
306
307
|
DType$1["Uint32"] = "uint32";
|
|
307
308
|
DType$1["Bool"] = "bool";
|
|
308
309
|
DType$1["Float16"] = "float16";
|
|
310
|
+
DType$1["Float64"] = "float64";
|
|
309
311
|
return DType$1;
|
|
310
312
|
}({});
|
|
311
313
|
const byteWidth = (dtype) => {
|
|
@@ -315,10 +317,11 @@ const byteWidth = (dtype) => {
|
|
|
315
317
|
case DType.Uint32:
|
|
316
318
|
case DType.Bool: return 4;
|
|
317
319
|
case DType.Float16: return 2;
|
|
320
|
+
case DType.Float64: return 8;
|
|
318
321
|
default: throw new TypeError(`Unknown dtype: ${dtype}`);
|
|
319
322
|
}
|
|
320
323
|
};
|
|
321
|
-
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
|
|
324
|
+
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16 || dtype === DType.Float64;
|
|
322
325
|
/**
|
|
323
326
|
* Promote two dtypes to their join according to the type lattice.
|
|
324
327
|
*
|
|
@@ -328,7 +331,7 @@ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float
|
|
|
328
331
|
*
|
|
329
332
|
* **Type lattice:**
|
|
330
333
|
* ```text
|
|
331
|
-
* bool -> uint32 -> int32 -> float16 -> float32
|
|
334
|
+
* bool -> uint32 -> int32 -> float16 -> float32 -> float64
|
|
332
335
|
* weakType --^
|
|
333
336
|
* ```
|
|
334
337
|
*
|
|
@@ -350,7 +353,8 @@ function promoteTypes(dtype1, dtype2) {
|
|
|
350
353
|
[DType.Uint32]: 1,
|
|
351
354
|
[DType.Int32]: 2,
|
|
352
355
|
[DType.Float16]: 3,
|
|
353
|
-
[DType.Float32]: 4
|
|
356
|
+
[DType.Float32]: 4,
|
|
357
|
+
[DType.Float64]: 5
|
|
354
358
|
};
|
|
355
359
|
return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
|
|
356
360
|
}
|
|
@@ -363,6 +367,7 @@ function dtypedArray(dtype, data) {
|
|
|
363
367
|
case DType.Bool: return new Int32Array(buffer, byteOffset, length);
|
|
364
368
|
case DType.Uint32: return new Uint32Array(buffer, byteOffset, length);
|
|
365
369
|
case DType.Float16: return new Float16Array(buffer, byteOffset, length);
|
|
370
|
+
case DType.Float64: return new Float64Array(buffer, byteOffset, length);
|
|
366
371
|
default: throw new Error(`Unimplemented dtype: ${dtype}`);
|
|
367
372
|
}
|
|
368
373
|
}
|
|
@@ -373,6 +378,7 @@ function dtypedJsArray(dtype, data) {
|
|
|
373
378
|
case DType.Bool: return new Int32Array(data);
|
|
374
379
|
case DType.Uint32: return new Uint32Array(data);
|
|
375
380
|
case DType.Float16: return new Float16Array(data);
|
|
381
|
+
case DType.Float64: return new Float64Array(data);
|
|
376
382
|
default: throw new Error(`Unimplemented dtype: ${dtype}`);
|
|
377
383
|
}
|
|
378
384
|
}
|
|
@@ -444,6 +450,14 @@ var AluExp = class AluExp {
|
|
|
444
450
|
static sqrt(a) {
|
|
445
451
|
return new AluExp(AluOp.Sqrt, a.dtype, [a]);
|
|
446
452
|
}
|
|
453
|
+
static floor(a) {
|
|
454
|
+
if (!isFloatDtype(a.dtype)) return a;
|
|
455
|
+
return new AluExp(AluOp.Floor, a.dtype, [a]);
|
|
456
|
+
}
|
|
457
|
+
static ceil(a) {
|
|
458
|
+
if (!isFloatDtype(a.dtype)) return a;
|
|
459
|
+
return new AluExp(AluOp.Ceil, a.dtype, [a]);
|
|
460
|
+
}
|
|
447
461
|
static reciprocal(a) {
|
|
448
462
|
return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
|
|
449
463
|
}
|
|
@@ -510,6 +524,9 @@ var AluExp = class AluExp {
|
|
|
510
524
|
static f16(value) {
|
|
511
525
|
return AluExp.const(DType.Float16, value);
|
|
512
526
|
}
|
|
527
|
+
static f64(value) {
|
|
528
|
+
return AluExp.const(DType.Float64, value);
|
|
529
|
+
}
|
|
513
530
|
not() {
|
|
514
531
|
if (this.dtype !== DType.Bool) throw new Error("not() can only be called on boolean expressions");
|
|
515
532
|
return AluExp.cmpne(this, AluExp.const(DType.Bool, true));
|
|
@@ -520,7 +537,8 @@ var AluExp = class AluExp {
|
|
|
520
537
|
const hasher = new FpHash();
|
|
521
538
|
hasher.update(this.op);
|
|
522
539
|
hasher.update(this.dtype);
|
|
523
|
-
hasher.update(
|
|
540
|
+
if (this.op === AluOp.Const) hasher.update(this.arg);
|
|
541
|
+
else hasher.update(JSON.stringify(this.arg));
|
|
524
542
|
hasher.update(this.src.length);
|
|
525
543
|
for (const s of this.src) hasher.update(s);
|
|
526
544
|
this.#hash = hasher.value;
|
|
@@ -620,6 +638,12 @@ var AluExp = class AluExp {
|
|
|
620
638
|
case AluOp.Sqrt:
|
|
621
639
|
ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
|
|
622
640
|
break;
|
|
641
|
+
case AluOp.Floor:
|
|
642
|
+
ret = [Math.floor(src[0].min), Math.floor(src[0].max)];
|
|
643
|
+
break;
|
|
644
|
+
case AluOp.Ceil:
|
|
645
|
+
ret = [Math.ceil(src[0].min), Math.ceil(src[0].max)];
|
|
646
|
+
break;
|
|
623
647
|
case AluOp.Reciprocal:
|
|
624
648
|
if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
|
|
625
649
|
ret = [1 / src[0].max, 1 / src[0].min];
|
|
@@ -757,6 +781,7 @@ var AluExp = class AluExp {
|
|
|
757
781
|
if (op === AluOp.Mul && x === 1) return src[1 - i];
|
|
758
782
|
if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
|
|
759
783
|
if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
|
|
784
|
+
if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
|
|
760
785
|
}
|
|
761
786
|
if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
|
|
762
787
|
const [a, b] = src[1].src;
|
|
@@ -851,7 +876,7 @@ var AluExp = class AluExp {
|
|
|
851
876
|
else return p(p(src[0].src[0], src[1]), src[0].src[1]).simplify(cache);
|
|
852
877
|
if (src[1].op === op && src[1].src[1].op === AluOp.Const) return p(p(src[0], src[1].src[0]), src[1].src[1]).simplify(cache);
|
|
853
878
|
}
|
|
854
|
-
if (op === AluOp.Mod || op === AluOp.Idiv && src[1].#isConstInt()) {
|
|
879
|
+
if ((op === AluOp.Mod || op === AluOp.Idiv) && src[1].#isConstInt()) {
|
|
855
880
|
const [x, y] = src;
|
|
856
881
|
{
|
|
857
882
|
const factors = [];
|
|
@@ -941,6 +966,8 @@ var AluExp = class AluExp {
|
|
|
941
966
|
case AluOp.Erf: return erf(x);
|
|
942
967
|
case AluOp.Erfc: return erfc(x);
|
|
943
968
|
case AluOp.Sqrt: return Math.sqrt(x);
|
|
969
|
+
case AluOp.Floor: return Math.floor(x);
|
|
970
|
+
case AluOp.Ceil: return Math.ceil(x);
|
|
944
971
|
case AluOp.Reciprocal: return 1 / x;
|
|
945
972
|
case AluOp.Cast: {
|
|
946
973
|
const wasFloat = isFloatDtype(this.src[0].dtype);
|
|
@@ -958,11 +985,13 @@ var AluExp = class AluExp {
|
|
|
958
985
|
else if (fromType === DType.Int32) view.setInt32(0, x, true);
|
|
959
986
|
else if (fromType === DType.Uint32) view.setUint32(0, x, true);
|
|
960
987
|
else if (fromType === DType.Float16) view.setFloat16(0, x, true);
|
|
988
|
+
else if (fromType === DType.Float64) view.setFloat64(0, x, true);
|
|
961
989
|
else throw new Error(`Unsupported bitcast from ${fromType}`);
|
|
962
990
|
if (this.dtype === DType.Float32) return view.getFloat32(0, true);
|
|
963
991
|
else if (this.dtype === DType.Int32) return view.getInt32(0, true);
|
|
964
992
|
else if (this.dtype === DType.Uint32) return view.getUint32(0, true);
|
|
965
993
|
else if (this.dtype === DType.Float16) return view.getFloat16(0, true);
|
|
994
|
+
else if (this.dtype === DType.Float64) return view.getFloat64(0, true);
|
|
966
995
|
else throw new Error(`Unsupported bitcast to ${this.dtype}`);
|
|
967
996
|
}
|
|
968
997
|
default: throw new Error(`Missing implemementation for ${this.op}`);
|
|
@@ -1128,6 +1157,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
1128
1157
|
AluOp$1["Erf"] = "Erf";
|
|
1129
1158
|
AluOp$1["Erfc"] = "Erfc";
|
|
1130
1159
|
AluOp$1["Sqrt"] = "Sqrt";
|
|
1160
|
+
AluOp$1["Floor"] = "Floor";
|
|
1161
|
+
AluOp$1["Ceil"] = "Ceil";
|
|
1131
1162
|
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
1132
1163
|
AluOp$1["Cast"] = "Cast";
|
|
1133
1164
|
AluOp$1["Bitcast"] = "Bitcast";
|
|
@@ -1162,6 +1193,8 @@ const AluGroup = {
|
|
|
1162
1193
|
AluOp.Erf,
|
|
1163
1194
|
AluOp.Erfc,
|
|
1164
1195
|
AluOp.Sqrt,
|
|
1196
|
+
AluOp.Floor,
|
|
1197
|
+
AluOp.Ceil,
|
|
1165
1198
|
AluOp.Reciprocal,
|
|
1166
1199
|
AluOp.Cast,
|
|
1167
1200
|
AluOp.Bitcast
|
|
@@ -1189,7 +1222,9 @@ const AluGroup = {
|
|
|
1189
1222
|
AluOp.Erf,
|
|
1190
1223
|
AluOp.Erfc,
|
|
1191
1224
|
AluOp.Sqrt,
|
|
1192
|
-
AluOp.Reciprocal
|
|
1225
|
+
AluOp.Reciprocal,
|
|
1226
|
+
AluOp.Floor,
|
|
1227
|
+
AluOp.Ceil
|
|
1193
1228
|
])
|
|
1194
1229
|
};
|
|
1195
1230
|
/** Common variables that can be substituted in expressions. */
|
|
@@ -2925,6 +2960,7 @@ var CodeGenerator = class {
|
|
|
2925
2960
|
local;
|
|
2926
2961
|
i32;
|
|
2927
2962
|
f32;
|
|
2963
|
+
f64;
|
|
2928
2964
|
v128;
|
|
2929
2965
|
i32x4;
|
|
2930
2966
|
f32x4;
|
|
@@ -2944,6 +2980,7 @@ var CodeGenerator = class {
|
|
|
2944
2980
|
this.local = new Local(this);
|
|
2945
2981
|
this.i32 = new I32(this);
|
|
2946
2982
|
this.f32 = new F32(this);
|
|
2983
|
+
this.f64 = new F64(this);
|
|
2947
2984
|
this.v128 = new V128(this);
|
|
2948
2985
|
this.i32x4 = new I32x4(this);
|
|
2949
2986
|
this.f32x4 = new F32x4(this);
|
|
@@ -3330,6 +3367,8 @@ var I32 = class {
|
|
|
3330
3367
|
ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
|
|
3331
3368
|
trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
|
|
3332
3369
|
trunc_f32_u = UNARY_OP("trunc_f32_u", 169, "f32", "i32");
|
|
3370
|
+
trunc_f64_s = UNARY_OP("trunc_f64_s", 170, "f64", "i32");
|
|
3371
|
+
trunc_f64_u = UNARY_OP("trunc_f64_u", 171, "f64", "i32");
|
|
3333
3372
|
load = LOAD_OP("load", 40, "i32");
|
|
3334
3373
|
load8_s = LOAD_OP("load8_s", 44, "i32");
|
|
3335
3374
|
load8_u = LOAD_OP("load8_u", 45, "i32");
|
|
@@ -3341,6 +3380,8 @@ var I32 = class {
|
|
|
3341
3380
|
reinterpret_f32 = UNARY_OP("reinterpret_f32", 188, "f32", "i32");
|
|
3342
3381
|
trunc_sat_f32_s = UNARY_OP("trunc_sat_f32_s", [252, 0], "f32", "i32");
|
|
3343
3382
|
trunc_sat_f32_u = UNARY_OP("trunc_sat_f32_u", [252, 1], "f32", "i32");
|
|
3383
|
+
trunc_sat_f64_s = UNARY_OP("trunc_sat_f64_s", [252, 2], "f64", "i32");
|
|
3384
|
+
trunc_sat_f64_u = UNARY_OP("trunc_sat_f64_u", [252, 3], "f64", "i32");
|
|
3344
3385
|
};
|
|
3345
3386
|
var F32 = class {
|
|
3346
3387
|
constructor(cg) {
|
|
@@ -3360,6 +3401,8 @@ var F32 = class {
|
|
|
3360
3401
|
for (let i = 0; i < 4; i++) this.cg._emit(bytes[i]);
|
|
3361
3402
|
this.cg._push(this);
|
|
3362
3403
|
}
|
|
3404
|
+
load = LOAD_OP("load", 42, "f32");
|
|
3405
|
+
store = STORE_OP("store", 56, "f32");
|
|
3363
3406
|
eq = BINARY_OP("eq", 91, "f32", "f32", "i32");
|
|
3364
3407
|
ne = BINARY_OP("ne", 92, "f32", "f32", "i32");
|
|
3365
3408
|
lt = BINARY_OP("lt", 93, "f32", "f32", "i32");
|
|
@@ -3382,10 +3425,53 @@ var F32 = class {
|
|
|
3382
3425
|
copysign = BINARY_OP("copysign", 152, "f32", "f32", "f32");
|
|
3383
3426
|
convert_i32_s = UNARY_OP("convert_i32_s", 178, "i32", "f32");
|
|
3384
3427
|
convert_i32_u = UNARY_OP("convert_i32_u", 179, "i32", "f32");
|
|
3385
|
-
|
|
3386
|
-
store = STORE_OP("store", 56, "f32");
|
|
3428
|
+
demote_f64 = UNARY_OP("demote_f64", 182, "f64", "f32");
|
|
3387
3429
|
reinterpret_i32 = UNARY_OP("reinterpret_i32", 190, "i32", "f32");
|
|
3388
3430
|
};
|
|
3431
|
+
var F64 = class {
|
|
3432
|
+
constructor(cg) {
|
|
3433
|
+
this.cg = cg;
|
|
3434
|
+
}
|
|
3435
|
+
get typeId() {
|
|
3436
|
+
return 124;
|
|
3437
|
+
}
|
|
3438
|
+
get name() {
|
|
3439
|
+
return "f64";
|
|
3440
|
+
}
|
|
3441
|
+
const(f) {
|
|
3442
|
+
this.cg._emit(68);
|
|
3443
|
+
const buffer = /* @__PURE__ */ new ArrayBuffer(8);
|
|
3444
|
+
new DataView(buffer).setFloat64(0, f, true);
|
|
3445
|
+
const bytes = new Uint8Array(buffer);
|
|
3446
|
+
for (let i = 0; i < 8; i++) this.cg._emit(bytes[i]);
|
|
3447
|
+
this.cg._push(this);
|
|
3448
|
+
}
|
|
3449
|
+
load = LOAD_OP("load", 43, "f64");
|
|
3450
|
+
store = STORE_OP("store", 57, "f64");
|
|
3451
|
+
eq = BINARY_OP("eq", 97, "f64", "f64", "i32");
|
|
3452
|
+
ne = BINARY_OP("ne", 98, "f64", "f64", "i32");
|
|
3453
|
+
lt = BINARY_OP("lt", 99, "f64", "f64", "i32");
|
|
3454
|
+
gt = BINARY_OP("gt", 100, "f64", "f64", "i32");
|
|
3455
|
+
le = BINARY_OP("le", 101, "f64", "f64", "i32");
|
|
3456
|
+
ge = BINARY_OP("ge", 102, "f64", "f64", "i32");
|
|
3457
|
+
abs = UNARY_OP("abs", 153, "f64", "f64");
|
|
3458
|
+
neg = UNARY_OP("neg", 154, "f64", "f64");
|
|
3459
|
+
ceil = UNARY_OP("ceil", 155, "f64", "f64");
|
|
3460
|
+
floor = UNARY_OP("floor", 156, "f64", "f64");
|
|
3461
|
+
trunc = UNARY_OP("trunc", 157, "f64", "f64");
|
|
3462
|
+
nearest = UNARY_OP("nearest", 158, "f64", "f64");
|
|
3463
|
+
sqrt = UNARY_OP("sqrt", 159, "f64", "f64");
|
|
3464
|
+
add = BINARY_OP("add", 160, "f64", "f64", "f64");
|
|
3465
|
+
sub = BINARY_OP("sub", 161, "f64", "f64", "f64");
|
|
3466
|
+
mul = BINARY_OP("mul", 162, "f64", "f64", "f64");
|
|
3467
|
+
div = BINARY_OP("div", 163, "f64", "f64", "f64");
|
|
3468
|
+
min = BINARY_OP("min", 164, "f64", "f64", "f64");
|
|
3469
|
+
max = BINARY_OP("max", 165, "f64", "f64", "f64");
|
|
3470
|
+
copysign = BINARY_OP("copysign", 166, "f64", "f64", "f64");
|
|
3471
|
+
convert_i32_s = UNARY_OP("convert_i32_s", 183, "i32", "f64");
|
|
3472
|
+
convert_i32_u = UNARY_OP("convert_i32_u", 184, "i32", "f64");
|
|
3473
|
+
promote_f32 = UNARY_OP("promote_f32", 187, "f32", "f64");
|
|
3474
|
+
};
|
|
3389
3475
|
function VECTOR_OP(op, vopcode, inTypes, outType) {
|
|
3390
3476
|
return function() {
|
|
3391
3477
|
for (const inType of inTypes.toReversed()) {
|
|
@@ -3638,10 +3724,10 @@ function codegenWasm(kernel) {
|
|
|
3638
3724
|
cg.local.get(acc);
|
|
3639
3725
|
if (re.dtype === DType.Bool) cg.i32.and();
|
|
3640
3726
|
else dty(cg, re.op, re.dtype).mul();
|
|
3641
|
-
} else if (re.op === AluOp.Min || re.op === AluOp.Max) if (re.dtype
|
|
3727
|
+
} else if (re.op === AluOp.Min || re.op === AluOp.Max) if (isFloatDtype(re.dtype)) {
|
|
3642
3728
|
cg.local.get(acc);
|
|
3643
|
-
if (re.op === AluOp.Min) cg.
|
|
3644
|
-
else cg.
|
|
3729
|
+
if (re.op === AluOp.Min) dtyF(cg, re.op, re.dtype).min();
|
|
3730
|
+
else dtyF(cg, re.op, re.dtype).max();
|
|
3645
3731
|
} else if ([
|
|
3646
3732
|
DType.Int32,
|
|
3647
3733
|
DType.Uint32,
|
|
@@ -3703,27 +3789,30 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3703
3789
|
else if (op === AluOp.Sub) dty(cg, op, dtype).sub();
|
|
3704
3790
|
else if (op === AluOp.Mul) if (dtype === DType.Bool) cg.i32.and();
|
|
3705
3791
|
else dty(cg, op, dtype).mul();
|
|
3706
|
-
else if (op === AluOp.Idiv) if (dtype
|
|
3707
|
-
|
|
3792
|
+
else if (op === AluOp.Idiv) if (isFloatDtype(dtype)) {
|
|
3793
|
+
dtyF(cg, op, dtype).div();
|
|
3794
|
+
dtyF(cg, op, dtype).trunc();
|
|
3795
|
+
} else if (dtype === DType.Uint32) cg.i32.div_u();
|
|
3708
3796
|
else if (dtype === DType.Int32) cg.i32.div_s();
|
|
3709
3797
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3710
|
-
else if (op === AluOp.Mod) if (dtype
|
|
3711
|
-
const
|
|
3712
|
-
const
|
|
3798
|
+
else if (op === AluOp.Mod) if (isFloatDtype(dtype)) {
|
|
3799
|
+
const dt = dtyF(cg, op, dtype);
|
|
3800
|
+
const a = cg.local.declare(dt);
|
|
3801
|
+
const b = cg.local.declare(dt);
|
|
3713
3802
|
cg.local.set(b);
|
|
3714
3803
|
cg.local.tee(a);
|
|
3715
3804
|
cg.local.get(a);
|
|
3716
3805
|
cg.local.get(b);
|
|
3717
|
-
|
|
3718
|
-
|
|
3806
|
+
dt.div();
|
|
3807
|
+
dt.trunc();
|
|
3719
3808
|
cg.local.get(b);
|
|
3720
|
-
|
|
3721
|
-
|
|
3809
|
+
dt.mul();
|
|
3810
|
+
dt.sub();
|
|
3722
3811
|
} else if (dtype === DType.Uint32) cg.i32.rem_u();
|
|
3723
3812
|
else if (dtype === DType.Int32) cg.i32.rem_s();
|
|
3724
3813
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3725
|
-
else if (op === AluOp.Min || op === AluOp.Max) if (dtype
|
|
3726
|
-
else cg.
|
|
3814
|
+
else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
|
|
3815
|
+
else dtyF(cg, op, dtype).max();
|
|
3727
3816
|
else if (dtype === DType.Int32 || dtype === DType.Uint32) {
|
|
3728
3817
|
const a = cg.local.declare(cg.i32);
|
|
3729
3818
|
const b = cg.local.declare(cg.i32);
|
|
@@ -3740,54 +3829,76 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3740
3829
|
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3741
3830
|
else if (op === AluOp.Cmplt) {
|
|
3742
3831
|
const srcDtype = src[0].dtype;
|
|
3743
|
-
if (srcDtype
|
|
3832
|
+
if (isFloatDtype(srcDtype)) dtyF(cg, op, srcDtype).lt();
|
|
3744
3833
|
else if (srcDtype === DType.Int32) cg.i32.lt_s();
|
|
3745
3834
|
else if (srcDtype === DType.Uint32) cg.i32.lt_u();
|
|
3746
3835
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3747
3836
|
} else if (op === AluOp.Cmpne) dty(cg, op, src[0].dtype).ne();
|
|
3748
3837
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3749
|
-
} else if (AluGroup.Unary.has(op))
|
|
3750
|
-
|
|
3751
|
-
|
|
3752
|
-
|
|
3753
|
-
|
|
3754
|
-
|
|
3755
|
-
|
|
3756
|
-
|
|
3757
|
-
|
|
3758
|
-
|
|
3759
|
-
|
|
3760
|
-
gen(src[0]);
|
|
3761
|
-
|
|
3762
|
-
|
|
3763
|
-
if (
|
|
3764
|
-
else if (
|
|
3765
|
-
else
|
|
3766
|
-
|
|
3767
|
-
|
|
3768
|
-
else
|
|
3769
|
-
else if (
|
|
3770
|
-
else if (
|
|
3771
|
-
|
|
3772
|
-
|
|
3773
|
-
|
|
3774
|
-
|
|
3775
|
-
|
|
3776
|
-
|
|
3777
|
-
|
|
3778
|
-
|
|
3779
|
-
|
|
3780
|
-
|
|
3781
|
-
|
|
3782
|
-
|
|
3783
|
-
|
|
3784
|
-
|
|
3785
|
-
|
|
3786
|
-
|
|
3787
|
-
|
|
3788
|
-
|
|
3789
|
-
|
|
3790
|
-
|
|
3838
|
+
} else if (AluGroup.Unary.has(op)) {
|
|
3839
|
+
const callFuncF32 = (func) => {
|
|
3840
|
+
if (dtype !== DType.Float32) if (dtype === DType.Float64) cg.f32.demote_f64();
|
|
3841
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3842
|
+
cg.call(func);
|
|
3843
|
+
if (dtype === DType.Float64) cg.f64.promote_f32();
|
|
3844
|
+
};
|
|
3845
|
+
if (op === AluOp.Sin) gen(src[0]), callFuncF32(funcs.sin);
|
|
3846
|
+
else if (op === AluOp.Cos) gen(src[0]), callFuncF32(funcs.cos);
|
|
3847
|
+
else if (op === AluOp.Asin) gen(src[0]), callFuncF32(funcs.asin);
|
|
3848
|
+
else if (op === AluOp.Atan) gen(src[0]), callFuncF32(funcs.atan);
|
|
3849
|
+
else if (op === AluOp.Exp) gen(src[0]), callFuncF32(funcs.exp);
|
|
3850
|
+
else if (op === AluOp.Log) gen(src[0]), callFuncF32(funcs.log);
|
|
3851
|
+
else if (op === AluOp.Erf) gen(src[0]), callFuncF32(funcs.erf);
|
|
3852
|
+
else if (op === AluOp.Erfc) gen(src[0]), callFuncF32(funcs.erfc);
|
|
3853
|
+
else if (op === AluOp.Sqrt) gen(src[0]), dtyF(cg, op, dtype).sqrt();
|
|
3854
|
+
else if (op === AluOp.Reciprocal) {
|
|
3855
|
+
const dt = dtyF(cg, op, dtype);
|
|
3856
|
+
dt.const(1), gen(src[0]), dt.div();
|
|
3857
|
+
} else if (op === AluOp.Floor) gen(src[0]), dtyF(cg, op, dtype).floor();
|
|
3858
|
+
else if (op === AluOp.Ceil) gen(src[0]), dtyF(cg, op, dtype).ceil();
|
|
3859
|
+
else if (op === AluOp.Cast) {
|
|
3860
|
+
gen(src[0]);
|
|
3861
|
+
const dtype0 = src[0].dtype;
|
|
3862
|
+
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
|
|
3863
|
+
if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
|
|
3864
|
+
else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_s();
|
|
3865
|
+
else if (i32repr);
|
|
3866
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3867
|
+
else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
|
|
3868
|
+
else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_u();
|
|
3869
|
+
else if (i32repr);
|
|
3870
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3871
|
+
else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
|
|
3872
|
+
else if (dtype0 === DType.Float64) cg.f32.demote_f64();
|
|
3873
|
+
else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
|
|
3874
|
+
else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
|
|
3875
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3876
|
+
else if (dtype === DType.Float64) if (dtype0 === DType.Float32) cg.f64.promote_f32();
|
|
3877
|
+
else if (dtype0 === DType.Float64);
|
|
3878
|
+
else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f64.convert_i32_s();
|
|
3879
|
+
else if (dtype0 === DType.Uint32) cg.f64.convert_i32_u();
|
|
3880
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3881
|
+
else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
|
|
3882
|
+
else if (i32repr) cg.i32.const(0), cg.i32.ne();
|
|
3883
|
+
else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
|
|
3884
|
+
else if (dtype0 === DType.Float64) cg.f64.const(0), cg.f64.ne();
|
|
3885
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3886
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3887
|
+
} else if (op === AluOp.Bitcast) {
|
|
3888
|
+
gen(src[0]);
|
|
3889
|
+
const dtype0 = src[0].dtype;
|
|
3890
|
+
if (dtype !== dtype0) {
|
|
3891
|
+
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
|
|
3892
|
+
if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
|
|
3893
|
+
else if (i32repr);
|
|
3894
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3895
|
+
else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
|
|
3896
|
+
else if (dtype0 === DType.Float32);
|
|
3897
|
+
else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
|
|
3898
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3899
|
+
}
|
|
3900
|
+
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3901
|
+
} else if (op === AluOp.Where) {
|
|
3791
3902
|
gen(src[1]);
|
|
3792
3903
|
gen(src[2]);
|
|
3793
3904
|
gen(src[0]);
|
|
@@ -3832,12 +3943,20 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3832
3943
|
function dty(cg, op, dtype) {
|
|
3833
3944
|
switch (dtype) {
|
|
3834
3945
|
case DType.Float32: return cg.f32;
|
|
3946
|
+
case DType.Float64: return cg.f64;
|
|
3835
3947
|
case DType.Int32:
|
|
3836
3948
|
case DType.Uint32:
|
|
3837
3949
|
case DType.Bool: return cg.i32;
|
|
3838
3950
|
default: throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3839
3951
|
}
|
|
3840
3952
|
}
|
|
3953
|
+
function dtyF(cg, op, dtype) {
|
|
3954
|
+
switch (dtype) {
|
|
3955
|
+
case DType.Float32: return cg.f32;
|
|
3956
|
+
case DType.Float64: return cg.f64;
|
|
3957
|
+
default: throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3958
|
+
}
|
|
3959
|
+
}
|
|
3841
3960
|
|
|
3842
3961
|
//#endregion
|
|
3843
3962
|
//#region src/backend.ts
|
|
@@ -3846,10 +3965,10 @@ const devices = [
|
|
|
3846
3965
|
"wasm",
|
|
3847
3966
|
"webgpu"
|
|
3848
3967
|
];
|
|
3849
|
-
let defaultBackend = "wasm";
|
|
3850
3968
|
const initializedBackends = /* @__PURE__ */ new Map();
|
|
3851
3969
|
initializedBackends.set("cpu", new CpuBackend());
|
|
3852
|
-
initializedBackends.set("wasm", new WasmBackend());
|
|
3970
|
+
if (typeof WebAssembly !== "undefined") initializedBackends.set("wasm", new WasmBackend());
|
|
3971
|
+
let defaultBackend = initializedBackends.has("wasm") ? "wasm" : "cpu";
|
|
3853
3972
|
/** Configure the default device for arrays. */
|
|
3854
3973
|
function defaultDevice(device) {
|
|
3855
3974
|
if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
|
|
@@ -3876,12 +3995,14 @@ async function init(...devicesToInit) {
|
|
|
3876
3995
|
/** Create a backend, if available. Internal function called by `init()`. */
|
|
3877
3996
|
async function createBackend(device) {
|
|
3878
3997
|
if (device === "cpu") return new CpuBackend();
|
|
3879
|
-
else if (device === "wasm")
|
|
3880
|
-
|
|
3998
|
+
else if (device === "wasm") {
|
|
3999
|
+
if (typeof WebAssembly === "undefined") return null;
|
|
4000
|
+
return new WasmBackend();
|
|
4001
|
+
} else if (device === "webgpu") {
|
|
3881
4002
|
if (!navigator.gpu) return null;
|
|
3882
4003
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3883
4004
|
if (!adapter) return null;
|
|
3884
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
4005
|
+
const { WebGPUBackend } = await import("./webgpu-BGuG58KZ.js");
|
|
3885
4006
|
const importantLimits = [
|
|
3886
4007
|
"maxBufferSize",
|
|
3887
4008
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -3935,4 +4056,4 @@ var UnsupportedOpError = class extends Error {
|
|
|
3935
4056
|
|
|
3936
4057
|
//#endregion
|
|
3937
4058
|
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
3938
|
-
//# sourceMappingURL=backend-
|
|
4059
|
+
//# sourceMappingURL=backend-BqymqzuU.js.map
|