@jax-js/jax 0.1.11 → 0.1.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -63,6 +63,7 @@ of late 2025.
63
63
 
64
64
  Community usage:
65
65
 
66
+ - [**g9-jaxjs**: Automatically interactive graphics with forward-mode AD](https://srush.github.io/g9jax/)
66
67
  - [**autoresearch-webgpu**: autoresesarch, in the browser](https://autoresearch.lucasgelfond.online/)
67
68
  - [**tanh.xyz**: Interactive ML visualizations](https://tanh.xyz/)
68
69
  - [**jax-js-bayes**: Declarative Bayesian modeling library](https://github.com/StefanSko/jax-js-bayes)
@@ -72,10 +73,13 @@ Demos on the jax-js website:
72
73
  - [Training neural networks on MNIST](https://jax-js.com/mnist)
73
74
  - [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
74
75
  - [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
76
+ - [Object detection: D-FINE (ONNX)](https://jax-js.com/d-fine)
75
77
  - [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
76
78
  - [Fluid simulation (Navier-Stokes)](https://jax-js.com/fluid-sim)
79
+ - [Neural cellular automata](https://jax-js.com/nca-growing)
77
80
  - [In-browser REPL](https://jax-js.com/repl)
78
81
  - [Matmul benchmark](https://jax-js.com/bench/matmul)
82
+ - [Matvec benchmark](https://jax-js.com/bench/matvec)
79
83
  - [Conv2d benchmark](https://jax-js.com/bench/conv2d)
80
84
  - [Mandelbrot set](https://jax-js.com/mandelbrot)
81
85
 
@@ -422,7 +426,6 @@ Contributions are welcomed! Some fruitful areas to look into:
422
426
  - Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
423
427
  - Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
424
428
  and multithreading. (Even single-threaded Wasm could be ~20x faster.)
425
- - Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
426
429
  - Making a fast transformer inference engine, comparing against onnxruntime-web.
427
430
 
428
431
  You may join our [Discord server](https://discord.gg/BW6YsCd4Tf) and chat with the community.
@@ -1430,11 +1430,13 @@ var Reduction = class {
1430
1430
  function accessorGlobal(dtype, gid, st, indices) {
1431
1431
  const [index, valid] = st.toAluExp(indices);
1432
1432
  const [, len] = st.views[0].dataRange();
1433
+ if (valid.resolve()) return AluExp.globalIndex(dtype, gid, len, index);
1433
1434
  return AluExp.where(valid, AluExp.globalIndex(dtype, gid, len, index), AluExp.const(dtype, 0));
1434
1435
  }
1435
1436
  /** Expression for accessing `indices` in an array recipe with variable "idx". */
1436
1437
  function accessorAluExp(exp, st, indices) {
1437
1438
  const [index, valid] = st.toAluExp(indices);
1439
+ if (valid.resolve()) return exp.substitute({ idx: index });
1438
1440
  return AluExp.where(valid, exp.substitute({ idx: index }), AluExp.const(exp.dtype, 0));
1439
1441
  }
1440
1442
  function threefry2x32(k0, k1, c0, c1) {
@@ -4520,7 +4522,7 @@ var WasmBackend = class {
4520
4522
  const buffer = this.#getBuffer(slot);
4521
4523
  if (start === void 0) start = 0;
4522
4524
  if (count === void 0) count = buffer.byteLength - start;
4523
- if (buffer.buffer instanceof SharedArrayBuffer) return new Uint8Array(buffer.slice(start, start + count));
4525
+ if (hasSharedArrayBuffer() && buffer.buffer instanceof SharedArrayBuffer) return new Uint8Array(buffer.slice(start, start + count));
4524
4526
  else return buffer.slice(start, start + count);
4525
4527
  }
4526
4528
  async prepareKernel(kernel) {
@@ -5059,7 +5061,7 @@ async function createBackend(device) {
5059
5061
  if (!navigator.gpu) return null;
5060
5062
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
5061
5063
  if (!adapter) return null;
5062
- const { WebGPUBackend } = await import("./webgpu-Dg8FpYrH.js");
5064
+ const { WebGPUBackend } = await import("./webgpu-NkF1TZ0t.js");
5063
5065
  const importantLimits = [
5064
5066
  "maxBufferSize",
5065
5067
  "maxComputeInvocationsPerWorkgroup",
@@ -5097,7 +5099,7 @@ async function createBackend(device) {
5097
5099
  });
5098
5100
  if (!gl) return null;
5099
5101
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
5100
- const { WebGLBackend } = await import("./webgl-D8-14NzA.js");
5102
+ const { WebGLBackend } = await import("./webgl-NsFtyIts.js");
5101
5103
  return new WebGLBackend(gl);
5102
5104
  } else throw new Error(`Backend not found: ${device}`);
5103
5105
  }
@@ -1431,11 +1431,13 @@ var Reduction = class {
1431
1431
  function accessorGlobal(dtype, gid, st, indices) {
1432
1432
  const [index, valid] = st.toAluExp(indices);
1433
1433
  const [, len] = st.views[0].dataRange();
1434
+ if (valid.resolve()) return AluExp.globalIndex(dtype, gid, len, index);
1434
1435
  return AluExp.where(valid, AluExp.globalIndex(dtype, gid, len, index), AluExp.const(dtype, 0));
1435
1436
  }
1436
1437
  /** Expression for accessing `indices` in an array recipe with variable "idx". */
1437
1438
  function accessorAluExp(exp, st, indices) {
1438
1439
  const [index, valid] = st.toAluExp(indices);
1440
+ if (valid.resolve()) return exp.substitute({ idx: index });
1439
1441
  return AluExp.where(valid, exp.substitute({ idx: index }), AluExp.const(exp.dtype, 0));
1440
1442
  }
1441
1443
  function threefry2x32(k0, k1, c0, c1) {
@@ -4521,7 +4523,7 @@ var WasmBackend = class {
4521
4523
  const buffer = this.#getBuffer(slot);
4522
4524
  if (start === void 0) start = 0;
4523
4525
  if (count === void 0) count = buffer.byteLength - start;
4524
- if (buffer.buffer instanceof SharedArrayBuffer) return new Uint8Array(buffer.slice(start, start + count));
4526
+ if (hasSharedArrayBuffer() && buffer.buffer instanceof SharedArrayBuffer) return new Uint8Array(buffer.slice(start, start + count));
4525
4527
  else return buffer.slice(start, start + count);
4526
4528
  }
4527
4529
  async prepareKernel(kernel) {
@@ -5060,7 +5062,7 @@ async function createBackend(device) {
5060
5062
  if (!navigator.gpu) return null;
5061
5063
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
5062
5064
  if (!adapter) return null;
5063
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-uU9nnttc.cjs"));
5065
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-DDGCYtHa.cjs"));
5064
5066
  const importantLimits = [
5065
5067
  "maxBufferSize",
5066
5068
  "maxComputeInvocationsPerWorkgroup",
@@ -5098,7 +5100,7 @@ async function createBackend(device) {
5098
5100
  });
5099
5101
  if (!gl) return null;
5100
5102
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
5101
- const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-Ovaaa-Qx.cjs"));
5103
+ const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-pbfUGDA6.cjs"));
5102
5104
  return new WebGLBackend(gl);
5103
5105
  } else throw new Error(`Backend not found: ${device}`);
5104
5106
  }
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-DlYlOYqN.cjs');
33
+ const require_backend = require('./backend-DMyuoWi2.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -240,7 +240,7 @@ __export(tree_exports, {
240
240
  structure: () => structure,
241
241
  unflatten: () => unflatten
242
242
  });
243
- const JsArray$2 = globalThis.Array;
243
+ const JsArray$3 = globalThis.Array;
244
244
  let NodeType = /* @__PURE__ */ function(NodeType$1) {
245
245
  NodeType$1["Array"] = "Array";
246
246
  NodeType$1["Object"] = "Object";
@@ -288,7 +288,7 @@ function flatten(tree) {
288
288
  return [leaves$1, treedef];
289
289
  }
290
290
  function _flatten(tree, leaves$1) {
291
- if (JsArray$2.isArray(tree)) {
291
+ if (JsArray$3.isArray(tree)) {
292
292
  const childTrees = tree.map((c) => _flatten(c, leaves$1));
293
293
  return new JsTreeDef(NodeType.Array, null, childTrees);
294
294
  } else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
@@ -2495,7 +2495,7 @@ function splitGraphDataflow(backend, jaxpr) {
2495
2495
 
2496
2496
  //#endregion
2497
2497
  //#region src/frontend/array.ts
2498
- const JsArray$1 = globalThis.Array;
2498
+ const JsArray$2 = globalThis.Array;
2499
2499
  const inlineArrayLimit = 128;
2500
2500
  /** Version of pureArray with fudged types. */
2501
2501
  const fudgeArray = pureArray;
@@ -2935,6 +2935,15 @@ var Array$1 = class Array$1 extends Tracer {
2935
2935
  this.#check();
2936
2936
  const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
2937
2937
  if (this.#source instanceof require_backend.AluExp) {
2938
+ let resolvedSource;
2939
+ if (this.#st.contiguous && this.#st.size < inlineArrayLimit && (resolvedSource = this.#source.resolve()) !== void 0) {
2940
+ const byteLength = this.#st.size * require_backend.byteWidth(this.#dtype);
2941
+ const initialData = new Uint8Array(byteLength);
2942
+ require_backend.dtypedArray(this.#dtype, initialData).fill(resolvedSource);
2943
+ this.#source = this.#backend.malloc(byteLength, initialData);
2944
+ this.#st = require_backend.ShapeTracker.fromShape(this.shape);
2945
+ return;
2946
+ }
2938
2947
  const exp$2 = require_backend.accessorAluExp(this.#source, this.#st, indices);
2939
2948
  const kernel = new require_backend.Kernel(0, this.#st.size, exp$2);
2940
2949
  const output = this.#backend.malloc(kernel.bytes);
@@ -3385,7 +3394,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
3385
3394
  if (!shape$1) {
3386
3395
  shape$1 = [];
3387
3396
  let cur = values;
3388
- while (JsArray$1.isArray(cur)) {
3397
+ while (JsArray$2.isArray(cur)) {
3389
3398
  shape$1.push(cur.length);
3390
3399
  cur = cur[0];
3391
3400
  }
@@ -4269,7 +4278,7 @@ const jvpRules = {
4269
4278
  return [[L], [dL]];
4270
4279
  },
4271
4280
  [Primitive.LU]([a], [da]) {
4272
- const [luMatrix, pivots, permutation] = lu$1(a);
4281
+ const [luMatrix, pivots, permutation$1] = lu$1(a);
4273
4282
  const [m, n] = a.shape.slice(-2);
4274
4283
  const k = Math.min(m, n);
4275
4284
  const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
@@ -4281,7 +4290,7 @@ const jvpRules = {
4281
4290
  const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
4282
4291
  const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
4283
4292
  const U = uPadded.add(uEye);
4284
- const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
4293
+ const P = permutation$1.ref.reshape([...permutation$1.shape, 1]).equal(arange(m)).astype(da.dtype);
4285
4294
  const pda = batchMatmulT(P, mT(da));
4286
4295
  const la = mT(triangularSolve$1(L.ref, mT(pda), {
4287
4296
  lower: true,
@@ -4293,11 +4302,11 @@ const jvpRules = {
4293
4302
  return [[
4294
4303
  luMatrix,
4295
4304
  pivots,
4296
- permutation
4305
+ permutation$1
4297
4306
  ], [
4298
4307
  lDot.add(uDot),
4299
4308
  zerosLike$1(pivots.ref),
4300
- zerosLike$1(permutation.ref)
4309
+ zerosLike$1(permutation$1.ref)
4301
4310
  ]];
4302
4311
  },
4303
4312
  [Primitive.Jit](primals, tangents, { name, jaxpr }) {
@@ -5379,8 +5388,8 @@ function cross$1(x1, x2, axis = -1) {
5379
5388
  function det(a) {
5380
5389
  a = fudgeArray(a);
5381
5390
  const n = checkSquare("det", a);
5382
- const [lu$2, pivots, permutation] = lu(a);
5383
- permutation.dispose();
5391
+ const [lu$2, pivots, permutation$1] = lu(a);
5392
+ permutation$1.dispose();
5384
5393
  const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
5385
5394
  const sign$1 = parity.mul(-2).add(1);
5386
5395
  const diag$1 = lu$2.diagonal(0, -1, -2);
@@ -5469,8 +5478,8 @@ function matrixPower(a, n) {
5469
5478
  function slogdet(a) {
5470
5479
  a = fudgeArray(a);
5471
5480
  const n = checkSquare("slogdet", a);
5472
- const [lu$2, pivots, permutation] = lu(a);
5473
- permutation.dispose();
5481
+ const [lu$2, pivots, permutation$1] = lu(a);
5482
+ permutation$1.dispose();
5474
5483
  let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
5475
5484
  const diag$1 = lu$2.diagonal(0, -1, -2);
5476
5485
  parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
@@ -5508,9 +5517,9 @@ function solve(a, b) {
5508
5517
  n,
5509
5518
  m
5510
5519
  ]);
5511
- const [lu$2, pivots, permutation] = lu(a);
5520
+ const [lu$2, pivots, permutation$1] = lu(a);
5512
5521
  pivots.dispose();
5513
- const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
5522
+ const P = arange(n).equal(permutation$1.reshape([...permutation$1.shape, 1])).astype(b.dtype);
5514
5523
  const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
5515
5524
  leftSide: true,
5516
5525
  lower: true,
@@ -7366,7 +7375,7 @@ __export(lax_exports, {
7366
7375
  stopGradient: () => stopGradient$1,
7367
7376
  topK: () => topK
7368
7377
  });
7369
- const JsArray = globalThis.Array;
7378
+ const JsArray$1 = globalThis.Array;
7370
7379
  /** Elementwise bitcast an array into a new dtype. */
7371
7380
  function bitcastConvertType(x, newDtype) {
7372
7381
  return fudgeArray(x).view(newDtype);
@@ -7553,7 +7562,7 @@ function convTransposePadding(k, s, padding) {
7553
7562
  } else if (padding === "VALID") {
7554
7563
  padLen = k + s - 2 + Math.max(k - s, 0);
7555
7564
  pad1 = k - 1;
7556
- } else if (JsArray.isArray(padding)) {
7565
+ } else if (JsArray$1.isArray(padding)) {
7557
7566
  const pads = [k - 1 - padding[0], k - 1 - padding[1]];
7558
7567
  pad1 = pads[0];
7559
7568
  padLen = pads[0] + pads[1];
@@ -8072,19 +8081,34 @@ function dotProductAttention(query, key$1, value, opts = {}) {
8072
8081
  //#region src/library/random.ts
8073
8082
  var random_exports = {};
8074
8083
  __export(random_exports, {
8084
+ ball: () => ball,
8075
8085
  bernoulli: () => bernoulli,
8076
8086
  bits: () => bits,
8077
8087
  categorical: () => categorical,
8078
8088
  cauchy: () => cauchy,
8089
+ choice: () => choice,
8090
+ doubleSidedMaxwell: () => doubleSidedMaxwell,
8079
8091
  exponential: () => exponential,
8092
+ geometric: () => geometric,
8080
8093
  gumbel: () => gumbel,
8081
8094
  key: () => key,
8082
8095
  laplace: () => laplace,
8096
+ logistic: () => logistic,
8097
+ lognormal: () => lognormal,
8098
+ maxwell: () => maxwell,
8083
8099
  multivariateNormal: () => multivariateNormal,
8084
8100
  normal: () => normal,
8101
+ pareto: () => pareto,
8102
+ permutation: () => permutation,
8103
+ rademacher: () => rademacher,
8104
+ randint: () => randint,
8105
+ rayleigh: () => rayleigh,
8085
8106
  split: () => split,
8086
- uniform: () => uniform
8107
+ triangular: () => triangular,
8108
+ uniform: () => uniform,
8109
+ weibullMin: () => weibullMin
8087
8110
  });
8111
+ const JsArray = globalThis.Array;
8088
8112
  function validateKeyShape(key$1, scalar = false) {
8089
8113
  if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
8090
8114
  if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
@@ -8137,6 +8161,21 @@ const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxv
8137
8161
  else return rand.mul(maxval - minval).add(minval);
8138
8162
  }, { staticArgnums: [1, 2] });
8139
8163
  /**
8164
+ * @function
8165
+ * Sample points uniformly from the Euclidean unit ball in `d` dimensions.
8166
+ *
8167
+ * Only the Euclidean `p=2` case is currently supported.
8168
+ */
8169
+ const ball = jit$1(function ball$1(key$1, d, { p = 2, shape: shape$1 = [] } = {}) {
8170
+ if (!Number.isInteger(d) || d <= 0) throw new Error(`ball: dimension must be a positive integer, got ${d}`);
8171
+ if (p !== 2) throw new Error("ball: only the Euclidean p=2 case is supported");
8172
+ const [k1, k2] = split(key$1, 2);
8173
+ const z = normal(k1, [...shape$1, d]);
8174
+ const norm = sqrt(z.ref.mul(z.ref).sum(-1, { keepdims: true }));
8175
+ const radius = exp(log(uniform(k2, [...shape$1, 1])).mul(1 / d));
8176
+ return z.div(norm).mul(radius);
8177
+ }, { staticArgnums: [1, 2] });
8178
+ /**
8140
8179
  * Sample Bernoulli random variables with given mean (0,1 categorical).
8141
8180
  *
8142
8181
  * Returns a random Boolean array with the specified shape. `p` can be an array
@@ -8198,6 +8237,57 @@ const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
8198
8237
  return tan(u.sub(.5).mul(Math.PI));
8199
8238
  }, { staticArgnums: [1] });
8200
8239
  /**
8240
+ * Sample from a population with optional replacement and optional probabilities.
8241
+ *
8242
+ * This implements the common JAX-compatible cases: integer populations and
8243
+ * array populations along `axis`. Probabilities `p`, if provided, are sampled
8244
+ * via `categorical(log(p))`.
8245
+ */
8246
+ function choice(key$1, a, { shape: shape$1 = [], replace = true, p, axis = 0 } = {}) {
8247
+ let n;
8248
+ let values = null;
8249
+ if (typeof a === "number") {
8250
+ if (!Number.isInteger(a) || a < 0) throw new Error(`choice: a must be a non-negative integer, got ${a}`);
8251
+ n = a;
8252
+ } else {
8253
+ values = fudgeArray(a);
8254
+ axis = require_backend.checkAxis(axis, values.ndim);
8255
+ n = values.shape[axis];
8256
+ }
8257
+ let indices;
8258
+ if (p !== void 0) indices = categorical(key$1, log(p), {
8259
+ shape: shape$1,
8260
+ replace
8261
+ });
8262
+ else if (replace) indices = randint(key$1, {
8263
+ minval: 0,
8264
+ maxval: n,
8265
+ shape: shape$1
8266
+ });
8267
+ else {
8268
+ const k = shape$1.reduce((acc, x) => acc * x, 1);
8269
+ if (k > n) throw new Error(`Number of samples without replacement (${k}) cannot exceed population size (${n}).`);
8270
+ indices = permutation(key$1, n).slice([0, k]).reshape(shape$1);
8271
+ }
8272
+ if (values === null) return indices;
8273
+ const index = JsArray(axis).fill([]);
8274
+ index.push(indices);
8275
+ return values.slice(...index);
8276
+ }
8277
+ /**
8278
+ * @function
8279
+ * Sample double-sided Maxwell random values with the provided location and scale.
8280
+ */
8281
+ const doubleSidedMaxwell = jit$1(function doubleSidedMaxwell$1(key$1, loc, scale, shape$1 = []) {
8282
+ loc = fudgeArray(loc);
8283
+ scale = fudgeArray(scale);
8284
+ const [k1, k2] = split(key$1, 2);
8285
+ return rademacher(k1, {
8286
+ shape: shape$1,
8287
+ dtype: require_backend.DType.Float32
8288
+ }).mul(maxwell(k2, shape$1)).mul(scale).add(loc);
8289
+ }, { staticArgnums: [3] });
8290
+ /**
8201
8291
  * @function
8202
8292
  * Sample exponential random values according to `p(x) = exp(-x)`.
8203
8293
  */
@@ -8207,6 +8297,14 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
8207
8297
  }, { staticArgnums: [1] });
8208
8298
  /**
8209
8299
  * @function
8300
+ * Sample geometric random values: the number of trials until first success.
8301
+ */
8302
+ const geometric = jit$1(function geometric$1(key$1, p, { shape: shape$1 = [], dtype = require_backend.DType.Int32 } = {}) {
8303
+ p = fudgeArray(p);
8304
+ return floor(log1p(negative(uniform(key$1, shape$1))).div(log1p(negative(p)))).add(1).astype(dtype);
8305
+ }, { staticArgnums: [2] });
8306
+ /**
8307
+ * @function
8210
8308
  * Sample from a Gumbel distribution with location 0 and scale 1.
8211
8309
  *
8212
8310
  * Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
@@ -8231,6 +8329,32 @@ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
8231
8329
  }, { staticArgnums: [1] });
8232
8330
  /**
8233
8331
  * @function
8332
+ * Sample from a logistic distribution with location 0 and scale 1.
8333
+ *
8334
+ * Uses inverse transform sampling: `x = log(u) - log(1-u)`.
8335
+ */
8336
+ const logistic = jit$1(function logistic$1(key$1, shape$1 = []) {
8337
+ const u = uniform(key$1, shape$1);
8338
+ return log(u.ref).sub(log1p(negative(u)));
8339
+ }, { staticArgnums: [1] });
8340
+ /**
8341
+ * @function
8342
+ * Sample log-normal random values: `exp(sigma * normal(key, shape))`.
8343
+ */
8344
+ const lognormal = jit$1(function lognormal$1(key$1, sigma = 1, shape$1 = []) {
8345
+ sigma = fudgeArray(sigma);
8346
+ return exp(normal(key$1, shape$1).mul(sigma));
8347
+ }, { staticArgnums: [2] });
8348
+ /**
8349
+ * @function
8350
+ * Sample Maxwell-distributed random values.
8351
+ */
8352
+ const maxwell = jit$1(function maxwell$1(key$1, shape$1 = []) {
8353
+ const z = normal(key$1, [...shape$1, 3]);
8354
+ return sqrt(z.ref.mul(z).sum(-1));
8355
+ }, { staticArgnums: [1] });
8356
+ /**
8357
+ * @function
8234
8358
  * Sample multivariate normal random values with given mean and covariance.
8235
8359
  *
8236
8360
  * The values are returned with the given shape, along with the final dimension
@@ -8271,6 +8395,97 @@ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
8271
8395
  const theta = u2.mul(2 * Math.PI);
8272
8396
  return radius.mul(cos(theta));
8273
8397
  }, { staticArgnums: [1] });
8398
+ /**
8399
+ * @function
8400
+ * Sample from a Pareto distribution with shape parameter `b` and support [1, ∞).
8401
+ */
8402
+ const pareto = jit$1(function pareto$1(key$1, b, shape$1 = []) {
8403
+ b = fudgeArray(b);
8404
+ return exp(exponential(key$1, shape$1).div(b));
8405
+ }, { staticArgnums: [2] });
8406
+ /**
8407
+ * Return a random permutation of an integer range or of an array along `axis`.
8408
+ */
8409
+ function permutation(key$1, x, axis = 0) {
8410
+ if (typeof x === "number") {
8411
+ if (!Number.isInteger(x) || x < 0) throw new Error(`permutation: x must be a non-negative integer, got ${x}`);
8412
+ return argsort(uniform(key$1, [x])).astype(require_backend.DType.Int32);
8413
+ }
8414
+ const arr = fudgeArray(x);
8415
+ axis = require_backend.checkAxis(axis, arr.ndim);
8416
+ const perm = permutation(key$1, arr.shape[axis]);
8417
+ const index = JsArray(axis).fill([]);
8418
+ index.push(perm);
8419
+ return arr.slice(...index);
8420
+ }
8421
+ /**
8422
+ * @function
8423
+ * Sample Rademacher random values, uniformly from {-1, 1}.
8424
+ */
8425
+ const rademacher = jit$1(function rademacher$1(key$1, { shape: shape$1 = [], dtype = require_backend.DType.Int32 } = {}) {
8426
+ if (dtype === require_backend.DType.Uint32 || dtype === require_backend.DType.Bool) throw new Error(`rademacher: unsupported dtype ${dtype}`);
8427
+ const one = array(1, {
8428
+ dtype,
8429
+ device: key$1.device
8430
+ });
8431
+ const minusOne = array(-1, {
8432
+ dtype,
8433
+ device: key$1.device
8434
+ });
8435
+ return where(bernoulli(key$1, .5, shape$1), one, minusOne);
8436
+ }, { staticArgnums: [1] });
8437
+ /**
8438
+ * @function
8439
+ * Sample integer values uniformly from `[minval, maxval)`.
8440
+ *
8441
+ * This uses modulo reduction of uniform 32-bit random bits. For ranges that do
8442
+ * not divide 2^32, this introduces a very small modulo bias.
8443
+ */
8444
+ const randint = jit$1(function randint$1(key$1, { minval, maxval, shape: shape$1 = [], dtype = require_backend.DType.Int32 }) {
8445
+ if (!Number.isInteger(minval) || !Number.isInteger(maxval)) throw new Error("randint: minval and maxval must be integers");
8446
+ if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
8447
+ if (dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Uint32) throw new Error(`randint: dtype must be int32 or uint32, got ${dtype}`);
8448
+ if (dtype === require_backend.DType.Uint32 && minval < 0) throw new Error("randint: uint32 dtype requires minval >= 0");
8449
+ const range$1 = maxval - minval;
8450
+ return bits(key$1, shape$1).mod(range$1).astype(dtype).add(minval);
8451
+ }, { staticArgnums: [1] });
8452
+ /**
8453
+ * @function
8454
+ * Sample Rayleigh random values with the provided scale parameter.
8455
+ */
8456
+ const rayleigh = jit$1(function rayleigh$1(key$1, scale = 1, shape$1 = []) {
8457
+ scale = fudgeArray(scale);
8458
+ return sqrt(exponential(key$1, shape$1).mul(2)).mul(scale);
8459
+ }, { staticArgnums: [2] });
8460
+ /**
8461
+ * @function
8462
+ * Sample triangular random values on `[left, right]` with the given mode.
8463
+ */
8464
+ const triangular = jit$1(function triangular$1(key$1, left, mode, right, shape$1 = []) {
8465
+ left = fudgeArray(left);
8466
+ mode = fudgeArray(mode);
8467
+ right = fudgeArray(right);
8468
+ const u = uniform(key$1, shape$1);
8469
+ const width = right.ref.sub(left.ref);
8470
+ const leftSpan = mode.ref.sub(left.ref);
8471
+ const rightSpan = right.ref.sub(mode);
8472
+ const cutoff = leftSpan.ref.div(width.ref);
8473
+ const cond = u.ref.less(cutoff);
8474
+ const lower = left.add(sqrt(u.ref.mul(width.ref).mul(leftSpan)));
8475
+ const upper = right.sub(sqrt(negative(u).add(1).mul(width).mul(rightSpan)));
8476
+ return where(cond, lower, upper);
8477
+ }, { staticArgnums: [4] });
8478
+ /**
8479
+ * @function
8480
+ * Sample Weibull minimum random values.
8481
+ *
8482
+ * Uses `scale * exponential(key) ** (1 / concentration)`.
8483
+ */
8484
+ const weibullMin = jit$1(function weibullMin$1(key$1, scale, concentration, shape$1 = []) {
8485
+ scale = fudgeArray(scale);
8486
+ concentration = fudgeArray(concentration);
8487
+ return scale.mul(exp(log(exponential(key$1, shape$1)).div(concentration)));
8488
+ }, { staticArgnums: [3] });
8274
8489
 
8275
8490
  //#endregion
8276
8491
  //#region src/library/scipy-special.ts
package/dist/index.d.cts CHANGED
@@ -2722,7 +2722,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
2722
2722
  localWindowSize?: number | [number, number];
2723
2723
  }): Array;
2724
2724
  declare namespace random_d_exports {
2725
- export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2725
+ export { ball, bernoulli, bits, categorical, cauchy, choice, doubleSidedMaxwell, exponential, geometric, gumbel, key, laplace, logistic, lognormal, maxwell, multivariateNormal, normal, pareto, permutation, rademacher, randint, rayleigh, split, triangular, uniform, weibullMin };
2726
2726
  }
2727
2727
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
2728
2728
  declare function key(seed: ArrayLike): Array;
@@ -2738,6 +2738,16 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
2738
2738
  minval?: number | undefined;
2739
2739
  maxval?: number | undefined;
2740
2740
  } | undefined) => Array>;
2741
+ /**
2742
+ * @function
2743
+ * Sample points uniformly from the Euclidean unit ball in `d` dimensions.
2744
+ *
2745
+ * Only the Euclidean `p=2` case is currently supported.
2746
+ */
2747
+ declare const ball: OwnedFunction<(key: ArrayLike, d: number, args_2?: {
2748
+ p?: number | undefined;
2749
+ shape?: number[] | undefined;
2750
+ } | undefined) => Array>;
2741
2751
  /**
2742
2752
  * Sample Bernoulli random variables with given mean (0,1 categorical).
2743
2753
  *
@@ -2778,11 +2788,42 @@ declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, arg
2778
2788
  * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
2779
2789
  */
2780
2790
  declare const cauchy: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2791
+ /**
2792
+ * Sample from a population with optional replacement and optional probabilities.
2793
+ *
2794
+ * This implements the common JAX-compatible cases: integer populations and
2795
+ * array populations along `axis`. Probabilities `p`, if provided, are sampled
2796
+ * via `categorical(log(p))`.
2797
+ */
2798
+ declare function choice(key: Array, a: number | ArrayLike, {
2799
+ shape,
2800
+ replace,
2801
+ p,
2802
+ axis
2803
+ }?: {
2804
+ shape?: number[];
2805
+ replace?: boolean;
2806
+ p?: ArrayLike;
2807
+ axis?: number;
2808
+ }): Array;
2809
+ /**
2810
+ * @function
2811
+ * Sample double-sided Maxwell random values with the provided location and scale.
2812
+ */
2813
+ declare const doubleSidedMaxwell: OwnedFunction<(key: ArrayLike, loc: ArrayLike, scale: ArrayLike, shape?: number[] | undefined) => Array>;
2781
2814
  /**
2782
2815
  * @function
2783
2816
  * Sample exponential random values according to `p(x) = exp(-x)`.
2784
2817
  */
2785
2818
  declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2819
+ /**
2820
+ * @function
2821
+ * Sample geometric random values: the number of trials until first success.
2822
+ */
2823
+ declare const geometric: OwnedFunction<(key: ArrayLike, p: ArrayLike, args_2?: {
2824
+ shape?: number[] | undefined;
2825
+ dtype?: DType | undefined;
2826
+ } | undefined) => Array>;
2786
2827
  /**
2787
2828
  * @function
2788
2829
  * Sample from a Gumbel distribution with location 0 and scale 1.
@@ -2798,6 +2839,23 @@ declare const gumbel: OwnedFunction<(key: ArrayLike, shape?: number[] | undefine
2798
2839
  * Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
2799
2840
  */
2800
2841
  declare const laplace: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2842
+ /**
2843
+ * @function
2844
+ * Sample from a logistic distribution with location 0 and scale 1.
2845
+ *
2846
+ * Uses inverse transform sampling: `x = log(u) - log(1-u)`.
2847
+ */
2848
+ declare const logistic: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2849
+ /**
2850
+ * @function
2851
+ * Sample log-normal random values: `exp(sigma * normal(key, shape))`.
2852
+ */
2853
+ declare const lognormal: OwnedFunction<(key: ArrayLike, sigma?: ArrayLike | undefined, shape?: number[] | undefined) => Array>;
2854
+ /**
2855
+ * @function
2856
+ * Sample Maxwell-distributed random values.
2857
+ */
2858
+ declare const maxwell: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2801
2859
  /**
2802
2860
  * @function
2803
2861
  * Sample multivariate normal random values with given mean and covariance.
@@ -2824,6 +2882,53 @@ declare const multivariateNormal: OwnedFunction<(key: ArrayLike, mean: ArrayLike
2824
2882
  * bitwise identical to JAX.
2825
2883
  */
2826
2884
  declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2885
+ /**
2886
+ * @function
2887
+ * Sample from a Pareto distribution with shape parameter `b` and support [1, ∞).
2888
+ */
2889
+ declare const pareto: OwnedFunction<(key: ArrayLike, b: ArrayLike, shape?: number[] | undefined) => Array>;
2890
+ /**
2891
+ * Return a random permutation of an integer range or of an array along `axis`.
2892
+ */
2893
+ declare function permutation(key: Array, x: number | ArrayLike, axis?: number): Array;
2894
+ /**
2895
+ * @function
2896
+ * Sample Rademacher random values, uniformly from {-1, 1}.
2897
+ */
2898
+ declare const rademacher: OwnedFunction<(key: ArrayLike, args_1?: {
2899
+ shape?: number[] | undefined;
2900
+ dtype?: DType | undefined;
2901
+ } | undefined) => Array>;
2902
+ /**
2903
+ * @function
2904
+ * Sample integer values uniformly from `[minval, maxval)`.
2905
+ *
2906
+ * This uses modulo reduction of uniform 32-bit random bits. For ranges that do
2907
+ * not divide 2^32, this introduces a very small modulo bias.
2908
+ */
2909
+ declare const randint: OwnedFunction<(key: ArrayLike, args_1: {
2910
+ minval: number;
2911
+ maxval: number;
2912
+ shape?: number[] | undefined;
2913
+ dtype?: DType | undefined;
2914
+ }) => Array>;
2915
+ /**
2916
+ * @function
2917
+ * Sample Rayleigh random values with the provided scale parameter.
2918
+ */
2919
+ declare const rayleigh: OwnedFunction<(key: ArrayLike, scale?: ArrayLike | undefined, shape?: number[] | undefined) => Array>;
2920
+ /**
2921
+ * @function
2922
+ * Sample triangular random values on `[left, right]` with the given mode.
2923
+ */
2924
+ declare const triangular: OwnedFunction<(key: ArrayLike, left: ArrayLike, mode: ArrayLike, right: ArrayLike, shape?: number[] | undefined) => Array>;
2925
+ /**
2926
+ * @function
2927
+ * Sample Weibull minimum random values.
2928
+ *
2929
+ * Uses `scale * exponential(key) ** (1 / concentration)`.
2930
+ */
2931
+ declare const weibullMin: OwnedFunction<(key: ArrayLike, scale: ArrayLike, concentration: ArrayLike, shape?: number[] | undefined) => Array>;
2827
2932
  declare namespace scipy_special_d_exports {
2828
2933
  export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
2829
2934
  }