@jax-js/jax 0.1.0 → 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -290,10 +290,11 @@ var FpHash = class FpHash {
290
290
  };
291
291
  /** Run a function while caching it inline inside a `Map`. */
292
292
  function runWithCache(cache, key, thunk) {
293
- if (cache.has(key)) return cache.get(key);
293
+ const keyStr = JSON.stringify(key);
294
+ if (cache.has(keyStr)) return cache.get(keyStr);
294
295
  else {
295
296
  const value = thunk();
296
- cache.set(key, value);
297
+ cache.set(keyStr, value);
297
298
  return value;
298
299
  }
299
300
  }
@@ -307,6 +308,7 @@ let DType = /* @__PURE__ */ function(DType$1) {
307
308
  DType$1["Uint32"] = "uint32";
308
309
  DType$1["Bool"] = "bool";
309
310
  DType$1["Float16"] = "float16";
311
+ DType$1["Float64"] = "float64";
310
312
  return DType$1;
311
313
  }({});
312
314
  const byteWidth = (dtype) => {
@@ -316,10 +318,11 @@ const byteWidth = (dtype) => {
316
318
  case DType.Uint32:
317
319
  case DType.Bool: return 4;
318
320
  case DType.Float16: return 2;
321
+ case DType.Float64: return 8;
319
322
  default: throw new TypeError(`Unknown dtype: ${dtype}`);
320
323
  }
321
324
  };
322
- const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
325
+ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16 || dtype === DType.Float64;
323
326
  /**
324
327
  * Promote two dtypes to their join according to the type lattice.
325
328
  *
@@ -329,7 +332,7 @@ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float
329
332
  *
330
333
  * **Type lattice:**
331
334
  * ```text
332
- * bool -> uint32 -> int32 -> float16 -> float32
335
+ * bool -> uint32 -> int32 -> float16 -> float32 -> float64
333
336
  * weakType --^
334
337
  * ```
335
338
  *
@@ -351,7 +354,8 @@ function promoteTypes(dtype1, dtype2) {
351
354
  [DType.Uint32]: 1,
352
355
  [DType.Int32]: 2,
353
356
  [DType.Float16]: 3,
354
- [DType.Float32]: 4
357
+ [DType.Float32]: 4,
358
+ [DType.Float64]: 5
355
359
  };
356
360
  return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
357
361
  }
@@ -364,6 +368,7 @@ function dtypedArray(dtype, data) {
364
368
  case DType.Bool: return new Int32Array(buffer, byteOffset, length);
365
369
  case DType.Uint32: return new Uint32Array(buffer, byteOffset, length);
366
370
  case DType.Float16: return new Float16Array(buffer, byteOffset, length);
371
+ case DType.Float64: return new Float64Array(buffer, byteOffset, length);
367
372
  default: throw new Error(`Unimplemented dtype: ${dtype}`);
368
373
  }
369
374
  }
@@ -374,6 +379,7 @@ function dtypedJsArray(dtype, data) {
374
379
  case DType.Bool: return new Int32Array(data);
375
380
  case DType.Uint32: return new Uint32Array(data);
376
381
  case DType.Float16: return new Float16Array(data);
382
+ case DType.Float64: return new Float64Array(data);
377
383
  default: throw new Error(`Unimplemented dtype: ${dtype}`);
378
384
  }
379
385
  }
@@ -445,6 +451,14 @@ var AluExp = class AluExp {
445
451
  static sqrt(a) {
446
452
  return new AluExp(AluOp.Sqrt, a.dtype, [a]);
447
453
  }
454
+ static floor(a) {
455
+ if (!isFloatDtype(a.dtype)) return a;
456
+ return new AluExp(AluOp.Floor, a.dtype, [a]);
457
+ }
458
+ static ceil(a) {
459
+ if (!isFloatDtype(a.dtype)) return a;
460
+ return new AluExp(AluOp.Ceil, a.dtype, [a]);
461
+ }
448
462
  static reciprocal(a) {
449
463
  return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
450
464
  }
@@ -511,6 +525,9 @@ var AluExp = class AluExp {
511
525
  static f16(value) {
512
526
  return AluExp.const(DType.Float16, value);
513
527
  }
528
+ static f64(value) {
529
+ return AluExp.const(DType.Float64, value);
530
+ }
514
531
  not() {
515
532
  if (this.dtype !== DType.Bool) throw new Error("not() can only be called on boolean expressions");
516
533
  return AluExp.cmpne(this, AluExp.const(DType.Bool, true));
@@ -521,7 +538,8 @@ var AluExp = class AluExp {
521
538
  const hasher = new FpHash();
522
539
  hasher.update(this.op);
523
540
  hasher.update(this.dtype);
524
- hasher.update(JSON.stringify(this.arg));
541
+ if (this.op === AluOp.Const) hasher.update(this.arg);
542
+ else hasher.update(JSON.stringify(this.arg));
525
543
  hasher.update(this.src.length);
526
544
  for (const s of this.src) hasher.update(s);
527
545
  this.#hash = hasher.value;
@@ -621,6 +639,12 @@ var AluExp = class AluExp {
621
639
  case AluOp.Sqrt:
622
640
  ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
623
641
  break;
642
+ case AluOp.Floor:
643
+ ret = [Math.floor(src[0].min), Math.floor(src[0].max)];
644
+ break;
645
+ case AluOp.Ceil:
646
+ ret = [Math.ceil(src[0].min), Math.ceil(src[0].max)];
647
+ break;
624
648
  case AluOp.Reciprocal:
625
649
  if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
626
650
  ret = [1 / src[0].max, 1 / src[0].min];
@@ -758,6 +782,7 @@ var AluExp = class AluExp {
758
782
  if (op === AluOp.Mul && x === 1) return src[1 - i];
759
783
  if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
760
784
  if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
785
+ if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
761
786
  }
762
787
  if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
763
788
  const [a, b] = src[1].src;
@@ -852,7 +877,7 @@ var AluExp = class AluExp {
852
877
  else return p(p(src[0].src[0], src[1]), src[0].src[1]).simplify(cache);
853
878
  if (src[1].op === op && src[1].src[1].op === AluOp.Const) return p(p(src[0], src[1].src[0]), src[1].src[1]).simplify(cache);
854
879
  }
855
- if (op === AluOp.Mod || op === AluOp.Idiv && src[1].#isConstInt()) {
880
+ if ((op === AluOp.Mod || op === AluOp.Idiv) && src[1].#isConstInt()) {
856
881
  const [x, y] = src;
857
882
  {
858
883
  const factors = [];
@@ -942,6 +967,8 @@ var AluExp = class AluExp {
942
967
  case AluOp.Erf: return erf(x);
943
968
  case AluOp.Erfc: return erfc(x);
944
969
  case AluOp.Sqrt: return Math.sqrt(x);
970
+ case AluOp.Floor: return Math.floor(x);
971
+ case AluOp.Ceil: return Math.ceil(x);
945
972
  case AluOp.Reciprocal: return 1 / x;
946
973
  case AluOp.Cast: {
947
974
  const wasFloat = isFloatDtype(this.src[0].dtype);
@@ -959,11 +986,13 @@ var AluExp = class AluExp {
959
986
  else if (fromType === DType.Int32) view.setInt32(0, x, true);
960
987
  else if (fromType === DType.Uint32) view.setUint32(0, x, true);
961
988
  else if (fromType === DType.Float16) view.setFloat16(0, x, true);
989
+ else if (fromType === DType.Float64) view.setFloat64(0, x, true);
962
990
  else throw new Error(`Unsupported bitcast from ${fromType}`);
963
991
  if (this.dtype === DType.Float32) return view.getFloat32(0, true);
964
992
  else if (this.dtype === DType.Int32) return view.getInt32(0, true);
965
993
  else if (this.dtype === DType.Uint32) return view.getUint32(0, true);
966
994
  else if (this.dtype === DType.Float16) return view.getFloat16(0, true);
995
+ else if (this.dtype === DType.Float64) return view.getFloat64(0, true);
967
996
  else throw new Error(`Unsupported bitcast to ${this.dtype}`);
968
997
  }
969
998
  default: throw new Error(`Missing implemementation for ${this.op}`);
@@ -1129,6 +1158,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
1129
1158
  AluOp$1["Erf"] = "Erf";
1130
1159
  AluOp$1["Erfc"] = "Erfc";
1131
1160
  AluOp$1["Sqrt"] = "Sqrt";
1161
+ AluOp$1["Floor"] = "Floor";
1162
+ AluOp$1["Ceil"] = "Ceil";
1132
1163
  AluOp$1["Reciprocal"] = "Reciprocal";
1133
1164
  AluOp$1["Cast"] = "Cast";
1134
1165
  AluOp$1["Bitcast"] = "Bitcast";
@@ -1163,6 +1194,8 @@ const AluGroup = {
1163
1194
  AluOp.Erf,
1164
1195
  AluOp.Erfc,
1165
1196
  AluOp.Sqrt,
1197
+ AluOp.Floor,
1198
+ AluOp.Ceil,
1166
1199
  AluOp.Reciprocal,
1167
1200
  AluOp.Cast,
1168
1201
  AluOp.Bitcast
@@ -1190,7 +1223,9 @@ const AluGroup = {
1190
1223
  AluOp.Erf,
1191
1224
  AluOp.Erfc,
1192
1225
  AluOp.Sqrt,
1193
- AluOp.Reciprocal
1226
+ AluOp.Reciprocal,
1227
+ AluOp.Floor,
1228
+ AluOp.Ceil
1194
1229
  ])
1195
1230
  };
1196
1231
  /** Common variables that can be substituted in expressions. */
@@ -2926,6 +2961,7 @@ var CodeGenerator = class {
2926
2961
  local;
2927
2962
  i32;
2928
2963
  f32;
2964
+ f64;
2929
2965
  v128;
2930
2966
  i32x4;
2931
2967
  f32x4;
@@ -2945,6 +2981,7 @@ var CodeGenerator = class {
2945
2981
  this.local = new Local(this);
2946
2982
  this.i32 = new I32(this);
2947
2983
  this.f32 = new F32(this);
2984
+ this.f64 = new F64(this);
2948
2985
  this.v128 = new V128(this);
2949
2986
  this.i32x4 = new I32x4(this);
2950
2987
  this.f32x4 = new F32x4(this);
@@ -3331,6 +3368,8 @@ var I32 = class {
3331
3368
  ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
3332
3369
  trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
3333
3370
  trunc_f32_u = UNARY_OP("trunc_f32_u", 169, "f32", "i32");
3371
+ trunc_f64_s = UNARY_OP("trunc_f64_s", 170, "f64", "i32");
3372
+ trunc_f64_u = UNARY_OP("trunc_f64_u", 171, "f64", "i32");
3334
3373
  load = LOAD_OP("load", 40, "i32");
3335
3374
  load8_s = LOAD_OP("load8_s", 44, "i32");
3336
3375
  load8_u = LOAD_OP("load8_u", 45, "i32");
@@ -3342,6 +3381,8 @@ var I32 = class {
3342
3381
  reinterpret_f32 = UNARY_OP("reinterpret_f32", 188, "f32", "i32");
3343
3382
  trunc_sat_f32_s = UNARY_OP("trunc_sat_f32_s", [252, 0], "f32", "i32");
3344
3383
  trunc_sat_f32_u = UNARY_OP("trunc_sat_f32_u", [252, 1], "f32", "i32");
3384
+ trunc_sat_f64_s = UNARY_OP("trunc_sat_f64_s", [252, 2], "f64", "i32");
3385
+ trunc_sat_f64_u = UNARY_OP("trunc_sat_f64_u", [252, 3], "f64", "i32");
3345
3386
  };
3346
3387
  var F32 = class {
3347
3388
  constructor(cg) {
@@ -3361,6 +3402,8 @@ var F32 = class {
3361
3402
  for (let i = 0; i < 4; i++) this.cg._emit(bytes[i]);
3362
3403
  this.cg._push(this);
3363
3404
  }
3405
+ load = LOAD_OP("load", 42, "f32");
3406
+ store = STORE_OP("store", 56, "f32");
3364
3407
  eq = BINARY_OP("eq", 91, "f32", "f32", "i32");
3365
3408
  ne = BINARY_OP("ne", 92, "f32", "f32", "i32");
3366
3409
  lt = BINARY_OP("lt", 93, "f32", "f32", "i32");
@@ -3383,10 +3426,53 @@ var F32 = class {
3383
3426
  copysign = BINARY_OP("copysign", 152, "f32", "f32", "f32");
3384
3427
  convert_i32_s = UNARY_OP("convert_i32_s", 178, "i32", "f32");
3385
3428
  convert_i32_u = UNARY_OP("convert_i32_u", 179, "i32", "f32");
3386
- load = LOAD_OP("load", 42, "f32");
3387
- store = STORE_OP("store", 56, "f32");
3429
+ demote_f64 = UNARY_OP("demote_f64", 182, "f64", "f32");
3388
3430
  reinterpret_i32 = UNARY_OP("reinterpret_i32", 190, "i32", "f32");
3389
3431
  };
3432
+ var F64 = class {
3433
+ constructor(cg) {
3434
+ this.cg = cg;
3435
+ }
3436
+ get typeId() {
3437
+ return 124;
3438
+ }
3439
+ get name() {
3440
+ return "f64";
3441
+ }
3442
+ const(f) {
3443
+ this.cg._emit(68);
3444
+ const buffer = /* @__PURE__ */ new ArrayBuffer(8);
3445
+ new DataView(buffer).setFloat64(0, f, true);
3446
+ const bytes = new Uint8Array(buffer);
3447
+ for (let i = 0; i < 8; i++) this.cg._emit(bytes[i]);
3448
+ this.cg._push(this);
3449
+ }
3450
+ load = LOAD_OP("load", 43, "f64");
3451
+ store = STORE_OP("store", 57, "f64");
3452
+ eq = BINARY_OP("eq", 97, "f64", "f64", "i32");
3453
+ ne = BINARY_OP("ne", 98, "f64", "f64", "i32");
3454
+ lt = BINARY_OP("lt", 99, "f64", "f64", "i32");
3455
+ gt = BINARY_OP("gt", 100, "f64", "f64", "i32");
3456
+ le = BINARY_OP("le", 101, "f64", "f64", "i32");
3457
+ ge = BINARY_OP("ge", 102, "f64", "f64", "i32");
3458
+ abs = UNARY_OP("abs", 153, "f64", "f64");
3459
+ neg = UNARY_OP("neg", 154, "f64", "f64");
3460
+ ceil = UNARY_OP("ceil", 155, "f64", "f64");
3461
+ floor = UNARY_OP("floor", 156, "f64", "f64");
3462
+ trunc = UNARY_OP("trunc", 157, "f64", "f64");
3463
+ nearest = UNARY_OP("nearest", 158, "f64", "f64");
3464
+ sqrt = UNARY_OP("sqrt", 159, "f64", "f64");
3465
+ add = BINARY_OP("add", 160, "f64", "f64", "f64");
3466
+ sub = BINARY_OP("sub", 161, "f64", "f64", "f64");
3467
+ mul = BINARY_OP("mul", 162, "f64", "f64", "f64");
3468
+ div = BINARY_OP("div", 163, "f64", "f64", "f64");
3469
+ min = BINARY_OP("min", 164, "f64", "f64", "f64");
3470
+ max = BINARY_OP("max", 165, "f64", "f64", "f64");
3471
+ copysign = BINARY_OP("copysign", 166, "f64", "f64", "f64");
3472
+ convert_i32_s = UNARY_OP("convert_i32_s", 183, "i32", "f64");
3473
+ convert_i32_u = UNARY_OP("convert_i32_u", 184, "i32", "f64");
3474
+ promote_f32 = UNARY_OP("promote_f32", 187, "f32", "f64");
3475
+ };
3390
3476
  function VECTOR_OP(op, vopcode, inTypes, outType) {
3391
3477
  return function() {
3392
3478
  for (const inType of inTypes.toReversed()) {
@@ -3639,10 +3725,10 @@ function codegenWasm(kernel) {
3639
3725
  cg.local.get(acc);
3640
3726
  if (re.dtype === DType.Bool) cg.i32.and();
3641
3727
  else dty(cg, re.op, re.dtype).mul();
3642
- } else if (re.op === AluOp.Min || re.op === AluOp.Max) if (re.dtype === DType.Float32) {
3728
+ } else if (re.op === AluOp.Min || re.op === AluOp.Max) if (isFloatDtype(re.dtype)) {
3643
3729
  cg.local.get(acc);
3644
- if (re.op === AluOp.Min) cg.f32.min();
3645
- else cg.f32.max();
3730
+ if (re.op === AluOp.Min) dtyF(cg, re.op, re.dtype).min();
3731
+ else dtyF(cg, re.op, re.dtype).max();
3646
3732
  } else if ([
3647
3733
  DType.Int32,
3648
3734
  DType.Uint32,
@@ -3704,27 +3790,30 @@ function translateExp(cg, funcs, exp, ctx) {
3704
3790
  else if (op === AluOp.Sub) dty(cg, op, dtype).sub();
3705
3791
  else if (op === AluOp.Mul) if (dtype === DType.Bool) cg.i32.and();
3706
3792
  else dty(cg, op, dtype).mul();
3707
- else if (op === AluOp.Idiv) if (dtype === DType.Float32) cg.f32.div(), cg.f32.trunc();
3708
- else if (dtype === DType.Uint32) cg.i32.div_u();
3793
+ else if (op === AluOp.Idiv) if (isFloatDtype(dtype)) {
3794
+ dtyF(cg, op, dtype).div();
3795
+ dtyF(cg, op, dtype).trunc();
3796
+ } else if (dtype === DType.Uint32) cg.i32.div_u();
3709
3797
  else if (dtype === DType.Int32) cg.i32.div_s();
3710
3798
  else throw new UnsupportedOpError(op, dtype, "wasm");
3711
- else if (op === AluOp.Mod) if (dtype === DType.Float32) {
3712
- const a = cg.local.declare(cg.f32);
3713
- const b = cg.local.declare(cg.f32);
3799
+ else if (op === AluOp.Mod) if (isFloatDtype(dtype)) {
3800
+ const dt = dtyF(cg, op, dtype);
3801
+ const a = cg.local.declare(dt);
3802
+ const b = cg.local.declare(dt);
3714
3803
  cg.local.set(b);
3715
3804
  cg.local.tee(a);
3716
3805
  cg.local.get(a);
3717
3806
  cg.local.get(b);
3718
- cg.f32.div();
3719
- cg.f32.trunc();
3807
+ dt.div();
3808
+ dt.trunc();
3720
3809
  cg.local.get(b);
3721
- cg.f32.mul();
3722
- cg.f32.sub();
3810
+ dt.mul();
3811
+ dt.sub();
3723
3812
  } else if (dtype === DType.Uint32) cg.i32.rem_u();
3724
3813
  else if (dtype === DType.Int32) cg.i32.rem_s();
3725
3814
  else throw new UnsupportedOpError(op, dtype, "wasm");
3726
- else if (op === AluOp.Min || op === AluOp.Max) if (dtype === DType.Float32) if (op === AluOp.Min) cg.f32.min();
3727
- else cg.f32.max();
3815
+ else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
3816
+ else dtyF(cg, op, dtype).max();
3728
3817
  else if (dtype === DType.Int32 || dtype === DType.Uint32) {
3729
3818
  const a = cg.local.declare(cg.i32);
3730
3819
  const b = cg.local.declare(cg.i32);
@@ -3741,54 +3830,76 @@ function translateExp(cg, funcs, exp, ctx) {
3741
3830
  } else throw new UnsupportedOpError(op, dtype, "wasm");
3742
3831
  else if (op === AluOp.Cmplt) {
3743
3832
  const srcDtype = src[0].dtype;
3744
- if (srcDtype === DType.Float32) cg.f32.lt();
3833
+ if (isFloatDtype(srcDtype)) dtyF(cg, op, srcDtype).lt();
3745
3834
  else if (srcDtype === DType.Int32) cg.i32.lt_s();
3746
3835
  else if (srcDtype === DType.Uint32) cg.i32.lt_u();
3747
3836
  else throw new UnsupportedOpError(op, dtype, "wasm");
3748
3837
  } else if (op === AluOp.Cmpne) dty(cg, op, src[0].dtype).ne();
3749
3838
  else throw new UnsupportedOpError(op, dtype, "wasm");
3750
- } else if (AluGroup.Unary.has(op)) if (op === AluOp.Sin) gen(src[0]), cg.call(funcs.sin);
3751
- else if (op === AluOp.Cos) gen(src[0]), cg.call(funcs.cos);
3752
- else if (op === AluOp.Asin) gen(src[0]), cg.call(funcs.asin);
3753
- else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
3754
- else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
3755
- else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
3756
- else if (op === AluOp.Erf) gen(src[0]), cg.call(funcs.erf);
3757
- else if (op === AluOp.Erfc) gen(src[0]), cg.call(funcs.erfc);
3758
- else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
3759
- else if (op === AluOp.Reciprocal) cg.f32.const(1), gen(src[0]), cg.f32.div();
3760
- else if (op === AluOp.Cast) {
3761
- gen(src[0]);
3762
- const dtype0 = src[0].dtype;
3763
- const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
3764
- if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
3765
- else if (i32repr);
3766
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3767
- else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
3768
- else if (i32repr);
3769
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3770
- else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
3771
- else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
3772
- else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
3773
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3774
- else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
3775
- else if (i32repr) cg.i32.const(0), cg.i32.ne();
3776
- else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
3777
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3778
- else throw new UnsupportedOpError(op, dtype, "wasm");
3779
- } else if (op === AluOp.Bitcast) {
3780
- gen(src[0]);
3781
- const dtype0 = src[0].dtype;
3782
- const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
3783
- if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
3784
- else if (i32repr);
3785
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3786
- else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
3787
- else if (dtype0 === DType.Float32);
3788
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3789
- else throw new UnsupportedOpError(op, dtype, "wasm");
3790
- } else throw new UnsupportedOpError(op, dtype, "wasm");
3791
- else if (op === AluOp.Where) {
3839
+ } else if (AluGroup.Unary.has(op)) {
3840
+ const callFuncF32 = (func) => {
3841
+ if (dtype !== DType.Float32) if (dtype === DType.Float64) cg.f32.demote_f64();
3842
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3843
+ cg.call(func);
3844
+ if (dtype === DType.Float64) cg.f64.promote_f32();
3845
+ };
3846
+ if (op === AluOp.Sin) gen(src[0]), callFuncF32(funcs.sin);
3847
+ else if (op === AluOp.Cos) gen(src[0]), callFuncF32(funcs.cos);
3848
+ else if (op === AluOp.Asin) gen(src[0]), callFuncF32(funcs.asin);
3849
+ else if (op === AluOp.Atan) gen(src[0]), callFuncF32(funcs.atan);
3850
+ else if (op === AluOp.Exp) gen(src[0]), callFuncF32(funcs.exp);
3851
+ else if (op === AluOp.Log) gen(src[0]), callFuncF32(funcs.log);
3852
+ else if (op === AluOp.Erf) gen(src[0]), callFuncF32(funcs.erf);
3853
+ else if (op === AluOp.Erfc) gen(src[0]), callFuncF32(funcs.erfc);
3854
+ else if (op === AluOp.Sqrt) gen(src[0]), dtyF(cg, op, dtype).sqrt();
3855
+ else if (op === AluOp.Reciprocal) {
3856
+ const dt = dtyF(cg, op, dtype);
3857
+ dt.const(1), gen(src[0]), dt.div();
3858
+ } else if (op === AluOp.Floor) gen(src[0]), dtyF(cg, op, dtype).floor();
3859
+ else if (op === AluOp.Ceil) gen(src[0]), dtyF(cg, op, dtype).ceil();
3860
+ else if (op === AluOp.Cast) {
3861
+ gen(src[0]);
3862
+ const dtype0 = src[0].dtype;
3863
+ const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
3864
+ if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
3865
+ else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_s();
3866
+ else if (i32repr);
3867
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3868
+ else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
3869
+ else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_u();
3870
+ else if (i32repr);
3871
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3872
+ else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
3873
+ else if (dtype0 === DType.Float64) cg.f32.demote_f64();
3874
+ else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
3875
+ else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
3876
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3877
+ else if (dtype === DType.Float64) if (dtype0 === DType.Float32) cg.f64.promote_f32();
3878
+ else if (dtype0 === DType.Float64);
3879
+ else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f64.convert_i32_s();
3880
+ else if (dtype0 === DType.Uint32) cg.f64.convert_i32_u();
3881
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3882
+ else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
3883
+ else if (i32repr) cg.i32.const(0), cg.i32.ne();
3884
+ else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
3885
+ else if (dtype0 === DType.Float64) cg.f64.const(0), cg.f64.ne();
3886
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3887
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3888
+ } else if (op === AluOp.Bitcast) {
3889
+ gen(src[0]);
3890
+ const dtype0 = src[0].dtype;
3891
+ if (dtype !== dtype0) {
3892
+ const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
3893
+ if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
3894
+ else if (i32repr);
3895
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3896
+ else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
3897
+ else if (dtype0 === DType.Float32);
3898
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3899
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3900
+ }
3901
+ } else throw new UnsupportedOpError(op, dtype, "wasm");
3902
+ } else if (op === AluOp.Where) {
3792
3903
  gen(src[1]);
3793
3904
  gen(src[2]);
3794
3905
  gen(src[0]);
@@ -3833,12 +3944,20 @@ function translateExp(cg, funcs, exp, ctx) {
3833
3944
  function dty(cg, op, dtype) {
3834
3945
  switch (dtype) {
3835
3946
  case DType.Float32: return cg.f32;
3947
+ case DType.Float64: return cg.f64;
3836
3948
  case DType.Int32:
3837
3949
  case DType.Uint32:
3838
3950
  case DType.Bool: return cg.i32;
3839
3951
  default: throw new UnsupportedOpError(op, dtype, "wasm");
3840
3952
  }
3841
3953
  }
3954
+ function dtyF(cg, op, dtype) {
3955
+ switch (dtype) {
3956
+ case DType.Float32: return cg.f32;
3957
+ case DType.Float64: return cg.f64;
3958
+ default: throw new UnsupportedOpError(op, dtype, "wasm");
3959
+ }
3960
+ }
3842
3961
 
3843
3962
  //#endregion
3844
3963
  //#region src/backend.ts
@@ -3847,10 +3966,10 @@ const devices = [
3847
3966
  "wasm",
3848
3967
  "webgpu"
3849
3968
  ];
3850
- let defaultBackend = "wasm";
3851
3969
  const initializedBackends = /* @__PURE__ */ new Map();
3852
3970
  initializedBackends.set("cpu", new CpuBackend());
3853
- initializedBackends.set("wasm", new WasmBackend());
3971
+ if (typeof WebAssembly !== "undefined") initializedBackends.set("wasm", new WasmBackend());
3972
+ let defaultBackend = initializedBackends.has("wasm") ? "wasm" : "cpu";
3854
3973
  /** Configure the default device for arrays. */
3855
3974
  function defaultDevice(device) {
3856
3975
  if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
@@ -3877,12 +3996,14 @@ async function init(...devicesToInit) {
3877
3996
  /** Create a backend, if available. Internal function called by `init()`. */
3878
3997
  async function createBackend(device) {
3879
3998
  if (device === "cpu") return new CpuBackend();
3880
- else if (device === "wasm") return new WasmBackend();
3881
- else if (device === "webgpu") {
3999
+ else if (device === "wasm") {
4000
+ if (typeof WebAssembly === "undefined") return null;
4001
+ return new WasmBackend();
4002
+ } else if (device === "webgpu") {
3882
4003
  if (!navigator.gpu) return null;
3883
4004
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3884
4005
  if (!adapter) return null;
3885
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BE7zA_01.cjs"));
4006
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-CcGP160M.cjs"));
3886
4007
  const importantLimits = [
3887
4008
  "maxBufferSize",
3888
4009
  "maxComputeInvocationsPerWorkgroup",
@@ -4229,4 +4350,4 @@ Object.defineProperty(exports, 'zipn', {
4229
4350
  return zipn;
4230
4351
  }
4231
4352
  });
4232
- //# sourceMappingURL=backend-FtkbO6pI.cjs.map
4353
+ //# sourceMappingURL=backend-DeVfWEFS.cjs.map