@genai-fi/nanogpt 0.5.6 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. package/dist/Generator.js +10 -9
  2. package/dist/NanoGPTModel.js +70 -121
  3. package/dist/RealDiv-7xu-pkZN.js +540 -0
  4. package/dist/Reshape-BYC1oUku.js +127 -0
  5. package/dist/TeachableLLM.d.ts +2 -0
  6. package/dist/TeachableLLM.js +42 -34
  7. package/dist/{TiedEmbedding-8S8xn8e6.js → TiedEmbedding-C1HBot-5.js} +12 -13
  8. package/dist/{axis_util-BczFISHz.js → axis_util-CCNL7jea.js} +14 -12
  9. package/dist/{broadcast_to-B7NGsBSh.js → broadcast_to-CddAF879.js} +2 -2
  10. package/dist/{concat-DdKPyAtw.js → concat-XOK9ANZu.js} +7 -7
  11. package/dist/{dataset-iqT4Otvb.js → dataset-BFFipD1c.js} +5 -5
  12. package/dist/{dropout-B09InSJS.js → dropout-xlKRoJyU.js} +9 -9
  13. package/dist/{gather-D6MsdXqc.js → gather-DKtUaTtA.js} +1 -1
  14. package/dist/gpgpu_math-B_ycgZ4W.js +3115 -0
  15. package/dist/{index-Du-bmOP8.js → index-CamYe_M8.js} +844 -647
  16. package/dist/{kernel_funcs_utils-DShm7-0k.js → kernel_funcs_utils-D5MS0JFg.js} +232 -136
  17. package/dist/layers/BaseLayer.js +2 -2
  18. package/dist/layers/CausalSelfAttention.js +6 -6
  19. package/dist/layers/MLP.js +5 -5
  20. package/dist/layers/RMSNorm.js +3 -3
  21. package/dist/layers/RoPECache.js +13 -33
  22. package/dist/layers/TiedEmbedding.js +6 -7
  23. package/dist/layers/TransformerBlock.js +1 -1
  24. package/dist/{log_sum_exp-CxfBtUaG.js → log_sum_exp-CV_5-TTu.js} +15 -15
  25. package/dist/main.js +24 -20
  26. package/dist/{mat_mul-CbiqIe2d.js → mat_mul-CAbRFWUj.js} +4 -4
  27. package/dist/{max-0Xnlpv8k.js → max-JBBv7aUf.js} +3 -3
  28. package/dist/mulmat_packed_gpu-DW4doKL_.js +71 -0
  29. package/dist/{norm-01kY9I2B.js → norm-B9dQTFYn.js} +12 -12
  30. package/dist/{ones-CrutWGas.js → ones-CMHNqMr6.js} +2 -2
  31. package/dist/ops/appendCache.js +3 -3
  32. package/dist/ops/attentionMask.js +1 -1
  33. package/dist/ops/cpu/appendCache.js +2 -2
  34. package/dist/ops/cpu/attentionMask.js +5 -5
  35. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  36. package/dist/ops/cpu/gatherSub.js +5 -5
  37. package/dist/ops/cpu/gelu.js +1 -1
  38. package/dist/ops/cpu/matMulGelu.js +1 -1
  39. package/dist/ops/cpu/matMulMul.js +1 -1
  40. package/dist/ops/cpu/mulDropout.js +1 -1
  41. package/dist/ops/cpu/normRMS.js +1 -1
  42. package/dist/ops/cpu/qkv.js +3 -3
  43. package/dist/ops/cpu/rope.js +5 -5
  44. package/dist/ops/cpu/scatterSub.js +18 -49
  45. package/dist/ops/fusedSoftmax.js +1 -1
  46. package/dist/ops/gatherSub.js +1 -1
  47. package/dist/ops/gelu.js +1 -1
  48. package/dist/ops/grads/attentionMask.js +15 -11
  49. package/dist/ops/grads/fusedSoftmax.js +12 -10
  50. package/dist/ops/grads/gelu.js +1 -1
  51. package/dist/ops/grads/matMulGelu.js +1 -1
  52. package/dist/ops/grads/normRMS.js +1 -1
  53. package/dist/ops/grads/qkv.js +1 -1
  54. package/dist/ops/grads/rope.js +1 -1
  55. package/dist/ops/log.d.ts +0 -0
  56. package/dist/ops/log.js +1 -0
  57. package/dist/ops/matMulGelu.js +1 -1
  58. package/dist/ops/matMulMul.js +1 -1
  59. package/dist/ops/mulDrop.js +1 -1
  60. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  61. package/dist/ops/normRMS.js +1 -1
  62. package/dist/ops/qkv.js +1 -1
  63. package/dist/ops/rope.js +8 -4
  64. package/dist/ops/scatterSub.js +1 -1
  65. package/dist/ops/webgl/appendCache.js +1 -1
  66. package/dist/ops/webgl/attentionMask.js +1 -1
  67. package/dist/ops/webgl/fusedSoftmax.js +31 -3379
  68. package/dist/ops/webgl/gatherSub.js +1 -1
  69. package/dist/ops/webgl/gelu.js +2 -2
  70. package/dist/{gpgpu_math-BFbOyvk4.js → ops/webgl/log.d.ts} +2 -8
  71. package/dist/ops/webgl/log.js +39 -0
  72. package/dist/ops/webgl/matMulGelu.js +48 -115
  73. package/dist/ops/webgl/matMulMul.js +1 -1
  74. package/dist/ops/webgl/mulDropout.js +1 -1
  75. package/dist/ops/webgl/normRMS.js +2 -2
  76. package/dist/ops/webgl/qkv.js +1 -1
  77. package/dist/ops/webgl/rope.js +1 -1
  78. package/dist/ops/webgl/scatterSub.js +1 -1
  79. package/dist/{ops-CJNniCAV.js → ops-DqtYemmV.js} +143 -135
  80. package/dist/{random_width-C-v-35bY.js → random_width-CLMQG5Jn.js} +6925 -6291
  81. package/dist/{range-Bvs1hidm.js → range-DqYjKnuG.js} +1 -1
  82. package/dist/reciprocal-z49filta.js +25 -0
  83. package/dist/register_all_kernels-COt6wLD0.js +21397 -0
  84. package/dist/{reshape-BH7eBpwq.js → reshape-C45vIIRU.js} +1 -1
  85. package/dist/scatter_nd_util-qgtnviTE.js +46 -0
  86. package/dist/selu_util-4QV_GXTB.js +740 -0
  87. package/dist/shared-ByfrGA97.js +3199 -0
  88. package/dist/{sin-CPAZXNjH.js → sin-9JBrfVaB.js} +1 -1
  89. package/dist/{softmax-DhWoBa7r.js → softmax-DvMvui-_.js} +1 -1
  90. package/dist/{split-BCUhuU7B.js → split-DxrHrPFK.js} +4 -4
  91. package/dist/{stack-BV1v7l3S.js → stack-DgaoDmnF.js} +1 -1
  92. package/dist/{sum-Cvq06317.js → sum-BpcpxNEh.js} +3 -3
  93. package/dist/{tensor-DgTOPY6h.js → tensor-CDz5x1mP.js} +1 -1
  94. package/dist/{tensor2d-CRWjDyUe.js → tensor2d-jO8JY5Jd.js} +1 -1
  95. package/dist/training/AdamExt.js +1 -1
  96. package/dist/training/DatasetBuilder.js +2 -2
  97. package/dist/training/FullTrainer.js +1 -1
  98. package/dist/training/Trainer.js +3 -3
  99. package/dist/training/sparseCrossEntropy.js +4 -4
  100. package/dist/utilities/dummy.d.ts +6 -0
  101. package/dist/utilities/dummy.js +31 -10
  102. package/dist/utilities/generate.js +3 -3
  103. package/dist/utilities/load.d.ts +25 -0
  104. package/dist/utilities/load.js +89 -37
  105. package/dist/utilities/profile.d.ts +5 -0
  106. package/dist/utilities/profile.js +12 -9
  107. package/dist/utilities/safetensors.d.ts +3 -0
  108. package/dist/utilities/safetensors.js +83 -0
  109. package/dist/utilities/save.js +47 -29
  110. package/dist/utilities/weights.js +2 -2
  111. package/dist/{variable-DZ3fF0R2.js → variable-CLVXjN7F.js} +1 -1
  112. package/dist/{zeros-BaHhQTWf.js → zeros-DUkkVccu.js} +8 -8
  113. package/package.json +3 -9
  114. package/dist/Reshape-Biok_3X1.js +0 -212
  115. package/dist/slice_util-DskXqRZa.js +0 -49
  116. package/dist/tfjs_backend-D9Ytje0G.js +0 -1010
@@ -1,8 +1,8 @@
1
- import { o as r, j as p, E as u, a6 as E, a7 as h, a1 as S, s as $, a8 as d } from "./index-Du-bmOP8.js";
2
- import { e as K } from "./axis_util-BczFISHz.js";
3
- import { m as T } from "./max-0Xnlpv8k.js";
4
- import { r as m } from "./reshape-BH7eBpwq.js";
5
- import { s as _ } from "./sum-Cvq06317.js";
1
+ import { q as r, w as p, E as u, a8 as E, a9 as h, p as S, s as $, a7 as d } from "./index-CamYe_M8.js";
2
+ import { e as K } from "./axis_util-CCNL7jea.js";
3
+ import { m as T } from "./max-JBBv7aUf.js";
4
+ import { r as m } from "./reshape-C45vIIRU.js";
5
+ import { s as _ } from "./sum-BpcpxNEh.js";
6
6
  /**
7
7
  * @license
8
8
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -23,7 +23,7 @@ function b(s) {
23
23
  const o = { x: p(s, "x", "exp") };
24
24
  return u.runKernel(E, o);
25
25
  }
26
- const N = /* @__PURE__ */ r({ exp_: b });
26
+ const w = /* @__PURE__ */ r({ exp_: b });
27
27
  /**
28
28
  * @license
29
29
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -40,11 +40,11 @@ const N = /* @__PURE__ */ r({ exp_: b });
40
40
  * limitations under the License.
41
41
  * =============================================================================
42
42
  */
43
- function j(s) {
43
+ function N(s) {
44
44
  const o = { x: p(s, "x", "log", "float32") };
45
45
  return u.runKernel(h, o);
46
46
  }
47
- const v = /* @__PURE__ */ r({ log_: j });
47
+ const q = /* @__PURE__ */ r({ log_: N });
48
48
  /**
49
49
  * @license
50
50
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -61,22 +61,22 @@ const v = /* @__PURE__ */ r({ log_: j });
61
61
  * limitations under the License.
62
62
  * =============================================================================
63
63
  */
64
- function w(s, a = null, o = !1) {
65
- const n = p(s, "x", "logSumExp"), t = S(a, n.shape), x = T(
66
- n,
64
+ function v(s, n = null, o = !1) {
65
+ const a = p(s, "x", "logSumExp"), t = S(n, a.shape), x = T(
66
+ a,
67
67
  t,
68
68
  !0
69
69
  /* keepDims */
70
- ), i = $(n, x), l = N(i), f = _(l, t), c = v(f), e = d(m(x, c.shape), c);
70
+ ), i = $(a, x), l = w(i), f = _(l, t), c = q(f), e = d(m(x, c.shape), c);
71
71
  if (o) {
72
72
  const g = K(e.shape, t);
73
73
  return m(e, g);
74
74
  }
75
75
  return e;
76
76
  }
77
- const M = /* @__PURE__ */ r({ logSumExp_: w });
77
+ const M = /* @__PURE__ */ r({ logSumExp_: v });
78
78
  export {
79
- v as a,
80
- N as e,
79
+ q as a,
80
+ w as e,
81
81
  M as l
82
82
  };
package/dist/main.js CHANGED
@@ -1,11 +1,11 @@
1
- import { default as E } from "./NanoGPTModel.js";
2
- import { default as G } from "./TeachableLLM.js";
3
- import { default as R } from "./tokeniser/CharTokeniser.js";
4
- import { default as q } from "./tokeniser/bpe.js";
5
- import { default as A } from "./utilities/waitForModel.js";
6
- import { default as I } from "./data/textLoader.js";
7
- import { estimateMemoryUsage as K, estimateParameterCount as O, estimateResources as Q, estimateTrainingMemoryUsage as S, validateConfig as V } from "./utilities/parameters.js";
8
- import "./index-Du-bmOP8.js";
1
+ import { default as R } from "./NanoGPTModel.js";
2
+ import { default as q } from "./TeachableLLM.js";
3
+ import { default as A } from "./tokeniser/CharTokeniser.js";
4
+ import { default as I } from "./tokeniser/bpe.js";
5
+ import { default as K } from "./utilities/waitForModel.js";
6
+ import { default as Q } from "./data/textLoader.js";
7
+ import { estimateMemoryUsage as V, estimateParameterCount as W, estimateResources as X, estimateTrainingMemoryUsage as Y, validateConfig as Z } from "./utilities/parameters.js";
8
+ import "./index-CamYe_M8.js";
9
9
  import "./ops/cpu/scatterSub.js";
10
10
  import "./ops/webgl/scatterSub.js";
11
11
  import "./ops/cpu/gatherSub.js";
@@ -16,7 +16,10 @@ import "./ops/grads/attentionMask.js";
16
16
  import "./ops/cpu/qkv.js";
17
17
  import "./ops/webgl/qkv.js";
18
18
  import "./ops/grads/qkv.js";
19
- import "@tensorflow/tfjs";
19
+ import "./random_width-CLMQG5Jn.js";
20
+ import "./register_all_kernels-COt6wLD0.js";
21
+ import "./index-Tf7vU29b.js";
22
+ import "./dataset-BFFipD1c.js";
20
23
  import "./ops/cpu/rope.js";
21
24
  import "./ops/webgl/rope.js";
22
25
  import "./ops/grads/rope.js";
@@ -34,16 +37,17 @@ import "./ops/grads/gelu.js";
34
37
  import "./ops/cpu/normRMS.js";
35
38
  import "./ops/webgl/normRMS.js";
36
39
  import "./ops/grads/normRMS.js";
40
+ import "./ops/webgl/log.js";
37
41
  export {
38
- q as BPETokeniser,
39
- R as CharTokeniser,
40
- E as NanoGPT,
41
- G as TeachableLLM,
42
- K as estimateMemoryUsage,
43
- O as estimateParameterCount,
44
- Q as estimateResources,
45
- S as estimateTrainingMemoryUsage,
46
- I as loadTextData,
47
- V as validateConfig,
48
- A as waitForModel
42
+ I as BPETokeniser,
43
+ A as CharTokeniser,
44
+ R as NanoGPT,
45
+ q as TeachableLLM,
46
+ V as estimateMemoryUsage,
47
+ W as estimateParameterCount,
48
+ X as estimateResources,
49
+ Y as estimateTrainingMemoryUsage,
50
+ Q as loadTextData,
51
+ Z as validateConfig,
52
+ K as waitForModel
49
53
  };
@@ -1,4 +1,4 @@
1
- import { o as m, j as s, u as c, E as M, B as p } from "./index-Du-bmOP8.js";
1
+ import { q as m, w as s, C as c, E as M, D as p } from "./index-CamYe_M8.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -15,10 +15,10 @@ import { o as m, j as s, u as c, E as M, B as p } from "./index-Du-bmOP8.js";
15
15
  * limitations under the License.
16
16
  * =============================================================================
17
17
  */
18
- function f(e, o, n = !1, l = !1) {
19
- let a = s(e, "a", "matMul"), t = s(o, "b", "matMul");
18
+ function f(e, n, o = !1, l = !1) {
19
+ let a = s(e, "a", "matMul"), t = s(n, "b", "matMul");
20
20
  [a, t] = c(a, t);
21
- const r = { a, b: t }, u = { transposeA: n, transposeB: l };
21
+ const r = { a, b: t }, u = { transposeA: o, transposeB: l };
22
22
  return M.runKernel(p, r, u);
23
23
  }
24
24
  const i = /* @__PURE__ */ m({ matMul_: f });
@@ -1,4 +1,4 @@
1
- import { o as r, j as e, E as x, M as c } from "./index-Du-bmOP8.js";
1
+ import { q as r, w as e, E as x, M as c } from "./index-CamYe_M8.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -15,8 +15,8 @@ import { o as r, j as e, E as x, M as c } from "./index-Du-bmOP8.js";
15
15
  * limitations under the License.
16
16
  * =============================================================================
17
17
  */
18
- function m(n, o = null, s = !1) {
19
- const t = { x: e(n, "x", "max") }, a = { reductionIndices: o, keepDims: s };
18
+ function m(n, s = null, o = !1) {
19
+ const t = { x: e(n, "x", "max") }, a = { reductionIndices: s, keepDims: o };
20
20
  return x.runKernel(c, t, a);
21
21
  }
22
22
  const l = /* @__PURE__ */ r({ max_: m });
@@ -0,0 +1,71 @@
1
+ import { u as z } from "./gpgpu_math-B_ycgZ4W.js";
2
+ /**
3
+ * @license
4
+ * Copyright 2018 Google LLC. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ * =============================================================================
17
+ */
18
+ class g {
19
+ constructor(e, s, v, a = !1, r = !1, c = !1, t = null, o = !1, l = !1) {
20
+ this.variableNames = ["matrixA", "matrixB"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape = v, this.enableShapeUniforms = z(this.outputShape.length);
21
+ const m = a ? e[1] : e[2], u = Math.ceil(m / 2), b = a ? "i * 2, rc.y" : "rc.y, i * 2", x = r ? "rc.z, i * 2" : "i * 2, rc.z", n = a ? ["a.xxyy", "a.zzww"] : ["a.xxzz", "a.yyww"], p = r ? ["b.xzxz", "b.ywyw"] : ["b.xyxy", "b.zwzw"];
22
+ let i = "", h = "";
23
+ t && (o ? i = `vec4 activation(vec4 a) {
24
+ vec4 b = getPreluActivationWeightsAtOutCoords();
25
+ ${t}
26
+ }` : l ? i = `vec4 activation(vec4 a) {
27
+ vec4 b = getLeakyreluAlphaAtOutCoords();
28
+ ${t}
29
+ }` : i = `vec4 activation(vec4 x) {
30
+ ${t}
31
+ }`, h = "result = activation(result);");
32
+ const $ = c ? "result += getBiasAtOutCoords();" : "";
33
+ c && this.variableNames.push("bias"), o && this.variableNames.push("preluActivationWeights"), l && this.variableNames.push("leakyreluAlpha");
34
+ let d = "rc.x", f = "rc.x";
35
+ e[0] < s[0] ? d = `imod(rc.x, ${e[0]})` : s[0] < e[0] && (f = `imod(rc.x, ${s[0]})`), this.userCode = `
36
+ ${i}
37
+ // Don't use uniform for sharedDimensionPacked for performance.
38
+ const float sharedDimension = ${u}.0;
39
+
40
+ vec4 dot2x2ARowBCol(ivec3 rc) {
41
+ vec4 result = vec4(0);
42
+ int batchA = ${d};
43
+ int batchB = ${f};
44
+ for (int i = 0; i < ${u}; i++) {
45
+ vec4 a = getMatrixA(batchA, ${b});
46
+ vec4 b = getMatrixB(batchB, ${x});
47
+
48
+ // These swizzled products need to be separately added.
49
+ // See: https://github.com/tensorflow/tfjs/issues/1735
50
+ result += (${n[0]} * ${p[0]});
51
+ result += (${n[1]} * ${p[1]});
52
+ }
53
+ return result;
54
+ }
55
+
56
+ void main() {
57
+ ivec3 rc = getOutputCoords();
58
+ vec4 result = dot2x2ARowBCol(rc);
59
+
60
+ ${$}
61
+
62
+ ${h}
63
+
64
+ setOutput(result);
65
+ }
66
+ `;
67
+ }
68
+ }
69
+ export {
70
+ g as M
71
+ };
@@ -1,8 +1,8 @@
1
- import { o as l, j as c, E as y, a0 as E, a1 as w, a2 as o, a3 as u, W as v, f as I, a4 as A } from "./index-Du-bmOP8.js";
2
- import { e as $ } from "./axis_util-BczFISHz.js";
3
- import { m as f } from "./max-0Xnlpv8k.js";
4
- import { r as h } from "./reshape-BH7eBpwq.js";
5
- import { s as t } from "./sum-Cvq06317.js";
1
+ import { q as l, w as c, E as y, a2 as E, p as w, a3 as o, a4 as u, l as v, f as I, a5 as A } from "./index-CamYe_M8.js";
2
+ import { e as $ } from "./axis_util-CCNL7jea.js";
3
+ import { m as f } from "./max-JBBv7aUf.js";
4
+ import { r as h } from "./reshape-C45vIIRU.js";
5
+ import { s as t } from "./sum-BpcpxNEh.js";
6
6
  /**
7
7
  * @license
8
8
  * Copyright 2020 Google Inc. All Rights Reserved.
@@ -40,21 +40,21 @@ const s = /* @__PURE__ */ l({ min_: k });
40
40
  * limitations under the License.
41
41
  * =============================================================================
42
42
  */
43
- function T(n, e = "euclidean", r = null, m = !1) {
43
+ function q(n, e = "euclidean", r = null, m = !1) {
44
44
  n = c(n, "x", "norm");
45
- const a = d(n, e, r);
45
+ const a = p(n, e, r);
46
46
  let i = a.shape;
47
47
  if (m) {
48
- const p = w(r, n.shape);
49
- i = $(a.shape, p);
48
+ const d = w(r, n.shape);
49
+ i = $(a.shape, d);
50
50
  }
51
51
  return h(a, i);
52
52
  }
53
- function d(n, e, r = null) {
53
+ function p(n, e, r = null) {
54
54
  if (n.rank === 0)
55
55
  return o(n);
56
56
  if (n.rank !== 1 && r === null)
57
- return d(h(n, [-1]), e, r);
57
+ return p(h(n, [-1]), e, r);
58
58
  if (n.rank === 1 || typeof r == "number" || Array.isArray(r) && r.length === 1) {
59
59
  if (e === 1)
60
60
  return t(o(n), r);
@@ -79,7 +79,7 @@ function d(n, e, r = null) {
79
79
  }
80
80
  throw new Error(`Error in norm: invalid axis: ${r}`);
81
81
  }
82
- const N = /* @__PURE__ */ l({ norm_: T });
82
+ const N = /* @__PURE__ */ l({ norm_: q });
83
83
  export {
84
84
  s as m,
85
85
  N as n
@@ -1,5 +1,5 @@
1
- import { n, p as t, q as m, E as i } from "./index-Du-bmOP8.js";
2
- import { z as c, c as f } from "./zeros-BaHhQTWf.js";
1
+ import { y as n, B as t, i as m, E as i } from "./index-CamYe_M8.js";
2
+ import { z as c, c as f } from "./zeros-DUkkVccu.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,8 +1,8 @@
1
- import { e as a } from "../index-Du-bmOP8.js";
1
+ import { e as a } from "../index-CamYe_M8.js";
2
2
  import "./cpu/appendCache.js";
3
3
  import "./webgl/appendCache.js";
4
- import { c as s } from "../concat-DdKPyAtw.js";
5
- import { z as c } from "../zeros-BaHhQTWf.js";
4
+ import { c as s } from "../concat-XOK9ANZu.js";
5
+ import { z as c } from "../zeros-DUkkVccu.js";
6
6
  function i(r, p, n, o) {
7
7
  if (!o) {
8
8
  const e = r.shape[2];
@@ -1,4 +1,4 @@
1
- import { e as o } from "../index-Du-bmOP8.js";
1
+ import { e as o } from "../index-CamYe_M8.js";
2
2
  import "./cpu/attentionMask.js";
3
3
  import "./webgl/attentionMask.js";
4
4
  import "./grads/attentionMask.js";
@@ -1,5 +1,5 @@
1
- import { r as d } from "../../index-Du-bmOP8.js";
2
- import { c as h } from "../../concat-DdKPyAtw.js";
1
+ import { r as d } from "../../index-CamYe_M8.js";
2
+ import { c as h } from "../../concat-XOK9ANZu.js";
3
3
  function u(p) {
4
4
  const { cache: n, item: s } = p.inputs, { maxSize: r, pastLen: c } = p.attrs, t = n.shape[0], o = n.shape[1], a = n.shape[3], e = s.shape[2];
5
5
  if (c + e <= r) {
@@ -1,8 +1,8 @@
1
- import { r as a, g as p, f as u } from "../../index-Du-bmOP8.js";
2
- import { l as N, w as b } from "../../ops-CJNniCAV.js";
3
- import { o as g } from "../../ones-CrutWGas.js";
4
- import { z as A } from "../../zeros-BaHhQTWf.js";
5
- import { m as I } from "../../mat_mul-CbiqIe2d.js";
1
+ import { r as a, g as p, f as u } from "../../index-CamYe_M8.js";
2
+ import { l as N, w as b } from "../../ops-DqtYemmV.js";
3
+ import { o as g } from "../../ones-CMHNqMr6.js";
4
+ import { z as A } from "../../zeros-DUkkVccu.js";
5
+ import { m as I } from "../../mat_mul-CAbRFWUj.js";
6
6
  function o(n) {
7
7
  const { q: s, k: e } = n.inputs, { divisor: r } = n.attrs, c = s.shape[2], t = e.shape[2], m = N.bandPart(g([t, t]), -1, 0).cast("bool"), l = A([t, t]), i = p([t, t], Number.NEGATIVE_INFINITY), f = b(m, l, i), k = I(s, e, !1, !0).mul(u(r)), d = f.slice([0, 0], [c, t]).expandDims(0).expandDims(0);
8
8
  return k.add(d);
@@ -1,5 +1,5 @@
1
- import { r as n } from "../../index-Du-bmOP8.js";
2
- import { s as f } from "../../softmax-DhWoBa7r.js";
1
+ import { r as n } from "../../index-CamYe_M8.js";
2
+ import { s as f } from "../../softmax-DvMvui-_.js";
3
3
  function r(t) {
4
4
  const { inputs: s, attrs: i } = t, { logits: o } = s, { dim: a, dropoutRate: e } = i;
5
5
  if (!o)
@@ -1,6 +1,6 @@
1
- import { o as u, j as c, E as g, O as h, r as m, s as p } from "../../index-Du-bmOP8.js";
2
- import { r as l } from "../../range-Bvs1hidm.js";
3
- import { s as N } from "../../stack-BV1v7l3S.js";
1
+ import { q as u, w as c, E as g, Y as h, r as m, s as p } from "../../index-CamYe_M8.js";
2
+ import { r as l } from "../../range-DqYjKnuG.js";
3
+ import { s as N } from "../../stack-DgaoDmnF.js";
4
4
  /**
5
5
  * @license
6
6
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -23,8 +23,8 @@ function f(e, s) {
23
23
  }
24
24
  const b = /* @__PURE__ */ u({ gatherND_: f });
25
25
  function d(e) {
26
- const { values: s, labels: n, logits: r } = e.inputs, t = n.shape[0], o = l(0, t, 1, "int32"), a = N([o, n], 1), i = b(r, a);
27
- return p(s, i);
26
+ const { values: s, labels: n, logits: r } = e.inputs, t = n.shape[0], a = l(0, t, 1, "int32"), i = N([a, n], 1), o = b(r, i);
27
+ return p(s, o);
28
28
  }
29
29
  const k = {
30
30
  kernelName: "EfficientGatherSub",
@@ -1,4 +1,4 @@
1
- import { r as t, t as d } from "../../index-Du-bmOP8.js";
1
+ import { r as t, t as d } from "../../index-CamYe_M8.js";
2
2
  const o = 0.7978845608028654, c = 0.044715;
3
3
  function m(u) {
4
4
  const { inputs: l } = u, { x: e } = l, n = e;
@@ -1,4 +1,4 @@
1
- import { r as a, t as i } from "../../index-Du-bmOP8.js";
1
+ import { r as a, t as i } from "../../index-CamYe_M8.js";
2
2
  const c = 0.7978845608028654, m = 0.044715;
3
3
  function M(o) {
4
4
  const { inputs: s } = o, { x: t, kernel: l } = s, e = t, u = l;
@@ -1,4 +1,4 @@
1
- import { r as n, t as M } from "../../index-Du-bmOP8.js";
1
+ import { r as n, t as M } from "../../index-CamYe_M8.js";
2
2
  function e(t) {
3
3
  const { inputs: r, attrs: o } = t, { transposeA: s, transposeB: l } = o, { x: c, kernel: u, y: a } = r, m = c, i = u, k = a;
4
4
  return M(() => m.matMul(i, s, l).mul(k));
@@ -1,4 +1,4 @@
1
- import { r as e, b as u } from "../../index-Du-bmOP8.js";
1
+ import { r as e, b as u } from "../../index-CamYe_M8.js";
2
2
  function n(o) {
3
3
  const { inputs: r } = o, { a: l, b: t } = r;
4
4
  return console.warn("Using fallback mulDrop implementation without dropout."), u(l, t);
@@ -1,4 +1,4 @@
1
- import { r as o, t as d } from "../../index-Du-bmOP8.js";
1
+ import { r as o, t as d } from "../../index-CamYe_M8.js";
2
2
  function i(t) {
3
3
  const { inputs: e } = t, { x: n, gamma: s } = e, r = n, a = s;
4
4
  return d(() => {
@@ -1,6 +1,6 @@
1
- import { r as q } from "../../index-Du-bmOP8.js";
2
- import { r as o } from "../../reshape-BH7eBpwq.js";
3
- import { s as x } from "../../split-BCUhuU7B.js";
1
+ import { r as q } from "../../index-CamYe_M8.js";
2
+ import { r as o } from "../../reshape-C45vIIRU.js";
3
+ import { s as x } from "../../split-DxrHrPFK.js";
4
4
  function v(p) {
5
5
  const { x: c, kernel: K } = p.inputs, { heads: n } = p.attrs, [s, e, t] = c.shape, a = o(c, [s * e, t]), i = a.dot(K);
6
6
  a.dispose();
@@ -1,8 +1,8 @@
1
- import { r as S } from "../../index-Du-bmOP8.js";
2
- import { r as F } from "../../range-Bvs1hidm.js";
3
- import { g as I } from "../../gather-D6MsdXqc.js";
4
- import { s as E } from "../../stack-BV1v7l3S.js";
5
- import { c as T } from "../../concat-DdKPyAtw.js";
1
+ import { r as S } from "../../index-CamYe_M8.js";
2
+ import { r as F } from "../../range-DqYjKnuG.js";
3
+ import { g as I } from "../../gather-DKtUaTtA.js";
4
+ import { s as E } from "../../stack-DgaoDmnF.js";
5
+ import { c as T } from "../../concat-XOK9ANZu.js";
6
6
  function U(t, c, p, o, r) {
7
7
  const n = o.shape[3], s = p;
8
8
  if (s > n) return o;
@@ -1,39 +1,8 @@
1
- import { o as l, n as g, j as h, E as k, a5 as w, r as $, s as d, b as m } from "../../index-Du-bmOP8.js";
2
- import { r as b } from "../../range-Bvs1hidm.js";
3
- import { s as E } from "../../stack-BV1v7l3S.js";
4
- import { o as D } from "../../ones-CrutWGas.js";
5
- function N(a, r, t) {
6
- const s = r.rank > 1 ? r.shape[r.rank - 1] : 1, e = r.rank > 1 ? r.rank - 1 : 1, o = `Must have updates.shape = indices.shape[:batchDim] + shape[sliceDim:], got updates.shape: ${t.shape}, indices.shape: ${r.shape}, shape: ${a}, sliceDim: ${s}, and batchDim: ${e}.`;
7
- if (t.rank < e)
8
- throw new Error(o + ` update.rank < ${e}. `);
9
- if (a.length < s + (t.rank - e))
10
- throw new Error(o + ` Output shape length < ${s + (t.rank - e)}`);
11
- if (t.rank !== e + a.length - s)
12
- throw new Error(o + ` update.rank != ${e + a.length - s}`);
13
- for (let n = 0; n < e; ++n)
14
- if (t.shape[n] !== r.shape[n])
15
- throw new Error(o + ` updates.shape[${n}] (${t.shape[n]}) != indices.shape[${n}] (${r.shape[n]}).`);
16
- for (let n = 0; n < t.rank - e; ++n)
17
- if (t.shape[n + e] !== a[n + s])
18
- throw new Error(o + ` updates.shape[${n + e}] (${t.shape[n + e]}) != shape[${n + e}] (${a[n + e]})`);
19
- }
20
- function S(a, r, t) {
21
- if (r.rank < 1)
22
- throw new Error(`tf.scatterND() expects the indices to be rank 1 or higher, but the rank was ${r.rank}.`);
23
- if (a.rank < 1)
24
- throw new Error(`tf.scatterND() expects the updates to be rank 1 or higher, but the rank was ${a.rank}.`);
25
- if (r.dtype !== "int32")
26
- throw new Error(`The dtype of 'indices' should be int32, but got dtype: ${r.dtype}`);
27
- if (t.length < 1)
28
- throw new Error(`Output rank must be greater or equal to 1, but got shape: ${t}`);
29
- if (t.length === 0) {
30
- if (r.size === 0)
31
- throw new Error(`Indices specified for empty output. indices shape: ${r.shape}`);
32
- if (a.size === 0)
33
- throw new Error(`Updates specified for empty output. updates shape: ${a.shape}`);
34
- }
35
- N(t, r, a);
36
- }
1
+ import { q as f, y as g, w as o, E as l, X as N, r as b, s as S, b as h } from "../../index-CamYe_M8.js";
2
+ import { v as D } from "../../scatter_nd_util-qgtnviTE.js";
3
+ import { r as k } from "../../range-DqYjKnuG.js";
4
+ import { s as v } from "../../stack-DgaoDmnF.js";
5
+ import { o as E } from "../../ones-CMHNqMr6.js";
37
6
  /**
38
7
  * @license
39
8
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -50,21 +19,21 @@ function S(a, r, t) {
50
19
  * limitations under the License.
51
20
  * =============================================================================
52
21
  */
53
- function y(a, r, t) {
54
- g(t);
55
- const s = h(a, "indices", "scatterND", "int32"), e = h(r, "updates", "scatterND");
56
- S(e, s, t);
57
- const o = { indices: s, updates: e }, n = { shape: t };
58
- return k.runKernel(w, o, n);
22
+ function I(a, e, s) {
23
+ g(s);
24
+ const n = o(a, "indices", "scatterND", "int32"), t = o(e, "updates", "scatterND");
25
+ D(t, n, s);
26
+ const c = { indices: n, updates: t }, r = { shape: s };
27
+ return l.runKernel(N, c, r);
59
28
  }
60
- const v = /* @__PURE__ */ l({ scatterND_: y });
61
- function I(a) {
62
- const { logits: r, labels: t, dy: s } = a.inputs, e = t.shape[0], o = r.shape[1], n = b(0, e, 1, "int32"), i = E([n, t], 1), c = D([e]), p = v(i, c, [e, o]), f = d(r, p), u = s.reshape([e, 1]);
63
- return m(f, u);
29
+ const y = /* @__PURE__ */ f({ scatterND_: I });
30
+ function K(a) {
31
+ const { logits: e, labels: s, dy: n } = a.inputs, t = s.shape[0], c = e.shape[1], r = k(0, t, 1, "int32"), i = v([r, s], 1), d = E([t]), u = y(i, d, [t, c]), p = S(e, u), m = n.reshape([t, 1]);
32
+ return h(p, m);
64
33
  }
65
- const T = {
34
+ const L = {
66
35
  kernelName: "EfficientScatterSub",
67
36
  backendName: "cpu",
68
- kernelFunc: I
37
+ kernelFunc: K
69
38
  };
70
- $(T);
39
+ b(L);
@@ -1,4 +1,4 @@
1
- import { e as t } from "../index-Du-bmOP8.js";
1
+ import { e as t } from "../index-CamYe_M8.js";
2
2
  import "./cpu/fusedSoftmax.js";
3
3
  import "./webgl/fusedSoftmax.js";
4
4
  import "./grads/fusedSoftmax.js";
@@ -1,4 +1,4 @@
1
- import { e as n } from "../index-Du-bmOP8.js";
1
+ import { e as n } from "../index-CamYe_M8.js";
2
2
  import "./cpu/gatherSub.js";
3
3
  import "./webgl/gatherSub.js";
4
4
  function f(r, e, t) {
package/dist/ops/gelu.js CHANGED
@@ -1,4 +1,4 @@
1
- import "../index-Du-bmOP8.js";
1
+ import "../index-CamYe_M8.js";
2
2
  import "./cpu/gelu.js";
3
3
  import "./webgl/gelu.js";
4
4
  import { d as e, g as i } from "./grads/gelu.js";
@@ -1,25 +1,29 @@
1
- import { h as l, f as a } from "../../index-Du-bmOP8.js";
2
- import { matMulMul as i } from "../matMulMul.js";
3
- const m = {
1
+ import { h as m, f as i } from "../../index-CamYe_M8.js";
2
+ import { matMulMul as u } from "../matMulMul.js";
3
+ const p = {
4
4
  kernelName: "AttentionMask",
5
5
  inputsToSave: ["q", "k"],
6
6
  outputsToSave: [],
7
- gradFunc: (t, u, c) => {
7
+ gradFunc: (t, c, l) => {
8
8
  if (Array.isArray(t))
9
9
  throw new Error("Expected dy to be a single Tensor");
10
- const [e, o] = u, { divisor: n } = c;
10
+ const [e, n] = c, { divisor: a } = l;
11
11
  return {
12
- q: () => i(t, o, a(n)),
12
+ q: () => u(t, n, i(a)),
13
13
  k: () => {
14
- const s = e.transpose([0, 1, 3, 2]), r = i(s, t, a(n));
15
- return s.dispose(), r.transpose([0, 1, 3, 2]);
14
+ const s = e.transpose([0, 1, 3, 2]), r = u(s, t, i(a));
15
+ s.dispose();
16
+ const o = r.transpose([0, 1, 3, 2]);
17
+ return r.dispose(), o;
16
18
  },
17
19
  mask: () => t,
18
20
  divisor: () => {
19
- const s = e.matMul(o, !1, !0), r = t.mul(s);
20
- return s.dispose(), r.sum();
21
+ const s = e.matMul(n, !1, !0), r = t.mul(s);
22
+ s.dispose();
23
+ const o = r.sum();
24
+ return r.dispose(), o;
21
25
  }
22
26
  };
23
27
  }
24
28
  };
25
- l(m);
29
+ m(p);
@@ -1,20 +1,22 @@
1
- import { h as d, b as u, s as f } from "../../index-Du-bmOP8.js";
2
- import { mulDrop as l } from "../mulDrop.js";
3
- import { s as g } from "../../sum-Cvq06317.js";
4
- const T = {
1
+ import { h as f, b as i, s as l } from "../../index-CamYe_M8.js";
2
+ import { mulDrop as g } from "../mulDrop.js";
3
+ import { s as T } from "../../sum-BpcpxNEh.js";
4
+ const Y = {
5
5
  kernelName: "FusedSoftmax",
6
6
  outputsToSave: [!0],
7
- gradFunc: (o, i, n) => {
8
- const [s] = i, { dim: a, dropoutRate: t, seed: e } = n, p = !0, r = t && e ? l(o, s, t, e) : u(o, s);
7
+ gradFunc: (o, n, a) => {
8
+ const [s] = n, { dim: p, dropoutRate: t, seed: e } = a, c = !0, r = t && e ? g(o, s, t, e) : i(o, s);
9
9
  return {
10
10
  logits: () => {
11
- const m = g(r, [a], p), c = u(m, s);
12
- return m.dispose(), f(r, c);
11
+ const m = T(r, [p], c), u = i(m, s);
12
+ m.dispose();
13
+ const d = l(r, u);
14
+ return u.dispose(), d;
13
15
  }
14
16
  };
15
17
  }
16
18
  };
17
- d(T);
19
+ f(Y);
18
20
  export {
19
- T as softmaxGradConfig
21
+ Y as softmaxGradConfig
20
22
  };
@@ -1,4 +1,4 @@
1
- import { h as t, e as n } from "../../index-Du-bmOP8.js";
1
+ import { h as t, e as n } from "../../index-CamYe_M8.js";
2
2
  import "../cpu/gelu.js";
3
3
  import "../webgl/gelu.js";
4
4
  const o = {
@@ -1,4 +1,4 @@
1
- import { h as a, e as o } from "../../index-Du-bmOP8.js";
1
+ import { h as a, e as o } from "../../index-CamYe_M8.js";
2
2
  function s(e, n, r) {
3
3
  return o().runKernel("MatMulGeluGrad", { dy: e, x: n, kernel: r });
4
4
  }
@@ -1,4 +1,4 @@
1
- import { h as t, e as g } from "../../index-Du-bmOP8.js";
1
+ import { h as t, e as g } from "../../index-CamYe_M8.js";
2
2
  function s(r, a, n) {
3
3
  return g().runKernel("RMSNormGrad", { dy: r, x: a, gamma: n });
4
4
  }