@jax-js/jax 0.1.1 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/{backend-CoVtc9dx.js → backend-BqymqzuU.js} +32 -7
- package/dist/{backend-BbrKEB18.cjs → backend-DeVfWEFS.cjs} +32 -7
- package/dist/index.cjs +2679 -2194
- package/dist/index.d.cts +1087 -981
- package/dist/index.d.ts +1087 -981
- package/dist/index.js +2663 -2178
- package/dist/{webgpu-B3UVme6n.js → webgpu-BGuG58KZ.js} +13 -11
- package/dist/{webgpu-DGYNVHma.cjs → webgpu-CcGP160M.cjs} +13 -11
- package/package.json +13 -21
|
@@ -289,10 +289,11 @@ var FpHash = class FpHash {
|
|
|
289
289
|
};
|
|
290
290
|
/** Run a function while caching it inline inside a `Map`. */
|
|
291
291
|
function runWithCache(cache, key, thunk) {
|
|
292
|
-
|
|
292
|
+
const keyStr = JSON.stringify(key);
|
|
293
|
+
if (cache.has(keyStr)) return cache.get(keyStr);
|
|
293
294
|
else {
|
|
294
295
|
const value = thunk();
|
|
295
|
-
cache.set(
|
|
296
|
+
cache.set(keyStr, value);
|
|
296
297
|
return value;
|
|
297
298
|
}
|
|
298
299
|
}
|
|
@@ -449,6 +450,14 @@ var AluExp = class AluExp {
|
|
|
449
450
|
static sqrt(a) {
|
|
450
451
|
return new AluExp(AluOp.Sqrt, a.dtype, [a]);
|
|
451
452
|
}
|
|
453
|
+
static floor(a) {
|
|
454
|
+
if (!isFloatDtype(a.dtype)) return a;
|
|
455
|
+
return new AluExp(AluOp.Floor, a.dtype, [a]);
|
|
456
|
+
}
|
|
457
|
+
static ceil(a) {
|
|
458
|
+
if (!isFloatDtype(a.dtype)) return a;
|
|
459
|
+
return new AluExp(AluOp.Ceil, a.dtype, [a]);
|
|
460
|
+
}
|
|
452
461
|
static reciprocal(a) {
|
|
453
462
|
return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
|
|
454
463
|
}
|
|
@@ -629,6 +638,12 @@ var AluExp = class AluExp {
|
|
|
629
638
|
case AluOp.Sqrt:
|
|
630
639
|
ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
|
|
631
640
|
break;
|
|
641
|
+
case AluOp.Floor:
|
|
642
|
+
ret = [Math.floor(src[0].min), Math.floor(src[0].max)];
|
|
643
|
+
break;
|
|
644
|
+
case AluOp.Ceil:
|
|
645
|
+
ret = [Math.ceil(src[0].min), Math.ceil(src[0].max)];
|
|
646
|
+
break;
|
|
632
647
|
case AluOp.Reciprocal:
|
|
633
648
|
if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
|
|
634
649
|
ret = [1 / src[0].max, 1 / src[0].min];
|
|
@@ -861,7 +876,7 @@ var AluExp = class AluExp {
|
|
|
861
876
|
else return p(p(src[0].src[0], src[1]), src[0].src[1]).simplify(cache);
|
|
862
877
|
if (src[1].op === op && src[1].src[1].op === AluOp.Const) return p(p(src[0], src[1].src[0]), src[1].src[1]).simplify(cache);
|
|
863
878
|
}
|
|
864
|
-
if (op === AluOp.Mod || op === AluOp.Idiv && src[1].#isConstInt()) {
|
|
879
|
+
if ((op === AluOp.Mod || op === AluOp.Idiv) && src[1].#isConstInt()) {
|
|
865
880
|
const [x, y] = src;
|
|
866
881
|
{
|
|
867
882
|
const factors = [];
|
|
@@ -951,6 +966,8 @@ var AluExp = class AluExp {
|
|
|
951
966
|
case AluOp.Erf: return erf(x);
|
|
952
967
|
case AluOp.Erfc: return erfc(x);
|
|
953
968
|
case AluOp.Sqrt: return Math.sqrt(x);
|
|
969
|
+
case AluOp.Floor: return Math.floor(x);
|
|
970
|
+
case AluOp.Ceil: return Math.ceil(x);
|
|
954
971
|
case AluOp.Reciprocal: return 1 / x;
|
|
955
972
|
case AluOp.Cast: {
|
|
956
973
|
const wasFloat = isFloatDtype(this.src[0].dtype);
|
|
@@ -1140,6 +1157,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
1140
1157
|
AluOp$1["Erf"] = "Erf";
|
|
1141
1158
|
AluOp$1["Erfc"] = "Erfc";
|
|
1142
1159
|
AluOp$1["Sqrt"] = "Sqrt";
|
|
1160
|
+
AluOp$1["Floor"] = "Floor";
|
|
1161
|
+
AluOp$1["Ceil"] = "Ceil";
|
|
1143
1162
|
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
1144
1163
|
AluOp$1["Cast"] = "Cast";
|
|
1145
1164
|
AluOp$1["Bitcast"] = "Bitcast";
|
|
@@ -1174,6 +1193,8 @@ const AluGroup = {
|
|
|
1174
1193
|
AluOp.Erf,
|
|
1175
1194
|
AluOp.Erfc,
|
|
1176
1195
|
AluOp.Sqrt,
|
|
1196
|
+
AluOp.Floor,
|
|
1197
|
+
AluOp.Ceil,
|
|
1177
1198
|
AluOp.Reciprocal,
|
|
1178
1199
|
AluOp.Cast,
|
|
1179
1200
|
AluOp.Bitcast
|
|
@@ -1201,7 +1222,9 @@ const AluGroup = {
|
|
|
1201
1222
|
AluOp.Erf,
|
|
1202
1223
|
AluOp.Erfc,
|
|
1203
1224
|
AluOp.Sqrt,
|
|
1204
|
-
AluOp.Reciprocal
|
|
1225
|
+
AluOp.Reciprocal,
|
|
1226
|
+
AluOp.Floor,
|
|
1227
|
+
AluOp.Ceil
|
|
1205
1228
|
])
|
|
1206
1229
|
};
|
|
1207
1230
|
/** Common variables that can be substituted in expressions. */
|
|
@@ -3831,7 +3854,9 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3831
3854
|
else if (op === AluOp.Reciprocal) {
|
|
3832
3855
|
const dt = dtyF(cg, op, dtype);
|
|
3833
3856
|
dt.const(1), gen(src[0]), dt.div();
|
|
3834
|
-
} else if (op === AluOp.
|
|
3857
|
+
} else if (op === AluOp.Floor) gen(src[0]), dtyF(cg, op, dtype).floor();
|
|
3858
|
+
else if (op === AluOp.Ceil) gen(src[0]), dtyF(cg, op, dtype).ceil();
|
|
3859
|
+
else if (op === AluOp.Cast) {
|
|
3835
3860
|
gen(src[0]);
|
|
3836
3861
|
const dtype0 = src[0].dtype;
|
|
3837
3862
|
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
|
|
@@ -3977,7 +4002,7 @@ async function createBackend(device) {
|
|
|
3977
4002
|
if (!navigator.gpu) return null;
|
|
3978
4003
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3979
4004
|
if (!adapter) return null;
|
|
3980
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
4005
|
+
const { WebGPUBackend } = await import("./webgpu-BGuG58KZ.js");
|
|
3981
4006
|
const importantLimits = [
|
|
3982
4007
|
"maxBufferSize",
|
|
3983
4008
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4031,4 +4056,4 @@ var UnsupportedOpError = class extends Error {
|
|
|
4031
4056
|
|
|
4032
4057
|
//#endregion
|
|
4033
4058
|
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
4034
|
-
//# sourceMappingURL=backend-
|
|
4059
|
+
//# sourceMappingURL=backend-BqymqzuU.js.map
|
|
@@ -290,10 +290,11 @@ var FpHash = class FpHash {
|
|
|
290
290
|
};
|
|
291
291
|
/** Run a function while caching it inline inside a `Map`. */
|
|
292
292
|
function runWithCache(cache, key, thunk) {
|
|
293
|
-
|
|
293
|
+
const keyStr = JSON.stringify(key);
|
|
294
|
+
if (cache.has(keyStr)) return cache.get(keyStr);
|
|
294
295
|
else {
|
|
295
296
|
const value = thunk();
|
|
296
|
-
cache.set(
|
|
297
|
+
cache.set(keyStr, value);
|
|
297
298
|
return value;
|
|
298
299
|
}
|
|
299
300
|
}
|
|
@@ -450,6 +451,14 @@ var AluExp = class AluExp {
|
|
|
450
451
|
static sqrt(a) {
|
|
451
452
|
return new AluExp(AluOp.Sqrt, a.dtype, [a]);
|
|
452
453
|
}
|
|
454
|
+
static floor(a) {
|
|
455
|
+
if (!isFloatDtype(a.dtype)) return a;
|
|
456
|
+
return new AluExp(AluOp.Floor, a.dtype, [a]);
|
|
457
|
+
}
|
|
458
|
+
static ceil(a) {
|
|
459
|
+
if (!isFloatDtype(a.dtype)) return a;
|
|
460
|
+
return new AluExp(AluOp.Ceil, a.dtype, [a]);
|
|
461
|
+
}
|
|
453
462
|
static reciprocal(a) {
|
|
454
463
|
return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
|
|
455
464
|
}
|
|
@@ -630,6 +639,12 @@ var AluExp = class AluExp {
|
|
|
630
639
|
case AluOp.Sqrt:
|
|
631
640
|
ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
|
|
632
641
|
break;
|
|
642
|
+
case AluOp.Floor:
|
|
643
|
+
ret = [Math.floor(src[0].min), Math.floor(src[0].max)];
|
|
644
|
+
break;
|
|
645
|
+
case AluOp.Ceil:
|
|
646
|
+
ret = [Math.ceil(src[0].min), Math.ceil(src[0].max)];
|
|
647
|
+
break;
|
|
633
648
|
case AluOp.Reciprocal:
|
|
634
649
|
if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
|
|
635
650
|
ret = [1 / src[0].max, 1 / src[0].min];
|
|
@@ -862,7 +877,7 @@ var AluExp = class AluExp {
|
|
|
862
877
|
else return p(p(src[0].src[0], src[1]), src[0].src[1]).simplify(cache);
|
|
863
878
|
if (src[1].op === op && src[1].src[1].op === AluOp.Const) return p(p(src[0], src[1].src[0]), src[1].src[1]).simplify(cache);
|
|
864
879
|
}
|
|
865
|
-
if (op === AluOp.Mod || op === AluOp.Idiv && src[1].#isConstInt()) {
|
|
880
|
+
if ((op === AluOp.Mod || op === AluOp.Idiv) && src[1].#isConstInt()) {
|
|
866
881
|
const [x, y] = src;
|
|
867
882
|
{
|
|
868
883
|
const factors = [];
|
|
@@ -952,6 +967,8 @@ var AluExp = class AluExp {
|
|
|
952
967
|
case AluOp.Erf: return erf(x);
|
|
953
968
|
case AluOp.Erfc: return erfc(x);
|
|
954
969
|
case AluOp.Sqrt: return Math.sqrt(x);
|
|
970
|
+
case AluOp.Floor: return Math.floor(x);
|
|
971
|
+
case AluOp.Ceil: return Math.ceil(x);
|
|
955
972
|
case AluOp.Reciprocal: return 1 / x;
|
|
956
973
|
case AluOp.Cast: {
|
|
957
974
|
const wasFloat = isFloatDtype(this.src[0].dtype);
|
|
@@ -1141,6 +1158,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
1141
1158
|
AluOp$1["Erf"] = "Erf";
|
|
1142
1159
|
AluOp$1["Erfc"] = "Erfc";
|
|
1143
1160
|
AluOp$1["Sqrt"] = "Sqrt";
|
|
1161
|
+
AluOp$1["Floor"] = "Floor";
|
|
1162
|
+
AluOp$1["Ceil"] = "Ceil";
|
|
1144
1163
|
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
1145
1164
|
AluOp$1["Cast"] = "Cast";
|
|
1146
1165
|
AluOp$1["Bitcast"] = "Bitcast";
|
|
@@ -1175,6 +1194,8 @@ const AluGroup = {
|
|
|
1175
1194
|
AluOp.Erf,
|
|
1176
1195
|
AluOp.Erfc,
|
|
1177
1196
|
AluOp.Sqrt,
|
|
1197
|
+
AluOp.Floor,
|
|
1198
|
+
AluOp.Ceil,
|
|
1178
1199
|
AluOp.Reciprocal,
|
|
1179
1200
|
AluOp.Cast,
|
|
1180
1201
|
AluOp.Bitcast
|
|
@@ -1202,7 +1223,9 @@ const AluGroup = {
|
|
|
1202
1223
|
AluOp.Erf,
|
|
1203
1224
|
AluOp.Erfc,
|
|
1204
1225
|
AluOp.Sqrt,
|
|
1205
|
-
AluOp.Reciprocal
|
|
1226
|
+
AluOp.Reciprocal,
|
|
1227
|
+
AluOp.Floor,
|
|
1228
|
+
AluOp.Ceil
|
|
1206
1229
|
])
|
|
1207
1230
|
};
|
|
1208
1231
|
/** Common variables that can be substituted in expressions. */
|
|
@@ -3832,7 +3855,9 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3832
3855
|
else if (op === AluOp.Reciprocal) {
|
|
3833
3856
|
const dt = dtyF(cg, op, dtype);
|
|
3834
3857
|
dt.const(1), gen(src[0]), dt.div();
|
|
3835
|
-
} else if (op === AluOp.
|
|
3858
|
+
} else if (op === AluOp.Floor) gen(src[0]), dtyF(cg, op, dtype).floor();
|
|
3859
|
+
else if (op === AluOp.Ceil) gen(src[0]), dtyF(cg, op, dtype).ceil();
|
|
3860
|
+
else if (op === AluOp.Cast) {
|
|
3836
3861
|
gen(src[0]);
|
|
3837
3862
|
const dtype0 = src[0].dtype;
|
|
3838
3863
|
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
|
|
@@ -3978,7 +4003,7 @@ async function createBackend(device) {
|
|
|
3978
4003
|
if (!navigator.gpu) return null;
|
|
3979
4004
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3980
4005
|
if (!adapter) return null;
|
|
3981
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
4006
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-CcGP160M.cjs"));
|
|
3982
4007
|
const importantLimits = [
|
|
3983
4008
|
"maxBufferSize",
|
|
3984
4009
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4325,4 +4350,4 @@ Object.defineProperty(exports, 'zipn', {
|
|
|
4325
4350
|
return zipn;
|
|
4326
4351
|
}
|
|
4327
4352
|
});
|
|
4328
|
-
//# sourceMappingURL=backend-
|
|
4353
|
+
//# sourceMappingURL=backend-DeVfWEFS.cjs.map
|