@genai-fi/nanogpt 0.2.11 → 0.2.12

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,6 +13,5 @@ export default class RoPECache {
13
13
  ensureRopeCache(needed: number): void;
14
14
  getCos(): TF.Tensor | null;
15
15
  getSin(): TF.Tensor | null;
16
- applyRoPE(q: TF.Tensor, k: TF.Tensor, pastLen: number): [TF.Tensor, TF.Tensor];
17
16
  dispose(): void;
18
17
  }
@@ -1,12 +1,12 @@
1
- class b {
2
- constructor(s, r) {
3
- this.tf = s, this.config = r;
4
- const o = this.config.nEmbed / this.config.nHead;
5
- if (this.rotaryDim = o, this.rotaryDim % 2 !== 0)
1
+ class n {
2
+ constructor(i, e) {
3
+ this.tf = i, this.config = e;
4
+ const t = this.config.nEmbed / this.config.nHead;
5
+ if (this.rotaryDim = t, this.rotaryDim % 2 !== 0)
6
6
  throw new Error("rotaryDim must be even");
7
7
  this.ropeBase = 1e4;
8
- const i = this.tf.range(0, this.rotaryDim, 2, "float32"), t = i.div(this.tf.scalar(this.rotaryDim, "float32")), e = this.tf.pow(this.tf.scalar(this.ropeBase, "float32"), t);
9
- this.ropeInvFreq = this.tf.reciprocal(e), t.dispose(), e.dispose(), i.dispose(), this.config.useRope === !1 ? (this.ropeCos = null, this.ropeSin = null, this.ropeCacheLen = 0) : this.tf.tidy(() => {
8
+ const s = this.tf.range(0, this.rotaryDim, 2, "float32"), o = s.div(this.tf.scalar(this.rotaryDim, "float32")), r = this.tf.pow(this.tf.scalar(this.ropeBase, "float32"), o);
9
+ this.ropeInvFreq = this.tf.reciprocal(r), o.dispose(), r.dispose(), s.dispose(), this.config.useRope === !1 ? (this.ropeCos = null, this.ropeSin = null, this.ropeCacheLen = 0) : this.tf.tidy(() => {
10
10
  this.ensureRopeCache(this.config.blockSize * 4);
11
11
  });
12
12
  }
@@ -18,11 +18,11 @@ class b {
18
18
  ropeSin = null;
19
19
  // [cacheLen, rotaryDim/2]
20
20
  ropeCacheLen = 0;
21
- ensureRopeCache(s) {
22
- if (s <= this.ropeCacheLen) return;
21
+ ensureRopeCache(i) {
22
+ if (i <= this.ropeCacheLen) return;
23
23
  this.ropeCos && this.ropeCos.dispose(), this.ropeSin && this.ropeSin.dispose();
24
- const o = this.tf.range(0, s, 1, "float32").expandDims(1).mul(this.ropeInvFreq.expandDims(0));
25
- this.ropeCos = this.tf.keep(this.tf.cos(o).expandDims(-1)), this.ropeSin = this.tf.keep(this.tf.sin(o).expandDims(-1)), this.ropeCacheLen = s;
24
+ const e = Math.max(i, this.ropeCacheLen + this.config.blockSize * 4), s = this.tf.range(0, e, 1, "float32").expandDims(1).mul(this.ropeInvFreq.expandDims(0));
25
+ this.ropeCos = this.tf.keep(this.tf.cos(s).expandDims(-1)), this.ropeSin = this.tf.keep(this.tf.sin(s).expandDims(-1)), this.ropeCacheLen = e;
26
26
  }
27
27
  getCos() {
28
28
  return this.ropeCos;
@@ -30,21 +30,10 @@ class b {
30
30
  getSin() {
31
31
  return this.ropeSin;
32
32
  }
33
- applyRoPE(s, r, o) {
34
- const i = s.shape[3], t = this.rotaryDim;
35
- if (t > i) return [s, r];
36
- const e = s.shape[2], R = o + e;
37
- this.ensureRopeCache(R);
38
- const n = t / 2, c = this.ropeCos.slice([o, 0, 0], [e, n, 1]).reshape([1, 1, e, n]), a = this.ropeSin.slice([o, 0, 0], [e, n, 1]).reshape([1, 1, e, n]), h = s.shape[0], p = s.shape[1], f = this.tf.range(0, t, 2, "int32"), l = this.tf.range(1, t, 2, "int32"), d = (u) => {
39
- const m = u.slice([0, 0, 0, 0], [h, p, e, t]), C = t < i ? u.slice([0, 0, 0, t], [h, p, e, i - t]) : null, g = this.tf.gather(m, f, 3), D = this.tf.gather(m, l, 3), x = g.mul(c).sub(D.mul(a)), k = D.mul(c).add(g.mul(a)), S = this.tf.stack([x, k], -1).reshape([h, p, e, t]);
40
- return C ? this.tf.concat([S, C], 3) : S;
41
- }, v = d(s), y = d(r);
42
- return f.dispose(), l.dispose(), [v, y];
43
- }
44
33
  dispose() {
45
34
  this.ropeCos && this.ropeCos.dispose(), this.ropeSin && this.ropeSin.dispose(), this.ropeInvFreq.dispose();
46
35
  }
47
36
  }
48
37
  export {
49
- b as default
38
+ n as default
50
39
  };
package/dist/ops/rope.js CHANGED
@@ -1,6 +1,6 @@
1
1
  import { engine as D } from "@tensorflow/tfjs";
2
- import { o as G, l as F, k as _, n as z, E as K, p as O, d as T, q as U, r as g, c as A } from "../index-YPKosni4.js";
3
- import { r as $, s as B } from "../stack-BtKpB0Ry.js";
2
+ import { o as G, l as F, k as _, n as U, E as K, p as z, d as I, q as O, r as f, c as A } from "../index-YPKosni4.js";
3
+ import { r as T, s as B } from "../stack-BtKpB0Ry.js";
4
4
  /**
5
5
  * @license
6
6
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -17,17 +17,17 @@ import { r as $, s as B } from "../stack-BtKpB0Ry.js";
17
17
  * limitations under the License.
18
18
  * =============================================================================
19
19
  */
20
- function W(e, t = 0) {
21
- F(e.length >= 1, () => "Pass at least one tensor to concat");
22
- const o = _(e, "tensors", "concat", "string_or_numeric");
23
- if (o[0].dtype === "complex64" && o.forEach((s) => {
24
- if (s.dtype !== "complex64")
20
+ function W(o, t = 0) {
21
+ F(o.length >= 1, () => "Pass at least one tensor to concat");
22
+ const e = _(o, "tensors", "concat", "string_or_numeric");
23
+ if (e[0].dtype === "complex64" && e.forEach((r) => {
24
+ if (r.dtype !== "complex64")
25
25
  throw new Error(`Cannot concatenate complex64 tensors with a tensor
26
- with dtype ${s.dtype}. `);
27
- }), o.length === 1)
28
- return z(o[0]);
29
- const n = o, r = { axis: t };
30
- return K.runKernel(O, n, r);
26
+ with dtype ${r.dtype}. `);
27
+ }), e.length === 1)
28
+ return U(e[0]);
29
+ const n = e, s = { axis: t };
30
+ return K.runKernel(z, n, s);
31
31
  }
32
32
  const j = /* @__PURE__ */ G({ concat_: W });
33
33
  /**
@@ -46,9 +46,9 @@ const j = /* @__PURE__ */ G({ concat_: W });
46
46
  * limitations under the License.
47
47
  * =============================================================================
48
48
  */
49
- function H(e, t, o = 0, n = 0) {
50
- const r = T(e, "x", "gather"), s = T(t, "indices", "gather", "int32"), c = { x: r, indices: s }, a = { axis: o, batchDims: n };
51
- return K.runKernel(U, c, a);
49
+ function H(o, t, e = 0, n = 0) {
50
+ const s = I(o, "x", "gather"), r = I(t, "indices", "gather", "int32"), c = { x: s, indices: r }, a = { axis: e, batchDims: n };
51
+ return K.runKernel(O, c, a);
52
52
  }
53
53
  const E = /* @__PURE__ */ G({ gather_: H });
54
54
  class J {
@@ -56,8 +56,9 @@ class J {
56
56
  outputShape;
57
57
  userCode;
58
58
  // enableShapeUniforms = true;
59
- constructor(t, o, n, r, s) {
60
- this.outputShape = [t, o, n, r], this.userCode = `
59
+ customUniforms = [{ name: "pastLen", type: "int" }];
60
+ constructor(t, e, n, s) {
61
+ this.outputShape = [t, e, n, s], this.userCode = `
61
62
  void main() {
62
63
  ivec4 coords = getOutputCoords(); // [b, h, t, d]
63
64
  int b = coords.x;
@@ -65,14 +66,14 @@ class J {
65
66
  int t = coords.z;
66
67
  int d = coords.w;
67
68
 
68
- int rotaryDim = ${r};
69
+ int rotaryDim = ${s};
69
70
 
70
71
  float outVal = 0.0;
71
72
 
72
73
  if (d < rotaryDim) {
73
74
  int pairIdx = d / 2;
74
- float cos = getCos(t + ${s}, pairIdx, 0);
75
- float sin = getSin(t + ${s}, pairIdx, 0);
75
+ float cos = getCos(t + pastLen, pairIdx, 0);
76
+ float sin = getSin(t + pastLen, pairIdx, 0);
76
77
 
77
78
  if (d % 2 == 0) {
78
79
  // even index
@@ -95,59 +96,59 @@ class J {
95
96
  `;
96
97
  }
97
98
  }
98
- function M(e) {
99
- const { x: t, sin: o, cos: n } = e.inputs, { pastLen: r } = e.attrs, s = e.backend, c = t.shape[0], a = t.shape[1], i = t.shape[2], d = t.shape[3], p = new J(c, a, i, d, r);
100
- return s.runWebGLProgram(p, [t, o, n], "float32");
99
+ function M(o) {
100
+ const { x: t, sin: e, cos: n } = o.inputs, { pastLen: s } = o.attrs, r = o.backend, c = t.shape[0], a = t.shape[1], i = t.shape[2], d = t.shape[3], p = new J(c, a, i, d);
101
+ return r.runWebGLProgram(p, [t, e, n], "float32", [[s]]);
101
102
  }
102
103
  const Q = {
103
104
  kernelName: "Rope",
104
105
  backendName: "webgl",
105
106
  kernelFunc: M
106
107
  };
107
- g(Q);
108
- function V(e, t, o, n, r) {
109
- const s = n.shape[3], c = o;
110
- if (c > s) return n;
111
- const a = n.shape[2], i = c / 2, d = t.slice([r, 0, 0], [a, i, 1]).reshape([1, 1, a, i]), p = e.slice([r, 0, 0], [a, i, 1]).reshape([1, 1, a, i]), u = n.shape[0], l = n.shape[1], m = $(0, c, 2, "int32"), x = $(1, c, 2, "int32"), X = ((b) => {
112
- const v = b.slice([0, 0, 0, 0], [u, l, a, c]), k = c < s ? b.slice([0, 0, 0, c], [u, l, a, s - c]) : null, h = E(v, m, 3), f = E(v, x, 3), C = h.mul(d), y = f.mul(p), R = C.sub(y), N = f.mul(d), S = h.mul(p), w = N.add(S);
113
- h.dispose(), f.dispose(), d.dispose(), p.dispose(), C.dispose(), y.dispose(), N.dispose(), S.dispose();
114
- const P = B([R, w], -1);
108
+ f(Q);
109
+ function V(o, t, e, n, s) {
110
+ const r = n.shape[3], c = e;
111
+ if (c > r) return n;
112
+ const a = n.shape[2], i = c / 2, d = t.slice([s, 0, 0], [a, i, 1]).reshape([1, 1, a, i]), p = o.slice([s, 0, 0], [a, i, 1]).reshape([1, 1, a, i]), u = n.shape[0], l = n.shape[1], g = T(0, c, 2, "int32"), x = T(1, c, 2, "int32"), $ = ((b) => {
113
+ const v = b.slice([0, 0, 0, 0], [u, l, a, c]), k = c < r ? b.slice([0, 0, 0, c], [u, l, a, r - c]) : null, h = E(v, g, 3), m = E(v, x, 3), C = h.mul(d), y = m.mul(p), R = C.sub(y), N = m.mul(d), S = h.mul(p), w = N.add(S);
114
+ h.dispose(), m.dispose(), d.dispose(), p.dispose(), C.dispose(), y.dispose(), N.dispose(), S.dispose();
115
+ const L = B([R, w], -1);
115
116
  R.dispose(), w.dispose();
116
- const I = P.reshape([u, l, a, c]);
117
- return P.dispose(), k ? j([I, k], 3) : I;
117
+ const P = L.reshape([u, l, a, c]);
118
+ return L.dispose(), k ? j([P, k], 3) : P;
118
119
  })(n);
119
- return m.dispose(), x.dispose(), X;
120
+ return g.dispose(), x.dispose(), $;
120
121
  }
121
- function L(e) {
122
- const { x: t, sin: o, cos: n } = e.inputs, { pastLen: r } = e.attrs, s = t.shape[3];
123
- return V(o, n, s, t, r);
122
+ function X(o) {
123
+ const { x: t, sin: e, cos: n } = o.inputs, { pastLen: s } = o.attrs, r = t.shape[3];
124
+ return V(e, n, r, t, s);
124
125
  }
125
126
  const Y = {
126
127
  kernelName: "Rope",
127
128
  backendName: "cpu",
128
- kernelFunc: L
129
+ kernelFunc: X
129
130
  };
130
- g(Y);
131
+ f(Y);
131
132
  const Z = {
132
133
  kernelName: "Rope",
133
134
  backendName: "tensorflow",
134
- kernelFunc: L
135
+ kernelFunc: X
135
136
  };
136
- g(Z);
137
- function st(e, t, o) {
138
- return t.ensureRopeCache(e.shape[1]), D().runKernel("Rope", { x: e, sin: t.getSin(), cos: t.getCos() }, { pastLen: o });
137
+ f(Z);
138
+ function st(o, t, e) {
139
+ return t.ensureRopeCache(o.shape[1] + e), D().runKernel("Rope", { x: o, sin: t.getSin(), cos: t.getCos() }, { pastLen: e });
139
140
  }
140
141
  const q = {
141
142
  kernelName: "Rope",
142
143
  inputsToSave: ["x", "sin", "cos"],
143
144
  outputsToSave: [],
144
- gradFunc: (e, t) => {
145
- const [o, n, r] = t, s = n.neg(), c = o.shape[3], i = V(s, r, c, e, 0);
146
- return s.dispose(), { x: () => i };
145
+ gradFunc: (o, t) => {
146
+ const [e, n, s] = t, r = n.neg(), c = e.shape[3], i = V(r, s, c, o, 0);
147
+ return r.dispose(), { x: () => i };
147
148
  }
148
149
  };
149
150
  A(q);
150
151
  export {
151
152
  st as rope,
152
- L as ropeCPU
153
+ X as ropeCPU
153
154
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@genai-fi/nanogpt",
3
- "version": "0.2.11",
3
+ "version": "0.2.12",
4
4
  "type": "module",
5
5
  "main": "dist/main.js",
6
6
  "types": "dist/main.d.ts",