@jax-js/jax 0.1.0 → 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -289,10 +289,11 @@ var FpHash = class FpHash {
289
289
  };
290
290
  /** Run a function while caching it inline inside a `Map`. */
291
291
  function runWithCache(cache, key, thunk) {
292
- if (cache.has(key)) return cache.get(key);
292
+ const keyStr = JSON.stringify(key);
293
+ if (cache.has(keyStr)) return cache.get(keyStr);
293
294
  else {
294
295
  const value = thunk();
295
- cache.set(key, value);
296
+ cache.set(keyStr, value);
296
297
  return value;
297
298
  }
298
299
  }
@@ -306,6 +307,7 @@ let DType = /* @__PURE__ */ function(DType$1) {
306
307
  DType$1["Uint32"] = "uint32";
307
308
  DType$1["Bool"] = "bool";
308
309
  DType$1["Float16"] = "float16";
310
+ DType$1["Float64"] = "float64";
309
311
  return DType$1;
310
312
  }({});
311
313
  const byteWidth = (dtype) => {
@@ -315,10 +317,11 @@ const byteWidth = (dtype) => {
315
317
  case DType.Uint32:
316
318
  case DType.Bool: return 4;
317
319
  case DType.Float16: return 2;
320
+ case DType.Float64: return 8;
318
321
  default: throw new TypeError(`Unknown dtype: ${dtype}`);
319
322
  }
320
323
  };
321
- const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
324
+ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16 || dtype === DType.Float64;
322
325
  /**
323
326
  * Promote two dtypes to their join according to the type lattice.
324
327
  *
@@ -328,7 +331,7 @@ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float
328
331
  *
329
332
  * **Type lattice:**
330
333
  * ```text
331
- * bool -> uint32 -> int32 -> float16 -> float32
334
+ * bool -> uint32 -> int32 -> float16 -> float32 -> float64
332
335
  * weakType --^
333
336
  * ```
334
337
  *
@@ -350,7 +353,8 @@ function promoteTypes(dtype1, dtype2) {
350
353
  [DType.Uint32]: 1,
351
354
  [DType.Int32]: 2,
352
355
  [DType.Float16]: 3,
353
- [DType.Float32]: 4
356
+ [DType.Float32]: 4,
357
+ [DType.Float64]: 5
354
358
  };
355
359
  return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
356
360
  }
@@ -363,6 +367,7 @@ function dtypedArray(dtype, data) {
363
367
  case DType.Bool: return new Int32Array(buffer, byteOffset, length);
364
368
  case DType.Uint32: return new Uint32Array(buffer, byteOffset, length);
365
369
  case DType.Float16: return new Float16Array(buffer, byteOffset, length);
370
+ case DType.Float64: return new Float64Array(buffer, byteOffset, length);
366
371
  default: throw new Error(`Unimplemented dtype: ${dtype}`);
367
372
  }
368
373
  }
@@ -373,6 +378,7 @@ function dtypedJsArray(dtype, data) {
373
378
  case DType.Bool: return new Int32Array(data);
374
379
  case DType.Uint32: return new Uint32Array(data);
375
380
  case DType.Float16: return new Float16Array(data);
381
+ case DType.Float64: return new Float64Array(data);
376
382
  default: throw new Error(`Unimplemented dtype: ${dtype}`);
377
383
  }
378
384
  }
@@ -444,6 +450,14 @@ var AluExp = class AluExp {
444
450
  static sqrt(a) {
445
451
  return new AluExp(AluOp.Sqrt, a.dtype, [a]);
446
452
  }
453
+ static floor(a) {
454
+ if (!isFloatDtype(a.dtype)) return a;
455
+ return new AluExp(AluOp.Floor, a.dtype, [a]);
456
+ }
457
+ static ceil(a) {
458
+ if (!isFloatDtype(a.dtype)) return a;
459
+ return new AluExp(AluOp.Ceil, a.dtype, [a]);
460
+ }
447
461
  static reciprocal(a) {
448
462
  return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
449
463
  }
@@ -510,6 +524,9 @@ var AluExp = class AluExp {
510
524
  static f16(value) {
511
525
  return AluExp.const(DType.Float16, value);
512
526
  }
527
+ static f64(value) {
528
+ return AluExp.const(DType.Float64, value);
529
+ }
513
530
  not() {
514
531
  if (this.dtype !== DType.Bool) throw new Error("not() can only be called on boolean expressions");
515
532
  return AluExp.cmpne(this, AluExp.const(DType.Bool, true));
@@ -520,7 +537,8 @@ var AluExp = class AluExp {
520
537
  const hasher = new FpHash();
521
538
  hasher.update(this.op);
522
539
  hasher.update(this.dtype);
523
- hasher.update(JSON.stringify(this.arg));
540
+ if (this.op === AluOp.Const) hasher.update(this.arg);
541
+ else hasher.update(JSON.stringify(this.arg));
524
542
  hasher.update(this.src.length);
525
543
  for (const s of this.src) hasher.update(s);
526
544
  this.#hash = hasher.value;
@@ -620,6 +638,12 @@ var AluExp = class AluExp {
620
638
  case AluOp.Sqrt:
621
639
  ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
622
640
  break;
641
+ case AluOp.Floor:
642
+ ret = [Math.floor(src[0].min), Math.floor(src[0].max)];
643
+ break;
644
+ case AluOp.Ceil:
645
+ ret = [Math.ceil(src[0].min), Math.ceil(src[0].max)];
646
+ break;
623
647
  case AluOp.Reciprocal:
624
648
  if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
625
649
  ret = [1 / src[0].max, 1 / src[0].min];
@@ -757,6 +781,7 @@ var AluExp = class AluExp {
757
781
  if (op === AluOp.Mul && x === 1) return src[1 - i];
758
782
  if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
759
783
  if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
784
+ if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
760
785
  }
761
786
  if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
762
787
  const [a, b] = src[1].src;
@@ -851,7 +876,7 @@ var AluExp = class AluExp {
851
876
  else return p(p(src[0].src[0], src[1]), src[0].src[1]).simplify(cache);
852
877
  if (src[1].op === op && src[1].src[1].op === AluOp.Const) return p(p(src[0], src[1].src[0]), src[1].src[1]).simplify(cache);
853
878
  }
854
- if (op === AluOp.Mod || op === AluOp.Idiv && src[1].#isConstInt()) {
879
+ if ((op === AluOp.Mod || op === AluOp.Idiv) && src[1].#isConstInt()) {
855
880
  const [x, y] = src;
856
881
  {
857
882
  const factors = [];
@@ -941,6 +966,8 @@ var AluExp = class AluExp {
941
966
  case AluOp.Erf: return erf(x);
942
967
  case AluOp.Erfc: return erfc(x);
943
968
  case AluOp.Sqrt: return Math.sqrt(x);
969
+ case AluOp.Floor: return Math.floor(x);
970
+ case AluOp.Ceil: return Math.ceil(x);
944
971
  case AluOp.Reciprocal: return 1 / x;
945
972
  case AluOp.Cast: {
946
973
  const wasFloat = isFloatDtype(this.src[0].dtype);
@@ -958,11 +985,13 @@ var AluExp = class AluExp {
958
985
  else if (fromType === DType.Int32) view.setInt32(0, x, true);
959
986
  else if (fromType === DType.Uint32) view.setUint32(0, x, true);
960
987
  else if (fromType === DType.Float16) view.setFloat16(0, x, true);
988
+ else if (fromType === DType.Float64) view.setFloat64(0, x, true);
961
989
  else throw new Error(`Unsupported bitcast from ${fromType}`);
962
990
  if (this.dtype === DType.Float32) return view.getFloat32(0, true);
963
991
  else if (this.dtype === DType.Int32) return view.getInt32(0, true);
964
992
  else if (this.dtype === DType.Uint32) return view.getUint32(0, true);
965
993
  else if (this.dtype === DType.Float16) return view.getFloat16(0, true);
994
+ else if (this.dtype === DType.Float64) return view.getFloat64(0, true);
966
995
  else throw new Error(`Unsupported bitcast to ${this.dtype}`);
967
996
  }
968
997
  default: throw new Error(`Missing implemementation for ${this.op}`);
@@ -1128,6 +1157,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
1128
1157
  AluOp$1["Erf"] = "Erf";
1129
1158
  AluOp$1["Erfc"] = "Erfc";
1130
1159
  AluOp$1["Sqrt"] = "Sqrt";
1160
+ AluOp$1["Floor"] = "Floor";
1161
+ AluOp$1["Ceil"] = "Ceil";
1131
1162
  AluOp$1["Reciprocal"] = "Reciprocal";
1132
1163
  AluOp$1["Cast"] = "Cast";
1133
1164
  AluOp$1["Bitcast"] = "Bitcast";
@@ -1162,6 +1193,8 @@ const AluGroup = {
1162
1193
  AluOp.Erf,
1163
1194
  AluOp.Erfc,
1164
1195
  AluOp.Sqrt,
1196
+ AluOp.Floor,
1197
+ AluOp.Ceil,
1165
1198
  AluOp.Reciprocal,
1166
1199
  AluOp.Cast,
1167
1200
  AluOp.Bitcast
@@ -1189,7 +1222,9 @@ const AluGroup = {
1189
1222
  AluOp.Erf,
1190
1223
  AluOp.Erfc,
1191
1224
  AluOp.Sqrt,
1192
- AluOp.Reciprocal
1225
+ AluOp.Reciprocal,
1226
+ AluOp.Floor,
1227
+ AluOp.Ceil
1193
1228
  ])
1194
1229
  };
1195
1230
  /** Common variables that can be substituted in expressions. */
@@ -2925,6 +2960,7 @@ var CodeGenerator = class {
2925
2960
  local;
2926
2961
  i32;
2927
2962
  f32;
2963
+ f64;
2928
2964
  v128;
2929
2965
  i32x4;
2930
2966
  f32x4;
@@ -2944,6 +2980,7 @@ var CodeGenerator = class {
2944
2980
  this.local = new Local(this);
2945
2981
  this.i32 = new I32(this);
2946
2982
  this.f32 = new F32(this);
2983
+ this.f64 = new F64(this);
2947
2984
  this.v128 = new V128(this);
2948
2985
  this.i32x4 = new I32x4(this);
2949
2986
  this.f32x4 = new F32x4(this);
@@ -3330,6 +3367,8 @@ var I32 = class {
3330
3367
  ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
3331
3368
  trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
3332
3369
  trunc_f32_u = UNARY_OP("trunc_f32_u", 169, "f32", "i32");
3370
+ trunc_f64_s = UNARY_OP("trunc_f64_s", 170, "f64", "i32");
3371
+ trunc_f64_u = UNARY_OP("trunc_f64_u", 171, "f64", "i32");
3333
3372
  load = LOAD_OP("load", 40, "i32");
3334
3373
  load8_s = LOAD_OP("load8_s", 44, "i32");
3335
3374
  load8_u = LOAD_OP("load8_u", 45, "i32");
@@ -3341,6 +3380,8 @@ var I32 = class {
3341
3380
  reinterpret_f32 = UNARY_OP("reinterpret_f32", 188, "f32", "i32");
3342
3381
  trunc_sat_f32_s = UNARY_OP("trunc_sat_f32_s", [252, 0], "f32", "i32");
3343
3382
  trunc_sat_f32_u = UNARY_OP("trunc_sat_f32_u", [252, 1], "f32", "i32");
3383
+ trunc_sat_f64_s = UNARY_OP("trunc_sat_f64_s", [252, 2], "f64", "i32");
3384
+ trunc_sat_f64_u = UNARY_OP("trunc_sat_f64_u", [252, 3], "f64", "i32");
3344
3385
  };
3345
3386
  var F32 = class {
3346
3387
  constructor(cg) {
@@ -3360,6 +3401,8 @@ var F32 = class {
3360
3401
  for (let i = 0; i < 4; i++) this.cg._emit(bytes[i]);
3361
3402
  this.cg._push(this);
3362
3403
  }
3404
+ load = LOAD_OP("load", 42, "f32");
3405
+ store = STORE_OP("store", 56, "f32");
3363
3406
  eq = BINARY_OP("eq", 91, "f32", "f32", "i32");
3364
3407
  ne = BINARY_OP("ne", 92, "f32", "f32", "i32");
3365
3408
  lt = BINARY_OP("lt", 93, "f32", "f32", "i32");
@@ -3382,10 +3425,53 @@ var F32 = class {
3382
3425
  copysign = BINARY_OP("copysign", 152, "f32", "f32", "f32");
3383
3426
  convert_i32_s = UNARY_OP("convert_i32_s", 178, "i32", "f32");
3384
3427
  convert_i32_u = UNARY_OP("convert_i32_u", 179, "i32", "f32");
3385
- load = LOAD_OP("load", 42, "f32");
3386
- store = STORE_OP("store", 56, "f32");
3428
+ demote_f64 = UNARY_OP("demote_f64", 182, "f64", "f32");
3387
3429
  reinterpret_i32 = UNARY_OP("reinterpret_i32", 190, "i32", "f32");
3388
3430
  };
3431
+ var F64 = class {
3432
+ constructor(cg) {
3433
+ this.cg = cg;
3434
+ }
3435
+ get typeId() {
3436
+ return 124;
3437
+ }
3438
+ get name() {
3439
+ return "f64";
3440
+ }
3441
+ const(f) {
3442
+ this.cg._emit(68);
3443
+ const buffer = /* @__PURE__ */ new ArrayBuffer(8);
3444
+ new DataView(buffer).setFloat64(0, f, true);
3445
+ const bytes = new Uint8Array(buffer);
3446
+ for (let i = 0; i < 8; i++) this.cg._emit(bytes[i]);
3447
+ this.cg._push(this);
3448
+ }
3449
+ load = LOAD_OP("load", 43, "f64");
3450
+ store = STORE_OP("store", 57, "f64");
3451
+ eq = BINARY_OP("eq", 97, "f64", "f64", "i32");
3452
+ ne = BINARY_OP("ne", 98, "f64", "f64", "i32");
3453
+ lt = BINARY_OP("lt", 99, "f64", "f64", "i32");
3454
+ gt = BINARY_OP("gt", 100, "f64", "f64", "i32");
3455
+ le = BINARY_OP("le", 101, "f64", "f64", "i32");
3456
+ ge = BINARY_OP("ge", 102, "f64", "f64", "i32");
3457
+ abs = UNARY_OP("abs", 153, "f64", "f64");
3458
+ neg = UNARY_OP("neg", 154, "f64", "f64");
3459
+ ceil = UNARY_OP("ceil", 155, "f64", "f64");
3460
+ floor = UNARY_OP("floor", 156, "f64", "f64");
3461
+ trunc = UNARY_OP("trunc", 157, "f64", "f64");
3462
+ nearest = UNARY_OP("nearest", 158, "f64", "f64");
3463
+ sqrt = UNARY_OP("sqrt", 159, "f64", "f64");
3464
+ add = BINARY_OP("add", 160, "f64", "f64", "f64");
3465
+ sub = BINARY_OP("sub", 161, "f64", "f64", "f64");
3466
+ mul = BINARY_OP("mul", 162, "f64", "f64", "f64");
3467
+ div = BINARY_OP("div", 163, "f64", "f64", "f64");
3468
+ min = BINARY_OP("min", 164, "f64", "f64", "f64");
3469
+ max = BINARY_OP("max", 165, "f64", "f64", "f64");
3470
+ copysign = BINARY_OP("copysign", 166, "f64", "f64", "f64");
3471
+ convert_i32_s = UNARY_OP("convert_i32_s", 183, "i32", "f64");
3472
+ convert_i32_u = UNARY_OP("convert_i32_u", 184, "i32", "f64");
3473
+ promote_f32 = UNARY_OP("promote_f32", 187, "f32", "f64");
3474
+ };
3389
3475
  function VECTOR_OP(op, vopcode, inTypes, outType) {
3390
3476
  return function() {
3391
3477
  for (const inType of inTypes.toReversed()) {
@@ -3638,10 +3724,10 @@ function codegenWasm(kernel) {
3638
3724
  cg.local.get(acc);
3639
3725
  if (re.dtype === DType.Bool) cg.i32.and();
3640
3726
  else dty(cg, re.op, re.dtype).mul();
3641
- } else if (re.op === AluOp.Min || re.op === AluOp.Max) if (re.dtype === DType.Float32) {
3727
+ } else if (re.op === AluOp.Min || re.op === AluOp.Max) if (isFloatDtype(re.dtype)) {
3642
3728
  cg.local.get(acc);
3643
- if (re.op === AluOp.Min) cg.f32.min();
3644
- else cg.f32.max();
3729
+ if (re.op === AluOp.Min) dtyF(cg, re.op, re.dtype).min();
3730
+ else dtyF(cg, re.op, re.dtype).max();
3645
3731
  } else if ([
3646
3732
  DType.Int32,
3647
3733
  DType.Uint32,
@@ -3703,27 +3789,30 @@ function translateExp(cg, funcs, exp, ctx) {
3703
3789
  else if (op === AluOp.Sub) dty(cg, op, dtype).sub();
3704
3790
  else if (op === AluOp.Mul) if (dtype === DType.Bool) cg.i32.and();
3705
3791
  else dty(cg, op, dtype).mul();
3706
- else if (op === AluOp.Idiv) if (dtype === DType.Float32) cg.f32.div(), cg.f32.trunc();
3707
- else if (dtype === DType.Uint32) cg.i32.div_u();
3792
+ else if (op === AluOp.Idiv) if (isFloatDtype(dtype)) {
3793
+ dtyF(cg, op, dtype).div();
3794
+ dtyF(cg, op, dtype).trunc();
3795
+ } else if (dtype === DType.Uint32) cg.i32.div_u();
3708
3796
  else if (dtype === DType.Int32) cg.i32.div_s();
3709
3797
  else throw new UnsupportedOpError(op, dtype, "wasm");
3710
- else if (op === AluOp.Mod) if (dtype === DType.Float32) {
3711
- const a = cg.local.declare(cg.f32);
3712
- const b = cg.local.declare(cg.f32);
3798
+ else if (op === AluOp.Mod) if (isFloatDtype(dtype)) {
3799
+ const dt = dtyF(cg, op, dtype);
3800
+ const a = cg.local.declare(dt);
3801
+ const b = cg.local.declare(dt);
3713
3802
  cg.local.set(b);
3714
3803
  cg.local.tee(a);
3715
3804
  cg.local.get(a);
3716
3805
  cg.local.get(b);
3717
- cg.f32.div();
3718
- cg.f32.trunc();
3806
+ dt.div();
3807
+ dt.trunc();
3719
3808
  cg.local.get(b);
3720
- cg.f32.mul();
3721
- cg.f32.sub();
3809
+ dt.mul();
3810
+ dt.sub();
3722
3811
  } else if (dtype === DType.Uint32) cg.i32.rem_u();
3723
3812
  else if (dtype === DType.Int32) cg.i32.rem_s();
3724
3813
  else throw new UnsupportedOpError(op, dtype, "wasm");
3725
- else if (op === AluOp.Min || op === AluOp.Max) if (dtype === DType.Float32) if (op === AluOp.Min) cg.f32.min();
3726
- else cg.f32.max();
3814
+ else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
3815
+ else dtyF(cg, op, dtype).max();
3727
3816
  else if (dtype === DType.Int32 || dtype === DType.Uint32) {
3728
3817
  const a = cg.local.declare(cg.i32);
3729
3818
  const b = cg.local.declare(cg.i32);
@@ -3740,54 +3829,76 @@ function translateExp(cg, funcs, exp, ctx) {
3740
3829
  } else throw new UnsupportedOpError(op, dtype, "wasm");
3741
3830
  else if (op === AluOp.Cmplt) {
3742
3831
  const srcDtype = src[0].dtype;
3743
- if (srcDtype === DType.Float32) cg.f32.lt();
3832
+ if (isFloatDtype(srcDtype)) dtyF(cg, op, srcDtype).lt();
3744
3833
  else if (srcDtype === DType.Int32) cg.i32.lt_s();
3745
3834
  else if (srcDtype === DType.Uint32) cg.i32.lt_u();
3746
3835
  else throw new UnsupportedOpError(op, dtype, "wasm");
3747
3836
  } else if (op === AluOp.Cmpne) dty(cg, op, src[0].dtype).ne();
3748
3837
  else throw new UnsupportedOpError(op, dtype, "wasm");
3749
- } else if (AluGroup.Unary.has(op)) if (op === AluOp.Sin) gen(src[0]), cg.call(funcs.sin);
3750
- else if (op === AluOp.Cos) gen(src[0]), cg.call(funcs.cos);
3751
- else if (op === AluOp.Asin) gen(src[0]), cg.call(funcs.asin);
3752
- else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
3753
- else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
3754
- else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
3755
- else if (op === AluOp.Erf) gen(src[0]), cg.call(funcs.erf);
3756
- else if (op === AluOp.Erfc) gen(src[0]), cg.call(funcs.erfc);
3757
- else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
3758
- else if (op === AluOp.Reciprocal) cg.f32.const(1), gen(src[0]), cg.f32.div();
3759
- else if (op === AluOp.Cast) {
3760
- gen(src[0]);
3761
- const dtype0 = src[0].dtype;
3762
- const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
3763
- if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
3764
- else if (i32repr);
3765
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3766
- else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
3767
- else if (i32repr);
3768
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3769
- else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
3770
- else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
3771
- else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
3772
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3773
- else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
3774
- else if (i32repr) cg.i32.const(0), cg.i32.ne();
3775
- else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
3776
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3777
- else throw new UnsupportedOpError(op, dtype, "wasm");
3778
- } else if (op === AluOp.Bitcast) {
3779
- gen(src[0]);
3780
- const dtype0 = src[0].dtype;
3781
- const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
3782
- if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
3783
- else if (i32repr);
3784
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3785
- else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
3786
- else if (dtype0 === DType.Float32);
3787
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3788
- else throw new UnsupportedOpError(op, dtype, "wasm");
3789
- } else throw new UnsupportedOpError(op, dtype, "wasm");
3790
- else if (op === AluOp.Where) {
3838
+ } else if (AluGroup.Unary.has(op)) {
3839
+ const callFuncF32 = (func) => {
3840
+ if (dtype !== DType.Float32) if (dtype === DType.Float64) cg.f32.demote_f64();
3841
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3842
+ cg.call(func);
3843
+ if (dtype === DType.Float64) cg.f64.promote_f32();
3844
+ };
3845
+ if (op === AluOp.Sin) gen(src[0]), callFuncF32(funcs.sin);
3846
+ else if (op === AluOp.Cos) gen(src[0]), callFuncF32(funcs.cos);
3847
+ else if (op === AluOp.Asin) gen(src[0]), callFuncF32(funcs.asin);
3848
+ else if (op === AluOp.Atan) gen(src[0]), callFuncF32(funcs.atan);
3849
+ else if (op === AluOp.Exp) gen(src[0]), callFuncF32(funcs.exp);
3850
+ else if (op === AluOp.Log) gen(src[0]), callFuncF32(funcs.log);
3851
+ else if (op === AluOp.Erf) gen(src[0]), callFuncF32(funcs.erf);
3852
+ else if (op === AluOp.Erfc) gen(src[0]), callFuncF32(funcs.erfc);
3853
+ else if (op === AluOp.Sqrt) gen(src[0]), dtyF(cg, op, dtype).sqrt();
3854
+ else if (op === AluOp.Reciprocal) {
3855
+ const dt = dtyF(cg, op, dtype);
3856
+ dt.const(1), gen(src[0]), dt.div();
3857
+ } else if (op === AluOp.Floor) gen(src[0]), dtyF(cg, op, dtype).floor();
3858
+ else if (op === AluOp.Ceil) gen(src[0]), dtyF(cg, op, dtype).ceil();
3859
+ else if (op === AluOp.Cast) {
3860
+ gen(src[0]);
3861
+ const dtype0 = src[0].dtype;
3862
+ const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
3863
+ if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
3864
+ else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_s();
3865
+ else if (i32repr);
3866
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3867
+ else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
3868
+ else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_u();
3869
+ else if (i32repr);
3870
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3871
+ else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
3872
+ else if (dtype0 === DType.Float64) cg.f32.demote_f64();
3873
+ else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
3874
+ else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
3875
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3876
+ else if (dtype === DType.Float64) if (dtype0 === DType.Float32) cg.f64.promote_f32();
3877
+ else if (dtype0 === DType.Float64);
3878
+ else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f64.convert_i32_s();
3879
+ else if (dtype0 === DType.Uint32) cg.f64.convert_i32_u();
3880
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3881
+ else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
3882
+ else if (i32repr) cg.i32.const(0), cg.i32.ne();
3883
+ else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
3884
+ else if (dtype0 === DType.Float64) cg.f64.const(0), cg.f64.ne();
3885
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3886
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3887
+ } else if (op === AluOp.Bitcast) {
3888
+ gen(src[0]);
3889
+ const dtype0 = src[0].dtype;
3890
+ if (dtype !== dtype0) {
3891
+ const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
3892
+ if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
3893
+ else if (i32repr);
3894
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3895
+ else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
3896
+ else if (dtype0 === DType.Float32);
3897
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3898
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3899
+ }
3900
+ } else throw new UnsupportedOpError(op, dtype, "wasm");
3901
+ } else if (op === AluOp.Where) {
3791
3902
  gen(src[1]);
3792
3903
  gen(src[2]);
3793
3904
  gen(src[0]);
@@ -3832,12 +3943,20 @@ function translateExp(cg, funcs, exp, ctx) {
3832
3943
  function dty(cg, op, dtype) {
3833
3944
  switch (dtype) {
3834
3945
  case DType.Float32: return cg.f32;
3946
+ case DType.Float64: return cg.f64;
3835
3947
  case DType.Int32:
3836
3948
  case DType.Uint32:
3837
3949
  case DType.Bool: return cg.i32;
3838
3950
  default: throw new UnsupportedOpError(op, dtype, "wasm");
3839
3951
  }
3840
3952
  }
3953
+ function dtyF(cg, op, dtype) {
3954
+ switch (dtype) {
3955
+ case DType.Float32: return cg.f32;
3956
+ case DType.Float64: return cg.f64;
3957
+ default: throw new UnsupportedOpError(op, dtype, "wasm");
3958
+ }
3959
+ }
3841
3960
 
3842
3961
  //#endregion
3843
3962
  //#region src/backend.ts
@@ -3846,10 +3965,10 @@ const devices = [
3846
3965
  "wasm",
3847
3966
  "webgpu"
3848
3967
  ];
3849
- let defaultBackend = "wasm";
3850
3968
  const initializedBackends = /* @__PURE__ */ new Map();
3851
3969
  initializedBackends.set("cpu", new CpuBackend());
3852
- initializedBackends.set("wasm", new WasmBackend());
3970
+ if (typeof WebAssembly !== "undefined") initializedBackends.set("wasm", new WasmBackend());
3971
+ let defaultBackend = initializedBackends.has("wasm") ? "wasm" : "cpu";
3853
3972
  /** Configure the default device for arrays. */
3854
3973
  function defaultDevice(device) {
3855
3974
  if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
@@ -3876,12 +3995,14 @@ async function init(...devicesToInit) {
3876
3995
  /** Create a backend, if available. Internal function called by `init()`. */
3877
3996
  async function createBackend(device) {
3878
3997
  if (device === "cpu") return new CpuBackend();
3879
- else if (device === "wasm") return new WasmBackend();
3880
- else if (device === "webgpu") {
3998
+ else if (device === "wasm") {
3999
+ if (typeof WebAssembly === "undefined") return null;
4000
+ return new WasmBackend();
4001
+ } else if (device === "webgpu") {
3881
4002
  if (!navigator.gpu) return null;
3882
4003
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3883
4004
  if (!adapter) return null;
3884
- const { WebGPUBackend } = await import("./webgpu-LGi2A3mS.js");
4005
+ const { WebGPUBackend } = await import("./webgpu-BGuG58KZ.js");
3885
4006
  const importantLimits = [
3886
4007
  "maxBufferSize",
3887
4008
  "maxComputeInvocationsPerWorkgroup",
@@ -3935,4 +4056,4 @@ var UnsupportedOpError = class extends Error {
3935
4056
 
3936
4057
  //#endregion
3937
4058
  export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
3938
- //# sourceMappingURL=backend-DwIAd0AG.js.map
4059
+ //# sourceMappingURL=backend-BqymqzuU.js.map