@jax-js/jax 0.1.8 → 0.1.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -43,6 +43,23 @@ way to get started on a blank HTML page.
43
43
  </script>
44
44
  ```
45
45
 
46
+ ## Examples
47
+
48
+ Cool things that the community has made with jax-js:
49
+
50
+ - [**tanh.xyz**: Interactive ML visualizations](https://tanh.xyz/)
51
+
52
+ And some more demos on the official website.
53
+
54
+ - [Training neural networks on MNIST](https://jax-js.com/mnist)
55
+ - [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
56
+ - [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
57
+ - [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
58
+ - [In-browser REPL](https://jax-js.com/repl)
59
+ - [Matmul benchmark](https://jax-js.com/bench/matmul)
60
+ - [Conv2d benchmark](https://jax-js.com/bench/conv2d)
61
+ - [Mandelbrot set](https://jax-js.com/mandelbrot)
62
+
46
63
  ## Feature comparison
47
64
 
48
65
  Here's a quick, high-level comparison with other popular web ML runtimes:
@@ -338,19 +355,6 @@ well as unique optimizations such as FlashAttention variants.
338
355
  That's all for this short tutorial. Please see the generated
339
356
  [API reference](https://jax-js.com/docs) for detailed documentation.
340
357
 
341
- ## Examples
342
-
343
- If you make something cool with jax-js, don't be a stranger! We can feature it here.
344
-
345
- - [Training neural networks on MNIST](https://jax-js.com/mnist)
346
- - [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
347
- - [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
348
- - [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
349
- - [In-browser REPL](https://jax-js.com/repl)
350
- - [Matmul benchmark](https://jax-js.com/bench/matmul)
351
- - [Conv2d benchmark](https://jax-js.com/bench/conv2d)
352
- - [Mandelbrot set](https://jax-js.com/mandelbrot)
353
-
354
358
  ## Development
355
359
 
356
360
  _The following technical details are for contributing to jax-js and modifying its internals._
@@ -1479,9 +1479,14 @@ var Routine = class {
1479
1479
  };
1480
1480
  /** One of the valid `Routine` that can be dispatched to backend. */
1481
1481
  let Routines = /* @__PURE__ */ function(Routines$1) {
1482
- /** Stable sorting algorithm along the last axis. */
1482
+ /**
1483
+ * Sort along the last axis.
1484
+ *
1485
+ * This may be _unstable_ but it often doesn't matter, sorting numbers is
1486
+ * bitwise unique up to signed zeros and NaNs.
1487
+ */
1483
1488
  Routines$1["Sort"] = "Sort";
1484
- /** Returns `int32` indices of the stably sorted array. */
1489
+ /** Stable sorting, returns `int32` indices and values of the sorted array. */
1485
1490
  Routines$1["Argsort"] = "Argsort";
1486
1491
  /**
1487
1492
  * Solve a triangular system of equations.
@@ -1545,7 +1550,13 @@ function runArgsort(type, [x], [y, yi]) {
1545
1550
  const out = y.subarray(offset, offset + n);
1546
1551
  const outi = yi.subarray(offset, offset + n);
1547
1552
  for (let i = 0; i < n; i++) outi[i] = i;
1548
- outi.sort((a, b) => ar[a] - ar[b]);
1553
+ outi.sort((a, b) => {
1554
+ const x$1 = ar[a];
1555
+ const y$1 = ar[b];
1556
+ if (isNaN(x$1)) return isNaN(y$1) ? 0 : 1;
1557
+ if (isNaN(y$1)) return -1;
1558
+ return x$1 === y$1 ? 0 : x$1 < y$1 ? -1 : 1;
1559
+ });
1549
1560
  for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
1550
1561
  }
1551
1562
  }
@@ -2321,7 +2332,7 @@ function tuneWebgpu(kernel) {
2321
2332
  if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2322
2333
  const s = dim.st.shape[dim.unroll - 1];
2323
2334
  if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
2324
- else for (const splits of [8, 4]) if (s % splits === 0) {
2335
+ else for (const splits of [4, 2]) if (s % splits === 0) {
2325
2336
  dim.applyUnroll(dim.unroll - 1, splits);
2326
2337
  break;
2327
2338
  }
@@ -4252,7 +4263,7 @@ async function createBackend(device) {
4252
4263
  if (!navigator.gpu) return null;
4253
4264
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4254
4265
  if (!adapter) return null;
4255
- const { WebGPUBackend } = await import("./webgpu-B96vzWGE.js");
4266
+ const { WebGPUBackend } = await import("./webgpu-AN0cG_nB.js");
4256
4267
  const importantLimits = [
4257
4268
  "maxBufferSize",
4258
4269
  "maxComputeInvocationsPerWorkgroup",
@@ -4290,7 +4301,7 @@ async function createBackend(device) {
4290
4301
  });
4291
4302
  if (!gl) return null;
4292
4303
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
4293
- const { WebGLBackend } = await import("./webgl-DweKSWEm.js");
4304
+ const { WebGLBackend } = await import("./webgl-DnGrclTz.js");
4294
4305
  return new WebGLBackend(gl);
4295
4306
  } else throw new Error(`Backend not found: ${device}`);
4296
4307
  }
@@ -1480,9 +1480,14 @@ var Routine = class {
1480
1480
  };
1481
1481
  /** One of the valid `Routine` that can be dispatched to backend. */
1482
1482
  let Routines = /* @__PURE__ */ function(Routines$1) {
1483
- /** Stable sorting algorithm along the last axis. */
1483
+ /**
1484
+ * Sort along the last axis.
1485
+ *
1486
+ * This may be _unstable_ but it often doesn't matter, sorting numbers is
1487
+ * bitwise unique up to signed zeros and NaNs.
1488
+ */
1484
1489
  Routines$1["Sort"] = "Sort";
1485
- /** Returns `int32` indices of the stably sorted array. */
1490
+ /** Stable sorting, returns `int32` indices and values of the sorted array. */
1486
1491
  Routines$1["Argsort"] = "Argsort";
1487
1492
  /**
1488
1493
  * Solve a triangular system of equations.
@@ -1546,7 +1551,13 @@ function runArgsort(type, [x], [y, yi]) {
1546
1551
  const out = y.subarray(offset, offset + n);
1547
1552
  const outi = yi.subarray(offset, offset + n);
1548
1553
  for (let i = 0; i < n; i++) outi[i] = i;
1549
- outi.sort((a, b) => ar[a] - ar[b]);
1554
+ outi.sort((a, b) => {
1555
+ const x$1 = ar[a];
1556
+ const y$1 = ar[b];
1557
+ if (isNaN(x$1)) return isNaN(y$1) ? 0 : 1;
1558
+ if (isNaN(y$1)) return -1;
1559
+ return x$1 === y$1 ? 0 : x$1 < y$1 ? -1 : 1;
1560
+ });
1550
1561
  for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
1551
1562
  }
1552
1563
  }
@@ -2322,7 +2333,7 @@ function tuneWebgpu(kernel) {
2322
2333
  if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2323
2334
  const s = dim.st.shape[dim.unroll - 1];
2324
2335
  if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
2325
- else for (const splits of [8, 4]) if (s % splits === 0) {
2336
+ else for (const splits of [4, 2]) if (s % splits === 0) {
2326
2337
  dim.applyUnroll(dim.unroll - 1, splits);
2327
2338
  break;
2328
2339
  }
@@ -4253,7 +4264,7 @@ async function createBackend(device) {
4253
4264
  if (!navigator.gpu) return null;
4254
4265
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4255
4266
  if (!adapter) return null;
4256
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BykvF26B.cjs"));
4267
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-CdjiJSa7.cjs"));
4257
4268
  const importantLimits = [
4258
4269
  "maxBufferSize",
4259
4270
  "maxComputeInvocationsPerWorkgroup",
@@ -4291,7 +4302,7 @@ async function createBackend(device) {
4291
4302
  });
4292
4303
  if (!gl) return null;
4293
4304
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
4294
- const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-DIIbKJ0G.cjs"));
4305
+ const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-C5NjXc1p.cjs"));
4295
4306
  return new WebGLBackend(gl);
4296
4307
  } else throw new Error(`Backend not found: ${device}`);
4297
4308
  }
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-B3foXiV_.cjs');
33
+ const require_backend = require('./backend-DpI0riom.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -920,18 +920,25 @@ var Tracer = class Tracer {
920
920
  return sort$1(this.transpose(perm)).transpose(require_backend.invertPermutation(perm));
921
921
  }
922
922
  /**
923
- * Return the indices that would sort an array. This may not be a stable
924
- * sorting algorithm; it need not preserve order of indices in ties.
923
+ * Return the indices that would sort an array. Unlike `sort`, this is
924
+ * guaranteed to be a stable sorting algorithm; it always returns the smaller
925
+ * index first in event of ties.
925
926
  *
926
927
  * See `jax.numpy.argsort` for full docs.
927
928
  */
928
929
  argsort(axis = -1) {
929
930
  axis = require_backend.checkAxis(axis, this.ndim);
930
- if (axis === this.ndim - 1) return argsort$1(this)[1];
931
+ if (axis === this.ndim - 1) {
932
+ const [y$1, yi$1] = argsort$1(this);
933
+ y$1.dispose();
934
+ return yi$1;
935
+ }
931
936
  const perm = require_backend.range(this.ndim);
932
937
  perm.splice(axis, 1);
933
938
  perm.push(axis);
934
- return argsort$1(this.transpose(perm))[1].transpose(require_backend.invertPermutation(perm));
939
+ const [y, yi] = argsort$1(this.transpose(perm));
940
+ y.dispose();
941
+ return yi.transpose(require_backend.invertPermutation(perm));
935
942
  }
936
943
  /**
937
944
  * Slice an array along one or more axes.
@@ -3416,32 +3423,26 @@ function fullInternal(aval, fillValue, device) {
3416
3423
  committed: device != void 0
3417
3424
  });
3418
3425
  }
3419
- function zerosLike$1(val, dtype) {
3420
- return fullLike(val, 0, dtype);
3426
+ function zerosLike$1(val, opts) {
3427
+ return fullLike(val, 0, opts);
3421
3428
  }
3422
- function onesLike$1(val, dtype) {
3423
- return fullLike(val, 1, dtype);
3429
+ function onesLike$1(val, opts) {
3430
+ return fullLike(val, 1, opts);
3424
3431
  }
3425
- function fullLike(val, fillValue, dtype) {
3432
+ function fullLike(val, fillValue, { dtype, shape: shape$1, device } = {}) {
3426
3433
  const aval = getAval(val);
3427
3434
  if (val instanceof Tracer) val.dispose();
3428
3435
  if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
3429
- const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
3430
- return fullInternal(sa, fillValue);
3436
+ const sa = new ShapedArray(shape$1 ?? aval.shape, dtype ?? aval.dtype, aval.weakType && dtype === void 0);
3437
+ return fullInternal(sa, fillValue, device);
3431
3438
  }
3432
3439
  /** Return a new array of given shape and type, filled with zeros. */
3433
- function zeros(shape$1, { dtype, device } = {}) {
3434
- return full(shape$1, 0, {
3435
- dtype,
3436
- device
3437
- });
3440
+ function zeros(shape$1, opts) {
3441
+ return full(shape$1, 0, opts);
3438
3442
  }
3439
3443
  /** Return a new array of given shape and type, filled with ones. */
3440
- function ones(shape$1, { dtype, device } = {}) {
3441
- return full(shape$1, 1, {
3442
- dtype,
3443
- device
3444
- });
3444
+ function ones(shape$1, opts) {
3445
+ return full(shape$1, 1, opts);
3445
3446
  }
3446
3447
  /** Return a new array of given shape and type, filled with `fill_value`. */
3447
3448
  function full(shape$1, fillValue, { dtype, device } = {}) {
@@ -5332,7 +5333,7 @@ function lstsq(a, b) {
5332
5333
  lower: true,
5333
5334
  transposeA: true
5334
5335
  });
5335
- return matmul(at, llb.ref);
5336
+ return matmul(at, llb);
5336
5337
  } else {
5337
5338
  const ata = matmul(at.ref, a);
5338
5339
  const l = cholesky(ata, { symmetrizeInput: false });
@@ -5423,7 +5424,7 @@ function solve(a, b) {
5423
5424
  lower: true,
5424
5425
  unitDiagonal: true
5425
5426
  });
5426
- let x = triangularSolve(lu$2, LPb.ref, {
5427
+ let x = triangularSolve(lu$2, LPb, {
5427
5428
  leftSide: true,
5428
5429
  lower: false
5429
5430
  });
@@ -6234,8 +6235,9 @@ function sort(a, axis = -1) {
6234
6235
  return fudgeArray(a).sort(axis);
6235
6236
  }
6236
6237
  /**
6237
- * Return indices that would sort an array. This may be an unstable sorting
6238
- * algorithm; it need not preserve order of indices in ties.
6238
+ * Return indices that would sort an array. Unlike `sort`, this is guaranteed to
6239
+ * be a stable sorting algorithm; it always returns the smaller index first in
6240
+ * event of ties.
6239
6241
  *
6240
6242
  * Returns an array of `int32` indices.
6241
6243
  *
@@ -6537,7 +6539,7 @@ function absolute(x) {
6537
6539
  /** Return an element-wise indication of sign of the input. */
6538
6540
  function sign(x) {
6539
6541
  x = fudgeArray(x);
6540
- return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
6542
+ return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
6541
6543
  }
6542
6544
  /** @function Return element-wise positive values of the input (no-op). */
6543
6545
  const positive = fudgeArray;
@@ -7030,7 +7032,8 @@ __export(lax_exports, {
7030
7032
  erfc: () => erfc,
7031
7033
  linalg: () => lax_linalg_exports,
7032
7034
  reduceWindow: () => reduceWindow,
7033
- stopGradient: () => stopGradient$1
7035
+ stopGradient: () => stopGradient$1,
7036
+ topK: () => topK
7034
7037
  });
7035
7038
  const JsArray = globalThis.Array;
7036
7039
  /**
@@ -7254,6 +7257,39 @@ function erfc(x) {
7254
7257
  function stopGradient$1(x) {
7255
7258
  return stopGradient(x);
7256
7259
  }
7260
+ /**
7261
+ * Returns top `k` values and their indices along the specified axis of operand.
7262
+ *
7263
+ * This is a _stable_ algorithm: If two elements are equal, the lower-index
7264
+ * element appears first.
7265
+ *
7266
+ * @returns A tuple of `(values, indices)`, where `values` and `indices` have
7267
+ * the same shape as `x`, except along `axis` where they have size `k`.
7268
+ */
7269
+ function topK(x, k, axis = -1) {
7270
+ x = fudgeArray(x);
7271
+ axis = require_backend.checkAxis(axis, x.ndim);
7272
+ const size$1 = x.shape[axis];
7273
+ if (k < 0 || k > size$1) throw new Error(`topK: k must be in the range [0, ${size$1}], got ${k}`);
7274
+ if (k === 0) {
7275
+ const outShape = x.shape.slice();
7276
+ outShape[axis] = 0;
7277
+ const y$1 = zerosLike$1(x.ref, { shape: outShape });
7278
+ const yi$1 = zerosLike$1(x, {
7279
+ dtype: require_backend.DType.Int32,
7280
+ shape: outShape
7281
+ });
7282
+ return [y$1, yi$1];
7283
+ }
7284
+ x = flip$1(x, [axis]);
7285
+ x = moveaxis(x, axis, -1);
7286
+ const [y, yi] = argsort$1(x);
7287
+ const extract = (a) => {
7288
+ a = a.slice(...require_backend.rep(a.ndim - 1, []), [-k]);
7289
+ return flip$1(moveaxis(a, -1, axis), [axis]);
7290
+ };
7291
+ return [extract(y), extract(yi.neg().add(size$1 - 1))];
7292
+ }
7257
7293
 
7258
7294
  //#endregion
7259
7295
  //#region src/library/nn.ts
@@ -7445,7 +7481,7 @@ const gelu = jit$1(function gelu$1(x, opts) {
7445
7481
  if (opts?.approximate ?? true) {
7446
7482
  const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
7447
7483
  return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
7448
- } else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
7484
+ } else return x.ref.mul(.5).mul(erfc$1(negative(x.mul(Math.SQRT1_2))));
7449
7485
  }, { staticArgnums: [1] });
7450
7486
  /**
7451
7487
  * Gated linear unit (GLU) activation function.
@@ -7703,6 +7739,7 @@ var random_exports = {};
7703
7739
  __export(random_exports, {
7704
7740
  bernoulli: () => bernoulli,
7705
7741
  bits: () => bits,
7742
+ categorical: () => categorical,
7706
7743
  cauchy: () => cauchy,
7707
7744
  exponential: () => exponential,
7708
7745
  gumbel: () => gumbel,
@@ -7774,6 +7811,47 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
7774
7811
  }
7775
7812
  /**
7776
7813
  * @function
7814
+ * Sample random values from categorical distributions.
7815
+ *
7816
+ * Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
7817
+ * trick for sampling without replacement.
7818
+ *
7819
+ * Note: Sampling without replacement currently uses argsort and slices the last
7820
+ * k elements. This should be replaced with a more efficient topK implementation.
7821
+ *
7822
+ * - `key` - PRNG key
7823
+ * - `logits` - Unnormalized log probabilities of the categorical distribution(s).
7824
+ * `softmax(logits, axis)` gives the corresponding probabilities.
7825
+ * - `axis` - Axis along which logits belong to the same categorical distribution.
7826
+ * - `shape` - Result batch shape. Must be broadcast-compatible with
7827
+ * `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
7828
+ * - `replace` - If true (default), sample with replacement. If false, sample
7829
+ * without replacement (each category can only be selected once per batch).
7830
+ * @returns A random array with int dtype and shape given by `shape` if provided,
7831
+ * otherwise `logits.shape` with `axis` removed.
7832
+ */
7833
+ const categorical = jit$1(function categorical$1(key$1, logits, { axis = -1, shape: shape$1, replace = true } = {}) {
7834
+ logits = fudgeArray(logits);
7835
+ axis = require_backend.checkAxis(axis, logits.ndim);
7836
+ const numCategories = logits.shape[axis];
7837
+ const batchShape = logits.shape.toSpliced(axis, 1);
7838
+ if (shape$1 === void 0) shape$1 = batchShape;
7839
+ else if (!require_backend.deepEqual(require_backend.generalBroadcast(shape$1, batchShape), shape$1)) throw new Error(`Shape ${shape$1} is not broadcast-compatible with batch shape ${batchShape}.`);
7840
+ const shapePrefix = shape$1.slice(0, shape$1.length - batchShape.length);
7841
+ if (replace) {
7842
+ const noise = gumbel(key$1, [...shapePrefix, ...logits.shape]);
7843
+ return argmax(noise.add(logits), axis + shapePrefix.length);
7844
+ } else {
7845
+ const k = shapePrefix.reduce((a, b) => a * b, 1);
7846
+ if (k > numCategories) throw new Error(`Number of samples without replacement (${k}) cannot exceed number of categories (${numCategories}).`);
7847
+ const noise = gumbel(key$1, logits.shape);
7848
+ const [values, indices] = topK(noise.add(logits), k, axis);
7849
+ values.dispose();
7850
+ return indices.reshape(shape$1);
7851
+ }
7852
+ }, { staticArgnums: [2] });
7853
+ /**
7854
+ * @function
7777
7855
  * Sample from a Cauchy distribution with location 0 and scale 1.
7778
7856
  *
7779
7857
  * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
package/dist/index.d.cts CHANGED
@@ -436,9 +436,14 @@ declare class Routine {
436
436
  }
437
437
  /** One of the valid `Routine` that can be dispatched to backend. */
438
438
  declare enum Routines {
439
- /** Stable sorting algorithm along the last axis. */
439
+ /**
440
+ * Sort along the last axis.
441
+ *
442
+ * This may be _unstable_ but it often doesn't matter, sorting numbers is
443
+ * bitwise unique up to signed zeros and NaNs.
444
+ */
440
445
  Sort = "Sort",
441
- /** Returns `int32` indices of the stably sorted array. */
446
+ /** Stable sorting, returns `int32` indices and values of the sorted array. */
442
447
  Argsort = "Argsort",
443
448
  /**
444
449
  * Solve a triangular system of equations.
@@ -750,9 +755,9 @@ declare enum Primitive {
750
755
  Shrink = "shrink",
751
756
  Pad = "pad",
752
757
  Sort = "sort",
753
- // sort(x, axis=-1)
758
+ // sort(x, axis=-1), unstable
754
759
  Argsort = "argsort",
755
- // argsort(x, axis=-1)
760
+ // argsort(x, axis=-1), stable
756
761
  TriangularSolve = "triangular_solve",
757
762
  // A is upper triangular, A @ X.T = B.T
758
763
  Cholesky = "cholesky",
@@ -1029,8 +1034,9 @@ declare abstract class Tracer {
1029
1034
  */
1030
1035
  sort(axis?: number): this;
1031
1036
  /**
1032
- * Return the indices that would sort an array. This may not be a stable
1033
- * sorting algorithm; it need not preserve order of indices in ties.
1037
+ * Return the indices that would sort an array. Unlike `sort`, this is
1038
+ * guaranteed to be a stable sorting algorithm; it always returns the smaller
1039
+ * index first in event of ties.
1034
1040
  *
1035
1041
  * See `jax.numpy.argsort` for full docs.
1036
1042
  */
@@ -1112,6 +1118,12 @@ type DTypeAndDevice = {
1112
1118
  dtype?: DType;
1113
1119
  device?: Device;
1114
1120
  };
1121
+ /** @inline */
1122
+ type DTypeShapeAndDevice = {
1123
+ dtype?: DType;
1124
+ shape?: number[];
1125
+ device?: Device;
1126
+ };
1115
1127
  type ArrayConstructorArgs = {
1116
1128
  source: AluExp | Slot;
1117
1129
  st: ShapeTracker;
@@ -1221,15 +1233,9 @@ declare function array(values: Array | DataArray | RecursiveArray<number> | Recu
1221
1233
  type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
1222
1234
  declare const implRules: { [P in Primitive]: ImplRule<P> };
1223
1235
  /** Return a new array of given shape and type, filled with zeros. */
1224
- declare function zeros(shape: number[], {
1225
- dtype,
1226
- device
1227
- }?: DTypeAndDevice): Array;
1236
+ declare function zeros(shape: number[], opts?: DTypeAndDevice): Array;
1228
1237
  /** Return a new array of given shape and type, filled with ones. */
1229
- declare function ones(shape: number[], {
1230
- dtype,
1231
- device
1232
- }?: DTypeAndDevice): Array;
1238
+ declare function ones(shape: number[], opts?: DTypeAndDevice): Array;
1233
1239
  /** Return a new array of given shape and type, filled with `fill_value`. */
1234
1240
  declare function full(shape: number[], fillValue: number | boolean | Array, {
1235
1241
  dtype,
@@ -1421,7 +1427,7 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
1421
1427
  unitDiagonal?: boolean;
1422
1428
  }): Array;
1423
1429
  declare namespace lax_d_exports {
1424
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
1430
+ export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
1425
1431
  }
1426
1432
  /**
1427
1433
  * Dimension numbers for general `dot()` primitive.
@@ -1527,6 +1533,16 @@ declare function erfc(x: ArrayLike): Array;
1527
1533
  * forward or reverse-mode automatic differentiation.
1528
1534
  */
1529
1535
  declare function stopGradient(x: ArrayLike): Array;
1536
+ /**
1537
+ * Returns top `k` values and their indices along the specified axis of operand.
1538
+ *
1539
+ * This is a _stable_ algorithm: If two elements are equal, the lower-index
1540
+ * element appears first.
1541
+ *
1542
+ * @returns A tuple of `(values, indices)`, where `values` and `indices` have
1543
+ * the same shape as `x`, except along `axis` where they have size `k`.
1544
+ */
1545
+ declare function topK(x: ArrayLike, k: number, axis?: number): [Array, Array];
1530
1546
  declare namespace numpy_fft_d_exports {
1531
1547
  export { ComplexPair, fft, ifft };
1532
1548
  }
@@ -1752,17 +1768,17 @@ declare const shape$1: (x: ArrayLike) => number[];
1752
1768
  * @function
1753
1769
  * Return an array of zeros with the same shape and type as a given array.
1754
1770
  */
1755
- declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
1771
+ declare const zerosLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
1756
1772
  /**
1757
1773
  * @function
1758
1774
  * Return an array of ones with the same shape and type as a given array.
1759
1775
  */
1760
- declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
1776
+ declare const onesLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
1761
1777
  /**
1762
1778
  * @function
1763
1779
  * Return a full array with the same shape and type as a given array.
1764
1780
  */
1765
- declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
1781
+ declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, opts?: DTypeShapeAndDevice) => Array;
1766
1782
  /**
1767
1783
  * Return the number of elements in an array, optionally along an axis.
1768
1784
  * Does not consume array reference.
@@ -1951,8 +1967,9 @@ declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: nu
1951
1967
  */
1952
1968
  declare function sort(a: ArrayLike, axis?: number): Array;
1953
1969
  /**
1954
- * Return indices that would sort an array. This may be an unstable sorting
1955
- * algorithm; it need not preserve order of indices in ties.
1970
+ * Return indices that would sort an array. Unlike `sort`, this is guaranteed to
1971
+ * be a stable sorting algorithm; it always returns the smaller index first in
1972
+ * event of ties.
1956
1973
  *
1957
1974
  * Returns an array of `int32` indices.
1958
1975
  *
@@ -2564,7 +2581,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
2564
2581
  localWindowSize?: number | [number, number];
2565
2582
  }): Array;
2566
2583
  declare namespace random_d_exports {
2567
- export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2584
+ export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2568
2585
  }
2569
2586
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
2570
2587
  declare function key(seed: ArrayLike): Array;
@@ -2587,6 +2604,32 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
2587
2604
  * and must be broadcastable to `shape`.
2588
2605
  */
2589
2606
  declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
2607
+ /**
2608
+ * @function
2609
+ * Sample random values from categorical distributions.
2610
+ *
2611
+ * Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
2612
+ * trick for sampling without replacement.
2613
+ *
2614
+ * Note: Sampling without replacement currently uses argsort and slices the last
2615
+ * k elements. This should be replaced with a more efficient topK implementation.
2616
+ *
2617
+ * - `key` - PRNG key
2618
+ * - `logits` - Unnormalized log probabilities of the categorical distribution(s).
2619
+ * `softmax(logits, axis)` gives the corresponding probabilities.
2620
+ * - `axis` - Axis along which logits belong to the same categorical distribution.
2621
+ * - `shape` - Result batch shape. Must be broadcast-compatible with
2622
+ * `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
2623
+ * - `replace` - If true (default), sample with replacement. If false, sample
2624
+ * without replacement (each category can only be selected once per batch).
2625
+ * @returns A random array with int dtype and shape given by `shape` if provided,
2626
+ * otherwise `logits.shape` with `axis` removed.
2627
+ */
2628
+ declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, args_2?: {
2629
+ axis?: number | undefined;
2630
+ shape?: number[] | undefined;
2631
+ replace?: boolean | undefined;
2632
+ } | undefined) => Array>;
2590
2633
  /**
2591
2634
  * @function
2592
2635
  * Sample from a Cauchy distribution with location 0 and scale 1.
package/dist/index.d.ts CHANGED
@@ -433,9 +433,14 @@ declare class Routine {
433
433
  }
434
434
  /** One of the valid `Routine` that can be dispatched to backend. */
435
435
  declare enum Routines {
436
- /** Stable sorting algorithm along the last axis. */
436
+ /**
437
+ * Sort along the last axis.
438
+ *
439
+ * This may be _unstable_ but it often doesn't matter, sorting numbers is
440
+ * bitwise unique up to signed zeros and NaNs.
441
+ */
437
442
  Sort = "Sort",
438
- /** Returns `int32` indices of the stably sorted array. */
443
+ /** Stable sorting, returns `int32` indices and values of the sorted array. */
439
444
  Argsort = "Argsort",
440
445
  /**
441
446
  * Solve a triangular system of equations.
@@ -747,9 +752,9 @@ declare enum Primitive {
747
752
  Shrink = "shrink",
748
753
  Pad = "pad",
749
754
  Sort = "sort",
750
- // sort(x, axis=-1)
755
+ // sort(x, axis=-1), unstable
751
756
  Argsort = "argsort",
752
- // argsort(x, axis=-1)
757
+ // argsort(x, axis=-1), stable
753
758
  TriangularSolve = "triangular_solve",
754
759
  // A is upper triangular, A @ X.T = B.T
755
760
  Cholesky = "cholesky",
@@ -1026,8 +1031,9 @@ declare abstract class Tracer {
1026
1031
  */
1027
1032
  sort(axis?: number): this;
1028
1033
  /**
1029
- * Return the indices that would sort an array. This may not be a stable
1030
- * sorting algorithm; it need not preserve order of indices in ties.
1034
+ * Return the indices that would sort an array. Unlike `sort`, this is
1035
+ * guaranteed to be a stable sorting algorithm; it always returns the smaller
1036
+ * index first in event of ties.
1031
1037
  *
1032
1038
  * See `jax.numpy.argsort` for full docs.
1033
1039
  */
@@ -1109,6 +1115,12 @@ type DTypeAndDevice = {
1109
1115
  dtype?: DType;
1110
1116
  device?: Device;
1111
1117
  };
1118
+ /** @inline */
1119
+ type DTypeShapeAndDevice = {
1120
+ dtype?: DType;
1121
+ shape?: number[];
1122
+ device?: Device;
1123
+ };
1112
1124
  type ArrayConstructorArgs = {
1113
1125
  source: AluExp | Slot;
1114
1126
  st: ShapeTracker;
@@ -1218,15 +1230,9 @@ declare function array(values: Array | DataArray | RecursiveArray<number> | Recu
1218
1230
  type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
1219
1231
  declare const implRules: { [P in Primitive]: ImplRule<P> };
1220
1232
  /** Return a new array of given shape and type, filled with zeros. */
1221
- declare function zeros(shape: number[], {
1222
- dtype,
1223
- device
1224
- }?: DTypeAndDevice): Array;
1233
+ declare function zeros(shape: number[], opts?: DTypeAndDevice): Array;
1225
1234
  /** Return a new array of given shape and type, filled with ones. */
1226
- declare function ones(shape: number[], {
1227
- dtype,
1228
- device
1229
- }?: DTypeAndDevice): Array;
1235
+ declare function ones(shape: number[], opts?: DTypeAndDevice): Array;
1230
1236
  /** Return a new array of given shape and type, filled with `fill_value`. */
1231
1237
  declare function full(shape: number[], fillValue: number | boolean | Array, {
1232
1238
  dtype,
@@ -1418,7 +1424,7 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
1418
1424
  unitDiagonal?: boolean;
1419
1425
  }): Array;
1420
1426
  declare namespace lax_d_exports {
1421
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
1427
+ export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
1422
1428
  }
1423
1429
  /**
1424
1430
  * Dimension numbers for general `dot()` primitive.
@@ -1524,6 +1530,16 @@ declare function erfc(x: ArrayLike): Array;
1524
1530
  * forward or reverse-mode automatic differentiation.
1525
1531
  */
1526
1532
  declare function stopGradient(x: ArrayLike): Array;
1533
+ /**
1534
+ * Returns top `k` values and their indices along the specified axis of operand.
1535
+ *
1536
+ * This is a _stable_ algorithm: If two elements are equal, the lower-index
1537
+ * element appears first.
1538
+ *
1539
+ * @returns A tuple of `(values, indices)`, where `values` and `indices` have
1540
+ * the same shape as `x`, except along `axis` where they have size `k`.
1541
+ */
1542
+ declare function topK(x: ArrayLike, k: number, axis?: number): [Array, Array];
1527
1543
  declare namespace numpy_fft_d_exports {
1528
1544
  export { ComplexPair, fft, ifft };
1529
1545
  }
@@ -1749,17 +1765,17 @@ declare const shape$1: (x: ArrayLike) => number[];
1749
1765
  * @function
1750
1766
  * Return an array of zeros with the same shape and type as a given array.
1751
1767
  */
1752
- declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
1768
+ declare const zerosLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
1753
1769
  /**
1754
1770
  * @function
1755
1771
  * Return an array of ones with the same shape and type as a given array.
1756
1772
  */
1757
- declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
1773
+ declare const onesLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
1758
1774
  /**
1759
1775
  * @function
1760
1776
  * Return a full array with the same shape and type as a given array.
1761
1777
  */
1762
- declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
1778
+ declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, opts?: DTypeShapeAndDevice) => Array;
1763
1779
  /**
1764
1780
  * Return the number of elements in an array, optionally along an axis.
1765
1781
  * Does not consume array reference.
@@ -1948,8 +1964,9 @@ declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: nu
1948
1964
  */
1949
1965
  declare function sort(a: ArrayLike, axis?: number): Array;
1950
1966
  /**
1951
- * Return indices that would sort an array. This may be an unstable sorting
1952
- * algorithm; it need not preserve order of indices in ties.
1967
+ * Return indices that would sort an array. Unlike `sort`, this is guaranteed to
1968
+ * be a stable sorting algorithm; it always returns the smaller index first in
1969
+ * event of ties.
1953
1970
  *
1954
1971
  * Returns an array of `int32` indices.
1955
1972
  *
@@ -2561,7 +2578,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
2561
2578
  localWindowSize?: number | [number, number];
2562
2579
  }): Array;
2563
2580
  declare namespace random_d_exports {
2564
- export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2581
+ export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2565
2582
  }
2566
2583
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
2567
2584
  declare function key(seed: ArrayLike): Array;
@@ -2584,6 +2601,32 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
2584
2601
  * and must be broadcastable to `shape`.
2585
2602
  */
2586
2603
  declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
2604
+ /**
2605
+ * @function
2606
+ * Sample random values from categorical distributions.
2607
+ *
2608
+ * Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
2609
+ * trick for sampling without replacement.
2610
+ *
2611
+ * Note: Sampling without replacement currently uses argsort and slices the last
2612
+ * k elements. This should be replaced with a more efficient topK implementation.
2613
+ *
2614
+ * - `key` - PRNG key
2615
+ * - `logits` - Unnormalized log probabilities of the categorical distribution(s).
2616
+ * `softmax(logits, axis)` gives the corresponding probabilities.
2617
+ * - `axis` - Axis along which logits belong to the same categorical distribution.
2618
+ * - `shape` - Result batch shape. Must be broadcast-compatible with
2619
+ * `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
2620
+ * - `replace` - If true (default), sample with replacement. If false, sample
2621
+ * without replacement (each category can only be selected once per batch).
2622
+ * @returns A random array with int dtype and shape given by `shape` if provided,
2623
+ * otherwise `logits.shape` with `axis` removed.
2624
+ */
2625
+ declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, args_2?: {
2626
+ axis?: number | undefined;
2627
+ shape?: number[] | undefined;
2628
+ replace?: boolean | undefined;
2629
+ } | undefined) => Array>;
2587
2630
  /**
2588
2631
  * @function
2589
2632
  * Sample from a Cauchy distribution with location 0 and scale 1.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-nEolvdLv.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-BId79r5b.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -889,18 +889,25 @@ var Tracer = class Tracer {
889
889
  return sort$1(this.transpose(perm)).transpose(invertPermutation(perm));
890
890
  }
891
891
  /**
892
- * Return the indices that would sort an array. This may not be a stable
893
- * sorting algorithm; it need not preserve order of indices in ties.
892
+ * Return the indices that would sort an array. Unlike `sort`, this is
893
+ * guaranteed to be a stable sorting algorithm; it always returns the smaller
894
+ * index first in event of ties.
894
895
  *
895
896
  * See `jax.numpy.argsort` for full docs.
896
897
  */
897
898
  argsort(axis = -1) {
898
899
  axis = checkAxis(axis, this.ndim);
899
- if (axis === this.ndim - 1) return argsort$1(this)[1];
900
+ if (axis === this.ndim - 1) {
901
+ const [y$1, yi$1] = argsort$1(this);
902
+ y$1.dispose();
903
+ return yi$1;
904
+ }
900
905
  const perm = range(this.ndim);
901
906
  perm.splice(axis, 1);
902
907
  perm.push(axis);
903
- return argsort$1(this.transpose(perm))[1].transpose(invertPermutation(perm));
908
+ const [y, yi] = argsort$1(this.transpose(perm));
909
+ y.dispose();
910
+ return yi.transpose(invertPermutation(perm));
904
911
  }
905
912
  /**
906
913
  * Slice an array along one or more axes.
@@ -3381,32 +3388,26 @@ function fullInternal(aval, fillValue, device) {
3381
3388
  committed: device != void 0
3382
3389
  });
3383
3390
  }
3384
- function zerosLike$1(val, dtype) {
3385
- return fullLike(val, 0, dtype);
3391
+ function zerosLike$1(val, opts) {
3392
+ return fullLike(val, 0, opts);
3386
3393
  }
3387
- function onesLike$1(val, dtype) {
3388
- return fullLike(val, 1, dtype);
3394
+ function onesLike$1(val, opts) {
3395
+ return fullLike(val, 1, opts);
3389
3396
  }
3390
- function fullLike(val, fillValue, dtype) {
3397
+ function fullLike(val, fillValue, { dtype, shape: shape$1, device } = {}) {
3391
3398
  const aval = getAval(val);
3392
3399
  if (val instanceof Tracer) val.dispose();
3393
3400
  if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
3394
- const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
3395
- return fullInternal(sa, fillValue);
3401
+ const sa = new ShapedArray(shape$1 ?? aval.shape, dtype ?? aval.dtype, aval.weakType && dtype === void 0);
3402
+ return fullInternal(sa, fillValue, device);
3396
3403
  }
3397
3404
  /** Return a new array of given shape and type, filled with zeros. */
3398
- function zeros(shape$1, { dtype, device } = {}) {
3399
- return full(shape$1, 0, {
3400
- dtype,
3401
- device
3402
- });
3405
+ function zeros(shape$1, opts) {
3406
+ return full(shape$1, 0, opts);
3403
3407
  }
3404
3408
  /** Return a new array of given shape and type, filled with ones. */
3405
- function ones(shape$1, { dtype, device } = {}) {
3406
- return full(shape$1, 1, {
3407
- dtype,
3408
- device
3409
- });
3409
+ function ones(shape$1, opts) {
3410
+ return full(shape$1, 1, opts);
3410
3411
  }
3411
3412
  /** Return a new array of given shape and type, filled with `fill_value`. */
3412
3413
  function full(shape$1, fillValue, { dtype, device } = {}) {
@@ -5295,7 +5296,7 @@ function lstsq(a, b) {
5295
5296
  lower: true,
5296
5297
  transposeA: true
5297
5298
  });
5298
- return matmul(at, llb.ref);
5299
+ return matmul(at, llb);
5299
5300
  } else {
5300
5301
  const ata = matmul(at.ref, a);
5301
5302
  const l = cholesky(ata, { symmetrizeInput: false });
@@ -5386,7 +5387,7 @@ function solve(a, b) {
5386
5387
  lower: true,
5387
5388
  unitDiagonal: true
5388
5389
  });
5389
- let x = triangularSolve(lu$2, LPb.ref, {
5390
+ let x = triangularSolve(lu$2, LPb, {
5390
5391
  leftSide: true,
5391
5392
  lower: false
5392
5393
  });
@@ -6197,8 +6198,9 @@ function sort(a, axis = -1) {
6197
6198
  return fudgeArray(a).sort(axis);
6198
6199
  }
6199
6200
  /**
6200
- * Return indices that would sort an array. This may be an unstable sorting
6201
- * algorithm; it need not preserve order of indices in ties.
6201
+ * Return indices that would sort an array. Unlike `sort`, this is guaranteed to
6202
+ * be a stable sorting algorithm; it always returns the smaller index first in
6203
+ * event of ties.
6202
6204
  *
6203
6205
  * Returns an array of `int32` indices.
6204
6206
  *
@@ -6500,7 +6502,7 @@ function absolute(x) {
6500
6502
  /** Return an element-wise indication of sign of the input. */
6501
6503
  function sign(x) {
6502
6504
  x = fudgeArray(x);
6503
- return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
6505
+ return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
6504
6506
  }
6505
6507
  /** @function Return element-wise positive values of the input (no-op). */
6506
6508
  const positive = fudgeArray;
@@ -6993,7 +6995,8 @@ __export(lax_exports, {
6993
6995
  erfc: () => erfc,
6994
6996
  linalg: () => lax_linalg_exports,
6995
6997
  reduceWindow: () => reduceWindow,
6996
- stopGradient: () => stopGradient$1
6998
+ stopGradient: () => stopGradient$1,
6999
+ topK: () => topK
6997
7000
  });
6998
7001
  const JsArray = globalThis.Array;
6999
7002
  /**
@@ -7217,6 +7220,39 @@ function erfc(x) {
7217
7220
  function stopGradient$1(x) {
7218
7221
  return stopGradient(x);
7219
7222
  }
7223
+ /**
7224
+ * Returns top `k` values and their indices along the specified axis of operand.
7225
+ *
7226
+ * This is a _stable_ algorithm: If two elements are equal, the lower-index
7227
+ * element appears first.
7228
+ *
7229
+ * @returns A tuple of `(values, indices)`, where `values` and `indices` have
7230
+ * the same shape as `x`, except along `axis` where they have size `k`.
7231
+ */
7232
+ function topK(x, k, axis = -1) {
7233
+ x = fudgeArray(x);
7234
+ axis = checkAxis(axis, x.ndim);
7235
+ const size$1 = x.shape[axis];
7236
+ if (k < 0 || k > size$1) throw new Error(`topK: k must be in the range [0, ${size$1}], got ${k}`);
7237
+ if (k === 0) {
7238
+ const outShape = x.shape.slice();
7239
+ outShape[axis] = 0;
7240
+ const y$1 = zerosLike$1(x.ref, { shape: outShape });
7241
+ const yi$1 = zerosLike$1(x, {
7242
+ dtype: DType.Int32,
7243
+ shape: outShape
7244
+ });
7245
+ return [y$1, yi$1];
7246
+ }
7247
+ x = flip$1(x, [axis]);
7248
+ x = moveaxis(x, axis, -1);
7249
+ const [y, yi] = argsort$1(x);
7250
+ const extract = (a) => {
7251
+ a = a.slice(...rep(a.ndim - 1, []), [-k]);
7252
+ return flip$1(moveaxis(a, -1, axis), [axis]);
7253
+ };
7254
+ return [extract(y), extract(yi.neg().add(size$1 - 1))];
7255
+ }
7220
7256
 
7221
7257
  //#endregion
7222
7258
  //#region src/library/nn.ts
@@ -7408,7 +7444,7 @@ const gelu = jit$1(function gelu$1(x, opts) {
7408
7444
  if (opts?.approximate ?? true) {
7409
7445
  const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
7410
7446
  return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
7411
- } else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
7447
+ } else return x.ref.mul(.5).mul(erfc$1(negative(x.mul(Math.SQRT1_2))));
7412
7448
  }, { staticArgnums: [1] });
7413
7449
  /**
7414
7450
  * Gated linear unit (GLU) activation function.
@@ -7666,6 +7702,7 @@ var random_exports = {};
7666
7702
  __export(random_exports, {
7667
7703
  bernoulli: () => bernoulli,
7668
7704
  bits: () => bits,
7705
+ categorical: () => categorical,
7669
7706
  cauchy: () => cauchy,
7670
7707
  exponential: () => exponential,
7671
7708
  gumbel: () => gumbel,
@@ -7737,6 +7774,47 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
7737
7774
  }
7738
7775
  /**
7739
7776
  * @function
7777
+ * Sample random values from categorical distributions.
7778
+ *
7779
+ * Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
7780
+ * trick for sampling without replacement.
7781
+ *
7782
+ * Note: Sampling without replacement currently uses argsort and slices the last
7783
+ * k elements. This should be replaced with a more efficient topK implementation.
7784
+ *
7785
+ * - `key` - PRNG key
7786
+ * - `logits` - Unnormalized log probabilities of the categorical distribution(s).
7787
+ * `softmax(logits, axis)` gives the corresponding probabilities.
7788
+ * - `axis` - Axis along which logits belong to the same categorical distribution.
7789
+ * - `shape` - Result batch shape. Must be broadcast-compatible with
7790
+ * `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
7791
+ * - `replace` - If true (default), sample with replacement. If false, sample
7792
+ * without replacement (each category can only be selected once per batch).
7793
+ * @returns A random array with int dtype and shape given by `shape` if provided,
7794
+ * otherwise `logits.shape` with `axis` removed.
7795
+ */
7796
+ const categorical = jit$1(function categorical$1(key$1, logits, { axis = -1, shape: shape$1, replace = true } = {}) {
7797
+ logits = fudgeArray(logits);
7798
+ axis = checkAxis(axis, logits.ndim);
7799
+ const numCategories = logits.shape[axis];
7800
+ const batchShape = logits.shape.toSpliced(axis, 1);
7801
+ if (shape$1 === void 0) shape$1 = batchShape;
7802
+ else if (!deepEqual(generalBroadcast(shape$1, batchShape), shape$1)) throw new Error(`Shape ${shape$1} is not broadcast-compatible with batch shape ${batchShape}.`);
7803
+ const shapePrefix = shape$1.slice(0, shape$1.length - batchShape.length);
7804
+ if (replace) {
7805
+ const noise = gumbel(key$1, [...shapePrefix, ...logits.shape]);
7806
+ return argmax(noise.add(logits), axis + shapePrefix.length);
7807
+ } else {
7808
+ const k = shapePrefix.reduce((a, b) => a * b, 1);
7809
+ if (k > numCategories) throw new Error(`Number of samples without replacement (${k}) cannot exceed number of categories (${numCategories}).`);
7810
+ const noise = gumbel(key$1, logits.shape);
7811
+ const [values, indices] = topK(noise.add(logits), k, axis);
7812
+ values.dispose();
7813
+ return indices.reshape(shape$1);
7814
+ }
7815
+ }, { staticArgnums: [2] });
7816
+ /**
7817
+ * @function
7740
7818
  * Sample from a Cauchy distribution with location 0 and scale 1.
7741
7819
  *
7742
7820
  * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-B3foXiV_.cjs');
1
+ const require_backend = require('./backend-DpI0riom.cjs');
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-nEolvdLv.js";
1
+ import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-BId79r5b.js";
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-nEolvdLv.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-BId79r5b.js";
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -247,6 +247,10 @@ function bitonicSortUniform(pass) {
247
247
  * `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
248
248
  *
249
249
  * The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
250
+ *
251
+ * If `outputIndices` is true, the shader also tracks the original indices of
252
+ * the sorted elements (argsort) and outputs them to a separate buffer. This
253
+ * also makes the sorting algorithm stable.
250
254
  */
251
255
  function bitonicSortShader(device, dtype, n, batches, outputIndices) {
252
256
  const ty = dtypeToWgsl(dtype, true);
@@ -286,14 +290,21 @@ ${isFloatDtype(dtype) ? `
286
290
  fn compare_and_swap(i: u32, j: u32) {
287
291
  let val_i = shared_vals[i];
288
292
  let val_j = shared_vals[j];
289
- if (compare(val_j, val_i)) {
293
+ ${outputIndices ? `
294
+ if (
295
+ compare(val_j, val_i) ||
296
+ (!compare(val_i, val_j) && shared_idx[j] < shared_idx[i])
297
+ ) {
290
298
  shared_vals[i] = val_j;
291
299
  shared_vals[j] = val_i;
292
- ${outputIndices ? `
293
300
  let tmp_idx = shared_idx[i];
294
301
  shared_idx[i] = shared_idx[j];
295
- shared_idx[j] = tmp_idx;` : ""}
296
- }
302
+ shared_idx[j] = tmp_idx;
303
+ }` : `
304
+ if (compare(val_j, val_i)) {
305
+ shared_vals[i] = val_j;
306
+ shared_vals[j] = val_i;
307
+ }`}
297
308
  }
298
309
 
299
310
  @compute @workgroup_size(${workgroupSize})
@@ -370,13 +381,17 @@ ${outputIndices ? `
370
381
  if (j < ${n}u) {
371
382
  let val_i = output[base + i];
372
383
  let val_j = output[base + j];
373
- if (compare(val_j, val_i)) {
384
+ ${outputIndices ? `
385
+ let idx_i = output_idx[base + i];
386
+ let idx_j = output_idx[base + j];
387
+ if (compare(val_j, val_i) || (!compare(val_i, val_j) && idx_j < idx_i)) {
374
388
  output[base + i] = val_j;
375
389
  output[base + j] = val_i;
376
- ${outputIndices ? `
377
- let tmp_idx = output_idx[base + i];
378
- output_idx[base + i] = output_idx[base + j];
379
- output_idx[base + j] = tmp_idx;` : ""}
390
+ output_idx[base + i] = idx_j;
391
+ output_idx[base + j] = idx_i;` : `
392
+ if (compare(val_j, val_i)) {
393
+ output[base + i] = val_j;
394
+ output[base + j] = val_i;`}
380
395
  }
381
396
  }
382
397
  }
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-B3foXiV_.cjs');
1
+ const require_backend = require('./backend-DpI0riom.cjs');
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -247,6 +247,10 @@ function bitonicSortUniform(pass) {
247
247
  * `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
248
248
  *
249
249
  * The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
250
+ *
251
+ * If `outputIndices` is true, the shader also tracks the original indices of
252
+ * the sorted elements (argsort) and outputs them to a separate buffer. This
253
+ * also makes the sorting algorithm stable.
250
254
  */
251
255
  function bitonicSortShader(device, dtype, n, batches, outputIndices) {
252
256
  const ty = dtypeToWgsl(dtype, true);
@@ -286,14 +290,21 @@ ${require_backend.isFloatDtype(dtype) ? `
286
290
  fn compare_and_swap(i: u32, j: u32) {
287
291
  let val_i = shared_vals[i];
288
292
  let val_j = shared_vals[j];
289
- if (compare(val_j, val_i)) {
293
+ ${outputIndices ? `
294
+ if (
295
+ compare(val_j, val_i) ||
296
+ (!compare(val_i, val_j) && shared_idx[j] < shared_idx[i])
297
+ ) {
290
298
  shared_vals[i] = val_j;
291
299
  shared_vals[j] = val_i;
292
- ${outputIndices ? `
293
300
  let tmp_idx = shared_idx[i];
294
301
  shared_idx[i] = shared_idx[j];
295
- shared_idx[j] = tmp_idx;` : ""}
296
- }
302
+ shared_idx[j] = tmp_idx;
303
+ }` : `
304
+ if (compare(val_j, val_i)) {
305
+ shared_vals[i] = val_j;
306
+ shared_vals[j] = val_i;
307
+ }`}
297
308
  }
298
309
 
299
310
  @compute @workgroup_size(${workgroupSize})
@@ -370,13 +381,17 @@ ${outputIndices ? `
370
381
  if (j < ${n}u) {
371
382
  let val_i = output[base + i];
372
383
  let val_j = output[base + j];
373
- if (compare(val_j, val_i)) {
384
+ ${outputIndices ? `
385
+ let idx_i = output_idx[base + i];
386
+ let idx_j = output_idx[base + j];
387
+ if (compare(val_j, val_i) || (!compare(val_i, val_j) && idx_j < idx_i)) {
374
388
  output[base + i] = val_j;
375
389
  output[base + j] = val_i;
376
- ${outputIndices ? `
377
- let tmp_idx = output_idx[base + i];
378
- output_idx[base + i] = output_idx[base + j];
379
- output_idx[base + j] = tmp_idx;` : ""}
390
+ output_idx[base + i] = idx_j;
391
+ output_idx[base + j] = idx_i;` : `
392
+ if (compare(val_j, val_i)) {
393
+ output[base + i] = val_j;
394
+ output[base + j] = val_i;`}
380
395
  }
381
396
  }
382
397
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.1.8",
3
+ "version": "0.1.9",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",