@jax-js/jax 0.1.1 → 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -289,10 +289,11 @@ var FpHash = class FpHash {
289
289
  };
290
290
  /** Run a function while caching it inline inside a `Map`. */
291
291
  function runWithCache(cache, key, thunk) {
292
- if (cache.has(key)) return cache.get(key);
292
+ const keyStr = JSON.stringify(key);
293
+ if (cache.has(keyStr)) return cache.get(keyStr);
293
294
  else {
294
295
  const value = thunk();
295
- cache.set(key, value);
296
+ cache.set(keyStr, value);
296
297
  return value;
297
298
  }
298
299
  }
@@ -449,6 +450,14 @@ var AluExp = class AluExp {
449
450
  static sqrt(a) {
450
451
  return new AluExp(AluOp.Sqrt, a.dtype, [a]);
451
452
  }
453
+ static floor(a) {
454
+ if (!isFloatDtype(a.dtype)) return a;
455
+ return new AluExp(AluOp.Floor, a.dtype, [a]);
456
+ }
457
+ static ceil(a) {
458
+ if (!isFloatDtype(a.dtype)) return a;
459
+ return new AluExp(AluOp.Ceil, a.dtype, [a]);
460
+ }
452
461
  static reciprocal(a) {
453
462
  return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
454
463
  }
@@ -629,6 +638,12 @@ var AluExp = class AluExp {
629
638
  case AluOp.Sqrt:
630
639
  ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
631
640
  break;
641
+ case AluOp.Floor:
642
+ ret = [Math.floor(src[0].min), Math.floor(src[0].max)];
643
+ break;
644
+ case AluOp.Ceil:
645
+ ret = [Math.ceil(src[0].min), Math.ceil(src[0].max)];
646
+ break;
632
647
  case AluOp.Reciprocal:
633
648
  if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
634
649
  ret = [1 / src[0].max, 1 / src[0].min];
@@ -861,7 +876,7 @@ var AluExp = class AluExp {
861
876
  else return p(p(src[0].src[0], src[1]), src[0].src[1]).simplify(cache);
862
877
  if (src[1].op === op && src[1].src[1].op === AluOp.Const) return p(p(src[0], src[1].src[0]), src[1].src[1]).simplify(cache);
863
878
  }
864
- if (op === AluOp.Mod || op === AluOp.Idiv && src[1].#isConstInt()) {
879
+ if ((op === AluOp.Mod || op === AluOp.Idiv) && src[1].#isConstInt()) {
865
880
  const [x, y] = src;
866
881
  {
867
882
  const factors = [];
@@ -951,6 +966,8 @@ var AluExp = class AluExp {
951
966
  case AluOp.Erf: return erf(x);
952
967
  case AluOp.Erfc: return erfc(x);
953
968
  case AluOp.Sqrt: return Math.sqrt(x);
969
+ case AluOp.Floor: return Math.floor(x);
970
+ case AluOp.Ceil: return Math.ceil(x);
954
971
  case AluOp.Reciprocal: return 1 / x;
955
972
  case AluOp.Cast: {
956
973
  const wasFloat = isFloatDtype(this.src[0].dtype);
@@ -1140,6 +1157,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
1140
1157
  AluOp$1["Erf"] = "Erf";
1141
1158
  AluOp$1["Erfc"] = "Erfc";
1142
1159
  AluOp$1["Sqrt"] = "Sqrt";
1160
+ AluOp$1["Floor"] = "Floor";
1161
+ AluOp$1["Ceil"] = "Ceil";
1143
1162
  AluOp$1["Reciprocal"] = "Reciprocal";
1144
1163
  AluOp$1["Cast"] = "Cast";
1145
1164
  AluOp$1["Bitcast"] = "Bitcast";
@@ -1174,6 +1193,8 @@ const AluGroup = {
1174
1193
  AluOp.Erf,
1175
1194
  AluOp.Erfc,
1176
1195
  AluOp.Sqrt,
1196
+ AluOp.Floor,
1197
+ AluOp.Ceil,
1177
1198
  AluOp.Reciprocal,
1178
1199
  AluOp.Cast,
1179
1200
  AluOp.Bitcast
@@ -1201,7 +1222,9 @@ const AluGroup = {
1201
1222
  AluOp.Erf,
1202
1223
  AluOp.Erfc,
1203
1224
  AluOp.Sqrt,
1204
- AluOp.Reciprocal
1225
+ AluOp.Reciprocal,
1226
+ AluOp.Floor,
1227
+ AluOp.Ceil
1205
1228
  ])
1206
1229
  };
1207
1230
  /** Common variables that can be substituted in expressions. */
@@ -3831,7 +3854,9 @@ function translateExp(cg, funcs, exp, ctx) {
3831
3854
  else if (op === AluOp.Reciprocal) {
3832
3855
  const dt = dtyF(cg, op, dtype);
3833
3856
  dt.const(1), gen(src[0]), dt.div();
3834
- } else if (op === AluOp.Cast) {
3857
+ } else if (op === AluOp.Floor) gen(src[0]), dtyF(cg, op, dtype).floor();
3858
+ else if (op === AluOp.Ceil) gen(src[0]), dtyF(cg, op, dtype).ceil();
3859
+ else if (op === AluOp.Cast) {
3835
3860
  gen(src[0]);
3836
3861
  const dtype0 = src[0].dtype;
3837
3862
  const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
@@ -3977,7 +4002,7 @@ async function createBackend(device) {
3977
4002
  if (!navigator.gpu) return null;
3978
4003
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3979
4004
  if (!adapter) return null;
3980
- const { WebGPUBackend } = await import("./webgpu-B3UVme6n.js");
4005
+ const { WebGPUBackend } = await import("./webgpu-BGuG58KZ.js");
3981
4006
  const importantLimits = [
3982
4007
  "maxBufferSize",
3983
4008
  "maxComputeInvocationsPerWorkgroup",
@@ -4031,4 +4056,4 @@ var UnsupportedOpError = class extends Error {
4031
4056
 
4032
4057
  //#endregion
4033
4058
  export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
4034
- //# sourceMappingURL=backend-CoVtc9dx.js.map
4059
+ //# sourceMappingURL=backend-BqymqzuU.js.map
@@ -290,10 +290,11 @@ var FpHash = class FpHash {
290
290
  };
291
291
  /** Run a function while caching it inline inside a `Map`. */
292
292
  function runWithCache(cache, key, thunk) {
293
- if (cache.has(key)) return cache.get(key);
293
+ const keyStr = JSON.stringify(key);
294
+ if (cache.has(keyStr)) return cache.get(keyStr);
294
295
  else {
295
296
  const value = thunk();
296
- cache.set(key, value);
297
+ cache.set(keyStr, value);
297
298
  return value;
298
299
  }
299
300
  }
@@ -450,6 +451,14 @@ var AluExp = class AluExp {
450
451
  static sqrt(a) {
451
452
  return new AluExp(AluOp.Sqrt, a.dtype, [a]);
452
453
  }
454
+ static floor(a) {
455
+ if (!isFloatDtype(a.dtype)) return a;
456
+ return new AluExp(AluOp.Floor, a.dtype, [a]);
457
+ }
458
+ static ceil(a) {
459
+ if (!isFloatDtype(a.dtype)) return a;
460
+ return new AluExp(AluOp.Ceil, a.dtype, [a]);
461
+ }
453
462
  static reciprocal(a) {
454
463
  return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
455
464
  }
@@ -630,6 +639,12 @@ var AluExp = class AluExp {
630
639
  case AluOp.Sqrt:
631
640
  ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
632
641
  break;
642
+ case AluOp.Floor:
643
+ ret = [Math.floor(src[0].min), Math.floor(src[0].max)];
644
+ break;
645
+ case AluOp.Ceil:
646
+ ret = [Math.ceil(src[0].min), Math.ceil(src[0].max)];
647
+ break;
633
648
  case AluOp.Reciprocal:
634
649
  if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
635
650
  ret = [1 / src[0].max, 1 / src[0].min];
@@ -862,7 +877,7 @@ var AluExp = class AluExp {
862
877
  else return p(p(src[0].src[0], src[1]), src[0].src[1]).simplify(cache);
863
878
  if (src[1].op === op && src[1].src[1].op === AluOp.Const) return p(p(src[0], src[1].src[0]), src[1].src[1]).simplify(cache);
864
879
  }
865
- if (op === AluOp.Mod || op === AluOp.Idiv && src[1].#isConstInt()) {
880
+ if ((op === AluOp.Mod || op === AluOp.Idiv) && src[1].#isConstInt()) {
866
881
  const [x, y] = src;
867
882
  {
868
883
  const factors = [];
@@ -952,6 +967,8 @@ var AluExp = class AluExp {
952
967
  case AluOp.Erf: return erf(x);
953
968
  case AluOp.Erfc: return erfc(x);
954
969
  case AluOp.Sqrt: return Math.sqrt(x);
970
+ case AluOp.Floor: return Math.floor(x);
971
+ case AluOp.Ceil: return Math.ceil(x);
955
972
  case AluOp.Reciprocal: return 1 / x;
956
973
  case AluOp.Cast: {
957
974
  const wasFloat = isFloatDtype(this.src[0].dtype);
@@ -1141,6 +1158,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
1141
1158
  AluOp$1["Erf"] = "Erf";
1142
1159
  AluOp$1["Erfc"] = "Erfc";
1143
1160
  AluOp$1["Sqrt"] = "Sqrt";
1161
+ AluOp$1["Floor"] = "Floor";
1162
+ AluOp$1["Ceil"] = "Ceil";
1144
1163
  AluOp$1["Reciprocal"] = "Reciprocal";
1145
1164
  AluOp$1["Cast"] = "Cast";
1146
1165
  AluOp$1["Bitcast"] = "Bitcast";
@@ -1175,6 +1194,8 @@ const AluGroup = {
1175
1194
  AluOp.Erf,
1176
1195
  AluOp.Erfc,
1177
1196
  AluOp.Sqrt,
1197
+ AluOp.Floor,
1198
+ AluOp.Ceil,
1178
1199
  AluOp.Reciprocal,
1179
1200
  AluOp.Cast,
1180
1201
  AluOp.Bitcast
@@ -1202,7 +1223,9 @@ const AluGroup = {
1202
1223
  AluOp.Erf,
1203
1224
  AluOp.Erfc,
1204
1225
  AluOp.Sqrt,
1205
- AluOp.Reciprocal
1226
+ AluOp.Reciprocal,
1227
+ AluOp.Floor,
1228
+ AluOp.Ceil
1206
1229
  ])
1207
1230
  };
1208
1231
  /** Common variables that can be substituted in expressions. */
@@ -3832,7 +3855,9 @@ function translateExp(cg, funcs, exp, ctx) {
3832
3855
  else if (op === AluOp.Reciprocal) {
3833
3856
  const dt = dtyF(cg, op, dtype);
3834
3857
  dt.const(1), gen(src[0]), dt.div();
3835
- } else if (op === AluOp.Cast) {
3858
+ } else if (op === AluOp.Floor) gen(src[0]), dtyF(cg, op, dtype).floor();
3859
+ else if (op === AluOp.Ceil) gen(src[0]), dtyF(cg, op, dtype).ceil();
3860
+ else if (op === AluOp.Cast) {
3836
3861
  gen(src[0]);
3837
3862
  const dtype0 = src[0].dtype;
3838
3863
  const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
@@ -3978,7 +4003,7 @@ async function createBackend(device) {
3978
4003
  if (!navigator.gpu) return null;
3979
4004
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3980
4005
  if (!adapter) return null;
3981
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-DGYNVHma.cjs"));
4006
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-CcGP160M.cjs"));
3982
4007
  const importantLimits = [
3983
4008
  "maxBufferSize",
3984
4009
  "maxComputeInvocationsPerWorkgroup",
@@ -4325,4 +4350,4 @@ Object.defineProperty(exports, 'zipn', {
4325
4350
  return zipn;
4326
4351
  }
4327
4352
  });
4328
- //# sourceMappingURL=backend-BbrKEB18.cjs.map
4353
+ //# sourceMappingURL=backend-DeVfWEFS.cjs.map