@jax-js/jax 0.1.10 → 0.1.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +7 -2
- package/dist/{backend-Ctqs8la1.js → backend-DI-V78Rk.js} +732 -21
- package/dist/{backend-DMauYnfl.cjs → backend-x-6vqzIM.cjs} +737 -20
- package/dist/index.cjs +372 -20
- package/dist/index.d.cts +172 -4
- package/dist/index.d.ts +172 -4
- package/dist/index.js +372 -21
- package/dist/{webgl-CvQ1QBX1.js → webgl-BhsnpeB0.js} +7 -1
- package/dist/{webgl-kvVt7-T7.cjs → webgl-CD3WK_Me.cjs} +7 -1
- package/dist/{webgpu-v_W_-oKw.js → webgpu-C2kLdkUh.js} +299 -149
- package/dist/{webgpu-DMSx7a6M.cjs → webgpu-C4S8Uq9e.cjs} +299 -149
- package/package.json +1 -1
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-x-6vqzIM.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -240,7 +240,7 @@ __export(tree_exports, {
|
|
|
240
240
|
structure: () => structure,
|
|
241
241
|
unflatten: () => unflatten
|
|
242
242
|
});
|
|
243
|
-
const JsArray$
|
|
243
|
+
const JsArray$3 = globalThis.Array;
|
|
244
244
|
let NodeType = /* @__PURE__ */ function(NodeType$1) {
|
|
245
245
|
NodeType$1["Array"] = "Array";
|
|
246
246
|
NodeType$1["Object"] = "Object";
|
|
@@ -288,7 +288,7 @@ function flatten(tree) {
|
|
|
288
288
|
return [leaves$1, treedef];
|
|
289
289
|
}
|
|
290
290
|
function _flatten(tree, leaves$1) {
|
|
291
|
-
if (JsArray$
|
|
291
|
+
if (JsArray$3.isArray(tree)) {
|
|
292
292
|
const childTrees = tree.map((c) => _flatten(c, leaves$1));
|
|
293
293
|
return new JsTreeDef(NodeType.Array, null, childTrees);
|
|
294
294
|
} else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
|
|
@@ -364,6 +364,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
364
364
|
Primitive$1["Mod"] = "mod";
|
|
365
365
|
Primitive$1["Min"] = "min";
|
|
366
366
|
Primitive$1["Max"] = "max";
|
|
367
|
+
Primitive$1["BitCombine"] = "bit_combine";
|
|
368
|
+
Primitive$1["BitShift"] = "bit_shift";
|
|
367
369
|
Primitive$1["Neg"] = "neg";
|
|
368
370
|
Primitive$1["Reciprocal"] = "reciprocal";
|
|
369
371
|
Primitive$1["Floor"] = "floor";
|
|
@@ -437,6 +439,12 @@ function min$1(x, y) {
|
|
|
437
439
|
function max$1(x, y) {
|
|
438
440
|
return bind1(Primitive.Max, [x, y]);
|
|
439
441
|
}
|
|
442
|
+
function bitCombine(x, y, op) {
|
|
443
|
+
return bind1(Primitive.BitCombine, [x, y], { op });
|
|
444
|
+
}
|
|
445
|
+
function bitShift(x, y, op) {
|
|
446
|
+
return bind1(Primitive.BitShift, [x, y], { op });
|
|
447
|
+
}
|
|
440
448
|
function neg(x) {
|
|
441
449
|
return bind1(Primitive.Neg, [x]);
|
|
442
450
|
}
|
|
@@ -1655,6 +1663,16 @@ const abstractEvalRules = {
|
|
|
1655
1663
|
[Primitive.Mod]: binopAbstractEval,
|
|
1656
1664
|
[Primitive.Min]: binopAbstractEval,
|
|
1657
1665
|
[Primitive.Max]: binopAbstractEval,
|
|
1666
|
+
[Primitive.BitCombine]([x, y]) {
|
|
1667
|
+
const aval = promoteAvals(x, y);
|
|
1668
|
+
if (require_backend.isFloatDtype(aval.dtype)) throw new TypeError(`bitwise operations require integer or boolean inputs, got ${aval.dtype}`);
|
|
1669
|
+
return [aval];
|
|
1670
|
+
},
|
|
1671
|
+
[Primitive.BitShift]([x, y]) {
|
|
1672
|
+
const shape$1 = require_backend.generalBroadcast(x.shape, y.shape);
|
|
1673
|
+
if (require_backend.isFloatDtype(x.dtype) || require_backend.isFloatDtype(y.dtype) || x.dtype === require_backend.DType.Bool || y.dtype === require_backend.DType.Bool) throw new TypeError(`bit shift operations require integer inputs, got ${x} and ${y}`);
|
|
1674
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
1675
|
+
},
|
|
1658
1676
|
[Primitive.Neg]: vectorizedUnopAbstractEval,
|
|
1659
1677
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
1660
1678
|
[Primitive.Floor]: vectorizedUnopAbstractEval,
|
|
@@ -2190,6 +2208,8 @@ const jitRules = {
|
|
|
2190
2208
|
[Primitive.Mod]: broadcastedJit(([a, b]) => require_backend.AluExp.mod(a, b)),
|
|
2191
2209
|
[Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
|
|
2192
2210
|
[Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
|
|
2211
|
+
[Primitive.BitCombine]: broadcastedJit(([a, b], { op }) => require_backend.AluExp.bitCombine(a, b, op)),
|
|
2212
|
+
[Primitive.BitShift]: broadcastedJit(([a, b], { op }) => require_backend.AluExp.bitShift(a, b, op)),
|
|
2193
2213
|
[Primitive.Neg]: unopJit((a) => require_backend.AluExp.sub(require_backend.AluExp.const(a.dtype, 0), a)),
|
|
2194
2214
|
[Primitive.Reciprocal]: unopJit(require_backend.AluExp.reciprocal),
|
|
2195
2215
|
[Primitive.Floor]: unopJit(require_backend.AluExp.floor),
|
|
@@ -2382,7 +2402,9 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2382
2402
|
case Primitive.Idiv:
|
|
2383
2403
|
case Primitive.Mod:
|
|
2384
2404
|
case Primitive.Min:
|
|
2385
|
-
case Primitive.Max:
|
|
2405
|
+
case Primitive.Max:
|
|
2406
|
+
case Primitive.BitCombine:
|
|
2407
|
+
case Primitive.BitShift: {
|
|
2386
2408
|
const otherInput = nextEqn.inputs.find((v) => v !== outVar);
|
|
2387
2409
|
if (otherInput instanceof Lit || require_backend.deepEqual(require_backend.generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
|
|
2388
2410
|
head = usages[0];
|
|
@@ -2473,7 +2495,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2473
2495
|
|
|
2474
2496
|
//#endregion
|
|
2475
2497
|
//#region src/frontend/array.ts
|
|
2476
|
-
const JsArray$
|
|
2498
|
+
const JsArray$2 = globalThis.Array;
|
|
2477
2499
|
const inlineArrayLimit = 128;
|
|
2478
2500
|
/** Version of pureArray with fudged types. */
|
|
2479
2501
|
const fudgeArray = pureArray;
|
|
@@ -2913,6 +2935,15 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2913
2935
|
this.#check();
|
|
2914
2936
|
const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
|
|
2915
2937
|
if (this.#source instanceof require_backend.AluExp) {
|
|
2938
|
+
let resolvedSource;
|
|
2939
|
+
if (this.#st.contiguous && this.#st.size < inlineArrayLimit && (resolvedSource = this.#source.resolve()) !== void 0) {
|
|
2940
|
+
const byteLength = this.#st.size * require_backend.byteWidth(this.#dtype);
|
|
2941
|
+
const initialData = new Uint8Array(byteLength);
|
|
2942
|
+
require_backend.dtypedArray(this.#dtype, initialData).fill(resolvedSource);
|
|
2943
|
+
this.#source = this.#backend.malloc(byteLength, initialData);
|
|
2944
|
+
this.#st = require_backend.ShapeTracker.fromShape(this.shape);
|
|
2945
|
+
return;
|
|
2946
|
+
}
|
|
2916
2947
|
const exp$2 = require_backend.accessorAluExp(this.#source, this.#st, indices);
|
|
2917
2948
|
const kernel = new require_backend.Kernel(0, this.#st.size, exp$2);
|
|
2918
2949
|
const output = this.#backend.malloc(kernel.bytes);
|
|
@@ -3021,6 +3052,42 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3021
3052
|
return require_backend.dtypedArray(this.dtype, buf);
|
|
3022
3053
|
}
|
|
3023
3054
|
/**
|
|
3055
|
+
* Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
|
|
3056
|
+
*
|
|
3057
|
+
* Only available on the WebGPU backend. The array's memory is still managed
|
|
3058
|
+
* by jax-js, and it will be freed when the buffer is no longer in use. You
|
|
3059
|
+
* _should not_ mutate the buffer's contents.
|
|
3060
|
+
*
|
|
3061
|
+
* Note that the GPU buffer may be slightly larger than the array's size; it
|
|
3062
|
+
* will always be aligned to 4 bytes.
|
|
3063
|
+
*/
|
|
3064
|
+
async gpuBuffer() {
|
|
3065
|
+
if (this.device !== "webgpu") throw new Error(`gpuBuffer() is only available on WebGPU backend`);
|
|
3066
|
+
this.#realize();
|
|
3067
|
+
const pending = this.#pending;
|
|
3068
|
+
if (pending) {
|
|
3069
|
+
await Promise.all(pending.map((p) => p.prepare()));
|
|
3070
|
+
for (const p of pending) p.submit();
|
|
3071
|
+
}
|
|
3072
|
+
const backend = this.#backend;
|
|
3073
|
+
const { buffer } = backend.buffers.get(this.#source);
|
|
3074
|
+
this.dispose();
|
|
3075
|
+
return buffer;
|
|
3076
|
+
}
|
|
3077
|
+
/** Synchronous version of `Array.gpuBuffer()`. */
|
|
3078
|
+
gpuBufferSync() {
|
|
3079
|
+
if (this.device !== "webgpu") throw new Error(`gpuBufferSync() is only available on WebGPU backend`);
|
|
3080
|
+
this.#realize();
|
|
3081
|
+
for (const p of this.#pending) {
|
|
3082
|
+
p.prepareSync();
|
|
3083
|
+
p.submit();
|
|
3084
|
+
}
|
|
3085
|
+
const backend = this.#backend;
|
|
3086
|
+
const { buffer } = backend.buffers.get(this.#source);
|
|
3087
|
+
this.dispose();
|
|
3088
|
+
return buffer;
|
|
3089
|
+
}
|
|
3090
|
+
/**
|
|
3024
3091
|
* Convert this array into a JavaScript object.
|
|
3025
3092
|
*
|
|
3026
3093
|
* This is a blocking operation that will compile all of the shaders and wait
|
|
@@ -3067,6 +3134,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3067
3134
|
[Primitive.Max]([x, y]) {
|
|
3068
3135
|
return [x.#binary(require_backend.AluOp.Max, y)];
|
|
3069
3136
|
},
|
|
3137
|
+
[Primitive.BitCombine]([x, y], { op }) {
|
|
3138
|
+
const custom = (src) => require_backend.AluExp.bitCombine(src[0], src[1], op);
|
|
3139
|
+
return [Array$1.#naryCustom("bit_combine", custom, [x, y])];
|
|
3140
|
+
},
|
|
3141
|
+
[Primitive.BitShift]([x, y], { op }) {
|
|
3142
|
+
const custom = (src) => require_backend.AluExp.bitShift(src[0], src[1], op);
|
|
3143
|
+
return [Array$1.#naryCustom("bit_shift", custom, [x, y], { dtypeOverride: [void 0, y.dtype] })];
|
|
3144
|
+
},
|
|
3070
3145
|
[Primitive.Neg]([x]) {
|
|
3071
3146
|
return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
|
|
3072
3147
|
},
|
|
@@ -3319,7 +3394,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
3319
3394
|
if (!shape$1) {
|
|
3320
3395
|
shape$1 = [];
|
|
3321
3396
|
let cur = values;
|
|
3322
|
-
while (JsArray$
|
|
3397
|
+
while (JsArray$2.isArray(cur)) {
|
|
3323
3398
|
shape$1.push(cur.length);
|
|
3324
3399
|
cur = cur[0];
|
|
3325
3400
|
}
|
|
@@ -3759,6 +3834,8 @@ const vmapRules = {
|
|
|
3759
3834
|
[Primitive.Mod]: broadcastBatcher(Primitive.Mod),
|
|
3760
3835
|
[Primitive.Min]: broadcastBatcher(Primitive.Min),
|
|
3761
3836
|
[Primitive.Max]: broadcastBatcher(Primitive.Max),
|
|
3837
|
+
[Primitive.BitCombine]: broadcastBatcher(Primitive.BitCombine),
|
|
3838
|
+
[Primitive.BitShift]: broadcastBatcher(Primitive.BitShift),
|
|
3762
3839
|
[Primitive.Neg]: unopBatcher(Primitive.Neg),
|
|
3763
3840
|
[Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
|
|
3764
3841
|
[Primitive.Floor]: unopBatcher(Primitive.Floor),
|
|
@@ -4082,6 +4159,8 @@ const jvpRules = {
|
|
|
4082
4159
|
[Primitive.Max]([x, y], [dx, dy]) {
|
|
4083
4160
|
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
4084
4161
|
},
|
|
4162
|
+
[Primitive.BitCombine]: zeroTangentsJvp(Primitive.BitCombine),
|
|
4163
|
+
[Primitive.BitShift]: zeroTangentsJvp(Primitive.BitShift),
|
|
4085
4164
|
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
4086
4165
|
[Primitive.Reciprocal]([x], [dx]) {
|
|
4087
4166
|
const xRecip = reciprocal$1(x.ref);
|
|
@@ -4199,7 +4278,7 @@ const jvpRules = {
|
|
|
4199
4278
|
return [[L], [dL]];
|
|
4200
4279
|
},
|
|
4201
4280
|
[Primitive.LU]([a], [da]) {
|
|
4202
|
-
const [luMatrix, pivots, permutation] = lu$1(a);
|
|
4281
|
+
const [luMatrix, pivots, permutation$1] = lu$1(a);
|
|
4203
4282
|
const [m, n] = a.shape.slice(-2);
|
|
4204
4283
|
const k = Math.min(m, n);
|
|
4205
4284
|
const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
|
|
@@ -4211,7 +4290,7 @@ const jvpRules = {
|
|
|
4211
4290
|
const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
|
|
4212
4291
|
const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
|
|
4213
4292
|
const U = uPadded.add(uEye);
|
|
4214
|
-
const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
|
|
4293
|
+
const P = permutation$1.ref.reshape([...permutation$1.shape, 1]).equal(arange(m)).astype(da.dtype);
|
|
4215
4294
|
const pda = batchMatmulT(P, mT(da));
|
|
4216
4295
|
const la = mT(triangularSolve$1(L.ref, mT(pda), {
|
|
4217
4296
|
lower: true,
|
|
@@ -4223,11 +4302,11 @@ const jvpRules = {
|
|
|
4223
4302
|
return [[
|
|
4224
4303
|
luMatrix,
|
|
4225
4304
|
pivots,
|
|
4226
|
-
permutation
|
|
4305
|
+
permutation$1
|
|
4227
4306
|
], [
|
|
4228
4307
|
lDot.add(uDot),
|
|
4229
4308
|
zerosLike$1(pivots.ref),
|
|
4230
|
-
zerosLike$1(permutation.ref)
|
|
4309
|
+
zerosLike$1(permutation$1.ref)
|
|
4231
4310
|
]];
|
|
4232
4311
|
},
|
|
4233
4312
|
[Primitive.Jit](primals, tangents, { name, jaxpr }) {
|
|
@@ -5273,7 +5352,8 @@ __export(numpy_linalg_exports, {
|
|
|
5273
5352
|
solve: () => solve,
|
|
5274
5353
|
tensordot: () => tensordot,
|
|
5275
5354
|
trace: () => trace,
|
|
5276
|
-
vecdot: () => vecdot
|
|
5355
|
+
vecdot: () => vecdot,
|
|
5356
|
+
vectorNorm: () => vectorNorm
|
|
5277
5357
|
});
|
|
5278
5358
|
function checkSquare(name, a) {
|
|
5279
5359
|
if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
|
|
@@ -5308,8 +5388,8 @@ function cross$1(x1, x2, axis = -1) {
|
|
|
5308
5388
|
function det(a) {
|
|
5309
5389
|
a = fudgeArray(a);
|
|
5310
5390
|
const n = checkSquare("det", a);
|
|
5311
|
-
const [lu$2, pivots, permutation] = lu(a);
|
|
5312
|
-
permutation.dispose();
|
|
5391
|
+
const [lu$2, pivots, permutation$1] = lu(a);
|
|
5392
|
+
permutation$1.dispose();
|
|
5313
5393
|
const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
|
|
5314
5394
|
const sign$1 = parity.mul(-2).add(1);
|
|
5315
5395
|
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
@@ -5398,8 +5478,8 @@ function matrixPower(a, n) {
|
|
|
5398
5478
|
function slogdet(a) {
|
|
5399
5479
|
a = fudgeArray(a);
|
|
5400
5480
|
const n = checkSquare("slogdet", a);
|
|
5401
|
-
const [lu$2, pivots, permutation] = lu(a);
|
|
5402
|
-
permutation.dispose();
|
|
5481
|
+
const [lu$2, pivots, permutation$1] = lu(a);
|
|
5482
|
+
permutation$1.dispose();
|
|
5403
5483
|
let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
|
|
5404
5484
|
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
5405
5485
|
parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
|
|
@@ -5437,9 +5517,9 @@ function solve(a, b) {
|
|
|
5437
5517
|
n,
|
|
5438
5518
|
m
|
|
5439
5519
|
]);
|
|
5440
|
-
const [lu$2, pivots, permutation] = lu(a);
|
|
5520
|
+
const [lu$2, pivots, permutation$1] = lu(a);
|
|
5441
5521
|
pivots.dispose();
|
|
5442
|
-
const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
|
|
5522
|
+
const P = arange(n).equal(permutation$1.reshape([...permutation$1.shape, 1])).astype(b.dtype);
|
|
5443
5523
|
const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
|
|
5444
5524
|
leftSide: true,
|
|
5445
5525
|
lower: true,
|
|
@@ -5452,6 +5532,23 @@ function solve(a, b) {
|
|
|
5452
5532
|
if (bIs1d) x = squeeze(x, -1);
|
|
5453
5533
|
return x;
|
|
5454
5534
|
}
|
|
5535
|
+
/**
|
|
5536
|
+
* Compute the vector norm of an array.
|
|
5537
|
+
*
|
|
5538
|
+
* @param x - Input array.
|
|
5539
|
+
* @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
|
|
5540
|
+
* @param axis - Axis/axes to reduce over (default: all axes).
|
|
5541
|
+
* @param keepdims - Whether to keep reduced dimensions as size 1.
|
|
5542
|
+
* @returns The norm of `x`, reduced over the given axes.
|
|
5543
|
+
*/
|
|
5544
|
+
function vectorNorm(x, { ord = 2, axis = null, keepdims = false } = {}) {
|
|
5545
|
+
x = fudgeArray(x);
|
|
5546
|
+
const ax = axis ?? null;
|
|
5547
|
+
if (ord === Infinity) return max(absolute(x), ax, { keepdims });
|
|
5548
|
+
else if (ord === -Infinity) return min(absolute(x), ax, { keepdims });
|
|
5549
|
+
else if (ord === 0) return x.notEqual(0).astype(x.dtype).sum(ax, { keepdims });
|
|
5550
|
+
else return power(power(absolute(x), ord).sum(ax, { keepdims }), 1 / ord);
|
|
5551
|
+
}
|
|
5455
5552
|
|
|
5456
5553
|
//#endregion
|
|
5457
5554
|
//#region src/library/numpy/dtype-info.ts
|
|
@@ -5571,6 +5668,13 @@ __export(numpy_exports, {
|
|
|
5571
5668
|
atan2: () => atan2,
|
|
5572
5669
|
atanh: () => arctanh,
|
|
5573
5670
|
average: () => average,
|
|
5671
|
+
bitwiseAnd: () => bitwiseAnd,
|
|
5672
|
+
bitwiseInvert: () => invert,
|
|
5673
|
+
bitwiseLeftShift: () => leftShift,
|
|
5674
|
+
bitwiseNot: () => invert,
|
|
5675
|
+
bitwiseOr: () => bitwiseOr,
|
|
5676
|
+
bitwiseRightShift: () => rightShift,
|
|
5677
|
+
bitwiseXor: () => bitwiseXor,
|
|
5574
5678
|
bool: () => bool,
|
|
5575
5679
|
broadcastArrays: () => broadcastArrays,
|
|
5576
5680
|
broadcastShapes: () => broadcastShapes,
|
|
@@ -5632,12 +5736,14 @@ __export(numpy_exports, {
|
|
|
5632
5736
|
inf: () => inf,
|
|
5633
5737
|
inner: () => inner,
|
|
5634
5738
|
int32: () => int32,
|
|
5739
|
+
invert: () => invert,
|
|
5635
5740
|
isfinite: () => isfinite,
|
|
5636
5741
|
isinf: () => isinf,
|
|
5637
5742
|
isnan: () => isnan,
|
|
5638
5743
|
isneginf: () => isneginf,
|
|
5639
5744
|
isposinf: () => isposinf,
|
|
5640
5745
|
ldexp: () => ldexp,
|
|
5746
|
+
leftShift: () => leftShift,
|
|
5641
5747
|
less: () => less,
|
|
5642
5748
|
lessEqual: () => lessEqual,
|
|
5643
5749
|
linalg: () => numpy_linalg_exports,
|
|
@@ -5686,6 +5792,7 @@ __export(numpy_exports, {
|
|
|
5686
5792
|
remainder: () => remainder,
|
|
5687
5793
|
repeat: () => repeat,
|
|
5688
5794
|
reshape: () => reshape,
|
|
5795
|
+
rightShift: () => rightShift,
|
|
5689
5796
|
rint: () => rint,
|
|
5690
5797
|
round: () => round,
|
|
5691
5798
|
shape: () => shape,
|
|
@@ -5800,6 +5907,44 @@ function logicalXor(x, y) {
|
|
|
5800
5907
|
function logicalNot(x) {
|
|
5801
5908
|
return notEqual(astype(x, require_backend.DType.Bool), true);
|
|
5802
5909
|
}
|
|
5910
|
+
/** Compute element-wise bitwise AND. */
|
|
5911
|
+
function bitwiseAnd(x, y) {
|
|
5912
|
+
return bitCombine(x, y, "and");
|
|
5913
|
+
}
|
|
5914
|
+
/** Compute element-wise bitwise OR. */
|
|
5915
|
+
function bitwiseOr(x, y) {
|
|
5916
|
+
return bitCombine(x, y, "or");
|
|
5917
|
+
}
|
|
5918
|
+
/** Compute element-wise bitwise XOR. */
|
|
5919
|
+
function bitwiseXor(x, y) {
|
|
5920
|
+
return bitCombine(x, y, "xor");
|
|
5921
|
+
}
|
|
5922
|
+
/** Compute element-wise bitwise NOT (inversion). */
|
|
5923
|
+
function invert(x) {
|
|
5924
|
+
const arr = fudgeArray(x);
|
|
5925
|
+
let allOnes;
|
|
5926
|
+
switch (arr.dtype) {
|
|
5927
|
+
case require_backend.DType.Bool:
|
|
5928
|
+
allOnes = true;
|
|
5929
|
+
break;
|
|
5930
|
+
case require_backend.DType.Uint32:
|
|
5931
|
+
allOnes = 4294967295;
|
|
5932
|
+
break;
|
|
5933
|
+
case require_backend.DType.Int32:
|
|
5934
|
+
allOnes = -1;
|
|
5935
|
+
break;
|
|
5936
|
+
default: throw new TypeError(`invert: unsupported dtype ${arr.dtype}`);
|
|
5937
|
+
}
|
|
5938
|
+
return bitCombine(arr, allOnes, "xor");
|
|
5939
|
+
}
|
|
5940
|
+
/** Compute element-wise left bit shift. */
|
|
5941
|
+
function leftShift(x, y) {
|
|
5942
|
+
return bitShift(x, y, "shl");
|
|
5943
|
+
}
|
|
5944
|
+
/** Compute element-wise right bit shift. */
|
|
5945
|
+
function rightShift(x, y) {
|
|
5946
|
+
return bitShift(x, y, "shr");
|
|
5947
|
+
}
|
|
5803
5948
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
5804
5949
|
const where = where$1;
|
|
5805
5950
|
/**
|
|
@@ -7230,7 +7375,7 @@ __export(lax_exports, {
|
|
|
7230
7375
|
stopGradient: () => stopGradient$1,
|
|
7231
7376
|
topK: () => topK
|
|
7232
7377
|
});
|
|
7233
|
-
const JsArray = globalThis.Array;
|
|
7378
|
+
const JsArray$1 = globalThis.Array;
|
|
7234
7379
|
/** Elementwise bitcast an array into a new dtype. */
|
|
7235
7380
|
function bitcastConvertType(x, newDtype) {
|
|
7236
7381
|
return fudgeArray(x).view(newDtype);
|
|
@@ -7417,7 +7562,7 @@ function convTransposePadding(k, s, padding) {
|
|
|
7417
7562
|
} else if (padding === "VALID") {
|
|
7418
7563
|
padLen = k + s - 2 + Math.max(k - s, 0);
|
|
7419
7564
|
pad1 = k - 1;
|
|
7420
|
-
} else if (JsArray.isArray(padding)) {
|
|
7565
|
+
} else if (JsArray$1.isArray(padding)) {
|
|
7421
7566
|
const pads = [k - 1 - padding[0], k - 1 - padding[1]];
|
|
7422
7567
|
pad1 = pads[0];
|
|
7423
7568
|
padLen = pads[0] + pads[1];
|
|
@@ -7936,19 +8081,34 @@ function dotProductAttention(query, key$1, value, opts = {}) {
|
|
|
7936
8081
|
//#region src/library/random.ts
|
|
7937
8082
|
var random_exports = {};
|
|
7938
8083
|
__export(random_exports, {
|
|
8084
|
+
ball: () => ball,
|
|
7939
8085
|
bernoulli: () => bernoulli,
|
|
7940
8086
|
bits: () => bits,
|
|
7941
8087
|
categorical: () => categorical,
|
|
7942
8088
|
cauchy: () => cauchy,
|
|
8089
|
+
choice: () => choice,
|
|
8090
|
+
doubleSidedMaxwell: () => doubleSidedMaxwell,
|
|
7943
8091
|
exponential: () => exponential,
|
|
8092
|
+
geometric: () => geometric,
|
|
7944
8093
|
gumbel: () => gumbel,
|
|
7945
8094
|
key: () => key,
|
|
7946
8095
|
laplace: () => laplace,
|
|
8096
|
+
logistic: () => logistic,
|
|
8097
|
+
lognormal: () => lognormal,
|
|
8098
|
+
maxwell: () => maxwell,
|
|
7947
8099
|
multivariateNormal: () => multivariateNormal,
|
|
7948
8100
|
normal: () => normal,
|
|
8101
|
+
pareto: () => pareto,
|
|
8102
|
+
permutation: () => permutation,
|
|
8103
|
+
rademacher: () => rademacher,
|
|
8104
|
+
randint: () => randint,
|
|
8105
|
+
rayleigh: () => rayleigh,
|
|
7949
8106
|
split: () => split,
|
|
7950
|
-
|
|
8107
|
+
triangular: () => triangular,
|
|
8108
|
+
uniform: () => uniform,
|
|
8109
|
+
weibullMin: () => weibullMin
|
|
7951
8110
|
});
|
|
8111
|
+
const JsArray = globalThis.Array;
|
|
7952
8112
|
function validateKeyShape(key$1, scalar = false) {
|
|
7953
8113
|
if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
|
|
7954
8114
|
if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
|
|
@@ -8001,6 +8161,21 @@ const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxv
|
|
|
8001
8161
|
else return rand.mul(maxval - minval).add(minval);
|
|
8002
8162
|
}, { staticArgnums: [1, 2] });
|
|
8003
8163
|
/**
|
|
8164
|
+
* @function
|
|
8165
|
+
* Sample points uniformly from the Euclidean unit ball in `d` dimensions.
|
|
8166
|
+
*
|
|
8167
|
+
* Only the Euclidean `p=2` case is currently supported.
|
|
8168
|
+
*/
|
|
8169
|
+
const ball = jit$1(function ball$1(key$1, d, { p = 2, shape: shape$1 = [] } = {}) {
|
|
8170
|
+
if (!Number.isInteger(d) || d <= 0) throw new Error(`ball: dimension must be a positive integer, got ${d}`);
|
|
8171
|
+
if (p !== 2) throw new Error("ball: only the Euclidean p=2 case is supported");
|
|
8172
|
+
const [k1, k2] = split(key$1, 2);
|
|
8173
|
+
const z = normal(k1, [...shape$1, d]);
|
|
8174
|
+
const norm = sqrt(z.ref.mul(z.ref).sum(-1, { keepdims: true }));
|
|
8175
|
+
const radius = exp(log(uniform(k2, [...shape$1, 1])).mul(1 / d));
|
|
8176
|
+
return z.div(norm).mul(radius);
|
|
8177
|
+
}, { staticArgnums: [1, 2] });
|
|
8178
|
+
/**
|
|
8004
8179
|
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
8005
8180
|
*
|
|
8006
8181
|
* Returns a random Boolean array with the specified shape. `p` can be an array
|
|
@@ -8062,6 +8237,57 @@ const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
|
|
|
8062
8237
|
return tan(u.sub(.5).mul(Math.PI));
|
|
8063
8238
|
}, { staticArgnums: [1] });
|
|
8064
8239
|
/**
|
|
8240
|
+
* Sample from a population with optional replacement and optional probabilities.
|
|
8241
|
+
*
|
|
8242
|
+
* This implements the common JAX-compatible cases: integer populations and
|
|
8243
|
+
* array populations along `axis`. Probabilities `p`, if provided, are sampled
|
|
8244
|
+
* via `categorical(log(p))`.
|
|
8245
|
+
*/
|
|
8246
|
+
function choice(key$1, a, { shape: shape$1 = [], replace = true, p, axis = 0 } = {}) {
|
|
8247
|
+
let n;
|
|
8248
|
+
let values = null;
|
|
8249
|
+
if (typeof a === "number") {
|
|
8250
|
+
if (!Number.isInteger(a) || a < 0) throw new Error(`choice: a must be a non-negative integer, got ${a}`);
|
|
8251
|
+
n = a;
|
|
8252
|
+
} else {
|
|
8253
|
+
values = fudgeArray(a);
|
|
8254
|
+
axis = require_backend.checkAxis(axis, values.ndim);
|
|
8255
|
+
n = values.shape[axis];
|
|
8256
|
+
}
|
|
8257
|
+
let indices;
|
|
8258
|
+
if (p !== void 0) indices = categorical(key$1, log(p), {
|
|
8259
|
+
shape: shape$1,
|
|
8260
|
+
replace
|
|
8261
|
+
});
|
|
8262
|
+
else if (replace) indices = randint(key$1, {
|
|
8263
|
+
minval: 0,
|
|
8264
|
+
maxval: n,
|
|
8265
|
+
shape: shape$1
|
|
8266
|
+
});
|
|
8267
|
+
else {
|
|
8268
|
+
const k = shape$1.reduce((acc, x) => acc * x, 1);
|
|
8269
|
+
if (k > n) throw new Error(`Number of samples without replacement (${k}) cannot exceed population size (${n}).`);
|
|
8270
|
+
indices = permutation(key$1, n).slice([0, k]).reshape(shape$1);
|
|
8271
|
+
}
|
|
8272
|
+
if (values === null) return indices;
|
|
8273
|
+
const index = JsArray(axis).fill([]);
|
|
8274
|
+
index.push(indices);
|
|
8275
|
+
return values.slice(...index);
|
|
8276
|
+
}
|
|
8277
|
+
/**
|
|
8278
|
+
* @function
|
|
8279
|
+
* Sample double-sided Maxwell random values with the provided location and scale.
|
|
8280
|
+
*/
|
|
8281
|
+
const doubleSidedMaxwell = jit$1(function doubleSidedMaxwell$1(key$1, loc, scale, shape$1 = []) {
|
|
8282
|
+
loc = fudgeArray(loc);
|
|
8283
|
+
scale = fudgeArray(scale);
|
|
8284
|
+
const [k1, k2] = split(key$1, 2);
|
|
8285
|
+
return rademacher(k1, {
|
|
8286
|
+
shape: shape$1,
|
|
8287
|
+
dtype: require_backend.DType.Float32
|
|
8288
|
+
}).mul(maxwell(k2, shape$1)).mul(scale).add(loc);
|
|
8289
|
+
}, { staticArgnums: [3] });
|
|
8290
|
+
/**
|
|
8065
8291
|
* @function
|
|
8066
8292
|
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
8067
8293
|
*/
|
|
@@ -8071,6 +8297,14 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
|
8071
8297
|
}, { staticArgnums: [1] });
|
|
8072
8298
|
/**
|
|
8073
8299
|
* @function
|
|
8300
|
+
* Sample geometric random values: the number of trials until first success.
|
|
8301
|
+
*/
|
|
8302
|
+
const geometric = jit$1(function geometric$1(key$1, p, { shape: shape$1 = [], dtype = require_backend.DType.Int32 } = {}) {
|
|
8303
|
+
p = fudgeArray(p);
|
|
8304
|
+
return floor(log1p(negative(uniform(key$1, shape$1))).div(log1p(negative(p)))).add(1).astype(dtype);
|
|
8305
|
+
}, { staticArgnums: [2] });
|
|
8306
|
+
/**
|
|
8307
|
+
* @function
|
|
8074
8308
|
* Sample from a Gumbel distribution with location 0 and scale 1.
|
|
8075
8309
|
*
|
|
8076
8310
|
* Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
|
|
@@ -8095,6 +8329,32 @@ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
|
|
|
8095
8329
|
}, { staticArgnums: [1] });
|
|
8096
8330
|
/**
|
|
8097
8331
|
* @function
|
|
8332
|
+
* Sample from a logistic distribution with location 0 and scale 1.
|
|
8333
|
+
*
|
|
8334
|
+
* Uses inverse transform sampling: `x = log(u) - log(1-u)`.
|
|
8335
|
+
*/
|
|
8336
|
+
const logistic = jit$1(function logistic$1(key$1, shape$1 = []) {
|
|
8337
|
+
const u = uniform(key$1, shape$1);
|
|
8338
|
+
return log(u.ref).sub(log1p(negative(u)));
|
|
8339
|
+
}, { staticArgnums: [1] });
|
|
8340
|
+
/**
|
|
8341
|
+
* @function
|
|
8342
|
+
* Sample log-normal random values: `exp(sigma * normal(key, shape))`.
|
|
8343
|
+
*/
|
|
8344
|
+
const lognormal = jit$1(function lognormal$1(key$1, sigma = 1, shape$1 = []) {
|
|
8345
|
+
sigma = fudgeArray(sigma);
|
|
8346
|
+
return exp(normal(key$1, shape$1).mul(sigma));
|
|
8347
|
+
}, { staticArgnums: [2] });
|
|
8348
|
+
/**
|
|
8349
|
+
* @function
|
|
8350
|
+
* Sample Maxwell-distributed random values.
|
|
8351
|
+
*/
|
|
8352
|
+
const maxwell = jit$1(function maxwell$1(key$1, shape$1 = []) {
|
|
8353
|
+
const z = normal(key$1, [...shape$1, 3]);
|
|
8354
|
+
return sqrt(z.ref.mul(z).sum(-1));
|
|
8355
|
+
}, { staticArgnums: [1] });
|
|
8356
|
+
/**
|
|
8357
|
+
* @function
|
|
8098
8358
|
* Sample multivariate normal random values with given mean and covariance.
|
|
8099
8359
|
*
|
|
8100
8360
|
* The values are returned with the given shape, along with the final dimension
|
|
@@ -8135,6 +8395,97 @@ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
|
|
|
8135
8395
|
const theta = u2.mul(2 * Math.PI);
|
|
8136
8396
|
return radius.mul(cos(theta));
|
|
8137
8397
|
}, { staticArgnums: [1] });
|
|
8398
|
+
/**
|
|
8399
|
+
* @function
|
|
8400
|
+
* Sample from a Pareto distribution with shape parameter `b` and support [1, ∞).
|
|
8401
|
+
*/
|
|
8402
|
+
const pareto = jit$1(function pareto$1(key$1, b, shape$1 = []) {
|
|
8403
|
+
b = fudgeArray(b);
|
|
8404
|
+
return exp(exponential(key$1, shape$1).div(b));
|
|
8405
|
+
}, { staticArgnums: [2] });
|
|
8406
|
+
/**
|
|
8407
|
+
* Return a random permutation of an integer range or of an array along `axis`.
|
|
8408
|
+
*/
|
|
8409
|
+
function permutation(key$1, x, axis = 0) {
|
|
8410
|
+
if (typeof x === "number") {
|
|
8411
|
+
if (!Number.isInteger(x) || x < 0) throw new Error(`permutation: x must be a non-negative integer, got ${x}`);
|
|
8412
|
+
return argsort(uniform(key$1, [x])).astype(require_backend.DType.Int32);
|
|
8413
|
+
}
|
|
8414
|
+
const arr = fudgeArray(x);
|
|
8415
|
+
axis = require_backend.checkAxis(axis, arr.ndim);
|
|
8416
|
+
const perm = permutation(key$1, arr.shape[axis]);
|
|
8417
|
+
const index = JsArray(axis).fill([]);
|
|
8418
|
+
index.push(perm);
|
|
8419
|
+
return arr.slice(...index);
|
|
8420
|
+
}
|
|
8421
|
+
/**
|
|
8422
|
+
* @function
|
|
8423
|
+
* Sample Rademacher random values, uniformly from {-1, 1}.
|
|
8424
|
+
*/
|
|
8425
|
+
const rademacher = jit$1(function rademacher$1(key$1, { shape: shape$1 = [], dtype = require_backend.DType.Int32 } = {}) {
|
|
8426
|
+
if (dtype === require_backend.DType.Uint32 || dtype === require_backend.DType.Bool) throw new Error(`rademacher: unsupported dtype ${dtype}`);
|
|
8427
|
+
const one = array(1, {
|
|
8428
|
+
dtype,
|
|
8429
|
+
device: key$1.device
|
|
8430
|
+
});
|
|
8431
|
+
const minusOne = array(-1, {
|
|
8432
|
+
dtype,
|
|
8433
|
+
device: key$1.device
|
|
8434
|
+
});
|
|
8435
|
+
return where(bernoulli(key$1, .5, shape$1), one, minusOne);
|
|
8436
|
+
}, { staticArgnums: [1] });
|
|
8437
|
+
/**
|
|
8438
|
+
* @function
|
|
8439
|
+
* Sample integer values uniformly from `[minval, maxval)`.
|
|
8440
|
+
*
|
|
8441
|
+
* This uses modulo reduction of uniform 32-bit random bits. For ranges that do
|
|
8442
|
+
* not divide 2^32, this introduces a very small modulo bias.
|
|
8443
|
+
*/
|
|
8444
|
+
const randint = jit$1(function randint$1(key$1, { minval, maxval, shape: shape$1 = [], dtype = require_backend.DType.Int32 }) {
|
|
8445
|
+
if (!Number.isInteger(minval) || !Number.isInteger(maxval)) throw new Error("randint: minval and maxval must be integers");
|
|
8446
|
+
if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
|
|
8447
|
+
if (dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Uint32) throw new Error(`randint: dtype must be int32 or uint32, got ${dtype}`);
|
|
8448
|
+
if (dtype === require_backend.DType.Uint32 && minval < 0) throw new Error("randint: uint32 dtype requires minval >= 0");
|
|
8449
|
+
const range$1 = maxval - minval;
|
|
8450
|
+
return bits(key$1, shape$1).mod(range$1).astype(dtype).add(minval);
|
|
8451
|
+
}, { staticArgnums: [1] });
|
|
8452
|
+
/**
|
|
8453
|
+
* @function
|
|
8454
|
+
* Sample Rayleigh random values with the provided scale parameter.
|
|
8455
|
+
*/
|
|
8456
|
+
const rayleigh = jit$1(function rayleigh$1(key$1, scale = 1, shape$1 = []) {
|
|
8457
|
+
scale = fudgeArray(scale);
|
|
8458
|
+
return sqrt(exponential(key$1, shape$1).mul(2)).mul(scale);
|
|
8459
|
+
}, { staticArgnums: [2] });
|
|
8460
|
+
/**
|
|
8461
|
+
* @function
|
|
8462
|
+
* Sample triangular random values on `[left, right]` with the given mode.
|
|
8463
|
+
*/
|
|
8464
|
+
const triangular = jit$1(function triangular$1(key$1, left, mode, right, shape$1 = []) {
|
|
8465
|
+
left = fudgeArray(left);
|
|
8466
|
+
mode = fudgeArray(mode);
|
|
8467
|
+
right = fudgeArray(right);
|
|
8468
|
+
const u = uniform(key$1, shape$1);
|
|
8469
|
+
const width = right.ref.sub(left.ref);
|
|
8470
|
+
const leftSpan = mode.ref.sub(left.ref);
|
|
8471
|
+
const rightSpan = right.ref.sub(mode);
|
|
8472
|
+
const cutoff = leftSpan.ref.div(width.ref);
|
|
8473
|
+
const cond = u.ref.less(cutoff);
|
|
8474
|
+
const lower = left.add(sqrt(u.ref.mul(width.ref).mul(leftSpan)));
|
|
8475
|
+
const upper = right.sub(sqrt(negative(u).add(1).mul(width).mul(rightSpan)));
|
|
8476
|
+
return where(cond, lower, upper);
|
|
8477
|
+
}, { staticArgnums: [4] });
|
|
8478
|
+
/**
|
|
8479
|
+
* @function
|
|
8480
|
+
* Sample Weibull minimum random values.
|
|
8481
|
+
*
|
|
8482
|
+
* Uses `scale * exponential(key) ** (1 / concentration)`.
|
|
8483
|
+
*/
|
|
8484
|
+
const weibullMin = jit$1(function weibullMin$1(key$1, scale, concentration, shape$1 = []) {
|
|
8485
|
+
scale = fudgeArray(scale);
|
|
8486
|
+
concentration = fudgeArray(concentration);
|
|
8487
|
+
return scale.mul(exp(log(exponential(key$1, shape$1)).div(concentration)));
|
|
8488
|
+
}, { staticArgnums: [3] });
|
|
8138
8489
|
|
|
8139
8490
|
//#endregion
|
|
8140
8491
|
//#region src/library/scipy-special.ts
|
|
@@ -8336,6 +8687,7 @@ exports.blockUntilReady = blockUntilReady;
|
|
|
8336
8687
|
exports.defaultDevice = require_backend.defaultDevice;
|
|
8337
8688
|
exports.devicePut = devicePut;
|
|
8338
8689
|
exports.devices = require_backend.devices;
|
|
8690
|
+
exports.getWebGPUDevice = require_backend.getWebGPUDevice;
|
|
8339
8691
|
exports.grad = grad;
|
|
8340
8692
|
exports.hessian = hessian;
|
|
8341
8693
|
exports.init = require_backend.init;
|