@jax-js/jax 0.1.11 → 0.1.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -1
- package/dist/{backend-DZvR7mZV.js → backend-DI-V78Rk.js} +4 -2
- package/dist/{backend-DlYlOYqN.cjs → backend-x-6vqzIM.cjs} +4 -2
- package/dist/index.cjs +233 -18
- package/dist/index.d.cts +106 -1
- package/dist/index.d.ts +106 -1
- package/dist/index.js +233 -18
- package/dist/{webgl-D8-14NzA.js → webgl-BhsnpeB0.js} +1 -1
- package/dist/{webgl-Ovaaa-Qx.cjs → webgl-CD3WK_Me.cjs} +1 -1
- package/dist/{webgpu-Dg8FpYrH.js → webgpu-C2kLdkUh.js} +299 -154
- package/dist/{webgpu-uU9nnttc.cjs → webgpu-C4S8Uq9e.cjs} +299 -154
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -63,6 +63,7 @@ of late 2025.
|
|
|
63
63
|
|
|
64
64
|
Community usage:
|
|
65
65
|
|
|
66
|
+
- [**g9-jaxjs**: Automatically interactive graphics with forward-mode AD](https://srush.github.io/g9jax/)
|
|
66
67
|
- [**autoresearch-webgpu**: autoresesarch, in the browser](https://autoresearch.lucasgelfond.online/)
|
|
67
68
|
- [**tanh.xyz**: Interactive ML visualizations](https://tanh.xyz/)
|
|
68
69
|
- [**jax-js-bayes**: Declarative Bayesian modeling library](https://github.com/StefanSko/jax-js-bayes)
|
|
@@ -72,10 +73,12 @@ Demos on the jax-js website:
|
|
|
72
73
|
- [Training neural networks on MNIST](https://jax-js.com/mnist)
|
|
73
74
|
- [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
|
|
74
75
|
- [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
|
|
76
|
+
- [Object detection: D-FINE (ONNX)](https://jax-js.com/d-fine)
|
|
75
77
|
- [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
|
|
76
78
|
- [Fluid simulation (Navier-Stokes)](https://jax-js.com/fluid-sim)
|
|
77
79
|
- [In-browser REPL](https://jax-js.com/repl)
|
|
78
80
|
- [Matmul benchmark](https://jax-js.com/bench/matmul)
|
|
81
|
+
- [Matvec benchmark](https://jax-js.com/bench/matvec)
|
|
79
82
|
- [Conv2d benchmark](https://jax-js.com/bench/conv2d)
|
|
80
83
|
- [Mandelbrot set](https://jax-js.com/mandelbrot)
|
|
81
84
|
|
|
@@ -422,7 +425,6 @@ Contributions are welcomed! Some fruitful areas to look into:
|
|
|
422
425
|
- Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
|
|
423
426
|
- Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
|
|
424
427
|
and multithreading. (Even single-threaded Wasm could be ~20x faster.)
|
|
425
|
-
- Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
|
|
426
428
|
- Making a fast transformer inference engine, comparing against onnxruntime-web.
|
|
427
429
|
|
|
428
430
|
You may join our [Discord server](https://discord.gg/BW6YsCd4Tf) and chat with the community.
|
|
@@ -1430,11 +1430,13 @@ var Reduction = class {
|
|
|
1430
1430
|
function accessorGlobal(dtype, gid, st, indices) {
|
|
1431
1431
|
const [index, valid] = st.toAluExp(indices);
|
|
1432
1432
|
const [, len] = st.views[0].dataRange();
|
|
1433
|
+
if (valid.resolve()) return AluExp.globalIndex(dtype, gid, len, index);
|
|
1433
1434
|
return AluExp.where(valid, AluExp.globalIndex(dtype, gid, len, index), AluExp.const(dtype, 0));
|
|
1434
1435
|
}
|
|
1435
1436
|
/** Expression for accessing `indices` in an array recipe with variable "idx". */
|
|
1436
1437
|
function accessorAluExp(exp, st, indices) {
|
|
1437
1438
|
const [index, valid] = st.toAluExp(indices);
|
|
1439
|
+
if (valid.resolve()) return exp.substitute({ idx: index });
|
|
1438
1440
|
return AluExp.where(valid, exp.substitute({ idx: index }), AluExp.const(exp.dtype, 0));
|
|
1439
1441
|
}
|
|
1440
1442
|
function threefry2x32(k0, k1, c0, c1) {
|
|
@@ -5059,7 +5061,7 @@ async function createBackend(device) {
|
|
|
5059
5061
|
if (!navigator.gpu) return null;
|
|
5060
5062
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
5061
5063
|
if (!adapter) return null;
|
|
5062
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
5064
|
+
const { WebGPUBackend } = await import("./webgpu-C2kLdkUh.js");
|
|
5063
5065
|
const importantLimits = [
|
|
5064
5066
|
"maxBufferSize",
|
|
5065
5067
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -5097,7 +5099,7 @@ async function createBackend(device) {
|
|
|
5097
5099
|
});
|
|
5098
5100
|
if (!gl) return null;
|
|
5099
5101
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
5100
|
-
const { WebGLBackend } = await import("./webgl-
|
|
5102
|
+
const { WebGLBackend } = await import("./webgl-BhsnpeB0.js");
|
|
5101
5103
|
return new WebGLBackend(gl);
|
|
5102
5104
|
} else throw new Error(`Backend not found: ${device}`);
|
|
5103
5105
|
}
|
|
@@ -1431,11 +1431,13 @@ var Reduction = class {
|
|
|
1431
1431
|
function accessorGlobal(dtype, gid, st, indices) {
|
|
1432
1432
|
const [index, valid] = st.toAluExp(indices);
|
|
1433
1433
|
const [, len] = st.views[0].dataRange();
|
|
1434
|
+
if (valid.resolve()) return AluExp.globalIndex(dtype, gid, len, index);
|
|
1434
1435
|
return AluExp.where(valid, AluExp.globalIndex(dtype, gid, len, index), AluExp.const(dtype, 0));
|
|
1435
1436
|
}
|
|
1436
1437
|
/** Expression for accessing `indices` in an array recipe with variable "idx". */
|
|
1437
1438
|
function accessorAluExp(exp, st, indices) {
|
|
1438
1439
|
const [index, valid] = st.toAluExp(indices);
|
|
1440
|
+
if (valid.resolve()) return exp.substitute({ idx: index });
|
|
1439
1441
|
return AluExp.where(valid, exp.substitute({ idx: index }), AluExp.const(exp.dtype, 0));
|
|
1440
1442
|
}
|
|
1441
1443
|
function threefry2x32(k0, k1, c0, c1) {
|
|
@@ -5060,7 +5062,7 @@ async function createBackend(device) {
|
|
|
5060
5062
|
if (!navigator.gpu) return null;
|
|
5061
5063
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
5062
5064
|
if (!adapter) return null;
|
|
5063
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
5065
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-C4S8Uq9e.cjs"));
|
|
5064
5066
|
const importantLimits = [
|
|
5065
5067
|
"maxBufferSize",
|
|
5066
5068
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -5098,7 +5100,7 @@ async function createBackend(device) {
|
|
|
5098
5100
|
});
|
|
5099
5101
|
if (!gl) return null;
|
|
5100
5102
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
5101
|
-
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-
|
|
5103
|
+
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-CD3WK_Me.cjs"));
|
|
5102
5104
|
return new WebGLBackend(gl);
|
|
5103
5105
|
} else throw new Error(`Backend not found: ${device}`);
|
|
5104
5106
|
}
|
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-x-6vqzIM.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -240,7 +240,7 @@ __export(tree_exports, {
|
|
|
240
240
|
structure: () => structure,
|
|
241
241
|
unflatten: () => unflatten
|
|
242
242
|
});
|
|
243
|
-
const JsArray$
|
|
243
|
+
const JsArray$3 = globalThis.Array;
|
|
244
244
|
let NodeType = /* @__PURE__ */ function(NodeType$1) {
|
|
245
245
|
NodeType$1["Array"] = "Array";
|
|
246
246
|
NodeType$1["Object"] = "Object";
|
|
@@ -288,7 +288,7 @@ function flatten(tree) {
|
|
|
288
288
|
return [leaves$1, treedef];
|
|
289
289
|
}
|
|
290
290
|
function _flatten(tree, leaves$1) {
|
|
291
|
-
if (JsArray$
|
|
291
|
+
if (JsArray$3.isArray(tree)) {
|
|
292
292
|
const childTrees = tree.map((c) => _flatten(c, leaves$1));
|
|
293
293
|
return new JsTreeDef(NodeType.Array, null, childTrees);
|
|
294
294
|
} else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
|
|
@@ -2495,7 +2495,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2495
2495
|
|
|
2496
2496
|
//#endregion
|
|
2497
2497
|
//#region src/frontend/array.ts
|
|
2498
|
-
const JsArray$
|
|
2498
|
+
const JsArray$2 = globalThis.Array;
|
|
2499
2499
|
const inlineArrayLimit = 128;
|
|
2500
2500
|
/** Version of pureArray with fudged types. */
|
|
2501
2501
|
const fudgeArray = pureArray;
|
|
@@ -2935,6 +2935,15 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2935
2935
|
this.#check();
|
|
2936
2936
|
const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
|
|
2937
2937
|
if (this.#source instanceof require_backend.AluExp) {
|
|
2938
|
+
let resolvedSource;
|
|
2939
|
+
if (this.#st.contiguous && this.#st.size < inlineArrayLimit && (resolvedSource = this.#source.resolve()) !== void 0) {
|
|
2940
|
+
const byteLength = this.#st.size * require_backend.byteWidth(this.#dtype);
|
|
2941
|
+
const initialData = new Uint8Array(byteLength);
|
|
2942
|
+
require_backend.dtypedArray(this.#dtype, initialData).fill(resolvedSource);
|
|
2943
|
+
this.#source = this.#backend.malloc(byteLength, initialData);
|
|
2944
|
+
this.#st = require_backend.ShapeTracker.fromShape(this.shape);
|
|
2945
|
+
return;
|
|
2946
|
+
}
|
|
2938
2947
|
const exp$2 = require_backend.accessorAluExp(this.#source, this.#st, indices);
|
|
2939
2948
|
const kernel = new require_backend.Kernel(0, this.#st.size, exp$2);
|
|
2940
2949
|
const output = this.#backend.malloc(kernel.bytes);
|
|
@@ -3385,7 +3394,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
3385
3394
|
if (!shape$1) {
|
|
3386
3395
|
shape$1 = [];
|
|
3387
3396
|
let cur = values;
|
|
3388
|
-
while (JsArray$
|
|
3397
|
+
while (JsArray$2.isArray(cur)) {
|
|
3389
3398
|
shape$1.push(cur.length);
|
|
3390
3399
|
cur = cur[0];
|
|
3391
3400
|
}
|
|
@@ -4269,7 +4278,7 @@ const jvpRules = {
|
|
|
4269
4278
|
return [[L], [dL]];
|
|
4270
4279
|
},
|
|
4271
4280
|
[Primitive.LU]([a], [da]) {
|
|
4272
|
-
const [luMatrix, pivots, permutation] = lu$1(a);
|
|
4281
|
+
const [luMatrix, pivots, permutation$1] = lu$1(a);
|
|
4273
4282
|
const [m, n] = a.shape.slice(-2);
|
|
4274
4283
|
const k = Math.min(m, n);
|
|
4275
4284
|
const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
|
|
@@ -4281,7 +4290,7 @@ const jvpRules = {
|
|
|
4281
4290
|
const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
|
|
4282
4291
|
const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
|
|
4283
4292
|
const U = uPadded.add(uEye);
|
|
4284
|
-
const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
|
|
4293
|
+
const P = permutation$1.ref.reshape([...permutation$1.shape, 1]).equal(arange(m)).astype(da.dtype);
|
|
4285
4294
|
const pda = batchMatmulT(P, mT(da));
|
|
4286
4295
|
const la = mT(triangularSolve$1(L.ref, mT(pda), {
|
|
4287
4296
|
lower: true,
|
|
@@ -4293,11 +4302,11 @@ const jvpRules = {
|
|
|
4293
4302
|
return [[
|
|
4294
4303
|
luMatrix,
|
|
4295
4304
|
pivots,
|
|
4296
|
-
permutation
|
|
4305
|
+
permutation$1
|
|
4297
4306
|
], [
|
|
4298
4307
|
lDot.add(uDot),
|
|
4299
4308
|
zerosLike$1(pivots.ref),
|
|
4300
|
-
zerosLike$1(permutation.ref)
|
|
4309
|
+
zerosLike$1(permutation$1.ref)
|
|
4301
4310
|
]];
|
|
4302
4311
|
},
|
|
4303
4312
|
[Primitive.Jit](primals, tangents, { name, jaxpr }) {
|
|
@@ -5379,8 +5388,8 @@ function cross$1(x1, x2, axis = -1) {
|
|
|
5379
5388
|
function det(a) {
|
|
5380
5389
|
a = fudgeArray(a);
|
|
5381
5390
|
const n = checkSquare("det", a);
|
|
5382
|
-
const [lu$2, pivots, permutation] = lu(a);
|
|
5383
|
-
permutation.dispose();
|
|
5391
|
+
const [lu$2, pivots, permutation$1] = lu(a);
|
|
5392
|
+
permutation$1.dispose();
|
|
5384
5393
|
const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
|
|
5385
5394
|
const sign$1 = parity.mul(-2).add(1);
|
|
5386
5395
|
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
@@ -5469,8 +5478,8 @@ function matrixPower(a, n) {
|
|
|
5469
5478
|
function slogdet(a) {
|
|
5470
5479
|
a = fudgeArray(a);
|
|
5471
5480
|
const n = checkSquare("slogdet", a);
|
|
5472
|
-
const [lu$2, pivots, permutation] = lu(a);
|
|
5473
|
-
permutation.dispose();
|
|
5481
|
+
const [lu$2, pivots, permutation$1] = lu(a);
|
|
5482
|
+
permutation$1.dispose();
|
|
5474
5483
|
let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
|
|
5475
5484
|
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
5476
5485
|
parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
|
|
@@ -5508,9 +5517,9 @@ function solve(a, b) {
|
|
|
5508
5517
|
n,
|
|
5509
5518
|
m
|
|
5510
5519
|
]);
|
|
5511
|
-
const [lu$2, pivots, permutation] = lu(a);
|
|
5520
|
+
const [lu$2, pivots, permutation$1] = lu(a);
|
|
5512
5521
|
pivots.dispose();
|
|
5513
|
-
const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
|
|
5522
|
+
const P = arange(n).equal(permutation$1.reshape([...permutation$1.shape, 1])).astype(b.dtype);
|
|
5514
5523
|
const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
|
|
5515
5524
|
leftSide: true,
|
|
5516
5525
|
lower: true,
|
|
@@ -7366,7 +7375,7 @@ __export(lax_exports, {
|
|
|
7366
7375
|
stopGradient: () => stopGradient$1,
|
|
7367
7376
|
topK: () => topK
|
|
7368
7377
|
});
|
|
7369
|
-
const JsArray = globalThis.Array;
|
|
7378
|
+
const JsArray$1 = globalThis.Array;
|
|
7370
7379
|
/** Elementwise bitcast an array into a new dtype. */
|
|
7371
7380
|
function bitcastConvertType(x, newDtype) {
|
|
7372
7381
|
return fudgeArray(x).view(newDtype);
|
|
@@ -7553,7 +7562,7 @@ function convTransposePadding(k, s, padding) {
|
|
|
7553
7562
|
} else if (padding === "VALID") {
|
|
7554
7563
|
padLen = k + s - 2 + Math.max(k - s, 0);
|
|
7555
7564
|
pad1 = k - 1;
|
|
7556
|
-
} else if (JsArray.isArray(padding)) {
|
|
7565
|
+
} else if (JsArray$1.isArray(padding)) {
|
|
7557
7566
|
const pads = [k - 1 - padding[0], k - 1 - padding[1]];
|
|
7558
7567
|
pad1 = pads[0];
|
|
7559
7568
|
padLen = pads[0] + pads[1];
|
|
@@ -8072,19 +8081,34 @@ function dotProductAttention(query, key$1, value, opts = {}) {
|
|
|
8072
8081
|
//#region src/library/random.ts
|
|
8073
8082
|
var random_exports = {};
|
|
8074
8083
|
__export(random_exports, {
|
|
8084
|
+
ball: () => ball,
|
|
8075
8085
|
bernoulli: () => bernoulli,
|
|
8076
8086
|
bits: () => bits,
|
|
8077
8087
|
categorical: () => categorical,
|
|
8078
8088
|
cauchy: () => cauchy,
|
|
8089
|
+
choice: () => choice,
|
|
8090
|
+
doubleSidedMaxwell: () => doubleSidedMaxwell,
|
|
8079
8091
|
exponential: () => exponential,
|
|
8092
|
+
geometric: () => geometric,
|
|
8080
8093
|
gumbel: () => gumbel,
|
|
8081
8094
|
key: () => key,
|
|
8082
8095
|
laplace: () => laplace,
|
|
8096
|
+
logistic: () => logistic,
|
|
8097
|
+
lognormal: () => lognormal,
|
|
8098
|
+
maxwell: () => maxwell,
|
|
8083
8099
|
multivariateNormal: () => multivariateNormal,
|
|
8084
8100
|
normal: () => normal,
|
|
8101
|
+
pareto: () => pareto,
|
|
8102
|
+
permutation: () => permutation,
|
|
8103
|
+
rademacher: () => rademacher,
|
|
8104
|
+
randint: () => randint,
|
|
8105
|
+
rayleigh: () => rayleigh,
|
|
8085
8106
|
split: () => split,
|
|
8086
|
-
|
|
8107
|
+
triangular: () => triangular,
|
|
8108
|
+
uniform: () => uniform,
|
|
8109
|
+
weibullMin: () => weibullMin
|
|
8087
8110
|
});
|
|
8111
|
+
const JsArray = globalThis.Array;
|
|
8088
8112
|
function validateKeyShape(key$1, scalar = false) {
|
|
8089
8113
|
if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
|
|
8090
8114
|
if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
|
|
@@ -8137,6 +8161,21 @@ const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxv
|
|
|
8137
8161
|
else return rand.mul(maxval - minval).add(minval);
|
|
8138
8162
|
}, { staticArgnums: [1, 2] });
|
|
8139
8163
|
/**
|
|
8164
|
+
* @function
|
|
8165
|
+
* Sample points uniformly from the Euclidean unit ball in `d` dimensions.
|
|
8166
|
+
*
|
|
8167
|
+
* Only the Euclidean `p=2` case is currently supported.
|
|
8168
|
+
*/
|
|
8169
|
+
const ball = jit$1(function ball$1(key$1, d, { p = 2, shape: shape$1 = [] } = {}) {
|
|
8170
|
+
if (!Number.isInteger(d) || d <= 0) throw new Error(`ball: dimension must be a positive integer, got ${d}`);
|
|
8171
|
+
if (p !== 2) throw new Error("ball: only the Euclidean p=2 case is supported");
|
|
8172
|
+
const [k1, k2] = split(key$1, 2);
|
|
8173
|
+
const z = normal(k1, [...shape$1, d]);
|
|
8174
|
+
const norm = sqrt(z.ref.mul(z.ref).sum(-1, { keepdims: true }));
|
|
8175
|
+
const radius = exp(log(uniform(k2, [...shape$1, 1])).mul(1 / d));
|
|
8176
|
+
return z.div(norm).mul(radius);
|
|
8177
|
+
}, { staticArgnums: [1, 2] });
|
|
8178
|
+
/**
|
|
8140
8179
|
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
8141
8180
|
*
|
|
8142
8181
|
* Returns a random Boolean array with the specified shape. `p` can be an array
|
|
@@ -8198,6 +8237,57 @@ const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
|
|
|
8198
8237
|
return tan(u.sub(.5).mul(Math.PI));
|
|
8199
8238
|
}, { staticArgnums: [1] });
|
|
8200
8239
|
/**
|
|
8240
|
+
* Sample from a population with optional replacement and optional probabilities.
|
|
8241
|
+
*
|
|
8242
|
+
* This implements the common JAX-compatible cases: integer populations and
|
|
8243
|
+
* array populations along `axis`. Probabilities `p`, if provided, are sampled
|
|
8244
|
+
* via `categorical(log(p))`.
|
|
8245
|
+
*/
|
|
8246
|
+
function choice(key$1, a, { shape: shape$1 = [], replace = true, p, axis = 0 } = {}) {
|
|
8247
|
+
let n;
|
|
8248
|
+
let values = null;
|
|
8249
|
+
if (typeof a === "number") {
|
|
8250
|
+
if (!Number.isInteger(a) || a < 0) throw new Error(`choice: a must be a non-negative integer, got ${a}`);
|
|
8251
|
+
n = a;
|
|
8252
|
+
} else {
|
|
8253
|
+
values = fudgeArray(a);
|
|
8254
|
+
axis = require_backend.checkAxis(axis, values.ndim);
|
|
8255
|
+
n = values.shape[axis];
|
|
8256
|
+
}
|
|
8257
|
+
let indices;
|
|
8258
|
+
if (p !== void 0) indices = categorical(key$1, log(p), {
|
|
8259
|
+
shape: shape$1,
|
|
8260
|
+
replace
|
|
8261
|
+
});
|
|
8262
|
+
else if (replace) indices = randint(key$1, {
|
|
8263
|
+
minval: 0,
|
|
8264
|
+
maxval: n,
|
|
8265
|
+
shape: shape$1
|
|
8266
|
+
});
|
|
8267
|
+
else {
|
|
8268
|
+
const k = shape$1.reduce((acc, x) => acc * x, 1);
|
|
8269
|
+
if (k > n) throw new Error(`Number of samples without replacement (${k}) cannot exceed population size (${n}).`);
|
|
8270
|
+
indices = permutation(key$1, n).slice([0, k]).reshape(shape$1);
|
|
8271
|
+
}
|
|
8272
|
+
if (values === null) return indices;
|
|
8273
|
+
const index = JsArray(axis).fill([]);
|
|
8274
|
+
index.push(indices);
|
|
8275
|
+
return values.slice(...index);
|
|
8276
|
+
}
|
|
8277
|
+
/**
|
|
8278
|
+
* @function
|
|
8279
|
+
* Sample double-sided Maxwell random values with the provided location and scale.
|
|
8280
|
+
*/
|
|
8281
|
+
const doubleSidedMaxwell = jit$1(function doubleSidedMaxwell$1(key$1, loc, scale, shape$1 = []) {
|
|
8282
|
+
loc = fudgeArray(loc);
|
|
8283
|
+
scale = fudgeArray(scale);
|
|
8284
|
+
const [k1, k2] = split(key$1, 2);
|
|
8285
|
+
return rademacher(k1, {
|
|
8286
|
+
shape: shape$1,
|
|
8287
|
+
dtype: require_backend.DType.Float32
|
|
8288
|
+
}).mul(maxwell(k2, shape$1)).mul(scale).add(loc);
|
|
8289
|
+
}, { staticArgnums: [3] });
|
|
8290
|
+
/**
|
|
8201
8291
|
* @function
|
|
8202
8292
|
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
8203
8293
|
*/
|
|
@@ -8207,6 +8297,14 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
|
8207
8297
|
}, { staticArgnums: [1] });
|
|
8208
8298
|
/**
|
|
8209
8299
|
* @function
|
|
8300
|
+
* Sample geometric random values: the number of trials until first success.
|
|
8301
|
+
*/
|
|
8302
|
+
const geometric = jit$1(function geometric$1(key$1, p, { shape: shape$1 = [], dtype = require_backend.DType.Int32 } = {}) {
|
|
8303
|
+
p = fudgeArray(p);
|
|
8304
|
+
return floor(log1p(negative(uniform(key$1, shape$1))).div(log1p(negative(p)))).add(1).astype(dtype);
|
|
8305
|
+
}, { staticArgnums: [2] });
|
|
8306
|
+
/**
|
|
8307
|
+
* @function
|
|
8210
8308
|
* Sample from a Gumbel distribution with location 0 and scale 1.
|
|
8211
8309
|
*
|
|
8212
8310
|
* Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
|
|
@@ -8231,6 +8329,32 @@ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
|
|
|
8231
8329
|
}, { staticArgnums: [1] });
|
|
8232
8330
|
/**
|
|
8233
8331
|
* @function
|
|
8332
|
+
* Sample from a logistic distribution with location 0 and scale 1.
|
|
8333
|
+
*
|
|
8334
|
+
* Uses inverse transform sampling: `x = log(u) - log(1-u)`.
|
|
8335
|
+
*/
|
|
8336
|
+
const logistic = jit$1(function logistic$1(key$1, shape$1 = []) {
|
|
8337
|
+
const u = uniform(key$1, shape$1);
|
|
8338
|
+
return log(u.ref).sub(log1p(negative(u)));
|
|
8339
|
+
}, { staticArgnums: [1] });
|
|
8340
|
+
/**
|
|
8341
|
+
* @function
|
|
8342
|
+
* Sample log-normal random values: `exp(sigma * normal(key, shape))`.
|
|
8343
|
+
*/
|
|
8344
|
+
const lognormal = jit$1(function lognormal$1(key$1, sigma = 1, shape$1 = []) {
|
|
8345
|
+
sigma = fudgeArray(sigma);
|
|
8346
|
+
return exp(normal(key$1, shape$1).mul(sigma));
|
|
8347
|
+
}, { staticArgnums: [2] });
|
|
8348
|
+
/**
|
|
8349
|
+
* @function
|
|
8350
|
+
* Sample Maxwell-distributed random values.
|
|
8351
|
+
*/
|
|
8352
|
+
const maxwell = jit$1(function maxwell$1(key$1, shape$1 = []) {
|
|
8353
|
+
const z = normal(key$1, [...shape$1, 3]);
|
|
8354
|
+
return sqrt(z.ref.mul(z).sum(-1));
|
|
8355
|
+
}, { staticArgnums: [1] });
|
|
8356
|
+
/**
|
|
8357
|
+
* @function
|
|
8234
8358
|
* Sample multivariate normal random values with given mean and covariance.
|
|
8235
8359
|
*
|
|
8236
8360
|
* The values are returned with the given shape, along with the final dimension
|
|
@@ -8271,6 +8395,97 @@ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
|
|
|
8271
8395
|
const theta = u2.mul(2 * Math.PI);
|
|
8272
8396
|
return radius.mul(cos(theta));
|
|
8273
8397
|
}, { staticArgnums: [1] });
|
|
8398
|
+
/**
|
|
8399
|
+
* @function
|
|
8400
|
+
* Sample from a Pareto distribution with shape parameter `b` and support [1, ∞).
|
|
8401
|
+
*/
|
|
8402
|
+
const pareto = jit$1(function pareto$1(key$1, b, shape$1 = []) {
|
|
8403
|
+
b = fudgeArray(b);
|
|
8404
|
+
return exp(exponential(key$1, shape$1).div(b));
|
|
8405
|
+
}, { staticArgnums: [2] });
|
|
8406
|
+
/**
|
|
8407
|
+
* Return a random permutation of an integer range or of an array along `axis`.
|
|
8408
|
+
*/
|
|
8409
|
+
function permutation(key$1, x, axis = 0) {
|
|
8410
|
+
if (typeof x === "number") {
|
|
8411
|
+
if (!Number.isInteger(x) || x < 0) throw new Error(`permutation: x must be a non-negative integer, got ${x}`);
|
|
8412
|
+
return argsort(uniform(key$1, [x])).astype(require_backend.DType.Int32);
|
|
8413
|
+
}
|
|
8414
|
+
const arr = fudgeArray(x);
|
|
8415
|
+
axis = require_backend.checkAxis(axis, arr.ndim);
|
|
8416
|
+
const perm = permutation(key$1, arr.shape[axis]);
|
|
8417
|
+
const index = JsArray(axis).fill([]);
|
|
8418
|
+
index.push(perm);
|
|
8419
|
+
return arr.slice(...index);
|
|
8420
|
+
}
|
|
8421
|
+
/**
|
|
8422
|
+
* @function
|
|
8423
|
+
* Sample Rademacher random values, uniformly from {-1, 1}.
|
|
8424
|
+
*/
|
|
8425
|
+
const rademacher = jit$1(function rademacher$1(key$1, { shape: shape$1 = [], dtype = require_backend.DType.Int32 } = {}) {
|
|
8426
|
+
if (dtype === require_backend.DType.Uint32 || dtype === require_backend.DType.Bool) throw new Error(`rademacher: unsupported dtype ${dtype}`);
|
|
8427
|
+
const one = array(1, {
|
|
8428
|
+
dtype,
|
|
8429
|
+
device: key$1.device
|
|
8430
|
+
});
|
|
8431
|
+
const minusOne = array(-1, {
|
|
8432
|
+
dtype,
|
|
8433
|
+
device: key$1.device
|
|
8434
|
+
});
|
|
8435
|
+
return where(bernoulli(key$1, .5, shape$1), one, minusOne);
|
|
8436
|
+
}, { staticArgnums: [1] });
|
|
8437
|
+
/**
|
|
8438
|
+
* @function
|
|
8439
|
+
* Sample integer values uniformly from `[minval, maxval)`.
|
|
8440
|
+
*
|
|
8441
|
+
* This uses modulo reduction of uniform 32-bit random bits. For ranges that do
|
|
8442
|
+
* not divide 2^32, this introduces a very small modulo bias.
|
|
8443
|
+
*/
|
|
8444
|
+
const randint = jit$1(function randint$1(key$1, { minval, maxval, shape: shape$1 = [], dtype = require_backend.DType.Int32 }) {
|
|
8445
|
+
if (!Number.isInteger(minval) || !Number.isInteger(maxval)) throw new Error("randint: minval and maxval must be integers");
|
|
8446
|
+
if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
|
|
8447
|
+
if (dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Uint32) throw new Error(`randint: dtype must be int32 or uint32, got ${dtype}`);
|
|
8448
|
+
if (dtype === require_backend.DType.Uint32 && minval < 0) throw new Error("randint: uint32 dtype requires minval >= 0");
|
|
8449
|
+
const range$1 = maxval - minval;
|
|
8450
|
+
return bits(key$1, shape$1).mod(range$1).astype(dtype).add(minval);
|
|
8451
|
+
}, { staticArgnums: [1] });
|
|
8452
|
+
/**
|
|
8453
|
+
* @function
|
|
8454
|
+
* Sample Rayleigh random values with the provided scale parameter.
|
|
8455
|
+
*/
|
|
8456
|
+
const rayleigh = jit$1(function rayleigh$1(key$1, scale = 1, shape$1 = []) {
|
|
8457
|
+
scale = fudgeArray(scale);
|
|
8458
|
+
return sqrt(exponential(key$1, shape$1).mul(2)).mul(scale);
|
|
8459
|
+
}, { staticArgnums: [2] });
|
|
8460
|
+
/**
|
|
8461
|
+
* @function
|
|
8462
|
+
* Sample triangular random values on `[left, right]` with the given mode.
|
|
8463
|
+
*/
|
|
8464
|
+
const triangular = jit$1(function triangular$1(key$1, left, mode, right, shape$1 = []) {
|
|
8465
|
+
left = fudgeArray(left);
|
|
8466
|
+
mode = fudgeArray(mode);
|
|
8467
|
+
right = fudgeArray(right);
|
|
8468
|
+
const u = uniform(key$1, shape$1);
|
|
8469
|
+
const width = right.ref.sub(left.ref);
|
|
8470
|
+
const leftSpan = mode.ref.sub(left.ref);
|
|
8471
|
+
const rightSpan = right.ref.sub(mode);
|
|
8472
|
+
const cutoff = leftSpan.ref.div(width.ref);
|
|
8473
|
+
const cond = u.ref.less(cutoff);
|
|
8474
|
+
const lower = left.add(sqrt(u.ref.mul(width.ref).mul(leftSpan)));
|
|
8475
|
+
const upper = right.sub(sqrt(negative(u).add(1).mul(width).mul(rightSpan)));
|
|
8476
|
+
return where(cond, lower, upper);
|
|
8477
|
+
}, { staticArgnums: [4] });
|
|
8478
|
+
/**
|
|
8479
|
+
* @function
|
|
8480
|
+
* Sample Weibull minimum random values.
|
|
8481
|
+
*
|
|
8482
|
+
* Uses `scale * exponential(key) ** (1 / concentration)`.
|
|
8483
|
+
*/
|
|
8484
|
+
const weibullMin = jit$1(function weibullMin$1(key$1, scale, concentration, shape$1 = []) {
|
|
8485
|
+
scale = fudgeArray(scale);
|
|
8486
|
+
concentration = fudgeArray(concentration);
|
|
8487
|
+
return scale.mul(exp(log(exponential(key$1, shape$1)).div(concentration)));
|
|
8488
|
+
}, { staticArgnums: [3] });
|
|
8274
8489
|
|
|
8275
8490
|
//#endregion
|
|
8276
8491
|
//#region src/library/scipy-special.ts
|
package/dist/index.d.cts
CHANGED
|
@@ -2722,7 +2722,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
|
|
|
2722
2722
|
localWindowSize?: number | [number, number];
|
|
2723
2723
|
}): Array;
|
|
2724
2724
|
declare namespace random_d_exports {
|
|
2725
|
-
export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2725
|
+
export { ball, bernoulli, bits, categorical, cauchy, choice, doubleSidedMaxwell, exponential, geometric, gumbel, key, laplace, logistic, lognormal, maxwell, multivariateNormal, normal, pareto, permutation, rademacher, randint, rayleigh, split, triangular, uniform, weibullMin };
|
|
2726
2726
|
}
|
|
2727
2727
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
2728
2728
|
declare function key(seed: ArrayLike): Array;
|
|
@@ -2738,6 +2738,16 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
|
|
|
2738
2738
|
minval?: number | undefined;
|
|
2739
2739
|
maxval?: number | undefined;
|
|
2740
2740
|
} | undefined) => Array>;
|
|
2741
|
+
/**
|
|
2742
|
+
* @function
|
|
2743
|
+
* Sample points uniformly from the Euclidean unit ball in `d` dimensions.
|
|
2744
|
+
*
|
|
2745
|
+
* Only the Euclidean `p=2` case is currently supported.
|
|
2746
|
+
*/
|
|
2747
|
+
declare const ball: OwnedFunction<(key: ArrayLike, d: number, args_2?: {
|
|
2748
|
+
p?: number | undefined;
|
|
2749
|
+
shape?: number[] | undefined;
|
|
2750
|
+
} | undefined) => Array>;
|
|
2741
2751
|
/**
|
|
2742
2752
|
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
2743
2753
|
*
|
|
@@ -2778,11 +2788,42 @@ declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, arg
|
|
|
2778
2788
|
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
|
2779
2789
|
*/
|
|
2780
2790
|
declare const cauchy: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2791
|
+
/**
|
|
2792
|
+
* Sample from a population with optional replacement and optional probabilities.
|
|
2793
|
+
*
|
|
2794
|
+
* This implements the common JAX-compatible cases: integer populations and
|
|
2795
|
+
* array populations along `axis`. Probabilities `p`, if provided, are sampled
|
|
2796
|
+
* via `categorical(log(p))`.
|
|
2797
|
+
*/
|
|
2798
|
+
declare function choice(key: Array, a: number | ArrayLike, {
|
|
2799
|
+
shape,
|
|
2800
|
+
replace,
|
|
2801
|
+
p,
|
|
2802
|
+
axis
|
|
2803
|
+
}?: {
|
|
2804
|
+
shape?: number[];
|
|
2805
|
+
replace?: boolean;
|
|
2806
|
+
p?: ArrayLike;
|
|
2807
|
+
axis?: number;
|
|
2808
|
+
}): Array;
|
|
2809
|
+
/**
|
|
2810
|
+
* @function
|
|
2811
|
+
* Sample double-sided Maxwell random values with the provided location and scale.
|
|
2812
|
+
*/
|
|
2813
|
+
declare const doubleSidedMaxwell: OwnedFunction<(key: ArrayLike, loc: ArrayLike, scale: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2781
2814
|
/**
|
|
2782
2815
|
* @function
|
|
2783
2816
|
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
2784
2817
|
*/
|
|
2785
2818
|
declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2819
|
+
/**
|
|
2820
|
+
* @function
|
|
2821
|
+
* Sample geometric random values: the number of trials until first success.
|
|
2822
|
+
*/
|
|
2823
|
+
declare const geometric: OwnedFunction<(key: ArrayLike, p: ArrayLike, args_2?: {
|
|
2824
|
+
shape?: number[] | undefined;
|
|
2825
|
+
dtype?: DType | undefined;
|
|
2826
|
+
} | undefined) => Array>;
|
|
2786
2827
|
/**
|
|
2787
2828
|
* @function
|
|
2788
2829
|
* Sample from a Gumbel distribution with location 0 and scale 1.
|
|
@@ -2798,6 +2839,23 @@ declare const gumbel: OwnedFunction<(key: ArrayLike, shape?: number[] | undefine
|
|
|
2798
2839
|
* Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
|
|
2799
2840
|
*/
|
|
2800
2841
|
declare const laplace: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2842
|
+
/**
|
|
2843
|
+
* @function
|
|
2844
|
+
* Sample from a logistic distribution with location 0 and scale 1.
|
|
2845
|
+
*
|
|
2846
|
+
* Uses inverse transform sampling: `x = log(u) - log(1-u)`.
|
|
2847
|
+
*/
|
|
2848
|
+
declare const logistic: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2849
|
+
/**
|
|
2850
|
+
* @function
|
|
2851
|
+
* Sample log-normal random values: `exp(sigma * normal(key, shape))`.
|
|
2852
|
+
*/
|
|
2853
|
+
declare const lognormal: OwnedFunction<(key: ArrayLike, sigma?: ArrayLike | undefined, shape?: number[] | undefined) => Array>;
|
|
2854
|
+
/**
|
|
2855
|
+
* @function
|
|
2856
|
+
* Sample Maxwell-distributed random values.
|
|
2857
|
+
*/
|
|
2858
|
+
declare const maxwell: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2801
2859
|
/**
|
|
2802
2860
|
* @function
|
|
2803
2861
|
* Sample multivariate normal random values with given mean and covariance.
|
|
@@ -2824,6 +2882,53 @@ declare const multivariateNormal: OwnedFunction<(key: ArrayLike, mean: ArrayLike
|
|
|
2824
2882
|
* bitwise identical to JAX.
|
|
2825
2883
|
*/
|
|
2826
2884
|
declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2885
|
+
/**
|
|
2886
|
+
* @function
|
|
2887
|
+
* Sample from a Pareto distribution with shape parameter `b` and support [1, ∞).
|
|
2888
|
+
*/
|
|
2889
|
+
declare const pareto: OwnedFunction<(key: ArrayLike, b: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2890
|
+
/**
|
|
2891
|
+
* Return a random permutation of an integer range or of an array along `axis`.
|
|
2892
|
+
*/
|
|
2893
|
+
declare function permutation(key: Array, x: number | ArrayLike, axis?: number): Array;
|
|
2894
|
+
/**
|
|
2895
|
+
* @function
|
|
2896
|
+
* Sample Rademacher random values, uniformly from {-1, 1}.
|
|
2897
|
+
*/
|
|
2898
|
+
declare const rademacher: OwnedFunction<(key: ArrayLike, args_1?: {
|
|
2899
|
+
shape?: number[] | undefined;
|
|
2900
|
+
dtype?: DType | undefined;
|
|
2901
|
+
} | undefined) => Array>;
|
|
2902
|
+
/**
|
|
2903
|
+
* @function
|
|
2904
|
+
* Sample integer values uniformly from `[minval, maxval)`.
|
|
2905
|
+
*
|
|
2906
|
+
* This uses modulo reduction of uniform 32-bit random bits. For ranges that do
|
|
2907
|
+
* not divide 2^32, this introduces a very small modulo bias.
|
|
2908
|
+
*/
|
|
2909
|
+
declare const randint: OwnedFunction<(key: ArrayLike, args_1: {
|
|
2910
|
+
minval: number;
|
|
2911
|
+
maxval: number;
|
|
2912
|
+
shape?: number[] | undefined;
|
|
2913
|
+
dtype?: DType | undefined;
|
|
2914
|
+
}) => Array>;
|
|
2915
|
+
/**
|
|
2916
|
+
* @function
|
|
2917
|
+
* Sample Rayleigh random values with the provided scale parameter.
|
|
2918
|
+
*/
|
|
2919
|
+
declare const rayleigh: OwnedFunction<(key: ArrayLike, scale?: ArrayLike | undefined, shape?: number[] | undefined) => Array>;
|
|
2920
|
+
/**
|
|
2921
|
+
* @function
|
|
2922
|
+
* Sample triangular random values on `[left, right]` with the given mode.
|
|
2923
|
+
*/
|
|
2924
|
+
declare const triangular: OwnedFunction<(key: ArrayLike, left: ArrayLike, mode: ArrayLike, right: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2925
|
+
/**
|
|
2926
|
+
* @function
|
|
2927
|
+
* Sample Weibull minimum random values.
|
|
2928
|
+
*
|
|
2929
|
+
* Uses `scale * exponential(key) ** (1 / concentration)`.
|
|
2930
|
+
*/
|
|
2931
|
+
declare const weibullMin: OwnedFunction<(key: ArrayLike, scale: ArrayLike, concentration: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2827
2932
|
declare namespace scipy_special_d_exports {
|
|
2828
2933
|
export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
|
|
2829
2934
|
}
|