@jax-js/jax 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/{backend-FtkbO6pI.cjs → backend-BbrKEB18.cjs} +165 -69
- package/dist/{backend-DwIAd0AG.js → backend-CoVtc9dx.js} +165 -69
- package/dist/index.cjs +42 -8
- package/dist/index.d.cts +20 -6
- package/dist/index.d.ts +20 -6
- package/dist/index.js +42 -8
- package/dist/{webgpu-LGi2A3mS.js → webgpu-B3UVme6n.js} +9 -4
- package/dist/{webgpu-BE7zA_01.cjs → webgpu-DGYNVHma.cjs} +9 -4
- package/package.json +21 -13
package/dist/index.d.ts
CHANGED
|
@@ -165,9 +165,10 @@ declare enum DType {
|
|
|
165
165
|
Uint32 = "uint32",
|
|
166
166
|
Bool = "bool",
|
|
167
167
|
Float16 = "float16",
|
|
168
|
+
Float64 = "float64",
|
|
168
169
|
}
|
|
169
170
|
/** @inline */
|
|
170
|
-
type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
|
|
171
|
+
type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer> | Float64Array<ArrayBuffer>;
|
|
171
172
|
/**
|
|
172
173
|
* Promote two dtypes to their join according to the type lattice.
|
|
173
174
|
*
|
|
@@ -177,7 +178,7 @@ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Arr
|
|
|
177
178
|
*
|
|
178
179
|
* **Type lattice:**
|
|
179
180
|
* ```text
|
|
180
|
-
* bool -> uint32 -> int32 -> float16 -> float32
|
|
181
|
+
* bool -> uint32 -> int32 -> float16 -> float32 -> float64
|
|
181
182
|
* weakType --^
|
|
182
183
|
* ```
|
|
183
184
|
*
|
|
@@ -240,6 +241,7 @@ declare class AluExp implements FpHashable {
|
|
|
240
241
|
static u32(value: number): AluExp;
|
|
241
242
|
static bool(value: boolean): AluExp;
|
|
242
243
|
static f16(value: number): AluExp;
|
|
244
|
+
static f64(value: number): AluExp;
|
|
243
245
|
not(): AluExp;
|
|
244
246
|
/** Compute a reasonable expression hash with low collision rate. */
|
|
245
247
|
getHash(): bigint;
|
|
@@ -627,11 +629,9 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
627
629
|
/** Type of parameters taken by each primitive. */
|
|
628
630
|
type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
|
|
629
631
|
declare enum CompareOp {
|
|
630
|
-
Greater = "greater",
|
|
631
632
|
Less = "less",
|
|
632
633
|
Equal = "equal",
|
|
633
634
|
NotEqual = "not_equal",
|
|
634
|
-
GreaterEqual = "greater_equal",
|
|
635
635
|
LessEqual = "less_equal",
|
|
636
636
|
}
|
|
637
637
|
/** @inline */
|
|
@@ -982,7 +982,7 @@ declare class Array extends Tracer {
|
|
|
982
982
|
_putSync(backend: Backend): Array;
|
|
983
983
|
}
|
|
984
984
|
/** Constructor for creating a new array from data. */
|
|
985
|
-
declare function array(values: Array |
|
|
985
|
+
declare function array(values: Array | DataArray | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
986
986
|
shape,
|
|
987
987
|
dtype,
|
|
988
988
|
device
|
|
@@ -1055,13 +1055,14 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
1055
1055
|
device
|
|
1056
1056
|
}?: DTypeAndDevice): Array;
|
|
1057
1057
|
declare namespace numpy_d_exports {
|
|
1058
|
-
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
1058
|
+
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
1059
1059
|
}
|
|
1060
1060
|
declare const float32 = DType.Float32;
|
|
1061
1061
|
declare const int32 = DType.Int32;
|
|
1062
1062
|
declare const uint32 = DType.Uint32;
|
|
1063
1063
|
declare const bool = DType.Bool;
|
|
1064
1064
|
declare const float16 = DType.Float16;
|
|
1065
|
+
declare const float64 = DType.Float64;
|
|
1065
1066
|
/** Euler's constant, `e = 2.7182818284590...` */
|
|
1066
1067
|
declare const e: number;
|
|
1067
1068
|
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
@@ -1532,6 +1533,19 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1532
1533
|
mean?: ArrayLike;
|
|
1533
1534
|
correction?: number;
|
|
1534
1535
|
} & ReduceOpts): Array;
|
|
1536
|
+
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
1537
|
+
declare function isinf(x: ArrayLike): Array;
|
|
1538
|
+
/** Test element-wise for NaN (Not a Number). */
|
|
1539
|
+
declare function isnan(x: ArrayLike): Array;
|
|
1540
|
+
/** Test element-wise for negative infinity, return bool array. */
|
|
1541
|
+
declare function isneginf(x: ArrayLike): Array;
|
|
1542
|
+
/** Test element-wise for positive infinity, return bool array. */
|
|
1543
|
+
declare function isposinf(x: ArrayLike): Array;
|
|
1544
|
+
/**
|
|
1545
|
+
* @function
|
|
1546
|
+
* Test element-wise for finite values (not infinity or NaN).
|
|
1547
|
+
*/
|
|
1548
|
+
declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1535
1549
|
//# sourceMappingURL=numpy.d.ts.map
|
|
1536
1550
|
//#endregion
|
|
1537
1551
|
//#region src/frontend/jaxpr.d.ts
|
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-CoVtc9dx.js";
|
|
3
3
|
|
|
4
4
|
//#region src/tree.ts
|
|
5
5
|
var tree_exports = {};
|
|
@@ -370,11 +370,9 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
370
370
|
return Primitive$1;
|
|
371
371
|
}({});
|
|
372
372
|
let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
|
|
373
|
-
CompareOp$1["Greater"] = "greater";
|
|
374
373
|
CompareOp$1["Less"] = "less";
|
|
375
374
|
CompareOp$1["Equal"] = "equal";
|
|
376
375
|
CompareOp$1["NotEqual"] = "not_equal";
|
|
377
|
-
CompareOp$1["GreaterEqual"] = "greater_equal";
|
|
378
376
|
CompareOp$1["LessEqual"] = "less_equal";
|
|
379
377
|
return CompareOp$1;
|
|
380
378
|
}({});
|
|
@@ -470,7 +468,7 @@ function compare(x, y, op) {
|
|
|
470
468
|
return bind1(Primitive.Compare, [x, y], { op });
|
|
471
469
|
}
|
|
472
470
|
function greater$1(x, y) {
|
|
473
|
-
return compare(
|
|
471
|
+
return compare(y, x, CompareOp.Less);
|
|
474
472
|
}
|
|
475
473
|
function less$1(x, y) {
|
|
476
474
|
return compare(x, y, CompareOp.Less);
|
|
@@ -482,7 +480,7 @@ function notEqual$1(x, y) {
|
|
|
482
480
|
return compare(x, y, CompareOp.NotEqual);
|
|
483
481
|
}
|
|
484
482
|
function greaterEqual$1(x, y) {
|
|
485
|
-
return compare(
|
|
483
|
+
return compare(y, x, CompareOp.LessEqual);
|
|
486
484
|
}
|
|
487
485
|
function lessEqual$1(x, y) {
|
|
488
486
|
return compare(x, y, CompareOp.LessEqual);
|
|
@@ -2209,6 +2207,9 @@ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
|
|
|
2209
2207
|
} else if (data instanceof Float16Array) {
|
|
2210
2208
|
if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2211
2209
|
dtype ??= DType.Float16;
|
|
2210
|
+
} else if (data instanceof Float64Array) {
|
|
2211
|
+
if (dtype && dtype !== DType.Float64) throw new Error("Float64Array must have float64 type");
|
|
2212
|
+
dtype ??= DType.Float64;
|
|
2212
2213
|
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2213
2214
|
if (data.length < inlineArrayLimit) {
|
|
2214
2215
|
let allEqual = true;
|
|
@@ -2420,11 +2421,9 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2420
2421
|
}
|
|
2421
2422
|
function aluCompare(a, b, op) {
|
|
2422
2423
|
switch (op) {
|
|
2423
|
-
case CompareOp.Greater: return AluExp.mul(AluExp.cmpne(a, b), AluExp.cmplt(a, b).not());
|
|
2424
2424
|
case CompareOp.Less: return AluExp.cmplt(a, b);
|
|
2425
2425
|
case CompareOp.Equal: return AluExp.cmpne(a, b).not();
|
|
2426
2426
|
case CompareOp.NotEqual: return AluExp.cmpne(a, b);
|
|
2427
|
-
case CompareOp.GreaterEqual: return AluExp.cmplt(a, b).not();
|
|
2428
2427
|
case CompareOp.LessEqual: return AluExp.add(AluExp.cmplt(a, b), AluExp.cmpne(a, b).not());
|
|
2429
2428
|
}
|
|
2430
2429
|
}
|
|
@@ -2557,7 +2556,7 @@ var JaxprEqn = class {
|
|
|
2557
2556
|
const paramsList = Object.entries(this.params).map(([k, v]) => PPrint.pp(`${k}=${v}`));
|
|
2558
2557
|
if (paramsList.length > 0) rhs = rhs.stack(PPrint.pp(" [ ")).stack(PPrint.prototype.concat(...paramsList)).stack(PPrint.pp(" ] "));
|
|
2559
2558
|
else rhs = rhs.stack(PPrint.pp(" "));
|
|
2560
|
-
rhs = rhs.stack(PPrint.pp(this.inputs.map((x) => x instanceof Var ? vp.name(x) :
|
|
2559
|
+
rhs = rhs.stack(PPrint.pp(this.inputs.map((x) => x instanceof Var ? vp.name(x) : String(x.value)).join(" ")));
|
|
2561
2560
|
return lhs.stack(PPrint.pp(" = ")).stack(rhs);
|
|
2562
2561
|
}
|
|
2563
2562
|
toString() {
|
|
@@ -4342,6 +4341,7 @@ __export(numpy_exports, {
|
|
|
4342
4341
|
flipud: () => flipud,
|
|
4343
4342
|
float16: () => float16,
|
|
4344
4343
|
float32: () => float32,
|
|
4344
|
+
float64: () => float64,
|
|
4345
4345
|
full: () => full,
|
|
4346
4346
|
fullLike: () => fullLike$1,
|
|
4347
4347
|
greater: () => greater,
|
|
@@ -4355,6 +4355,11 @@ __export(numpy_exports, {
|
|
|
4355
4355
|
inf: () => inf,
|
|
4356
4356
|
inner: () => inner,
|
|
4357
4357
|
int32: () => int32,
|
|
4358
|
+
isfinite: () => isfinite,
|
|
4359
|
+
isinf: () => isinf,
|
|
4360
|
+
isnan: () => isnan,
|
|
4361
|
+
isneginf: () => isneginf,
|
|
4362
|
+
isposinf: () => isposinf,
|
|
4358
4363
|
less: () => less,
|
|
4359
4364
|
lessEqual: () => lessEqual,
|
|
4360
4365
|
linspace: () => linspace,
|
|
@@ -4425,6 +4430,7 @@ const int32 = DType.Int32;
|
|
|
4425
4430
|
const uint32 = DType.Uint32;
|
|
4426
4431
|
const bool = DType.Bool;
|
|
4427
4432
|
const float16 = DType.Float16;
|
|
4433
|
+
const float64 = DType.Float64;
|
|
4428
4434
|
/** Euler's constant, `e = 2.7182818284590...` */
|
|
4429
4435
|
const e = Math.E;
|
|
4430
4436
|
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
@@ -5225,6 +5231,34 @@ function var_(x, axis = null, opts) {
|
|
|
5225
5231
|
function std(x, axis = null, opts) {
|
|
5226
5232
|
return sqrt(var_(x, axis, opts));
|
|
5227
5233
|
}
|
|
5234
|
+
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
5235
|
+
function isinf(x) {
|
|
5236
|
+
x = fudgeArray(x);
|
|
5237
|
+
return isFloatDtype(x.dtype) ? x.ref.equal(Infinity).add(x.equal(-Infinity)) : fullLike$1(x, false);
|
|
5238
|
+
}
|
|
5239
|
+
/** Test element-wise for NaN (Not a Number). */
|
|
5240
|
+
function isnan(x) {
|
|
5241
|
+
x = fudgeArray(x);
|
|
5242
|
+
return isFloatDtype(x.dtype) ? x.ref.notEqual(x) : fullLike$1(x, false);
|
|
5243
|
+
}
|
|
5244
|
+
/** Test element-wise for negative infinity, return bool array. */
|
|
5245
|
+
function isneginf(x) {
|
|
5246
|
+
x = fudgeArray(x);
|
|
5247
|
+
return isFloatDtype(x.dtype) ? x.equal(-Infinity) : fullLike$1(x, false);
|
|
5248
|
+
}
|
|
5249
|
+
/** Test element-wise for positive infinity, return bool array. */
|
|
5250
|
+
function isposinf(x) {
|
|
5251
|
+
x = fudgeArray(x);
|
|
5252
|
+
return isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
|
|
5253
|
+
}
|
|
5254
|
+
/**
|
|
5255
|
+
* @function
|
|
5256
|
+
* Test element-wise for finite values (not infinity or NaN).
|
|
5257
|
+
*/
|
|
5258
|
+
const isfinite = jit$1(function isfinite$1(x) {
|
|
5259
|
+
if (!isFloatDtype(x.dtype)) return fullLike$1(x, true);
|
|
5260
|
+
return isnan(x.ref).add(isinf(x)).notEqual(true);
|
|
5261
|
+
});
|
|
5228
5262
|
|
|
5229
5263
|
//#endregion
|
|
5230
5264
|
//#region src/nn.ts
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-CoVtc9dx.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -323,7 +323,7 @@ function dtypeToWgsl(dtype, storage = false) {
|
|
|
323
323
|
case DType.Uint32: return "u32";
|
|
324
324
|
case DType.Float32: return "f32";
|
|
325
325
|
case DType.Float16: return "f16";
|
|
326
|
-
default: throw new Error(`Unsupported dtype: ${dtype}`);
|
|
326
|
+
default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
|
|
327
327
|
}
|
|
328
328
|
}
|
|
329
329
|
function constToWgsl(dtype, value) {
|
|
@@ -397,6 +397,7 @@ function pipelineSource(device, kernel) {
|
|
|
397
397
|
}
|
|
398
398
|
let gensymCount = 0;
|
|
399
399
|
const gensym = () => `alu${gensymCount++}`;
|
|
400
|
+
const isGensym = (text) => text.match(/^alu[0-9]+$/);
|
|
400
401
|
for (let i = 0; i < args.length; i++) if (!usedArgs[i]) emit(`_ = &${args[i]};`);
|
|
401
402
|
const references = /* @__PURE__ */ new Map();
|
|
402
403
|
const seen = /* @__PURE__ */ new Set();
|
|
@@ -425,7 +426,11 @@ function pipelineSource(device, kernel) {
|
|
|
425
426
|
else if (op === AluOp.Min) source = `min(${strip1(a)}, ${strip1(b)})`;
|
|
426
427
|
else if (op === AluOp.Max) source = `max(${strip1(a)}, ${strip1(b)})`;
|
|
427
428
|
else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
428
|
-
else if (op === AluOp.Cmpne)
|
|
429
|
+
else if (op === AluOp.Cmpne) if (isFloatDtype(src[0].dtype)) {
|
|
430
|
+
const x = isGensym(a) ? a : gensym();
|
|
431
|
+
if (x !== a) emit(`let ${x} = ${a};`);
|
|
432
|
+
source = `(${x} != ${b} || min(${x}, ${dtypeToWgsl(src[0].dtype)}(inf())) != ${x})`;
|
|
433
|
+
} else source = `(${a} != ${b})`;
|
|
429
434
|
} else if (AluGroup.Unary.has(op)) if (op === AluOp.Reciprocal && src[0].op === AluOp.Sqrt) {
|
|
430
435
|
const a = gen(src[0].src[0]);
|
|
431
436
|
source = `inverseSqrt(${a})`;
|
|
@@ -645,4 +650,4 @@ async function compileError(shaderModule, scope, code) {
|
|
|
645
650
|
|
|
646
651
|
//#endregion
|
|
647
652
|
export { WebGPUBackend };
|
|
648
|
-
//# sourceMappingURL=webgpu-
|
|
653
|
+
//# sourceMappingURL=webgpu-B3UVme6n.js.map
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-BbrKEB18.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -323,7 +323,7 @@ function dtypeToWgsl(dtype, storage = false) {
|
|
|
323
323
|
case require_backend.DType.Uint32: return "u32";
|
|
324
324
|
case require_backend.DType.Float32: return "f32";
|
|
325
325
|
case require_backend.DType.Float16: return "f16";
|
|
326
|
-
default: throw new Error(`Unsupported dtype: ${dtype}`);
|
|
326
|
+
default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
|
|
327
327
|
}
|
|
328
328
|
}
|
|
329
329
|
function constToWgsl(dtype, value) {
|
|
@@ -397,6 +397,7 @@ function pipelineSource(device, kernel) {
|
|
|
397
397
|
}
|
|
398
398
|
let gensymCount = 0;
|
|
399
399
|
const gensym = () => `alu${gensymCount++}`;
|
|
400
|
+
const isGensym = (text) => text.match(/^alu[0-9]+$/);
|
|
400
401
|
for (let i = 0; i < args.length; i++) if (!usedArgs[i]) emit(`_ = &${args[i]};`);
|
|
401
402
|
const references = /* @__PURE__ */ new Map();
|
|
402
403
|
const seen = /* @__PURE__ */ new Set();
|
|
@@ -425,7 +426,11 @@ function pipelineSource(device, kernel) {
|
|
|
425
426
|
else if (op === require_backend.AluOp.Min) source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
426
427
|
else if (op === require_backend.AluOp.Max) source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
427
428
|
else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
428
|
-
else if (op === require_backend.AluOp.Cmpne)
|
|
429
|
+
else if (op === require_backend.AluOp.Cmpne) if (require_backend.isFloatDtype(src[0].dtype)) {
|
|
430
|
+
const x = isGensym(a) ? a : gensym();
|
|
431
|
+
if (x !== a) emit(`let ${x} = ${a};`);
|
|
432
|
+
source = `(${x} != ${b} || min(${x}, ${dtypeToWgsl(src[0].dtype)}(inf())) != ${x})`;
|
|
433
|
+
} else source = `(${a} != ${b})`;
|
|
429
434
|
} else if (require_backend.AluGroup.Unary.has(op)) if (op === require_backend.AluOp.Reciprocal && src[0].op === require_backend.AluOp.Sqrt) {
|
|
430
435
|
const a = gen(src[0].src[0]);
|
|
431
436
|
source = `inverseSqrt(${a})`;
|
|
@@ -645,4 +650,4 @@ async function compileError(shaderModule, scope, code) {
|
|
|
645
650
|
|
|
646
651
|
//#endregion
|
|
647
652
|
exports.WebGPUBackend = WebGPUBackend;
|
|
648
|
-
//# sourceMappingURL=webgpu-
|
|
653
|
+
//# sourceMappingURL=webgpu-DGYNVHma.cjs.map
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@jax-js/jax",
|
|
3
|
-
"version": "0.1.
|
|
3
|
+
"version": "0.1.1",
|
|
4
4
|
"description": "Numerical computing and ML in the browser",
|
|
5
5
|
"keywords": [
|
|
6
6
|
"machine learning",
|
|
@@ -35,6 +35,19 @@
|
|
|
35
35
|
"types": "dist/index.d.ts",
|
|
36
36
|
"module": "dist/index.js",
|
|
37
37
|
"license": "MIT",
|
|
38
|
+
"scripts": {
|
|
39
|
+
"build": "tsdown",
|
|
40
|
+
"build:watch": "TSDOWN_WATCH_MODE=1 tsdown",
|
|
41
|
+
"check": "tsc",
|
|
42
|
+
"docs": "tsx typedoc.ts",
|
|
43
|
+
"format": "prettier --write .",
|
|
44
|
+
"format:check": "prettier --check .",
|
|
45
|
+
"lint": "eslint",
|
|
46
|
+
"test": "vitest",
|
|
47
|
+
"test:coverage": "vitest run --coverage && open coverage/index.html",
|
|
48
|
+
"prepublishOnly": "pnpm build",
|
|
49
|
+
"postpublish": "git tag jax/v$npm_package_version && git push --tags"
|
|
50
|
+
},
|
|
38
51
|
"devDependencies": {
|
|
39
52
|
"@eslint/js": "^9.31.0",
|
|
40
53
|
"@types/debug": "^4.1.12",
|
|
@@ -55,9 +68,15 @@
|
|
|
55
68
|
"typescript-eslint": "^8.46.4",
|
|
56
69
|
"vitest": "^4.0.9"
|
|
57
70
|
},
|
|
71
|
+
"packageManager": "pnpm@10.22.0",
|
|
58
72
|
"engines": {
|
|
59
73
|
"pnpm": ">=10.0.0"
|
|
60
74
|
},
|
|
75
|
+
"pnpm": {
|
|
76
|
+
"overrides": {
|
|
77
|
+
"@tensorflow/tfjs-core>@webgpu/types": "^0.1.68"
|
|
78
|
+
}
|
|
79
|
+
},
|
|
61
80
|
"prettier": {
|
|
62
81
|
"plugins": [
|
|
63
82
|
"prettier-plugin-svelte"
|
|
@@ -73,16 +92,5 @@
|
|
|
73
92
|
}
|
|
74
93
|
],
|
|
75
94
|
"proseWrap": "always"
|
|
76
|
-
},
|
|
77
|
-
"scripts": {
|
|
78
|
-
"build": "tsdown",
|
|
79
|
-
"build:watch": "TSDOWN_WATCH_MODE=1 tsdown",
|
|
80
|
-
"check": "tsc",
|
|
81
|
-
"docs": "tsx typedoc.ts",
|
|
82
|
-
"format": "prettier --write .",
|
|
83
|
-
"format:check": "prettier --check .",
|
|
84
|
-
"lint": "eslint",
|
|
85
|
-
"test": "vitest",
|
|
86
|
-
"test:coverage": "vitest run --coverage && open coverage/index.html"
|
|
87
95
|
}
|
|
88
|
-
}
|
|
96
|
+
}
|