@jax-js/jax 0.1.6 → 0.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +73 -7
- package/dist/{backend-D7s-Retx.cjs → backend-B3foXiV_.cjs} +4 -4
- package/dist/{backend-Dx6Ob2D1.js → backend-nEolvdLv.js} +4 -4
- package/dist/index.cjs +94 -23
- package/dist/index.d.cts +1561 -1538
- package/dist/index.d.ts +1561 -1538
- package/dist/index.js +94 -23
- package/dist/{webgl-CyfzNW8T.cjs → webgl-DIIbKJ0G.cjs} +1 -1
- package/dist/{webgl-CLLvzJlO.js → webgl-DweKSWEm.js} +1 -1
- package/dist/{webgpu-C-VfevQW.js → webgpu-B96vzWGE.js} +1 -1
- package/dist/{webgpu-rraa6dfz.cjs → webgpu-BykvF26B.cjs} +1 -1
- package/package.json +12 -1
package/README.md
CHANGED
|
@@ -43,6 +43,55 @@ way to get started on a blank HTML page.
|
|
|
43
43
|
</script>
|
|
44
44
|
```
|
|
45
45
|
|
|
46
|
+
## Feature comparison
|
|
47
|
+
|
|
48
|
+
Here's a quick, high-level comparison with other popular web ML runtimes:
|
|
49
|
+
|
|
50
|
+
| Feature | jax-js | TensorFlow.js | onnxruntime-web |
|
|
51
|
+
| ------------------------------- | ---------- | --------------- | ------------------ |
|
|
52
|
+
| **Overview** | | | |
|
|
53
|
+
| API style | JAX/NumPy | TensorFlow-like | Static ONNX graphs |
|
|
54
|
+
| Latest release | 2026 | ⚠️ 2024 | 2026 |
|
|
55
|
+
| Speed | Fastest | Fast | Fastest |
|
|
56
|
+
| Bundle size (gzip) | 80 KB | 269 KB | 90 KB + 24 MB Wasm |
|
|
57
|
+
| **Autodiff & JIT** | | | |
|
|
58
|
+
| Gradients | ✅ | ✅ | ❌ |
|
|
59
|
+
| Jacobian and Hessian | ✅ | ❌ | ❌ |
|
|
60
|
+
| `jvp()` forward differentiation | ✅ | ❌ | ❌ |
|
|
61
|
+
| `jit()` kernel fusion | ✅ | ❌ | ❌ |
|
|
62
|
+
| `vmap()` auto-vectorization | ✅ | ❌ | ❌ |
|
|
63
|
+
| Graph capture | ✅ | ❌ | ✅ |
|
|
64
|
+
| **Backends & Data** | | | |
|
|
65
|
+
| WebGPU backend | ✅ | 🟡 Preview | ✅ |
|
|
66
|
+
| WebGL backend | ✅ | ✅ | ✅ |
|
|
67
|
+
| Wasm (CPU) backend | ✅ | ✅ | ✅ |
|
|
68
|
+
| Eager array API | ✅ | ✅ | ❌ |
|
|
69
|
+
| Run ONNX models | 🟡 Partial | ❌ | ✅ |
|
|
70
|
+
| Read safetensors | ✅ | ❌ | ❌ |
|
|
71
|
+
| Float64 | ✅ | ❌ | ❌ |
|
|
72
|
+
| Float32 | ✅ | ✅ | ✅ |
|
|
73
|
+
| Float16 | ✅ | ❌ | ✅ |
|
|
74
|
+
| BFloat16 | ❌ | ❌ | ❌ |
|
|
75
|
+
| Packed Uint8 | ❌ | ❌ | 🟡 Partial |
|
|
76
|
+
| Mixed precision | ✅ | ❌ | ✅ |
|
|
77
|
+
| Mixed devices | ✅ | ❌ | ❌ |
|
|
78
|
+
| **Ops & Numerics** | | | |
|
|
79
|
+
| Arithmetic functions | ✅ | ✅ | ✅ |
|
|
80
|
+
| Matrix multiplication | ✅ | ✅ | ✅ |
|
|
81
|
+
| General einsum | ✅ | 🟡 Partial | 🟡 Partial |
|
|
82
|
+
| Sorting | ✅ | ❌ | ❌ |
|
|
83
|
+
| Activation functions | ✅ | ✅ | ✅ |
|
|
84
|
+
| NaN/Inf numerics | ✅ | ✅ | ✅ |
|
|
85
|
+
| Basic convolutions | ✅ | ✅ | ✅ |
|
|
86
|
+
| n-d convolutions | ✅ | ❌ | ✅ |
|
|
87
|
+
| Strided/dilated convolution | ✅ | ✅ | ✅ |
|
|
88
|
+
| Cholesky, Lstsq | ✅ | ❌ | ❌ |
|
|
89
|
+
| LU, Solve, Determinant | ✅ | ❌ | ❌ |
|
|
90
|
+
| SVD | ❌ | ❌ | ❌ |
|
|
91
|
+
| FFT | ✅ | ✅ | ✅ |
|
|
92
|
+
| Basic RNG (Uniform, Normal) | ✅ | ✅ | ✅ |
|
|
93
|
+
| Advanced RNG | ✅ | ❌ | ❌ |
|
|
94
|
+
|
|
46
95
|
## Tutorial
|
|
47
96
|
|
|
48
97
|
Programming in `jax-js` looks [very similar to JAX](https://docs.jax.dev/en/latest/jax-101.html),
|
|
@@ -271,15 +320,18 @@ self-contained way in other projects.
|
|
|
271
320
|
|
|
272
321
|
### Performance
|
|
273
322
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
chip ([try it](https://jax-js.com/bench/matmul)).
|
|
323
|
+
The WebGPU runtime includes an ML compiler with tile-aware optimizations, tuned for indiidual
|
|
324
|
+
browsers. Also, this library uniquely has the `jit()` feature that fuses operations together and
|
|
325
|
+
records an execution graph. jax-js achieves **over 7000 GFLOP/s** for matrix multiplication on an
|
|
326
|
+
Apple M4 Max chip ([try it](https://jax-js.com/bench/matmul)).
|
|
278
327
|
|
|
279
|
-
For that example, it's
|
|
328
|
+
For that example, it's significantly faster than both
|
|
280
329
|
[TensorFlow.js](https://github.com/tensorflow/tfjs) and
|
|
281
330
|
[ONNX Runtime Web](https://www.npmjs.com/package/onnxruntime-web), which both use handwritten
|
|
282
|
-
libraries of custom kernels
|
|
331
|
+
libraries of custom kernels.
|
|
332
|
+
|
|
333
|
+
It's still early though. There's a lot of low-hanging fruit to continue optimizing the library, as
|
|
334
|
+
well as unique optimizations such as FlashAttention variants.
|
|
283
335
|
|
|
284
336
|
### API Reference
|
|
285
337
|
|
|
@@ -291,8 +343,9 @@ That's all for this short tutorial. Please see the generated
|
|
|
291
343
|
If you make something cool with jax-js, don't be a stranger! We can feature it here.
|
|
292
344
|
|
|
293
345
|
- [Training neural networks on MNIST](https://jax-js.com/mnist)
|
|
346
|
+
- [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
|
|
294
347
|
- [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
|
|
295
|
-
- [Object detection
|
|
348
|
+
- [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
|
|
296
349
|
- [In-browser REPL](https://jax-js.com/repl)
|
|
297
350
|
- [Matmul benchmark](https://jax-js.com/bench/matmul)
|
|
298
351
|
- [Conv2d benchmark](https://jax-js.com/bench/conv2d)
|
|
@@ -310,6 +363,19 @@ pnpm install
|
|
|
310
363
|
pnpm run build:watch
|
|
311
364
|
```
|
|
312
365
|
|
|
366
|
+
The `pnpm install` command automatically sets up Git hooks via
|
|
367
|
+
[Husky](https://typicode.github.io/husky/). Pre-commit hooks will run ESLint and Prettier on staged
|
|
368
|
+
files to ensure code quality.
|
|
369
|
+
|
|
370
|
+
You can also run linting and formatting manually:
|
|
371
|
+
|
|
372
|
+
```bash
|
|
373
|
+
pnpm lint # Run ESLint
|
|
374
|
+
pnpm format # Format all files with Prettier
|
|
375
|
+
pnpm format:check # Check formatting without writing
|
|
376
|
+
pnpm check # Run TypeScript type checking
|
|
377
|
+
```
|
|
378
|
+
|
|
313
379
|
Then you can run tests in a headless browser using [Vitest](https://vitest.dev/).
|
|
314
380
|
|
|
315
381
|
```bash
|
|
@@ -910,7 +910,7 @@ var AluExp = class AluExp {
|
|
|
910
910
|
return ret.simplify(cache);
|
|
911
911
|
}
|
|
912
912
|
}
|
|
913
|
-
if (y.arg > 0) {
|
|
913
|
+
if (y.arg > 0 && x.min >= 0) {
|
|
914
914
|
let [xNoConst, constVal] = [x, 0];
|
|
915
915
|
if (x.op === AluOp.Add && x.src[1].op === AluOp.Const) [xNoConst, constVal] = [x.src[0], x.src[1].arg];
|
|
916
916
|
const terms = [];
|
|
@@ -932,7 +932,7 @@ var AluExp = class AluExp {
|
|
|
932
932
|
rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(remainders[i] / gcdVal)), terms[i]));
|
|
933
933
|
quo = AluExp.add(quo, AluExp.mul(AluExp.const(x.dtype, quotients[i]), terms[i]));
|
|
934
934
|
}
|
|
935
|
-
if (
|
|
935
|
+
if (rem.min >= 0) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
|
|
936
936
|
else return AluExp.add(AluExp.idiv(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal))), quo).simplify(cache);
|
|
937
937
|
}
|
|
938
938
|
}
|
|
@@ -4253,7 +4253,7 @@ async function createBackend(device) {
|
|
|
4253
4253
|
if (!navigator.gpu) return null;
|
|
4254
4254
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4255
4255
|
if (!adapter) return null;
|
|
4256
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
4256
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BykvF26B.cjs"));
|
|
4257
4257
|
const importantLimits = [
|
|
4258
4258
|
"maxBufferSize",
|
|
4259
4259
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4291,7 +4291,7 @@ async function createBackend(device) {
|
|
|
4291
4291
|
});
|
|
4292
4292
|
if (!gl) return null;
|
|
4293
4293
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4294
|
-
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-
|
|
4294
|
+
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-DIIbKJ0G.cjs"));
|
|
4295
4295
|
return new WebGLBackend(gl);
|
|
4296
4296
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4297
4297
|
}
|
|
@@ -909,7 +909,7 @@ var AluExp = class AluExp {
|
|
|
909
909
|
return ret.simplify(cache);
|
|
910
910
|
}
|
|
911
911
|
}
|
|
912
|
-
if (y.arg > 0) {
|
|
912
|
+
if (y.arg > 0 && x.min >= 0) {
|
|
913
913
|
let [xNoConst, constVal] = [x, 0];
|
|
914
914
|
if (x.op === AluOp.Add && x.src[1].op === AluOp.Const) [xNoConst, constVal] = [x.src[0], x.src[1].arg];
|
|
915
915
|
const terms = [];
|
|
@@ -931,7 +931,7 @@ var AluExp = class AluExp {
|
|
|
931
931
|
rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(remainders[i] / gcdVal)), terms[i]));
|
|
932
932
|
quo = AluExp.add(quo, AluExp.mul(AluExp.const(x.dtype, quotients[i]), terms[i]));
|
|
933
933
|
}
|
|
934
|
-
if (
|
|
934
|
+
if (rem.min >= 0) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
|
|
935
935
|
else return AluExp.add(AluExp.idiv(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal))), quo).simplify(cache);
|
|
936
936
|
}
|
|
937
937
|
}
|
|
@@ -4252,7 +4252,7 @@ async function createBackend(device) {
|
|
|
4252
4252
|
if (!navigator.gpu) return null;
|
|
4253
4253
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4254
4254
|
if (!adapter) return null;
|
|
4255
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
4255
|
+
const { WebGPUBackend } = await import("./webgpu-B96vzWGE.js");
|
|
4256
4256
|
const importantLimits = [
|
|
4257
4257
|
"maxBufferSize",
|
|
4258
4258
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4290,7 +4290,7 @@ async function createBackend(device) {
|
|
|
4290
4290
|
});
|
|
4291
4291
|
if (!gl) return null;
|
|
4292
4292
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4293
|
-
const { WebGLBackend } = await import("./webgl-
|
|
4293
|
+
const { WebGLBackend } = await import("./webgl-DweKSWEm.js");
|
|
4294
4294
|
return new WebGLBackend(gl);
|
|
4295
4295
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4296
4296
|
}
|
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-B3foXiV_.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -337,11 +337,11 @@ function map(fn, tree, ...rest) {
|
|
|
337
337
|
}
|
|
338
338
|
/** Take a reference of every array in a tree. */
|
|
339
339
|
function ref(tree) {
|
|
340
|
-
return map((x) => x.ref, tree);
|
|
340
|
+
return map((x) => x instanceof Tracer ? x.ref : x, tree);
|
|
341
341
|
}
|
|
342
342
|
/** Dispose every array in a tree. */
|
|
343
343
|
function dispose(tree) {
|
|
344
|
-
if (tree) map((x) => x.dispose(), tree);
|
|
344
|
+
if (tree) map((x) => x instanceof Tracer ? x.dispose() : void 0, tree);
|
|
345
345
|
}
|
|
346
346
|
|
|
347
347
|
//#endregion
|
|
@@ -615,14 +615,20 @@ function shrink(x, slice) {
|
|
|
615
615
|
}
|
|
616
616
|
function pad$1(x, width) {
|
|
617
617
|
const nd = ndim$1(x);
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
else if (
|
|
621
|
-
if (width
|
|
622
|
-
const
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
618
|
+
let w;
|
|
619
|
+
if (typeof width === "number") w = [[width, width]];
|
|
620
|
+
else if (require_backend.isNumberPair(width)) w = [width];
|
|
621
|
+
else if (!Array.isArray(width)) {
|
|
622
|
+
const indicesAndPairs = Object.entries(width);
|
|
623
|
+
w = require_backend.rep(nd, [0, 0]);
|
|
624
|
+
for (const [k, v] of indicesAndPairs) w[require_backend.checkAxis(parseInt(k), nd)] = v;
|
|
625
|
+
} else if (!width.every(require_backend.isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
|
|
626
|
+
else w = width;
|
|
627
|
+
if (w.length === 1) {
|
|
628
|
+
const [w0, w1] = w[0];
|
|
629
|
+
w = require_backend.rep(nd, () => [w0, w1]);
|
|
630
|
+
} else if (w.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${w.length}`);
|
|
631
|
+
return bind1(Primitive.Pad, [x], { width: w });
|
|
626
632
|
}
|
|
627
633
|
function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
628
634
|
const as = getShape(a);
|
|
@@ -798,6 +804,22 @@ var Tracer = class Tracer {
|
|
|
798
804
|
const result = reduce(this.astype(castDtype), require_backend.AluOp.Add, axis, opts);
|
|
799
805
|
return result.mul(1 / n).astype(originalDtype);
|
|
800
806
|
}
|
|
807
|
+
/** Minimum of the elements of the array along a given axis. */
|
|
808
|
+
min(axis = null, opts) {
|
|
809
|
+
return reduce(this, require_backend.AluOp.Min, axis, opts);
|
|
810
|
+
}
|
|
811
|
+
/** Maximum of the elements of the array along a given axis. */
|
|
812
|
+
max(axis = null, opts) {
|
|
813
|
+
return reduce(this, require_backend.AluOp.Max, axis, opts);
|
|
814
|
+
}
|
|
815
|
+
/** Test whether all array elements along a given axis evaluate to true. */
|
|
816
|
+
all(axis = null, opts) {
|
|
817
|
+
return this.astype(require_backend.DType.Bool).min(axis, opts);
|
|
818
|
+
}
|
|
819
|
+
/** Test whether any array element along a given axis evaluates to true. */
|
|
820
|
+
any(axis = null, opts) {
|
|
821
|
+
return this.astype(require_backend.DType.Bool).max(axis, opts);
|
|
822
|
+
}
|
|
801
823
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
802
824
|
transpose(perm) {
|
|
803
825
|
return transpose$1(this, perm);
|
|
@@ -5307,6 +5329,7 @@ function lstsq(a, b) {
|
|
|
5307
5329
|
});
|
|
5308
5330
|
const llb = triangularSolve(l, lb, {
|
|
5309
5331
|
leftSide: true,
|
|
5332
|
+
lower: true,
|
|
5310
5333
|
transposeA: true
|
|
5311
5334
|
});
|
|
5312
5335
|
return matmul(at, llb.ref);
|
|
@@ -5320,6 +5343,7 @@ function lstsq(a, b) {
|
|
|
5320
5343
|
});
|
|
5321
5344
|
const llb = triangularSolve(l, lb, {
|
|
5322
5345
|
leftSide: true,
|
|
5346
|
+
lower: true,
|
|
5323
5347
|
transposeA: true
|
|
5324
5348
|
});
|
|
5325
5349
|
return llb;
|
|
@@ -5607,6 +5631,7 @@ __export(numpy_exports, {
|
|
|
5607
5631
|
moveaxis: () => moveaxis$1,
|
|
5608
5632
|
multiply: () => multiply,
|
|
5609
5633
|
nan: () => nan,
|
|
5634
|
+
nanToNum: () => nanToNum,
|
|
5610
5635
|
ndim: () => ndim,
|
|
5611
5636
|
negative: () => negative,
|
|
5612
5637
|
notEqual: () => notEqual,
|
|
@@ -5804,24 +5829,22 @@ function max(a, axis = null, opts) {
|
|
|
5804
5829
|
return reduce(a, require_backend.AluOp.Max, axis, opts);
|
|
5805
5830
|
}
|
|
5806
5831
|
/**
|
|
5807
|
-
* Test whether
|
|
5832
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
5808
5833
|
*
|
|
5809
5834
|
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5810
5835
|
* removed. If axis is None, returns a scalar.
|
|
5811
5836
|
*/
|
|
5812
|
-
function
|
|
5813
|
-
|
|
5814
|
-
return min(a, axis, opts);
|
|
5837
|
+
function any(a, axis = null, opts) {
|
|
5838
|
+
return fudgeArray(a).any(axis, opts);
|
|
5815
5839
|
}
|
|
5816
5840
|
/**
|
|
5817
|
-
* Test whether
|
|
5841
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
5818
5842
|
*
|
|
5819
5843
|
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5820
5844
|
* removed. If axis is None, returns a scalar.
|
|
5821
5845
|
*/
|
|
5822
|
-
function
|
|
5823
|
-
|
|
5824
|
-
return max(a, axis, opts);
|
|
5846
|
+
function all(a, axis = null, opts) {
|
|
5847
|
+
return fudgeArray(a).all(axis, opts);
|
|
5825
5848
|
}
|
|
5826
5849
|
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
5827
5850
|
function ptp(a, axis = null, opts) {
|
|
@@ -5922,7 +5945,7 @@ function split$1(a, indicesOrSections, axis = 0) {
|
|
|
5922
5945
|
const partSize = size$1 / indicesOrSections;
|
|
5923
5946
|
sizes = require_backend.rep(indicesOrSections, partSize);
|
|
5924
5947
|
} else {
|
|
5925
|
-
const indices = indicesOrSections;
|
|
5948
|
+
const indices = indicesOrSections.map((i) => i < 0 ? i + size$1 : i);
|
|
5926
5949
|
sizes = [indices[0]];
|
|
5927
5950
|
for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
|
|
5928
5951
|
sizes.push(size$1 - indices[indices.length - 1]);
|
|
@@ -6870,6 +6893,21 @@ function isposinf(x) {
|
|
|
6870
6893
|
return require_backend.isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
|
|
6871
6894
|
}
|
|
6872
6895
|
/**
|
|
6896
|
+
* Replace NaN and infinite entries in an array.
|
|
6897
|
+
*
|
|
6898
|
+
* By default, NaNs are replaced with `0.0`, and infinities are are substituted
|
|
6899
|
+
* with the corresponding maximum or minimum finite values.
|
|
6900
|
+
*/
|
|
6901
|
+
function nanToNum(x, { nan: nan$1 = 0, posinf = null, neginf = null } = {}) {
|
|
6902
|
+
x = fudgeArray(x);
|
|
6903
|
+
x = where(isnan(x.ref), nan$1, x);
|
|
6904
|
+
posinf ??= require_backend.isFloatDtype(x.dtype) ? finfo(x.dtype).max : iinfo(x.dtype).max;
|
|
6905
|
+
neginf ??= require_backend.isFloatDtype(x.dtype) ? finfo(x.dtype).min : iinfo(x.dtype).min;
|
|
6906
|
+
x = where(isposinf(x.ref), posinf, x);
|
|
6907
|
+
x = where(isneginf(x.ref), neginf, x);
|
|
6908
|
+
return x;
|
|
6909
|
+
}
|
|
6910
|
+
/**
|
|
6873
6911
|
* @function
|
|
6874
6912
|
* Test element-wise for finite values (not infinity or NaN).
|
|
6875
6913
|
*/
|
|
@@ -6967,7 +7005,10 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
|
|
|
6967
7005
|
b = fudgeArray(b);
|
|
6968
7006
|
if (!leftSide) transposeA = !transposeA;
|
|
6969
7007
|
else b = moveaxis$1(b, -2, -1);
|
|
6970
|
-
if (transposeA)
|
|
7008
|
+
if (transposeA) {
|
|
7009
|
+
a = moveaxis$1(a, -2, -1);
|
|
7010
|
+
lower = !lower;
|
|
7011
|
+
}
|
|
6971
7012
|
let x = triangularSolve$1(a, b, {
|
|
6972
7013
|
lower,
|
|
6973
7014
|
unitDiagonal
|
|
@@ -7578,8 +7619,6 @@ function oneHot(x, numClasses) {
|
|
|
7578
7619
|
* `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
|
|
7579
7620
|
*/
|
|
7580
7621
|
function dotProductAttention(query, key$1, value, opts = {}) {
|
|
7581
|
-
if (opts.querySeqLengths !== void 0 || opts.keyValueSeqLengths !== void 0) throw new Error("Sequence length masking is not yet implemented");
|
|
7582
|
-
if (opts.localWindowSize !== void 0) throw new Error("Local attention is not yet implemented");
|
|
7583
7622
|
query = fudgeArray(query);
|
|
7584
7623
|
key$1 = fudgeArray(key$1);
|
|
7585
7624
|
value = fudgeArray(value);
|
|
@@ -7617,6 +7656,38 @@ function dotProductAttention(query, key$1, value, opts = {}) {
|
|
|
7617
7656
|
const causalMask = tri(L, S, 0, { dtype: require_backend.DType.Bool });
|
|
7618
7657
|
scores = where(causalMask, scores, -Infinity);
|
|
7619
7658
|
}
|
|
7659
|
+
if (opts.localWindowSize !== void 0) {
|
|
7660
|
+
const [before, after] = typeof opts.localWindowSize === "number" ? [opts.localWindowSize, opts.localWindowSize] : opts.localWindowSize;
|
|
7661
|
+
if (before < 0 || after < 0 || !Number.isInteger(before) || !Number.isInteger(after)) throw new Error(`dotProductAttention: localWindowSize values must be non-negative, got ${opts.localWindowSize}`);
|
|
7662
|
+
const localMask = tri(L, S, after, { dtype: require_backend.DType.Bool }).mul(tri(L, S, -before - 1, { dtype: require_backend.DType.Bool }).notEqual(true));
|
|
7663
|
+
scores = where(localMask, scores, -Infinity);
|
|
7664
|
+
}
|
|
7665
|
+
if (opts.querySeqLengths !== void 0) {
|
|
7666
|
+
const sl = expandDims(opts.querySeqLengths, [
|
|
7667
|
+
-1,
|
|
7668
|
+
-2,
|
|
7669
|
+
-3
|
|
7670
|
+
]);
|
|
7671
|
+
scores = where(arange(L).reshape([
|
|
7672
|
+
1,
|
|
7673
|
+
1,
|
|
7674
|
+
L,
|
|
7675
|
+
1
|
|
7676
|
+
]).less(sl), scores, -Infinity);
|
|
7677
|
+
}
|
|
7678
|
+
if (opts.keyValueSeqLengths !== void 0) {
|
|
7679
|
+
const sl = expandDims(opts.keyValueSeqLengths, [
|
|
7680
|
+
-1,
|
|
7681
|
+
-2,
|
|
7682
|
+
-3
|
|
7683
|
+
]);
|
|
7684
|
+
scores = where(arange(S).reshape([
|
|
7685
|
+
1,
|
|
7686
|
+
1,
|
|
7687
|
+
1,
|
|
7688
|
+
S
|
|
7689
|
+
]).less(sl), scores, -Infinity);
|
|
7690
|
+
}
|
|
7620
7691
|
const attn = softmax(scores, -1);
|
|
7621
7692
|
const out = einsum("BNLS,BSNH->BLNH", attn, value);
|
|
7622
7693
|
return isRank3 ? out.reshape([
|