@jax-js/jax 0.1.7 → 0.1.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -43,6 +43,23 @@ way to get started on a blank HTML page.
43
43
  </script>
44
44
  ```
45
45
 
46
+ ## Examples
47
+
48
+ Cool things that the community has made with jax-js:
49
+
50
+ - [**tanh.xyz**: Interactive ML visualizations](https://tanh.xyz/)
51
+
52
+ And some more demos on the official website.
53
+
54
+ - [Training neural networks on MNIST](https://jax-js.com/mnist)
55
+ - [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
56
+ - [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
57
+ - [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
58
+ - [In-browser REPL](https://jax-js.com/repl)
59
+ - [Matmul benchmark](https://jax-js.com/bench/matmul)
60
+ - [Conv2d benchmark](https://jax-js.com/bench/conv2d)
61
+ - [Mandelbrot set](https://jax-js.com/mandelbrot)
62
+
46
63
  ## Feature comparison
47
64
 
48
65
  Here's a quick, high-level comparison with other popular web ML runtimes:
@@ -338,19 +355,6 @@ well as unique optimizations such as FlashAttention variants.
338
355
  That's all for this short tutorial. Please see the generated
339
356
  [API reference](https://jax-js.com/docs) for detailed documentation.
340
357
 
341
- ## Examples
342
-
343
- If you make something cool with jax-js, don't be a stranger! We can feature it here.
344
-
345
- - [Training neural networks on MNIST](https://jax-js.com/mnist)
346
- - [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
347
- - [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
348
- - [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
349
- - [In-browser REPL](https://jax-js.com/repl)
350
- - [Matmul benchmark](https://jax-js.com/bench/matmul)
351
- - [Conv2d benchmark](https://jax-js.com/bench/conv2d)
352
- - [Mandelbrot set](https://jax-js.com/mandelbrot)
353
-
354
358
  ## Development
355
359
 
356
360
  _The following technical details are for contributing to jax-js and modifying its internals._
@@ -363,6 +367,19 @@ pnpm install
363
367
  pnpm run build:watch
364
368
  ```
365
369
 
370
+ The `pnpm install` command automatically sets up Git hooks via
371
+ [Husky](https://typicode.github.io/husky/). Pre-commit hooks will run ESLint and Prettier on staged
372
+ files to ensure code quality.
373
+
374
+ You can also run linting and formatting manually:
375
+
376
+ ```bash
377
+ pnpm lint # Run ESLint
378
+ pnpm format # Format all files with Prettier
379
+ pnpm format:check # Check formatting without writing
380
+ pnpm check # Run TypeScript type checking
381
+ ```
382
+
366
383
  Then you can run tests in a headless browser using [Vitest](https://vitest.dev/).
367
384
 
368
385
  ```bash
@@ -1479,9 +1479,14 @@ var Routine = class {
1479
1479
  };
1480
1480
  /** One of the valid `Routine` that can be dispatched to backend. */
1481
1481
  let Routines = /* @__PURE__ */ function(Routines$1) {
1482
- /** Stable sorting algorithm along the last axis. */
1482
+ /**
1483
+ * Sort along the last axis.
1484
+ *
1485
+ * This may be _unstable_ but it often doesn't matter, sorting numbers is
1486
+ * bitwise unique up to signed zeros and NaNs.
1487
+ */
1483
1488
  Routines$1["Sort"] = "Sort";
1484
- /** Returns `int32` indices of the stably sorted array. */
1489
+ /** Stable sorting, returns `int32` indices and values of the sorted array. */
1485
1490
  Routines$1["Argsort"] = "Argsort";
1486
1491
  /**
1487
1492
  * Solve a triangular system of equations.
@@ -1545,7 +1550,13 @@ function runArgsort(type, [x], [y, yi]) {
1545
1550
  const out = y.subarray(offset, offset + n);
1546
1551
  const outi = yi.subarray(offset, offset + n);
1547
1552
  for (let i = 0; i < n; i++) outi[i] = i;
1548
- outi.sort((a, b) => ar[a] - ar[b]);
1553
+ outi.sort((a, b) => {
1554
+ const x$1 = ar[a];
1555
+ const y$1 = ar[b];
1556
+ if (isNaN(x$1)) return isNaN(y$1) ? 0 : 1;
1557
+ if (isNaN(y$1)) return -1;
1558
+ return x$1 === y$1 ? 0 : x$1 < y$1 ? -1 : 1;
1559
+ });
1549
1560
  for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
1550
1561
  }
1551
1562
  }
@@ -2321,7 +2332,7 @@ function tuneWebgpu(kernel) {
2321
2332
  if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2322
2333
  const s = dim.st.shape[dim.unroll - 1];
2323
2334
  if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
2324
- else for (const splits of [8, 4]) if (s % splits === 0) {
2335
+ else for (const splits of [4, 2]) if (s % splits === 0) {
2325
2336
  dim.applyUnroll(dim.unroll - 1, splits);
2326
2337
  break;
2327
2338
  }
@@ -4252,7 +4263,7 @@ async function createBackend(device) {
4252
4263
  if (!navigator.gpu) return null;
4253
4264
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4254
4265
  if (!adapter) return null;
4255
- const { WebGPUBackend } = await import("./webgpu-B96vzWGE.js");
4266
+ const { WebGPUBackend } = await import("./webgpu-AN0cG_nB.js");
4256
4267
  const importantLimits = [
4257
4268
  "maxBufferSize",
4258
4269
  "maxComputeInvocationsPerWorkgroup",
@@ -4290,7 +4301,7 @@ async function createBackend(device) {
4290
4301
  });
4291
4302
  if (!gl) return null;
4292
4303
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
4293
- const { WebGLBackend } = await import("./webgl-DweKSWEm.js");
4304
+ const { WebGLBackend } = await import("./webgl-DnGrclTz.js");
4294
4305
  return new WebGLBackend(gl);
4295
4306
  } else throw new Error(`Backend not found: ${device}`);
4296
4307
  }
@@ -1480,9 +1480,14 @@ var Routine = class {
1480
1480
  };
1481
1481
  /** One of the valid `Routine` that can be dispatched to backend. */
1482
1482
  let Routines = /* @__PURE__ */ function(Routines$1) {
1483
- /** Stable sorting algorithm along the last axis. */
1483
+ /**
1484
+ * Sort along the last axis.
1485
+ *
1486
+ * This may be _unstable_ but it often doesn't matter, sorting numbers is
1487
+ * bitwise unique up to signed zeros and NaNs.
1488
+ */
1484
1489
  Routines$1["Sort"] = "Sort";
1485
- /** Returns `int32` indices of the stably sorted array. */
1490
+ /** Stable sorting, returns `int32` indices and values of the sorted array. */
1486
1491
  Routines$1["Argsort"] = "Argsort";
1487
1492
  /**
1488
1493
  * Solve a triangular system of equations.
@@ -1546,7 +1551,13 @@ function runArgsort(type, [x], [y, yi]) {
1546
1551
  const out = y.subarray(offset, offset + n);
1547
1552
  const outi = yi.subarray(offset, offset + n);
1548
1553
  for (let i = 0; i < n; i++) outi[i] = i;
1549
- outi.sort((a, b) => ar[a] - ar[b]);
1554
+ outi.sort((a, b) => {
1555
+ const x$1 = ar[a];
1556
+ const y$1 = ar[b];
1557
+ if (isNaN(x$1)) return isNaN(y$1) ? 0 : 1;
1558
+ if (isNaN(y$1)) return -1;
1559
+ return x$1 === y$1 ? 0 : x$1 < y$1 ? -1 : 1;
1560
+ });
1550
1561
  for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
1551
1562
  }
1552
1563
  }
@@ -2322,7 +2333,7 @@ function tuneWebgpu(kernel) {
2322
2333
  if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2323
2334
  const s = dim.st.shape[dim.unroll - 1];
2324
2335
  if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
2325
- else for (const splits of [8, 4]) if (s % splits === 0) {
2336
+ else for (const splits of [4, 2]) if (s % splits === 0) {
2326
2337
  dim.applyUnroll(dim.unroll - 1, splits);
2327
2338
  break;
2328
2339
  }
@@ -4253,7 +4264,7 @@ async function createBackend(device) {
4253
4264
  if (!navigator.gpu) return null;
4254
4265
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4255
4266
  if (!adapter) return null;
4256
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BykvF26B.cjs"));
4267
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-CdjiJSa7.cjs"));
4257
4268
  const importantLimits = [
4258
4269
  "maxBufferSize",
4259
4270
  "maxComputeInvocationsPerWorkgroup",
@@ -4291,7 +4302,7 @@ async function createBackend(device) {
4291
4302
  });
4292
4303
  if (!gl) return null;
4293
4304
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
4294
- const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-DIIbKJ0G.cjs"));
4305
+ const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-C5NjXc1p.cjs"));
4295
4306
  return new WebGLBackend(gl);
4296
4307
  } else throw new Error(`Backend not found: ${device}`);
4297
4308
  }
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-B3foXiV_.cjs');
33
+ const require_backend = require('./backend-DpI0riom.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -920,18 +920,25 @@ var Tracer = class Tracer {
920
920
  return sort$1(this.transpose(perm)).transpose(require_backend.invertPermutation(perm));
921
921
  }
922
922
  /**
923
- * Return the indices that would sort an array. This may not be a stable
924
- * sorting algorithm; it need not preserve order of indices in ties.
923
+ * Return the indices that would sort an array. Unlike `sort`, this is
924
+ * guaranteed to be a stable sorting algorithm; it always returns the smaller
925
+ * index first in event of ties.
925
926
  *
926
927
  * See `jax.numpy.argsort` for full docs.
927
928
  */
928
929
  argsort(axis = -1) {
929
930
  axis = require_backend.checkAxis(axis, this.ndim);
930
- if (axis === this.ndim - 1) return argsort$1(this)[1];
931
+ if (axis === this.ndim - 1) {
932
+ const [y$1, yi$1] = argsort$1(this);
933
+ y$1.dispose();
934
+ return yi$1;
935
+ }
931
936
  const perm = require_backend.range(this.ndim);
932
937
  perm.splice(axis, 1);
933
938
  perm.push(axis);
934
- return argsort$1(this.transpose(perm))[1].transpose(require_backend.invertPermutation(perm));
939
+ const [y, yi] = argsort$1(this.transpose(perm));
940
+ y.dispose();
941
+ return yi.transpose(require_backend.invertPermutation(perm));
935
942
  }
936
943
  /**
937
944
  * Slice an array along one or more axes.
@@ -3416,32 +3423,26 @@ function fullInternal(aval, fillValue, device) {
3416
3423
  committed: device != void 0
3417
3424
  });
3418
3425
  }
3419
- function zerosLike$1(val, dtype) {
3420
- return fullLike(val, 0, dtype);
3426
+ function zerosLike$1(val, opts) {
3427
+ return fullLike(val, 0, opts);
3421
3428
  }
3422
- function onesLike$1(val, dtype) {
3423
- return fullLike(val, 1, dtype);
3429
+ function onesLike$1(val, opts) {
3430
+ return fullLike(val, 1, opts);
3424
3431
  }
3425
- function fullLike(val, fillValue, dtype) {
3432
+ function fullLike(val, fillValue, { dtype, shape: shape$1, device } = {}) {
3426
3433
  const aval = getAval(val);
3427
3434
  if (val instanceof Tracer) val.dispose();
3428
3435
  if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
3429
- const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
3430
- return fullInternal(sa, fillValue);
3436
+ const sa = new ShapedArray(shape$1 ?? aval.shape, dtype ?? aval.dtype, aval.weakType && dtype === void 0);
3437
+ return fullInternal(sa, fillValue, device);
3431
3438
  }
3432
3439
  /** Return a new array of given shape and type, filled with zeros. */
3433
- function zeros(shape$1, { dtype, device } = {}) {
3434
- return full(shape$1, 0, {
3435
- dtype,
3436
- device
3437
- });
3440
+ function zeros(shape$1, opts) {
3441
+ return full(shape$1, 0, opts);
3438
3442
  }
3439
3443
  /** Return a new array of given shape and type, filled with ones. */
3440
- function ones(shape$1, { dtype, device } = {}) {
3441
- return full(shape$1, 1, {
3442
- dtype,
3443
- device
3444
- });
3444
+ function ones(shape$1, opts) {
3445
+ return full(shape$1, 1, opts);
3445
3446
  }
3446
3447
  /** Return a new array of given shape and type, filled with `fill_value`. */
3447
3448
  function full(shape$1, fillValue, { dtype, device } = {}) {
@@ -5329,9 +5330,10 @@ function lstsq(a, b) {
5329
5330
  });
5330
5331
  const llb = triangularSolve(l, lb, {
5331
5332
  leftSide: true,
5333
+ lower: true,
5332
5334
  transposeA: true
5333
5335
  });
5334
- return matmul(at, llb.ref);
5336
+ return matmul(at, llb);
5335
5337
  } else {
5336
5338
  const ata = matmul(at.ref, a);
5337
5339
  const l = cholesky(ata, { symmetrizeInput: false });
@@ -5342,6 +5344,7 @@ function lstsq(a, b) {
5342
5344
  });
5343
5345
  const llb = triangularSolve(l, lb, {
5344
5346
  leftSide: true,
5347
+ lower: true,
5345
5348
  transposeA: true
5346
5349
  });
5347
5350
  return llb;
@@ -5421,7 +5424,7 @@ function solve(a, b) {
5421
5424
  lower: true,
5422
5425
  unitDiagonal: true
5423
5426
  });
5424
- let x = triangularSolve(lu$2, LPb.ref, {
5427
+ let x = triangularSolve(lu$2, LPb, {
5425
5428
  leftSide: true,
5426
5429
  lower: false
5427
5430
  });
@@ -6232,8 +6235,9 @@ function sort(a, axis = -1) {
6232
6235
  return fudgeArray(a).sort(axis);
6233
6236
  }
6234
6237
  /**
6235
- * Return indices that would sort an array. This may be an unstable sorting
6236
- * algorithm; it need not preserve order of indices in ties.
6238
+ * Return indices that would sort an array. Unlike `sort`, this is guaranteed to
6239
+ * be a stable sorting algorithm; it always returns the smaller index first in
6240
+ * event of ties.
6237
6241
  *
6238
6242
  * Returns an array of `int32` indices.
6239
6243
  *
@@ -6535,7 +6539,7 @@ function absolute(x) {
6535
6539
  /** Return an element-wise indication of sign of the input. */
6536
6540
  function sign(x) {
6537
6541
  x = fudgeArray(x);
6538
- return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
6542
+ return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
6539
6543
  }
6540
6544
  /** @function Return element-wise positive values of the input (no-op). */
6541
6545
  const positive = fudgeArray;
@@ -7003,7 +7007,10 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
7003
7007
  b = fudgeArray(b);
7004
7008
  if (!leftSide) transposeA = !transposeA;
7005
7009
  else b = moveaxis$1(b, -2, -1);
7006
- if (transposeA) a = moveaxis$1(a, -2, -1);
7010
+ if (transposeA) {
7011
+ a = moveaxis$1(a, -2, -1);
7012
+ lower = !lower;
7013
+ }
7007
7014
  let x = triangularSolve$1(a, b, {
7008
7015
  lower,
7009
7016
  unitDiagonal
@@ -7025,7 +7032,8 @@ __export(lax_exports, {
7025
7032
  erfc: () => erfc,
7026
7033
  linalg: () => lax_linalg_exports,
7027
7034
  reduceWindow: () => reduceWindow,
7028
- stopGradient: () => stopGradient$1
7035
+ stopGradient: () => stopGradient$1,
7036
+ topK: () => topK
7029
7037
  });
7030
7038
  const JsArray = globalThis.Array;
7031
7039
  /**
@@ -7249,6 +7257,39 @@ function erfc(x) {
7249
7257
  function stopGradient$1(x) {
7250
7258
  return stopGradient(x);
7251
7259
  }
7260
+ /**
7261
+ * Returns top `k` values and their indices along the specified axis of operand.
7262
+ *
7263
+ * This is a _stable_ algorithm: If two elements are equal, the lower-index
7264
+ * element appears first.
7265
+ *
7266
+ * @returns A tuple of `(values, indices)`, where `values` and `indices` have
7267
+ * the same shape as `x`, except along `axis` where they have size `k`.
7268
+ */
7269
+ function topK(x, k, axis = -1) {
7270
+ x = fudgeArray(x);
7271
+ axis = require_backend.checkAxis(axis, x.ndim);
7272
+ const size$1 = x.shape[axis];
7273
+ if (k < 0 || k > size$1) throw new Error(`topK: k must be in the range [0, ${size$1}], got ${k}`);
7274
+ if (k === 0) {
7275
+ const outShape = x.shape.slice();
7276
+ outShape[axis] = 0;
7277
+ const y$1 = zerosLike$1(x.ref, { shape: outShape });
7278
+ const yi$1 = zerosLike$1(x, {
7279
+ dtype: require_backend.DType.Int32,
7280
+ shape: outShape
7281
+ });
7282
+ return [y$1, yi$1];
7283
+ }
7284
+ x = flip$1(x, [axis]);
7285
+ x = moveaxis(x, axis, -1);
7286
+ const [y, yi] = argsort$1(x);
7287
+ const extract = (a) => {
7288
+ a = a.slice(...require_backend.rep(a.ndim - 1, []), [-k]);
7289
+ return flip$1(moveaxis(a, -1, axis), [axis]);
7290
+ };
7291
+ return [extract(y), extract(yi.neg().add(size$1 - 1))];
7292
+ }
7252
7293
 
7253
7294
  //#endregion
7254
7295
  //#region src/library/nn.ts
@@ -7440,7 +7481,7 @@ const gelu = jit$1(function gelu$1(x, opts) {
7440
7481
  if (opts?.approximate ?? true) {
7441
7482
  const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
7442
7483
  return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
7443
- } else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
7484
+ } else return x.ref.mul(.5).mul(erfc$1(negative(x.mul(Math.SQRT1_2))));
7444
7485
  }, { staticArgnums: [1] });
7445
7486
  /**
7446
7487
  * Gated linear unit (GLU) activation function.
@@ -7698,6 +7739,7 @@ var random_exports = {};
7698
7739
  __export(random_exports, {
7699
7740
  bernoulli: () => bernoulli,
7700
7741
  bits: () => bits,
7742
+ categorical: () => categorical,
7701
7743
  cauchy: () => cauchy,
7702
7744
  exponential: () => exponential,
7703
7745
  gumbel: () => gumbel,
@@ -7769,6 +7811,47 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
7769
7811
  }
7770
7812
  /**
7771
7813
  * @function
7814
+ * Sample random values from categorical distributions.
7815
+ *
7816
+ * Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
7817
+ * trick for sampling without replacement.
7818
+ *
7819
+ * Note: Sampling without replacement currently uses argsort and slices the last
7820
+ * k elements. This should be replaced with a more efficient topK implementation.
7821
+ *
7822
+ * - `key` - PRNG key
7823
+ * - `logits` - Unnormalized log probabilities of the categorical distribution(s).
7824
+ * `softmax(logits, axis)` gives the corresponding probabilities.
7825
+ * - `axis` - Axis along which logits belong to the same categorical distribution.
7826
+ * - `shape` - Result batch shape. Must be broadcast-compatible with
7827
+ * `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
7828
+ * - `replace` - If true (default), sample with replacement. If false, sample
7829
+ * without replacement (each category can only be selected once per batch).
7830
+ * @returns A random array with int dtype and shape given by `shape` if provided,
7831
+ * otherwise `logits.shape` with `axis` removed.
7832
+ */
7833
+ const categorical = jit$1(function categorical$1(key$1, logits, { axis = -1, shape: shape$1, replace = true } = {}) {
7834
+ logits = fudgeArray(logits);
7835
+ axis = require_backend.checkAxis(axis, logits.ndim);
7836
+ const numCategories = logits.shape[axis];
7837
+ const batchShape = logits.shape.toSpliced(axis, 1);
7838
+ if (shape$1 === void 0) shape$1 = batchShape;
7839
+ else if (!require_backend.deepEqual(require_backend.generalBroadcast(shape$1, batchShape), shape$1)) throw new Error(`Shape ${shape$1} is not broadcast-compatible with batch shape ${batchShape}.`);
7840
+ const shapePrefix = shape$1.slice(0, shape$1.length - batchShape.length);
7841
+ if (replace) {
7842
+ const noise = gumbel(key$1, [...shapePrefix, ...logits.shape]);
7843
+ return argmax(noise.add(logits), axis + shapePrefix.length);
7844
+ } else {
7845
+ const k = shapePrefix.reduce((a, b) => a * b, 1);
7846
+ if (k > numCategories) throw new Error(`Number of samples without replacement (${k}) cannot exceed number of categories (${numCategories}).`);
7847
+ const noise = gumbel(key$1, logits.shape);
7848
+ const [values, indices] = topK(noise.add(logits), k, axis);
7849
+ values.dispose();
7850
+ return indices.reshape(shape$1);
7851
+ }
7852
+ }, { staticArgnums: [2] });
7853
+ /**
7854
+ * @function
7772
7855
  * Sample from a Cauchy distribution with location 0 and scale 1.
7773
7856
  *
7774
7857
  * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
package/dist/index.d.cts CHANGED
@@ -436,9 +436,14 @@ declare class Routine {
436
436
  }
437
437
  /** One of the valid `Routine` that can be dispatched to backend. */
438
438
  declare enum Routines {
439
- /** Stable sorting algorithm along the last axis. */
439
+ /**
440
+ * Sort along the last axis.
441
+ *
442
+ * This may be _unstable_ but it often doesn't matter, sorting numbers is
443
+ * bitwise unique up to signed zeros and NaNs.
444
+ */
440
445
  Sort = "Sort",
441
- /** Returns `int32` indices of the stably sorted array. */
446
+ /** Stable sorting, returns `int32` indices and values of the sorted array. */
442
447
  Argsort = "Argsort",
443
448
  /**
444
449
  * Solve a triangular system of equations.
@@ -750,9 +755,9 @@ declare enum Primitive {
750
755
  Shrink = "shrink",
751
756
  Pad = "pad",
752
757
  Sort = "sort",
753
- // sort(x, axis=-1)
758
+ // sort(x, axis=-1), unstable
754
759
  Argsort = "argsort",
755
- // argsort(x, axis=-1)
760
+ // argsort(x, axis=-1), stable
756
761
  TriangularSolve = "triangular_solve",
757
762
  // A is upper triangular, A @ X.T = B.T
758
763
  Cholesky = "cholesky",
@@ -1029,8 +1034,9 @@ declare abstract class Tracer {
1029
1034
  */
1030
1035
  sort(axis?: number): this;
1031
1036
  /**
1032
- * Return the indices that would sort an array. This may not be a stable
1033
- * sorting algorithm; it need not preserve order of indices in ties.
1037
+ * Return the indices that would sort an array. Unlike `sort`, this is
1038
+ * guaranteed to be a stable sorting algorithm; it always returns the smaller
1039
+ * index first in event of ties.
1034
1040
  *
1035
1041
  * See `jax.numpy.argsort` for full docs.
1036
1042
  */
@@ -1112,6 +1118,12 @@ type DTypeAndDevice = {
1112
1118
  dtype?: DType;
1113
1119
  device?: Device;
1114
1120
  };
1121
+ /** @inline */
1122
+ type DTypeShapeAndDevice = {
1123
+ dtype?: DType;
1124
+ shape?: number[];
1125
+ device?: Device;
1126
+ };
1115
1127
  type ArrayConstructorArgs = {
1116
1128
  source: AluExp | Slot;
1117
1129
  st: ShapeTracker;
@@ -1221,15 +1233,9 @@ declare function array(values: Array | DataArray | RecursiveArray<number> | Recu
1221
1233
  type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
1222
1234
  declare const implRules: { [P in Primitive]: ImplRule<P> };
1223
1235
  /** Return a new array of given shape and type, filled with zeros. */
1224
- declare function zeros(shape: number[], {
1225
- dtype,
1226
- device
1227
- }?: DTypeAndDevice): Array;
1236
+ declare function zeros(shape: number[], opts?: DTypeAndDevice): Array;
1228
1237
  /** Return a new array of given shape and type, filled with ones. */
1229
- declare function ones(shape: number[], {
1230
- dtype,
1231
- device
1232
- }?: DTypeAndDevice): Array;
1238
+ declare function ones(shape: number[], opts?: DTypeAndDevice): Array;
1233
1239
  /** Return a new array of given shape and type, filled with `fill_value`. */
1234
1240
  declare function full(shape: number[], fillValue: number | boolean | Array, {
1235
1241
  dtype,
@@ -1421,7 +1427,7 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
1421
1427
  unitDiagonal?: boolean;
1422
1428
  }): Array;
1423
1429
  declare namespace lax_d_exports {
1424
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
1430
+ export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
1425
1431
  }
1426
1432
  /**
1427
1433
  * Dimension numbers for general `dot()` primitive.
@@ -1527,6 +1533,16 @@ declare function erfc(x: ArrayLike): Array;
1527
1533
  * forward or reverse-mode automatic differentiation.
1528
1534
  */
1529
1535
  declare function stopGradient(x: ArrayLike): Array;
1536
+ /**
1537
+ * Returns top `k` values and their indices along the specified axis of operand.
1538
+ *
1539
+ * This is a _stable_ algorithm: If two elements are equal, the lower-index
1540
+ * element appears first.
1541
+ *
1542
+ * @returns A tuple of `(values, indices)`, where `values` and `indices` have
1543
+ * the same shape as `x`, except along `axis` where they have size `k`.
1544
+ */
1545
+ declare function topK(x: ArrayLike, k: number, axis?: number): [Array, Array];
1530
1546
  declare namespace numpy_fft_d_exports {
1531
1547
  export { ComplexPair, fft, ifft };
1532
1548
  }
@@ -1752,17 +1768,17 @@ declare const shape$1: (x: ArrayLike) => number[];
1752
1768
  * @function
1753
1769
  * Return an array of zeros with the same shape and type as a given array.
1754
1770
  */
1755
- declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
1771
+ declare const zerosLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
1756
1772
  /**
1757
1773
  * @function
1758
1774
  * Return an array of ones with the same shape and type as a given array.
1759
1775
  */
1760
- declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
1776
+ declare const onesLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
1761
1777
  /**
1762
1778
  * @function
1763
1779
  * Return a full array with the same shape and type as a given array.
1764
1780
  */
1765
- declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
1781
+ declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, opts?: DTypeShapeAndDevice) => Array;
1766
1782
  /**
1767
1783
  * Return the number of elements in an array, optionally along an axis.
1768
1784
  * Does not consume array reference.
@@ -1951,8 +1967,9 @@ declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: nu
1951
1967
  */
1952
1968
  declare function sort(a: ArrayLike, axis?: number): Array;
1953
1969
  /**
1954
- * Return indices that would sort an array. This may be an unstable sorting
1955
- * algorithm; it need not preserve order of indices in ties.
1970
+ * Return indices that would sort an array. Unlike `sort`, this is guaranteed to
1971
+ * be a stable sorting algorithm; it always returns the smaller index first in
1972
+ * event of ties.
1956
1973
  *
1957
1974
  * Returns an array of `int32` indices.
1958
1975
  *
@@ -2564,7 +2581,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
2564
2581
  localWindowSize?: number | [number, number];
2565
2582
  }): Array;
2566
2583
  declare namespace random_d_exports {
2567
- export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2584
+ export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2568
2585
  }
2569
2586
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
2570
2587
  declare function key(seed: ArrayLike): Array;
@@ -2587,6 +2604,32 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
2587
2604
  * and must be broadcastable to `shape`.
2588
2605
  */
2589
2606
  declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
2607
+ /**
2608
+ * @function
2609
+ * Sample random values from categorical distributions.
2610
+ *
2611
+ * Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
2612
+ * trick for sampling without replacement.
2613
+ *
2614
+ * Note: Sampling without replacement currently uses argsort and slices the last
2615
+ * k elements. This should be replaced with a more efficient topK implementation.
2616
+ *
2617
+ * - `key` - PRNG key
2618
+ * - `logits` - Unnormalized log probabilities of the categorical distribution(s).
2619
+ * `softmax(logits, axis)` gives the corresponding probabilities.
2620
+ * - `axis` - Axis along which logits belong to the same categorical distribution.
2621
+ * - `shape` - Result batch shape. Must be broadcast-compatible with
2622
+ * `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
2623
+ * - `replace` - If true (default), sample with replacement. If false, sample
2624
+ * without replacement (each category can only be selected once per batch).
2625
+ * @returns A random array with int dtype and shape given by `shape` if provided,
2626
+ * otherwise `logits.shape` with `axis` removed.
2627
+ */
2628
+ declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, args_2?: {
2629
+ axis?: number | undefined;
2630
+ shape?: number[] | undefined;
2631
+ replace?: boolean | undefined;
2632
+ } | undefined) => Array>;
2590
2633
  /**
2591
2634
  * @function
2592
2635
  * Sample from a Cauchy distribution with location 0 and scale 1.
package/dist/index.d.ts CHANGED
@@ -433,9 +433,14 @@ declare class Routine {
433
433
  }
434
434
  /** One of the valid `Routine` that can be dispatched to backend. */
435
435
  declare enum Routines {
436
- /** Stable sorting algorithm along the last axis. */
436
+ /**
437
+ * Sort along the last axis.
438
+ *
439
+ * This may be _unstable_ but it often doesn't matter, sorting numbers is
440
+ * bitwise unique up to signed zeros and NaNs.
441
+ */
437
442
  Sort = "Sort",
438
- /** Returns `int32` indices of the stably sorted array. */
443
+ /** Stable sorting, returns `int32` indices and values of the sorted array. */
439
444
  Argsort = "Argsort",
440
445
  /**
441
446
  * Solve a triangular system of equations.
@@ -747,9 +752,9 @@ declare enum Primitive {
747
752
  Shrink = "shrink",
748
753
  Pad = "pad",
749
754
  Sort = "sort",
750
- // sort(x, axis=-1)
755
+ // sort(x, axis=-1), unstable
751
756
  Argsort = "argsort",
752
- // argsort(x, axis=-1)
757
+ // argsort(x, axis=-1), stable
753
758
  TriangularSolve = "triangular_solve",
754
759
  // A is upper triangular, A @ X.T = B.T
755
760
  Cholesky = "cholesky",
@@ -1026,8 +1031,9 @@ declare abstract class Tracer {
1026
1031
  */
1027
1032
  sort(axis?: number): this;
1028
1033
  /**
1029
- * Return the indices that would sort an array. This may not be a stable
1030
- * sorting algorithm; it need not preserve order of indices in ties.
1034
+ * Return the indices that would sort an array. Unlike `sort`, this is
1035
+ * guaranteed to be a stable sorting algorithm; it always returns the smaller
1036
+ * index first in event of ties.
1031
1037
  *
1032
1038
  * See `jax.numpy.argsort` for full docs.
1033
1039
  */
@@ -1109,6 +1115,12 @@ type DTypeAndDevice = {
1109
1115
  dtype?: DType;
1110
1116
  device?: Device;
1111
1117
  };
1118
+ /** @inline */
1119
+ type DTypeShapeAndDevice = {
1120
+ dtype?: DType;
1121
+ shape?: number[];
1122
+ device?: Device;
1123
+ };
1112
1124
  type ArrayConstructorArgs = {
1113
1125
  source: AluExp | Slot;
1114
1126
  st: ShapeTracker;
@@ -1218,15 +1230,9 @@ declare function array(values: Array | DataArray | RecursiveArray<number> | Recu
1218
1230
  type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
1219
1231
  declare const implRules: { [P in Primitive]: ImplRule<P> };
1220
1232
  /** Return a new array of given shape and type, filled with zeros. */
1221
- declare function zeros(shape: number[], {
1222
- dtype,
1223
- device
1224
- }?: DTypeAndDevice): Array;
1233
+ declare function zeros(shape: number[], opts?: DTypeAndDevice): Array;
1225
1234
  /** Return a new array of given shape and type, filled with ones. */
1226
- declare function ones(shape: number[], {
1227
- dtype,
1228
- device
1229
- }?: DTypeAndDevice): Array;
1235
+ declare function ones(shape: number[], opts?: DTypeAndDevice): Array;
1230
1236
  /** Return a new array of given shape and type, filled with `fill_value`. */
1231
1237
  declare function full(shape: number[], fillValue: number | boolean | Array, {
1232
1238
  dtype,
@@ -1418,7 +1424,7 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
1418
1424
  unitDiagonal?: boolean;
1419
1425
  }): Array;
1420
1426
  declare namespace lax_d_exports {
1421
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
1427
+ export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
1422
1428
  }
1423
1429
  /**
1424
1430
  * Dimension numbers for general `dot()` primitive.
@@ -1524,6 +1530,16 @@ declare function erfc(x: ArrayLike): Array;
1524
1530
  * forward or reverse-mode automatic differentiation.
1525
1531
  */
1526
1532
  declare function stopGradient(x: ArrayLike): Array;
1533
+ /**
1534
+ * Returns top `k` values and their indices along the specified axis of operand.
1535
+ *
1536
+ * This is a _stable_ algorithm: If two elements are equal, the lower-index
1537
+ * element appears first.
1538
+ *
1539
+ * @returns A tuple of `(values, indices)`, where `values` and `indices` have
1540
+ * the same shape as `x`, except along `axis` where they have size `k`.
1541
+ */
1542
+ declare function topK(x: ArrayLike, k: number, axis?: number): [Array, Array];
1527
1543
  declare namespace numpy_fft_d_exports {
1528
1544
  export { ComplexPair, fft, ifft };
1529
1545
  }
@@ -1749,17 +1765,17 @@ declare const shape$1: (x: ArrayLike) => number[];
1749
1765
  * @function
1750
1766
  * Return an array of zeros with the same shape and type as a given array.
1751
1767
  */
1752
- declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
1768
+ declare const zerosLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
1753
1769
  /**
1754
1770
  * @function
1755
1771
  * Return an array of ones with the same shape and type as a given array.
1756
1772
  */
1757
- declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
1773
+ declare const onesLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
1758
1774
  /**
1759
1775
  * @function
1760
1776
  * Return a full array with the same shape and type as a given array.
1761
1777
  */
1762
- declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
1778
+ declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, opts?: DTypeShapeAndDevice) => Array;
1763
1779
  /**
1764
1780
  * Return the number of elements in an array, optionally along an axis.
1765
1781
  * Does not consume array reference.
@@ -1948,8 +1964,9 @@ declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: nu
1948
1964
  */
1949
1965
  declare function sort(a: ArrayLike, axis?: number): Array;
1950
1966
  /**
1951
- * Return indices that would sort an array. This may be an unstable sorting
1952
- * algorithm; it need not preserve order of indices in ties.
1967
+ * Return indices that would sort an array. Unlike `sort`, this is guaranteed to
1968
+ * be a stable sorting algorithm; it always returns the smaller index first in
1969
+ * event of ties.
1953
1970
  *
1954
1971
  * Returns an array of `int32` indices.
1955
1972
  *
@@ -2561,7 +2578,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
2561
2578
  localWindowSize?: number | [number, number];
2562
2579
  }): Array;
2563
2580
  declare namespace random_d_exports {
2564
- export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2581
+ export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2565
2582
  }
2566
2583
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
2567
2584
  declare function key(seed: ArrayLike): Array;
@@ -2584,6 +2601,32 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
2584
2601
  * and must be broadcastable to `shape`.
2585
2602
  */
2586
2603
  declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
2604
+ /**
2605
+ * @function
2606
+ * Sample random values from categorical distributions.
2607
+ *
2608
+ * Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
2609
+ * trick for sampling without replacement.
2610
+ *
2611
+ * Note: Sampling without replacement currently uses argsort and slices the last
2612
+ * k elements. This should be replaced with a more efficient topK implementation.
2613
+ *
2614
+ * - `key` - PRNG key
2615
+ * - `logits` - Unnormalized log probabilities of the categorical distribution(s).
2616
+ * `softmax(logits, axis)` gives the corresponding probabilities.
2617
+ * - `axis` - Axis along which logits belong to the same categorical distribution.
2618
+ * - `shape` - Result batch shape. Must be broadcast-compatible with
2619
+ * `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
2620
+ * - `replace` - If true (default), sample with replacement. If false, sample
2621
+ * without replacement (each category can only be selected once per batch).
2622
+ * @returns A random array with int dtype and shape given by `shape` if provided,
2623
+ * otherwise `logits.shape` with `axis` removed.
2624
+ */
2625
+ declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, args_2?: {
2626
+ axis?: number | undefined;
2627
+ shape?: number[] | undefined;
2628
+ replace?: boolean | undefined;
2629
+ } | undefined) => Array>;
2587
2630
  /**
2588
2631
  * @function
2589
2632
  * Sample from a Cauchy distribution with location 0 and scale 1.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-nEolvdLv.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-BId79r5b.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -889,18 +889,25 @@ var Tracer = class Tracer {
889
889
  return sort$1(this.transpose(perm)).transpose(invertPermutation(perm));
890
890
  }
891
891
  /**
892
- * Return the indices that would sort an array. This may not be a stable
893
- * sorting algorithm; it need not preserve order of indices in ties.
892
+ * Return the indices that would sort an array. Unlike `sort`, this is
893
+ * guaranteed to be a stable sorting algorithm; it always returns the smaller
894
+ * index first in event of ties.
894
895
  *
895
896
  * See `jax.numpy.argsort` for full docs.
896
897
  */
897
898
  argsort(axis = -1) {
898
899
  axis = checkAxis(axis, this.ndim);
899
- if (axis === this.ndim - 1) return argsort$1(this)[1];
900
+ if (axis === this.ndim - 1) {
901
+ const [y$1, yi$1] = argsort$1(this);
902
+ y$1.dispose();
903
+ return yi$1;
904
+ }
900
905
  const perm = range(this.ndim);
901
906
  perm.splice(axis, 1);
902
907
  perm.push(axis);
903
- return argsort$1(this.transpose(perm))[1].transpose(invertPermutation(perm));
908
+ const [y, yi] = argsort$1(this.transpose(perm));
909
+ y.dispose();
910
+ return yi.transpose(invertPermutation(perm));
904
911
  }
905
912
  /**
906
913
  * Slice an array along one or more axes.
@@ -3381,32 +3388,26 @@ function fullInternal(aval, fillValue, device) {
3381
3388
  committed: device != void 0
3382
3389
  });
3383
3390
  }
3384
- function zerosLike$1(val, dtype) {
3385
- return fullLike(val, 0, dtype);
3391
+ function zerosLike$1(val, opts) {
3392
+ return fullLike(val, 0, opts);
3386
3393
  }
3387
- function onesLike$1(val, dtype) {
3388
- return fullLike(val, 1, dtype);
3394
+ function onesLike$1(val, opts) {
3395
+ return fullLike(val, 1, opts);
3389
3396
  }
3390
- function fullLike(val, fillValue, dtype) {
3397
+ function fullLike(val, fillValue, { dtype, shape: shape$1, device } = {}) {
3391
3398
  const aval = getAval(val);
3392
3399
  if (val instanceof Tracer) val.dispose();
3393
3400
  if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
3394
- const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
3395
- return fullInternal(sa, fillValue);
3401
+ const sa = new ShapedArray(shape$1 ?? aval.shape, dtype ?? aval.dtype, aval.weakType && dtype === void 0);
3402
+ return fullInternal(sa, fillValue, device);
3396
3403
  }
3397
3404
  /** Return a new array of given shape and type, filled with zeros. */
3398
- function zeros(shape$1, { dtype, device } = {}) {
3399
- return full(shape$1, 0, {
3400
- dtype,
3401
- device
3402
- });
3405
+ function zeros(shape$1, opts) {
3406
+ return full(shape$1, 0, opts);
3403
3407
  }
3404
3408
  /** Return a new array of given shape and type, filled with ones. */
3405
- function ones(shape$1, { dtype, device } = {}) {
3406
- return full(shape$1, 1, {
3407
- dtype,
3408
- device
3409
- });
3409
+ function ones(shape$1, opts) {
3410
+ return full(shape$1, 1, opts);
3410
3411
  }
3411
3412
  /** Return a new array of given shape and type, filled with `fill_value`. */
3412
3413
  function full(shape$1, fillValue, { dtype, device } = {}) {
@@ -5292,9 +5293,10 @@ function lstsq(a, b) {
5292
5293
  });
5293
5294
  const llb = triangularSolve(l, lb, {
5294
5295
  leftSide: true,
5296
+ lower: true,
5295
5297
  transposeA: true
5296
5298
  });
5297
- return matmul(at, llb.ref);
5299
+ return matmul(at, llb);
5298
5300
  } else {
5299
5301
  const ata = matmul(at.ref, a);
5300
5302
  const l = cholesky(ata, { symmetrizeInput: false });
@@ -5305,6 +5307,7 @@ function lstsq(a, b) {
5305
5307
  });
5306
5308
  const llb = triangularSolve(l, lb, {
5307
5309
  leftSide: true,
5310
+ lower: true,
5308
5311
  transposeA: true
5309
5312
  });
5310
5313
  return llb;
@@ -5384,7 +5387,7 @@ function solve(a, b) {
5384
5387
  lower: true,
5385
5388
  unitDiagonal: true
5386
5389
  });
5387
- let x = triangularSolve(lu$2, LPb.ref, {
5390
+ let x = triangularSolve(lu$2, LPb, {
5388
5391
  leftSide: true,
5389
5392
  lower: false
5390
5393
  });
@@ -6195,8 +6198,9 @@ function sort(a, axis = -1) {
6195
6198
  return fudgeArray(a).sort(axis);
6196
6199
  }
6197
6200
  /**
6198
- * Return indices that would sort an array. This may be an unstable sorting
6199
- * algorithm; it need not preserve order of indices in ties.
6201
+ * Return indices that would sort an array. Unlike `sort`, this is guaranteed to
6202
+ * be a stable sorting algorithm; it always returns the smaller index first in
6203
+ * event of ties.
6200
6204
  *
6201
6205
  * Returns an array of `int32` indices.
6202
6206
  *
@@ -6498,7 +6502,7 @@ function absolute(x) {
6498
6502
  /** Return an element-wise indication of sign of the input. */
6499
6503
  function sign(x) {
6500
6504
  x = fudgeArray(x);
6501
- return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
6505
+ return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
6502
6506
  }
6503
6507
  /** @function Return element-wise positive values of the input (no-op). */
6504
6508
  const positive = fudgeArray;
@@ -6966,7 +6970,10 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
6966
6970
  b = fudgeArray(b);
6967
6971
  if (!leftSide) transposeA = !transposeA;
6968
6972
  else b = moveaxis$1(b, -2, -1);
6969
- if (transposeA) a = moveaxis$1(a, -2, -1);
6973
+ if (transposeA) {
6974
+ a = moveaxis$1(a, -2, -1);
6975
+ lower = !lower;
6976
+ }
6970
6977
  let x = triangularSolve$1(a, b, {
6971
6978
  lower,
6972
6979
  unitDiagonal
@@ -6988,7 +6995,8 @@ __export(lax_exports, {
6988
6995
  erfc: () => erfc,
6989
6996
  linalg: () => lax_linalg_exports,
6990
6997
  reduceWindow: () => reduceWindow,
6991
- stopGradient: () => stopGradient$1
6998
+ stopGradient: () => stopGradient$1,
6999
+ topK: () => topK
6992
7000
  });
6993
7001
  const JsArray = globalThis.Array;
6994
7002
  /**
@@ -7212,6 +7220,39 @@ function erfc(x) {
7212
7220
  function stopGradient$1(x) {
7213
7221
  return stopGradient(x);
7214
7222
  }
7223
+ /**
7224
+ * Returns top `k` values and their indices along the specified axis of operand.
7225
+ *
7226
+ * This is a _stable_ algorithm: If two elements are equal, the lower-index
7227
+ * element appears first.
7228
+ *
7229
+ * @returns A tuple of `(values, indices)`, where `values` and `indices` have
7230
+ * the same shape as `x`, except along `axis` where they have size `k`.
7231
+ */
7232
+ function topK(x, k, axis = -1) {
7233
+ x = fudgeArray(x);
7234
+ axis = checkAxis(axis, x.ndim);
7235
+ const size$1 = x.shape[axis];
7236
+ if (k < 0 || k > size$1) throw new Error(`topK: k must be in the range [0, ${size$1}], got ${k}`);
7237
+ if (k === 0) {
7238
+ const outShape = x.shape.slice();
7239
+ outShape[axis] = 0;
7240
+ const y$1 = zerosLike$1(x.ref, { shape: outShape });
7241
+ const yi$1 = zerosLike$1(x, {
7242
+ dtype: DType.Int32,
7243
+ shape: outShape
7244
+ });
7245
+ return [y$1, yi$1];
7246
+ }
7247
+ x = flip$1(x, [axis]);
7248
+ x = moveaxis(x, axis, -1);
7249
+ const [y, yi] = argsort$1(x);
7250
+ const extract = (a) => {
7251
+ a = a.slice(...rep(a.ndim - 1, []), [-k]);
7252
+ return flip$1(moveaxis(a, -1, axis), [axis]);
7253
+ };
7254
+ return [extract(y), extract(yi.neg().add(size$1 - 1))];
7255
+ }
7215
7256
 
7216
7257
  //#endregion
7217
7258
  //#region src/library/nn.ts
@@ -7403,7 +7444,7 @@ const gelu = jit$1(function gelu$1(x, opts) {
7403
7444
  if (opts?.approximate ?? true) {
7404
7445
  const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
7405
7446
  return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
7406
- } else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
7447
+ } else return x.ref.mul(.5).mul(erfc$1(negative(x.mul(Math.SQRT1_2))));
7407
7448
  }, { staticArgnums: [1] });
7408
7449
  /**
7409
7450
  * Gated linear unit (GLU) activation function.
@@ -7661,6 +7702,7 @@ var random_exports = {};
7661
7702
  __export(random_exports, {
7662
7703
  bernoulli: () => bernoulli,
7663
7704
  bits: () => bits,
7705
+ categorical: () => categorical,
7664
7706
  cauchy: () => cauchy,
7665
7707
  exponential: () => exponential,
7666
7708
  gumbel: () => gumbel,
@@ -7732,6 +7774,47 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
7732
7774
  }
7733
7775
  /**
7734
7776
  * @function
7777
+ * Sample random values from categorical distributions.
7778
+ *
7779
+ * Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
7780
+ * trick for sampling without replacement.
7781
+ *
7782
+ * Note: Sampling without replacement currently uses argsort and slices the last
7783
+ * k elements. This should be replaced with a more efficient topK implementation.
7784
+ *
7785
+ * - `key` - PRNG key
7786
+ * - `logits` - Unnormalized log probabilities of the categorical distribution(s).
7787
+ * `softmax(logits, axis)` gives the corresponding probabilities.
7788
+ * - `axis` - Axis along which logits belong to the same categorical distribution.
7789
+ * - `shape` - Result batch shape. Must be broadcast-compatible with
7790
+ * `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
7791
+ * - `replace` - If true (default), sample with replacement. If false, sample
7792
+ * without replacement (each category can only be selected once per batch).
7793
+ * @returns A random array with int dtype and shape given by `shape` if provided,
7794
+ * otherwise `logits.shape` with `axis` removed.
7795
+ */
7796
+ const categorical = jit$1(function categorical$1(key$1, logits, { axis = -1, shape: shape$1, replace = true } = {}) {
7797
+ logits = fudgeArray(logits);
7798
+ axis = checkAxis(axis, logits.ndim);
7799
+ const numCategories = logits.shape[axis];
7800
+ const batchShape = logits.shape.toSpliced(axis, 1);
7801
+ if (shape$1 === void 0) shape$1 = batchShape;
7802
+ else if (!deepEqual(generalBroadcast(shape$1, batchShape), shape$1)) throw new Error(`Shape ${shape$1} is not broadcast-compatible with batch shape ${batchShape}.`);
7803
+ const shapePrefix = shape$1.slice(0, shape$1.length - batchShape.length);
7804
+ if (replace) {
7805
+ const noise = gumbel(key$1, [...shapePrefix, ...logits.shape]);
7806
+ return argmax(noise.add(logits), axis + shapePrefix.length);
7807
+ } else {
7808
+ const k = shapePrefix.reduce((a, b) => a * b, 1);
7809
+ if (k > numCategories) throw new Error(`Number of samples without replacement (${k}) cannot exceed number of categories (${numCategories}).`);
7810
+ const noise = gumbel(key$1, logits.shape);
7811
+ const [values, indices] = topK(noise.add(logits), k, axis);
7812
+ values.dispose();
7813
+ return indices.reshape(shape$1);
7814
+ }
7815
+ }, { staticArgnums: [2] });
7816
+ /**
7817
+ * @function
7735
7818
  * Sample from a Cauchy distribution with location 0 and scale 1.
7736
7819
  *
7737
7820
  * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-B3foXiV_.cjs');
1
+ const require_backend = require('./backend-DpI0riom.cjs');
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-nEolvdLv.js";
1
+ import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-BId79r5b.js";
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-nEolvdLv.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-BId79r5b.js";
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -247,6 +247,10 @@ function bitonicSortUniform(pass) {
247
247
  * `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
248
248
  *
249
249
  * The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
250
+ *
251
+ * If `outputIndices` is true, the shader also tracks the original indices of
252
+ * the sorted elements (argsort) and outputs them to a separate buffer. This
253
+ * also makes the sorting algorithm stable.
250
254
  */
251
255
  function bitonicSortShader(device, dtype, n, batches, outputIndices) {
252
256
  const ty = dtypeToWgsl(dtype, true);
@@ -286,14 +290,21 @@ ${isFloatDtype(dtype) ? `
286
290
  fn compare_and_swap(i: u32, j: u32) {
287
291
  let val_i = shared_vals[i];
288
292
  let val_j = shared_vals[j];
289
- if (compare(val_j, val_i)) {
293
+ ${outputIndices ? `
294
+ if (
295
+ compare(val_j, val_i) ||
296
+ (!compare(val_i, val_j) && shared_idx[j] < shared_idx[i])
297
+ ) {
290
298
  shared_vals[i] = val_j;
291
299
  shared_vals[j] = val_i;
292
- ${outputIndices ? `
293
300
  let tmp_idx = shared_idx[i];
294
301
  shared_idx[i] = shared_idx[j];
295
- shared_idx[j] = tmp_idx;` : ""}
296
- }
302
+ shared_idx[j] = tmp_idx;
303
+ }` : `
304
+ if (compare(val_j, val_i)) {
305
+ shared_vals[i] = val_j;
306
+ shared_vals[j] = val_i;
307
+ }`}
297
308
  }
298
309
 
299
310
  @compute @workgroup_size(${workgroupSize})
@@ -370,13 +381,17 @@ ${outputIndices ? `
370
381
  if (j < ${n}u) {
371
382
  let val_i = output[base + i];
372
383
  let val_j = output[base + j];
373
- if (compare(val_j, val_i)) {
384
+ ${outputIndices ? `
385
+ let idx_i = output_idx[base + i];
386
+ let idx_j = output_idx[base + j];
387
+ if (compare(val_j, val_i) || (!compare(val_i, val_j) && idx_j < idx_i)) {
374
388
  output[base + i] = val_j;
375
389
  output[base + j] = val_i;
376
- ${outputIndices ? `
377
- let tmp_idx = output_idx[base + i];
378
- output_idx[base + i] = output_idx[base + j];
379
- output_idx[base + j] = tmp_idx;` : ""}
390
+ output_idx[base + i] = idx_j;
391
+ output_idx[base + j] = idx_i;` : `
392
+ if (compare(val_j, val_i)) {
393
+ output[base + i] = val_j;
394
+ output[base + j] = val_i;`}
380
395
  }
381
396
  }
382
397
  }
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-B3foXiV_.cjs');
1
+ const require_backend = require('./backend-DpI0riom.cjs');
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -247,6 +247,10 @@ function bitonicSortUniform(pass) {
247
247
  * `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
248
248
  *
249
249
  * The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
250
+ *
251
+ * If `outputIndices` is true, the shader also tracks the original indices of
252
+ * the sorted elements (argsort) and outputs them to a separate buffer. This
253
+ * also makes the sorting algorithm stable.
250
254
  */
251
255
  function bitonicSortShader(device, dtype, n, batches, outputIndices) {
252
256
  const ty = dtypeToWgsl(dtype, true);
@@ -286,14 +290,21 @@ ${require_backend.isFloatDtype(dtype) ? `
286
290
  fn compare_and_swap(i: u32, j: u32) {
287
291
  let val_i = shared_vals[i];
288
292
  let val_j = shared_vals[j];
289
- if (compare(val_j, val_i)) {
293
+ ${outputIndices ? `
294
+ if (
295
+ compare(val_j, val_i) ||
296
+ (!compare(val_i, val_j) && shared_idx[j] < shared_idx[i])
297
+ ) {
290
298
  shared_vals[i] = val_j;
291
299
  shared_vals[j] = val_i;
292
- ${outputIndices ? `
293
300
  let tmp_idx = shared_idx[i];
294
301
  shared_idx[i] = shared_idx[j];
295
- shared_idx[j] = tmp_idx;` : ""}
296
- }
302
+ shared_idx[j] = tmp_idx;
303
+ }` : `
304
+ if (compare(val_j, val_i)) {
305
+ shared_vals[i] = val_j;
306
+ shared_vals[j] = val_i;
307
+ }`}
297
308
  }
298
309
 
299
310
  @compute @workgroup_size(${workgroupSize})
@@ -370,13 +381,17 @@ ${outputIndices ? `
370
381
  if (j < ${n}u) {
371
382
  let val_i = output[base + i];
372
383
  let val_j = output[base + j];
373
- if (compare(val_j, val_i)) {
384
+ ${outputIndices ? `
385
+ let idx_i = output_idx[base + i];
386
+ let idx_j = output_idx[base + j];
387
+ if (compare(val_j, val_i) || (!compare(val_i, val_j) && idx_j < idx_i)) {
374
388
  output[base + i] = val_j;
375
389
  output[base + j] = val_i;
376
- ${outputIndices ? `
377
- let tmp_idx = output_idx[base + i];
378
- output_idx[base + i] = output_idx[base + j];
379
- output_idx[base + j] = tmp_idx;` : ""}
390
+ output_idx[base + i] = idx_j;
391
+ output_idx[base + j] = idx_i;` : `
392
+ if (compare(val_j, val_i)) {
393
+ output[base + i] = val_j;
394
+ output[base + j] = val_i;`}
380
395
  }
381
396
  }
382
397
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.1.7",
3
+ "version": "0.1.9",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",
@@ -44,6 +44,8 @@
44
44
  "eslint": "^9.31.0",
45
45
  "eslint-plugin-import": "^2.32.0",
46
46
  "globals": "^16.0.0",
47
+ "husky": "^9.1.7",
48
+ "lint-staged": "^16.2.7",
47
49
  "playwright": "~1.52.0",
48
50
  "prettier": "^3.6.2",
49
51
  "prettier-plugin-svelte": "^3.4.0",
@@ -74,6 +76,15 @@
74
76
  ],
75
77
  "proseWrap": "always"
76
78
  },
79
+ "lint-staged": {
80
+ "*.{ts,tsx,js,jsx}": [
81
+ "eslint --fix",
82
+ "prettier --write"
83
+ ],
84
+ "*.{json,md,yml,yaml,css,svelte,html}": [
85
+ "prettier --write"
86
+ ]
87
+ },
77
88
  "scripts": {
78
89
  "build": "tsdown",
79
90
  "build:watch": "TSDOWN_WATCH_MODE=1 tsdown",