@jax-js/jax 0.1.8 → 0.1.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -13
- package/dist/{backend-nEolvdLv.js → backend-BId79r5b.js} +17 -6
- package/dist/{backend-B3foXiV_.cjs → backend-DpI0riom.cjs} +17 -6
- package/dist/index.cjs +107 -29
- package/dist/index.d.cts +64 -21
- package/dist/index.d.ts +64 -21
- package/dist/index.js +107 -29
- package/dist/{webgl-DIIbKJ0G.cjs → webgl-C5NjXc1p.cjs} +1 -1
- package/dist/{webgl-DweKSWEm.js → webgl-DnGrclTz.js} +1 -1
- package/dist/{webgpu-B96vzWGE.js → webgpu-AN0cG_nB.js} +25 -10
- package/dist/{webgpu-BykvF26B.cjs → webgpu-CdjiJSa7.cjs} +25 -10
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -43,6 +43,23 @@ way to get started on a blank HTML page.
|
|
|
43
43
|
</script>
|
|
44
44
|
```
|
|
45
45
|
|
|
46
|
+
## Examples
|
|
47
|
+
|
|
48
|
+
Cool things that the community has made with jax-js:
|
|
49
|
+
|
|
50
|
+
- [**tanh.xyz**: Interactive ML visualizations](https://tanh.xyz/)
|
|
51
|
+
|
|
52
|
+
And some more demos on the official website.
|
|
53
|
+
|
|
54
|
+
- [Training neural networks on MNIST](https://jax-js.com/mnist)
|
|
55
|
+
- [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
|
|
56
|
+
- [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
|
|
57
|
+
- [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
|
|
58
|
+
- [In-browser REPL](https://jax-js.com/repl)
|
|
59
|
+
- [Matmul benchmark](https://jax-js.com/bench/matmul)
|
|
60
|
+
- [Conv2d benchmark](https://jax-js.com/bench/conv2d)
|
|
61
|
+
- [Mandelbrot set](https://jax-js.com/mandelbrot)
|
|
62
|
+
|
|
46
63
|
## Feature comparison
|
|
47
64
|
|
|
48
65
|
Here's a quick, high-level comparison with other popular web ML runtimes:
|
|
@@ -338,19 +355,6 @@ well as unique optimizations such as FlashAttention variants.
|
|
|
338
355
|
That's all for this short tutorial. Please see the generated
|
|
339
356
|
[API reference](https://jax-js.com/docs) for detailed documentation.
|
|
340
357
|
|
|
341
|
-
## Examples
|
|
342
|
-
|
|
343
|
-
If you make something cool with jax-js, don't be a stranger! We can feature it here.
|
|
344
|
-
|
|
345
|
-
- [Training neural networks on MNIST](https://jax-js.com/mnist)
|
|
346
|
-
- [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
|
|
347
|
-
- [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
|
|
348
|
-
- [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
|
|
349
|
-
- [In-browser REPL](https://jax-js.com/repl)
|
|
350
|
-
- [Matmul benchmark](https://jax-js.com/bench/matmul)
|
|
351
|
-
- [Conv2d benchmark](https://jax-js.com/bench/conv2d)
|
|
352
|
-
- [Mandelbrot set](https://jax-js.com/mandelbrot)
|
|
353
|
-
|
|
354
358
|
## Development
|
|
355
359
|
|
|
356
360
|
_The following technical details are for contributing to jax-js and modifying its internals._
|
|
@@ -1479,9 +1479,14 @@ var Routine = class {
|
|
|
1479
1479
|
};
|
|
1480
1480
|
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
1481
1481
|
let Routines = /* @__PURE__ */ function(Routines$1) {
|
|
1482
|
-
/**
|
|
1482
|
+
/**
|
|
1483
|
+
* Sort along the last axis.
|
|
1484
|
+
*
|
|
1485
|
+
* This may be _unstable_ but it often doesn't matter, sorting numbers is
|
|
1486
|
+
* bitwise unique up to signed zeros and NaNs.
|
|
1487
|
+
*/
|
|
1483
1488
|
Routines$1["Sort"] = "Sort";
|
|
1484
|
-
/**
|
|
1489
|
+
/** Stable sorting, returns `int32` indices and values of the sorted array. */
|
|
1485
1490
|
Routines$1["Argsort"] = "Argsort";
|
|
1486
1491
|
/**
|
|
1487
1492
|
* Solve a triangular system of equations.
|
|
@@ -1545,7 +1550,13 @@ function runArgsort(type, [x], [y, yi]) {
|
|
|
1545
1550
|
const out = y.subarray(offset, offset + n);
|
|
1546
1551
|
const outi = yi.subarray(offset, offset + n);
|
|
1547
1552
|
for (let i = 0; i < n; i++) outi[i] = i;
|
|
1548
|
-
outi.sort((a, b) =>
|
|
1553
|
+
outi.sort((a, b) => {
|
|
1554
|
+
const x$1 = ar[a];
|
|
1555
|
+
const y$1 = ar[b];
|
|
1556
|
+
if (isNaN(x$1)) return isNaN(y$1) ? 0 : 1;
|
|
1557
|
+
if (isNaN(y$1)) return -1;
|
|
1558
|
+
return x$1 === y$1 ? 0 : x$1 < y$1 ? -1 : 1;
|
|
1559
|
+
});
|
|
1549
1560
|
for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
|
|
1550
1561
|
}
|
|
1551
1562
|
}
|
|
@@ -2321,7 +2332,7 @@ function tuneWebgpu(kernel) {
|
|
|
2321
2332
|
if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
2322
2333
|
const s = dim.st.shape[dim.unroll - 1];
|
|
2323
2334
|
if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2324
|
-
else for (const splits of [
|
|
2335
|
+
else for (const splits of [4, 2]) if (s % splits === 0) {
|
|
2325
2336
|
dim.applyUnroll(dim.unroll - 1, splits);
|
|
2326
2337
|
break;
|
|
2327
2338
|
}
|
|
@@ -4252,7 +4263,7 @@ async function createBackend(device) {
|
|
|
4252
4263
|
if (!navigator.gpu) return null;
|
|
4253
4264
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4254
4265
|
if (!adapter) return null;
|
|
4255
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
4266
|
+
const { WebGPUBackend } = await import("./webgpu-AN0cG_nB.js");
|
|
4256
4267
|
const importantLimits = [
|
|
4257
4268
|
"maxBufferSize",
|
|
4258
4269
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4290,7 +4301,7 @@ async function createBackend(device) {
|
|
|
4290
4301
|
});
|
|
4291
4302
|
if (!gl) return null;
|
|
4292
4303
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4293
|
-
const { WebGLBackend } = await import("./webgl-
|
|
4304
|
+
const { WebGLBackend } = await import("./webgl-DnGrclTz.js");
|
|
4294
4305
|
return new WebGLBackend(gl);
|
|
4295
4306
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4296
4307
|
}
|
|
@@ -1480,9 +1480,14 @@ var Routine = class {
|
|
|
1480
1480
|
};
|
|
1481
1481
|
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
1482
1482
|
let Routines = /* @__PURE__ */ function(Routines$1) {
|
|
1483
|
-
/**
|
|
1483
|
+
/**
|
|
1484
|
+
* Sort along the last axis.
|
|
1485
|
+
*
|
|
1486
|
+
* This may be _unstable_ but it often doesn't matter, sorting numbers is
|
|
1487
|
+
* bitwise unique up to signed zeros and NaNs.
|
|
1488
|
+
*/
|
|
1484
1489
|
Routines$1["Sort"] = "Sort";
|
|
1485
|
-
/**
|
|
1490
|
+
/** Stable sorting, returns `int32` indices and values of the sorted array. */
|
|
1486
1491
|
Routines$1["Argsort"] = "Argsort";
|
|
1487
1492
|
/**
|
|
1488
1493
|
* Solve a triangular system of equations.
|
|
@@ -1546,7 +1551,13 @@ function runArgsort(type, [x], [y, yi]) {
|
|
|
1546
1551
|
const out = y.subarray(offset, offset + n);
|
|
1547
1552
|
const outi = yi.subarray(offset, offset + n);
|
|
1548
1553
|
for (let i = 0; i < n; i++) outi[i] = i;
|
|
1549
|
-
outi.sort((a, b) =>
|
|
1554
|
+
outi.sort((a, b) => {
|
|
1555
|
+
const x$1 = ar[a];
|
|
1556
|
+
const y$1 = ar[b];
|
|
1557
|
+
if (isNaN(x$1)) return isNaN(y$1) ? 0 : 1;
|
|
1558
|
+
if (isNaN(y$1)) return -1;
|
|
1559
|
+
return x$1 === y$1 ? 0 : x$1 < y$1 ? -1 : 1;
|
|
1560
|
+
});
|
|
1550
1561
|
for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
|
|
1551
1562
|
}
|
|
1552
1563
|
}
|
|
@@ -2322,7 +2333,7 @@ function tuneWebgpu(kernel) {
|
|
|
2322
2333
|
if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
2323
2334
|
const s = dim.st.shape[dim.unroll - 1];
|
|
2324
2335
|
if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2325
|
-
else for (const splits of [
|
|
2336
|
+
else for (const splits of [4, 2]) if (s % splits === 0) {
|
|
2326
2337
|
dim.applyUnroll(dim.unroll - 1, splits);
|
|
2327
2338
|
break;
|
|
2328
2339
|
}
|
|
@@ -4253,7 +4264,7 @@ async function createBackend(device) {
|
|
|
4253
4264
|
if (!navigator.gpu) return null;
|
|
4254
4265
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4255
4266
|
if (!adapter) return null;
|
|
4256
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
4267
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-CdjiJSa7.cjs"));
|
|
4257
4268
|
const importantLimits = [
|
|
4258
4269
|
"maxBufferSize",
|
|
4259
4270
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4291,7 +4302,7 @@ async function createBackend(device) {
|
|
|
4291
4302
|
});
|
|
4292
4303
|
if (!gl) return null;
|
|
4293
4304
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4294
|
-
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-
|
|
4305
|
+
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-C5NjXc1p.cjs"));
|
|
4295
4306
|
return new WebGLBackend(gl);
|
|
4296
4307
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4297
4308
|
}
|
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-DpI0riom.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -920,18 +920,25 @@ var Tracer = class Tracer {
|
|
|
920
920
|
return sort$1(this.transpose(perm)).transpose(require_backend.invertPermutation(perm));
|
|
921
921
|
}
|
|
922
922
|
/**
|
|
923
|
-
* Return the indices that would sort an array.
|
|
924
|
-
* sorting algorithm; it
|
|
923
|
+
* Return the indices that would sort an array. Unlike `sort`, this is
|
|
924
|
+
* guaranteed to be a stable sorting algorithm; it always returns the smaller
|
|
925
|
+
* index first in event of ties.
|
|
925
926
|
*
|
|
926
927
|
* See `jax.numpy.argsort` for full docs.
|
|
927
928
|
*/
|
|
928
929
|
argsort(axis = -1) {
|
|
929
930
|
axis = require_backend.checkAxis(axis, this.ndim);
|
|
930
|
-
if (axis === this.ndim - 1)
|
|
931
|
+
if (axis === this.ndim - 1) {
|
|
932
|
+
const [y$1, yi$1] = argsort$1(this);
|
|
933
|
+
y$1.dispose();
|
|
934
|
+
return yi$1;
|
|
935
|
+
}
|
|
931
936
|
const perm = require_backend.range(this.ndim);
|
|
932
937
|
perm.splice(axis, 1);
|
|
933
938
|
perm.push(axis);
|
|
934
|
-
|
|
939
|
+
const [y, yi] = argsort$1(this.transpose(perm));
|
|
940
|
+
y.dispose();
|
|
941
|
+
return yi.transpose(require_backend.invertPermutation(perm));
|
|
935
942
|
}
|
|
936
943
|
/**
|
|
937
944
|
* Slice an array along one or more axes.
|
|
@@ -3416,32 +3423,26 @@ function fullInternal(aval, fillValue, device) {
|
|
|
3416
3423
|
committed: device != void 0
|
|
3417
3424
|
});
|
|
3418
3425
|
}
|
|
3419
|
-
function zerosLike$1(val,
|
|
3420
|
-
return fullLike(val, 0,
|
|
3426
|
+
function zerosLike$1(val, opts) {
|
|
3427
|
+
return fullLike(val, 0, opts);
|
|
3421
3428
|
}
|
|
3422
|
-
function onesLike$1(val,
|
|
3423
|
-
return fullLike(val, 1,
|
|
3429
|
+
function onesLike$1(val, opts) {
|
|
3430
|
+
return fullLike(val, 1, opts);
|
|
3424
3431
|
}
|
|
3425
|
-
function fullLike(val, fillValue, dtype) {
|
|
3432
|
+
function fullLike(val, fillValue, { dtype, shape: shape$1, device } = {}) {
|
|
3426
3433
|
const aval = getAval(val);
|
|
3427
3434
|
if (val instanceof Tracer) val.dispose();
|
|
3428
3435
|
if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
|
|
3429
|
-
const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
|
|
3430
|
-
return fullInternal(sa, fillValue);
|
|
3436
|
+
const sa = new ShapedArray(shape$1 ?? aval.shape, dtype ?? aval.dtype, aval.weakType && dtype === void 0);
|
|
3437
|
+
return fullInternal(sa, fillValue, device);
|
|
3431
3438
|
}
|
|
3432
3439
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
3433
|
-
function zeros(shape$1,
|
|
3434
|
-
return full(shape$1, 0,
|
|
3435
|
-
dtype,
|
|
3436
|
-
device
|
|
3437
|
-
});
|
|
3440
|
+
function zeros(shape$1, opts) {
|
|
3441
|
+
return full(shape$1, 0, opts);
|
|
3438
3442
|
}
|
|
3439
3443
|
/** Return a new array of given shape and type, filled with ones. */
|
|
3440
|
-
function ones(shape$1,
|
|
3441
|
-
return full(shape$1, 1,
|
|
3442
|
-
dtype,
|
|
3443
|
-
device
|
|
3444
|
-
});
|
|
3444
|
+
function ones(shape$1, opts) {
|
|
3445
|
+
return full(shape$1, 1, opts);
|
|
3445
3446
|
}
|
|
3446
3447
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
3447
3448
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
@@ -5332,7 +5333,7 @@ function lstsq(a, b) {
|
|
|
5332
5333
|
lower: true,
|
|
5333
5334
|
transposeA: true
|
|
5334
5335
|
});
|
|
5335
|
-
return matmul(at, llb
|
|
5336
|
+
return matmul(at, llb);
|
|
5336
5337
|
} else {
|
|
5337
5338
|
const ata = matmul(at.ref, a);
|
|
5338
5339
|
const l = cholesky(ata, { symmetrizeInput: false });
|
|
@@ -5423,7 +5424,7 @@ function solve(a, b) {
|
|
|
5423
5424
|
lower: true,
|
|
5424
5425
|
unitDiagonal: true
|
|
5425
5426
|
});
|
|
5426
|
-
let x = triangularSolve(lu$2, LPb
|
|
5427
|
+
let x = triangularSolve(lu$2, LPb, {
|
|
5427
5428
|
leftSide: true,
|
|
5428
5429
|
lower: false
|
|
5429
5430
|
});
|
|
@@ -6234,8 +6235,9 @@ function sort(a, axis = -1) {
|
|
|
6234
6235
|
return fudgeArray(a).sort(axis);
|
|
6235
6236
|
}
|
|
6236
6237
|
/**
|
|
6237
|
-
* Return indices that would sort an array.
|
|
6238
|
-
* algorithm; it
|
|
6238
|
+
* Return indices that would sort an array. Unlike `sort`, this is guaranteed to
|
|
6239
|
+
* be a stable sorting algorithm; it always returns the smaller index first in
|
|
6240
|
+
* event of ties.
|
|
6239
6241
|
*
|
|
6240
6242
|
* Returns an array of `int32` indices.
|
|
6241
6243
|
*
|
|
@@ -6537,7 +6539,7 @@ function absolute(x) {
|
|
|
6537
6539
|
/** Return an element-wise indication of sign of the input. */
|
|
6538
6540
|
function sign(x) {
|
|
6539
6541
|
x = fudgeArray(x);
|
|
6540
|
-
return where(notEqual(x.ref, 0), where(less(x
|
|
6542
|
+
return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
|
|
6541
6543
|
}
|
|
6542
6544
|
/** @function Return element-wise positive values of the input (no-op). */
|
|
6543
6545
|
const positive = fudgeArray;
|
|
@@ -7030,7 +7032,8 @@ __export(lax_exports, {
|
|
|
7030
7032
|
erfc: () => erfc,
|
|
7031
7033
|
linalg: () => lax_linalg_exports,
|
|
7032
7034
|
reduceWindow: () => reduceWindow,
|
|
7033
|
-
stopGradient: () => stopGradient$1
|
|
7035
|
+
stopGradient: () => stopGradient$1,
|
|
7036
|
+
topK: () => topK
|
|
7034
7037
|
});
|
|
7035
7038
|
const JsArray = globalThis.Array;
|
|
7036
7039
|
/**
|
|
@@ -7254,6 +7257,39 @@ function erfc(x) {
|
|
|
7254
7257
|
function stopGradient$1(x) {
|
|
7255
7258
|
return stopGradient(x);
|
|
7256
7259
|
}
|
|
7260
|
+
/**
|
|
7261
|
+
* Returns top `k` values and their indices along the specified axis of operand.
|
|
7262
|
+
*
|
|
7263
|
+
* This is a _stable_ algorithm: If two elements are equal, the lower-index
|
|
7264
|
+
* element appears first.
|
|
7265
|
+
*
|
|
7266
|
+
* @returns A tuple of `(values, indices)`, where `values` and `indices` have
|
|
7267
|
+
* the same shape as `x`, except along `axis` where they have size `k`.
|
|
7268
|
+
*/
|
|
7269
|
+
function topK(x, k, axis = -1) {
|
|
7270
|
+
x = fudgeArray(x);
|
|
7271
|
+
axis = require_backend.checkAxis(axis, x.ndim);
|
|
7272
|
+
const size$1 = x.shape[axis];
|
|
7273
|
+
if (k < 0 || k > size$1) throw new Error(`topK: k must be in the range [0, ${size$1}], got ${k}`);
|
|
7274
|
+
if (k === 0) {
|
|
7275
|
+
const outShape = x.shape.slice();
|
|
7276
|
+
outShape[axis] = 0;
|
|
7277
|
+
const y$1 = zerosLike$1(x.ref, { shape: outShape });
|
|
7278
|
+
const yi$1 = zerosLike$1(x, {
|
|
7279
|
+
dtype: require_backend.DType.Int32,
|
|
7280
|
+
shape: outShape
|
|
7281
|
+
});
|
|
7282
|
+
return [y$1, yi$1];
|
|
7283
|
+
}
|
|
7284
|
+
x = flip$1(x, [axis]);
|
|
7285
|
+
x = moveaxis(x, axis, -1);
|
|
7286
|
+
const [y, yi] = argsort$1(x);
|
|
7287
|
+
const extract = (a) => {
|
|
7288
|
+
a = a.slice(...require_backend.rep(a.ndim - 1, []), [-k]);
|
|
7289
|
+
return flip$1(moveaxis(a, -1, axis), [axis]);
|
|
7290
|
+
};
|
|
7291
|
+
return [extract(y), extract(yi.neg().add(size$1 - 1))];
|
|
7292
|
+
}
|
|
7257
7293
|
|
|
7258
7294
|
//#endregion
|
|
7259
7295
|
//#region src/library/nn.ts
|
|
@@ -7445,7 +7481,7 @@ const gelu = jit$1(function gelu$1(x, opts) {
|
|
|
7445
7481
|
if (opts?.approximate ?? true) {
|
|
7446
7482
|
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
7447
7483
|
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
7448
|
-
} else return x.ref.mul(.5).mul(erfc$1(negative(x.
|
|
7484
|
+
} else return x.ref.mul(.5).mul(erfc$1(negative(x.mul(Math.SQRT1_2))));
|
|
7449
7485
|
}, { staticArgnums: [1] });
|
|
7450
7486
|
/**
|
|
7451
7487
|
* Gated linear unit (GLU) activation function.
|
|
@@ -7703,6 +7739,7 @@ var random_exports = {};
|
|
|
7703
7739
|
__export(random_exports, {
|
|
7704
7740
|
bernoulli: () => bernoulli,
|
|
7705
7741
|
bits: () => bits,
|
|
7742
|
+
categorical: () => categorical,
|
|
7706
7743
|
cauchy: () => cauchy,
|
|
7707
7744
|
exponential: () => exponential,
|
|
7708
7745
|
gumbel: () => gumbel,
|
|
@@ -7774,6 +7811,47 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
|
7774
7811
|
}
|
|
7775
7812
|
/**
|
|
7776
7813
|
* @function
|
|
7814
|
+
* Sample random values from categorical distributions.
|
|
7815
|
+
*
|
|
7816
|
+
* Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
|
|
7817
|
+
* trick for sampling without replacement.
|
|
7818
|
+
*
|
|
7819
|
+
* Note: Sampling without replacement currently uses argsort and slices the last
|
|
7820
|
+
* k elements. This should be replaced with a more efficient topK implementation.
|
|
7821
|
+
*
|
|
7822
|
+
* - `key` - PRNG key
|
|
7823
|
+
* - `logits` - Unnormalized log probabilities of the categorical distribution(s).
|
|
7824
|
+
* `softmax(logits, axis)` gives the corresponding probabilities.
|
|
7825
|
+
* - `axis` - Axis along which logits belong to the same categorical distribution.
|
|
7826
|
+
* - `shape` - Result batch shape. Must be broadcast-compatible with
|
|
7827
|
+
* `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
|
|
7828
|
+
* - `replace` - If true (default), sample with replacement. If false, sample
|
|
7829
|
+
* without replacement (each category can only be selected once per batch).
|
|
7830
|
+
* @returns A random array with int dtype and shape given by `shape` if provided,
|
|
7831
|
+
* otherwise `logits.shape` with `axis` removed.
|
|
7832
|
+
*/
|
|
7833
|
+
const categorical = jit$1(function categorical$1(key$1, logits, { axis = -1, shape: shape$1, replace = true } = {}) {
|
|
7834
|
+
logits = fudgeArray(logits);
|
|
7835
|
+
axis = require_backend.checkAxis(axis, logits.ndim);
|
|
7836
|
+
const numCategories = logits.shape[axis];
|
|
7837
|
+
const batchShape = logits.shape.toSpliced(axis, 1);
|
|
7838
|
+
if (shape$1 === void 0) shape$1 = batchShape;
|
|
7839
|
+
else if (!require_backend.deepEqual(require_backend.generalBroadcast(shape$1, batchShape), shape$1)) throw new Error(`Shape ${shape$1} is not broadcast-compatible with batch shape ${batchShape}.`);
|
|
7840
|
+
const shapePrefix = shape$1.slice(0, shape$1.length - batchShape.length);
|
|
7841
|
+
if (replace) {
|
|
7842
|
+
const noise = gumbel(key$1, [...shapePrefix, ...logits.shape]);
|
|
7843
|
+
return argmax(noise.add(logits), axis + shapePrefix.length);
|
|
7844
|
+
} else {
|
|
7845
|
+
const k = shapePrefix.reduce((a, b) => a * b, 1);
|
|
7846
|
+
if (k > numCategories) throw new Error(`Number of samples without replacement (${k}) cannot exceed number of categories (${numCategories}).`);
|
|
7847
|
+
const noise = gumbel(key$1, logits.shape);
|
|
7848
|
+
const [values, indices] = topK(noise.add(logits), k, axis);
|
|
7849
|
+
values.dispose();
|
|
7850
|
+
return indices.reshape(shape$1);
|
|
7851
|
+
}
|
|
7852
|
+
}, { staticArgnums: [2] });
|
|
7853
|
+
/**
|
|
7854
|
+
* @function
|
|
7777
7855
|
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
7778
7856
|
*
|
|
7779
7857
|
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
package/dist/index.d.cts
CHANGED
|
@@ -436,9 +436,14 @@ declare class Routine {
|
|
|
436
436
|
}
|
|
437
437
|
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
438
438
|
declare enum Routines {
|
|
439
|
-
/**
|
|
439
|
+
/**
|
|
440
|
+
* Sort along the last axis.
|
|
441
|
+
*
|
|
442
|
+
* This may be _unstable_ but it often doesn't matter, sorting numbers is
|
|
443
|
+
* bitwise unique up to signed zeros and NaNs.
|
|
444
|
+
*/
|
|
440
445
|
Sort = "Sort",
|
|
441
|
-
/**
|
|
446
|
+
/** Stable sorting, returns `int32` indices and values of the sorted array. */
|
|
442
447
|
Argsort = "Argsort",
|
|
443
448
|
/**
|
|
444
449
|
* Solve a triangular system of equations.
|
|
@@ -750,9 +755,9 @@ declare enum Primitive {
|
|
|
750
755
|
Shrink = "shrink",
|
|
751
756
|
Pad = "pad",
|
|
752
757
|
Sort = "sort",
|
|
753
|
-
// sort(x, axis=-1)
|
|
758
|
+
// sort(x, axis=-1), unstable
|
|
754
759
|
Argsort = "argsort",
|
|
755
|
-
// argsort(x, axis=-1)
|
|
760
|
+
// argsort(x, axis=-1), stable
|
|
756
761
|
TriangularSolve = "triangular_solve",
|
|
757
762
|
// A is upper triangular, A @ X.T = B.T
|
|
758
763
|
Cholesky = "cholesky",
|
|
@@ -1029,8 +1034,9 @@ declare abstract class Tracer {
|
|
|
1029
1034
|
*/
|
|
1030
1035
|
sort(axis?: number): this;
|
|
1031
1036
|
/**
|
|
1032
|
-
* Return the indices that would sort an array.
|
|
1033
|
-
* sorting algorithm; it
|
|
1037
|
+
* Return the indices that would sort an array. Unlike `sort`, this is
|
|
1038
|
+
* guaranteed to be a stable sorting algorithm; it always returns the smaller
|
|
1039
|
+
* index first in event of ties.
|
|
1034
1040
|
*
|
|
1035
1041
|
* See `jax.numpy.argsort` for full docs.
|
|
1036
1042
|
*/
|
|
@@ -1112,6 +1118,12 @@ type DTypeAndDevice = {
|
|
|
1112
1118
|
dtype?: DType;
|
|
1113
1119
|
device?: Device;
|
|
1114
1120
|
};
|
|
1121
|
+
/** @inline */
|
|
1122
|
+
type DTypeShapeAndDevice = {
|
|
1123
|
+
dtype?: DType;
|
|
1124
|
+
shape?: number[];
|
|
1125
|
+
device?: Device;
|
|
1126
|
+
};
|
|
1115
1127
|
type ArrayConstructorArgs = {
|
|
1116
1128
|
source: AluExp | Slot;
|
|
1117
1129
|
st: ShapeTracker;
|
|
@@ -1221,15 +1233,9 @@ declare function array(values: Array | DataArray | RecursiveArray<number> | Recu
|
|
|
1221
1233
|
type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
|
|
1222
1234
|
declare const implRules: { [P in Primitive]: ImplRule<P> };
|
|
1223
1235
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
1224
|
-
declare function zeros(shape: number[],
|
|
1225
|
-
dtype,
|
|
1226
|
-
device
|
|
1227
|
-
}?: DTypeAndDevice): Array;
|
|
1236
|
+
declare function zeros(shape: number[], opts?: DTypeAndDevice): Array;
|
|
1228
1237
|
/** Return a new array of given shape and type, filled with ones. */
|
|
1229
|
-
declare function ones(shape: number[],
|
|
1230
|
-
dtype,
|
|
1231
|
-
device
|
|
1232
|
-
}?: DTypeAndDevice): Array;
|
|
1238
|
+
declare function ones(shape: number[], opts?: DTypeAndDevice): Array;
|
|
1233
1239
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
1234
1240
|
declare function full(shape: number[], fillValue: number | boolean | Array, {
|
|
1235
1241
|
dtype,
|
|
@@ -1421,7 +1427,7 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
|
1421
1427
|
unitDiagonal?: boolean;
|
|
1422
1428
|
}): Array;
|
|
1423
1429
|
declare namespace lax_d_exports {
|
|
1424
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
|
|
1430
|
+
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
|
|
1425
1431
|
}
|
|
1426
1432
|
/**
|
|
1427
1433
|
* Dimension numbers for general `dot()` primitive.
|
|
@@ -1527,6 +1533,16 @@ declare function erfc(x: ArrayLike): Array;
|
|
|
1527
1533
|
* forward or reverse-mode automatic differentiation.
|
|
1528
1534
|
*/
|
|
1529
1535
|
declare function stopGradient(x: ArrayLike): Array;
|
|
1536
|
+
/**
|
|
1537
|
+
* Returns top `k` values and their indices along the specified axis of operand.
|
|
1538
|
+
*
|
|
1539
|
+
* This is a _stable_ algorithm: If two elements are equal, the lower-index
|
|
1540
|
+
* element appears first.
|
|
1541
|
+
*
|
|
1542
|
+
* @returns A tuple of `(values, indices)`, where `values` and `indices` have
|
|
1543
|
+
* the same shape as `x`, except along `axis` where they have size `k`.
|
|
1544
|
+
*/
|
|
1545
|
+
declare function topK(x: ArrayLike, k: number, axis?: number): [Array, Array];
|
|
1530
1546
|
declare namespace numpy_fft_d_exports {
|
|
1531
1547
|
export { ComplexPair, fft, ifft };
|
|
1532
1548
|
}
|
|
@@ -1752,17 +1768,17 @@ declare const shape$1: (x: ArrayLike) => number[];
|
|
|
1752
1768
|
* @function
|
|
1753
1769
|
* Return an array of zeros with the same shape and type as a given array.
|
|
1754
1770
|
*/
|
|
1755
|
-
declare const zerosLike: (a: ArrayLike,
|
|
1771
|
+
declare const zerosLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
|
|
1756
1772
|
/**
|
|
1757
1773
|
* @function
|
|
1758
1774
|
* Return an array of ones with the same shape and type as a given array.
|
|
1759
1775
|
*/
|
|
1760
|
-
declare const onesLike: (a: ArrayLike,
|
|
1776
|
+
declare const onesLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
|
|
1761
1777
|
/**
|
|
1762
1778
|
* @function
|
|
1763
1779
|
* Return a full array with the same shape and type as a given array.
|
|
1764
1780
|
*/
|
|
1765
|
-
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array,
|
|
1781
|
+
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, opts?: DTypeShapeAndDevice) => Array;
|
|
1766
1782
|
/**
|
|
1767
1783
|
* Return the number of elements in an array, optionally along an axis.
|
|
1768
1784
|
* Does not consume array reference.
|
|
@@ -1951,8 +1967,9 @@ declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: nu
|
|
|
1951
1967
|
*/
|
|
1952
1968
|
declare function sort(a: ArrayLike, axis?: number): Array;
|
|
1953
1969
|
/**
|
|
1954
|
-
* Return indices that would sort an array.
|
|
1955
|
-
* algorithm; it
|
|
1970
|
+
* Return indices that would sort an array. Unlike `sort`, this is guaranteed to
|
|
1971
|
+
* be a stable sorting algorithm; it always returns the smaller index first in
|
|
1972
|
+
* event of ties.
|
|
1956
1973
|
*
|
|
1957
1974
|
* Returns an array of `int32` indices.
|
|
1958
1975
|
*
|
|
@@ -2564,7 +2581,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
|
|
|
2564
2581
|
localWindowSize?: number | [number, number];
|
|
2565
2582
|
}): Array;
|
|
2566
2583
|
declare namespace random_d_exports {
|
|
2567
|
-
export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2584
|
+
export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2568
2585
|
}
|
|
2569
2586
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
2570
2587
|
declare function key(seed: ArrayLike): Array;
|
|
@@ -2587,6 +2604,32 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
|
|
|
2587
2604
|
* and must be broadcastable to `shape`.
|
|
2588
2605
|
*/
|
|
2589
2606
|
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
2607
|
+
/**
|
|
2608
|
+
* @function
|
|
2609
|
+
* Sample random values from categorical distributions.
|
|
2610
|
+
*
|
|
2611
|
+
* Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
|
|
2612
|
+
* trick for sampling without replacement.
|
|
2613
|
+
*
|
|
2614
|
+
* Note: Sampling without replacement currently uses argsort and slices the last
|
|
2615
|
+
* k elements. This should be replaced with a more efficient topK implementation.
|
|
2616
|
+
*
|
|
2617
|
+
* - `key` - PRNG key
|
|
2618
|
+
* - `logits` - Unnormalized log probabilities of the categorical distribution(s).
|
|
2619
|
+
* `softmax(logits, axis)` gives the corresponding probabilities.
|
|
2620
|
+
* - `axis` - Axis along which logits belong to the same categorical distribution.
|
|
2621
|
+
* - `shape` - Result batch shape. Must be broadcast-compatible with
|
|
2622
|
+
* `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
|
|
2623
|
+
* - `replace` - If true (default), sample with replacement. If false, sample
|
|
2624
|
+
* without replacement (each category can only be selected once per batch).
|
|
2625
|
+
* @returns A random array with int dtype and shape given by `shape` if provided,
|
|
2626
|
+
* otherwise `logits.shape` with `axis` removed.
|
|
2627
|
+
*/
|
|
2628
|
+
declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, args_2?: {
|
|
2629
|
+
axis?: number | undefined;
|
|
2630
|
+
shape?: number[] | undefined;
|
|
2631
|
+
replace?: boolean | undefined;
|
|
2632
|
+
} | undefined) => Array>;
|
|
2590
2633
|
/**
|
|
2591
2634
|
* @function
|
|
2592
2635
|
* Sample from a Cauchy distribution with location 0 and scale 1.
|
package/dist/index.d.ts
CHANGED
|
@@ -433,9 +433,14 @@ declare class Routine {
|
|
|
433
433
|
}
|
|
434
434
|
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
435
435
|
declare enum Routines {
|
|
436
|
-
/**
|
|
436
|
+
/**
|
|
437
|
+
* Sort along the last axis.
|
|
438
|
+
*
|
|
439
|
+
* This may be _unstable_ but it often doesn't matter, sorting numbers is
|
|
440
|
+
* bitwise unique up to signed zeros and NaNs.
|
|
441
|
+
*/
|
|
437
442
|
Sort = "Sort",
|
|
438
|
-
/**
|
|
443
|
+
/** Stable sorting, returns `int32` indices and values of the sorted array. */
|
|
439
444
|
Argsort = "Argsort",
|
|
440
445
|
/**
|
|
441
446
|
* Solve a triangular system of equations.
|
|
@@ -747,9 +752,9 @@ declare enum Primitive {
|
|
|
747
752
|
Shrink = "shrink",
|
|
748
753
|
Pad = "pad",
|
|
749
754
|
Sort = "sort",
|
|
750
|
-
// sort(x, axis=-1)
|
|
755
|
+
// sort(x, axis=-1), unstable
|
|
751
756
|
Argsort = "argsort",
|
|
752
|
-
// argsort(x, axis=-1)
|
|
757
|
+
// argsort(x, axis=-1), stable
|
|
753
758
|
TriangularSolve = "triangular_solve",
|
|
754
759
|
// A is upper triangular, A @ X.T = B.T
|
|
755
760
|
Cholesky = "cholesky",
|
|
@@ -1026,8 +1031,9 @@ declare abstract class Tracer {
|
|
|
1026
1031
|
*/
|
|
1027
1032
|
sort(axis?: number): this;
|
|
1028
1033
|
/**
|
|
1029
|
-
* Return the indices that would sort an array.
|
|
1030
|
-
* sorting algorithm; it
|
|
1034
|
+
* Return the indices that would sort an array. Unlike `sort`, this is
|
|
1035
|
+
* guaranteed to be a stable sorting algorithm; it always returns the smaller
|
|
1036
|
+
* index first in event of ties.
|
|
1031
1037
|
*
|
|
1032
1038
|
* See `jax.numpy.argsort` for full docs.
|
|
1033
1039
|
*/
|
|
@@ -1109,6 +1115,12 @@ type DTypeAndDevice = {
|
|
|
1109
1115
|
dtype?: DType;
|
|
1110
1116
|
device?: Device;
|
|
1111
1117
|
};
|
|
1118
|
+
/** @inline */
|
|
1119
|
+
type DTypeShapeAndDevice = {
|
|
1120
|
+
dtype?: DType;
|
|
1121
|
+
shape?: number[];
|
|
1122
|
+
device?: Device;
|
|
1123
|
+
};
|
|
1112
1124
|
type ArrayConstructorArgs = {
|
|
1113
1125
|
source: AluExp | Slot;
|
|
1114
1126
|
st: ShapeTracker;
|
|
@@ -1218,15 +1230,9 @@ declare function array(values: Array | DataArray | RecursiveArray<number> | Recu
|
|
|
1218
1230
|
type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
|
|
1219
1231
|
declare const implRules: { [P in Primitive]: ImplRule<P> };
|
|
1220
1232
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
1221
|
-
declare function zeros(shape: number[],
|
|
1222
|
-
dtype,
|
|
1223
|
-
device
|
|
1224
|
-
}?: DTypeAndDevice): Array;
|
|
1233
|
+
declare function zeros(shape: number[], opts?: DTypeAndDevice): Array;
|
|
1225
1234
|
/** Return a new array of given shape and type, filled with ones. */
|
|
1226
|
-
declare function ones(shape: number[],
|
|
1227
|
-
dtype,
|
|
1228
|
-
device
|
|
1229
|
-
}?: DTypeAndDevice): Array;
|
|
1235
|
+
declare function ones(shape: number[], opts?: DTypeAndDevice): Array;
|
|
1230
1236
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
1231
1237
|
declare function full(shape: number[], fillValue: number | boolean | Array, {
|
|
1232
1238
|
dtype,
|
|
@@ -1418,7 +1424,7 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
|
1418
1424
|
unitDiagonal?: boolean;
|
|
1419
1425
|
}): Array;
|
|
1420
1426
|
declare namespace lax_d_exports {
|
|
1421
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
|
|
1427
|
+
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
|
|
1422
1428
|
}
|
|
1423
1429
|
/**
|
|
1424
1430
|
* Dimension numbers for general `dot()` primitive.
|
|
@@ -1524,6 +1530,16 @@ declare function erfc(x: ArrayLike): Array;
|
|
|
1524
1530
|
* forward or reverse-mode automatic differentiation.
|
|
1525
1531
|
*/
|
|
1526
1532
|
declare function stopGradient(x: ArrayLike): Array;
|
|
1533
|
+
/**
|
|
1534
|
+
* Returns top `k` values and their indices along the specified axis of operand.
|
|
1535
|
+
*
|
|
1536
|
+
* This is a _stable_ algorithm: If two elements are equal, the lower-index
|
|
1537
|
+
* element appears first.
|
|
1538
|
+
*
|
|
1539
|
+
* @returns A tuple of `(values, indices)`, where `values` and `indices` have
|
|
1540
|
+
* the same shape as `x`, except along `axis` where they have size `k`.
|
|
1541
|
+
*/
|
|
1542
|
+
declare function topK(x: ArrayLike, k: number, axis?: number): [Array, Array];
|
|
1527
1543
|
declare namespace numpy_fft_d_exports {
|
|
1528
1544
|
export { ComplexPair, fft, ifft };
|
|
1529
1545
|
}
|
|
@@ -1749,17 +1765,17 @@ declare const shape$1: (x: ArrayLike) => number[];
|
|
|
1749
1765
|
* @function
|
|
1750
1766
|
* Return an array of zeros with the same shape and type as a given array.
|
|
1751
1767
|
*/
|
|
1752
|
-
declare const zerosLike: (a: ArrayLike,
|
|
1768
|
+
declare const zerosLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
|
|
1753
1769
|
/**
|
|
1754
1770
|
* @function
|
|
1755
1771
|
* Return an array of ones with the same shape and type as a given array.
|
|
1756
1772
|
*/
|
|
1757
|
-
declare const onesLike: (a: ArrayLike,
|
|
1773
|
+
declare const onesLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
|
|
1758
1774
|
/**
|
|
1759
1775
|
* @function
|
|
1760
1776
|
* Return a full array with the same shape and type as a given array.
|
|
1761
1777
|
*/
|
|
1762
|
-
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array,
|
|
1778
|
+
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, opts?: DTypeShapeAndDevice) => Array;
|
|
1763
1779
|
/**
|
|
1764
1780
|
* Return the number of elements in an array, optionally along an axis.
|
|
1765
1781
|
* Does not consume array reference.
|
|
@@ -1948,8 +1964,9 @@ declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: nu
|
|
|
1948
1964
|
*/
|
|
1949
1965
|
declare function sort(a: ArrayLike, axis?: number): Array;
|
|
1950
1966
|
/**
|
|
1951
|
-
* Return indices that would sort an array.
|
|
1952
|
-
* algorithm; it
|
|
1967
|
+
* Return indices that would sort an array. Unlike `sort`, this is guaranteed to
|
|
1968
|
+
* be a stable sorting algorithm; it always returns the smaller index first in
|
|
1969
|
+
* event of ties.
|
|
1953
1970
|
*
|
|
1954
1971
|
* Returns an array of `int32` indices.
|
|
1955
1972
|
*
|
|
@@ -2561,7 +2578,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
|
|
|
2561
2578
|
localWindowSize?: number | [number, number];
|
|
2562
2579
|
}): Array;
|
|
2563
2580
|
declare namespace random_d_exports {
|
|
2564
|
-
export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2581
|
+
export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2565
2582
|
}
|
|
2566
2583
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
2567
2584
|
declare function key(seed: ArrayLike): Array;
|
|
@@ -2584,6 +2601,32 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
|
|
|
2584
2601
|
* and must be broadcastable to `shape`.
|
|
2585
2602
|
*/
|
|
2586
2603
|
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
2604
|
+
/**
|
|
2605
|
+
* @function
|
|
2606
|
+
* Sample random values from categorical distributions.
|
|
2607
|
+
*
|
|
2608
|
+
* Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
|
|
2609
|
+
* trick for sampling without replacement.
|
|
2610
|
+
*
|
|
2611
|
+
* Note: Sampling without replacement currently uses argsort and slices the last
|
|
2612
|
+
* k elements. This should be replaced with a more efficient topK implementation.
|
|
2613
|
+
*
|
|
2614
|
+
* - `key` - PRNG key
|
|
2615
|
+
* - `logits` - Unnormalized log probabilities of the categorical distribution(s).
|
|
2616
|
+
* `softmax(logits, axis)` gives the corresponding probabilities.
|
|
2617
|
+
* - `axis` - Axis along which logits belong to the same categorical distribution.
|
|
2618
|
+
* - `shape` - Result batch shape. Must be broadcast-compatible with
|
|
2619
|
+
* `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
|
|
2620
|
+
* - `replace` - If true (default), sample with replacement. If false, sample
|
|
2621
|
+
* without replacement (each category can only be selected once per batch).
|
|
2622
|
+
* @returns A random array with int dtype and shape given by `shape` if provided,
|
|
2623
|
+
* otherwise `logits.shape` with `axis` removed.
|
|
2624
|
+
*/
|
|
2625
|
+
declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, args_2?: {
|
|
2626
|
+
axis?: number | undefined;
|
|
2627
|
+
shape?: number[] | undefined;
|
|
2628
|
+
replace?: boolean | undefined;
|
|
2629
|
+
} | undefined) => Array>;
|
|
2587
2630
|
/**
|
|
2588
2631
|
* @function
|
|
2589
2632
|
* Sample from a Cauchy distribution with location 0 and scale 1.
|
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-BId79r5b.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -889,18 +889,25 @@ var Tracer = class Tracer {
|
|
|
889
889
|
return sort$1(this.transpose(perm)).transpose(invertPermutation(perm));
|
|
890
890
|
}
|
|
891
891
|
/**
|
|
892
|
-
* Return the indices that would sort an array.
|
|
893
|
-
* sorting algorithm; it
|
|
892
|
+
* Return the indices that would sort an array. Unlike `sort`, this is
|
|
893
|
+
* guaranteed to be a stable sorting algorithm; it always returns the smaller
|
|
894
|
+
* index first in event of ties.
|
|
894
895
|
*
|
|
895
896
|
* See `jax.numpy.argsort` for full docs.
|
|
896
897
|
*/
|
|
897
898
|
argsort(axis = -1) {
|
|
898
899
|
axis = checkAxis(axis, this.ndim);
|
|
899
|
-
if (axis === this.ndim - 1)
|
|
900
|
+
if (axis === this.ndim - 1) {
|
|
901
|
+
const [y$1, yi$1] = argsort$1(this);
|
|
902
|
+
y$1.dispose();
|
|
903
|
+
return yi$1;
|
|
904
|
+
}
|
|
900
905
|
const perm = range(this.ndim);
|
|
901
906
|
perm.splice(axis, 1);
|
|
902
907
|
perm.push(axis);
|
|
903
|
-
|
|
908
|
+
const [y, yi] = argsort$1(this.transpose(perm));
|
|
909
|
+
y.dispose();
|
|
910
|
+
return yi.transpose(invertPermutation(perm));
|
|
904
911
|
}
|
|
905
912
|
/**
|
|
906
913
|
* Slice an array along one or more axes.
|
|
@@ -3381,32 +3388,26 @@ function fullInternal(aval, fillValue, device) {
|
|
|
3381
3388
|
committed: device != void 0
|
|
3382
3389
|
});
|
|
3383
3390
|
}
|
|
3384
|
-
function zerosLike$1(val,
|
|
3385
|
-
return fullLike(val, 0,
|
|
3391
|
+
function zerosLike$1(val, opts) {
|
|
3392
|
+
return fullLike(val, 0, opts);
|
|
3386
3393
|
}
|
|
3387
|
-
function onesLike$1(val,
|
|
3388
|
-
return fullLike(val, 1,
|
|
3394
|
+
function onesLike$1(val, opts) {
|
|
3395
|
+
return fullLike(val, 1, opts);
|
|
3389
3396
|
}
|
|
3390
|
-
function fullLike(val, fillValue, dtype) {
|
|
3397
|
+
function fullLike(val, fillValue, { dtype, shape: shape$1, device } = {}) {
|
|
3391
3398
|
const aval = getAval(val);
|
|
3392
3399
|
if (val instanceof Tracer) val.dispose();
|
|
3393
3400
|
if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
|
|
3394
|
-
const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
|
|
3395
|
-
return fullInternal(sa, fillValue);
|
|
3401
|
+
const sa = new ShapedArray(shape$1 ?? aval.shape, dtype ?? aval.dtype, aval.weakType && dtype === void 0);
|
|
3402
|
+
return fullInternal(sa, fillValue, device);
|
|
3396
3403
|
}
|
|
3397
3404
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
3398
|
-
function zeros(shape$1,
|
|
3399
|
-
return full(shape$1, 0,
|
|
3400
|
-
dtype,
|
|
3401
|
-
device
|
|
3402
|
-
});
|
|
3405
|
+
function zeros(shape$1, opts) {
|
|
3406
|
+
return full(shape$1, 0, opts);
|
|
3403
3407
|
}
|
|
3404
3408
|
/** Return a new array of given shape and type, filled with ones. */
|
|
3405
|
-
function ones(shape$1,
|
|
3406
|
-
return full(shape$1, 1,
|
|
3407
|
-
dtype,
|
|
3408
|
-
device
|
|
3409
|
-
});
|
|
3409
|
+
function ones(shape$1, opts) {
|
|
3410
|
+
return full(shape$1, 1, opts);
|
|
3410
3411
|
}
|
|
3411
3412
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
3412
3413
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
@@ -5295,7 +5296,7 @@ function lstsq(a, b) {
|
|
|
5295
5296
|
lower: true,
|
|
5296
5297
|
transposeA: true
|
|
5297
5298
|
});
|
|
5298
|
-
return matmul(at, llb
|
|
5299
|
+
return matmul(at, llb);
|
|
5299
5300
|
} else {
|
|
5300
5301
|
const ata = matmul(at.ref, a);
|
|
5301
5302
|
const l = cholesky(ata, { symmetrizeInput: false });
|
|
@@ -5386,7 +5387,7 @@ function solve(a, b) {
|
|
|
5386
5387
|
lower: true,
|
|
5387
5388
|
unitDiagonal: true
|
|
5388
5389
|
});
|
|
5389
|
-
let x = triangularSolve(lu$2, LPb
|
|
5390
|
+
let x = triangularSolve(lu$2, LPb, {
|
|
5390
5391
|
leftSide: true,
|
|
5391
5392
|
lower: false
|
|
5392
5393
|
});
|
|
@@ -6197,8 +6198,9 @@ function sort(a, axis = -1) {
|
|
|
6197
6198
|
return fudgeArray(a).sort(axis);
|
|
6198
6199
|
}
|
|
6199
6200
|
/**
|
|
6200
|
-
* Return indices that would sort an array.
|
|
6201
|
-
* algorithm; it
|
|
6201
|
+
* Return indices that would sort an array. Unlike `sort`, this is guaranteed to
|
|
6202
|
+
* be a stable sorting algorithm; it always returns the smaller index first in
|
|
6203
|
+
* event of ties.
|
|
6202
6204
|
*
|
|
6203
6205
|
* Returns an array of `int32` indices.
|
|
6204
6206
|
*
|
|
@@ -6500,7 +6502,7 @@ function absolute(x) {
|
|
|
6500
6502
|
/** Return an element-wise indication of sign of the input. */
|
|
6501
6503
|
function sign(x) {
|
|
6502
6504
|
x = fudgeArray(x);
|
|
6503
|
-
return where(notEqual(x.ref, 0), where(less(x
|
|
6505
|
+
return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
|
|
6504
6506
|
}
|
|
6505
6507
|
/** @function Return element-wise positive values of the input (no-op). */
|
|
6506
6508
|
const positive = fudgeArray;
|
|
@@ -6993,7 +6995,8 @@ __export(lax_exports, {
|
|
|
6993
6995
|
erfc: () => erfc,
|
|
6994
6996
|
linalg: () => lax_linalg_exports,
|
|
6995
6997
|
reduceWindow: () => reduceWindow,
|
|
6996
|
-
stopGradient: () => stopGradient$1
|
|
6998
|
+
stopGradient: () => stopGradient$1,
|
|
6999
|
+
topK: () => topK
|
|
6997
7000
|
});
|
|
6998
7001
|
const JsArray = globalThis.Array;
|
|
6999
7002
|
/**
|
|
@@ -7217,6 +7220,39 @@ function erfc(x) {
|
|
|
7217
7220
|
function stopGradient$1(x) {
|
|
7218
7221
|
return stopGradient(x);
|
|
7219
7222
|
}
|
|
7223
|
+
/**
|
|
7224
|
+
* Returns top `k` values and their indices along the specified axis of operand.
|
|
7225
|
+
*
|
|
7226
|
+
* This is a _stable_ algorithm: If two elements are equal, the lower-index
|
|
7227
|
+
* element appears first.
|
|
7228
|
+
*
|
|
7229
|
+
* @returns A tuple of `(values, indices)`, where `values` and `indices` have
|
|
7230
|
+
* the same shape as `x`, except along `axis` where they have size `k`.
|
|
7231
|
+
*/
|
|
7232
|
+
function topK(x, k, axis = -1) {
|
|
7233
|
+
x = fudgeArray(x);
|
|
7234
|
+
axis = checkAxis(axis, x.ndim);
|
|
7235
|
+
const size$1 = x.shape[axis];
|
|
7236
|
+
if (k < 0 || k > size$1) throw new Error(`topK: k must be in the range [0, ${size$1}], got ${k}`);
|
|
7237
|
+
if (k === 0) {
|
|
7238
|
+
const outShape = x.shape.slice();
|
|
7239
|
+
outShape[axis] = 0;
|
|
7240
|
+
const y$1 = zerosLike$1(x.ref, { shape: outShape });
|
|
7241
|
+
const yi$1 = zerosLike$1(x, {
|
|
7242
|
+
dtype: DType.Int32,
|
|
7243
|
+
shape: outShape
|
|
7244
|
+
});
|
|
7245
|
+
return [y$1, yi$1];
|
|
7246
|
+
}
|
|
7247
|
+
x = flip$1(x, [axis]);
|
|
7248
|
+
x = moveaxis(x, axis, -1);
|
|
7249
|
+
const [y, yi] = argsort$1(x);
|
|
7250
|
+
const extract = (a) => {
|
|
7251
|
+
a = a.slice(...rep(a.ndim - 1, []), [-k]);
|
|
7252
|
+
return flip$1(moveaxis(a, -1, axis), [axis]);
|
|
7253
|
+
};
|
|
7254
|
+
return [extract(y), extract(yi.neg().add(size$1 - 1))];
|
|
7255
|
+
}
|
|
7220
7256
|
|
|
7221
7257
|
//#endregion
|
|
7222
7258
|
//#region src/library/nn.ts
|
|
@@ -7408,7 +7444,7 @@ const gelu = jit$1(function gelu$1(x, opts) {
|
|
|
7408
7444
|
if (opts?.approximate ?? true) {
|
|
7409
7445
|
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
7410
7446
|
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
7411
|
-
} else return x.ref.mul(.5).mul(erfc$1(negative(x.
|
|
7447
|
+
} else return x.ref.mul(.5).mul(erfc$1(negative(x.mul(Math.SQRT1_2))));
|
|
7412
7448
|
}, { staticArgnums: [1] });
|
|
7413
7449
|
/**
|
|
7414
7450
|
* Gated linear unit (GLU) activation function.
|
|
@@ -7666,6 +7702,7 @@ var random_exports = {};
|
|
|
7666
7702
|
__export(random_exports, {
|
|
7667
7703
|
bernoulli: () => bernoulli,
|
|
7668
7704
|
bits: () => bits,
|
|
7705
|
+
categorical: () => categorical,
|
|
7669
7706
|
cauchy: () => cauchy,
|
|
7670
7707
|
exponential: () => exponential,
|
|
7671
7708
|
gumbel: () => gumbel,
|
|
@@ -7737,6 +7774,47 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
|
7737
7774
|
}
|
|
7738
7775
|
/**
|
|
7739
7776
|
* @function
|
|
7777
|
+
* Sample random values from categorical distributions.
|
|
7778
|
+
*
|
|
7779
|
+
* Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
|
|
7780
|
+
* trick for sampling without replacement.
|
|
7781
|
+
*
|
|
7782
|
+
* Note: Sampling without replacement currently uses argsort and slices the last
|
|
7783
|
+
* k elements. This should be replaced with a more efficient topK implementation.
|
|
7784
|
+
*
|
|
7785
|
+
* - `key` - PRNG key
|
|
7786
|
+
* - `logits` - Unnormalized log probabilities of the categorical distribution(s).
|
|
7787
|
+
* `softmax(logits, axis)` gives the corresponding probabilities.
|
|
7788
|
+
* - `axis` - Axis along which logits belong to the same categorical distribution.
|
|
7789
|
+
* - `shape` - Result batch shape. Must be broadcast-compatible with
|
|
7790
|
+
* `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
|
|
7791
|
+
* - `replace` - If true (default), sample with replacement. If false, sample
|
|
7792
|
+
* without replacement (each category can only be selected once per batch).
|
|
7793
|
+
* @returns A random array with int dtype and shape given by `shape` if provided,
|
|
7794
|
+
* otherwise `logits.shape` with `axis` removed.
|
|
7795
|
+
*/
|
|
7796
|
+
const categorical = jit$1(function categorical$1(key$1, logits, { axis = -1, shape: shape$1, replace = true } = {}) {
|
|
7797
|
+
logits = fudgeArray(logits);
|
|
7798
|
+
axis = checkAxis(axis, logits.ndim);
|
|
7799
|
+
const numCategories = logits.shape[axis];
|
|
7800
|
+
const batchShape = logits.shape.toSpliced(axis, 1);
|
|
7801
|
+
if (shape$1 === void 0) shape$1 = batchShape;
|
|
7802
|
+
else if (!deepEqual(generalBroadcast(shape$1, batchShape), shape$1)) throw new Error(`Shape ${shape$1} is not broadcast-compatible with batch shape ${batchShape}.`);
|
|
7803
|
+
const shapePrefix = shape$1.slice(0, shape$1.length - batchShape.length);
|
|
7804
|
+
if (replace) {
|
|
7805
|
+
const noise = gumbel(key$1, [...shapePrefix, ...logits.shape]);
|
|
7806
|
+
return argmax(noise.add(logits), axis + shapePrefix.length);
|
|
7807
|
+
} else {
|
|
7808
|
+
const k = shapePrefix.reduce((a, b) => a * b, 1);
|
|
7809
|
+
if (k > numCategories) throw new Error(`Number of samples without replacement (${k}) cannot exceed number of categories (${numCategories}).`);
|
|
7810
|
+
const noise = gumbel(key$1, logits.shape);
|
|
7811
|
+
const [values, indices] = topK(noise.add(logits), k, axis);
|
|
7812
|
+
values.dispose();
|
|
7813
|
+
return indices.reshape(shape$1);
|
|
7814
|
+
}
|
|
7815
|
+
}, { staticArgnums: [2] });
|
|
7816
|
+
/**
|
|
7817
|
+
* @function
|
|
7740
7818
|
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
7741
7819
|
*
|
|
7742
7820
|
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-
|
|
1
|
+
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-BId79r5b.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgl/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-BId79r5b.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -247,6 +247,10 @@ function bitonicSortUniform(pass) {
|
|
|
247
247
|
* `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
|
|
248
248
|
*
|
|
249
249
|
* The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
|
|
250
|
+
*
|
|
251
|
+
* If `outputIndices` is true, the shader also tracks the original indices of
|
|
252
|
+
* the sorted elements (argsort) and outputs them to a separate buffer. This
|
|
253
|
+
* also makes the sorting algorithm stable.
|
|
250
254
|
*/
|
|
251
255
|
function bitonicSortShader(device, dtype, n, batches, outputIndices) {
|
|
252
256
|
const ty = dtypeToWgsl(dtype, true);
|
|
@@ -286,14 +290,21 @@ ${isFloatDtype(dtype) ? `
|
|
|
286
290
|
fn compare_and_swap(i: u32, j: u32) {
|
|
287
291
|
let val_i = shared_vals[i];
|
|
288
292
|
let val_j = shared_vals[j];
|
|
289
|
-
|
|
293
|
+
${outputIndices ? `
|
|
294
|
+
if (
|
|
295
|
+
compare(val_j, val_i) ||
|
|
296
|
+
(!compare(val_i, val_j) && shared_idx[j] < shared_idx[i])
|
|
297
|
+
) {
|
|
290
298
|
shared_vals[i] = val_j;
|
|
291
299
|
shared_vals[j] = val_i;
|
|
292
|
-
${outputIndices ? `
|
|
293
300
|
let tmp_idx = shared_idx[i];
|
|
294
301
|
shared_idx[i] = shared_idx[j];
|
|
295
|
-
shared_idx[j] = tmp_idx
|
|
296
|
-
}
|
|
302
|
+
shared_idx[j] = tmp_idx;
|
|
303
|
+
}` : `
|
|
304
|
+
if (compare(val_j, val_i)) {
|
|
305
|
+
shared_vals[i] = val_j;
|
|
306
|
+
shared_vals[j] = val_i;
|
|
307
|
+
}`}
|
|
297
308
|
}
|
|
298
309
|
|
|
299
310
|
@compute @workgroup_size(${workgroupSize})
|
|
@@ -370,13 +381,17 @@ ${outputIndices ? `
|
|
|
370
381
|
if (j < ${n}u) {
|
|
371
382
|
let val_i = output[base + i];
|
|
372
383
|
let val_j = output[base + j];
|
|
373
|
-
|
|
384
|
+
${outputIndices ? `
|
|
385
|
+
let idx_i = output_idx[base + i];
|
|
386
|
+
let idx_j = output_idx[base + j];
|
|
387
|
+
if (compare(val_j, val_i) || (!compare(val_i, val_j) && idx_j < idx_i)) {
|
|
374
388
|
output[base + i] = val_j;
|
|
375
389
|
output[base + j] = val_i;
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
390
|
+
output_idx[base + i] = idx_j;
|
|
391
|
+
output_idx[base + j] = idx_i;` : `
|
|
392
|
+
if (compare(val_j, val_i)) {
|
|
393
|
+
output[base + i] = val_j;
|
|
394
|
+
output[base + j] = val_i;`}
|
|
380
395
|
}
|
|
381
396
|
}
|
|
382
397
|
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-DpI0riom.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -247,6 +247,10 @@ function bitonicSortUniform(pass) {
|
|
|
247
247
|
* `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
|
|
248
248
|
*
|
|
249
249
|
* The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
|
|
250
|
+
*
|
|
251
|
+
* If `outputIndices` is true, the shader also tracks the original indices of
|
|
252
|
+
* the sorted elements (argsort) and outputs them to a separate buffer. This
|
|
253
|
+
* also makes the sorting algorithm stable.
|
|
250
254
|
*/
|
|
251
255
|
function bitonicSortShader(device, dtype, n, batches, outputIndices) {
|
|
252
256
|
const ty = dtypeToWgsl(dtype, true);
|
|
@@ -286,14 +290,21 @@ ${require_backend.isFloatDtype(dtype) ? `
|
|
|
286
290
|
fn compare_and_swap(i: u32, j: u32) {
|
|
287
291
|
let val_i = shared_vals[i];
|
|
288
292
|
let val_j = shared_vals[j];
|
|
289
|
-
|
|
293
|
+
${outputIndices ? `
|
|
294
|
+
if (
|
|
295
|
+
compare(val_j, val_i) ||
|
|
296
|
+
(!compare(val_i, val_j) && shared_idx[j] < shared_idx[i])
|
|
297
|
+
) {
|
|
290
298
|
shared_vals[i] = val_j;
|
|
291
299
|
shared_vals[j] = val_i;
|
|
292
|
-
${outputIndices ? `
|
|
293
300
|
let tmp_idx = shared_idx[i];
|
|
294
301
|
shared_idx[i] = shared_idx[j];
|
|
295
|
-
shared_idx[j] = tmp_idx
|
|
296
|
-
}
|
|
302
|
+
shared_idx[j] = tmp_idx;
|
|
303
|
+
}` : `
|
|
304
|
+
if (compare(val_j, val_i)) {
|
|
305
|
+
shared_vals[i] = val_j;
|
|
306
|
+
shared_vals[j] = val_i;
|
|
307
|
+
}`}
|
|
297
308
|
}
|
|
298
309
|
|
|
299
310
|
@compute @workgroup_size(${workgroupSize})
|
|
@@ -370,13 +381,17 @@ ${outputIndices ? `
|
|
|
370
381
|
if (j < ${n}u) {
|
|
371
382
|
let val_i = output[base + i];
|
|
372
383
|
let val_j = output[base + j];
|
|
373
|
-
|
|
384
|
+
${outputIndices ? `
|
|
385
|
+
let idx_i = output_idx[base + i];
|
|
386
|
+
let idx_j = output_idx[base + j];
|
|
387
|
+
if (compare(val_j, val_i) || (!compare(val_i, val_j) && idx_j < idx_i)) {
|
|
374
388
|
output[base + i] = val_j;
|
|
375
389
|
output[base + j] = val_i;
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
390
|
+
output_idx[base + i] = idx_j;
|
|
391
|
+
output_idx[base + j] = idx_i;` : `
|
|
392
|
+
if (compare(val_j, val_i)) {
|
|
393
|
+
output[base + i] = val_j;
|
|
394
|
+
output[base + j] = val_i;`}
|
|
380
395
|
}
|
|
381
396
|
}
|
|
382
397
|
}
|