@jax-js/jax 0.1.6 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +60 -7
- package/dist/{backend-D7s-Retx.cjs → backend-B3foXiV_.cjs} +4 -4
- package/dist/{backend-Dx6Ob2D1.js → backend-nEolvdLv.js} +4 -4
- package/dist/index.cjs +88 -22
- package/dist/index.d.cts +1561 -1538
- package/dist/index.d.ts +1561 -1538
- package/dist/index.js +88 -22
- package/dist/{webgl-CyfzNW8T.cjs → webgl-DIIbKJ0G.cjs} +1 -1
- package/dist/{webgl-CLLvzJlO.js → webgl-DweKSWEm.js} +1 -1
- package/dist/{webgpu-C-VfevQW.js → webgpu-B96vzWGE.js} +1 -1
- package/dist/{webgpu-rraa6dfz.cjs → webgpu-BykvF26B.cjs} +1 -1
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -43,6 +43,55 @@ way to get started on a blank HTML page.
|
|
|
43
43
|
</script>
|
|
44
44
|
```
|
|
45
45
|
|
|
46
|
+
## Feature comparison
|
|
47
|
+
|
|
48
|
+
Here's a quick, high-level comparison with other popular web ML runtimes:
|
|
49
|
+
|
|
50
|
+
| Feature | jax-js | TensorFlow.js | onnxruntime-web |
|
|
51
|
+
| ------------------------------- | ---------- | --------------- | ------------------ |
|
|
52
|
+
| **Overview** | | | |
|
|
53
|
+
| API style | JAX/NumPy | TensorFlow-like | Static ONNX graphs |
|
|
54
|
+
| Latest release | 2026 | ⚠️ 2024 | 2026 |
|
|
55
|
+
| Speed | Fastest | Fast | Fastest |
|
|
56
|
+
| Bundle size (gzip) | 80 KB | 269 KB | 90 KB + 24 MB Wasm |
|
|
57
|
+
| **Autodiff & JIT** | | | |
|
|
58
|
+
| Gradients | ✅ | ✅ | ❌ |
|
|
59
|
+
| Jacobian and Hessian | ✅ | ❌ | ❌ |
|
|
60
|
+
| `jvp()` forward differentiation | ✅ | ❌ | ❌ |
|
|
61
|
+
| `jit()` kernel fusion | ✅ | ❌ | ❌ |
|
|
62
|
+
| `vmap()` auto-vectorization | ✅ | ❌ | ❌ |
|
|
63
|
+
| Graph capture | ✅ | ❌ | ✅ |
|
|
64
|
+
| **Backends & Data** | | | |
|
|
65
|
+
| WebGPU backend | ✅ | 🟡 Preview | ✅ |
|
|
66
|
+
| WebGL backend | ✅ | ✅ | ✅ |
|
|
67
|
+
| Wasm (CPU) backend | ✅ | ✅ | ✅ |
|
|
68
|
+
| Eager array API | ✅ | ✅ | ❌ |
|
|
69
|
+
| Run ONNX models | 🟡 Partial | ❌ | ✅ |
|
|
70
|
+
| Read safetensors | ✅ | ❌ | ❌ |
|
|
71
|
+
| Float64 | ✅ | ❌ | ❌ |
|
|
72
|
+
| Float32 | ✅ | ✅ | ✅ |
|
|
73
|
+
| Float16 | ✅ | ❌ | ✅ |
|
|
74
|
+
| BFloat16 | ❌ | ❌ | ❌ |
|
|
75
|
+
| Packed Uint8 | ❌ | ❌ | 🟡 Partial |
|
|
76
|
+
| Mixed precision | ✅ | ❌ | ✅ |
|
|
77
|
+
| Mixed devices | ✅ | ❌ | ❌ |
|
|
78
|
+
| **Ops & Numerics** | | | |
|
|
79
|
+
| Arithmetic functions | ✅ | ✅ | ✅ |
|
|
80
|
+
| Matrix multiplication | ✅ | ✅ | ✅ |
|
|
81
|
+
| General einsum | ✅ | 🟡 Partial | 🟡 Partial |
|
|
82
|
+
| Sorting | ✅ | ❌ | ❌ |
|
|
83
|
+
| Activation functions | ✅ | ✅ | ✅ |
|
|
84
|
+
| NaN/Inf numerics | ✅ | ✅ | ✅ |
|
|
85
|
+
| Basic convolutions | ✅ | ✅ | ✅ |
|
|
86
|
+
| n-d convolutions | ✅ | ❌ | ✅ |
|
|
87
|
+
| Strided/dilated convolution | ✅ | ✅ | ✅ |
|
|
88
|
+
| Cholesky, Lstsq | ✅ | ❌ | ❌ |
|
|
89
|
+
| LU, Solve, Determinant | ✅ | ❌ | ❌ |
|
|
90
|
+
| SVD | ❌ | ❌ | ❌ |
|
|
91
|
+
| FFT | ✅ | ✅ | ✅ |
|
|
92
|
+
| Basic RNG (Uniform, Normal) | ✅ | ✅ | ✅ |
|
|
93
|
+
| Advanced RNG | ✅ | ❌ | ❌ |
|
|
94
|
+
|
|
46
95
|
## Tutorial
|
|
47
96
|
|
|
48
97
|
Programming in `jax-js` looks [very similar to JAX](https://docs.jax.dev/en/latest/jax-101.html),
|
|
@@ -271,15 +320,18 @@ self-contained way in other projects.
|
|
|
271
320
|
|
|
272
321
|
### Performance
|
|
273
322
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
chip ([try it](https://jax-js.com/bench/matmul)).
|
|
323
|
+
The WebGPU runtime includes an ML compiler with tile-aware optimizations, tuned for indiidual
|
|
324
|
+
browsers. Also, this library uniquely has the `jit()` feature that fuses operations together and
|
|
325
|
+
records an execution graph. jax-js achieves **over 7000 GFLOP/s** for matrix multiplication on an
|
|
326
|
+
Apple M4 Max chip ([try it](https://jax-js.com/bench/matmul)).
|
|
278
327
|
|
|
279
|
-
For that example, it's
|
|
328
|
+
For that example, it's significantly faster than both
|
|
280
329
|
[TensorFlow.js](https://github.com/tensorflow/tfjs) and
|
|
281
330
|
[ONNX Runtime Web](https://www.npmjs.com/package/onnxruntime-web), which both use handwritten
|
|
282
|
-
libraries of custom kernels
|
|
331
|
+
libraries of custom kernels.
|
|
332
|
+
|
|
333
|
+
It's still early though. There's a lot of low-hanging fruit to continue optimizing the library, as
|
|
334
|
+
well as unique optimizations such as FlashAttention variants.
|
|
283
335
|
|
|
284
336
|
### API Reference
|
|
285
337
|
|
|
@@ -291,8 +343,9 @@ That's all for this short tutorial. Please see the generated
|
|
|
291
343
|
If you make something cool with jax-js, don't be a stranger! We can feature it here.
|
|
292
344
|
|
|
293
345
|
- [Training neural networks on MNIST](https://jax-js.com/mnist)
|
|
346
|
+
- [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
|
|
294
347
|
- [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
|
|
295
|
-
- [Object detection
|
|
348
|
+
- [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
|
|
296
349
|
- [In-browser REPL](https://jax-js.com/repl)
|
|
297
350
|
- [Matmul benchmark](https://jax-js.com/bench/matmul)
|
|
298
351
|
- [Conv2d benchmark](https://jax-js.com/bench/conv2d)
|
|
@@ -910,7 +910,7 @@ var AluExp = class AluExp {
|
|
|
910
910
|
return ret.simplify(cache);
|
|
911
911
|
}
|
|
912
912
|
}
|
|
913
|
-
if (y.arg > 0) {
|
|
913
|
+
if (y.arg > 0 && x.min >= 0) {
|
|
914
914
|
let [xNoConst, constVal] = [x, 0];
|
|
915
915
|
if (x.op === AluOp.Add && x.src[1].op === AluOp.Const) [xNoConst, constVal] = [x.src[0], x.src[1].arg];
|
|
916
916
|
const terms = [];
|
|
@@ -932,7 +932,7 @@ var AluExp = class AluExp {
|
|
|
932
932
|
rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(remainders[i] / gcdVal)), terms[i]));
|
|
933
933
|
quo = AluExp.add(quo, AluExp.mul(AluExp.const(x.dtype, quotients[i]), terms[i]));
|
|
934
934
|
}
|
|
935
|
-
if (
|
|
935
|
+
if (rem.min >= 0) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
|
|
936
936
|
else return AluExp.add(AluExp.idiv(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal))), quo).simplify(cache);
|
|
937
937
|
}
|
|
938
938
|
}
|
|
@@ -4253,7 +4253,7 @@ async function createBackend(device) {
|
|
|
4253
4253
|
if (!navigator.gpu) return null;
|
|
4254
4254
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4255
4255
|
if (!adapter) return null;
|
|
4256
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
4256
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BykvF26B.cjs"));
|
|
4257
4257
|
const importantLimits = [
|
|
4258
4258
|
"maxBufferSize",
|
|
4259
4259
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4291,7 +4291,7 @@ async function createBackend(device) {
|
|
|
4291
4291
|
});
|
|
4292
4292
|
if (!gl) return null;
|
|
4293
4293
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4294
|
-
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-
|
|
4294
|
+
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-DIIbKJ0G.cjs"));
|
|
4295
4295
|
return new WebGLBackend(gl);
|
|
4296
4296
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4297
4297
|
}
|
|
@@ -909,7 +909,7 @@ var AluExp = class AluExp {
|
|
|
909
909
|
return ret.simplify(cache);
|
|
910
910
|
}
|
|
911
911
|
}
|
|
912
|
-
if (y.arg > 0) {
|
|
912
|
+
if (y.arg > 0 && x.min >= 0) {
|
|
913
913
|
let [xNoConst, constVal] = [x, 0];
|
|
914
914
|
if (x.op === AluOp.Add && x.src[1].op === AluOp.Const) [xNoConst, constVal] = [x.src[0], x.src[1].arg];
|
|
915
915
|
const terms = [];
|
|
@@ -931,7 +931,7 @@ var AluExp = class AluExp {
|
|
|
931
931
|
rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(remainders[i] / gcdVal)), terms[i]));
|
|
932
932
|
quo = AluExp.add(quo, AluExp.mul(AluExp.const(x.dtype, quotients[i]), terms[i]));
|
|
933
933
|
}
|
|
934
|
-
if (
|
|
934
|
+
if (rem.min >= 0) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
|
|
935
935
|
else return AluExp.add(AluExp.idiv(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal))), quo).simplify(cache);
|
|
936
936
|
}
|
|
937
937
|
}
|
|
@@ -4252,7 +4252,7 @@ async function createBackend(device) {
|
|
|
4252
4252
|
if (!navigator.gpu) return null;
|
|
4253
4253
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4254
4254
|
if (!adapter) return null;
|
|
4255
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
4255
|
+
const { WebGPUBackend } = await import("./webgpu-B96vzWGE.js");
|
|
4256
4256
|
const importantLimits = [
|
|
4257
4257
|
"maxBufferSize",
|
|
4258
4258
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4290,7 +4290,7 @@ async function createBackend(device) {
|
|
|
4290
4290
|
});
|
|
4291
4291
|
if (!gl) return null;
|
|
4292
4292
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4293
|
-
const { WebGLBackend } = await import("./webgl-
|
|
4293
|
+
const { WebGLBackend } = await import("./webgl-DweKSWEm.js");
|
|
4294
4294
|
return new WebGLBackend(gl);
|
|
4295
4295
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4296
4296
|
}
|
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-B3foXiV_.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -337,11 +337,11 @@ function map(fn, tree, ...rest) {
|
|
|
337
337
|
}
|
|
338
338
|
/** Take a reference of every array in a tree. */
|
|
339
339
|
function ref(tree) {
|
|
340
|
-
return map((x) => x.ref, tree);
|
|
340
|
+
return map((x) => x instanceof Tracer ? x.ref : x, tree);
|
|
341
341
|
}
|
|
342
342
|
/** Dispose every array in a tree. */
|
|
343
343
|
function dispose(tree) {
|
|
344
|
-
if (tree) map((x) => x.dispose(), tree);
|
|
344
|
+
if (tree) map((x) => x instanceof Tracer ? x.dispose() : void 0, tree);
|
|
345
345
|
}
|
|
346
346
|
|
|
347
347
|
//#endregion
|
|
@@ -615,14 +615,20 @@ function shrink(x, slice) {
|
|
|
615
615
|
}
|
|
616
616
|
function pad$1(x, width) {
|
|
617
617
|
const nd = ndim$1(x);
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
else if (
|
|
621
|
-
if (width
|
|
622
|
-
const
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
618
|
+
let w;
|
|
619
|
+
if (typeof width === "number") w = [[width, width]];
|
|
620
|
+
else if (require_backend.isNumberPair(width)) w = [width];
|
|
621
|
+
else if (!Array.isArray(width)) {
|
|
622
|
+
const indicesAndPairs = Object.entries(width);
|
|
623
|
+
w = require_backend.rep(nd, [0, 0]);
|
|
624
|
+
for (const [k, v] of indicesAndPairs) w[require_backend.checkAxis(parseInt(k), nd)] = v;
|
|
625
|
+
} else if (!width.every(require_backend.isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
|
|
626
|
+
else w = width;
|
|
627
|
+
if (w.length === 1) {
|
|
628
|
+
const [w0, w1] = w[0];
|
|
629
|
+
w = require_backend.rep(nd, () => [w0, w1]);
|
|
630
|
+
} else if (w.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${w.length}`);
|
|
631
|
+
return bind1(Primitive.Pad, [x], { width: w });
|
|
626
632
|
}
|
|
627
633
|
function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
628
634
|
const as = getShape(a);
|
|
@@ -798,6 +804,22 @@ var Tracer = class Tracer {
|
|
|
798
804
|
const result = reduce(this.astype(castDtype), require_backend.AluOp.Add, axis, opts);
|
|
799
805
|
return result.mul(1 / n).astype(originalDtype);
|
|
800
806
|
}
|
|
807
|
+
/** Minimum of the elements of the array along a given axis. */
|
|
808
|
+
min(axis = null, opts) {
|
|
809
|
+
return reduce(this, require_backend.AluOp.Min, axis, opts);
|
|
810
|
+
}
|
|
811
|
+
/** Maximum of the elements of the array along a given axis. */
|
|
812
|
+
max(axis = null, opts) {
|
|
813
|
+
return reduce(this, require_backend.AluOp.Max, axis, opts);
|
|
814
|
+
}
|
|
815
|
+
/** Test whether all array elements along a given axis evaluate to true. */
|
|
816
|
+
all(axis = null, opts) {
|
|
817
|
+
return this.astype(require_backend.DType.Bool).min(axis, opts);
|
|
818
|
+
}
|
|
819
|
+
/** Test whether any array element along a given axis evaluates to true. */
|
|
820
|
+
any(axis = null, opts) {
|
|
821
|
+
return this.astype(require_backend.DType.Bool).max(axis, opts);
|
|
822
|
+
}
|
|
801
823
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
802
824
|
transpose(perm) {
|
|
803
825
|
return transpose$1(this, perm);
|
|
@@ -5607,6 +5629,7 @@ __export(numpy_exports, {
|
|
|
5607
5629
|
moveaxis: () => moveaxis$1,
|
|
5608
5630
|
multiply: () => multiply,
|
|
5609
5631
|
nan: () => nan,
|
|
5632
|
+
nanToNum: () => nanToNum,
|
|
5610
5633
|
ndim: () => ndim,
|
|
5611
5634
|
negative: () => negative,
|
|
5612
5635
|
notEqual: () => notEqual,
|
|
@@ -5804,24 +5827,22 @@ function max(a, axis = null, opts) {
|
|
|
5804
5827
|
return reduce(a, require_backend.AluOp.Max, axis, opts);
|
|
5805
5828
|
}
|
|
5806
5829
|
/**
|
|
5807
|
-
* Test whether
|
|
5830
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
5808
5831
|
*
|
|
5809
5832
|
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5810
5833
|
* removed. If axis is None, returns a scalar.
|
|
5811
5834
|
*/
|
|
5812
|
-
function
|
|
5813
|
-
|
|
5814
|
-
return min(a, axis, opts);
|
|
5835
|
+
function any(a, axis = null, opts) {
|
|
5836
|
+
return fudgeArray(a).any(axis, opts);
|
|
5815
5837
|
}
|
|
5816
5838
|
/**
|
|
5817
|
-
* Test whether
|
|
5839
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
5818
5840
|
*
|
|
5819
5841
|
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5820
5842
|
* removed. If axis is None, returns a scalar.
|
|
5821
5843
|
*/
|
|
5822
|
-
function
|
|
5823
|
-
|
|
5824
|
-
return max(a, axis, opts);
|
|
5844
|
+
function all(a, axis = null, opts) {
|
|
5845
|
+
return fudgeArray(a).all(axis, opts);
|
|
5825
5846
|
}
|
|
5826
5847
|
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
5827
5848
|
function ptp(a, axis = null, opts) {
|
|
@@ -5922,7 +5943,7 @@ function split$1(a, indicesOrSections, axis = 0) {
|
|
|
5922
5943
|
const partSize = size$1 / indicesOrSections;
|
|
5923
5944
|
sizes = require_backend.rep(indicesOrSections, partSize);
|
|
5924
5945
|
} else {
|
|
5925
|
-
const indices = indicesOrSections;
|
|
5946
|
+
const indices = indicesOrSections.map((i) => i < 0 ? i + size$1 : i);
|
|
5926
5947
|
sizes = [indices[0]];
|
|
5927
5948
|
for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
|
|
5928
5949
|
sizes.push(size$1 - indices[indices.length - 1]);
|
|
@@ -6870,6 +6891,21 @@ function isposinf(x) {
|
|
|
6870
6891
|
return require_backend.isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
|
|
6871
6892
|
}
|
|
6872
6893
|
/**
|
|
6894
|
+
* Replace NaN and infinite entries in an array.
|
|
6895
|
+
*
|
|
6896
|
+
* By default, NaNs are replaced with `0.0`, and infinities are are substituted
|
|
6897
|
+
* with the corresponding maximum or minimum finite values.
|
|
6898
|
+
*/
|
|
6899
|
+
function nanToNum(x, { nan: nan$1 = 0, posinf = null, neginf = null } = {}) {
|
|
6900
|
+
x = fudgeArray(x);
|
|
6901
|
+
x = where(isnan(x.ref), nan$1, x);
|
|
6902
|
+
posinf ??= require_backend.isFloatDtype(x.dtype) ? finfo(x.dtype).max : iinfo(x.dtype).max;
|
|
6903
|
+
neginf ??= require_backend.isFloatDtype(x.dtype) ? finfo(x.dtype).min : iinfo(x.dtype).min;
|
|
6904
|
+
x = where(isposinf(x.ref), posinf, x);
|
|
6905
|
+
x = where(isneginf(x.ref), neginf, x);
|
|
6906
|
+
return x;
|
|
6907
|
+
}
|
|
6908
|
+
/**
|
|
6873
6909
|
* @function
|
|
6874
6910
|
* Test element-wise for finite values (not infinity or NaN).
|
|
6875
6911
|
*/
|
|
@@ -7578,8 +7614,6 @@ function oneHot(x, numClasses) {
|
|
|
7578
7614
|
* `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
|
|
7579
7615
|
*/
|
|
7580
7616
|
function dotProductAttention(query, key$1, value, opts = {}) {
|
|
7581
|
-
if (opts.querySeqLengths !== void 0 || opts.keyValueSeqLengths !== void 0) throw new Error("Sequence length masking is not yet implemented");
|
|
7582
|
-
if (opts.localWindowSize !== void 0) throw new Error("Local attention is not yet implemented");
|
|
7583
7617
|
query = fudgeArray(query);
|
|
7584
7618
|
key$1 = fudgeArray(key$1);
|
|
7585
7619
|
value = fudgeArray(value);
|
|
@@ -7617,6 +7651,38 @@ function dotProductAttention(query, key$1, value, opts = {}) {
|
|
|
7617
7651
|
const causalMask = tri(L, S, 0, { dtype: require_backend.DType.Bool });
|
|
7618
7652
|
scores = where(causalMask, scores, -Infinity);
|
|
7619
7653
|
}
|
|
7654
|
+
if (opts.localWindowSize !== void 0) {
|
|
7655
|
+
const [before, after] = typeof opts.localWindowSize === "number" ? [opts.localWindowSize, opts.localWindowSize] : opts.localWindowSize;
|
|
7656
|
+
if (before < 0 || after < 0 || !Number.isInteger(before) || !Number.isInteger(after)) throw new Error(`dotProductAttention: localWindowSize values must be non-negative, got ${opts.localWindowSize}`);
|
|
7657
|
+
const localMask = tri(L, S, after, { dtype: require_backend.DType.Bool }).mul(tri(L, S, -before - 1, { dtype: require_backend.DType.Bool }).notEqual(true));
|
|
7658
|
+
scores = where(localMask, scores, -Infinity);
|
|
7659
|
+
}
|
|
7660
|
+
if (opts.querySeqLengths !== void 0) {
|
|
7661
|
+
const sl = expandDims(opts.querySeqLengths, [
|
|
7662
|
+
-1,
|
|
7663
|
+
-2,
|
|
7664
|
+
-3
|
|
7665
|
+
]);
|
|
7666
|
+
scores = where(arange(L).reshape([
|
|
7667
|
+
1,
|
|
7668
|
+
1,
|
|
7669
|
+
L,
|
|
7670
|
+
1
|
|
7671
|
+
]).less(sl), scores, -Infinity);
|
|
7672
|
+
}
|
|
7673
|
+
if (opts.keyValueSeqLengths !== void 0) {
|
|
7674
|
+
const sl = expandDims(opts.keyValueSeqLengths, [
|
|
7675
|
+
-1,
|
|
7676
|
+
-2,
|
|
7677
|
+
-3
|
|
7678
|
+
]);
|
|
7679
|
+
scores = where(arange(S).reshape([
|
|
7680
|
+
1,
|
|
7681
|
+
1,
|
|
7682
|
+
1,
|
|
7683
|
+
S
|
|
7684
|
+
]).less(sl), scores, -Infinity);
|
|
7685
|
+
}
|
|
7620
7686
|
const attn = softmax(scores, -1);
|
|
7621
7687
|
const out = einsum("BNLS,BSNH->BLNH", attn, value);
|
|
7622
7688
|
return isRank3 ? out.reshape([
|