@jax-js/jax 0.1.6 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -43,6 +43,55 @@ way to get started on a blank HTML page.
43
43
  </script>
44
44
  ```
45
45
 
46
+ ## Feature comparison
47
+
48
+ Here's a quick, high-level comparison with other popular web ML runtimes:
49
+
50
+ | Feature | jax-js | TensorFlow.js | onnxruntime-web |
51
+ | ------------------------------- | ---------- | --------------- | ------------------ |
52
+ | **Overview** | | | |
53
+ | API style | JAX/NumPy | TensorFlow-like | Static ONNX graphs |
54
+ | Latest release | 2026 | ⚠️ 2024 | 2026 |
55
+ | Speed | Fastest | Fast | Fastest |
56
+ | Bundle size (gzip) | 80 KB | 269 KB | 90 KB + 24 MB Wasm |
57
+ | **Autodiff & JIT** | | | |
58
+ | Gradients | ✅ | ✅ | ❌ |
59
+ | Jacobian and Hessian | ✅ | ❌ | ❌ |
60
+ | `jvp()` forward differentiation | ✅ | ❌ | ❌ |
61
+ | `jit()` kernel fusion | ✅ | ❌ | ❌ |
62
+ | `vmap()` auto-vectorization | ✅ | ❌ | ❌ |
63
+ | Graph capture | ✅ | ❌ | ✅ |
64
+ | **Backends & Data** | | | |
65
+ | WebGPU backend | ✅ | 🟡 Preview | ✅ |
66
+ | WebGL backend | ✅ | ✅ | ✅ |
67
+ | Wasm (CPU) backend | ✅ | ✅ | ✅ |
68
+ | Eager array API | ✅ | ✅ | ❌ |
69
+ | Run ONNX models | 🟡 Partial | ❌ | ✅ |
70
+ | Read safetensors | ✅ | ❌ | ❌ |
71
+ | Float64 | ✅ | ❌ | ❌ |
72
+ | Float32 | ✅ | ✅ | ✅ |
73
+ | Float16 | ✅ | ❌ | ✅ |
74
+ | BFloat16 | ❌ | ❌ | ❌ |
75
+ | Packed Uint8 | ❌ | ❌ | 🟡 Partial |
76
+ | Mixed precision | ✅ | ❌ | ✅ |
77
+ | Mixed devices | ✅ | ❌ | ❌ |
78
+ | **Ops & Numerics** | | | |
79
+ | Arithmetic functions | ✅ | ✅ | ✅ |
80
+ | Matrix multiplication | ✅ | ✅ | ✅ |
81
+ | General einsum | ✅ | 🟡 Partial | 🟡 Partial |
82
+ | Sorting | ✅ | ❌ | ❌ |
83
+ | Activation functions | ✅ | ✅ | ✅ |
84
+ | NaN/Inf numerics | ✅ | ✅ | ✅ |
85
+ | Basic convolutions | ✅ | ✅ | ✅ |
86
+ | n-d convolutions | ✅ | ❌ | ✅ |
87
+ | Strided/dilated convolution | ✅ | ✅ | ✅ |
88
+ | Cholesky, Lstsq | ✅ | ❌ | ❌ |
89
+ | LU, Solve, Determinant | ✅ | ❌ | ❌ |
90
+ | SVD | ❌ | ❌ | ❌ |
91
+ | FFT | ✅ | ✅ | ✅ |
92
+ | Basic RNG (Uniform, Normal) | ✅ | ✅ | ✅ |
93
+ | Advanced RNG | ✅ | ❌ | ❌ |
94
+
46
95
  ## Tutorial
47
96
 
48
97
  Programming in `jax-js` looks [very similar to JAX](https://docs.jax.dev/en/latest/jax-101.html),
@@ -271,15 +320,18 @@ self-contained way in other projects.
271
320
 
272
321
  ### Performance
273
322
 
274
- We haven't spent a ton of time optimizing yet, but performance is generally pretty good. `jit` is
275
- very helpful for fusing operations together, and it's a feature only available on the web in jax-js.
276
- The default kernel-tuning heuristics get about 3000 GFLOP/s for matrix multiplication on an M4 Pro
277
- chip ([try it](https://jax-js.com/bench/matmul)).
323
+ The WebGPU runtime includes an ML compiler with tile-aware optimizations, tuned for indiidual
324
+ browsers. Also, this library uniquely has the `jit()` feature that fuses operations together and
325
+ records an execution graph. jax-js achieves **over 7000 GFLOP/s** for matrix multiplication on an
326
+ Apple M4 Max chip ([try it](https://jax-js.com/bench/matmul)).
278
327
 
279
- For that example, it's around the same GFLOP/s as
328
+ For that example, it's significantly faster than both
280
329
  [TensorFlow.js](https://github.com/tensorflow/tfjs) and
281
330
  [ONNX Runtime Web](https://www.npmjs.com/package/onnxruntime-web), which both use handwritten
282
- libraries of custom kernels (versus jax-js, which generates kernels with an ML compiler).
331
+ libraries of custom kernels.
332
+
333
+ It's still early though. There's a lot of low-hanging fruit to continue optimizing the library, as
334
+ well as unique optimizations such as FlashAttention variants.
283
335
 
284
336
  ### API Reference
285
337
 
@@ -291,8 +343,9 @@ That's all for this short tutorial. Please see the generated
291
343
  If you make something cool with jax-js, don't be a stranger! We can feature it here.
292
344
 
293
345
  - [Training neural networks on MNIST](https://jax-js.com/mnist)
346
+ - [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
294
347
  - [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
295
- - [Object detection with DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
348
+ - [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
296
349
  - [In-browser REPL](https://jax-js.com/repl)
297
350
  - [Matmul benchmark](https://jax-js.com/bench/matmul)
298
351
  - [Conv2d benchmark](https://jax-js.com/bench/conv2d)
@@ -310,6 +363,19 @@ pnpm install
310
363
  pnpm run build:watch
311
364
  ```
312
365
 
366
+ The `pnpm install` command automatically sets up Git hooks via
367
+ [Husky](https://typicode.github.io/husky/). Pre-commit hooks will run ESLint and Prettier on staged
368
+ files to ensure code quality.
369
+
370
+ You can also run linting and formatting manually:
371
+
372
+ ```bash
373
+ pnpm lint # Run ESLint
374
+ pnpm format # Format all files with Prettier
375
+ pnpm format:check # Check formatting without writing
376
+ pnpm check # Run TypeScript type checking
377
+ ```
378
+
313
379
  Then you can run tests in a headless browser using [Vitest](https://vitest.dev/).
314
380
 
315
381
  ```bash
@@ -910,7 +910,7 @@ var AluExp = class AluExp {
910
910
  return ret.simplify(cache);
911
911
  }
912
912
  }
913
- if (y.arg > 0) {
913
+ if (y.arg > 0 && x.min >= 0) {
914
914
  let [xNoConst, constVal] = [x, 0];
915
915
  if (x.op === AluOp.Add && x.src[1].op === AluOp.Const) [xNoConst, constVal] = [x.src[0], x.src[1].arg];
916
916
  const terms = [];
@@ -932,7 +932,7 @@ var AluExp = class AluExp {
932
932
  rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(remainders[i] / gcdVal)), terms[i]));
933
933
  quo = AluExp.add(quo, AluExp.mul(AluExp.const(x.dtype, quotients[i]), terms[i]));
934
934
  }
935
- if (!((x.min < 0 || rem.min < 0) && remainders.some((r) => r !== 0))) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
935
+ if (rem.min >= 0) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
936
936
  else return AluExp.add(AluExp.idiv(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal))), quo).simplify(cache);
937
937
  }
938
938
  }
@@ -4253,7 +4253,7 @@ async function createBackend(device) {
4253
4253
  if (!navigator.gpu) return null;
4254
4254
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4255
4255
  if (!adapter) return null;
4256
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-rraa6dfz.cjs"));
4256
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BykvF26B.cjs"));
4257
4257
  const importantLimits = [
4258
4258
  "maxBufferSize",
4259
4259
  "maxComputeInvocationsPerWorkgroup",
@@ -4291,7 +4291,7 @@ async function createBackend(device) {
4291
4291
  });
4292
4292
  if (!gl) return null;
4293
4293
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
4294
- const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-CyfzNW8T.cjs"));
4294
+ const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-DIIbKJ0G.cjs"));
4295
4295
  return new WebGLBackend(gl);
4296
4296
  } else throw new Error(`Backend not found: ${device}`);
4297
4297
  }
@@ -909,7 +909,7 @@ var AluExp = class AluExp {
909
909
  return ret.simplify(cache);
910
910
  }
911
911
  }
912
- if (y.arg > 0) {
912
+ if (y.arg > 0 && x.min >= 0) {
913
913
  let [xNoConst, constVal] = [x, 0];
914
914
  if (x.op === AluOp.Add && x.src[1].op === AluOp.Const) [xNoConst, constVal] = [x.src[0], x.src[1].arg];
915
915
  const terms = [];
@@ -931,7 +931,7 @@ var AluExp = class AluExp {
931
931
  rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(remainders[i] / gcdVal)), terms[i]));
932
932
  quo = AluExp.add(quo, AluExp.mul(AluExp.const(x.dtype, quotients[i]), terms[i]));
933
933
  }
934
- if (!((x.min < 0 || rem.min < 0) && remainders.some((r) => r !== 0))) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
934
+ if (rem.min >= 0) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
935
935
  else return AluExp.add(AluExp.idiv(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal))), quo).simplify(cache);
936
936
  }
937
937
  }
@@ -4252,7 +4252,7 @@ async function createBackend(device) {
4252
4252
  if (!navigator.gpu) return null;
4253
4253
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4254
4254
  if (!adapter) return null;
4255
- const { WebGPUBackend } = await import("./webgpu-C-VfevQW.js");
4255
+ const { WebGPUBackend } = await import("./webgpu-B96vzWGE.js");
4256
4256
  const importantLimits = [
4257
4257
  "maxBufferSize",
4258
4258
  "maxComputeInvocationsPerWorkgroup",
@@ -4290,7 +4290,7 @@ async function createBackend(device) {
4290
4290
  });
4291
4291
  if (!gl) return null;
4292
4292
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
4293
- const { WebGLBackend } = await import("./webgl-CLLvzJlO.js");
4293
+ const { WebGLBackend } = await import("./webgl-DweKSWEm.js");
4294
4294
  return new WebGLBackend(gl);
4295
4295
  } else throw new Error(`Backend not found: ${device}`);
4296
4296
  }
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-D7s-Retx.cjs');
33
+ const require_backend = require('./backend-B3foXiV_.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -337,11 +337,11 @@ function map(fn, tree, ...rest) {
337
337
  }
338
338
  /** Take a reference of every array in a tree. */
339
339
  function ref(tree) {
340
- return map((x) => x.ref, tree);
340
+ return map((x) => x instanceof Tracer ? x.ref : x, tree);
341
341
  }
342
342
  /** Dispose every array in a tree. */
343
343
  function dispose(tree) {
344
- if (tree) map((x) => x.dispose(), tree);
344
+ if (tree) map((x) => x instanceof Tracer ? x.dispose() : void 0, tree);
345
345
  }
346
346
 
347
347
  //#endregion
@@ -615,14 +615,20 @@ function shrink(x, slice) {
615
615
  }
616
616
  function pad$1(x, width) {
617
617
  const nd = ndim$1(x);
618
- if (typeof width === "number") width = [[width, width]];
619
- else if (require_backend.isNumberPair(width)) width = [width];
620
- else if (!Array.isArray(width) || !width.every(require_backend.isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
621
- if (width.length === 1) {
622
- const [w0, w1] = width[0];
623
- width = require_backend.rep(nd, () => [w0, w1]);
624
- } else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
625
- return bind1(Primitive.Pad, [x], { width });
618
+ let w;
619
+ if (typeof width === "number") w = [[width, width]];
620
+ else if (require_backend.isNumberPair(width)) w = [width];
621
+ else if (!Array.isArray(width)) {
622
+ const indicesAndPairs = Object.entries(width);
623
+ w = require_backend.rep(nd, [0, 0]);
624
+ for (const [k, v] of indicesAndPairs) w[require_backend.checkAxis(parseInt(k), nd)] = v;
625
+ } else if (!width.every(require_backend.isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
626
+ else w = width;
627
+ if (w.length === 1) {
628
+ const [w0, w1] = w[0];
629
+ w = require_backend.rep(nd, () => [w0, w1]);
630
+ } else if (w.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${w.length}`);
631
+ return bind1(Primitive.Pad, [x], { width: w });
626
632
  }
627
633
  function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
628
634
  const as = getShape(a);
@@ -798,6 +804,22 @@ var Tracer = class Tracer {
798
804
  const result = reduce(this.astype(castDtype), require_backend.AluOp.Add, axis, opts);
799
805
  return result.mul(1 / n).astype(originalDtype);
800
806
  }
807
+ /** Minimum of the elements of the array along a given axis. */
808
+ min(axis = null, opts) {
809
+ return reduce(this, require_backend.AluOp.Min, axis, opts);
810
+ }
811
+ /** Maximum of the elements of the array along a given axis. */
812
+ max(axis = null, opts) {
813
+ return reduce(this, require_backend.AluOp.Max, axis, opts);
814
+ }
815
+ /** Test whether all array elements along a given axis evaluate to true. */
816
+ all(axis = null, opts) {
817
+ return this.astype(require_backend.DType.Bool).min(axis, opts);
818
+ }
819
+ /** Test whether any array element along a given axis evaluates to true. */
820
+ any(axis = null, opts) {
821
+ return this.astype(require_backend.DType.Bool).max(axis, opts);
822
+ }
801
823
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
802
824
  transpose(perm) {
803
825
  return transpose$1(this, perm);
@@ -5307,6 +5329,7 @@ function lstsq(a, b) {
5307
5329
  });
5308
5330
  const llb = triangularSolve(l, lb, {
5309
5331
  leftSide: true,
5332
+ lower: true,
5310
5333
  transposeA: true
5311
5334
  });
5312
5335
  return matmul(at, llb.ref);
@@ -5320,6 +5343,7 @@ function lstsq(a, b) {
5320
5343
  });
5321
5344
  const llb = triangularSolve(l, lb, {
5322
5345
  leftSide: true,
5346
+ lower: true,
5323
5347
  transposeA: true
5324
5348
  });
5325
5349
  return llb;
@@ -5607,6 +5631,7 @@ __export(numpy_exports, {
5607
5631
  moveaxis: () => moveaxis$1,
5608
5632
  multiply: () => multiply,
5609
5633
  nan: () => nan,
5634
+ nanToNum: () => nanToNum,
5610
5635
  ndim: () => ndim,
5611
5636
  negative: () => negative,
5612
5637
  notEqual: () => notEqual,
@@ -5804,24 +5829,22 @@ function max(a, axis = null, opts) {
5804
5829
  return reduce(a, require_backend.AluOp.Max, axis, opts);
5805
5830
  }
5806
5831
  /**
5807
- * Test whether all array elements along a given axis evaluate to True.
5832
+ * Test whether any array element along a given axis evaluates to True.
5808
5833
  *
5809
5834
  * Returns a boolean array with the same shape as `a` with the specified axis
5810
5835
  * removed. If axis is None, returns a scalar.
5811
5836
  */
5812
- function all(a, axis = null, opts) {
5813
- a = fudgeArray(a).astype(require_backend.DType.Bool);
5814
- return min(a, axis, opts);
5837
+ function any(a, axis = null, opts) {
5838
+ return fudgeArray(a).any(axis, opts);
5815
5839
  }
5816
5840
  /**
5817
- * Test whether any array element along a given axis evaluates to True.
5841
+ * Test whether all array elements along a given axis evaluate to True.
5818
5842
  *
5819
5843
  * Returns a boolean array with the same shape as `a` with the specified axis
5820
5844
  * removed. If axis is None, returns a scalar.
5821
5845
  */
5822
- function any(a, axis = null, opts) {
5823
- a = fudgeArray(a).astype(require_backend.DType.Bool);
5824
- return max(a, axis, opts);
5846
+ function all(a, axis = null, opts) {
5847
+ return fudgeArray(a).all(axis, opts);
5825
5848
  }
5826
5849
  /** Return the peak-to-peak range along a given axis (`max - min`). */
5827
5850
  function ptp(a, axis = null, opts) {
@@ -5922,7 +5945,7 @@ function split$1(a, indicesOrSections, axis = 0) {
5922
5945
  const partSize = size$1 / indicesOrSections;
5923
5946
  sizes = require_backend.rep(indicesOrSections, partSize);
5924
5947
  } else {
5925
- const indices = indicesOrSections;
5948
+ const indices = indicesOrSections.map((i) => i < 0 ? i + size$1 : i);
5926
5949
  sizes = [indices[0]];
5927
5950
  for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
5928
5951
  sizes.push(size$1 - indices[indices.length - 1]);
@@ -6870,6 +6893,21 @@ function isposinf(x) {
6870
6893
  return require_backend.isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
6871
6894
  }
6872
6895
  /**
6896
+ * Replace NaN and infinite entries in an array.
6897
+ *
6898
+ * By default, NaNs are replaced with `0.0`, and infinities are are substituted
6899
+ * with the corresponding maximum or minimum finite values.
6900
+ */
6901
+ function nanToNum(x, { nan: nan$1 = 0, posinf = null, neginf = null } = {}) {
6902
+ x = fudgeArray(x);
6903
+ x = where(isnan(x.ref), nan$1, x);
6904
+ posinf ??= require_backend.isFloatDtype(x.dtype) ? finfo(x.dtype).max : iinfo(x.dtype).max;
6905
+ neginf ??= require_backend.isFloatDtype(x.dtype) ? finfo(x.dtype).min : iinfo(x.dtype).min;
6906
+ x = where(isposinf(x.ref), posinf, x);
6907
+ x = where(isneginf(x.ref), neginf, x);
6908
+ return x;
6909
+ }
6910
+ /**
6873
6911
  * @function
6874
6912
  * Test element-wise for finite values (not infinity or NaN).
6875
6913
  */
@@ -6967,7 +7005,10 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
6967
7005
  b = fudgeArray(b);
6968
7006
  if (!leftSide) transposeA = !transposeA;
6969
7007
  else b = moveaxis$1(b, -2, -1);
6970
- if (transposeA) a = moveaxis$1(a, -2, -1);
7008
+ if (transposeA) {
7009
+ a = moveaxis$1(a, -2, -1);
7010
+ lower = !lower;
7011
+ }
6971
7012
  let x = triangularSolve$1(a, b, {
6972
7013
  lower,
6973
7014
  unitDiagonal
@@ -7578,8 +7619,6 @@ function oneHot(x, numClasses) {
7578
7619
  * `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
7579
7620
  */
7580
7621
  function dotProductAttention(query, key$1, value, opts = {}) {
7581
- if (opts.querySeqLengths !== void 0 || opts.keyValueSeqLengths !== void 0) throw new Error("Sequence length masking is not yet implemented");
7582
- if (opts.localWindowSize !== void 0) throw new Error("Local attention is not yet implemented");
7583
7622
  query = fudgeArray(query);
7584
7623
  key$1 = fudgeArray(key$1);
7585
7624
  value = fudgeArray(value);
@@ -7617,6 +7656,38 @@ function dotProductAttention(query, key$1, value, opts = {}) {
7617
7656
  const causalMask = tri(L, S, 0, { dtype: require_backend.DType.Bool });
7618
7657
  scores = where(causalMask, scores, -Infinity);
7619
7658
  }
7659
+ if (opts.localWindowSize !== void 0) {
7660
+ const [before, after] = typeof opts.localWindowSize === "number" ? [opts.localWindowSize, opts.localWindowSize] : opts.localWindowSize;
7661
+ if (before < 0 || after < 0 || !Number.isInteger(before) || !Number.isInteger(after)) throw new Error(`dotProductAttention: localWindowSize values must be non-negative, got ${opts.localWindowSize}`);
7662
+ const localMask = tri(L, S, after, { dtype: require_backend.DType.Bool }).mul(tri(L, S, -before - 1, { dtype: require_backend.DType.Bool }).notEqual(true));
7663
+ scores = where(localMask, scores, -Infinity);
7664
+ }
7665
+ if (opts.querySeqLengths !== void 0) {
7666
+ const sl = expandDims(opts.querySeqLengths, [
7667
+ -1,
7668
+ -2,
7669
+ -3
7670
+ ]);
7671
+ scores = where(arange(L).reshape([
7672
+ 1,
7673
+ 1,
7674
+ L,
7675
+ 1
7676
+ ]).less(sl), scores, -Infinity);
7677
+ }
7678
+ if (opts.keyValueSeqLengths !== void 0) {
7679
+ const sl = expandDims(opts.keyValueSeqLengths, [
7680
+ -1,
7681
+ -2,
7682
+ -3
7683
+ ]);
7684
+ scores = where(arange(S).reshape([
7685
+ 1,
7686
+ 1,
7687
+ 1,
7688
+ S
7689
+ ]).less(sl), scores, -Infinity);
7690
+ }
7620
7691
  const attn = softmax(scores, -1);
7621
7692
  const out = einsum("BNLS,BSNH->BLNH", attn, value);
7622
7693
  return isRank3 ? out.reshape([