@genai-fi/nanogpt 0.4.4 → 0.4.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -132,7 +132,7 @@ class wt extends B {
132
132
  }) : (this.ropeCache = new K(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = st({ rate: this.config.gpt.dropout }), this.blocks = [];
133
133
  for (let e = 0; e < this.config.gpt.nLayer; e++)
134
134
  this.blocks.push(new W(e, this.config));
135
- this.lnF = new N(this.config, 1e-8, "final_rms_norm");
135
+ this.lnF = new N(this.config, "final_rms_norm");
136
136
  }
137
137
  get checkpointing() {
138
138
  return this.config.layerConfig.checkpointAttention === !0 || this.config.layerConfig.checkpointMLP === !0;
@@ -3,8 +3,8 @@ import l from "./NanoGPTModel.js";
3
3
  import { saveModel as d } from "./utilities/save.js";
4
4
  import { loadModel as f } from "./utilities/load.js";
5
5
  import u from "./Generator.js";
6
- import _ from "./Trainer.js";
7
- import { E as p } from "./index-Dwqa6Zy2.js";
6
+ import p from "./Trainer.js";
7
+ import { E as _ } from "./index-Dwqa6Zy2.js";
8
8
  import { dummyPassAsync as m } from "./utilities/dummy.js";
9
9
  import c from "./tokeniser/CharTokeniser.js";
10
10
  import g from "./tokeniser/bpe.js";
@@ -37,9 +37,12 @@ import "./ops/grads/matMulGelu.js";
37
37
  import "./ops/cpu/gelu.js";
38
38
  import "./ops/webgl/gelu.js";
39
39
  import "./ops/grads/gelu.js";
40
+ import "./ops/cpu/normRMS.js";
41
+ import "./ops/webgl/normRMS.js";
42
+ import "./ops/grads/normRMS.js";
40
43
  import w from "./utilities/profile.js";
41
44
  class a {
42
- ee = new p();
45
+ ee = new _();
43
46
  _config;
44
47
  _model;
45
48
  _tokeniser;
@@ -126,7 +129,7 @@ class a {
126
129
  trainer() {
127
130
  if (!this._model || !this._tokeniser)
128
131
  throw new Error("Model or tokeniser is not initialized.");
129
- const t = new _(this._model, this._tokeniser);
132
+ const t = new p(this._model, this._tokeniser);
130
133
  return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e) => {
131
134
  const i = this.ee.listeners("trainStep");
132
135
  for (const o of i)
@@ -6,12 +6,12 @@ import { appendCache as E } from "../ops/appendCache.js";
6
6
  import { D as z, F as S, t as $, c as L, e as j, H as O } from "../index--6vO-cOz.js";
7
7
  import { fusedSoftmax as _ } from "../ops/fusedSoftmax.js";
8
8
  import { l as W, w as M, d as x } from "../tfjs_backend-DuKis_xG.js";
9
- import { o as N } from "../ones-D6kB8bdY.js";
10
- import { v as A } from "../variable-BJTZ3jOy.js";
11
- import { z as q } from "../zeros-8xl-W2DC.js";
9
+ import { o as q } from "../ones-D6kB8bdY.js";
10
+ import { v as b } from "../variable-BJTZ3jOy.js";
11
+ import { z as B } from "../zeros-8xl-W2DC.js";
12
12
  import { r as C, d as I } from "../dropout-DFEXTPV0.js";
13
- import { r as B } from "../reshape-z51Eu-re.js";
14
- import { m as F } from "../mat_mul-BEHRPMh0.js";
13
+ import { r as F } from "../reshape-z51Eu-re.js";
14
+ import { m as H } from "../mat_mul-BEHRPMh0.js";
15
15
  class nt extends T {
16
16
  cAttn = null;
17
17
  cProj = null;
@@ -23,16 +23,16 @@ class nt extends T {
23
23
  units;
24
24
  projUnits;
25
25
  constructor(t, s) {
26
- super(s), this.index = t, this.units = s.gpt.nEmbed * 3, this.projUnits = s.gpt.nEmbed, this.bias = W.bandPart(N([s.gpt.blockSize, s.gpt.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(s.gpt.nEmbed / s.gpt.nHead);
27
- const o = q([s.gpt.blockSize, s.gpt.blockSize]), e = z([s.gpt.blockSize, s.gpt.blockSize], Number.NEGATIVE_INFINITY);
28
- this.maskInf = M(this.bias, o, e);
26
+ super(s), this.index = t, this.units = s.gpt.nEmbed * 3, this.projUnits = s.gpt.nEmbed, this.bias = W.bandPart(q([s.gpt.blockSize, s.gpt.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(s.gpt.nEmbed / s.gpt.nHead);
27
+ const e = B([s.gpt.blockSize, s.gpt.blockSize]), o = z([s.gpt.blockSize, s.gpt.blockSize], Number.NEGATIVE_INFINITY);
28
+ this.maskInf = M(this.bias, e, o);
29
29
  }
30
30
  build() {
31
- this.cAttn === null && (this.cAttn = A(
31
+ this.cAttn === null && (this.cAttn = b(
32
32
  C([this.config.gpt.nEmbed, this.units], 0, 0.02),
33
33
  !0
34
34
  //`block_${this.index}_attn_cAttn_kernel`
35
- )), this.cProj === null && (this.cProj = A(
35
+ )), this.cProj === null && (this.cProj = b(
36
36
  C([this.projUnits, this.config.gpt.nEmbed], 0, 0.02),
37
37
  !0
38
38
  //`block_${this.index}_attn_cProj_kernel`
@@ -53,57 +53,58 @@ class nt extends T {
53
53
  t.set(`block_${this.index}_cAttn`, this.cAttn ? [this.cAttn.clone()] : []), t.set(`block_${this.index}_cProj`, this.cProj ? [this.cProj.clone()] : []);
54
54
  }
55
55
  loadWeights(t) {
56
- const s = t.get(`block_${this.index}_cAttn`)?.[0], o = t.get(`block_${this.index}_cProj`)?.[0];
56
+ const s = t.get(`block_${this.index}_cAttn`)?.[0], e = t.get(`block_${this.index}_cProj`)?.[0];
57
57
  if (!s) throw new Error(`Weights for block_${this.index}_cAttn not found`);
58
- if (!o) throw new Error(`Weights for block_${this.index}_cProj not found`);
59
- this.cAttn ? this.cAttn.assign(s) : this.cAttn = A(s, !0), this.cProj ? this.cProj.assign(o) : this.cProj = A(o, !0);
58
+ if (!e) throw new Error(`Weights for block_${this.index}_cProj not found`);
59
+ this.cAttn ? this.cAttn.assign(s) : this.cAttn = b(s, !0), this.cProj ? this.cProj.assign(e) : this.cProj = b(e, !0);
60
60
  }
61
- getAttentionScores(t, s, o, e) {
61
+ getAttentionScores(t, s, e, o) {
62
62
  const i = P(t, s, this.divisor, this.maskInf);
63
- return _(i, o ? this.config.gpt.dropout : 0, e);
63
+ return _(i, e ? this.config.gpt.dropout : 0, o);
64
64
  }
65
65
  // Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
66
- getAttentionScoresWithPast(t, s, o) {
67
- const e = P(t, s, this.divisor, void 0, o);
68
- return _(e, 0, 0);
66
+ getAttentionScoresWithPast(t, s, e) {
67
+ const o = P(t, s, this.divisor, void 0, e);
68
+ return _(o, 0, 0);
69
69
  }
70
70
  getQKV(t) {
71
71
  return y(t, this.cAttn, this.config.gpt.nHead);
72
72
  }
73
73
  getOutputProjection(t) {
74
- const s = t.shape[0], o = t.shape[2], e = this.config.gpt.nEmbed, i = t.transpose([0, 2, 1, 3]), n = B(i, [s, o, e]);
74
+ const s = t.shape[0], e = t.shape[2], o = this.config.gpt.nEmbed, i = t.transpose([0, 2, 1, 3]), n = F(i, [s, e, o]);
75
75
  return x(n, this.cProj);
76
76
  }
77
- updateCache(t, s, o, e) {
78
- const i = this.config.gpt.blockSize, n = t.shape[2], r = e?.length || 0, a = o ? t : E(t, i, r, e?.k), p = o ? s : E(s, i, r, e?.v);
79
- return {
77
+ updateCache(t, s, e, o) {
78
+ const i = this.config.gpt.blockSize, n = t.shape[2], r = o?.length || 0, a = e ? t : E(t, i, r, o?.k);
79
+ e || (t.dispose(), o?.k.dispose());
80
+ const p = e ? s : E(s, i, r, o?.v);
81
+ return e || (s.dispose(), o?.v.dispose()), {
80
82
  k: S(a),
81
83
  v: S(p),
82
84
  length: Math.min(r + n, i),
83
- cumulativeLength: e ? e.cumulativeLength + n : n
85
+ cumulativeLength: o ? o.cumulativeLength + n : n
84
86
  };
85
87
  }
86
- forward(t, s = !1, o, e = !1, i) {
88
+ forward(t, s = !1, e, o = !1, i) {
87
89
  return $(() => {
88
90
  this.startMemory();
89
- const [n, r, a] = this.getQKV(t), p = i ? i.cumulativeLength : 0, c = this.config.layerConfig.ropeCache, u = c ? w(n, c, p) : n, f = c ? w(r, c, p) : r;
91
+ const [n, r, a] = this.getQKV(t), p = i ? i.cumulativeLength : 0, c = this.config.layerConfig.ropeCache, u = c ? w(n, c, p) : n, A = c ? w(r, c, p) : r;
90
92
  c && (n.dispose(), r.dispose());
91
- const g = i ? i.length : 0, d = this.updateCache(f, a, s, i), l = d.k, m = d.v;
92
- i && (f.dispose(), a.dispose());
93
+ const f = i ? i.length : 0, d = this.updateCache(A, a, s, i), l = d.k, g = d.v;
93
94
  let h;
94
- g > 0 ? h = this.getAttentionScoresWithPast(u, l, g) : h = this.getAttentionScores(u, l, s, o), u.dispose(), s && l.dispose();
95
- const b = F(h, m);
96
- e || h.dispose(), s && m.dispose();
97
- const k = this.getOutputProjection(b);
98
- b.dispose();
99
- const v = e ? h.mean(1) : void 0;
95
+ f > 0 ? h = this.getAttentionScoresWithPast(u, l, f) : h = this.getAttentionScores(u, l, s, e), u.dispose(), s && l.dispose();
96
+ const m = H(h, g);
97
+ o || h.dispose(), s && g.dispose();
98
+ const k = this.getOutputProjection(m);
99
+ m.dispose();
100
+ const v = o ? h.mean(1) : void 0;
100
101
  return this.endMemory("CausalSelfAttention"), { output: k, attention: v, presentKV: s ? void 0 : d };
101
102
  });
102
103
  }
103
- call(t, s = !1, o = !1, e) {
104
- if (e && !this.config.gpt.useRope)
104
+ call(t, s = !1, e = !1, o) {
105
+ if (o && !this.config.gpt.useRope)
105
106
  throw new Error("Cannot use pastKV without RoPE enabled");
106
- if (s && e)
107
+ if (s && o)
107
108
  throw new Error("Cannot use pastKV during training");
108
109
  if (t.shape.length !== 3)
109
110
  throw new Error(`Input tensor must be rank 3 [B, T, C], got shape ${t.shape}`);
@@ -115,15 +116,15 @@ class nt extends T {
115
116
  const r = L(
116
117
  // @ts-expect-error Invalid params
117
118
  (a, p, c, u) => {
118
- const f = this.forward(a, !0, i);
119
+ const A = this.forward(a, !0, i);
119
120
  u([a]);
120
- const g = (d, l) => {
121
- const [m] = l, h = j().state.activeTape;
121
+ const f = (d, l) => {
122
+ const [g] = l, h = j().state.activeTape;
122
123
  j().state.activeTape = [];
123
- const b = O((k, v, R) => this.forward(k, !0, i).output)([m, p, c], d);
124
- return j().state.activeTape = h, b;
124
+ const m = O((k, v, R) => this.forward(k, !0, i).output)([g, p, c], d);
125
+ return j().state.activeTape = h, m;
125
126
  };
126
- return { value: f.output, gradFunc: g };
127
+ return { value: A.output, gradFunc: f };
127
128
  }
128
129
  )(t, this.cAttn, this.cProj);
129
130
  if (this.config.gpt.dropout > 0) {
@@ -132,7 +133,7 @@ class nt extends T {
132
133
  } else
133
134
  return { output: r };
134
135
  } else {
135
- const n = this.forward(t, s, i, o, e);
136
+ const n = this.forward(t, s, i, e, o);
136
137
  if (this.config.gpt.dropout > 0) {
137
138
  const r = I(n.output, this.config.gpt.dropout);
138
139
  return n.output.dispose(), { output: r, attention: n.attention, presentKV: n.presentKV };
@@ -2,8 +2,7 @@ import { Tensor, Variable } from '@tensorflow/tfjs-core';
2
2
  import { default as BaseLayer, GPTLayerConfig } from './BaseLayer';
3
3
  export default class RMSNorm extends BaseLayer {
4
4
  private gamma;
5
- private epsilon;
6
- constructor(config: GPTLayerConfig, epsilon?: number, name?: string);
5
+ constructor(config: GPTLayerConfig, name?: string);
7
6
  get trainableWeights(): Variable[];
8
7
  set trainable(value: boolean);
9
8
  getWeights(): Tensor[];
@@ -1,12 +1,12 @@
1
1
  import { t as r } from "../index--6vO-cOz.js";
2
2
  import m from "./BaseLayer.js";
3
- import { v as i } from "../variable-BJTZ3jOy.js";
4
- import { o } from "../ones-D6kB8bdY.js";
5
- class d extends m {
3
+ import { normRMS as s } from "../ops/normRMS.js";
4
+ import { v as e } from "../variable-BJTZ3jOy.js";
5
+ import { o as i } from "../ones-D6kB8bdY.js";
6
+ class u extends m {
6
7
  gamma;
7
- epsilon;
8
- constructor(t, s = 1e-8, a = "") {
9
- super(t), this.epsilon = s, this.gamma = i(o([t.gpt.nEmbed]), !0, `${a}_gamma`, "float32");
8
+ constructor(t, a = "") {
9
+ super(t), this.gamma = e(i([t.gpt.nEmbed]), !0, `${a}_gamma`, "float32");
10
10
  }
11
11
  get trainableWeights() {
12
12
  return [this.gamma];
@@ -23,8 +23,8 @@ class d extends m {
23
23
  apply(t) {
24
24
  return r(() => {
25
25
  this.startMemory();
26
- const a = t.square().mean(-1, !0).add(this.epsilon).rsqrt(), e = t.mul(a).mul(this.gamma);
27
- return this.endMemory("RMSNorm"), e;
26
+ const a = s(t, this.gamma);
27
+ return this.endMemory("RMSNorm"), a;
28
28
  });
29
29
  }
30
30
  dispose() {
@@ -32,5 +32,5 @@ class d extends m {
32
32
  }
33
33
  }
34
34
  export {
35
- d as default
35
+ u as default
36
36
  };
@@ -12,7 +12,7 @@ class W extends p {
12
12
  _trainable = !0;
13
13
  skipped = !1;
14
14
  constructor(t, s) {
15
- super(s), this.index = t, this.ln1 = new a(s, 1e-8, `block_${this.index}_rms1`), this.attn = new h(this.index, s), this.ln2 = new a(s, 1e-8, `block_${this.index}_rms2`), this.mlp = new o(this.index, s);
15
+ super(s), this.index = t, this.ln1 = new a(s, `block_${this.index}_rms1`), this.attn = new h(this.index, s), this.ln2 = new a(s, `block_${this.index}_rms2`), this.mlp = new o(this.index, s);
16
16
  }
17
17
  get variables() {
18
18
  return [
package/dist/main.js CHANGED
@@ -1,10 +1,10 @@
1
- import { default as w } from "./NanoGPTModel.js";
2
- import { default as D } from "./TeachableLLM.js";
3
- import { default as F } from "./tokeniser/CharTokeniser.js";
4
- import { default as N } from "./tokeniser/bpe.js";
5
- import { default as j } from "./utilities/waitForModel.js";
6
- import { default as z } from "./data/textLoader.js";
7
- import { estimateMemoryUsage as H, estimateParameterCount as I, estimateResources as J, estimateTrainingMemoryUsage as K, validateConfig as O } from "./utilities/parameters.js";
1
+ import { default as E } from "./NanoGPTModel.js";
2
+ import { default as G } from "./TeachableLLM.js";
3
+ import { default as R } from "./tokeniser/CharTokeniser.js";
4
+ import { default as q } from "./tokeniser/bpe.js";
5
+ import { default as A } from "./utilities/waitForModel.js";
6
+ import { default as I } from "./data/textLoader.js";
7
+ import { estimateMemoryUsage as K, estimateParameterCount as O, estimateResources as Q, estimateTrainingMemoryUsage as S, validateConfig as V } from "./utilities/parameters.js";
8
8
  import "./index--6vO-cOz.js";
9
9
  import "./ops/cpu/scatterSub.js";
10
10
  import "./ops/webgl/scatterSub.js";
@@ -31,16 +31,19 @@ import "./ops/grads/matMulGelu.js";
31
31
  import "./ops/cpu/gelu.js";
32
32
  import "./ops/webgl/gelu.js";
33
33
  import "./ops/grads/gelu.js";
34
+ import "./ops/cpu/normRMS.js";
35
+ import "./ops/webgl/normRMS.js";
36
+ import "./ops/grads/normRMS.js";
34
37
  export {
35
- N as BPETokeniser,
36
- F as CharTokeniser,
37
- w as NanoGPT,
38
- D as TeachableLLM,
39
- H as estimateMemoryUsage,
40
- I as estimateParameterCount,
41
- J as estimateResources,
42
- K as estimateTrainingMemoryUsage,
43
- z as loadTextData,
44
- O as validateConfig,
45
- j as waitForModel
38
+ q as BPETokeniser,
39
+ R as CharTokeniser,
40
+ E as NanoGPT,
41
+ G as TeachableLLM,
42
+ K as estimateMemoryUsage,
43
+ O as estimateParameterCount,
44
+ Q as estimateResources,
45
+ S as estimateTrainingMemoryUsage,
46
+ I as loadTextData,
47
+ V as validateConfig,
48
+ A as waitForModel
46
49
  };
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,39 @@
1
+ import { r as o, t as d } from "../../index--6vO-cOz.js";
2
+ function i(t) {
3
+ const { inputs: e } = t, { x: n, gamma: s } = e, r = n, a = s;
4
+ return d(() => {
5
+ const u = r.square().mean(-1, !0).add(1e-8).rsqrt();
6
+ return r.mul(u).mul(a);
7
+ });
8
+ }
9
+ const f = {
10
+ kernelName: "RMSNorm",
11
+ backendName: "cpu",
12
+ kernelFunc: i
13
+ };
14
+ o(f);
15
+ const g = {
16
+ kernelName: "RMSNorm",
17
+ backendName: "tensorflow",
18
+ kernelFunc: i
19
+ };
20
+ o(g);
21
+ function N(t) {
22
+ const { dy: e, x: n, gamma: s } = t.inputs;
23
+ return d(() => {
24
+ const r = n.shape[n.shape.length - 1], a = n.square().mean(-1, !0), m = a.add(1e-8).rsqrt(), u = n.mul(m), l = e.mul(u).sum([0, 1]), c = e.mul(s), k = c.mul(n).sum(-1, !0).div(r);
25
+ return [c.mul(m).sub(n.mul(k).mul(m).div(a.add(1e-8))), l];
26
+ });
27
+ }
28
+ const S = {
29
+ kernelName: "RMSNormGrad",
30
+ backendName: "cpu",
31
+ kernelFunc: N
32
+ };
33
+ o(S);
34
+ const R = {
35
+ kernelName: "RMSNormGrad",
36
+ backendName: "tensorflow",
37
+ kernelFunc: N
38
+ };
39
+ o(R);
@@ -0,0 +1,2 @@
1
+ import { GradConfig } from '@tensorflow/tfjs-core';
2
+ export declare const normRMSGradConfig: GradConfig;
@@ -0,0 +1,20 @@
1
+ import { g as t, e as g } from "../../index--6vO-cOz.js";
2
+ function s(r, a, n) {
3
+ return g().runKernel("RMSNormGrad", { dy: r, x: a, gamma: n });
4
+ }
5
+ const u = {
6
+ kernelName: "RMSNorm",
7
+ inputsToSave: ["x", "gamma"],
8
+ outputsToSave: [],
9
+ gradFunc: (r, a) => {
10
+ const [n, e] = a, [m, o] = s(r, n, e);
11
+ return {
12
+ x: () => m,
13
+ gamma: () => o
14
+ };
15
+ }
16
+ };
17
+ t(u);
18
+ export {
19
+ u as normRMSGradConfig
20
+ };
@@ -0,0 +1,2 @@
1
+ import { Tensor } from '@tensorflow/tfjs-core';
2
+ export declare function normRMS(x: Tensor, gamma: Tensor): Tensor;
@@ -0,0 +1,10 @@
1
+ import { e as n } from "../index--6vO-cOz.js";
2
+ import "./cpu/normRMS.js";
3
+ import "./webgl/normRMS.js";
4
+ import "./grads/normRMS.js";
5
+ function p(r, o) {
6
+ return n().runKernel("RMSNorm", { x: r, gamma: o });
7
+ }
8
+ export {
9
+ p as normRMS
10
+ };
@@ -7,9 +7,10 @@ type BatchMatMulConfig = {
7
7
  transposeA: boolean;
8
8
  transposeB: boolean;
9
9
  backend: MathBackendWebGL;
10
- activationSnippet: string;
10
+ activationSnippet?: string;
11
+ multiplier?: TensorInfo;
11
12
  };
12
- export declare function batchMatMulGeluImpl({ a, b, transposeA, transposeB, backend, activationSnippet, }: BatchMatMulConfig): TensorInfo;
13
+ export declare function batchMatMulGeluImpl({ a, b, transposeA, transposeB, backend, activationSnippet, multiplier, }: BatchMatMulConfig): TensorInfo;
13
14
  export declare function batchMatMulKernel(args: {
14
15
  inputs: {
15
16
  x: TensorInfo;
@@ -1,7 +1,7 @@
1
- import { r as G, t as P, e as R, b as I, n as k, O as L, j as F, Q as U } from "../../index--6vO-cOz.js";
2
- import { r as g } from "../../Reshape-CiAY8ltP.js";
1
+ import { r as C, t as R, e as I, n as G, O as L, j as F, Q as U } from "../../index--6vO-cOz.js";
2
+ import { r as S } from "../../Reshape-CiAY8ltP.js";
3
3
  import { u as H } from "../../gpgpu_math-CUzjlO9A.js";
4
- import { m as z } from "../../mat_mul-BEHRPMh0.js";
4
+ import { m as B } from "../../mat_mul-BEHRPMh0.js";
5
5
  /**
6
6
  * @license
7
7
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -19,39 +19,39 @@ import { m as z } from "../../mat_mul-BEHRPMh0.js";
19
19
  * =============================================================================
20
20
  */
21
21
  class W {
22
- constructor(e, s, a, n = !1, c = !1, o = !1, r = null, i = !1, u = !1) {
23
- this.variableNames = ["matrixA", "matrixB"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape = a, this.enableShapeUniforms = H(this.outputShape.length);
24
- const p = n ? e[1] : e[2], l = Math.ceil(p / 2), b = n ? "i * 2, rc.y" : "rc.y, i * 2", M = c ? "rc.z, i * 2" : "i * 2, rc.z", h = n ? ["a.xxyy", "a.zzww"] : ["a.xxzz", "a.yyww"], d = c ? ["b.xzxz", "b.ywyw"] : ["b.xyxy", "b.zwzw"];
25
- let m = "", v = "";
26
- r && (i ? m = `vec4 activation(vec4 a) {
22
+ constructor(e, s, n, a = !1, c = !1, o = !1, r = null, u = !1, l = !1) {
23
+ this.variableNames = ["matrixA", "matrixB"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape = n, this.enableShapeUniforms = H(this.outputShape.length);
24
+ const h = a ? e[1] : e[2], p = Math.ceil(h / 2), d = a ? "i * 2, rc.y" : "rc.y, i * 2", $ = c ? "rc.z, i * 2" : "i * 2, rc.z", x = a ? ["a.xxyy", "a.zzww"] : ["a.xxzz", "a.yyww"], m = c ? ["b.xzxz", "b.ywyw"] : ["b.xyxy", "b.zwzw"];
25
+ let i = "", b = "";
26
+ r && (u ? i = `vec4 activation(vec4 a) {
27
27
  vec4 b = getPreluActivationWeightsAtOutCoords();
28
28
  ${r}
29
- }` : u ? m = `vec4 activation(vec4 a) {
29
+ }` : l ? i = `vec4 activation(vec4 a) {
30
30
  vec4 b = getLeakyreluAlphaAtOutCoords();
31
31
  ${r}
32
- }` : m = `vec4 activation(vec4 x) {
32
+ }` : i = `vec4 activation(vec4 x) {
33
33
  ${r}
34
- }`, v = "result = activation(result);");
35
- const $ = o ? "result += getBiasAtOutCoords();" : "";
36
- o && this.variableNames.push("bias"), i && this.variableNames.push("preluActivationWeights"), u && this.variableNames.push("leakyreluAlpha");
37
- let f = "rc.x", x = "rc.x";
38
- e[0] < s[0] ? f = `imod(rc.x, ${e[0]})` : s[0] < e[0] && (x = `imod(rc.x, ${s[0]})`), this.userCode = `
39
- ${m}
34
+ }`, b = "result = activation(result);");
35
+ const M = o ? "result += getBiasAtOutCoords();" : "";
36
+ o && this.variableNames.push("bias"), u && this.variableNames.push("preluActivationWeights"), l && this.variableNames.push("leakyreluAlpha");
37
+ let f = "rc.x", v = "rc.x";
38
+ e[0] < s[0] ? f = `imod(rc.x, ${e[0]})` : s[0] < e[0] && (v = `imod(rc.x, ${s[0]})`), this.userCode = `
39
+ ${i}
40
40
  // Don't use uniform for sharedDimensionPacked for performance.
41
- const float sharedDimension = ${l}.0;
41
+ const float sharedDimension = ${p}.0;
42
42
 
43
43
  vec4 dot2x2ARowBCol(ivec3 rc) {
44
44
  vec4 result = vec4(0);
45
45
  int batchA = ${f};
46
- int batchB = ${x};
47
- for (int i = 0; i < ${l}; i++) {
48
- vec4 a = getMatrixA(batchA, ${b});
49
- vec4 b = getMatrixB(batchB, ${M});
46
+ int batchB = ${v};
47
+ for (int i = 0; i < ${p}; i++) {
48
+ vec4 a = getMatrixA(batchA, ${d});
49
+ vec4 b = getMatrixB(batchB, ${$});
50
50
 
51
51
  // These swizzled products need to be separately added.
52
52
  // See: https://github.com/tensorflow/tfjs/issues/1735
53
- result += (${h[0]} * ${d[0]});
54
- result += (${h[1]} * ${d[1]});
53
+ result += (${x[0]} * ${m[0]});
54
+ result += (${x[1]} * ${m[1]});
55
55
  }
56
56
  return result;
57
57
  }
@@ -60,69 +60,72 @@ class W {
60
60
  ivec3 rc = getOutputCoords();
61
61
  vec4 result = dot2x2ARowBCol(rc);
62
62
 
63
- ${$}
63
+ ${M}
64
64
 
65
- ${v}
65
+ ${b}
66
66
 
67
67
  setOutput(result);
68
68
  }
69
69
  `;
70
70
  }
71
71
  }
72
- const S = 0.7978845608028654, w = 0.044715, j = `
72
+ const g = 0.7978845608028654, w = 0.044715, j = `
73
73
  vec4 x3 = x * x * x;
74
74
  vec4 inner = x + ${w} * x3;
75
- inner = ${S} * inner;
75
+ inner = ${g} * inner;
76
76
  inner = tanh(inner);
77
77
  inner = 0.5 * (1.0 + inner);
78
78
  vec4 result = x * inner;
79
79
  return result;
80
80
  `, q = `
81
- vec4 x2 = x * x;
82
- vec4 x3 = x2 * x;
83
- vec4 u = ${S} * (x + ${w} * x3);
81
+ vec4 a2 = a * a;
82
+ vec4 a3 = a2 * a;
83
+ vec4 u = ${g} * (a + ${w} * a3);
84
84
  vec4 t = tanh(u);
85
85
  vec4 sech2 = 1.0 - t * t;
86
- vec4 du_dx = ${S} * (1.0 + 3.0 * ${w} * x2);
87
- vec4 dgelu = 0.5 * (1.0 + t) + 0.5 * x * sech2 * du_dx;
88
- return dgelu;
86
+ vec4 du_dx = ${g} * (1.0 + 3.0 * ${w} * a2);
87
+ vec4 dgelu = 0.5 * (1.0 + t) + 0.5 * a * sech2 * du_dx;
88
+ return dgelu * b;
89
89
  `, se = 1e3;
90
- function B({
90
+ function O({
91
91
  a: t,
92
92
  b: e,
93
93
  transposeA: s,
94
- transposeB: a,
95
- backend: n,
96
- activationSnippet: c
94
+ transposeB: n,
95
+ backend: a,
96
+ activationSnippet: c,
97
+ multiplier: o
97
98
  }) {
98
- const o = t.shape.length, r = e.shape.length, i = s ? t.shape[o - 2] : t.shape[o - 1], u = a ? e.shape[r - 1] : e.shape[r - 2], p = s ? t.shape[o - 1] : t.shape[o - 2], l = a ? e.shape[r - 2] : e.shape[r - 1], b = t.shape.slice(0, -2), M = e.shape.slice(0, -2), h = k(b), d = k(M), v = L(t.shape.slice(0, -2), e.shape.slice(0, -2)).concat([p, l]);
99
+ const r = t.shape.length, u = e.shape.length, l = s ? t.shape[r - 2] : t.shape[r - 1], h = n ? e.shape[u - 1] : e.shape[u - 2], p = s ? t.shape[r - 1] : t.shape[r - 2], d = n ? e.shape[u - 2] : e.shape[u - 1], $ = t.shape.slice(0, -2), x = e.shape.slice(0, -2), m = G($), i = G(x), M = L(t.shape.slice(0, -2), e.shape.slice(0, -2)).concat([p, d]);
99
100
  F(
100
- i === u,
101
- () => `Error in matMul: inner shapes (${i}) and (${u}) of Tensors with shapes ${t.shape} and ${e.shape} and transposeA=${s} and transposeB=${a} must match.`
101
+ l === h,
102
+ () => `Error in matMul: inner shapes (${l}) and (${h}) of Tensors with shapes ${t.shape} and ${e.shape} and transposeA=${s} and transposeB=${n} must match.`
102
103
  );
103
- const $ = s ? [h, i, p] : [h, p, i], f = a ? [d, l, u] : [d, u, l], x = g({ inputs: { x: t }, backend: n, attrs: { shape: $ } }), A = g({ inputs: { x: e }, backend: n, attrs: { shape: f } }), y = [x, A], C = Math.max(h, d), O = c, E = U(t.dtype, e.dtype), N = new W(
104
- $,
104
+ const f = s ? [m, l, p] : [m, p, l], v = n ? [i, d, h] : [i, h, d], A = S({ inputs: { x: t }, backend: a, attrs: { shape: f } }), y = S({ inputs: { x: e }, backend: a, attrs: { shape: v } }), D = [A, y], E = Math.max(m, i), N = c, T = U(t.dtype, e.dtype), _ = new W(
105
105
  f,
106
- [C, p, l],
106
+ v,
107
+ [E, p, d],
107
108
  s,
108
- a,
109
- !1,
110
- O,
109
+ n,
111
110
  !1,
111
+ N,
112
+ !!o,
112
113
  !1
113
- ), T = [x, A], D = n.runWebGLProgram(N, T, E), _ = g({ inputs: { x: D }, backend: n, attrs: { shape: v } });
114
- y.push(D);
115
- for (const K of y)
116
- n.disposeIntermediateTensorInfo(K);
117
- return _;
114
+ ), k = [A, y];
115
+ o && k.push(o);
116
+ const z = a.runWebGLProgram(_, k, T), K = S({ inputs: { x: z }, backend: a, attrs: { shape: M } });
117
+ D.push(z);
118
+ for (const P of D)
119
+ a.disposeIntermediateTensorInfo(P);
120
+ return K;
118
121
  }
119
122
  function Q(t) {
120
- const { inputs: e, backend: s } = t, { x: a, kernel: n } = e;
121
- if (a === void 0 || n === void 0)
123
+ const { inputs: e, backend: s } = t, { x: n, kernel: a } = e;
124
+ if (n === void 0 || a === void 0)
122
125
  throw new Error("BatchMatMul requires two input tensors.");
123
- return B({
124
- a,
125
- b: n,
126
+ return O({
127
+ a: n,
128
+ b: a,
126
129
  transposeA: !1,
127
130
  transposeB: !1,
128
131
  backend: s,
@@ -134,23 +137,22 @@ const J = {
134
137
  backendName: "webgl",
135
138
  kernelFunc: Q
136
139
  };
137
- G(J);
140
+ C(J);
138
141
  function V(t) {
139
- const { dy: e, x: s, kernel: a } = t.inputs, n = t.backend;
140
- return P(() => {
141
- const c = R().makeTensorFromTensorInfo(
142
- B({
142
+ const { dy: e, x: s, kernel: n } = t.inputs, a = t.backend;
143
+ return R(() => {
144
+ const c = I().makeTensorFromTensorInfo(
145
+ O({
143
146
  a: s,
144
- b: a,
147
+ b: n,
145
148
  transposeA: !1,
146
149
  transposeB: !1,
147
- backend: n,
148
- activationSnippet: q
150
+ backend: a,
151
+ activationSnippet: q,
152
+ multiplier: e
149
153
  })
150
- ), o = I(e, c);
151
- c.dispose();
152
- const r = z(o, a, !1, !0), i = z(s, o, !0, !1);
153
- return [r, i];
154
+ ), o = B(c, n, !1, !0), r = B(s, c, !0, !1);
155
+ return [o, r];
154
156
  });
155
157
  }
156
158
  const X = {
@@ -158,9 +160,9 @@ const X = {
158
160
  backendName: "webgl",
159
161
  kernelFunc: V
160
162
  };
161
- G(X);
163
+ C(X);
162
164
  export {
163
165
  se as MATMUL_SHARED_DIM_THRESHOLD,
164
- B as batchMatMulGeluImpl,
166
+ O as batchMatMulGeluImpl,
165
167
  Q as batchMatMulKernel
166
168
  };
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,78 @@
1
+ import { r as c, e as h } from "../../index--6vO-cOz.js";
2
+ import { s as q } from "../../sum-DdkDf2MG.js";
3
+ class G {
4
+ variableNames = ["x", "meanSquare", "gamma"];
5
+ outputShape;
6
+ userCode;
7
+ constructor(e, a, o) {
8
+ this.outputShape = [e, a, o], this.userCode = `
9
+ void main() {
10
+ ivec3 coords = getOutputCoords();
11
+ float x = getXAtOutCoords();
12
+ float meanSquare = getMeanSquare(coords.x, coords.y, 0);
13
+ float gamma = getGammaAtOutCoords();
14
+ float invRms = inversesqrt(meanSquare + 1e-8);
15
+ float normalized = x * invRms;
16
+ float outVal = normalized * gamma;
17
+ setOutput(outVal);
18
+ }
19
+ `;
20
+ }
21
+ }
22
+ function v(t) {
23
+ const { x: e, gamma: a } = t.inputs, o = t.backend, r = e.shape[0], s = e.shape[1], n = e.shape[2], m = e.square().mean(-1, !0), u = new G(r, s, n);
24
+ return o.runWebGLProgram(u, [e, m, a], "float32");
25
+ }
26
+ const x = {
27
+ kernelName: "RMSNorm",
28
+ backendName: "webgl",
29
+ kernelFunc: v
30
+ };
31
+ c(x);
32
+ class y {
33
+ variableNames = ["x", "meanSquare", "dyGamma", "dyXMean"];
34
+ outputShape;
35
+ userCode;
36
+ constructor(e, a, o) {
37
+ this.outputShape = [e, a, o], this.userCode = `
38
+ void main() {
39
+ ivec3 coords = getOutputCoords();
40
+ float x = getXAtOutCoords();
41
+ float meanSquare = getMeanSquare(coords.x, coords.y, 0) + 1e-8;
42
+ float dyGamma = getDyGammaAtOutCoords();
43
+ float dyXMean = getDyXMean(coords.x, coords.y, 0) / ${o}.0;
44
+ float invRms = inversesqrt(meanSquare);
45
+ float dx = dyGamma * invRms - x * dyXMean * invRms / meanSquare;
46
+ setOutput(dx);
47
+ }
48
+ `;
49
+ }
50
+ }
51
+ class C {
52
+ variableNames = ["x", "meanSquare", "dy"];
53
+ outputShape;
54
+ userCode;
55
+ constructor(e, a, o) {
56
+ this.outputShape = [e, a, o], this.userCode = `
57
+ void main() {
58
+ ivec3 coords = getOutputCoords();
59
+ float x = getXAtOutCoords();
60
+ float meanSquare = getMeanSquare(coords.x, coords.y, 0) + 1e-8;
61
+ float dy = getDyAtOutCoords();
62
+ float invRms = inversesqrt(meanSquare);
63
+ float dGamma = dy * (x * invRms);
64
+ setOutput(dGamma);
65
+ }
66
+ `;
67
+ }
68
+ }
69
+ function b(t) {
70
+ const { dy: e, x: a, gamma: o } = t.inputs, r = t.backend, s = a.shape[0], n = a.shape[1], m = a.shape[2], u = a.square().mean(-1, !0), d = e.mul(o), l = d.mul(a).sum(-1, !0), i = new y(s, n, m), g = r.runWebGLProgram(i, [a, u, d, l], "float32"), p = new C(s, n, m), S = r.runWebGLProgram(p, [a, u, e], "float32"), f = q(h().makeTensorFromTensorInfo(S), [0, 1]);
71
+ return [g, f];
72
+ }
73
+ const N = {
74
+ kernelName: "RMSNormGrad",
75
+ backendName: "webgl",
76
+ kernelFunc: b
77
+ };
78
+ c(N);
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@genai-fi/nanogpt",
3
- "version": "0.4.4",
3
+ "version": "0.4.5",
4
4
  "type": "module",
5
5
  "main": "dist/main.js",
6
6
  "types": "dist/main.d.ts",