@genai-fi/nanogpt 0.4.4 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. package/dist/BaseLayer-BhrMN8JO.js +135 -0
  2. package/dist/Generator.js +44 -41
  3. package/dist/NanoGPTModel.d.ts +12 -16
  4. package/dist/NanoGPTModel.js +128 -138
  5. package/dist/{Reshape-CiAY8ltP.js → Reshape-BE5rA4rT.js} +8 -8
  6. package/dist/TeachableLLM.js +8 -5
  7. package/dist/{TiedEmbedding-DznFwzcB.js → TiedEmbedding-DsDRvLB0.js} +751 -768
  8. package/dist/{axis_util-QP0LdI1v.js → axis_util-97KkkyRQ.js} +1 -1
  9. package/dist/broadcast_to-CMlkG8NS.js +44 -0
  10. package/dist/{concat-DvWM7HGZ.js → concat-Cxbo2sOz.js} +3 -3
  11. package/dist/{dropout-DFEXTPV0.js → dropout-kbDY39Ci.js} +1 -1
  12. package/dist/{gather-C5D8PxwA.js → gather-Bxe1Qip8.js} +4 -4
  13. package/dist/{gpgpu_math-CUzjlO9A.js → gpgpu_math-C0zyxKFi.js} +1 -1
  14. package/dist/{index--6vO-cOz.js → index-iNhkcAEQ.js} +82 -82
  15. package/dist/{kernel_funcs_utils-C6YBCuOt.js → kernel_funcs_utils-C4eIk4fE.js} +20 -20
  16. package/dist/layers/BaseLayer.d.ts +28 -4
  17. package/dist/layers/BaseLayer.js +3 -16
  18. package/dist/layers/CausalSelfAttention.d.ts +22 -24
  19. package/dist/layers/CausalSelfAttention.js +73 -127
  20. package/dist/layers/MLP.d.ts +8 -15
  21. package/dist/layers/MLP.js +43 -81
  22. package/dist/layers/RMSNorm.d.ts +5 -11
  23. package/dist/layers/RMSNorm.js +13 -29
  24. package/dist/layers/RoPECache.js +14 -12
  25. package/dist/layers/TiedEmbedding.d.ts +6 -16
  26. package/dist/layers/TiedEmbedding.js +5 -5
  27. package/dist/layers/TransformerBlock.d.ts +12 -16
  28. package/dist/layers/TransformerBlock.js +20 -41
  29. package/dist/{log_sum_exp-CiEy1aUe.js → log_sum_exp-CkumwesB.js} +11 -11
  30. package/dist/main.js +22 -19
  31. package/dist/{mat_mul-BEHRPMh0.js → mat_mul-D0SifYfJ.js} +3 -3
  32. package/dist/{max-BUShNgfh.js → max-CYaAjEEp.js} +3 -3
  33. package/dist/{moments-DYOHXoRV.js → moments-B06NlR_V.js} +6 -6
  34. package/dist/{norm-DSva3hI3.js → norm-D3676xIo.js} +7 -7
  35. package/dist/{ones-D6kB8bdY.js → ones-BIeFnPHR.js} +2 -2
  36. package/dist/ops/appendCache.js +4 -4
  37. package/dist/ops/attentionMask.d.ts +1 -1
  38. package/dist/ops/attentionMask.js +4 -4
  39. package/dist/ops/cpu/appendCache.js +2 -2
  40. package/dist/ops/cpu/attentionMask.js +14 -15
  41. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  42. package/dist/ops/cpu/gatherSub.js +5 -5
  43. package/dist/ops/cpu/gelu.js +1 -1
  44. package/dist/ops/cpu/matMulGelu.js +1 -1
  45. package/dist/ops/cpu/matMulMul.d.ts +1 -0
  46. package/dist/ops/cpu/matMulMul.js +17 -0
  47. package/dist/ops/cpu/mulDropout.js +1 -1
  48. package/dist/ops/cpu/normRMS.d.ts +1 -0
  49. package/dist/ops/cpu/normRMS.js +39 -0
  50. package/dist/ops/cpu/qkv.js +3 -3
  51. package/dist/ops/cpu/rope.js +5 -5
  52. package/dist/ops/cpu/scatterSub.js +8 -8
  53. package/dist/ops/fusedSoftmax.js +1 -1
  54. package/dist/ops/gatherSub.js +1 -1
  55. package/dist/ops/gelu.js +1 -1
  56. package/dist/ops/grads/attentionMask.js +13 -9
  57. package/dist/ops/grads/fusedSoftmax.js +12 -9
  58. package/dist/ops/grads/gelu.js +1 -1
  59. package/dist/ops/grads/matMulGelu.js +1 -1
  60. package/dist/ops/grads/normRMS.d.ts +2 -0
  61. package/dist/ops/grads/normRMS.js +20 -0
  62. package/dist/ops/grads/qkv.js +19 -9
  63. package/dist/ops/grads/rope.js +1 -1
  64. package/dist/ops/matMulGelu.js +1 -1
  65. package/dist/ops/matMulMul.d.ts +2 -0
  66. package/dist/ops/matMulMul.js +9 -0
  67. package/dist/ops/mulDrop.js +1 -1
  68. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  69. package/dist/ops/normRMS.d.ts +2 -0
  70. package/dist/ops/normRMS.js +10 -0
  71. package/dist/ops/qkv.js +1 -1
  72. package/dist/ops/scatterSub.js +1 -1
  73. package/dist/ops/webgl/appendCache.js +1 -1
  74. package/dist/ops/webgl/attentionMask.js +13 -12
  75. package/dist/ops/webgl/fusedSoftmax.js +43 -40
  76. package/dist/ops/webgl/gatherSub.js +1 -1
  77. package/dist/ops/webgl/gelu.js +2 -2
  78. package/dist/ops/webgl/matMulGelu.d.ts +3 -2
  79. package/dist/ops/webgl/matMulGelu.js +77 -75
  80. package/dist/ops/webgl/matMulMul.d.ts +14 -0
  81. package/dist/ops/webgl/matMulMul.js +28 -0
  82. package/dist/ops/webgl/mulDropout.js +1 -1
  83. package/dist/ops/webgl/normRMS.d.ts +1 -0
  84. package/dist/ops/webgl/normRMS.js +86 -0
  85. package/dist/ops/webgl/qkv.js +1 -1
  86. package/dist/ops/webgl/rope.js +1 -1
  87. package/dist/ops/webgl/scatterSub.js +1 -1
  88. package/dist/ops-ObfXLHYQ.js +1269 -0
  89. package/dist/{range-C_vpUjBu.js → range-BsFU-SNG.js} +1 -1
  90. package/dist/{reshape-z51Eu-re.js → reshape-DxTPgnwL.js} +3 -3
  91. package/dist/{sin-H567uayl.js → sin-BOX-JVAj.js} +5 -5
  92. package/dist/slice_util-D-kaD4ZV.js +49 -0
  93. package/dist/{softmax-Dsxflvdl.js → softmax-BjsptB07.js} +2 -2
  94. package/dist/{split-B_k_jwud.js → split-BCbrzthj.js} +4 -4
  95. package/dist/{stack-CmqSdsfs.js → stack--cqr9Dgc.js} +2 -2
  96. package/dist/{sum-DdkDf2MG.js → sum-B_92TaHD.js} +5 -5
  97. package/dist/{tensor-BGYi41cj.js → tensor-CfiPXsW4.js} +1 -1
  98. package/dist/{tensor2d-DUr_htjt.js → tensor2d-tSxWdFMH.js} +1 -1
  99. package/dist/tfjs_backend-NucKez4s.js +1010 -0
  100. package/dist/training/AdamExt.js +1 -1
  101. package/dist/training/DatasetBuilder.js +44 -44
  102. package/dist/training/Evaluator.js +6 -6
  103. package/dist/training/FullTrainer.js +1 -1
  104. package/dist/training/Trainer.js +7 -7
  105. package/dist/training/sparseCrossEntropy.js +4 -4
  106. package/dist/utilities/dummy.js +10 -10
  107. package/dist/utilities/generate.js +3 -3
  108. package/dist/utilities/load.js +1 -1
  109. package/dist/utilities/profile.js +1 -1
  110. package/dist/utilities/save.js +10 -8
  111. package/dist/utilities/weights.js +2 -2
  112. package/dist/{zeros-8xl-W2DC.js → zeros-NMYTayy7.js} +3 -3
  113. package/package.json +1 -1
  114. package/dist/slice_util-BdhYwFY_.js +0 -90
  115. package/dist/tfjs_backend-DuKis_xG.js +0 -2271
  116. package/dist/variable-BJTZ3jOy.js +0 -23
@@ -1,7 +1,7 @@
1
- import { r as G, t as P, e as R, b as I, n as k, O as L, j as F, Q as U } from "../../index--6vO-cOz.js";
2
- import { r as g } from "../../Reshape-CiAY8ltP.js";
3
- import { u as H } from "../../gpgpu_math-CUzjlO9A.js";
4
- import { m as z } from "../../mat_mul-BEHRPMh0.js";
1
+ import { r as C, t as R, e as I, p as G, N as L, k as F, O as U } from "../../index-iNhkcAEQ.js";
2
+ import { r as S } from "../../Reshape-BE5rA4rT.js";
3
+ import { u as H } from "../../gpgpu_math-C0zyxKFi.js";
4
+ import { m as B } from "../../mat_mul-D0SifYfJ.js";
5
5
  /**
6
6
  * @license
7
7
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -19,39 +19,39 @@ import { m as z } from "../../mat_mul-BEHRPMh0.js";
19
19
  * =============================================================================
20
20
  */
21
21
  class W {
22
- constructor(e, s, a, n = !1, c = !1, o = !1, r = null, i = !1, u = !1) {
23
- this.variableNames = ["matrixA", "matrixB"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape = a, this.enableShapeUniforms = H(this.outputShape.length);
24
- const p = n ? e[1] : e[2], l = Math.ceil(p / 2), b = n ? "i * 2, rc.y" : "rc.y, i * 2", M = c ? "rc.z, i * 2" : "i * 2, rc.z", h = n ? ["a.xxyy", "a.zzww"] : ["a.xxzz", "a.yyww"], d = c ? ["b.xzxz", "b.ywyw"] : ["b.xyxy", "b.zwzw"];
25
- let m = "", v = "";
26
- r && (i ? m = `vec4 activation(vec4 a) {
22
+ constructor(e, s, n, a = !1, c = !1, o = !1, r = null, u = !1, l = !1) {
23
+ this.variableNames = ["matrixA", "matrixB"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape = n, this.enableShapeUniforms = H(this.outputShape.length);
24
+ const h = a ? e[1] : e[2], p = Math.ceil(h / 2), d = a ? "i * 2, rc.y" : "rc.y, i * 2", $ = c ? "rc.z, i * 2" : "i * 2, rc.z", x = a ? ["a.xxyy", "a.zzww"] : ["a.xxzz", "a.yyww"], m = c ? ["b.xzxz", "b.ywyw"] : ["b.xyxy", "b.zwzw"];
25
+ let i = "", b = "";
26
+ r && (u ? i = `vec4 activation(vec4 a) {
27
27
  vec4 b = getPreluActivationWeightsAtOutCoords();
28
28
  ${r}
29
- }` : u ? m = `vec4 activation(vec4 a) {
29
+ }` : l ? i = `vec4 activation(vec4 a) {
30
30
  vec4 b = getLeakyreluAlphaAtOutCoords();
31
31
  ${r}
32
- }` : m = `vec4 activation(vec4 x) {
32
+ }` : i = `vec4 activation(vec4 x) {
33
33
  ${r}
34
- }`, v = "result = activation(result);");
35
- const $ = o ? "result += getBiasAtOutCoords();" : "";
36
- o && this.variableNames.push("bias"), i && this.variableNames.push("preluActivationWeights"), u && this.variableNames.push("leakyreluAlpha");
37
- let f = "rc.x", x = "rc.x";
38
- e[0] < s[0] ? f = `imod(rc.x, ${e[0]})` : s[0] < e[0] && (x = `imod(rc.x, ${s[0]})`), this.userCode = `
39
- ${m}
34
+ }`, b = "result = activation(result);");
35
+ const M = o ? "result += getBiasAtOutCoords();" : "";
36
+ o && this.variableNames.push("bias"), u && this.variableNames.push("preluActivationWeights"), l && this.variableNames.push("leakyreluAlpha");
37
+ let f = "rc.x", v = "rc.x";
38
+ e[0] < s[0] ? f = `imod(rc.x, ${e[0]})` : s[0] < e[0] && (v = `imod(rc.x, ${s[0]})`), this.userCode = `
39
+ ${i}
40
40
  // Don't use uniform for sharedDimensionPacked for performance.
41
- const float sharedDimension = ${l}.0;
41
+ const float sharedDimension = ${p}.0;
42
42
 
43
43
  vec4 dot2x2ARowBCol(ivec3 rc) {
44
44
  vec4 result = vec4(0);
45
45
  int batchA = ${f};
46
- int batchB = ${x};
47
- for (int i = 0; i < ${l}; i++) {
48
- vec4 a = getMatrixA(batchA, ${b});
49
- vec4 b = getMatrixB(batchB, ${M});
46
+ int batchB = ${v};
47
+ for (int i = 0; i < ${p}; i++) {
48
+ vec4 a = getMatrixA(batchA, ${d});
49
+ vec4 b = getMatrixB(batchB, ${$});
50
50
 
51
51
  // These swizzled products need to be separately added.
52
52
  // See: https://github.com/tensorflow/tfjs/issues/1735
53
- result += (${h[0]} * ${d[0]});
54
- result += (${h[1]} * ${d[1]});
53
+ result += (${x[0]} * ${m[0]});
54
+ result += (${x[1]} * ${m[1]});
55
55
  }
56
56
  return result;
57
57
  }
@@ -60,97 +60,99 @@ class W {
60
60
  ivec3 rc = getOutputCoords();
61
61
  vec4 result = dot2x2ARowBCol(rc);
62
62
 
63
- ${$}
63
+ ${M}
64
64
 
65
- ${v}
65
+ ${b}
66
66
 
67
67
  setOutput(result);
68
68
  }
69
69
  `;
70
70
  }
71
71
  }
72
- const S = 0.7978845608028654, w = 0.044715, j = `
72
+ const g = 0.7978845608028654, w = 0.044715, j = `
73
73
  vec4 x3 = x * x * x;
74
74
  vec4 inner = x + ${w} * x3;
75
- inner = ${S} * inner;
75
+ inner = ${g} * inner;
76
76
  inner = tanh(inner);
77
77
  inner = 0.5 * (1.0 + inner);
78
78
  vec4 result = x * inner;
79
79
  return result;
80
80
  `, q = `
81
- vec4 x2 = x * x;
82
- vec4 x3 = x2 * x;
83
- vec4 u = ${S} * (x + ${w} * x3);
81
+ vec4 a2 = a * a;
82
+ vec4 a3 = a2 * a;
83
+ vec4 u = ${g} * (a + ${w} * a3);
84
84
  vec4 t = tanh(u);
85
85
  vec4 sech2 = 1.0 - t * t;
86
- vec4 du_dx = ${S} * (1.0 + 3.0 * ${w} * x2);
87
- vec4 dgelu = 0.5 * (1.0 + t) + 0.5 * x * sech2 * du_dx;
88
- return dgelu;
86
+ vec4 du_dx = ${g} * (1.0 + 3.0 * ${w} * a2);
87
+ vec4 dgelu = 0.5 * (1.0 + t) + 0.5 * a * sech2 * du_dx;
88
+ return dgelu * b;
89
89
  `, se = 1e3;
90
- function B({
90
+ function O({
91
91
  a: t,
92
92
  b: e,
93
93
  transposeA: s,
94
- transposeB: a,
95
- backend: n,
96
- activationSnippet: c
94
+ transposeB: n,
95
+ backend: a,
96
+ activationSnippet: c,
97
+ multiplier: o
97
98
  }) {
98
- const o = t.shape.length, r = e.shape.length, i = s ? t.shape[o - 2] : t.shape[o - 1], u = a ? e.shape[r - 1] : e.shape[r - 2], p = s ? t.shape[o - 1] : t.shape[o - 2], l = a ? e.shape[r - 2] : e.shape[r - 1], b = t.shape.slice(0, -2), M = e.shape.slice(0, -2), h = k(b), d = k(M), v = L(t.shape.slice(0, -2), e.shape.slice(0, -2)).concat([p, l]);
99
+ const r = t.shape.length, u = e.shape.length, l = s ? t.shape[r - 2] : t.shape[r - 1], h = n ? e.shape[u - 1] : e.shape[u - 2], p = s ? t.shape[r - 1] : t.shape[r - 2], d = n ? e.shape[u - 2] : e.shape[u - 1], $ = t.shape.slice(0, -2), x = e.shape.slice(0, -2), m = G($), i = G(x), M = L(t.shape.slice(0, -2), e.shape.slice(0, -2)).concat([p, d]);
99
100
  F(
100
- i === u,
101
- () => `Error in matMul: inner shapes (${i}) and (${u}) of Tensors with shapes ${t.shape} and ${e.shape} and transposeA=${s} and transposeB=${a} must match.`
101
+ l === h,
102
+ () => `Error in matMul: inner shapes (${l}) and (${h}) of Tensors with shapes ${t.shape} and ${e.shape} and transposeA=${s} and transposeB=${n} must match.`
102
103
  );
103
- const $ = s ? [h, i, p] : [h, p, i], f = a ? [d, l, u] : [d, u, l], x = g({ inputs: { x: t }, backend: n, attrs: { shape: $ } }), A = g({ inputs: { x: e }, backend: n, attrs: { shape: f } }), y = [x, A], C = Math.max(h, d), O = c, E = U(t.dtype, e.dtype), N = new W(
104
- $,
104
+ const f = s ? [m, l, p] : [m, p, l], v = n ? [i, d, h] : [i, h, d], A = S({ inputs: { x: t }, backend: a, attrs: { shape: f } }), y = S({ inputs: { x: e }, backend: a, attrs: { shape: v } }), k = [A, y], N = Math.max(m, i), E = c, T = U(t.dtype, e.dtype), _ = new W(
105
105
  f,
106
- [C, p, l],
106
+ v,
107
+ [N, p, d],
107
108
  s,
108
- a,
109
- !1,
110
- O,
109
+ n,
111
110
  !1,
111
+ E,
112
+ !!o,
112
113
  !1
113
- ), T = [x, A], D = n.runWebGLProgram(N, T, E), _ = g({ inputs: { x: D }, backend: n, attrs: { shape: v } });
114
- y.push(D);
115
- for (const K of y)
116
- n.disposeIntermediateTensorInfo(K);
117
- return _;
114
+ ), D = [A, y];
115
+ o && D.push(o);
116
+ const z = a.runWebGLProgram(_, D, T), K = S({ inputs: { x: z }, backend: a, attrs: { shape: M } });
117
+ k.push(z);
118
+ for (const P of k)
119
+ a.disposeIntermediateTensorInfo(P);
120
+ return K;
118
121
  }
119
- function Q(t) {
120
- const { inputs: e, backend: s } = t, { x: a, kernel: n } = e;
121
- if (a === void 0 || n === void 0)
122
+ function J(t) {
123
+ const { inputs: e, backend: s } = t, { x: n, kernel: a } = e;
124
+ if (n === void 0 || a === void 0)
122
125
  throw new Error("BatchMatMul requires two input tensors.");
123
- return B({
124
- a,
125
- b: n,
126
+ return O({
127
+ a: n,
128
+ b: a,
126
129
  transposeA: !1,
127
130
  transposeB: !1,
128
131
  backend: s,
129
132
  activationSnippet: j
130
133
  });
131
134
  }
132
- const J = {
135
+ const Q = {
133
136
  kernelName: "MatMulGelu",
134
137
  backendName: "webgl",
135
- kernelFunc: Q
138
+ kernelFunc: J
136
139
  };
137
- G(J);
140
+ C(Q);
138
141
  function V(t) {
139
- const { dy: e, x: s, kernel: a } = t.inputs, n = t.backend;
140
- return P(() => {
141
- const c = R().makeTensorFromTensorInfo(
142
- B({
142
+ const { dy: e, x: s, kernel: n } = t.inputs, a = t.backend;
143
+ return R(() => {
144
+ const c = I().makeTensorFromTensorInfo(
145
+ O({
143
146
  a: s,
144
- b: a,
147
+ b: n,
145
148
  transposeA: !1,
146
149
  transposeB: !1,
147
- backend: n,
148
- activationSnippet: q
150
+ backend: a,
151
+ activationSnippet: q,
152
+ multiplier: e
149
153
  })
150
- ), o = I(e, c);
151
- c.dispose();
152
- const r = z(o, a, !1, !0), i = z(s, o, !0, !1);
153
- return [r, i];
154
+ ), o = B(c, n, !1, !0), r = B(s, c, !0, !1);
155
+ return [o, r];
154
156
  });
155
157
  }
156
158
  const X = {
@@ -158,9 +160,9 @@ const X = {
158
160
  backendName: "webgl",
159
161
  kernelFunc: V
160
162
  };
161
- G(X);
163
+ C(X);
162
164
  export {
163
165
  se as MATMUL_SHARED_DIM_THRESHOLD,
164
- B as batchMatMulGeluImpl,
165
- Q as batchMatMulKernel
166
+ O as batchMatMulGeluImpl,
167
+ J as batchMatMulKernel
166
168
  };
@@ -0,0 +1,14 @@
1
+ import { TensorInfo } from '@tensorflow/tfjs-core';
2
+ import { MathBackendWebGL } from '@tensorflow/tfjs-backend-webgl';
3
+ export declare function batchMatMulKernel(args: {
4
+ inputs: {
5
+ x: TensorInfo;
6
+ kernel: TensorInfo;
7
+ y: TensorInfo;
8
+ };
9
+ attrs: {
10
+ transposeA: boolean;
11
+ transposeB: boolean;
12
+ };
13
+ backend: MathBackendWebGL;
14
+ }): TensorInfo;
@@ -0,0 +1,28 @@
1
+ import { r as u } from "../../index-iNhkcAEQ.js";
2
+ import { batchMatMulGeluImpl as c } from "./matMulGelu.js";
3
+ const M = `
4
+ return a * b;
5
+ `;
6
+ function p(r) {
7
+ const { inputs: n, backend: o, attrs: a } = r, { x: t, kernel: e, y: l } = n, { transposeA: i, transposeB: s } = a;
8
+ if (t === void 0 || e === void 0)
9
+ throw new Error("BatchMatMul requires two input tensors.");
10
+ return c({
11
+ a: t,
12
+ b: e,
13
+ transposeA: i,
14
+ transposeB: s,
15
+ backend: o,
16
+ activationSnippet: M,
17
+ multiplier: l
18
+ });
19
+ }
20
+ const m = {
21
+ kernelName: "MatMulMul",
22
+ backendName: "webgl",
23
+ kernelFunc: p
24
+ };
25
+ u(m);
26
+ export {
27
+ p as batchMatMulKernel
28
+ };
@@ -1,4 +1,4 @@
1
- import { r as m } from "../../index--6vO-cOz.js";
1
+ import { r as m } from "../../index-iNhkcAEQ.js";
2
2
  class f {
3
3
  variableNames = ["a", "b"];
4
4
  outputShape;
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,86 @@
1
+ import { r as p, e as G } from "../../index-iNhkcAEQ.js";
2
+ import { s as x } from "../../sum-B_92TaHD.js";
3
+ class y {
4
+ variableNames = ["x", "meanSquare", "gamma"];
5
+ outputShape;
6
+ userCode;
7
+ constructor(a, e, o) {
8
+ this.outputShape = [a, e, o], this.userCode = `
9
+ void main() {
10
+ ivec3 coords = getOutputCoords();
11
+ float x = getXAtOutCoords();
12
+ float meanSquare = getMeanSquare(coords.x, coords.y, 0);
13
+ float gamma = getGammaAtOutCoords();
14
+ float invRms = inversesqrt(meanSquare + 1e-8);
15
+ float normalized = x * invRms;
16
+ float outVal = normalized * gamma;
17
+ setOutput(outVal);
18
+ }
19
+ `;
20
+ }
21
+ }
22
+ function v(t) {
23
+ const { x: a, gamma: e } = t.inputs, o = t.backend, r = a.shape[0], n = a.shape[1], m = a.shape[2], u = a.square().mean(-1, !0), s = new y(r, n, m);
24
+ return o.runWebGLProgram(s, [a, u, e], "float32");
25
+ }
26
+ const C = {
27
+ kernelName: "RMSNorm",
28
+ backendName: "webgl",
29
+ kernelFunc: v
30
+ };
31
+ p(C);
32
+ class b {
33
+ variableNames = ["x", "meanSquare", "dyGamma", "dyXMean"];
34
+ outputShape;
35
+ userCode;
36
+ constructor(a, e, o) {
37
+ this.outputShape = [a, e, o], this.userCode = `
38
+ void main() {
39
+ ivec3 coords = getOutputCoords();
40
+ float x = getXAtOutCoords();
41
+ float meanSquare = getMeanSquare(coords.x, coords.y, 0) + 1e-8;
42
+ float dyGamma = getDyGammaAtOutCoords();
43
+ float dyXMean = getDyXMean(coords.x, coords.y, 0) / ${o}.0;
44
+ float invRms = inversesqrt(meanSquare);
45
+ float dx = dyGamma * invRms - x * dyXMean * invRms / meanSquare;
46
+ setOutput(dx);
47
+ }
48
+ `;
49
+ }
50
+ }
51
+ class N {
52
+ variableNames = ["x", "meanSquare", "dy"];
53
+ outputShape;
54
+ userCode;
55
+ constructor(a, e, o) {
56
+ this.outputShape = [a, e, o], this.userCode = `
57
+ void main() {
58
+ ivec3 coords = getOutputCoords();
59
+ float x = getXAtOutCoords();
60
+ float meanSquare = getMeanSquare(coords.x, coords.y, 0) + 1e-8;
61
+ float dy = getDyAtOutCoords();
62
+ float invRms = inversesqrt(meanSquare);
63
+ float dGamma = dy * (x * invRms);
64
+ setOutput(dGamma);
65
+ }
66
+ `;
67
+ }
68
+ }
69
+ function M(t) {
70
+ const { dy: a, x: e, gamma: o } = t.inputs, r = t.backend, n = e.shape[0], m = e.shape[1], u = e.shape[2], s = a.mul(o), c = s.mul(e), i = c.sum(-1, !0);
71
+ c.dispose();
72
+ const l = e.square(), d = l.mean(-1, !0);
73
+ l.dispose();
74
+ const f = new b(n, m, u), S = r.runWebGLProgram(f, [e, d, s, i], "float32");
75
+ s.dispose(), i.dispose();
76
+ const h = new N(n, m, u), g = r.runWebGLProgram(h, [e, d, a], "float32");
77
+ d.dispose();
78
+ const q = x(G().makeTensorFromTensorInfo(g), [0, 1]);
79
+ return r.disposeIntermediateTensorInfo(g), [S, q];
80
+ }
81
+ const k = {
82
+ kernelName: "RMSNormGrad",
83
+ backendName: "webgl",
84
+ kernelFunc: M
85
+ };
86
+ p(k);
@@ -1,4 +1,4 @@
1
- import { r as i } from "../../index--6vO-cOz.js";
1
+ import { r as i } from "../../index-iNhkcAEQ.js";
2
2
  class l {
3
3
  variableNames = ["x", "kernel"];
4
4
  outputShape;
@@ -1,4 +1,4 @@
1
- import { r as u } from "../../index--6vO-cOz.js";
1
+ import { r as u } from "../../index-iNhkcAEQ.js";
2
2
  class l {
3
3
  variableNames = ["x", "sin", "cos"];
4
4
  outputShape;
@@ -1,4 +1,4 @@
1
- import { r as i } from "../../index--6vO-cOz.js";
1
+ import { r as i } from "../../index-iNhkcAEQ.js";
2
2
  class u {
3
3
  variableNames = ["labels", "softmaxProbs", "dy"];
4
4
  outputShape;