@jax-js/jax 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -306,6 +306,7 @@ let DType = /* @__PURE__ */ function(DType$1) {
306
306
  DType$1["Uint32"] = "uint32";
307
307
  DType$1["Bool"] = "bool";
308
308
  DType$1["Float16"] = "float16";
309
+ DType$1["Float64"] = "float64";
309
310
  return DType$1;
310
311
  }({});
311
312
  const byteWidth = (dtype) => {
@@ -315,10 +316,11 @@ const byteWidth = (dtype) => {
315
316
  case DType.Uint32:
316
317
  case DType.Bool: return 4;
317
318
  case DType.Float16: return 2;
319
+ case DType.Float64: return 8;
318
320
  default: throw new TypeError(`Unknown dtype: ${dtype}`);
319
321
  }
320
322
  };
321
- const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
323
+ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16 || dtype === DType.Float64;
322
324
  /**
323
325
  * Promote two dtypes to their join according to the type lattice.
324
326
  *
@@ -328,7 +330,7 @@ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float
328
330
  *
329
331
  * **Type lattice:**
330
332
  * ```text
331
- * bool -> uint32 -> int32 -> float16 -> float32
333
+ * bool -> uint32 -> int32 -> float16 -> float32 -> float64
332
334
  * weakType --^
333
335
  * ```
334
336
  *
@@ -350,7 +352,8 @@ function promoteTypes(dtype1, dtype2) {
350
352
  [DType.Uint32]: 1,
351
353
  [DType.Int32]: 2,
352
354
  [DType.Float16]: 3,
353
- [DType.Float32]: 4
355
+ [DType.Float32]: 4,
356
+ [DType.Float64]: 5
354
357
  };
355
358
  return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
356
359
  }
@@ -363,6 +366,7 @@ function dtypedArray(dtype, data) {
363
366
  case DType.Bool: return new Int32Array(buffer, byteOffset, length);
364
367
  case DType.Uint32: return new Uint32Array(buffer, byteOffset, length);
365
368
  case DType.Float16: return new Float16Array(buffer, byteOffset, length);
369
+ case DType.Float64: return new Float64Array(buffer, byteOffset, length);
366
370
  default: throw new Error(`Unimplemented dtype: ${dtype}`);
367
371
  }
368
372
  }
@@ -373,6 +377,7 @@ function dtypedJsArray(dtype, data) {
373
377
  case DType.Bool: return new Int32Array(data);
374
378
  case DType.Uint32: return new Uint32Array(data);
375
379
  case DType.Float16: return new Float16Array(data);
380
+ case DType.Float64: return new Float64Array(data);
376
381
  default: throw new Error(`Unimplemented dtype: ${dtype}`);
377
382
  }
378
383
  }
@@ -510,6 +515,9 @@ var AluExp = class AluExp {
510
515
  static f16(value) {
511
516
  return AluExp.const(DType.Float16, value);
512
517
  }
518
+ static f64(value) {
519
+ return AluExp.const(DType.Float64, value);
520
+ }
513
521
  not() {
514
522
  if (this.dtype !== DType.Bool) throw new Error("not() can only be called on boolean expressions");
515
523
  return AluExp.cmpne(this, AluExp.const(DType.Bool, true));
@@ -520,7 +528,8 @@ var AluExp = class AluExp {
520
528
  const hasher = new FpHash();
521
529
  hasher.update(this.op);
522
530
  hasher.update(this.dtype);
523
- hasher.update(JSON.stringify(this.arg));
531
+ if (this.op === AluOp.Const) hasher.update(this.arg);
532
+ else hasher.update(JSON.stringify(this.arg));
524
533
  hasher.update(this.src.length);
525
534
  for (const s of this.src) hasher.update(s);
526
535
  this.#hash = hasher.value;
@@ -757,6 +766,7 @@ var AluExp = class AluExp {
757
766
  if (op === AluOp.Mul && x === 1) return src[1 - i];
758
767
  if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
759
768
  if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
769
+ if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
760
770
  }
761
771
  if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
762
772
  const [a, b] = src[1].src;
@@ -958,11 +968,13 @@ var AluExp = class AluExp {
958
968
  else if (fromType === DType.Int32) view.setInt32(0, x, true);
959
969
  else if (fromType === DType.Uint32) view.setUint32(0, x, true);
960
970
  else if (fromType === DType.Float16) view.setFloat16(0, x, true);
971
+ else if (fromType === DType.Float64) view.setFloat64(0, x, true);
961
972
  else throw new Error(`Unsupported bitcast from ${fromType}`);
962
973
  if (this.dtype === DType.Float32) return view.getFloat32(0, true);
963
974
  else if (this.dtype === DType.Int32) return view.getInt32(0, true);
964
975
  else if (this.dtype === DType.Uint32) return view.getUint32(0, true);
965
976
  else if (this.dtype === DType.Float16) return view.getFloat16(0, true);
977
+ else if (this.dtype === DType.Float64) return view.getFloat64(0, true);
966
978
  else throw new Error(`Unsupported bitcast to ${this.dtype}`);
967
979
  }
968
980
  default: throw new Error(`Missing implemementation for ${this.op}`);
@@ -2925,6 +2937,7 @@ var CodeGenerator = class {
2925
2937
  local;
2926
2938
  i32;
2927
2939
  f32;
2940
+ f64;
2928
2941
  v128;
2929
2942
  i32x4;
2930
2943
  f32x4;
@@ -2944,6 +2957,7 @@ var CodeGenerator = class {
2944
2957
  this.local = new Local(this);
2945
2958
  this.i32 = new I32(this);
2946
2959
  this.f32 = new F32(this);
2960
+ this.f64 = new F64(this);
2947
2961
  this.v128 = new V128(this);
2948
2962
  this.i32x4 = new I32x4(this);
2949
2963
  this.f32x4 = new F32x4(this);
@@ -3330,6 +3344,8 @@ var I32 = class {
3330
3344
  ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
3331
3345
  trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
3332
3346
  trunc_f32_u = UNARY_OP("trunc_f32_u", 169, "f32", "i32");
3347
+ trunc_f64_s = UNARY_OP("trunc_f64_s", 170, "f64", "i32");
3348
+ trunc_f64_u = UNARY_OP("trunc_f64_u", 171, "f64", "i32");
3333
3349
  load = LOAD_OP("load", 40, "i32");
3334
3350
  load8_s = LOAD_OP("load8_s", 44, "i32");
3335
3351
  load8_u = LOAD_OP("load8_u", 45, "i32");
@@ -3341,6 +3357,8 @@ var I32 = class {
3341
3357
  reinterpret_f32 = UNARY_OP("reinterpret_f32", 188, "f32", "i32");
3342
3358
  trunc_sat_f32_s = UNARY_OP("trunc_sat_f32_s", [252, 0], "f32", "i32");
3343
3359
  trunc_sat_f32_u = UNARY_OP("trunc_sat_f32_u", [252, 1], "f32", "i32");
3360
+ trunc_sat_f64_s = UNARY_OP("trunc_sat_f64_s", [252, 2], "f64", "i32");
3361
+ trunc_sat_f64_u = UNARY_OP("trunc_sat_f64_u", [252, 3], "f64", "i32");
3344
3362
  };
3345
3363
  var F32 = class {
3346
3364
  constructor(cg) {
@@ -3360,6 +3378,8 @@ var F32 = class {
3360
3378
  for (let i = 0; i < 4; i++) this.cg._emit(bytes[i]);
3361
3379
  this.cg._push(this);
3362
3380
  }
3381
+ load = LOAD_OP("load", 42, "f32");
3382
+ store = STORE_OP("store", 56, "f32");
3363
3383
  eq = BINARY_OP("eq", 91, "f32", "f32", "i32");
3364
3384
  ne = BINARY_OP("ne", 92, "f32", "f32", "i32");
3365
3385
  lt = BINARY_OP("lt", 93, "f32", "f32", "i32");
@@ -3382,10 +3402,53 @@ var F32 = class {
3382
3402
  copysign = BINARY_OP("copysign", 152, "f32", "f32", "f32");
3383
3403
  convert_i32_s = UNARY_OP("convert_i32_s", 178, "i32", "f32");
3384
3404
  convert_i32_u = UNARY_OP("convert_i32_u", 179, "i32", "f32");
3385
- load = LOAD_OP("load", 42, "f32");
3386
- store = STORE_OP("store", 56, "f32");
3405
+ demote_f64 = UNARY_OP("demote_f64", 182, "f64", "f32");
3387
3406
  reinterpret_i32 = UNARY_OP("reinterpret_i32", 190, "i32", "f32");
3388
3407
  };
3408
+ var F64 = class {
3409
+ constructor(cg) {
3410
+ this.cg = cg;
3411
+ }
3412
+ get typeId() {
3413
+ return 124;
3414
+ }
3415
+ get name() {
3416
+ return "f64";
3417
+ }
3418
+ const(f) {
3419
+ this.cg._emit(68);
3420
+ const buffer = /* @__PURE__ */ new ArrayBuffer(8);
3421
+ new DataView(buffer).setFloat64(0, f, true);
3422
+ const bytes = new Uint8Array(buffer);
3423
+ for (let i = 0; i < 8; i++) this.cg._emit(bytes[i]);
3424
+ this.cg._push(this);
3425
+ }
3426
+ load = LOAD_OP("load", 43, "f64");
3427
+ store = STORE_OP("store", 57, "f64");
3428
+ eq = BINARY_OP("eq", 97, "f64", "f64", "i32");
3429
+ ne = BINARY_OP("ne", 98, "f64", "f64", "i32");
3430
+ lt = BINARY_OP("lt", 99, "f64", "f64", "i32");
3431
+ gt = BINARY_OP("gt", 100, "f64", "f64", "i32");
3432
+ le = BINARY_OP("le", 101, "f64", "f64", "i32");
3433
+ ge = BINARY_OP("ge", 102, "f64", "f64", "i32");
3434
+ abs = UNARY_OP("abs", 153, "f64", "f64");
3435
+ neg = UNARY_OP("neg", 154, "f64", "f64");
3436
+ ceil = UNARY_OP("ceil", 155, "f64", "f64");
3437
+ floor = UNARY_OP("floor", 156, "f64", "f64");
3438
+ trunc = UNARY_OP("trunc", 157, "f64", "f64");
3439
+ nearest = UNARY_OP("nearest", 158, "f64", "f64");
3440
+ sqrt = UNARY_OP("sqrt", 159, "f64", "f64");
3441
+ add = BINARY_OP("add", 160, "f64", "f64", "f64");
3442
+ sub = BINARY_OP("sub", 161, "f64", "f64", "f64");
3443
+ mul = BINARY_OP("mul", 162, "f64", "f64", "f64");
3444
+ div = BINARY_OP("div", 163, "f64", "f64", "f64");
3445
+ min = BINARY_OP("min", 164, "f64", "f64", "f64");
3446
+ max = BINARY_OP("max", 165, "f64", "f64", "f64");
3447
+ copysign = BINARY_OP("copysign", 166, "f64", "f64", "f64");
3448
+ convert_i32_s = UNARY_OP("convert_i32_s", 183, "i32", "f64");
3449
+ convert_i32_u = UNARY_OP("convert_i32_u", 184, "i32", "f64");
3450
+ promote_f32 = UNARY_OP("promote_f32", 187, "f32", "f64");
3451
+ };
3389
3452
  function VECTOR_OP(op, vopcode, inTypes, outType) {
3390
3453
  return function() {
3391
3454
  for (const inType of inTypes.toReversed()) {
@@ -3638,10 +3701,10 @@ function codegenWasm(kernel) {
3638
3701
  cg.local.get(acc);
3639
3702
  if (re.dtype === DType.Bool) cg.i32.and();
3640
3703
  else dty(cg, re.op, re.dtype).mul();
3641
- } else if (re.op === AluOp.Min || re.op === AluOp.Max) if (re.dtype === DType.Float32) {
3704
+ } else if (re.op === AluOp.Min || re.op === AluOp.Max) if (isFloatDtype(re.dtype)) {
3642
3705
  cg.local.get(acc);
3643
- if (re.op === AluOp.Min) cg.f32.min();
3644
- else cg.f32.max();
3706
+ if (re.op === AluOp.Min) dtyF(cg, re.op, re.dtype).min();
3707
+ else dtyF(cg, re.op, re.dtype).max();
3645
3708
  } else if ([
3646
3709
  DType.Int32,
3647
3710
  DType.Uint32,
@@ -3703,27 +3766,30 @@ function translateExp(cg, funcs, exp, ctx) {
3703
3766
  else if (op === AluOp.Sub) dty(cg, op, dtype).sub();
3704
3767
  else if (op === AluOp.Mul) if (dtype === DType.Bool) cg.i32.and();
3705
3768
  else dty(cg, op, dtype).mul();
3706
- else if (op === AluOp.Idiv) if (dtype === DType.Float32) cg.f32.div(), cg.f32.trunc();
3707
- else if (dtype === DType.Uint32) cg.i32.div_u();
3769
+ else if (op === AluOp.Idiv) if (isFloatDtype(dtype)) {
3770
+ dtyF(cg, op, dtype).div();
3771
+ dtyF(cg, op, dtype).trunc();
3772
+ } else if (dtype === DType.Uint32) cg.i32.div_u();
3708
3773
  else if (dtype === DType.Int32) cg.i32.div_s();
3709
3774
  else throw new UnsupportedOpError(op, dtype, "wasm");
3710
- else if (op === AluOp.Mod) if (dtype === DType.Float32) {
3711
- const a = cg.local.declare(cg.f32);
3712
- const b = cg.local.declare(cg.f32);
3775
+ else if (op === AluOp.Mod) if (isFloatDtype(dtype)) {
3776
+ const dt = dtyF(cg, op, dtype);
3777
+ const a = cg.local.declare(dt);
3778
+ const b = cg.local.declare(dt);
3713
3779
  cg.local.set(b);
3714
3780
  cg.local.tee(a);
3715
3781
  cg.local.get(a);
3716
3782
  cg.local.get(b);
3717
- cg.f32.div();
3718
- cg.f32.trunc();
3783
+ dt.div();
3784
+ dt.trunc();
3719
3785
  cg.local.get(b);
3720
- cg.f32.mul();
3721
- cg.f32.sub();
3786
+ dt.mul();
3787
+ dt.sub();
3722
3788
  } else if (dtype === DType.Uint32) cg.i32.rem_u();
3723
3789
  else if (dtype === DType.Int32) cg.i32.rem_s();
3724
3790
  else throw new UnsupportedOpError(op, dtype, "wasm");
3725
- else if (op === AluOp.Min || op === AluOp.Max) if (dtype === DType.Float32) if (op === AluOp.Min) cg.f32.min();
3726
- else cg.f32.max();
3791
+ else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
3792
+ else dtyF(cg, op, dtype).max();
3727
3793
  else if (dtype === DType.Int32 || dtype === DType.Uint32) {
3728
3794
  const a = cg.local.declare(cg.i32);
3729
3795
  const b = cg.local.declare(cg.i32);
@@ -3740,54 +3806,74 @@ function translateExp(cg, funcs, exp, ctx) {
3740
3806
  } else throw new UnsupportedOpError(op, dtype, "wasm");
3741
3807
  else if (op === AluOp.Cmplt) {
3742
3808
  const srcDtype = src[0].dtype;
3743
- if (srcDtype === DType.Float32) cg.f32.lt();
3809
+ if (isFloatDtype(srcDtype)) dtyF(cg, op, srcDtype).lt();
3744
3810
  else if (srcDtype === DType.Int32) cg.i32.lt_s();
3745
3811
  else if (srcDtype === DType.Uint32) cg.i32.lt_u();
3746
3812
  else throw new UnsupportedOpError(op, dtype, "wasm");
3747
3813
  } else if (op === AluOp.Cmpne) dty(cg, op, src[0].dtype).ne();
3748
3814
  else throw new UnsupportedOpError(op, dtype, "wasm");
3749
- } else if (AluGroup.Unary.has(op)) if (op === AluOp.Sin) gen(src[0]), cg.call(funcs.sin);
3750
- else if (op === AluOp.Cos) gen(src[0]), cg.call(funcs.cos);
3751
- else if (op === AluOp.Asin) gen(src[0]), cg.call(funcs.asin);
3752
- else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
3753
- else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
3754
- else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
3755
- else if (op === AluOp.Erf) gen(src[0]), cg.call(funcs.erf);
3756
- else if (op === AluOp.Erfc) gen(src[0]), cg.call(funcs.erfc);
3757
- else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
3758
- else if (op === AluOp.Reciprocal) cg.f32.const(1), gen(src[0]), cg.f32.div();
3759
- else if (op === AluOp.Cast) {
3760
- gen(src[0]);
3761
- const dtype0 = src[0].dtype;
3762
- const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
3763
- if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
3764
- else if (i32repr);
3765
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3766
- else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
3767
- else if (i32repr);
3768
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3769
- else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
3770
- else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
3771
- else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
3772
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3773
- else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
3774
- else if (i32repr) cg.i32.const(0), cg.i32.ne();
3775
- else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
3776
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3777
- else throw new UnsupportedOpError(op, dtype, "wasm");
3778
- } else if (op === AluOp.Bitcast) {
3779
- gen(src[0]);
3780
- const dtype0 = src[0].dtype;
3781
- const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
3782
- if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
3783
- else if (i32repr);
3784
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3785
- else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
3786
- else if (dtype0 === DType.Float32);
3787
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3788
- else throw new UnsupportedOpError(op, dtype, "wasm");
3789
- } else throw new UnsupportedOpError(op, dtype, "wasm");
3790
- else if (op === AluOp.Where) {
3815
+ } else if (AluGroup.Unary.has(op)) {
3816
+ const callFuncF32 = (func) => {
3817
+ if (dtype !== DType.Float32) if (dtype === DType.Float64) cg.f32.demote_f64();
3818
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3819
+ cg.call(func);
3820
+ if (dtype === DType.Float64) cg.f64.promote_f32();
3821
+ };
3822
+ if (op === AluOp.Sin) gen(src[0]), callFuncF32(funcs.sin);
3823
+ else if (op === AluOp.Cos) gen(src[0]), callFuncF32(funcs.cos);
3824
+ else if (op === AluOp.Asin) gen(src[0]), callFuncF32(funcs.asin);
3825
+ else if (op === AluOp.Atan) gen(src[0]), callFuncF32(funcs.atan);
3826
+ else if (op === AluOp.Exp) gen(src[0]), callFuncF32(funcs.exp);
3827
+ else if (op === AluOp.Log) gen(src[0]), callFuncF32(funcs.log);
3828
+ else if (op === AluOp.Erf) gen(src[0]), callFuncF32(funcs.erf);
3829
+ else if (op === AluOp.Erfc) gen(src[0]), callFuncF32(funcs.erfc);
3830
+ else if (op === AluOp.Sqrt) gen(src[0]), dtyF(cg, op, dtype).sqrt();
3831
+ else if (op === AluOp.Reciprocal) {
3832
+ const dt = dtyF(cg, op, dtype);
3833
+ dt.const(1), gen(src[0]), dt.div();
3834
+ } else if (op === AluOp.Cast) {
3835
+ gen(src[0]);
3836
+ const dtype0 = src[0].dtype;
3837
+ const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
3838
+ if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
3839
+ else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_s();
3840
+ else if (i32repr);
3841
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3842
+ else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
3843
+ else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_u();
3844
+ else if (i32repr);
3845
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3846
+ else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
3847
+ else if (dtype0 === DType.Float64) cg.f32.demote_f64();
3848
+ else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
3849
+ else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
3850
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3851
+ else if (dtype === DType.Float64) if (dtype0 === DType.Float32) cg.f64.promote_f32();
3852
+ else if (dtype0 === DType.Float64);
3853
+ else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f64.convert_i32_s();
3854
+ else if (dtype0 === DType.Uint32) cg.f64.convert_i32_u();
3855
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3856
+ else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
3857
+ else if (i32repr) cg.i32.const(0), cg.i32.ne();
3858
+ else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
3859
+ else if (dtype0 === DType.Float64) cg.f64.const(0), cg.f64.ne();
3860
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3861
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3862
+ } else if (op === AluOp.Bitcast) {
3863
+ gen(src[0]);
3864
+ const dtype0 = src[0].dtype;
3865
+ if (dtype !== dtype0) {
3866
+ const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
3867
+ if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
3868
+ else if (i32repr);
3869
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3870
+ else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
3871
+ else if (dtype0 === DType.Float32);
3872
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3873
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3874
+ }
3875
+ } else throw new UnsupportedOpError(op, dtype, "wasm");
3876
+ } else if (op === AluOp.Where) {
3791
3877
  gen(src[1]);
3792
3878
  gen(src[2]);
3793
3879
  gen(src[0]);
@@ -3832,12 +3918,20 @@ function translateExp(cg, funcs, exp, ctx) {
3832
3918
  function dty(cg, op, dtype) {
3833
3919
  switch (dtype) {
3834
3920
  case DType.Float32: return cg.f32;
3921
+ case DType.Float64: return cg.f64;
3835
3922
  case DType.Int32:
3836
3923
  case DType.Uint32:
3837
3924
  case DType.Bool: return cg.i32;
3838
3925
  default: throw new UnsupportedOpError(op, dtype, "wasm");
3839
3926
  }
3840
3927
  }
3928
+ function dtyF(cg, op, dtype) {
3929
+ switch (dtype) {
3930
+ case DType.Float32: return cg.f32;
3931
+ case DType.Float64: return cg.f64;
3932
+ default: throw new UnsupportedOpError(op, dtype, "wasm");
3933
+ }
3934
+ }
3841
3935
 
3842
3936
  //#endregion
3843
3937
  //#region src/backend.ts
@@ -3846,10 +3940,10 @@ const devices = [
3846
3940
  "wasm",
3847
3941
  "webgpu"
3848
3942
  ];
3849
- let defaultBackend = "wasm";
3850
3943
  const initializedBackends = /* @__PURE__ */ new Map();
3851
3944
  initializedBackends.set("cpu", new CpuBackend());
3852
- initializedBackends.set("wasm", new WasmBackend());
3945
+ if (typeof WebAssembly !== "undefined") initializedBackends.set("wasm", new WasmBackend());
3946
+ let defaultBackend = initializedBackends.has("wasm") ? "wasm" : "cpu";
3853
3947
  /** Configure the default device for arrays. */
3854
3948
  function defaultDevice(device) {
3855
3949
  if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
@@ -3876,12 +3970,14 @@ async function init(...devicesToInit) {
3876
3970
  /** Create a backend, if available. Internal function called by `init()`. */
3877
3971
  async function createBackend(device) {
3878
3972
  if (device === "cpu") return new CpuBackend();
3879
- else if (device === "wasm") return new WasmBackend();
3880
- else if (device === "webgpu") {
3973
+ else if (device === "wasm") {
3974
+ if (typeof WebAssembly === "undefined") return null;
3975
+ return new WasmBackend();
3976
+ } else if (device === "webgpu") {
3881
3977
  if (!navigator.gpu) return null;
3882
3978
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3883
3979
  if (!adapter) return null;
3884
- const { WebGPUBackend } = await import("./webgpu-LGi2A3mS.js");
3980
+ const { WebGPUBackend } = await import("./webgpu-B3UVme6n.js");
3885
3981
  const importantLimits = [
3886
3982
  "maxBufferSize",
3887
3983
  "maxComputeInvocationsPerWorkgroup",
@@ -3935,4 +4031,4 @@ var UnsupportedOpError = class extends Error {
3935
4031
 
3936
4032
  //#endregion
3937
4033
  export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
3938
- //# sourceMappingURL=backend-DwIAd0AG.js.map
4034
+ //# sourceMappingURL=backend-CoVtc9dx.js.map
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
30
30
  }) : target, mod));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-FtkbO6pI.cjs');
33
+ const require_backend = require('./backend-BbrKEB18.cjs');
34
34
 
35
35
  //#region src/tree.ts
36
36
  var tree_exports = {};
@@ -401,11 +401,9 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
401
401
  return Primitive$1;
402
402
  }({});
403
403
  let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
404
- CompareOp$1["Greater"] = "greater";
405
404
  CompareOp$1["Less"] = "less";
406
405
  CompareOp$1["Equal"] = "equal";
407
406
  CompareOp$1["NotEqual"] = "not_equal";
408
- CompareOp$1["GreaterEqual"] = "greater_equal";
409
407
  CompareOp$1["LessEqual"] = "less_equal";
410
408
  return CompareOp$1;
411
409
  }({});
@@ -501,7 +499,7 @@ function compare(x, y, op) {
501
499
  return bind1(Primitive.Compare, [x, y], { op });
502
500
  }
503
501
  function greater$1(x, y) {
504
- return compare(x, y, CompareOp.Greater);
502
+ return compare(y, x, CompareOp.Less);
505
503
  }
506
504
  function less$1(x, y) {
507
505
  return compare(x, y, CompareOp.Less);
@@ -513,7 +511,7 @@ function notEqual$1(x, y) {
513
511
  return compare(x, y, CompareOp.NotEqual);
514
512
  }
515
513
  function greaterEqual$1(x, y) {
516
- return compare(x, y, CompareOp.GreaterEqual);
514
+ return compare(y, x, CompareOp.LessEqual);
517
515
  }
518
516
  function lessEqual$1(x, y) {
519
517
  return compare(x, y, CompareOp.LessEqual);
@@ -2240,6 +2238,9 @@ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
2240
2238
  } else if (data instanceof Float16Array) {
2241
2239
  if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
2242
2240
  dtype ??= require_backend.DType.Float16;
2241
+ } else if (data instanceof Float64Array) {
2242
+ if (dtype && dtype !== require_backend.DType.Float64) throw new Error("Float64Array must have float64 type");
2243
+ dtype ??= require_backend.DType.Float64;
2243
2244
  } else throw new Error("Unsupported data array type: " + data.constructor.name);
2244
2245
  if (data.length < inlineArrayLimit) {
2245
2246
  let allEqual = true;
@@ -2451,11 +2452,9 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2451
2452
  }
2452
2453
  function aluCompare(a, b, op) {
2453
2454
  switch (op) {
2454
- case CompareOp.Greater: return require_backend.AluExp.mul(require_backend.AluExp.cmpne(a, b), require_backend.AluExp.cmplt(a, b).not());
2455
2455
  case CompareOp.Less: return require_backend.AluExp.cmplt(a, b);
2456
2456
  case CompareOp.Equal: return require_backend.AluExp.cmpne(a, b).not();
2457
2457
  case CompareOp.NotEqual: return require_backend.AluExp.cmpne(a, b);
2458
- case CompareOp.GreaterEqual: return require_backend.AluExp.cmplt(a, b).not();
2459
2458
  case CompareOp.LessEqual: return require_backend.AluExp.add(require_backend.AluExp.cmplt(a, b), require_backend.AluExp.cmpne(a, b).not());
2460
2459
  }
2461
2460
  }
@@ -2592,7 +2591,7 @@ var JaxprEqn = class {
2592
2591
  const paramsList = Object.entries(this.params).map(([k, v]) => require_backend.PPrint.pp(`${k}=${v}`));
2593
2592
  if (paramsList.length > 0) rhs = rhs.stack(require_backend.PPrint.pp(" [ ")).stack(require_backend.PPrint.prototype.concat(...paramsList)).stack(require_backend.PPrint.pp(" ] "));
2594
2593
  else rhs = rhs.stack(require_backend.PPrint.pp(" "));
2595
- rhs = rhs.stack(require_backend.PPrint.pp(this.inputs.map((x) => x instanceof Var ? vp.name(x) : JSON.stringify(x.value)).join(" ")));
2594
+ rhs = rhs.stack(require_backend.PPrint.pp(this.inputs.map((x) => x instanceof Var ? vp.name(x) : String(x.value)).join(" ")));
2596
2595
  return lhs.stack(require_backend.PPrint.pp(" = ")).stack(rhs);
2597
2596
  }
2598
2597
  toString() {
@@ -4379,6 +4378,7 @@ __export(numpy_exports, {
4379
4378
  flipud: () => flipud,
4380
4379
  float16: () => float16,
4381
4380
  float32: () => float32,
4381
+ float64: () => float64,
4382
4382
  full: () => full,
4383
4383
  fullLike: () => fullLike$1,
4384
4384
  greater: () => greater,
@@ -4392,6 +4392,11 @@ __export(numpy_exports, {
4392
4392
  inf: () => inf,
4393
4393
  inner: () => inner,
4394
4394
  int32: () => int32,
4395
+ isfinite: () => isfinite,
4396
+ isinf: () => isinf,
4397
+ isnan: () => isnan,
4398
+ isneginf: () => isneginf,
4399
+ isposinf: () => isposinf,
4395
4400
  less: () => less,
4396
4401
  lessEqual: () => lessEqual,
4397
4402
  linspace: () => linspace,
@@ -4462,6 +4467,7 @@ const int32 = require_backend.DType.Int32;
4462
4467
  const uint32 = require_backend.DType.Uint32;
4463
4468
  const bool = require_backend.DType.Bool;
4464
4469
  const float16 = require_backend.DType.Float16;
4470
+ const float64 = require_backend.DType.Float64;
4465
4471
  /** Euler's constant, `e = 2.7182818284590...` */
4466
4472
  const e = Math.E;
4467
4473
  /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
@@ -5262,6 +5268,34 @@ function var_(x, axis = null, opts) {
5262
5268
  function std(x, axis = null, opts) {
5263
5269
  return sqrt(var_(x, axis, opts));
5264
5270
  }
5271
+ /** Test element-wise for positive or negative infinity, return bool array. */
5272
+ function isinf(x) {
5273
+ x = fudgeArray(x);
5274
+ return require_backend.isFloatDtype(x.dtype) ? x.ref.equal(Infinity).add(x.equal(-Infinity)) : fullLike$1(x, false);
5275
+ }
5276
+ /** Test element-wise for NaN (Not a Number). */
5277
+ function isnan(x) {
5278
+ x = fudgeArray(x);
5279
+ return require_backend.isFloatDtype(x.dtype) ? x.ref.notEqual(x) : fullLike$1(x, false);
5280
+ }
5281
+ /** Test element-wise for negative infinity, return bool array. */
5282
+ function isneginf(x) {
5283
+ x = fudgeArray(x);
5284
+ return require_backend.isFloatDtype(x.dtype) ? x.equal(-Infinity) : fullLike$1(x, false);
5285
+ }
5286
+ /** Test element-wise for positive infinity, return bool array. */
5287
+ function isposinf(x) {
5288
+ x = fudgeArray(x);
5289
+ return require_backend.isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
5290
+ }
5291
+ /**
5292
+ * @function
5293
+ * Test element-wise for finite values (not infinity or NaN).
5294
+ */
5295
+ const isfinite = jit$1(function isfinite$1(x) {
5296
+ if (!require_backend.isFloatDtype(x.dtype)) return fullLike$1(x, true);
5297
+ return isnan(x.ref).add(isinf(x)).notEqual(true);
5298
+ });
5265
5299
 
5266
5300
  //#endregion
5267
5301
  //#region src/nn.ts
package/dist/index.d.cts CHANGED
@@ -168,9 +168,10 @@ declare enum DType {
168
168
  Uint32 = "uint32",
169
169
  Bool = "bool",
170
170
  Float16 = "float16",
171
+ Float64 = "float64",
171
172
  }
172
173
  /** @inline */
173
- type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
174
+ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer> | Float64Array<ArrayBuffer>;
174
175
  /**
175
176
  * Promote two dtypes to their join according to the type lattice.
176
177
  *
@@ -180,7 +181,7 @@ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Arr
180
181
  *
181
182
  * **Type lattice:**
182
183
  * ```text
183
- * bool -> uint32 -> int32 -> float16 -> float32
184
+ * bool -> uint32 -> int32 -> float16 -> float32 -> float64
184
185
  * weakType --^
185
186
  * ```
186
187
  *
@@ -243,6 +244,7 @@ declare class AluExp implements FpHashable {
243
244
  static u32(value: number): AluExp;
244
245
  static bool(value: boolean): AluExp;
245
246
  static f16(value: number): AluExp;
247
+ static f64(value: number): AluExp;
246
248
  not(): AluExp;
247
249
  /** Compute a reasonable expression hash with low collision rate. */
248
250
  getHash(): bigint;
@@ -630,11 +632,9 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
630
632
  /** Type of parameters taken by each primitive. */
631
633
  type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
632
634
  declare enum CompareOp {
633
- Greater = "greater",
634
635
  Less = "less",
635
636
  Equal = "equal",
636
637
  NotEqual = "not_equal",
637
- GreaterEqual = "greater_equal",
638
638
  LessEqual = "less_equal",
639
639
  }
640
640
  /** @inline */
@@ -985,7 +985,7 @@ declare class Array extends Tracer {
985
985
  _putSync(backend: Backend): Array;
986
986
  }
987
987
  /** Constructor for creating a new array from data. */
988
- declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
988
+ declare function array(values: Array | DataArray | RecursiveArray<number> | RecursiveArray<boolean>, {
989
989
  shape,
990
990
  dtype,
991
991
  device
@@ -1058,13 +1058,14 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
1058
1058
  device
1059
1059
  }?: DTypeAndDevice): Array;
1060
1060
  declare namespace numpy_d_exports {
1061
- export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1061
+ export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1062
1062
  }
1063
1063
  declare const float32 = DType.Float32;
1064
1064
  declare const int32 = DType.Int32;
1065
1065
  declare const uint32 = DType.Uint32;
1066
1066
  declare const bool = DType.Bool;
1067
1067
  declare const float16 = DType.Float16;
1068
+ declare const float64 = DType.Float64;
1068
1069
  /** Euler's constant, `e = 2.7182818284590...` */
1069
1070
  declare const e: number;
1070
1071
  /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
@@ -1535,6 +1536,19 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
1535
1536
  mean?: ArrayLike;
1536
1537
  correction?: number;
1537
1538
  } & ReduceOpts): Array;
1539
+ /** Test element-wise for positive or negative infinity, return bool array. */
1540
+ declare function isinf(x: ArrayLike): Array;
1541
+ /** Test element-wise for NaN (Not a Number). */
1542
+ declare function isnan(x: ArrayLike): Array;
1543
+ /** Test element-wise for negative infinity, return bool array. */
1544
+ declare function isneginf(x: ArrayLike): Array;
1545
+ /** Test element-wise for positive infinity, return bool array. */
1546
+ declare function isposinf(x: ArrayLike): Array;
1547
+ /**
1548
+ * @function
1549
+ * Test element-wise for finite values (not infinity or NaN).
1550
+ */
1551
+ declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
1538
1552
  //# sourceMappingURL=numpy.d.ts.map
1539
1553
  //#endregion
1540
1554
  //#region src/frontend/jaxpr.d.ts