@genai-fi/nanogpt 0.2.6 → 0.2.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,6 +12,7 @@ import "./index-Tf7vU29b.js";
12
12
  import "./jszip.min-CjP2V1VV.js";
13
13
  import "./ops/scatterSub.js";
14
14
  import "./ops/gatherSub.js";
15
+ import "./ops/attentionMask.js";
15
16
  class a extends c {
16
17
  _config;
17
18
  _model;
@@ -1,4 +1,4 @@
1
- import { o as t, c as s, b as n, E as m, C as r } from "./index-D1SlunD-.js";
1
+ import { o as t, c as s, d as n, E as m, C as r } from "./index-DQfEAU9u.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -3858,29 +3858,29 @@ export {
3858
3858
  Qn as V,
3859
3859
  qs as _,
3860
3860
  Z as a,
3861
- Is as b,
3861
+ K as b,
3862
3862
  I as c,
3863
- Js as d,
3863
+ Is as d,
3864
3864
  Xs as e,
3865
- y as f,
3866
- Ls as g,
3867
- Ft as h,
3868
- Nt as i,
3869
- Qt as j,
3870
- U as k,
3871
- Ne as l,
3865
+ Js as f,
3866
+ y as g,
3867
+ Ls as h,
3868
+ Ft as i,
3869
+ Nt as j,
3870
+ Qt as k,
3871
+ U as l,
3872
3872
  p as m,
3873
- Gs as n,
3873
+ Ne as n,
3874
3874
  F as o,
3875
- vs as p,
3876
- Ts as q,
3875
+ Gs as p,
3876
+ vs as q,
3877
3877
  Hs as r,
3878
3878
  j as s,
3879
- w as t,
3880
- js as u,
3881
- Qs as v,
3882
- E as w,
3883
- K as x,
3879
+ Ts as t,
3880
+ w as u,
3881
+ js as v,
3882
+ Qs as w,
3883
+ E as x,
3884
3884
  zs as y,
3885
3885
  C as z
3886
3886
  };
@@ -1,4 +1,5 @@
1
- class S {
1
+ import { attentionMask as z } from "../ops/attentionMask.js";
2
+ class j {
2
3
  constructor(t, i, s, e) {
3
4
  this.ropeCache = e, this.config = s, this.tf = t, this.index = i, this.cAttn = this.tf.layers.dense({
4
5
  units: 3 * s.nEmbed,
@@ -18,9 +19,9 @@ class S {
18
19
  stddev: 0.02 / Math.sqrt(2 * s.nLayer)
19
20
  }),
20
21
  biasInitializer: "zeros"
21
- }), this.attnDropout = this.tf.layers.dropout({ rate: s.dropout }), this.residDropout = this.tf.layers.dropout({ rate: s.dropout }), this.bias = this.tf.linalg.bandPart(this.tf.ones([s.blockSize, s.blockSize]), -1, 0).cast("bool"), this.divisor = this.tf.scalar(1 / Math.sqrt(s.nEmbed / s.nHead));
22
- const a = this.tf.zeros([s.blockSize, s.blockSize]), h = this.tf.fill([s.blockSize, s.blockSize], Number.NEGATIVE_INFINITY);
23
- this.maskInf = this.tf.where(this.bias, a, h);
22
+ }), this.attnDropout = this.tf.layers.dropout({ rate: s.dropout }), this.residDropout = this.tf.layers.dropout({ rate: s.dropout }), this.bias = this.tf.linalg.bandPart(this.tf.ones([s.blockSize, s.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(s.nEmbed / s.nHead);
23
+ const o = this.tf.zeros([s.blockSize, s.blockSize]), c = this.tf.fill([s.blockSize, s.blockSize], Number.NEGATIVE_INFINITY);
24
+ this.maskInf = this.tf.where(this.bias, o, c);
24
25
  }
25
26
  config;
26
27
  cAttn;
@@ -52,70 +53,70 @@ class S {
52
53
  this.cAttn.setWeights(t.get(`block_${this.index}_cAttn`) || []), this.cProj.setWeights(t.get(`block_${this.index}_cProj`) || []);
53
54
  }
54
55
  getAttentionScores(t, i, s) {
55
- const e = t.shape[2], h = this.tf.matMul(t, i, !1, !0).mul(this.divisor), n = this.maskInf.slice([0, 0], [e, e]).expandDims(0).expandDims(0), r = h.add(n), o = this.tf.softmax(r, -1);
56
+ const e = z(t, i, this.maskInf, this.divisor), o = this.tf.softmax(e, -1);
56
57
  return this.attnDropout.apply(o, { training: s });
57
58
  }
58
59
  // Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
59
60
  getAttentionScoresWithPast(t, i, s, e) {
60
- const a = t.shape[2];
61
- let n = this.tf.matMul(t, i, !1, !0).mul(this.divisor);
62
- if (a > 1 && e > 0)
61
+ const o = t.shape[2];
62
+ let r = this.tf.matMul(t, i, !1, !0).mul(this.divisor);
63
+ if (o > 1 && e > 0)
63
64
  throw new Error("Cannot use past with T_cur > 1");
64
- if (a > 1) {
65
- const o = this.maskInf.slice([0, 0], [a, a]).expandDims(0).expandDims(0);
66
- n = n.add(o);
65
+ if (o > 1) {
66
+ const a = this.maskInf.slice([0, 0], [o, o]).expandDims(0).expandDims(0);
67
+ r = r.add(a);
67
68
  }
68
- const r = this.tf.softmax(n, -1);
69
- return this.attnDropout.apply(r, { training: s });
69
+ const h = this.tf.softmax(r, -1);
70
+ return this.attnDropout.apply(h, { training: s });
70
71
  }
71
72
  getQKV(t) {
72
- const [i, s, e] = t.shape, a = this.cAttn.apply(t), [h, n, r] = this.tf.split(a, 3, -1);
73
- a.dispose();
74
- const o = e / this.config.nHead, u = this.tf.reshape(h, [i, s, this.config.nHead, o]);
75
- h.dispose();
73
+ const [i, s, e] = t.shape, o = this.cAttn.apply(t), [c, r, h] = this.tf.split(o, 3, -1);
74
+ o.dispose();
75
+ const a = e / this.config.nHead, u = this.tf.reshape(c, [i, s, this.config.nHead, a]);
76
+ c.dispose();
76
77
  const f = u.transpose([0, 2, 1, 3]);
77
78
  u.dispose();
78
- const d = this.tf.reshape(n, [i, s, this.config.nHead, o]);
79
- n.dispose();
80
- const c = d.transpose([0, 2, 1, 3]);
81
- d.dispose();
82
- const l = this.tf.reshape(r, [i, s, this.config.nHead, o]);
79
+ const d = this.tf.reshape(r, [i, s, this.config.nHead, a]);
83
80
  r.dispose();
81
+ const n = d.transpose([0, 2, 1, 3]);
82
+ d.dispose();
83
+ const l = this.tf.reshape(h, [i, s, this.config.nHead, a]);
84
+ h.dispose();
84
85
  const p = l.transpose([0, 2, 1, 3]);
85
- return l.dispose(), [f, c, p];
86
+ return l.dispose(), [f, n, p];
86
87
  }
87
88
  getOutputProjection(t, i) {
88
- const s = t.shape[0], e = t.shape[2], a = this.config.nEmbed, h = t.transpose([0, 2, 1, 3]), n = this.tf.reshape(h, [s, e, a]), r = this.cProj.apply(n);
89
- return this.residDropout.apply(r, { training: i });
89
+ const s = t.shape[0], e = t.shape[2], o = this.config.nEmbed, c = t.transpose([0, 2, 1, 3]), r = this.tf.reshape(c, [s, e, o]), h = this.cProj.apply(r);
90
+ return this.residDropout.apply(h, { training: i });
90
91
  }
91
92
  // Added optional KV cache support (pastKV). Returns presentKV for chaining.
92
93
  call(t, i = !1, s = !1, e) {
93
94
  if (e && !this.config.useRope)
94
95
  throw new Error("Cannot use pastKV without RoPE enabled");
95
96
  return this.tf.tidy(() => {
96
- const [a, h, n] = this.getQKV(t), r = a.shape[2], o = this.config.blockSize, u = e ? e.cumulativeLength : 0, [f, d] = this.ropeCache ? this.ropeCache.applyRoPE(a, h, u) : [a, h];
97
- let c = d, l = n, p = 0;
98
- e && (p = e.length, c = this.tf.concat([e.k, d], 2), l = this.tf.concat([e.v, n], 2));
99
- const b = c.shape[2];
100
- if (b > o) {
101
- const k = b - o, g = c.shape[0], v = c.shape[1], I = c.shape[3];
102
- c = c.slice([0, 0, k, 0], [g, v, o, I]), l = l.slice([0, 0, k, 0], [g, v, o, I]), p = o - r;
97
+ const [o, c, r] = this.getQKV(t), h = o.shape[2], a = this.config.blockSize, u = e ? e.cumulativeLength : 0, [f, d] = this.ropeCache ? this.ropeCache.applyRoPE(o, c, u) : [o, c];
98
+ let n = d, l = r, p = 0;
99
+ e && (p = e.length, n = this.tf.concat([e.k, d], 2), l = this.tf.concat([e.v, r], 2));
100
+ const b = n.shape[2];
101
+ if (b > a) {
102
+ const k = b - a, g = n.shape[0], I = n.shape[1], _ = n.shape[3];
103
+ n = n.slice([0, 0, k, 0], [g, I, a, _]), l = l.slice([0, 0, k, 0], [g, I, a, _]), p = a - h;
103
104
  }
104
105
  let m;
105
- p > 0 ? m = this.getAttentionScoresWithPast(f, c, i, p) : m = this.getAttentionScores(f, c, i);
106
- const _ = this.tf.matMul(m, l), A = this.getOutputProjection(_, i), P = {
107
- k: this.tf.keep(c),
106
+ p > 0 ? m = this.getAttentionScoresWithPast(f, n, i, p) : m = this.getAttentionScores(f, n, i);
107
+ const v = this.tf.matMul(m, l), A = this.getOutputProjection(v, i), P = {
108
+ k: this.tf.keep(n),
108
109
  v: this.tf.keep(l),
109
- length: p + r,
110
- cumulativeLength: e ? e.cumulativeLength + r : r
110
+ length: p + h,
111
+ cumulativeLength: e ? e.cumulativeLength + h : h
111
112
  };
112
113
  return { output: A, attention: s ? m.mean(1) : void 0, presentKV: P };
113
114
  });
114
115
  }
115
116
  dispose() {
116
- this.cAttn.dispose(), this.cProj.dispose(), this.attnDropout.dispose(), this.residDropout.dispose(), this.bias.dispose(), this.maskInf.dispose(), this.divisor.dispose();
117
+ this.cAttn.dispose(), this.cProj.dispose(), this.attnDropout.dispose(), this.residDropout.dispose(), this.bias.dispose(), this.maskInf.dispose();
117
118
  }
118
119
  }
119
120
  export {
120
- S as default
121
+ j as default
121
122
  };
@@ -1,29 +1,7 @@
1
- import { o as h, c as u, x as B, E as c, B as V, y as X, D as Y, I as Z, F as ee, N as te, H as se, J as ne, K as re, O as ae, Q as ue, f as L, w as ie, T as A, m as oe, U as le, t as ce, k as C, V as P, v as U, _ as H } from "../index-D1SlunD-.js";
2
- import { s as pe, r as f } from "../sum-02UQ5Eaq.js";
3
- import { c as he } from "../complex-D6Bq1XDf.js";
4
- /**
5
- * @license
6
- * Copyright 2020 Google LLC. All Rights Reserved.
7
- * Licensed under the Apache License, Version 2.0 (the "License");
8
- * you may not use this file except in compliance with the License.
9
- * You may obtain a copy of the License at
10
- *
11
- * http://www.apache.org/licenses/LICENSE-2.0
12
- *
13
- * Unless required by applicable law or agreed to in writing, software
14
- * distributed under the License is distributed on an "AS IS" BASIS,
15
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
- * See the License for the specific language governing permissions and
17
- * limitations under the License.
18
- * =============================================================================
19
- */
20
- function fe(t, e, s = !1, n = !1) {
21
- let r = u(t, "a", "matMul"), i = u(e, "b", "matMul");
22
- [r, i] = B(r, i);
23
- const o = { a: r, b: i }, p = { transposeA: s, transposeB: n };
24
- return c.runKernel(V, o, p);
25
- }
26
- const m = /* @__PURE__ */ h({ matMul_: fe });
1
+ import { o as h, c as i, E as o, y as V, D as X, I as Y, F as Z, N as ee, H as te, J as se, K as ne, O as re, Q as ue, g as L, x as ae, T as A, m as ie, U as oe, u as le, b as q, l as C, V as P, w as U, _ as H } from "../index-DQfEAU9u.js";
2
+ import { s as ce, r as f } from "../sum-B-O33dgG.js";
3
+ import { m } from "../mat_mul-CuHB58-H.js";
4
+ import { c as pe } from "../complex-CeoYJn2o.js";
27
5
  /**
28
6
  * @license
29
7
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -40,11 +18,11 @@ const m = /* @__PURE__ */ h({ matMul_: fe });
40
18
  * limitations under the License.
41
19
  * =============================================================================
42
20
  */
43
- function de(t) {
44
- const s = { x: u(t, "x", "sigmoid", "float32") };
45
- return c.runKernel(X, s);
21
+ function he(t) {
22
+ const s = { x: i(t, "x", "sigmoid", "float32") };
23
+ return o.runKernel(V, s);
46
24
  }
47
- const me = /* @__PURE__ */ h({ sigmoid_: de });
25
+ const fe = /* @__PURE__ */ h({ sigmoid_: he });
48
26
  /**
49
27
  * @license
50
28
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -61,11 +39,11 @@ const me = /* @__PURE__ */ h({ sigmoid_: de });
61
39
  * limitations under the License.
62
40
  * =============================================================================
63
41
  */
64
- function ge(t) {
65
- const s = { x: u(t, "x", "elu", "float32") };
66
- return c.runKernel(Y, s);
42
+ function de(t) {
43
+ const s = { x: i(t, "x", "elu", "float32") };
44
+ return o.runKernel(X, s);
67
45
  }
68
- const $e = /* @__PURE__ */ h({ elu_: ge });
46
+ const me = /* @__PURE__ */ h({ elu_: de });
69
47
  /**
70
48
  * @license
71
49
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -82,11 +60,11 @@ const $e = /* @__PURE__ */ h({ elu_: ge });
82
60
  * limitations under the License.
83
61
  * =============================================================================
84
62
  */
85
- function xe(t) {
86
- const s = { input: u(t, "input", "imag") };
87
- return c.runKernel(Z, s);
63
+ function ge(t) {
64
+ const s = { input: i(t, "input", "imag") };
65
+ return o.runKernel(Y, s);
88
66
  }
89
- const ke = /* @__PURE__ */ h({ imag_: xe });
67
+ const $e = /* @__PURE__ */ h({ imag_: ge });
90
68
  /**
91
69
  * @license
92
70
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -103,11 +81,11 @@ const ke = /* @__PURE__ */ h({ imag_: xe });
103
81
  * limitations under the License.
104
82
  * =============================================================================
105
83
  */
106
- function De(t, e = 0.2) {
107
- const n = { x: u(t, "x", "leakyRelu") }, r = { alpha: e };
108
- return c.runKernel(ee, n, r);
84
+ function xe(t, e = 0.2) {
85
+ const n = { x: i(t, "x", "leakyRelu") }, r = { alpha: e };
86
+ return o.runKernel(Z, n, r);
109
87
  }
110
- const be = /* @__PURE__ */ h({ leakyRelu_: De });
88
+ const ke = /* @__PURE__ */ h({ leakyRelu_: xe });
111
89
  /**
112
90
  * @license
113
91
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -124,11 +102,11 @@ const be = /* @__PURE__ */ h({ leakyRelu_: De });
124
102
  * limitations under the License.
125
103
  * =============================================================================
126
104
  */
127
- function ye(t) {
128
- const s = { x: u(t, "x", "neg") };
129
- return c.runKernel(te, s);
105
+ function De(t) {
106
+ const s = { x: i(t, "x", "neg") };
107
+ return o.runKernel(ee, s);
130
108
  }
131
- const Se = /* @__PURE__ */ h({ neg_: ye });
109
+ const be = /* @__PURE__ */ h({ neg_: De });
132
110
  /**
133
111
  * @license
134
112
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -145,11 +123,11 @@ const Se = /* @__PURE__ */ h({ neg_: ye });
145
123
  * limitations under the License.
146
124
  * =============================================================================
147
125
  */
148
- function Me(t, e) {
149
- const s = u(t, "x", "prelu"), n = u(e, "alpha", "prelu"), r = { x: s, alpha: n };
150
- return c.runKernel(se, r);
126
+ function ye(t, e) {
127
+ const s = i(t, "x", "prelu"), n = i(e, "alpha", "prelu"), r = { x: s, alpha: n };
128
+ return o.runKernel(te, r);
151
129
  }
152
- const Ke = /* @__PURE__ */ h({ prelu_: Me });
130
+ const Se = /* @__PURE__ */ h({ prelu_: ye });
153
131
  /**
154
132
  * @license
155
133
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -166,11 +144,11 @@ const Ke = /* @__PURE__ */ h({ prelu_: Me });
166
144
  * limitations under the License.
167
145
  * =============================================================================
168
146
  */
169
- function _e(t) {
170
- const s = { input: u(t, "input", "real") };
171
- return c.runKernel(ne, s);
147
+ function Ke(t) {
148
+ const s = { input: i(t, "input", "real") };
149
+ return o.runKernel(se, s);
172
150
  }
173
- const we = /* @__PURE__ */ h({ real_: _e });
151
+ const _e = /* @__PURE__ */ h({ real_: Ke });
174
152
  /**
175
153
  * @license
176
154
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -187,11 +165,11 @@ const we = /* @__PURE__ */ h({ real_: _e });
187
165
  * limitations under the License.
188
166
  * =============================================================================
189
167
  */
190
- function We(t) {
191
- const s = { x: u(t, "x", "relu") };
192
- return c.runKernel(re, s);
168
+ function Me(t) {
169
+ const s = { x: i(t, "x", "relu") };
170
+ return o.runKernel(ne, s);
193
171
  }
194
- const ze = /* @__PURE__ */ h({ relu_: We });
172
+ const we = /* @__PURE__ */ h({ relu_: Me });
195
173
  /**
196
174
  * @license
197
175
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -208,11 +186,11 @@ const ze = /* @__PURE__ */ h({ relu_: We });
208
186
  * limitations under the License.
209
187
  * =============================================================================
210
188
  */
211
- function Ee(t) {
212
- const s = { x: u(t, "x", "relu6") };
213
- return c.runKernel(ae, s);
189
+ function We(t) {
190
+ const s = { x: i(t, "x", "relu6") };
191
+ return o.runKernel(re, s);
214
192
  }
215
- const Oe = /* @__PURE__ */ h({ relu6_: Ee });
193
+ const ze = /* @__PURE__ */ h({ relu6_: We });
216
194
  /**
217
195
  * @license
218
196
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -229,11 +207,11 @@ const Oe = /* @__PURE__ */ h({ relu6_: Ee });
229
207
  * limitations under the License.
230
208
  * =============================================================================
231
209
  */
232
- function Fe(t, e = 0) {
233
- const n = { x: u(t, "x", "step") }, r = { alpha: e };
234
- return c.runKernel(ue, n, r);
210
+ function Ee(t, e = 0) {
211
+ const n = { x: i(t, "x", "step") }, r = { alpha: e };
212
+ return o.runKernel(ue, n, r);
235
213
  }
236
- const Re = /* @__PURE__ */ h({ step_: Fe });
214
+ const Oe = /* @__PURE__ */ h({ step_: Ee });
237
215
  /**
238
216
  * @license
239
217
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -250,19 +228,19 @@ const Re = /* @__PURE__ */ h({ step_: Fe });
250
228
  * limitations under the License.
251
229
  * =============================================================================
252
230
  */
253
- function Ae(t, e, s) {
254
- const n = u(t, "x", "transpose");
255
- if (e == null && (e = n.shape.map((o, p) => p).reverse()), L(n.rank === e.length, () => `Error in transpose: rank of input ${n.rank} must match length of perm ${e}.`), e.forEach((o) => {
256
- L(o >= 0 && o < n.rank, () => `All entries in 'perm' must be between 0 and ${n.rank - 1} but got ${e}`);
231
+ function Fe(t, e, s) {
232
+ const n = i(t, "x", "transpose");
233
+ if (e == null && (e = n.shape.map((l, p) => p).reverse()), L(n.rank === e.length, () => `Error in transpose: rank of input ${n.rank} must match length of perm ${e}.`), e.forEach((l) => {
234
+ L(l >= 0 && l < n.rank, () => `All entries in 'perm' must be between 0 and ${n.rank - 1} but got ${e}`);
257
235
  }), n.rank <= 1)
258
236
  return n.clone();
259
- const r = { x: n }, i = { perm: e };
260
- return n.dtype === "complex64" ? ie(() => {
261
- let o = we(n), p = ke(n);
262
- return o = c.runKernel(A, { x: o }, i), p = c.runKernel(A, { x: p }, i), s && (p = Se(p)), he(o, p);
263
- }) : c.runKernel(A, r, i);
237
+ const r = { x: n }, c = { perm: e };
238
+ return n.dtype === "complex64" ? ae(() => {
239
+ let l = _e(n), p = $e(n);
240
+ return l = o.runKernel(A, { x: l }, c), p = o.runKernel(A, { x: p }, c), s && (p = be(p)), pe(l, p);
241
+ }) : o.runKernel(A, r, c);
264
242
  }
265
- const Be = /* @__PURE__ */ h({ transpose_: Ae });
243
+ const Re = /* @__PURE__ */ h({ transpose_: Fe });
266
244
  /**
267
245
  * @license
268
246
  * Copyright 2019 Google LLC. All Rights Reserved.
@@ -279,36 +257,36 @@ const Be = /* @__PURE__ */ h({ transpose_: Ae });
279
257
  * limitations under the License.
280
258
  * =============================================================================
281
259
  */
282
- function Le(t, e, s) {
260
+ function Ae(t, e, s) {
283
261
  if (s == null || s === "linear")
284
262
  return t;
285
263
  if (s === "relu")
286
- return oe(t, Re(e));
264
+ return ie(t, Oe(e));
287
265
  throw new Error(`Cannot compute gradient for fused activation ${s}.`);
288
266
  }
289
- function Te(t, e) {
267
+ function Le(t, e) {
290
268
  let s = e;
291
- const n = le(t.shape, e.shape);
292
- return n.length > 0 && (s = pe(s, n)), f(s, t.shape);
269
+ const n = oe(t.shape, e.shape);
270
+ return n.length > 0 && (s = ce(s, n)), f(s, t.shape);
293
271
  }
294
- function Ne(t, e, s, n) {
272
+ function Te(t, e, s, n) {
295
273
  if (e === "linear")
296
274
  return t;
297
275
  if (e === "relu")
298
- return ze(t);
276
+ return we(t);
299
277
  if (e === "elu")
300
- return $e(t);
278
+ return me(t);
301
279
  if (e === "relu6")
302
- return Oe(t);
280
+ return ze(t);
303
281
  if (e === "prelu")
304
- return Ke(t, s);
282
+ return Se(t, s);
305
283
  if (e === "leakyrelu")
306
- return be(t, n);
284
+ return ke(t, n);
307
285
  if (e === "sigmoid")
308
- return me(t);
286
+ return fe(t);
309
287
  throw new Error(`Unknown fused activation ${e}.`);
310
288
  }
311
- const ve = (t, e) => !(t > 0) || e === "linear";
289
+ const Be = (t, e) => !(t > 0) || e === "linear";
312
290
  /**
313
291
  * @license
314
292
  * Copyright 2019 Google LLC. All Rights Reserved.
@@ -325,49 +303,49 @@ const ve = (t, e) => !(t > 0) || e === "linear";
325
303
  * limitations under the License.
326
304
  * =============================================================================
327
305
  */
328
- function Ge({ a: t, b: e, transposeA: s = !1, transposeB: n = !1, bias: r, activation: i = "linear", preluActivationWeights: o, leakyreluAlpha: p = 0.2 }) {
329
- if (ve(c.state.gradientDepth, i) === !1) {
306
+ function Ne({ a: t, b: e, transposeA: s = !1, transposeB: n = !1, bias: r, activation: c = "linear", preluActivationWeights: l, leakyreluAlpha: p = 0.2 }) {
307
+ if (Be(o.state.gradientDepth, c) === !1) {
330
308
  let x = m(t, e, s, n);
331
- return r != null && (x = ce(x, r)), Ne(x, i, o, p);
309
+ return r != null && (x = le(x, r)), Te(x, c, l, p);
332
310
  }
333
- let a = u(t, "a", "fused matMul"), l = u(e, "b", "fused matMul");
334
- [a, l] = B(a, l);
335
- const D = s ? a.shape[a.rank - 2] : a.shape[a.rank - 1], b = n ? l.shape[l.rank - 1] : l.shape[l.rank - 2], w = s ? a.shape[a.rank - 1] : a.shape[a.rank - 2], W = n ? l.shape[l.rank - 2] : l.shape[l.rank - 1], T = a.shape.slice(0, -2), y = l.shape.slice(0, -2), N = C(T), v = C(y);
336
- L(D === b, () => `Error in fused matMul: inner shapes (${D}) and (${b}) of Tensors with shapes ${a.shape} and ${l.shape} and transposeA=${s} and transposeB=${n} must match.`);
337
- const O = P(a.shape.slice(0, -2), l.shape.slice(0, -2)).concat([w, W]), F = s ? f(a, [N, D, w]) : f(a, [N, w, D]), R = n ? f(l, [v, W, b]) : f(l, [v, b, W]);
311
+ let u = i(t, "a", "fused matMul"), a = i(e, "b", "fused matMul");
312
+ [u, a] = q(u, a);
313
+ const D = s ? u.shape[u.rank - 2] : u.shape[u.rank - 1], b = n ? a.shape[a.rank - 1] : a.shape[a.rank - 2], w = s ? u.shape[u.rank - 1] : u.shape[u.rank - 2], W = n ? a.shape[a.rank - 2] : a.shape[a.rank - 1], T = u.shape.slice(0, -2), y = a.shape.slice(0, -2), B = C(T), N = C(y);
314
+ L(D === b, () => `Error in fused matMul: inner shapes (${D}) and (${b}) of Tensors with shapes ${u.shape} and ${a.shape} and transposeA=${s} and transposeB=${n} must match.`);
315
+ const O = P(u.shape.slice(0, -2), a.shape.slice(0, -2)).concat([w, W]), F = s ? f(u, [B, D, w]) : f(u, [B, w, D]), R = n ? f(a, [N, W, b]) : f(a, [N, b, W]);
338
316
  let S;
339
- r != null && (S = u(r, "bias", "fused matMul"), [S] = B(S, a), P(O, S.shape));
317
+ r != null && (S = i(r, "bias", "fused matMul"), [S] = q(S, u), P(O, S.shape));
340
318
  let G;
341
- o != null && (G = u(o, "prelu weights", "fused matMul"));
342
- const I = (x, _) => {
343
- const [g, $, k, z] = _, d = Le(f(x, k.shape), k, i);
344
- let M, K;
345
- if (!s && !n ? (M = m(d, $, !1, !0), K = m(g, d, !0, !1)) : !s && n ? (M = m(d, $, !1, !1), K = m(d, g, !0, !1)) : s && !n ? (M = m($, d, !1, !0), K = m(g, d, !1, !1)) : (M = m($, d, !0, !0), K = m(d, g, !0, !0)), r != null) {
346
- const Q = Te(z, d);
347
- return [M, K, Q];
319
+ l != null && (G = i(l, "prelu weights", "fused matMul"));
320
+ const I = (x, M) => {
321
+ const [g, $, k, z] = M, d = Ae(f(x, k.shape), k, c);
322
+ let K, _;
323
+ if (!s && !n ? (K = m(d, $, !1, !0), _ = m(g, d, !0, !1)) : !s && n ? (K = m(d, $, !1, !1), _ = m(d, g, !0, !1)) : s && !n ? (K = m($, d, !1, !0), _ = m(g, d, !1, !1)) : (K = m($, d, !0, !0), _ = m(d, g, !0, !0)), r != null) {
324
+ const Q = Le(z, d);
325
+ return [K, _, Q];
348
326
  } else
349
- return [M, K];
350
- }, j = {
327
+ return [K, _];
328
+ }, v = {
351
329
  a: F,
352
330
  b: R,
353
331
  bias: S,
354
332
  preluActivationWeights: G
355
- }, q = { transposeA: s, transposeB: n, activation: i, leakyreluAlpha: p };
356
- return r == null ? U((_, g, $) => {
333
+ }, j = { transposeA: s, transposeB: n, activation: c, leakyreluAlpha: p };
334
+ return r == null ? U((M, g, $) => {
357
335
  const k = (
358
336
  // tslint:disable-next-line: no-unnecessary-type-assertion
359
- c.runKernel(H, j, q)
337
+ o.runKernel(H, v, j)
360
338
  );
361
- return $([_, g, k]), { value: f(k, O), gradFunc: I };
362
- })(F, R) : U((_, g, $, k) => {
339
+ return $([M, g, k]), { value: f(k, O), gradFunc: I };
340
+ })(F, R) : U((M, g, $, k) => {
363
341
  const z = (
364
342
  // tslint:disable-next-line: no-unnecessary-type-assertion
365
- c.runKernel(H, j, q)
343
+ o.runKernel(H, v, j)
366
344
  );
367
- return k([_, g, z, $]), { value: f(z, O), gradFunc: I };
345
+ return k([M, g, z, $]), { value: f(z, O), gradFunc: I };
368
346
  })(F, R, S);
369
347
  }
370
- const J = /* @__PURE__ */ h({ fusedMatMul_: Ge });
348
+ const J = /* @__PURE__ */ h({ fusedMatMul_: Ne });
371
349
  /**
372
350
  * @license
373
351
  * Copyright 2018 Google LLC
@@ -391,12 +369,12 @@ class E extends Error {
391
369
  * https://opensource.org/licenses/MIT.
392
370
  * =============================================================================
393
371
  */
394
- function Ie(t, e, s, n) {
372
+ function Ge(t, e, s, n) {
395
373
  if (t.rank < 2 || e.rank < 2)
396
374
  throw new E(`dot requires both inputs to be rank >= 2 but got x shape = ${t.shape} and y shape = ${e.shape}`);
397
375
  if (e.rank >= 3) {
398
- const r = t.shape.slice(-1)[0], i = e.shape.slice(-2)[0];
399
- if (r !== i)
376
+ const r = t.shape.slice(-1)[0], c = e.shape.slice(-2)[0];
377
+ if (r !== c)
400
378
  throw new E(`If rank y >= 3, then the second last dim of y must equal the last dim of x but got x shape = ${t.shape} and y shape = ${e.shape}`);
401
379
  }
402
380
  if (t.rank === 2 && e.rank === 2)
@@ -409,11 +387,11 @@ function Ie(t, e, s, n) {
409
387
  activation: s
410
388
  });
411
389
  {
412
- const r = t.shape.slice(), i = r.pop();
413
- t = f(t, [-1, i]);
414
- const o = e.shape.slice(), p = o.pop(), a = o.pop(), l = [...o, p], D = Array.from({ length: e.rank }, (T, y) => y === 0 ? e.rank - 2 : y <= e.rank - 2 ? y - 1 : y);
415
- e = f(Be(e, D), [a, -1]);
416
- const b = [...r, ...l];
390
+ const r = t.shape.slice(), c = r.pop();
391
+ t = f(t, [-1, c]);
392
+ const l = e.shape.slice(), p = l.pop(), u = l.pop(), a = [...l, p], D = Array.from({ length: e.rank }, (T, y) => y === 0 ? e.rank - 2 : y <= e.rank - 2 ? y - 1 : y);
393
+ e = f(Re(e, D), [u, -1]);
394
+ const b = [...r, ...a];
417
395
  return f(J({
418
396
  a: t,
419
397
  b: e,
@@ -424,7 +402,7 @@ function Ie(t, e, s, n) {
424
402
  }), b);
425
403
  }
426
404
  }
427
- class Ue {
405
+ class Pe {
428
406
  vocabSize;
429
407
  embedDim;
430
408
  tf;
@@ -447,7 +425,7 @@ class Ue {
447
425
  return this.tf.gather(this.tiedWeights, e, 0);
448
426
  }
449
427
  project(e) {
450
- return Ie(e, this.tiedWeights.transpose());
428
+ return Ge(e, this.tiedWeights.transpose());
451
429
  }
452
430
  getWeights() {
453
431
  return [this.tiedWeights];
@@ -466,5 +444,5 @@ class Ue {
466
444
  }
467
445
  }
468
446
  export {
469
- Ue as default
447
+ Pe as default
470
448
  };
package/dist/main.js CHANGED
@@ -1,20 +1,21 @@
1
- import { default as r } from "./NanoGPTModel.js";
2
- import { default as s } from "./TeachableLLM.js";
3
- import { default as i } from "./tokeniser/CharTokeniser.js";
1
+ import { default as m } from "./NanoGPTModel.js";
2
+ import { default as i } from "./TeachableLLM.js";
3
+ import { default as l } from "./tokeniser/CharTokeniser.js";
4
4
  import { default as d } from "./utilities/waitForModel.js";
5
- import { default as u } from "./data/textLoader.js";
6
- import { estimateMemoryUsage as n, estimateParameterCount as T, estimateResources as g, estimateTrainingMemoryUsage as M, validateConfig as C } from "./utilities/parameters.js";
5
+ import { default as x } from "./data/textLoader.js";
6
+ import { estimateMemoryUsage as T, estimateParameterCount as g, estimateResources as M, estimateTrainingMemoryUsage as C, validateConfig as c } from "./utilities/parameters.js";
7
7
  import "./ops/scatterSub.js";
8
8
  import "./ops/gatherSub.js";
9
+ import "./ops/attentionMask.js";
9
10
  export {
10
- i as CharTokeniser,
11
- r as NanoGPT,
12
- s as TeachableLLM,
13
- n as estimateMemoryUsage,
14
- T as estimateParameterCount,
15
- g as estimateResources,
16
- M as estimateTrainingMemoryUsage,
17
- u as loadTextData,
18
- C as validateConfig,
11
+ l as CharTokeniser,
12
+ m as NanoGPT,
13
+ i as TeachableLLM,
14
+ T as estimateMemoryUsage,
15
+ g as estimateParameterCount,
16
+ M as estimateResources,
17
+ C as estimateTrainingMemoryUsage,
18
+ x as loadTextData,
19
+ c as validateConfig,
19
20
  d as waitForModel
20
21
  };
@@ -0,0 +1,27 @@
1
+ import { o as c, c as s, b as m, E as M, B as p } from "./index-DQfEAU9u.js";
2
+ /**
3
+ * @license
4
+ * Copyright 2020 Google LLC. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ * =============================================================================
17
+ */
18
+ function b(e, o, n = !1, l = !1) {
19
+ let a = s(e, "a", "matMul"), t = s(o, "b", "matMul");
20
+ [a, t] = m(a, t);
21
+ const r = { a, b: t }, u = { transposeA: n, transposeB: l };
22
+ return M.runKernel(p, r, u);
23
+ }
24
+ const i = /* @__PURE__ */ c({ matMul_: b });
25
+ export {
26
+ i as m
27
+ };
@@ -0,0 +1,2 @@
1
+ import { Tensor } from '@tensorflow/tfjs';
2
+ export declare function attentionMask(q: Tensor, k: Tensor, mask: Tensor, divisor: number): Tensor;
@@ -0,0 +1,62 @@
1
+ import { engine as d } from "@tensorflow/tfjs";
2
+ import { r as k, s as u } from "../index-DQfEAU9u.js";
3
+ import { m as l } from "../mat_mul-CuHB58-H.js";
4
+ class p {
5
+ variableNames = ["q", "k", "mask"];
6
+ outputShape;
7
+ userCode;
8
+ // enableShapeUniforms = true;
9
+ customUniforms = [{ name: "divisor", type: "float" }];
10
+ constructor(t, e, n, a) {
11
+ this.outputShape = [t, e, n, n], this.userCode = `
12
+ void main() {
13
+ ivec4 coords = getOutputCoords(); // [batch, nh, t1, t2]
14
+ int b = coords.x;
15
+ int h = coords.y;
16
+ int t1 = coords.z;
17
+ int t2 = coords.w;
18
+
19
+ float sum = 0.0;
20
+ for (int i = 0; i < ${a}; ++i) {
21
+ float qv = getQ(b, h, t1, i);
22
+ float kv = getK(b, h, t2, i); // k is transposed on last two dims
23
+ sum += qv * kv;
24
+ }
25
+
26
+ // Scale by divisor
27
+ float scaled = sum * divisor;
28
+
29
+ // Add mask
30
+ float maskVal = getMask(t1, t2); // mask is [T,T]
31
+
32
+ setOutput(scaled + maskVal);
33
+ }
34
+ `;
35
+ }
36
+ }
37
+ function f(s) {
38
+ const { q: t, k: e, mask: n } = s.inputs, { divisor: a } = s.attrs, o = s.backend, c = t.shape[0], i = t.shape[2], r = t.shape[1], m = new p(c, r, i, t.shape[3]);
39
+ return o.runWebGLProgram(m, [t, e, n], "float32", [[a]]);
40
+ }
41
+ const h = {
42
+ kernelName: "AttentionMask",
43
+ backendName: "webgl",
44
+ kernelFunc: f
45
+ };
46
+ k(h);
47
+ function b(s) {
48
+ const { q: t, k: e, mask: n } = s.inputs, { divisor: a } = s.attrs, o = t.shape[2], i = l(t, e, !1, !0).mul(u(a)), r = n.slice([0, 0], [o, o]).expandDims(0).expandDims(0);
49
+ return i.add(r);
50
+ }
51
+ const v = {
52
+ kernelName: "AttentionMask",
53
+ backendName: "cpu",
54
+ kernelFunc: b
55
+ };
56
+ k(v);
57
+ function C(s, t, e, n) {
58
+ return d().runKernel("AttentionMask", { q: s, k: t, mask: e }, { divisor: n });
59
+ }
60
+ export {
61
+ C as attentionMask
62
+ };
@@ -1,6 +1,6 @@
1
1
  import { engine as l } from "@tensorflow/tfjs";
2
- import { o as g, c as i, E as b, G as d, r as c, a as h } from "../index-D1SlunD-.js";
3
- import { r as p, s as f } from "../stack-DB2YLlAs.js";
2
+ import { o as g, c as i, E as b, G as d, r as c, a as h } from "../index-DQfEAU9u.js";
3
+ import { r as p, s as f } from "../stack-C9cTkqpq.js";
4
4
  /**
5
5
  * @license
6
6
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { r as o } from "../../index-D1SlunD-.js";
1
+ import { r as o } from "../../index-DQfEAU9u.js";
2
2
  function r(e) {
3
3
  const { logits: t, labels: n } = e.inputs;
4
4
  return e.backend.executeMultipleOutputs("SparseSoftmaxCrossEntropyWithLogits", [], [t, n], 2);
@@ -1,7 +1,7 @@
1
1
  import { engine as $ } from "@tensorflow/tfjs";
2
- import { i as u, j as S, k as h, E as f, l as E, o as N, c as l, n as y, r as p, a as D, m as x } from "../index-D1SlunD-.js";
3
- import { c as m } from "../complex-D6Bq1XDf.js";
4
- import { r as v, s as T } from "../stack-DB2YLlAs.js";
2
+ import { j as u, k as S, l as p, E as f, n as E, o as N, c as l, p as y, r as h, a as D, m as x } from "../index-DQfEAU9u.js";
3
+ import { c as m } from "../complex-CeoYJn2o.js";
4
+ import { r as v, s as T } from "../stack-C9cTkqpq.js";
5
5
  /**
6
6
  * @license
7
7
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -23,7 +23,7 @@ function i(e, t = "float32") {
23
23
  const a = i(e, "float32"), o = i(e, "float32");
24
24
  return m(a, o);
25
25
  }
26
- const r = S(h(e), t);
26
+ const r = S(p(e), t);
27
27
  return f.makeTensor(r, e, t);
28
28
  }
29
29
  /**
@@ -47,7 +47,7 @@ function d(e, t = "float32") {
47
47
  const a = d(e, "float32"), o = i(e, "float32");
48
48
  return m(a, o);
49
49
  }
50
- const r = E(h(e), t);
50
+ const r = E(p(e), t);
51
51
  return f.makeTensor(r, e, t);
52
52
  }
53
53
  function C(e, t, r) {
@@ -131,7 +131,7 @@ const K = {
131
131
  backendName: "webgl",
132
132
  kernelFunc: P
133
133
  };
134
- p(K);
134
+ h(K);
135
135
  function A(e) {
136
136
  const { logits: t, labels: r, dy: a } = e.inputs, o = r.shape[0], s = t.shape[1], n = v(0, o, 1, "int32"), c = T([n, r], 1), b = d([o]), g = I(c, b, [o, s]), k = D(t, g), w = a.reshape([o, 1]);
137
137
  return x(k, w);
@@ -141,7 +141,7 @@ const F = {
141
141
  backendName: "cpu",
142
142
  kernelFunc: A
143
143
  };
144
- p(F);
144
+ h(F);
145
145
  function M(e, t, r) {
146
146
  return $().runKernel("EfficientScatterSub", { logits: e, labels: t, dy: r }, {});
147
147
  }
@@ -1,4 +1,4 @@
1
- import { E as e, R as c, o as f, d as u, f as a, P as i } from "./index-D1SlunD-.js";
1
+ import { E as e, R as c, o as f, f as u, g as a, P as i } from "./index-DQfEAU9u.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -15,7 +15,7 @@ import { E as e, R as c, o as f, d as u, f as a, P as i } from "./index-D1SlunD-
15
15
  * limitations under the License.
16
16
  * =============================================================================
17
17
  */
18
- function g(n, s, t = 1, r = "float32") {
18
+ function l(n, s, t = 1, r = "float32") {
19
19
  if (t === 0)
20
20
  throw new Error("Cannot have a step of zero");
21
21
  const o = { start: n, stop: s, step: t, dtype: r };
@@ -45,6 +45,6 @@ function k(n, s = 0) {
45
45
  }
46
46
  const h = /* @__PURE__ */ f({ stack_: k });
47
47
  export {
48
- g as r,
48
+ l as r,
49
49
  h as s
50
50
  };
@@ -1,4 +1,4 @@
1
- import { o, c as a, E as u, g as p, h as i, S as x } from "./index-D1SlunD-.js";
1
+ import { o, c as a, E as u, h as i, i as p, S as x } from "./index-DQfEAU9u.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -17,7 +17,7 @@ import { o, c as a, E as u, g as p, h as i, S as x } from "./index-D1SlunD-.js";
17
17
  */
18
18
  function l(n, t) {
19
19
  const s = { x: a(n, "x", "reshape", "string_or_numeric") }, r = { shape: t };
20
- return u.runKernel(p, s, r);
20
+ return u.runKernel(i, s, r);
21
21
  }
22
22
  const h = /* @__PURE__ */ o({ reshape_: l });
23
23
  /**
@@ -38,7 +38,7 @@ const h = /* @__PURE__ */ o({ reshape_: l });
38
38
  */
39
39
  function m(n, t = null, e = !1) {
40
40
  let s = a(n, "x", "sum");
41
- s.dtype === "bool" && (s = i(s, "int32"));
41
+ s.dtype === "bool" && (s = p(s, "int32"));
42
42
  const r = { x: s }, c = { axis: t, keepDims: e };
43
43
  return u.runKernel(x, r, c);
44
44
  }
@@ -1,4 +1,4 @@
1
- import { A as r, m as c, s as h, a as g, e as o } from "../index-D1SlunD-.js";
1
+ import { A as r, m as c, s as h, a as g, e as o } from "../index-DQfEAU9u.js";
2
2
  class u extends r {
3
3
  constructor(t, e, s, a, i) {
4
4
  super(t, e, s, a), this.config = i, this.startLearningRate = t;
@@ -1,7 +1,7 @@
1
- import { gatherSub as K } from "../ops/gatherSub.js";
2
- import { scatterSub as _ } from "../ops/scatterSub.js";
3
- import { o as l, c as d, E as f, M as G, p as z, L as I, q as N, a as E, t as M, u as T, e as m, v as S, w as $, z as g } from "../index-D1SlunD-.js";
4
- import { s as F, r as b } from "../sum-02UQ5Eaq.js";
1
+ import { gatherSub as w } from "../ops/gatherSub.js";
2
+ import { scatterSub as K } from "../ops/scatterSub.js";
3
+ import { o as l, c as d, E as f, M as _, q as z, L as I, t as N, a as E, u as M, v as T, e as m, w as g, x as $, z as S } from "../index-DQfEAU9u.js";
4
+ import { s as F, r as b } from "../sum-B-O33dgG.js";
5
5
  /**
6
6
  * @license
7
7
  * Copyright 2017 Google LLC. All Rights Reserved.
@@ -47,9 +47,9 @@ function q(n, s) {
47
47
  */
48
48
  function A(n, s = null, t = !1) {
49
49
  const e = { x: d(n, "x", "max") }, r = { reductionIndices: s, keepDims: t };
50
- return f.runKernel(G, e, r);
50
+ return f.runKernel(_, e, r);
51
51
  }
52
- const k = /* @__PURE__ */ l({ max_: A });
52
+ const L = /* @__PURE__ */ l({ max_: A });
53
53
  /**
54
54
  * @license
55
55
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -109,7 +109,7 @@ const j = /* @__PURE__ */ l({ log_: W });
109
109
  * =============================================================================
110
110
  */
111
111
  function B(n, s = null, t = !1) {
112
- const a = d(n, "x", "logSumExp"), e = N(s, a.shape), r = k(
112
+ const a = d(n, "x", "logSumExp"), e = N(s, a.shape), r = L(
113
113
  a,
114
114
  e,
115
115
  !0
@@ -148,30 +148,30 @@ function J(n, s = -1) {
148
148
  const Q = /* @__PURE__ */ l({ softmax_: J });
149
149
  function R(n, s) {
150
150
  return $(() => {
151
- const t = n.shape[n.shape.length - 1], e = n.shape.slice(0, -1).reduce((h, x) => h * x, 1), r = n.shape.length > 2 ? n.reshape([e, t]) : n, c = s.shape.length > 1 ? s.reshape([e]).cast("int32") : s.cast("int32"), o = k(r, -1, !0), p = E(r, o), u = H(p, -1);
152
- return K(u, c, p);
151
+ const t = n.shape[n.shape.length - 1], e = n.shape.slice(0, -1).reduce((h, x) => h * x, 1), r = n.shape.length > 2 ? n.reshape([e, t]) : n, c = s.shape.length > 1 ? s.reshape([e]).cast("int32") : s.cast("int32"), o = L(r, -1, !0), p = E(r, o), u = H(p, -1);
152
+ return w(u, c, p);
153
153
  });
154
154
  }
155
- function Z() {
156
- return m().backendName === "tensorflow" ? S((s, t, a) => {
155
+ function ss() {
156
+ return m().backendName === "tensorflow" ? g((s, t, a) => {
157
157
  const e = s.shape.length > 2 ? s.reshape([-1, s.shape[s.shape.length - 1]]) : s, r = t.shape.length > 1 ? t.reshape([-1]).cast("int32") : t.cast("int32"), [c, o] = m().runKernel(
158
158
  "NativeSparseSoftmaxCrossEntropy",
159
159
  { logits: e, labels: r },
160
160
  {}
161
161
  );
162
- return a([o.reshape(s.shape)]), { value: c, gradFunc: (p, u) => [u[0], g(t)] };
163
- }) : S(
162
+ return a([o.reshape(s.shape)]), { value: c, gradFunc: (p, u) => [u[0], S(t)] };
163
+ }) : g(
164
164
  // @ts-expect-error Invalid params
165
165
  (s, t, a) => {
166
166
  const e = s.shape[s.shape.length - 1], c = s.shape.slice(0, -1).reduce((h, x) => h * x, 1), o = s.reshape([c, e]), p = t.reshape([c]).cast("int32"), u = R(o, p);
167
167
  return a([o, p]), o.dispose(), p.dispose(), { value: u, gradFunc: (h, x) => $(() => {
168
- const y = x[0], C = x[1], L = Q(y), v = _(L, C, h), w = g(t);
169
- return [v, w];
168
+ const k = x[0], y = x[1], C = Q(k), G = K(C, y, h), v = S(t);
169
+ return [G.reshape(s.shape), v];
170
170
  }) };
171
171
  }
172
172
  );
173
173
  }
174
174
  export {
175
- Z as createSoftmaxCrossEntropyWithGrad,
175
+ ss as createSoftmaxCrossEntropyWithGrad,
176
176
  R as sparseSoftmaxCrossEntropy
177
177
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@genai-fi/nanogpt",
3
- "version": "0.2.6",
3
+ "version": "0.2.8",
4
4
  "type": "module",
5
5
  "main": "dist/main.js",
6
6
  "types": "dist/main.d.ts",