@genai-fi/nanogpt 0.2.10 → 0.2.12

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  import { defaultConfig as h } from "./config.js";
2
- import d from "./NanoGPTModel.js";
3
- import { saveModel as m } from "./utilities/save.js";
2
+ import m from "./NanoGPTModel.js";
3
+ import { saveModel as d } from "./utilities/save.js";
4
4
  import { loadModel as f } from "./utilities/load.js";
5
5
  import u from "./Generator.js";
6
6
  import _ from "./Trainer.js";
@@ -13,7 +13,9 @@ import "./jszip.min-CjP2V1VV.js";
13
13
  import "./ops/scatterSub.js";
14
14
  import "./ops/gatherSub.js";
15
15
  import "./ops/attentionMask.js";
16
- import w from "./utilities/profile.js";
16
+ import "./ops/qkv.js";
17
+ import "./ops/rope.js";
18
+ import p from "./utilities/profile.js";
17
19
  class a extends c {
18
20
  _config;
19
21
  _model;
@@ -50,7 +52,7 @@ class a extends c {
50
52
  saveModel(t) {
51
53
  if (!this._model || !this._tokeniser)
52
54
  throw new Error("Model or tokeniser is not initialized.");
53
- return m(this._model, this._tokeniser, t);
55
+ return d(this._model, this._tokeniser, t);
54
56
  }
55
57
  static loadModel(t, r) {
56
58
  const e = new a(t);
@@ -65,7 +67,7 @@ class a extends c {
65
67
  }), e;
66
68
  }
67
69
  static create(t, r = {}) {
68
- const e = { ...h, ...r }, o = new g(e.vocabSize), s = new d(t, e), i = new a(t, o, s);
70
+ const e = { ...h, ...r }, o = new g(e.vocabSize), s = new m(t, e), i = new a(t, o, s);
69
71
  return i.setStatus("warmup"), l(s).then(() => {
70
72
  i.tokeniser.trained ? i.setStatus("ready") : (i.setStatus("awaitingTokens"), i.tokeniser.once("trainStatus", (n) => {
71
73
  n === "trained" && i.setStatus("ready");
@@ -84,7 +86,7 @@ class a extends c {
84
86
  if (t) {
85
87
  if (!this._model)
86
88
  throw new Error("Model is not initialized.");
87
- this._model.getProfiler() || this._model.setProfiler(new w());
89
+ this._model.getProfiler() || this._model.setProfiler(new p());
88
90
  } else
89
91
  this._model && this._model.setProfiler(void 0);
90
92
  }
@@ -1,4 +1,4 @@
1
- import { o as c, d as s, g as n, E as m, C as r } from "./index-CWQLouWz.js";
1
+ import { o as c, d as s, g as n, E as m, C as r } from "./index-YPKosni4.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -383,7 +383,7 @@ function _t(n, t) {
383
383
  return e.set(n, s), e.get(n);
384
384
  }
385
385
  }
386
- const Ge = "Abs", ne = "Add", Es = "BatchMatMul", se = "Cast", As = "Complex", ze = "ComplexAbs", We = "RealDiv", Bs = "Elu", vs = "Exp", je = "Fill", Ke = "FloorDiv", Ms = "GatherNd", re = "Identity", Fs = "Imag", $s = "LeakyRelu", Rs = "Log", xs = "Max", Ve = "Maximum", qe = "Multiply", Ns = "Neg", Ds = "Pack", He = "Pow", Cs = "Prelu", _s = "Range", Ps = "Real", Os = "Relu", Ls = "Reshape", Us = "Relu6", Gs = "ScatterNd", zs = "Sigmoid", Je = "Sqrt", Ws = "Sum", js = "Softmax", Xe = "Sub", Ks = "Transpose", Ye = "ZerosLike", Vs = "Step", qs = "_FusedMatMul";
386
+ const Ge = "Abs", ne = "Add", Es = "BatchMatMul", se = "Cast", As = "Complex", ze = "ComplexAbs", Bs = "Concat", We = "RealDiv", vs = "Elu", Ms = "Exp", je = "Fill", Ke = "FloorDiv", Fs = "GatherV2", $s = "GatherNd", re = "Identity", Rs = "Imag", xs = "LeakyRelu", Ns = "Log", Ds = "Max", Ve = "Maximum", qe = "Multiply", Cs = "Neg", _s = "Pack", He = "Pow", Ps = "Prelu", Os = "Range", Ls = "Real", Us = "Relu", Gs = "Reshape", zs = "Relu6", Ws = "ScatterNd", js = "Sigmoid", Je = "Sqrt", Ks = "Sum", Vs = "SplitV", qs = "Softmax", Xe = "Sub", Hs = "Transpose", Ye = "ZerosLike", Js = "Step", Xs = "_FusedMatMul";
387
387
  /**
388
388
  * @license
389
389
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -438,11 +438,11 @@ function Wt(n) {
438
438
  }
439
439
  return e;
440
440
  }
441
- function Hs(n) {
441
+ function Ys(n) {
442
442
  const { kernelName: t, backendName: e } = n, s = ie(t, e);
443
443
  ht.has(s) && O(`The kernel '${t}' for backend '${e}' is already registered`), ht.set(s, n);
444
444
  }
445
- function Js(n) {
445
+ function Qs(n) {
446
446
  const { kernelName: t } = n;
447
447
  It.has(t) && S().getBool("DEBUG") && O(`Overriding the gradient for '${t}'`), It.set(t, n);
448
448
  }
@@ -1902,7 +1902,7 @@ function I(n, t, e, s = "numeric") {
1902
1902
  const a = r !== "string" ? ae(n, r) : at(n, [], !0);
1903
1903
  return g.makeTensor(a, i, r);
1904
1904
  }
1905
- function Xs(n, t, e, s = "numeric") {
1905
+ function Zs(n, t, e, s = "numeric") {
1906
1906
  if (!Array.isArray(n))
1907
1907
  throw new Error(`Argument ${t} passed to ${e} must be a \`Tensor[]\` or \`TensorLike[]\``);
1908
1908
  return n.map((i, o) => I(i, `${t}[${o}]`, e, s));
@@ -2065,10 +2065,10 @@ function Sn(n, t) {
2065
2065
  * limitations under the License.
2066
2066
  * =============================================================================
2067
2067
  */
2068
- function Ys() {
2068
+ function tr() {
2069
2069
  return g;
2070
2070
  }
2071
- function Qs() {
2071
+ function er() {
2072
2072
  return g.memory();
2073
2073
  }
2074
2074
  function E(n, t) {
@@ -2893,7 +2893,7 @@ function Yn(n, t, e) {
2893
2893
  * limitations under the License.
2894
2894
  * =============================================================================
2895
2895
  */
2896
- function Zs(n, t) {
2896
+ function nr(n, t) {
2897
2897
  const e = [];
2898
2898
  for (let s = 0; s < t.length; s++) {
2899
2899
  const r = n[n.length - s - 1], i = t.length - s - 1, o = t[i];
@@ -3061,7 +3061,7 @@ function ss(n, t) {
3061
3061
  a[u] != null && (c[l.name] = a[u]);
3062
3062
  }), s?.forEach((l) => c[l.name] = null), { value: o, grads: c };
3063
3063
  }
3064
- function tr(n) {
3064
+ function sr(n) {
3065
3065
  return g.customGrad(n);
3066
3066
  }
3067
3067
  /**
@@ -3841,55 +3841,59 @@ function bs() {
3841
3841
  */
3842
3842
  bs();
3843
3843
  export {
3844
+ Qn as $,
3844
3845
  ds as A,
3845
3846
  Es as B,
3846
3847
  As as C,
3847
- C as D,
3848
+ w as D,
3848
3849
  g as E,
3849
- zs as F,
3850
- Ms as G,
3851
- Bs as H,
3852
- Fs as I,
3853
- $s as J,
3854
- Cs as K,
3855
- Rs as L,
3856
- xs as M,
3857
- Ns as N,
3858
- Ps as O,
3859
- Ds as P,
3860
- Os as Q,
3861
- _s as R,
3862
- Ws as S,
3863
- Us as T,
3864
- Vs as U,
3865
- Ks as V,
3866
- Zs as W,
3867
- Qn as X,
3868
- qs as _,
3850
+ qs as F,
3851
+ $s as G,
3852
+ sr as H,
3853
+ E as I,
3854
+ C as J,
3855
+ js as K,
3856
+ Ns as L,
3857
+ Ds as M,
3858
+ vs as N,
3859
+ Rs as O,
3860
+ _s as P,
3861
+ xs as Q,
3862
+ Gs as R,
3863
+ Ks as S,
3864
+ Cs as T,
3865
+ Ps as U,
3866
+ Ls as V,
3867
+ Us as W,
3868
+ zs as X,
3869
+ Js as Y,
3870
+ Hs as Z,
3871
+ nr as _,
3869
3872
  p as a,
3873
+ Xs as a0,
3870
3874
  Z as b,
3871
- Js as c,
3875
+ Qs as c,
3872
3876
  I as d,
3873
- Ys as e,
3877
+ tr as e,
3874
3878
  V as f,
3875
3879
  Is as g,
3876
- Xs as h,
3877
- y as i,
3878
- Ls as j,
3879
- $t as k,
3880
- Dt as l,
3881
- Qs as m,
3882
- Zt as n,
3880
+ $t as h,
3881
+ Vs as i,
3882
+ Os as j,
3883
+ Zs as k,
3884
+ y as l,
3885
+ er as m,
3886
+ Gn as n,
3883
3887
  F as o,
3884
- G as p,
3885
- De as q,
3886
- Hs as r,
3888
+ Bs as p,
3889
+ Fs as q,
3890
+ Ys as r,
3887
3891
  K as s,
3888
- Gs as t,
3889
- vs as u,
3890
- Ts as v,
3891
- w,
3892
- js as x,
3893
- tr as y,
3894
- E as z
3892
+ Dt as t,
3893
+ Zt as u,
3894
+ G as v,
3895
+ De as w,
3896
+ Ws as x,
3897
+ Ms as y,
3898
+ Ts as z
3895
3899
  };
@@ -21,7 +21,9 @@ export default class CausalSelfAttention extends BaseLayer {
21
21
  private divisor;
22
22
  private index;
23
23
  private _trainable;
24
+ private units;
24
25
  constructor(tf: typeof TF, index: number, config: GPTConfig, ropeCache?: RoPECache | undefined);
26
+ private build;
25
27
  get variables(): TF.Variable[];
26
28
  get trainable(): boolean;
27
29
  set trainable(value: boolean);
@@ -1,17 +1,10 @@
1
- import { attentionMask as z } from "../ops/attentionMask.js";
2
- import S from "./BaseLayer.js";
3
- class C extends S {
1
+ import { attentionMask as x } from "../ops/attentionMask.js";
2
+ import j from "./BaseLayer.js";
3
+ import { qkv as w } from "../ops/qkv.js";
4
+ import { rope as y } from "../ops/rope.js";
5
+ class N extends j {
4
6
  constructor(t, i, s, e) {
5
- super(), this.ropeCache = e, this.config = s, this.tf = t, this.index = i, this.cAttn = this.tf.layers.dense({
6
- units: 3 * s.nEmbed,
7
- useBias: s.biasInLinear,
8
- name: `block_${i}_attn_cAttn`,
9
- kernelInitializer: this.tf.initializers.randomNormal({
10
- mean: 0,
11
- stddev: 0.02
12
- }),
13
- biasInitializer: "zeros"
14
- }), this.cProj = this.tf.layers.dense({
7
+ super(), this.ropeCache = e, this.config = s, this.tf = t, this.index = i, this.units = s.nEmbed * 3, this.cProj = this.tf.layers.dense({
15
8
  units: s.nEmbed,
16
9
  useBias: s.biasInLinear,
17
10
  name: `block_${i}_attn_cProj`,
@@ -21,11 +14,11 @@ class C extends S {
21
14
  }),
22
15
  biasInitializer: "zeros"
23
16
  }), this.attnDropout = this.tf.layers.dropout({ rate: s.dropout }), this.residDropout = this.tf.layers.dropout({ rate: s.dropout }), this.bias = this.tf.linalg.bandPart(this.tf.ones([s.blockSize, s.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(s.nEmbed / s.nHead);
24
- const o = this.tf.zeros([s.blockSize, s.blockSize]), c = this.tf.fill([s.blockSize, s.blockSize], Number.NEGATIVE_INFINITY);
25
- this.maskInf = this.tf.where(this.bias, o, c);
17
+ const o = this.tf.zeros([s.blockSize, s.blockSize]), a = this.tf.fill([s.blockSize, s.blockSize], Number.NEGATIVE_INFINITY);
18
+ this.maskInf = this.tf.where(this.bias, o, a);
26
19
  }
27
20
  config;
28
- cAttn;
21
+ cAttn = null;
29
22
  cProj;
30
23
  attnDropout;
31
24
  residDropout;
@@ -35,26 +28,35 @@ class C extends S {
35
28
  divisor;
36
29
  index;
37
30
  _trainable = !0;
31
+ units;
32
+ build() {
33
+ this.cAttn === null && (this.cAttn = this.tf.variable(
34
+ this.tf.randomNormal([this.config.nEmbed, this.units], 0, 0.02),
35
+ !0
36
+ //`block_${this.index}_attn_cAttn_kernel`
37
+ ));
38
+ }
38
39
  get variables() {
39
- return [
40
- ...this.cAttn.trainableWeights.map((t) => t.read()),
41
- ...this.cProj.trainableWeights.map((t) => t.read())
42
- ];
40
+ if (this.cAttn === null)
41
+ throw new Error("Layer not built yet");
42
+ return [this.cAttn, ...this.cProj.trainableWeights.map((t) => t.read())];
43
43
  }
44
44
  get trainable() {
45
45
  return this._trainable;
46
46
  }
47
47
  set trainable(t) {
48
- this._trainable = t, this.cAttn.trainable = t, this.cProj.trainable = t;
48
+ this._trainable = t, this.cAttn && (this.cAttn.trainable = t), this.cProj.trainable = t;
49
49
  }
50
50
  saveWeights(t) {
51
- t.set(`block_${this.index}_cAttn`, this.cAttn.getWeights()), t.set(`block_${this.index}_cProj`, this.cProj.getWeights());
51
+ t.set(`block_${this.index}_cAttn`, this.cAttn ? [this.cAttn.clone()] : []), t.set(`block_${this.index}_cProj`, this.cProj.getWeights());
52
52
  }
53
53
  loadWeights(t) {
54
- this.cAttn.setWeights(t.get(`block_${this.index}_cAttn`) || []), this.cProj.setWeights(t.get(`block_${this.index}_cProj`) || []);
54
+ const i = t.get(`block_${this.index}_cAttn`)?.[0];
55
+ if (!i) throw new Error(`Weights for block_${this.index}_cAttn not found`);
56
+ this.cAttn ? this.cAttn.assign(i) : this.cAttn = this.tf.variable(i, !0), this.cProj.setWeights(t.get(`block_${this.index}_cProj`) || []);
55
57
  }
56
58
  getAttentionScores(t, i, s) {
57
- const e = z(t, i, this.maskInf, this.divisor), o = this.tf.softmax(e, -1);
59
+ const e = x(t, i, this.maskInf, this.divisor), o = this.tf.softmax(e, -1);
58
60
  return this.attnDropout.apply(o, { training: s });
59
61
  }
60
62
  // Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
@@ -64,61 +66,49 @@ class C extends S {
64
66
  if (o > 1 && e > 0)
65
67
  throw new Error("Cannot use past with T_cur > 1");
66
68
  if (o > 1) {
67
- const a = this.maskInf.slice([0, 0], [o, o]).expandDims(0).expandDims(0);
68
- r = r.add(a);
69
+ const c = this.maskInf.slice([0, 0], [o, o]).expandDims(0).expandDims(0);
70
+ r = r.add(c);
69
71
  }
70
72
  const h = this.tf.softmax(r, -1);
71
73
  return this.attnDropout.apply(h, { training: s });
72
74
  }
73
75
  getQKV(t) {
74
- const [i, s, e] = t.shape, o = this.cAttn.apply(t), [c, r, h] = this.tf.split(o, 3, -1);
75
- o.dispose();
76
- const a = e / this.config.nHead, u = this.tf.reshape(c, [i, s, this.config.nHead, a]);
77
- c.dispose();
78
- const f = u.transpose([0, 2, 1, 3]);
79
- u.dispose();
80
- const d = this.tf.reshape(r, [i, s, this.config.nHead, a]);
81
- r.dispose();
82
- const n = d.transpose([0, 2, 1, 3]);
83
- d.dispose();
84
- const l = this.tf.reshape(h, [i, s, this.config.nHead, a]);
85
- h.dispose();
86
- const p = l.transpose([0, 2, 1, 3]);
87
- return l.dispose(), [f, n, p];
76
+ return w(t, this.cAttn, this.config.nHead);
88
77
  }
89
78
  getOutputProjection(t, i) {
90
- const s = t.shape[0], e = t.shape[2], o = this.config.nEmbed, c = t.transpose([0, 2, 1, 3]), r = this.tf.reshape(c, [s, e, o]), h = this.cProj.apply(r);
79
+ const s = t.shape[0], e = t.shape[2], o = this.config.nEmbed, a = t.transpose([0, 2, 1, 3]), r = this.tf.reshape(a, [s, e, o]), h = this.cProj.apply(r);
91
80
  return this.residDropout.apply(h, { training: i });
92
81
  }
93
82
  // Added optional KV cache support (pastKV). Returns presentKV for chaining.
94
83
  call(t, i = !1, s = !1, e) {
95
84
  if (e && !this.config.useRope)
96
85
  throw new Error("Cannot use pastKV without RoPE enabled");
97
- return this.tf.tidy(() => {
86
+ return this.build(), this.tf.tidy(() => {
98
87
  this.startMemory();
99
- const [o, c, r] = this.getQKV(t), h = o.shape[2], a = this.config.blockSize, u = e ? e.cumulativeLength : 0, [f, d] = this.ropeCache ? this.ropeCache.applyRoPE(o, c, u) : [o, c];
100
- let n = d, l = r, p = 0;
101
- e && (p = e.length, n = this.tf.concat([e.k, d], 2), l = this.tf.concat([e.v, r], 2));
88
+ const [o, a, r] = this.getQKV(t), h = o.shape[2], c = this.config.blockSize, d = e ? e.cumulativeLength : 0, f = this.ropeCache ? y(o, this.ropeCache, d) : o, m = this.ropeCache ? y(a, this.ropeCache, d) : a;
89
+ this.ropeCache && (o.dispose(), a.dispose());
90
+ let n = m, l = r, u = 0;
91
+ e && (u = e.length, n = this.tf.concat([e.k, m], 2), l = this.tf.concat([e.v, r], 2));
102
92
  const b = n.shape[2];
103
- if (b > a) {
104
- const k = b - a, g = n.shape[0], A = n.shape[1], I = n.shape[3];
105
- n = n.slice([0, 0, k, 0], [g, A, a, I]), l = l.slice([0, 0, k, 0], [g, A, a, I]), p = a - h;
93
+ if (b > c) {
94
+ const k = b - c, A = n.shape[0], g = n.shape[1], _ = n.shape[3];
95
+ n = n.slice([0, 0, k, 0], [A, g, c, _]), l = l.slice([0, 0, k, 0], [A, g, c, _]), u = c - h;
106
96
  }
107
- let m;
108
- p > 0 ? m = this.getAttentionScoresWithPast(f, n, i, p) : m = this.getAttentionScores(f, n, i);
109
- const _ = this.tf.matMul(m, l), v = this.getOutputProjection(_, i), y = {
97
+ let p;
98
+ u > 0 ? p = this.getAttentionScoresWithPast(f, n, i, u) : p = this.getAttentionScores(f, n, i);
99
+ const P = this.tf.matMul(p, l), S = this.getOutputProjection(P, i), v = {
110
100
  k: this.tf.keep(n),
111
101
  v: this.tf.keep(l),
112
- length: p + h,
102
+ length: u + h,
113
103
  cumulativeLength: e ? e.cumulativeLength + h : h
114
- }, P = s ? m.mean(1) : void 0;
115
- return this.endMemory("CausalSelfAttention"), { output: v, attention: P, presentKV: y };
104
+ }, I = s ? p.mean(1) : void 0;
105
+ return this.endMemory("CausalSelfAttention"), { output: S, attention: I, presentKV: v };
116
106
  });
117
107
  }
118
108
  dispose() {
119
- this.cAttn.dispose(), this.cProj.dispose(), this.attnDropout.dispose(), this.residDropout.dispose(), this.bias.dispose(), this.maskInf.dispose();
109
+ this.cAttn?.dispose(), this.cProj.dispose(), this.attnDropout.dispose(), this.residDropout.dispose(), this.bias.dispose(), this.maskInf.dispose();
120
110
  }
121
111
  }
122
112
  export {
123
- C as default
113
+ N as default
124
114
  };
@@ -3,14 +3,15 @@ import { GPTConfig } from '../config';
3
3
  export default class RoPECache {
4
4
  private readonly tf;
5
5
  private readonly config;
6
- private rotaryDim;
6
+ readonly rotaryDim: number;
7
7
  private ropeBase;
8
8
  private ropeInvFreq;
9
9
  private ropeCos;
10
10
  private ropeSin;
11
11
  private ropeCacheLen;
12
12
  constructor(tf: typeof TF, config: GPTConfig);
13
- private ensureRopeCache;
14
- applyRoPE(q: TF.Tensor, k: TF.Tensor, pastLen: number): [TF.Tensor, TF.Tensor];
13
+ ensureRopeCache(needed: number): void;
14
+ getCos(): TF.Tensor | null;
15
+ getSin(): TF.Tensor | null;
15
16
  dispose(): void;
16
17
  }
@@ -1,12 +1,12 @@
1
- class b {
2
- constructor(s, r) {
3
- this.tf = s, this.config = r;
4
- const o = this.config.nEmbed / this.config.nHead;
5
- if (this.rotaryDim = o, this.rotaryDim % 2 !== 0)
1
+ class n {
2
+ constructor(i, e) {
3
+ this.tf = i, this.config = e;
4
+ const t = this.config.nEmbed / this.config.nHead;
5
+ if (this.rotaryDim = t, this.rotaryDim % 2 !== 0)
6
6
  throw new Error("rotaryDim must be even");
7
7
  this.ropeBase = 1e4;
8
- const i = this.tf.range(0, this.rotaryDim, 2, "float32"), t = i.div(this.tf.scalar(this.rotaryDim, "float32")), e = this.tf.pow(this.tf.scalar(this.ropeBase, "float32"), t);
9
- this.ropeInvFreq = this.tf.reciprocal(e), t.dispose(), e.dispose(), i.dispose(), this.config.useRope === !1 ? (this.ropeCos = null, this.ropeSin = null, this.ropeCacheLen = 0) : this.tf.tidy(() => {
8
+ const s = this.tf.range(0, this.rotaryDim, 2, "float32"), o = s.div(this.tf.scalar(this.rotaryDim, "float32")), r = this.tf.pow(this.tf.scalar(this.ropeBase, "float32"), o);
9
+ this.ropeInvFreq = this.tf.reciprocal(r), o.dispose(), r.dispose(), s.dispose(), this.config.useRope === !1 ? (this.ropeCos = null, this.ropeSin = null, this.ropeCacheLen = 0) : this.tf.tidy(() => {
10
10
  this.ensureRopeCache(this.config.blockSize * 4);
11
11
  });
12
12
  }
@@ -18,27 +18,22 @@ class b {
18
18
  ropeSin = null;
19
19
  // [cacheLen, rotaryDim/2]
20
20
  ropeCacheLen = 0;
21
- ensureRopeCache(s) {
22
- if (s <= this.ropeCacheLen) return;
21
+ ensureRopeCache(i) {
22
+ if (i <= this.ropeCacheLen) return;
23
23
  this.ropeCos && this.ropeCos.dispose(), this.ropeSin && this.ropeSin.dispose();
24
- const o = this.tf.range(0, s, 1, "float32").expandDims(1).mul(this.ropeInvFreq.expandDims(0));
25
- this.ropeCos = this.tf.keep(this.tf.cos(o).expandDims(-1)), this.ropeSin = this.tf.keep(this.tf.sin(o).expandDims(-1)), this.ropeCacheLen = s;
24
+ const e = Math.max(i, this.ropeCacheLen + this.config.blockSize * 4), s = this.tf.range(0, e, 1, "float32").expandDims(1).mul(this.ropeInvFreq.expandDims(0));
25
+ this.ropeCos = this.tf.keep(this.tf.cos(s).expandDims(-1)), this.ropeSin = this.tf.keep(this.tf.sin(s).expandDims(-1)), this.ropeCacheLen = e;
26
26
  }
27
- applyRoPE(s, r, o) {
28
- const i = s.shape[3], t = this.rotaryDim;
29
- if (t > i) return [s, r];
30
- const e = s.shape[2], v = o + e;
31
- this.ensureRopeCache(v);
32
- const n = t / 2, p = this.ropeCos.slice([o, 0, 0], [e, n, 1]).reshape([1, 1, e, n]), a = this.ropeSin.slice([o, 0, 0], [e, n, 1]).reshape([1, 1, e, n]), h = s.shape[0], c = s.shape[1], f = this.tf.range(0, t, 2, "int32"), l = this.tf.range(1, t, 2, "int32"), d = (u) => {
33
- const m = u.slice([0, 0, 0, 0], [h, c, e, t]), C = t < i ? u.slice([0, 0, 0, t], [h, c, e, i - t]) : null, D = this.tf.gather(m, f, 3), g = this.tf.gather(m, l, 3), x = D.mul(p).sub(g.mul(a)), k = g.mul(p).add(D.mul(a)), R = this.tf.stack([x, k], -1).reshape([h, c, e, t]);
34
- return C ? this.tf.concat([R, C], 3) : R;
35
- }, y = d(s), S = d(r);
36
- return f.dispose(), l.dispose(), [y, S];
27
+ getCos() {
28
+ return this.ropeCos;
29
+ }
30
+ getSin() {
31
+ return this.ropeSin;
37
32
  }
38
33
  dispose() {
39
34
  this.ropeCos && this.ropeCos.dispose(), this.ropeSin && this.ropeSin.dispose(), this.ropeInvFreq.dispose();
40
35
  }
41
36
  }
42
37
  export {
43
- b as default
38
+ n as default
44
39
  };
@@ -1,7 +1,8 @@
1
- import { o as h, d as i, E as o, F as V, H as X, I as Y, J as Z, N as ee, K as te, O as se, Q as ne, T as re, U as ue, i as L, z as ae, V as A, a as ie, W as oe, w as le, f as q, p as C, X as P, y as U, _ as H } from "../index-CWQLouWz.js";
2
- import { s as ce, r as f } from "../sum-CnIf1YOh.js";
3
- import { m } from "../mat_mul-4v7St11W.js";
4
- import { c as pe } from "../complex-x7w5HPOS.js";
1
+ import { o as h, d as i, E as o, K as X, N as Y, O as Z, Q as J, T as ee, U as te, V as se, W as ne, X as re, Y as ue, l as L, I as ae, Z as A, a as ie, _ as oe, D as le, f as q, v as C, $ as P, H as U, a0 as H } from "../index-YPKosni4.js";
2
+ import { r as f } from "../reshape-DmnmKT6r.js";
3
+ import { s as ce } from "../sum-D7fu15XL.js";
4
+ import { m } from "../mat_mul-Bu7bhLms.js";
5
+ import { c as pe } from "../complex-CJ-qCcLB.js";
5
6
  /**
6
7
  * @license
7
8
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -20,7 +21,7 @@ import { c as pe } from "../complex-x7w5HPOS.js";
20
21
  */
21
22
  function he(t) {
22
23
  const s = { x: i(t, "x", "sigmoid", "float32") };
23
- return o.runKernel(V, s);
24
+ return o.runKernel(X, s);
24
25
  }
25
26
  const fe = /* @__PURE__ */ h({ sigmoid_: he });
26
27
  /**
@@ -41,7 +42,7 @@ const fe = /* @__PURE__ */ h({ sigmoid_: he });
41
42
  */
42
43
  function de(t) {
43
44
  const s = { x: i(t, "x", "elu", "float32") };
44
- return o.runKernel(X, s);
45
+ return o.runKernel(Y, s);
45
46
  }
46
47
  const me = /* @__PURE__ */ h({ elu_: de });
47
48
  /**
@@ -62,7 +63,7 @@ const me = /* @__PURE__ */ h({ elu_: de });
62
63
  */
63
64
  function ge(t) {
64
65
  const s = { input: i(t, "input", "imag") };
65
- return o.runKernel(Y, s);
66
+ return o.runKernel(Z, s);
66
67
  }
67
68
  const $e = /* @__PURE__ */ h({ imag_: ge });
68
69
  /**
@@ -83,7 +84,7 @@ const $e = /* @__PURE__ */ h({ imag_: ge });
83
84
  */
84
85
  function xe(t, e = 0.2) {
85
86
  const n = { x: i(t, "x", "leakyRelu") }, r = { alpha: e };
86
- return o.runKernel(Z, n, r);
87
+ return o.runKernel(J, n, r);
87
88
  }
88
89
  const ke = /* @__PURE__ */ h({ leakyRelu_: xe });
89
90
  /**
@@ -169,7 +170,7 @@ function Me(t) {
169
170
  const s = { x: i(t, "x", "relu") };
170
171
  return o.runKernel(ne, s);
171
172
  }
172
- const we = /* @__PURE__ */ h({ relu_: Me });
173
+ const We = /* @__PURE__ */ h({ relu_: Me });
173
174
  /**
174
175
  * @license
175
176
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -186,11 +187,11 @@ const we = /* @__PURE__ */ h({ relu_: Me });
186
187
  * limitations under the License.
187
188
  * =============================================================================
188
189
  */
189
- function We(t) {
190
+ function we(t) {
190
191
  const s = { x: i(t, "x", "relu6") };
191
192
  return o.runKernel(re, s);
192
193
  }
193
- const ze = /* @__PURE__ */ h({ relu6_: We });
194
+ const ze = /* @__PURE__ */ h({ relu6_: we });
194
195
  /**
195
196
  * @license
196
197
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -273,7 +274,7 @@ function Te(t, e, s, n) {
273
274
  if (e === "linear")
274
275
  return t;
275
276
  if (e === "relu")
276
- return we(t);
277
+ return We(t);
277
278
  if (e === "elu")
278
279
  return me(t);
279
280
  if (e === "relu6")
@@ -310,42 +311,42 @@ function Ne({ a: t, b: e, transposeA: s = !1, transposeB: n = !1, bias: r, activ
310
311
  }
311
312
  let u = i(t, "a", "fused matMul"), a = i(e, "b", "fused matMul");
312
313
  [u, a] = q(u, a);
313
- const D = s ? u.shape[u.rank - 2] : u.shape[u.rank - 1], b = n ? a.shape[a.rank - 1] : a.shape[a.rank - 2], w = s ? u.shape[u.rank - 1] : u.shape[u.rank - 2], W = n ? a.shape[a.rank - 2] : a.shape[a.rank - 1], T = u.shape.slice(0, -2), y = a.shape.slice(0, -2), B = C(T), N = C(y);
314
+ const D = s ? u.shape[u.rank - 2] : u.shape[u.rank - 1], b = n ? a.shape[a.rank - 1] : a.shape[a.rank - 2], W = s ? u.shape[u.rank - 1] : u.shape[u.rank - 2], w = n ? a.shape[a.rank - 2] : a.shape[a.rank - 1], T = u.shape.slice(0, -2), y = a.shape.slice(0, -2), B = C(T), N = C(y);
314
315
  L(D === b, () => `Error in fused matMul: inner shapes (${D}) and (${b}) of Tensors with shapes ${u.shape} and ${a.shape} and transposeA=${s} and transposeB=${n} must match.`);
315
- const O = P(u.shape.slice(0, -2), a.shape.slice(0, -2)).concat([w, W]), F = s ? f(u, [B, D, w]) : f(u, [B, w, D]), R = n ? f(a, [N, W, b]) : f(a, [N, b, W]);
316
+ const O = P(u.shape.slice(0, -2), a.shape.slice(0, -2)).concat([W, w]), F = s ? f(u, [B, D, W]) : f(u, [B, W, D]), R = n ? f(a, [N, w, b]) : f(a, [N, b, w]);
316
317
  let S;
317
318
  r != null && (S = i(r, "bias", "fused matMul"), [S] = q(S, u), P(O, S.shape));
318
- let G;
319
- l != null && (G = i(l, "prelu weights", "fused matMul"));
320
- const I = (x, M) => {
319
+ let v;
320
+ l != null && (v = i(l, "prelu weights", "fused matMul"));
321
+ const G = (x, M) => {
321
322
  const [g, $, k, z] = M, d = Ae(f(x, k.shape), k, c);
322
323
  let K, _;
323
324
  if (!s && !n ? (K = m(d, $, !1, !0), _ = m(g, d, !0, !1)) : !s && n ? (K = m(d, $, !1, !1), _ = m(d, g, !0, !1)) : s && !n ? (K = m($, d, !1, !0), _ = m(g, d, !1, !1)) : (K = m($, d, !0, !0), _ = m(d, g, !0, !0)), r != null) {
324
- const Q = Le(z, d);
325
- return [K, _, Q];
325
+ const V = Le(z, d);
326
+ return [K, _, V];
326
327
  } else
327
328
  return [K, _];
328
- }, v = {
329
+ }, I = {
329
330
  a: F,
330
331
  b: R,
331
332
  bias: S,
332
- preluActivationWeights: G
333
+ preluActivationWeights: v
333
334
  }, j = { transposeA: s, transposeB: n, activation: c, leakyreluAlpha: p };
334
335
  return r == null ? U((M, g, $) => {
335
336
  const k = (
336
337
  // tslint:disable-next-line: no-unnecessary-type-assertion
337
- o.runKernel(H, v, j)
338
+ o.runKernel(H, I, j)
338
339
  );
339
- return $([M, g, k]), { value: f(k, O), gradFunc: I };
340
+ return $([M, g, k]), { value: f(k, O), gradFunc: G };
340
341
  })(F, R) : U((M, g, $, k) => {
341
342
  const z = (
342
343
  // tslint:disable-next-line: no-unnecessary-type-assertion
343
- o.runKernel(H, v, j)
344
+ o.runKernel(H, I, j)
344
345
  );
345
- return k([M, g, z, $]), { value: f(z, O), gradFunc: I };
346
+ return k([M, g, z, $]), { value: f(z, O), gradFunc: G };
346
347
  })(F, R, S);
347
348
  }
348
- const J = /* @__PURE__ */ h({ fusedMatMul_: Ne });
349
+ const Q = /* @__PURE__ */ h({ fusedMatMul_: Ne });
349
350
  /**
350
351
  * @license
351
352
  * Copyright 2018 Google LLC
@@ -369,7 +370,7 @@ class E extends Error {
369
370
  * https://opensource.org/licenses/MIT.
370
371
  * =============================================================================
371
372
  */
372
- function Ge(t, e, s, n) {
373
+ function ve(t, e, s, n) {
373
374
  if (t.rank < 2 || e.rank < 2)
374
375
  throw new E(`dot requires both inputs to be rank >= 2 but got x shape = ${t.shape} and y shape = ${e.shape}`);
375
376
  if (e.rank >= 3) {
@@ -378,7 +379,7 @@ function Ge(t, e, s, n) {
378
379
  throw new E(`If rank y >= 3, then the second last dim of y must equal the last dim of x but got x shape = ${t.shape} and y shape = ${e.shape}`);
379
380
  }
380
381
  if (t.rank === 2 && e.rank === 2)
381
- return J({
382
+ return Q({
382
383
  a: t,
383
384
  b: e,
384
385
  transposeA: !1,
@@ -392,7 +393,7 @@ function Ge(t, e, s, n) {
392
393
  const l = e.shape.slice(), p = l.pop(), u = l.pop(), a = [...l, p], D = Array.from({ length: e.rank }, (T, y) => y === 0 ? e.rank - 2 : y <= e.rank - 2 ? y - 1 : y);
393
394
  e = f(Re(e, D), [u, -1]);
394
395
  const b = [...r, ...a];
395
- return f(J({
396
+ return f(Q({
396
397
  a: t,
397
398
  b: e,
398
399
  transposeA: !1,
@@ -402,7 +403,7 @@ function Ge(t, e, s, n) {
402
403
  }), b);
403
404
  }
404
405
  }
405
- class Pe {
406
+ class Ue {
406
407
  vocabSize;
407
408
  embedDim;
408
409
  tf;
@@ -425,7 +426,7 @@ class Pe {
425
426
  return this.tf.gather(this.tiedWeights, e, 0);
426
427
  }
427
428
  project(e) {
428
- return Ge(e, this.tiedWeights.transpose());
429
+ return ve(e, this.tiedWeights.transpose());
429
430
  }
430
431
  getWeights() {
431
432
  return [this.tiedWeights];
@@ -444,5 +445,5 @@ class Pe {
444
445
  }
445
446
  }
446
447
  export {
447
- Pe as default
448
+ Ue as default
448
449
  };