@jax-js/jax 0.1.7 → 0.1.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +30 -13
- package/dist/{backend-nEolvdLv.js → backend-BId79r5b.js} +17 -6
- package/dist/{backend-B3foXiV_.cjs → backend-DpI0riom.cjs} +17 -6
- package/dist/index.cjs +113 -30
- package/dist/index.d.cts +64 -21
- package/dist/index.d.ts +64 -21
- package/dist/index.js +113 -30
- package/dist/{webgl-DIIbKJ0G.cjs → webgl-C5NjXc1p.cjs} +1 -1
- package/dist/{webgl-DweKSWEm.js → webgl-DnGrclTz.js} +1 -1
- package/dist/{webgpu-B96vzWGE.js → webgpu-AN0cG_nB.js} +25 -10
- package/dist/{webgpu-BykvF26B.cjs → webgpu-CdjiJSa7.cjs} +25 -10
- package/package.json +12 -1
package/README.md
CHANGED
|
@@ -43,6 +43,23 @@ way to get started on a blank HTML page.
|
|
|
43
43
|
</script>
|
|
44
44
|
```
|
|
45
45
|
|
|
46
|
+
## Examples
|
|
47
|
+
|
|
48
|
+
Cool things that the community has made with jax-js:
|
|
49
|
+
|
|
50
|
+
- [**tanh.xyz**: Interactive ML visualizations](https://tanh.xyz/)
|
|
51
|
+
|
|
52
|
+
And some more demos on the official website.
|
|
53
|
+
|
|
54
|
+
- [Training neural networks on MNIST](https://jax-js.com/mnist)
|
|
55
|
+
- [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
|
|
56
|
+
- [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
|
|
57
|
+
- [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
|
|
58
|
+
- [In-browser REPL](https://jax-js.com/repl)
|
|
59
|
+
- [Matmul benchmark](https://jax-js.com/bench/matmul)
|
|
60
|
+
- [Conv2d benchmark](https://jax-js.com/bench/conv2d)
|
|
61
|
+
- [Mandelbrot set](https://jax-js.com/mandelbrot)
|
|
62
|
+
|
|
46
63
|
## Feature comparison
|
|
47
64
|
|
|
48
65
|
Here's a quick, high-level comparison with other popular web ML runtimes:
|
|
@@ -338,19 +355,6 @@ well as unique optimizations such as FlashAttention variants.
|
|
|
338
355
|
That's all for this short tutorial. Please see the generated
|
|
339
356
|
[API reference](https://jax-js.com/docs) for detailed documentation.
|
|
340
357
|
|
|
341
|
-
## Examples
|
|
342
|
-
|
|
343
|
-
If you make something cool with jax-js, don't be a stranger! We can feature it here.
|
|
344
|
-
|
|
345
|
-
- [Training neural networks on MNIST](https://jax-js.com/mnist)
|
|
346
|
-
- [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
|
|
347
|
-
- [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
|
|
348
|
-
- [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
|
|
349
|
-
- [In-browser REPL](https://jax-js.com/repl)
|
|
350
|
-
- [Matmul benchmark](https://jax-js.com/bench/matmul)
|
|
351
|
-
- [Conv2d benchmark](https://jax-js.com/bench/conv2d)
|
|
352
|
-
- [Mandelbrot set](https://jax-js.com/mandelbrot)
|
|
353
|
-
|
|
354
358
|
## Development
|
|
355
359
|
|
|
356
360
|
_The following technical details are for contributing to jax-js and modifying its internals._
|
|
@@ -363,6 +367,19 @@ pnpm install
|
|
|
363
367
|
pnpm run build:watch
|
|
364
368
|
```
|
|
365
369
|
|
|
370
|
+
The `pnpm install` command automatically sets up Git hooks via
|
|
371
|
+
[Husky](https://typicode.github.io/husky/). Pre-commit hooks will run ESLint and Prettier on staged
|
|
372
|
+
files to ensure code quality.
|
|
373
|
+
|
|
374
|
+
You can also run linting and formatting manually:
|
|
375
|
+
|
|
376
|
+
```bash
|
|
377
|
+
pnpm lint # Run ESLint
|
|
378
|
+
pnpm format # Format all files with Prettier
|
|
379
|
+
pnpm format:check # Check formatting without writing
|
|
380
|
+
pnpm check # Run TypeScript type checking
|
|
381
|
+
```
|
|
382
|
+
|
|
366
383
|
Then you can run tests in a headless browser using [Vitest](https://vitest.dev/).
|
|
367
384
|
|
|
368
385
|
```bash
|
|
@@ -1479,9 +1479,14 @@ var Routine = class {
|
|
|
1479
1479
|
};
|
|
1480
1480
|
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
1481
1481
|
let Routines = /* @__PURE__ */ function(Routines$1) {
|
|
1482
|
-
/**
|
|
1482
|
+
/**
|
|
1483
|
+
* Sort along the last axis.
|
|
1484
|
+
*
|
|
1485
|
+
* This may be _unstable_ but it often doesn't matter, sorting numbers is
|
|
1486
|
+
* bitwise unique up to signed zeros and NaNs.
|
|
1487
|
+
*/
|
|
1483
1488
|
Routines$1["Sort"] = "Sort";
|
|
1484
|
-
/**
|
|
1489
|
+
/** Stable sorting, returns `int32` indices and values of the sorted array. */
|
|
1485
1490
|
Routines$1["Argsort"] = "Argsort";
|
|
1486
1491
|
/**
|
|
1487
1492
|
* Solve a triangular system of equations.
|
|
@@ -1545,7 +1550,13 @@ function runArgsort(type, [x], [y, yi]) {
|
|
|
1545
1550
|
const out = y.subarray(offset, offset + n);
|
|
1546
1551
|
const outi = yi.subarray(offset, offset + n);
|
|
1547
1552
|
for (let i = 0; i < n; i++) outi[i] = i;
|
|
1548
|
-
outi.sort((a, b) =>
|
|
1553
|
+
outi.sort((a, b) => {
|
|
1554
|
+
const x$1 = ar[a];
|
|
1555
|
+
const y$1 = ar[b];
|
|
1556
|
+
if (isNaN(x$1)) return isNaN(y$1) ? 0 : 1;
|
|
1557
|
+
if (isNaN(y$1)) return -1;
|
|
1558
|
+
return x$1 === y$1 ? 0 : x$1 < y$1 ? -1 : 1;
|
|
1559
|
+
});
|
|
1549
1560
|
for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
|
|
1550
1561
|
}
|
|
1551
1562
|
}
|
|
@@ -2321,7 +2332,7 @@ function tuneWebgpu(kernel) {
|
|
|
2321
2332
|
if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
2322
2333
|
const s = dim.st.shape[dim.unroll - 1];
|
|
2323
2334
|
if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2324
|
-
else for (const splits of [
|
|
2335
|
+
else for (const splits of [4, 2]) if (s % splits === 0) {
|
|
2325
2336
|
dim.applyUnroll(dim.unroll - 1, splits);
|
|
2326
2337
|
break;
|
|
2327
2338
|
}
|
|
@@ -4252,7 +4263,7 @@ async function createBackend(device) {
|
|
|
4252
4263
|
if (!navigator.gpu) return null;
|
|
4253
4264
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4254
4265
|
if (!adapter) return null;
|
|
4255
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
4266
|
+
const { WebGPUBackend } = await import("./webgpu-AN0cG_nB.js");
|
|
4256
4267
|
const importantLimits = [
|
|
4257
4268
|
"maxBufferSize",
|
|
4258
4269
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4290,7 +4301,7 @@ async function createBackend(device) {
|
|
|
4290
4301
|
});
|
|
4291
4302
|
if (!gl) return null;
|
|
4292
4303
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4293
|
-
const { WebGLBackend } = await import("./webgl-
|
|
4304
|
+
const { WebGLBackend } = await import("./webgl-DnGrclTz.js");
|
|
4294
4305
|
return new WebGLBackend(gl);
|
|
4295
4306
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4296
4307
|
}
|
|
@@ -1480,9 +1480,14 @@ var Routine = class {
|
|
|
1480
1480
|
};
|
|
1481
1481
|
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
1482
1482
|
let Routines = /* @__PURE__ */ function(Routines$1) {
|
|
1483
|
-
/**
|
|
1483
|
+
/**
|
|
1484
|
+
* Sort along the last axis.
|
|
1485
|
+
*
|
|
1486
|
+
* This may be _unstable_ but it often doesn't matter, sorting numbers is
|
|
1487
|
+
* bitwise unique up to signed zeros and NaNs.
|
|
1488
|
+
*/
|
|
1484
1489
|
Routines$1["Sort"] = "Sort";
|
|
1485
|
-
/**
|
|
1490
|
+
/** Stable sorting, returns `int32` indices and values of the sorted array. */
|
|
1486
1491
|
Routines$1["Argsort"] = "Argsort";
|
|
1487
1492
|
/**
|
|
1488
1493
|
* Solve a triangular system of equations.
|
|
@@ -1546,7 +1551,13 @@ function runArgsort(type, [x], [y, yi]) {
|
|
|
1546
1551
|
const out = y.subarray(offset, offset + n);
|
|
1547
1552
|
const outi = yi.subarray(offset, offset + n);
|
|
1548
1553
|
for (let i = 0; i < n; i++) outi[i] = i;
|
|
1549
|
-
outi.sort((a, b) =>
|
|
1554
|
+
outi.sort((a, b) => {
|
|
1555
|
+
const x$1 = ar[a];
|
|
1556
|
+
const y$1 = ar[b];
|
|
1557
|
+
if (isNaN(x$1)) return isNaN(y$1) ? 0 : 1;
|
|
1558
|
+
if (isNaN(y$1)) return -1;
|
|
1559
|
+
return x$1 === y$1 ? 0 : x$1 < y$1 ? -1 : 1;
|
|
1560
|
+
});
|
|
1550
1561
|
for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
|
|
1551
1562
|
}
|
|
1552
1563
|
}
|
|
@@ -2322,7 +2333,7 @@ function tuneWebgpu(kernel) {
|
|
|
2322
2333
|
if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
2323
2334
|
const s = dim.st.shape[dim.unroll - 1];
|
|
2324
2335
|
if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2325
|
-
else for (const splits of [
|
|
2336
|
+
else for (const splits of [4, 2]) if (s % splits === 0) {
|
|
2326
2337
|
dim.applyUnroll(dim.unroll - 1, splits);
|
|
2327
2338
|
break;
|
|
2328
2339
|
}
|
|
@@ -4253,7 +4264,7 @@ async function createBackend(device) {
|
|
|
4253
4264
|
if (!navigator.gpu) return null;
|
|
4254
4265
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4255
4266
|
if (!adapter) return null;
|
|
4256
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
4267
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-CdjiJSa7.cjs"));
|
|
4257
4268
|
const importantLimits = [
|
|
4258
4269
|
"maxBufferSize",
|
|
4259
4270
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4291,7 +4302,7 @@ async function createBackend(device) {
|
|
|
4291
4302
|
});
|
|
4292
4303
|
if (!gl) return null;
|
|
4293
4304
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4294
|
-
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-
|
|
4305
|
+
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-C5NjXc1p.cjs"));
|
|
4295
4306
|
return new WebGLBackend(gl);
|
|
4296
4307
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4297
4308
|
}
|
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-DpI0riom.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -920,18 +920,25 @@ var Tracer = class Tracer {
|
|
|
920
920
|
return sort$1(this.transpose(perm)).transpose(require_backend.invertPermutation(perm));
|
|
921
921
|
}
|
|
922
922
|
/**
|
|
923
|
-
* Return the indices that would sort an array.
|
|
924
|
-
* sorting algorithm; it
|
|
923
|
+
* Return the indices that would sort an array. Unlike `sort`, this is
|
|
924
|
+
* guaranteed to be a stable sorting algorithm; it always returns the smaller
|
|
925
|
+
* index first in event of ties.
|
|
925
926
|
*
|
|
926
927
|
* See `jax.numpy.argsort` for full docs.
|
|
927
928
|
*/
|
|
928
929
|
argsort(axis = -1) {
|
|
929
930
|
axis = require_backend.checkAxis(axis, this.ndim);
|
|
930
|
-
if (axis === this.ndim - 1)
|
|
931
|
+
if (axis === this.ndim - 1) {
|
|
932
|
+
const [y$1, yi$1] = argsort$1(this);
|
|
933
|
+
y$1.dispose();
|
|
934
|
+
return yi$1;
|
|
935
|
+
}
|
|
931
936
|
const perm = require_backend.range(this.ndim);
|
|
932
937
|
perm.splice(axis, 1);
|
|
933
938
|
perm.push(axis);
|
|
934
|
-
|
|
939
|
+
const [y, yi] = argsort$1(this.transpose(perm));
|
|
940
|
+
y.dispose();
|
|
941
|
+
return yi.transpose(require_backend.invertPermutation(perm));
|
|
935
942
|
}
|
|
936
943
|
/**
|
|
937
944
|
* Slice an array along one or more axes.
|
|
@@ -3416,32 +3423,26 @@ function fullInternal(aval, fillValue, device) {
|
|
|
3416
3423
|
committed: device != void 0
|
|
3417
3424
|
});
|
|
3418
3425
|
}
|
|
3419
|
-
function zerosLike$1(val,
|
|
3420
|
-
return fullLike(val, 0,
|
|
3426
|
+
function zerosLike$1(val, opts) {
|
|
3427
|
+
return fullLike(val, 0, opts);
|
|
3421
3428
|
}
|
|
3422
|
-
function onesLike$1(val,
|
|
3423
|
-
return fullLike(val, 1,
|
|
3429
|
+
function onesLike$1(val, opts) {
|
|
3430
|
+
return fullLike(val, 1, opts);
|
|
3424
3431
|
}
|
|
3425
|
-
function fullLike(val, fillValue, dtype) {
|
|
3432
|
+
function fullLike(val, fillValue, { dtype, shape: shape$1, device } = {}) {
|
|
3426
3433
|
const aval = getAval(val);
|
|
3427
3434
|
if (val instanceof Tracer) val.dispose();
|
|
3428
3435
|
if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
|
|
3429
|
-
const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
|
|
3430
|
-
return fullInternal(sa, fillValue);
|
|
3436
|
+
const sa = new ShapedArray(shape$1 ?? aval.shape, dtype ?? aval.dtype, aval.weakType && dtype === void 0);
|
|
3437
|
+
return fullInternal(sa, fillValue, device);
|
|
3431
3438
|
}
|
|
3432
3439
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
3433
|
-
function zeros(shape$1,
|
|
3434
|
-
return full(shape$1, 0,
|
|
3435
|
-
dtype,
|
|
3436
|
-
device
|
|
3437
|
-
});
|
|
3440
|
+
function zeros(shape$1, opts) {
|
|
3441
|
+
return full(shape$1, 0, opts);
|
|
3438
3442
|
}
|
|
3439
3443
|
/** Return a new array of given shape and type, filled with ones. */
|
|
3440
|
-
function ones(shape$1,
|
|
3441
|
-
return full(shape$1, 1,
|
|
3442
|
-
dtype,
|
|
3443
|
-
device
|
|
3444
|
-
});
|
|
3444
|
+
function ones(shape$1, opts) {
|
|
3445
|
+
return full(shape$1, 1, opts);
|
|
3445
3446
|
}
|
|
3446
3447
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
3447
3448
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
@@ -5329,9 +5330,10 @@ function lstsq(a, b) {
|
|
|
5329
5330
|
});
|
|
5330
5331
|
const llb = triangularSolve(l, lb, {
|
|
5331
5332
|
leftSide: true,
|
|
5333
|
+
lower: true,
|
|
5332
5334
|
transposeA: true
|
|
5333
5335
|
});
|
|
5334
|
-
return matmul(at, llb
|
|
5336
|
+
return matmul(at, llb);
|
|
5335
5337
|
} else {
|
|
5336
5338
|
const ata = matmul(at.ref, a);
|
|
5337
5339
|
const l = cholesky(ata, { symmetrizeInput: false });
|
|
@@ -5342,6 +5344,7 @@ function lstsq(a, b) {
|
|
|
5342
5344
|
});
|
|
5343
5345
|
const llb = triangularSolve(l, lb, {
|
|
5344
5346
|
leftSide: true,
|
|
5347
|
+
lower: true,
|
|
5345
5348
|
transposeA: true
|
|
5346
5349
|
});
|
|
5347
5350
|
return llb;
|
|
@@ -5421,7 +5424,7 @@ function solve(a, b) {
|
|
|
5421
5424
|
lower: true,
|
|
5422
5425
|
unitDiagonal: true
|
|
5423
5426
|
});
|
|
5424
|
-
let x = triangularSolve(lu$2, LPb
|
|
5427
|
+
let x = triangularSolve(lu$2, LPb, {
|
|
5425
5428
|
leftSide: true,
|
|
5426
5429
|
lower: false
|
|
5427
5430
|
});
|
|
@@ -6232,8 +6235,9 @@ function sort(a, axis = -1) {
|
|
|
6232
6235
|
return fudgeArray(a).sort(axis);
|
|
6233
6236
|
}
|
|
6234
6237
|
/**
|
|
6235
|
-
* Return indices that would sort an array.
|
|
6236
|
-
* algorithm; it
|
|
6238
|
+
* Return indices that would sort an array. Unlike `sort`, this is guaranteed to
|
|
6239
|
+
* be a stable sorting algorithm; it always returns the smaller index first in
|
|
6240
|
+
* event of ties.
|
|
6237
6241
|
*
|
|
6238
6242
|
* Returns an array of `int32` indices.
|
|
6239
6243
|
*
|
|
@@ -6535,7 +6539,7 @@ function absolute(x) {
|
|
|
6535
6539
|
/** Return an element-wise indication of sign of the input. */
|
|
6536
6540
|
function sign(x) {
|
|
6537
6541
|
x = fudgeArray(x);
|
|
6538
|
-
return where(notEqual(x.ref, 0), where(less(x
|
|
6542
|
+
return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
|
|
6539
6543
|
}
|
|
6540
6544
|
/** @function Return element-wise positive values of the input (no-op). */
|
|
6541
6545
|
const positive = fudgeArray;
|
|
@@ -7003,7 +7007,10 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
|
|
|
7003
7007
|
b = fudgeArray(b);
|
|
7004
7008
|
if (!leftSide) transposeA = !transposeA;
|
|
7005
7009
|
else b = moveaxis$1(b, -2, -1);
|
|
7006
|
-
if (transposeA)
|
|
7010
|
+
if (transposeA) {
|
|
7011
|
+
a = moveaxis$1(a, -2, -1);
|
|
7012
|
+
lower = !lower;
|
|
7013
|
+
}
|
|
7007
7014
|
let x = triangularSolve$1(a, b, {
|
|
7008
7015
|
lower,
|
|
7009
7016
|
unitDiagonal
|
|
@@ -7025,7 +7032,8 @@ __export(lax_exports, {
|
|
|
7025
7032
|
erfc: () => erfc,
|
|
7026
7033
|
linalg: () => lax_linalg_exports,
|
|
7027
7034
|
reduceWindow: () => reduceWindow,
|
|
7028
|
-
stopGradient: () => stopGradient$1
|
|
7035
|
+
stopGradient: () => stopGradient$1,
|
|
7036
|
+
topK: () => topK
|
|
7029
7037
|
});
|
|
7030
7038
|
const JsArray = globalThis.Array;
|
|
7031
7039
|
/**
|
|
@@ -7249,6 +7257,39 @@ function erfc(x) {
|
|
|
7249
7257
|
function stopGradient$1(x) {
|
|
7250
7258
|
return stopGradient(x);
|
|
7251
7259
|
}
|
|
7260
|
+
/**
|
|
7261
|
+
* Returns top `k` values and their indices along the specified axis of operand.
|
|
7262
|
+
*
|
|
7263
|
+
* This is a _stable_ algorithm: If two elements are equal, the lower-index
|
|
7264
|
+
* element appears first.
|
|
7265
|
+
*
|
|
7266
|
+
* @returns A tuple of `(values, indices)`, where `values` and `indices` have
|
|
7267
|
+
* the same shape as `x`, except along `axis` where they have size `k`.
|
|
7268
|
+
*/
|
|
7269
|
+
function topK(x, k, axis = -1) {
|
|
7270
|
+
x = fudgeArray(x);
|
|
7271
|
+
axis = require_backend.checkAxis(axis, x.ndim);
|
|
7272
|
+
const size$1 = x.shape[axis];
|
|
7273
|
+
if (k < 0 || k > size$1) throw new Error(`topK: k must be in the range [0, ${size$1}], got ${k}`);
|
|
7274
|
+
if (k === 0) {
|
|
7275
|
+
const outShape = x.shape.slice();
|
|
7276
|
+
outShape[axis] = 0;
|
|
7277
|
+
const y$1 = zerosLike$1(x.ref, { shape: outShape });
|
|
7278
|
+
const yi$1 = zerosLike$1(x, {
|
|
7279
|
+
dtype: require_backend.DType.Int32,
|
|
7280
|
+
shape: outShape
|
|
7281
|
+
});
|
|
7282
|
+
return [y$1, yi$1];
|
|
7283
|
+
}
|
|
7284
|
+
x = flip$1(x, [axis]);
|
|
7285
|
+
x = moveaxis(x, axis, -1);
|
|
7286
|
+
const [y, yi] = argsort$1(x);
|
|
7287
|
+
const extract = (a) => {
|
|
7288
|
+
a = a.slice(...require_backend.rep(a.ndim - 1, []), [-k]);
|
|
7289
|
+
return flip$1(moveaxis(a, -1, axis), [axis]);
|
|
7290
|
+
};
|
|
7291
|
+
return [extract(y), extract(yi.neg().add(size$1 - 1))];
|
|
7292
|
+
}
|
|
7252
7293
|
|
|
7253
7294
|
//#endregion
|
|
7254
7295
|
//#region src/library/nn.ts
|
|
@@ -7440,7 +7481,7 @@ const gelu = jit$1(function gelu$1(x, opts) {
|
|
|
7440
7481
|
if (opts?.approximate ?? true) {
|
|
7441
7482
|
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
7442
7483
|
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
7443
|
-
} else return x.ref.mul(.5).mul(erfc$1(negative(x.
|
|
7484
|
+
} else return x.ref.mul(.5).mul(erfc$1(negative(x.mul(Math.SQRT1_2))));
|
|
7444
7485
|
}, { staticArgnums: [1] });
|
|
7445
7486
|
/**
|
|
7446
7487
|
* Gated linear unit (GLU) activation function.
|
|
@@ -7698,6 +7739,7 @@ var random_exports = {};
|
|
|
7698
7739
|
__export(random_exports, {
|
|
7699
7740
|
bernoulli: () => bernoulli,
|
|
7700
7741
|
bits: () => bits,
|
|
7742
|
+
categorical: () => categorical,
|
|
7701
7743
|
cauchy: () => cauchy,
|
|
7702
7744
|
exponential: () => exponential,
|
|
7703
7745
|
gumbel: () => gumbel,
|
|
@@ -7769,6 +7811,47 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
|
7769
7811
|
}
|
|
7770
7812
|
/**
|
|
7771
7813
|
* @function
|
|
7814
|
+
* Sample random values from categorical distributions.
|
|
7815
|
+
*
|
|
7816
|
+
* Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
|
|
7817
|
+
* trick for sampling without replacement.
|
|
7818
|
+
*
|
|
7819
|
+
* Note: Sampling without replacement currently uses argsort and slices the last
|
|
7820
|
+
* k elements. This should be replaced with a more efficient topK implementation.
|
|
7821
|
+
*
|
|
7822
|
+
* - `key` - PRNG key
|
|
7823
|
+
* - `logits` - Unnormalized log probabilities of the categorical distribution(s).
|
|
7824
|
+
* `softmax(logits, axis)` gives the corresponding probabilities.
|
|
7825
|
+
* - `axis` - Axis along which logits belong to the same categorical distribution.
|
|
7826
|
+
* - `shape` - Result batch shape. Must be broadcast-compatible with
|
|
7827
|
+
* `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
|
|
7828
|
+
* - `replace` - If true (default), sample with replacement. If false, sample
|
|
7829
|
+
* without replacement (each category can only be selected once per batch).
|
|
7830
|
+
* @returns A random array with int dtype and shape given by `shape` if provided,
|
|
7831
|
+
* otherwise `logits.shape` with `axis` removed.
|
|
7832
|
+
*/
|
|
7833
|
+
const categorical = jit$1(function categorical$1(key$1, logits, { axis = -1, shape: shape$1, replace = true } = {}) {
|
|
7834
|
+
logits = fudgeArray(logits);
|
|
7835
|
+
axis = require_backend.checkAxis(axis, logits.ndim);
|
|
7836
|
+
const numCategories = logits.shape[axis];
|
|
7837
|
+
const batchShape = logits.shape.toSpliced(axis, 1);
|
|
7838
|
+
if (shape$1 === void 0) shape$1 = batchShape;
|
|
7839
|
+
else if (!require_backend.deepEqual(require_backend.generalBroadcast(shape$1, batchShape), shape$1)) throw new Error(`Shape ${shape$1} is not broadcast-compatible with batch shape ${batchShape}.`);
|
|
7840
|
+
const shapePrefix = shape$1.slice(0, shape$1.length - batchShape.length);
|
|
7841
|
+
if (replace) {
|
|
7842
|
+
const noise = gumbel(key$1, [...shapePrefix, ...logits.shape]);
|
|
7843
|
+
return argmax(noise.add(logits), axis + shapePrefix.length);
|
|
7844
|
+
} else {
|
|
7845
|
+
const k = shapePrefix.reduce((a, b) => a * b, 1);
|
|
7846
|
+
if (k > numCategories) throw new Error(`Number of samples without replacement (${k}) cannot exceed number of categories (${numCategories}).`);
|
|
7847
|
+
const noise = gumbel(key$1, logits.shape);
|
|
7848
|
+
const [values, indices] = topK(noise.add(logits), k, axis);
|
|
7849
|
+
values.dispose();
|
|
7850
|
+
return indices.reshape(shape$1);
|
|
7851
|
+
}
|
|
7852
|
+
}, { staticArgnums: [2] });
|
|
7853
|
+
/**
|
|
7854
|
+
* @function
|
|
7772
7855
|
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
7773
7856
|
*
|
|
7774
7857
|
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
package/dist/index.d.cts
CHANGED
|
@@ -436,9 +436,14 @@ declare class Routine {
|
|
|
436
436
|
}
|
|
437
437
|
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
438
438
|
declare enum Routines {
|
|
439
|
-
/**
|
|
439
|
+
/**
|
|
440
|
+
* Sort along the last axis.
|
|
441
|
+
*
|
|
442
|
+
* This may be _unstable_ but it often doesn't matter, sorting numbers is
|
|
443
|
+
* bitwise unique up to signed zeros and NaNs.
|
|
444
|
+
*/
|
|
440
445
|
Sort = "Sort",
|
|
441
|
-
/**
|
|
446
|
+
/** Stable sorting, returns `int32` indices and values of the sorted array. */
|
|
442
447
|
Argsort = "Argsort",
|
|
443
448
|
/**
|
|
444
449
|
* Solve a triangular system of equations.
|
|
@@ -750,9 +755,9 @@ declare enum Primitive {
|
|
|
750
755
|
Shrink = "shrink",
|
|
751
756
|
Pad = "pad",
|
|
752
757
|
Sort = "sort",
|
|
753
|
-
// sort(x, axis=-1)
|
|
758
|
+
// sort(x, axis=-1), unstable
|
|
754
759
|
Argsort = "argsort",
|
|
755
|
-
// argsort(x, axis=-1)
|
|
760
|
+
// argsort(x, axis=-1), stable
|
|
756
761
|
TriangularSolve = "triangular_solve",
|
|
757
762
|
// A is upper triangular, A @ X.T = B.T
|
|
758
763
|
Cholesky = "cholesky",
|
|
@@ -1029,8 +1034,9 @@ declare abstract class Tracer {
|
|
|
1029
1034
|
*/
|
|
1030
1035
|
sort(axis?: number): this;
|
|
1031
1036
|
/**
|
|
1032
|
-
* Return the indices that would sort an array.
|
|
1033
|
-
* sorting algorithm; it
|
|
1037
|
+
* Return the indices that would sort an array. Unlike `sort`, this is
|
|
1038
|
+
* guaranteed to be a stable sorting algorithm; it always returns the smaller
|
|
1039
|
+
* index first in event of ties.
|
|
1034
1040
|
*
|
|
1035
1041
|
* See `jax.numpy.argsort` for full docs.
|
|
1036
1042
|
*/
|
|
@@ -1112,6 +1118,12 @@ type DTypeAndDevice = {
|
|
|
1112
1118
|
dtype?: DType;
|
|
1113
1119
|
device?: Device;
|
|
1114
1120
|
};
|
|
1121
|
+
/** @inline */
|
|
1122
|
+
type DTypeShapeAndDevice = {
|
|
1123
|
+
dtype?: DType;
|
|
1124
|
+
shape?: number[];
|
|
1125
|
+
device?: Device;
|
|
1126
|
+
};
|
|
1115
1127
|
type ArrayConstructorArgs = {
|
|
1116
1128
|
source: AluExp | Slot;
|
|
1117
1129
|
st: ShapeTracker;
|
|
@@ -1221,15 +1233,9 @@ declare function array(values: Array | DataArray | RecursiveArray<number> | Recu
|
|
|
1221
1233
|
type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
|
|
1222
1234
|
declare const implRules: { [P in Primitive]: ImplRule<P> };
|
|
1223
1235
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
1224
|
-
declare function zeros(shape: number[],
|
|
1225
|
-
dtype,
|
|
1226
|
-
device
|
|
1227
|
-
}?: DTypeAndDevice): Array;
|
|
1236
|
+
declare function zeros(shape: number[], opts?: DTypeAndDevice): Array;
|
|
1228
1237
|
/** Return a new array of given shape and type, filled with ones. */
|
|
1229
|
-
declare function ones(shape: number[],
|
|
1230
|
-
dtype,
|
|
1231
|
-
device
|
|
1232
|
-
}?: DTypeAndDevice): Array;
|
|
1238
|
+
declare function ones(shape: number[], opts?: DTypeAndDevice): Array;
|
|
1233
1239
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
1234
1240
|
declare function full(shape: number[], fillValue: number | boolean | Array, {
|
|
1235
1241
|
dtype,
|
|
@@ -1421,7 +1427,7 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
|
1421
1427
|
unitDiagonal?: boolean;
|
|
1422
1428
|
}): Array;
|
|
1423
1429
|
declare namespace lax_d_exports {
|
|
1424
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
|
|
1430
|
+
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
|
|
1425
1431
|
}
|
|
1426
1432
|
/**
|
|
1427
1433
|
* Dimension numbers for general `dot()` primitive.
|
|
@@ -1527,6 +1533,16 @@ declare function erfc(x: ArrayLike): Array;
|
|
|
1527
1533
|
* forward or reverse-mode automatic differentiation.
|
|
1528
1534
|
*/
|
|
1529
1535
|
declare function stopGradient(x: ArrayLike): Array;
|
|
1536
|
+
/**
|
|
1537
|
+
* Returns top `k` values and their indices along the specified axis of operand.
|
|
1538
|
+
*
|
|
1539
|
+
* This is a _stable_ algorithm: If two elements are equal, the lower-index
|
|
1540
|
+
* element appears first.
|
|
1541
|
+
*
|
|
1542
|
+
* @returns A tuple of `(values, indices)`, where `values` and `indices` have
|
|
1543
|
+
* the same shape as `x`, except along `axis` where they have size `k`.
|
|
1544
|
+
*/
|
|
1545
|
+
declare function topK(x: ArrayLike, k: number, axis?: number): [Array, Array];
|
|
1530
1546
|
declare namespace numpy_fft_d_exports {
|
|
1531
1547
|
export { ComplexPair, fft, ifft };
|
|
1532
1548
|
}
|
|
@@ -1752,17 +1768,17 @@ declare const shape$1: (x: ArrayLike) => number[];
|
|
|
1752
1768
|
* @function
|
|
1753
1769
|
* Return an array of zeros with the same shape and type as a given array.
|
|
1754
1770
|
*/
|
|
1755
|
-
declare const zerosLike: (a: ArrayLike,
|
|
1771
|
+
declare const zerosLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
|
|
1756
1772
|
/**
|
|
1757
1773
|
* @function
|
|
1758
1774
|
* Return an array of ones with the same shape and type as a given array.
|
|
1759
1775
|
*/
|
|
1760
|
-
declare const onesLike: (a: ArrayLike,
|
|
1776
|
+
declare const onesLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
|
|
1761
1777
|
/**
|
|
1762
1778
|
* @function
|
|
1763
1779
|
* Return a full array with the same shape and type as a given array.
|
|
1764
1780
|
*/
|
|
1765
|
-
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array,
|
|
1781
|
+
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, opts?: DTypeShapeAndDevice) => Array;
|
|
1766
1782
|
/**
|
|
1767
1783
|
* Return the number of elements in an array, optionally along an axis.
|
|
1768
1784
|
* Does not consume array reference.
|
|
@@ -1951,8 +1967,9 @@ declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: nu
|
|
|
1951
1967
|
*/
|
|
1952
1968
|
declare function sort(a: ArrayLike, axis?: number): Array;
|
|
1953
1969
|
/**
|
|
1954
|
-
* Return indices that would sort an array.
|
|
1955
|
-
* algorithm; it
|
|
1970
|
+
* Return indices that would sort an array. Unlike `sort`, this is guaranteed to
|
|
1971
|
+
* be a stable sorting algorithm; it always returns the smaller index first in
|
|
1972
|
+
* event of ties.
|
|
1956
1973
|
*
|
|
1957
1974
|
* Returns an array of `int32` indices.
|
|
1958
1975
|
*
|
|
@@ -2564,7 +2581,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
|
|
|
2564
2581
|
localWindowSize?: number | [number, number];
|
|
2565
2582
|
}): Array;
|
|
2566
2583
|
declare namespace random_d_exports {
|
|
2567
|
-
export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2584
|
+
export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2568
2585
|
}
|
|
2569
2586
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
2570
2587
|
declare function key(seed: ArrayLike): Array;
|
|
@@ -2587,6 +2604,32 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
|
|
|
2587
2604
|
* and must be broadcastable to `shape`.
|
|
2588
2605
|
*/
|
|
2589
2606
|
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
2607
|
+
/**
|
|
2608
|
+
* @function
|
|
2609
|
+
* Sample random values from categorical distributions.
|
|
2610
|
+
*
|
|
2611
|
+
* Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
|
|
2612
|
+
* trick for sampling without replacement.
|
|
2613
|
+
*
|
|
2614
|
+
* Note: Sampling without replacement currently uses argsort and slices the last
|
|
2615
|
+
* k elements. This should be replaced with a more efficient topK implementation.
|
|
2616
|
+
*
|
|
2617
|
+
* - `key` - PRNG key
|
|
2618
|
+
* - `logits` - Unnormalized log probabilities of the categorical distribution(s).
|
|
2619
|
+
* `softmax(logits, axis)` gives the corresponding probabilities.
|
|
2620
|
+
* - `axis` - Axis along which logits belong to the same categorical distribution.
|
|
2621
|
+
* - `shape` - Result batch shape. Must be broadcast-compatible with
|
|
2622
|
+
* `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
|
|
2623
|
+
* - `replace` - If true (default), sample with replacement. If false, sample
|
|
2624
|
+
* without replacement (each category can only be selected once per batch).
|
|
2625
|
+
* @returns A random array with int dtype and shape given by `shape` if provided,
|
|
2626
|
+
* otherwise `logits.shape` with `axis` removed.
|
|
2627
|
+
*/
|
|
2628
|
+
declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, args_2?: {
|
|
2629
|
+
axis?: number | undefined;
|
|
2630
|
+
shape?: number[] | undefined;
|
|
2631
|
+
replace?: boolean | undefined;
|
|
2632
|
+
} | undefined) => Array>;
|
|
2590
2633
|
/**
|
|
2591
2634
|
* @function
|
|
2592
2635
|
* Sample from a Cauchy distribution with location 0 and scale 1.
|
package/dist/index.d.ts
CHANGED
|
@@ -433,9 +433,14 @@ declare class Routine {
|
|
|
433
433
|
}
|
|
434
434
|
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
435
435
|
declare enum Routines {
|
|
436
|
-
/**
|
|
436
|
+
/**
|
|
437
|
+
* Sort along the last axis.
|
|
438
|
+
*
|
|
439
|
+
* This may be _unstable_ but it often doesn't matter, sorting numbers is
|
|
440
|
+
* bitwise unique up to signed zeros and NaNs.
|
|
441
|
+
*/
|
|
437
442
|
Sort = "Sort",
|
|
438
|
-
/**
|
|
443
|
+
/** Stable sorting, returns `int32` indices and values of the sorted array. */
|
|
439
444
|
Argsort = "Argsort",
|
|
440
445
|
/**
|
|
441
446
|
* Solve a triangular system of equations.
|
|
@@ -747,9 +752,9 @@ declare enum Primitive {
|
|
|
747
752
|
Shrink = "shrink",
|
|
748
753
|
Pad = "pad",
|
|
749
754
|
Sort = "sort",
|
|
750
|
-
// sort(x, axis=-1)
|
|
755
|
+
// sort(x, axis=-1), unstable
|
|
751
756
|
Argsort = "argsort",
|
|
752
|
-
// argsort(x, axis=-1)
|
|
757
|
+
// argsort(x, axis=-1), stable
|
|
753
758
|
TriangularSolve = "triangular_solve",
|
|
754
759
|
// A is upper triangular, A @ X.T = B.T
|
|
755
760
|
Cholesky = "cholesky",
|
|
@@ -1026,8 +1031,9 @@ declare abstract class Tracer {
|
|
|
1026
1031
|
*/
|
|
1027
1032
|
sort(axis?: number): this;
|
|
1028
1033
|
/**
|
|
1029
|
-
* Return the indices that would sort an array.
|
|
1030
|
-
* sorting algorithm; it
|
|
1034
|
+
* Return the indices that would sort an array. Unlike `sort`, this is
|
|
1035
|
+
* guaranteed to be a stable sorting algorithm; it always returns the smaller
|
|
1036
|
+
* index first in event of ties.
|
|
1031
1037
|
*
|
|
1032
1038
|
* See `jax.numpy.argsort` for full docs.
|
|
1033
1039
|
*/
|
|
@@ -1109,6 +1115,12 @@ type DTypeAndDevice = {
|
|
|
1109
1115
|
dtype?: DType;
|
|
1110
1116
|
device?: Device;
|
|
1111
1117
|
};
|
|
1118
|
+
/** @inline */
|
|
1119
|
+
type DTypeShapeAndDevice = {
|
|
1120
|
+
dtype?: DType;
|
|
1121
|
+
shape?: number[];
|
|
1122
|
+
device?: Device;
|
|
1123
|
+
};
|
|
1112
1124
|
type ArrayConstructorArgs = {
|
|
1113
1125
|
source: AluExp | Slot;
|
|
1114
1126
|
st: ShapeTracker;
|
|
@@ -1218,15 +1230,9 @@ declare function array(values: Array | DataArray | RecursiveArray<number> | Recu
|
|
|
1218
1230
|
type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
|
|
1219
1231
|
declare const implRules: { [P in Primitive]: ImplRule<P> };
|
|
1220
1232
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
1221
|
-
declare function zeros(shape: number[],
|
|
1222
|
-
dtype,
|
|
1223
|
-
device
|
|
1224
|
-
}?: DTypeAndDevice): Array;
|
|
1233
|
+
declare function zeros(shape: number[], opts?: DTypeAndDevice): Array;
|
|
1225
1234
|
/** Return a new array of given shape and type, filled with ones. */
|
|
1226
|
-
declare function ones(shape: number[],
|
|
1227
|
-
dtype,
|
|
1228
|
-
device
|
|
1229
|
-
}?: DTypeAndDevice): Array;
|
|
1235
|
+
declare function ones(shape: number[], opts?: DTypeAndDevice): Array;
|
|
1230
1236
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
1231
1237
|
declare function full(shape: number[], fillValue: number | boolean | Array, {
|
|
1232
1238
|
dtype,
|
|
@@ -1418,7 +1424,7 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
|
1418
1424
|
unitDiagonal?: boolean;
|
|
1419
1425
|
}): Array;
|
|
1420
1426
|
declare namespace lax_d_exports {
|
|
1421
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
|
|
1427
|
+
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
|
|
1422
1428
|
}
|
|
1423
1429
|
/**
|
|
1424
1430
|
* Dimension numbers for general `dot()` primitive.
|
|
@@ -1524,6 +1530,16 @@ declare function erfc(x: ArrayLike): Array;
|
|
|
1524
1530
|
* forward or reverse-mode automatic differentiation.
|
|
1525
1531
|
*/
|
|
1526
1532
|
declare function stopGradient(x: ArrayLike): Array;
|
|
1533
|
+
/**
|
|
1534
|
+
* Returns top `k` values and their indices along the specified axis of operand.
|
|
1535
|
+
*
|
|
1536
|
+
* This is a _stable_ algorithm: If two elements are equal, the lower-index
|
|
1537
|
+
* element appears first.
|
|
1538
|
+
*
|
|
1539
|
+
* @returns A tuple of `(values, indices)`, where `values` and `indices` have
|
|
1540
|
+
* the same shape as `x`, except along `axis` where they have size `k`.
|
|
1541
|
+
*/
|
|
1542
|
+
declare function topK(x: ArrayLike, k: number, axis?: number): [Array, Array];
|
|
1527
1543
|
declare namespace numpy_fft_d_exports {
|
|
1528
1544
|
export { ComplexPair, fft, ifft };
|
|
1529
1545
|
}
|
|
@@ -1749,17 +1765,17 @@ declare const shape$1: (x: ArrayLike) => number[];
|
|
|
1749
1765
|
* @function
|
|
1750
1766
|
* Return an array of zeros with the same shape and type as a given array.
|
|
1751
1767
|
*/
|
|
1752
|
-
declare const zerosLike: (a: ArrayLike,
|
|
1768
|
+
declare const zerosLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
|
|
1753
1769
|
/**
|
|
1754
1770
|
* @function
|
|
1755
1771
|
* Return an array of ones with the same shape and type as a given array.
|
|
1756
1772
|
*/
|
|
1757
|
-
declare const onesLike: (a: ArrayLike,
|
|
1773
|
+
declare const onesLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
|
|
1758
1774
|
/**
|
|
1759
1775
|
* @function
|
|
1760
1776
|
* Return a full array with the same shape and type as a given array.
|
|
1761
1777
|
*/
|
|
1762
|
-
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array,
|
|
1778
|
+
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, opts?: DTypeShapeAndDevice) => Array;
|
|
1763
1779
|
/**
|
|
1764
1780
|
* Return the number of elements in an array, optionally along an axis.
|
|
1765
1781
|
* Does not consume array reference.
|
|
@@ -1948,8 +1964,9 @@ declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: nu
|
|
|
1948
1964
|
*/
|
|
1949
1965
|
declare function sort(a: ArrayLike, axis?: number): Array;
|
|
1950
1966
|
/**
|
|
1951
|
-
* Return indices that would sort an array.
|
|
1952
|
-
* algorithm; it
|
|
1967
|
+
* Return indices that would sort an array. Unlike `sort`, this is guaranteed to
|
|
1968
|
+
* be a stable sorting algorithm; it always returns the smaller index first in
|
|
1969
|
+
* event of ties.
|
|
1953
1970
|
*
|
|
1954
1971
|
* Returns an array of `int32` indices.
|
|
1955
1972
|
*
|
|
@@ -2561,7 +2578,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
|
|
|
2561
2578
|
localWindowSize?: number | [number, number];
|
|
2562
2579
|
}): Array;
|
|
2563
2580
|
declare namespace random_d_exports {
|
|
2564
|
-
export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2581
|
+
export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2565
2582
|
}
|
|
2566
2583
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
2567
2584
|
declare function key(seed: ArrayLike): Array;
|
|
@@ -2584,6 +2601,32 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
|
|
|
2584
2601
|
* and must be broadcastable to `shape`.
|
|
2585
2602
|
*/
|
|
2586
2603
|
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
2604
|
+
/**
|
|
2605
|
+
* @function
|
|
2606
|
+
* Sample random values from categorical distributions.
|
|
2607
|
+
*
|
|
2608
|
+
* Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
|
|
2609
|
+
* trick for sampling without replacement.
|
|
2610
|
+
*
|
|
2611
|
+
* Note: Sampling without replacement currently uses argsort and slices the last
|
|
2612
|
+
* k elements. This should be replaced with a more efficient topK implementation.
|
|
2613
|
+
*
|
|
2614
|
+
* - `key` - PRNG key
|
|
2615
|
+
* - `logits` - Unnormalized log probabilities of the categorical distribution(s).
|
|
2616
|
+
* `softmax(logits, axis)` gives the corresponding probabilities.
|
|
2617
|
+
* - `axis` - Axis along which logits belong to the same categorical distribution.
|
|
2618
|
+
* - `shape` - Result batch shape. Must be broadcast-compatible with
|
|
2619
|
+
* `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
|
|
2620
|
+
* - `replace` - If true (default), sample with replacement. If false, sample
|
|
2621
|
+
* without replacement (each category can only be selected once per batch).
|
|
2622
|
+
* @returns A random array with int dtype and shape given by `shape` if provided,
|
|
2623
|
+
* otherwise `logits.shape` with `axis` removed.
|
|
2624
|
+
*/
|
|
2625
|
+
declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, args_2?: {
|
|
2626
|
+
axis?: number | undefined;
|
|
2627
|
+
shape?: number[] | undefined;
|
|
2628
|
+
replace?: boolean | undefined;
|
|
2629
|
+
} | undefined) => Array>;
|
|
2587
2630
|
/**
|
|
2588
2631
|
* @function
|
|
2589
2632
|
* Sample from a Cauchy distribution with location 0 and scale 1.
|
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-BId79r5b.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -889,18 +889,25 @@ var Tracer = class Tracer {
|
|
|
889
889
|
return sort$1(this.transpose(perm)).transpose(invertPermutation(perm));
|
|
890
890
|
}
|
|
891
891
|
/**
|
|
892
|
-
* Return the indices that would sort an array.
|
|
893
|
-
* sorting algorithm; it
|
|
892
|
+
* Return the indices that would sort an array. Unlike `sort`, this is
|
|
893
|
+
* guaranteed to be a stable sorting algorithm; it always returns the smaller
|
|
894
|
+
* index first in event of ties.
|
|
894
895
|
*
|
|
895
896
|
* See `jax.numpy.argsort` for full docs.
|
|
896
897
|
*/
|
|
897
898
|
argsort(axis = -1) {
|
|
898
899
|
axis = checkAxis(axis, this.ndim);
|
|
899
|
-
if (axis === this.ndim - 1)
|
|
900
|
+
if (axis === this.ndim - 1) {
|
|
901
|
+
const [y$1, yi$1] = argsort$1(this);
|
|
902
|
+
y$1.dispose();
|
|
903
|
+
return yi$1;
|
|
904
|
+
}
|
|
900
905
|
const perm = range(this.ndim);
|
|
901
906
|
perm.splice(axis, 1);
|
|
902
907
|
perm.push(axis);
|
|
903
|
-
|
|
908
|
+
const [y, yi] = argsort$1(this.transpose(perm));
|
|
909
|
+
y.dispose();
|
|
910
|
+
return yi.transpose(invertPermutation(perm));
|
|
904
911
|
}
|
|
905
912
|
/**
|
|
906
913
|
* Slice an array along one or more axes.
|
|
@@ -3381,32 +3388,26 @@ function fullInternal(aval, fillValue, device) {
|
|
|
3381
3388
|
committed: device != void 0
|
|
3382
3389
|
});
|
|
3383
3390
|
}
|
|
3384
|
-
function zerosLike$1(val,
|
|
3385
|
-
return fullLike(val, 0,
|
|
3391
|
+
function zerosLike$1(val, opts) {
|
|
3392
|
+
return fullLike(val, 0, opts);
|
|
3386
3393
|
}
|
|
3387
|
-
function onesLike$1(val,
|
|
3388
|
-
return fullLike(val, 1,
|
|
3394
|
+
function onesLike$1(val, opts) {
|
|
3395
|
+
return fullLike(val, 1, opts);
|
|
3389
3396
|
}
|
|
3390
|
-
function fullLike(val, fillValue, dtype) {
|
|
3397
|
+
function fullLike(val, fillValue, { dtype, shape: shape$1, device } = {}) {
|
|
3391
3398
|
const aval = getAval(val);
|
|
3392
3399
|
if (val instanceof Tracer) val.dispose();
|
|
3393
3400
|
if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
|
|
3394
|
-
const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
|
|
3395
|
-
return fullInternal(sa, fillValue);
|
|
3401
|
+
const sa = new ShapedArray(shape$1 ?? aval.shape, dtype ?? aval.dtype, aval.weakType && dtype === void 0);
|
|
3402
|
+
return fullInternal(sa, fillValue, device);
|
|
3396
3403
|
}
|
|
3397
3404
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
3398
|
-
function zeros(shape$1,
|
|
3399
|
-
return full(shape$1, 0,
|
|
3400
|
-
dtype,
|
|
3401
|
-
device
|
|
3402
|
-
});
|
|
3405
|
+
function zeros(shape$1, opts) {
|
|
3406
|
+
return full(shape$1, 0, opts);
|
|
3403
3407
|
}
|
|
3404
3408
|
/** Return a new array of given shape and type, filled with ones. */
|
|
3405
|
-
function ones(shape$1,
|
|
3406
|
-
return full(shape$1, 1,
|
|
3407
|
-
dtype,
|
|
3408
|
-
device
|
|
3409
|
-
});
|
|
3409
|
+
function ones(shape$1, opts) {
|
|
3410
|
+
return full(shape$1, 1, opts);
|
|
3410
3411
|
}
|
|
3411
3412
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
3412
3413
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
@@ -5292,9 +5293,10 @@ function lstsq(a, b) {
|
|
|
5292
5293
|
});
|
|
5293
5294
|
const llb = triangularSolve(l, lb, {
|
|
5294
5295
|
leftSide: true,
|
|
5296
|
+
lower: true,
|
|
5295
5297
|
transposeA: true
|
|
5296
5298
|
});
|
|
5297
|
-
return matmul(at, llb
|
|
5299
|
+
return matmul(at, llb);
|
|
5298
5300
|
} else {
|
|
5299
5301
|
const ata = matmul(at.ref, a);
|
|
5300
5302
|
const l = cholesky(ata, { symmetrizeInput: false });
|
|
@@ -5305,6 +5307,7 @@ function lstsq(a, b) {
|
|
|
5305
5307
|
});
|
|
5306
5308
|
const llb = triangularSolve(l, lb, {
|
|
5307
5309
|
leftSide: true,
|
|
5310
|
+
lower: true,
|
|
5308
5311
|
transposeA: true
|
|
5309
5312
|
});
|
|
5310
5313
|
return llb;
|
|
@@ -5384,7 +5387,7 @@ function solve(a, b) {
|
|
|
5384
5387
|
lower: true,
|
|
5385
5388
|
unitDiagonal: true
|
|
5386
5389
|
});
|
|
5387
|
-
let x = triangularSolve(lu$2, LPb
|
|
5390
|
+
let x = triangularSolve(lu$2, LPb, {
|
|
5388
5391
|
leftSide: true,
|
|
5389
5392
|
lower: false
|
|
5390
5393
|
});
|
|
@@ -6195,8 +6198,9 @@ function sort(a, axis = -1) {
|
|
|
6195
6198
|
return fudgeArray(a).sort(axis);
|
|
6196
6199
|
}
|
|
6197
6200
|
/**
|
|
6198
|
-
* Return indices that would sort an array.
|
|
6199
|
-
* algorithm; it
|
|
6201
|
+
* Return indices that would sort an array. Unlike `sort`, this is guaranteed to
|
|
6202
|
+
* be a stable sorting algorithm; it always returns the smaller index first in
|
|
6203
|
+
* event of ties.
|
|
6200
6204
|
*
|
|
6201
6205
|
* Returns an array of `int32` indices.
|
|
6202
6206
|
*
|
|
@@ -6498,7 +6502,7 @@ function absolute(x) {
|
|
|
6498
6502
|
/** Return an element-wise indication of sign of the input. */
|
|
6499
6503
|
function sign(x) {
|
|
6500
6504
|
x = fudgeArray(x);
|
|
6501
|
-
return where(notEqual(x.ref, 0), where(less(x
|
|
6505
|
+
return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
|
|
6502
6506
|
}
|
|
6503
6507
|
/** @function Return element-wise positive values of the input (no-op). */
|
|
6504
6508
|
const positive = fudgeArray;
|
|
@@ -6966,7 +6970,10 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
|
|
|
6966
6970
|
b = fudgeArray(b);
|
|
6967
6971
|
if (!leftSide) transposeA = !transposeA;
|
|
6968
6972
|
else b = moveaxis$1(b, -2, -1);
|
|
6969
|
-
if (transposeA)
|
|
6973
|
+
if (transposeA) {
|
|
6974
|
+
a = moveaxis$1(a, -2, -1);
|
|
6975
|
+
lower = !lower;
|
|
6976
|
+
}
|
|
6970
6977
|
let x = triangularSolve$1(a, b, {
|
|
6971
6978
|
lower,
|
|
6972
6979
|
unitDiagonal
|
|
@@ -6988,7 +6995,8 @@ __export(lax_exports, {
|
|
|
6988
6995
|
erfc: () => erfc,
|
|
6989
6996
|
linalg: () => lax_linalg_exports,
|
|
6990
6997
|
reduceWindow: () => reduceWindow,
|
|
6991
|
-
stopGradient: () => stopGradient$1
|
|
6998
|
+
stopGradient: () => stopGradient$1,
|
|
6999
|
+
topK: () => topK
|
|
6992
7000
|
});
|
|
6993
7001
|
const JsArray = globalThis.Array;
|
|
6994
7002
|
/**
|
|
@@ -7212,6 +7220,39 @@ function erfc(x) {
|
|
|
7212
7220
|
function stopGradient$1(x) {
|
|
7213
7221
|
return stopGradient(x);
|
|
7214
7222
|
}
|
|
7223
|
+
/**
|
|
7224
|
+
* Returns top `k` values and their indices along the specified axis of operand.
|
|
7225
|
+
*
|
|
7226
|
+
* This is a _stable_ algorithm: If two elements are equal, the lower-index
|
|
7227
|
+
* element appears first.
|
|
7228
|
+
*
|
|
7229
|
+
* @returns A tuple of `(values, indices)`, where `values` and `indices` have
|
|
7230
|
+
* the same shape as `x`, except along `axis` where they have size `k`.
|
|
7231
|
+
*/
|
|
7232
|
+
function topK(x, k, axis = -1) {
|
|
7233
|
+
x = fudgeArray(x);
|
|
7234
|
+
axis = checkAxis(axis, x.ndim);
|
|
7235
|
+
const size$1 = x.shape[axis];
|
|
7236
|
+
if (k < 0 || k > size$1) throw new Error(`topK: k must be in the range [0, ${size$1}], got ${k}`);
|
|
7237
|
+
if (k === 0) {
|
|
7238
|
+
const outShape = x.shape.slice();
|
|
7239
|
+
outShape[axis] = 0;
|
|
7240
|
+
const y$1 = zerosLike$1(x.ref, { shape: outShape });
|
|
7241
|
+
const yi$1 = zerosLike$1(x, {
|
|
7242
|
+
dtype: DType.Int32,
|
|
7243
|
+
shape: outShape
|
|
7244
|
+
});
|
|
7245
|
+
return [y$1, yi$1];
|
|
7246
|
+
}
|
|
7247
|
+
x = flip$1(x, [axis]);
|
|
7248
|
+
x = moveaxis(x, axis, -1);
|
|
7249
|
+
const [y, yi] = argsort$1(x);
|
|
7250
|
+
const extract = (a) => {
|
|
7251
|
+
a = a.slice(...rep(a.ndim - 1, []), [-k]);
|
|
7252
|
+
return flip$1(moveaxis(a, -1, axis), [axis]);
|
|
7253
|
+
};
|
|
7254
|
+
return [extract(y), extract(yi.neg().add(size$1 - 1))];
|
|
7255
|
+
}
|
|
7215
7256
|
|
|
7216
7257
|
//#endregion
|
|
7217
7258
|
//#region src/library/nn.ts
|
|
@@ -7403,7 +7444,7 @@ const gelu = jit$1(function gelu$1(x, opts) {
|
|
|
7403
7444
|
if (opts?.approximate ?? true) {
|
|
7404
7445
|
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
7405
7446
|
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
7406
|
-
} else return x.ref.mul(.5).mul(erfc$1(negative(x.
|
|
7447
|
+
} else return x.ref.mul(.5).mul(erfc$1(negative(x.mul(Math.SQRT1_2))));
|
|
7407
7448
|
}, { staticArgnums: [1] });
|
|
7408
7449
|
/**
|
|
7409
7450
|
* Gated linear unit (GLU) activation function.
|
|
@@ -7661,6 +7702,7 @@ var random_exports = {};
|
|
|
7661
7702
|
__export(random_exports, {
|
|
7662
7703
|
bernoulli: () => bernoulli,
|
|
7663
7704
|
bits: () => bits,
|
|
7705
|
+
categorical: () => categorical,
|
|
7664
7706
|
cauchy: () => cauchy,
|
|
7665
7707
|
exponential: () => exponential,
|
|
7666
7708
|
gumbel: () => gumbel,
|
|
@@ -7732,6 +7774,47 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
|
7732
7774
|
}
|
|
7733
7775
|
/**
|
|
7734
7776
|
* @function
|
|
7777
|
+
* Sample random values from categorical distributions.
|
|
7778
|
+
*
|
|
7779
|
+
* Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
|
|
7780
|
+
* trick for sampling without replacement.
|
|
7781
|
+
*
|
|
7782
|
+
* Note: Sampling without replacement currently uses argsort and slices the last
|
|
7783
|
+
* k elements. This should be replaced with a more efficient topK implementation.
|
|
7784
|
+
*
|
|
7785
|
+
* - `key` - PRNG key
|
|
7786
|
+
* - `logits` - Unnormalized log probabilities of the categorical distribution(s).
|
|
7787
|
+
* `softmax(logits, axis)` gives the corresponding probabilities.
|
|
7788
|
+
* - `axis` - Axis along which logits belong to the same categorical distribution.
|
|
7789
|
+
* - `shape` - Result batch shape. Must be broadcast-compatible with
|
|
7790
|
+
* `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
|
|
7791
|
+
* - `replace` - If true (default), sample with replacement. If false, sample
|
|
7792
|
+
* without replacement (each category can only be selected once per batch).
|
|
7793
|
+
* @returns A random array with int dtype and shape given by `shape` if provided,
|
|
7794
|
+
* otherwise `logits.shape` with `axis` removed.
|
|
7795
|
+
*/
|
|
7796
|
+
const categorical = jit$1(function categorical$1(key$1, logits, { axis = -1, shape: shape$1, replace = true } = {}) {
|
|
7797
|
+
logits = fudgeArray(logits);
|
|
7798
|
+
axis = checkAxis(axis, logits.ndim);
|
|
7799
|
+
const numCategories = logits.shape[axis];
|
|
7800
|
+
const batchShape = logits.shape.toSpliced(axis, 1);
|
|
7801
|
+
if (shape$1 === void 0) shape$1 = batchShape;
|
|
7802
|
+
else if (!deepEqual(generalBroadcast(shape$1, batchShape), shape$1)) throw new Error(`Shape ${shape$1} is not broadcast-compatible with batch shape ${batchShape}.`);
|
|
7803
|
+
const shapePrefix = shape$1.slice(0, shape$1.length - batchShape.length);
|
|
7804
|
+
if (replace) {
|
|
7805
|
+
const noise = gumbel(key$1, [...shapePrefix, ...logits.shape]);
|
|
7806
|
+
return argmax(noise.add(logits), axis + shapePrefix.length);
|
|
7807
|
+
} else {
|
|
7808
|
+
const k = shapePrefix.reduce((a, b) => a * b, 1);
|
|
7809
|
+
if (k > numCategories) throw new Error(`Number of samples without replacement (${k}) cannot exceed number of categories (${numCategories}).`);
|
|
7810
|
+
const noise = gumbel(key$1, logits.shape);
|
|
7811
|
+
const [values, indices] = topK(noise.add(logits), k, axis);
|
|
7812
|
+
values.dispose();
|
|
7813
|
+
return indices.reshape(shape$1);
|
|
7814
|
+
}
|
|
7815
|
+
}, { staticArgnums: [2] });
|
|
7816
|
+
/**
|
|
7817
|
+
* @function
|
|
7735
7818
|
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
7736
7819
|
*
|
|
7737
7820
|
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-
|
|
1
|
+
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-BId79r5b.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgl/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-BId79r5b.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -247,6 +247,10 @@ function bitonicSortUniform(pass) {
|
|
|
247
247
|
* `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
|
|
248
248
|
*
|
|
249
249
|
* The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
|
|
250
|
+
*
|
|
251
|
+
* If `outputIndices` is true, the shader also tracks the original indices of
|
|
252
|
+
* the sorted elements (argsort) and outputs them to a separate buffer. This
|
|
253
|
+
* also makes the sorting algorithm stable.
|
|
250
254
|
*/
|
|
251
255
|
function bitonicSortShader(device, dtype, n, batches, outputIndices) {
|
|
252
256
|
const ty = dtypeToWgsl(dtype, true);
|
|
@@ -286,14 +290,21 @@ ${isFloatDtype(dtype) ? `
|
|
|
286
290
|
fn compare_and_swap(i: u32, j: u32) {
|
|
287
291
|
let val_i = shared_vals[i];
|
|
288
292
|
let val_j = shared_vals[j];
|
|
289
|
-
|
|
293
|
+
${outputIndices ? `
|
|
294
|
+
if (
|
|
295
|
+
compare(val_j, val_i) ||
|
|
296
|
+
(!compare(val_i, val_j) && shared_idx[j] < shared_idx[i])
|
|
297
|
+
) {
|
|
290
298
|
shared_vals[i] = val_j;
|
|
291
299
|
shared_vals[j] = val_i;
|
|
292
|
-
${outputIndices ? `
|
|
293
300
|
let tmp_idx = shared_idx[i];
|
|
294
301
|
shared_idx[i] = shared_idx[j];
|
|
295
|
-
shared_idx[j] = tmp_idx
|
|
296
|
-
}
|
|
302
|
+
shared_idx[j] = tmp_idx;
|
|
303
|
+
}` : `
|
|
304
|
+
if (compare(val_j, val_i)) {
|
|
305
|
+
shared_vals[i] = val_j;
|
|
306
|
+
shared_vals[j] = val_i;
|
|
307
|
+
}`}
|
|
297
308
|
}
|
|
298
309
|
|
|
299
310
|
@compute @workgroup_size(${workgroupSize})
|
|
@@ -370,13 +381,17 @@ ${outputIndices ? `
|
|
|
370
381
|
if (j < ${n}u) {
|
|
371
382
|
let val_i = output[base + i];
|
|
372
383
|
let val_j = output[base + j];
|
|
373
|
-
|
|
384
|
+
${outputIndices ? `
|
|
385
|
+
let idx_i = output_idx[base + i];
|
|
386
|
+
let idx_j = output_idx[base + j];
|
|
387
|
+
if (compare(val_j, val_i) || (!compare(val_i, val_j) && idx_j < idx_i)) {
|
|
374
388
|
output[base + i] = val_j;
|
|
375
389
|
output[base + j] = val_i;
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
390
|
+
output_idx[base + i] = idx_j;
|
|
391
|
+
output_idx[base + j] = idx_i;` : `
|
|
392
|
+
if (compare(val_j, val_i)) {
|
|
393
|
+
output[base + i] = val_j;
|
|
394
|
+
output[base + j] = val_i;`}
|
|
380
395
|
}
|
|
381
396
|
}
|
|
382
397
|
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-DpI0riom.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -247,6 +247,10 @@ function bitonicSortUniform(pass) {
|
|
|
247
247
|
* `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
|
|
248
248
|
*
|
|
249
249
|
* The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
|
|
250
|
+
*
|
|
251
|
+
* If `outputIndices` is true, the shader also tracks the original indices of
|
|
252
|
+
* the sorted elements (argsort) and outputs them to a separate buffer. This
|
|
253
|
+
* also makes the sorting algorithm stable.
|
|
250
254
|
*/
|
|
251
255
|
function bitonicSortShader(device, dtype, n, batches, outputIndices) {
|
|
252
256
|
const ty = dtypeToWgsl(dtype, true);
|
|
@@ -286,14 +290,21 @@ ${require_backend.isFloatDtype(dtype) ? `
|
|
|
286
290
|
fn compare_and_swap(i: u32, j: u32) {
|
|
287
291
|
let val_i = shared_vals[i];
|
|
288
292
|
let val_j = shared_vals[j];
|
|
289
|
-
|
|
293
|
+
${outputIndices ? `
|
|
294
|
+
if (
|
|
295
|
+
compare(val_j, val_i) ||
|
|
296
|
+
(!compare(val_i, val_j) && shared_idx[j] < shared_idx[i])
|
|
297
|
+
) {
|
|
290
298
|
shared_vals[i] = val_j;
|
|
291
299
|
shared_vals[j] = val_i;
|
|
292
|
-
${outputIndices ? `
|
|
293
300
|
let tmp_idx = shared_idx[i];
|
|
294
301
|
shared_idx[i] = shared_idx[j];
|
|
295
|
-
shared_idx[j] = tmp_idx
|
|
296
|
-
}
|
|
302
|
+
shared_idx[j] = tmp_idx;
|
|
303
|
+
}` : `
|
|
304
|
+
if (compare(val_j, val_i)) {
|
|
305
|
+
shared_vals[i] = val_j;
|
|
306
|
+
shared_vals[j] = val_i;
|
|
307
|
+
}`}
|
|
297
308
|
}
|
|
298
309
|
|
|
299
310
|
@compute @workgroup_size(${workgroupSize})
|
|
@@ -370,13 +381,17 @@ ${outputIndices ? `
|
|
|
370
381
|
if (j < ${n}u) {
|
|
371
382
|
let val_i = output[base + i];
|
|
372
383
|
let val_j = output[base + j];
|
|
373
|
-
|
|
384
|
+
${outputIndices ? `
|
|
385
|
+
let idx_i = output_idx[base + i];
|
|
386
|
+
let idx_j = output_idx[base + j];
|
|
387
|
+
if (compare(val_j, val_i) || (!compare(val_i, val_j) && idx_j < idx_i)) {
|
|
374
388
|
output[base + i] = val_j;
|
|
375
389
|
output[base + j] = val_i;
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
390
|
+
output_idx[base + i] = idx_j;
|
|
391
|
+
output_idx[base + j] = idx_i;` : `
|
|
392
|
+
if (compare(val_j, val_i)) {
|
|
393
|
+
output[base + i] = val_j;
|
|
394
|
+
output[base + j] = val_i;`}
|
|
380
395
|
}
|
|
381
396
|
}
|
|
382
397
|
}
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@jax-js/jax",
|
|
3
|
-
"version": "0.1.
|
|
3
|
+
"version": "0.1.9",
|
|
4
4
|
"description": "Numerical computing and ML in the browser",
|
|
5
5
|
"keywords": [
|
|
6
6
|
"machine learning",
|
|
@@ -44,6 +44,8 @@
|
|
|
44
44
|
"eslint": "^9.31.0",
|
|
45
45
|
"eslint-plugin-import": "^2.32.0",
|
|
46
46
|
"globals": "^16.0.0",
|
|
47
|
+
"husky": "^9.1.7",
|
|
48
|
+
"lint-staged": "^16.2.7",
|
|
47
49
|
"playwright": "~1.52.0",
|
|
48
50
|
"prettier": "^3.6.2",
|
|
49
51
|
"prettier-plugin-svelte": "^3.4.0",
|
|
@@ -74,6 +76,15 @@
|
|
|
74
76
|
],
|
|
75
77
|
"proseWrap": "always"
|
|
76
78
|
},
|
|
79
|
+
"lint-staged": {
|
|
80
|
+
"*.{ts,tsx,js,jsx}": [
|
|
81
|
+
"eslint --fix",
|
|
82
|
+
"prettier --write"
|
|
83
|
+
],
|
|
84
|
+
"*.{json,md,yml,yaml,css,svelte,html}": [
|
|
85
|
+
"prettier --write"
|
|
86
|
+
]
|
|
87
|
+
},
|
|
77
88
|
"scripts": {
|
|
78
89
|
"build": "tsdown",
|
|
79
90
|
"build:watch": "TSDOWN_WATCH_MODE=1 tsdown",
|