@genai-fi/nanogpt 0.4.2 → 0.4.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. package/dist/Generator.js +3 -3
  2. package/dist/NanoGPTModel.js +73 -76
  3. package/dist/Reshape-CiAY8ltP.js +212 -0
  4. package/dist/TeachableLLM.js +7 -1
  5. package/dist/{TiedEmbedding-CnJ1bx4q.js → TiedEmbedding-DznFwzcB.js} +244 -244
  6. package/dist/{axis_util-BgTGy5w8.js → axis_util-QP0LdI1v.js} +1 -1
  7. package/dist/{concat-CuRsVY-K.js → concat-DvWM7HGZ.js} +1 -1
  8. package/dist/data/parquet.js +9 -6
  9. package/dist/data/textLoader.js +6 -5
  10. package/dist/{dropout-DfDdklfL.js → dropout-DFEXTPV0.js} +4 -4
  11. package/dist/{gather-ZYRWhmXR.js → gather-C5D8PxwA.js} +1 -1
  12. package/dist/gpgpu_math-CUzjlO9A.js +23 -0
  13. package/dist/{index-C4JCoBvj.js → index--6vO-cOz.js} +87 -87
  14. package/dist/{kernel_funcs_utils-CAd1h9X1.js → kernel_funcs_utils-C6YBCuOt.js} +72 -91
  15. package/dist/layers/CausalSelfAttention.js +44 -44
  16. package/dist/layers/MLP.js +31 -33
  17. package/dist/layers/RMSNorm.js +3 -3
  18. package/dist/layers/RoPECache.js +3 -3
  19. package/dist/layers/TiedEmbedding.js +5 -5
  20. package/dist/layers/TransformerBlock.js +1 -1
  21. package/dist/{log_sum_exp-BswFnwOb.js → log_sum_exp-CiEy1aUe.js} +7 -7
  22. package/dist/main.js +25 -19
  23. package/dist/{mat_mul-415y5Qn2.js → mat_mul-BEHRPMh0.js} +1 -1
  24. package/dist/{max-CP_9O2Yd.js → max-BUShNgfh.js} +1 -1
  25. package/dist/{moments-CjeIaVdp.js → moments-DYOHXoRV.js} +5 -5
  26. package/dist/{norm-CZM380I3.js → norm-DSva3hI3.js} +13 -13
  27. package/dist/{ones-Bf3YR48P.js → ones-D6kB8bdY.js} +2 -2
  28. package/dist/ops/appendCache.d.ts +1 -1
  29. package/dist/ops/appendCache.js +10 -4
  30. package/dist/ops/attentionMask.js +1 -1
  31. package/dist/ops/cpu/appendCache.d.ts +1 -2
  32. package/dist/ops/cpu/appendCache.js +15 -20
  33. package/dist/ops/cpu/attentionMask.js +10 -10
  34. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  35. package/dist/ops/cpu/gatherSub.js +4 -4
  36. package/dist/ops/cpu/gelu.js +1 -1
  37. package/dist/ops/cpu/matMulGelu.d.ts +1 -0
  38. package/dist/ops/cpu/matMulGelu.js +40 -0
  39. package/dist/ops/cpu/mulDropout.js +1 -1
  40. package/dist/ops/cpu/qkv.js +3 -3
  41. package/dist/ops/cpu/rope.js +5 -5
  42. package/dist/ops/cpu/scatterSub.js +4 -4
  43. package/dist/ops/fusedSoftmax.js +1 -1
  44. package/dist/ops/gatherSub.js +1 -1
  45. package/dist/ops/gelu.js +2 -2
  46. package/dist/ops/grads/attentionMask.js +1 -1
  47. package/dist/ops/grads/fusedSoftmax.js +2 -2
  48. package/dist/ops/grads/gelu.js +24 -3
  49. package/dist/ops/grads/matMulGelu.d.ts +1 -0
  50. package/dist/ops/grads/matMulGelu.js +17 -0
  51. package/dist/ops/grads/qkv.js +1 -1
  52. package/dist/ops/grads/rope.js +1 -1
  53. package/dist/ops/matMulGelu.d.ts +3 -0
  54. package/dist/ops/matMulGelu.js +14 -0
  55. package/dist/ops/mulDrop.js +1 -1
  56. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  57. package/dist/ops/qkv.js +1 -1
  58. package/dist/ops/scatterSub.js +1 -1
  59. package/dist/ops/webgl/appendCache.js +14 -13
  60. package/dist/ops/webgl/attentionMask.js +1 -1
  61. package/dist/ops/webgl/fusedSoftmax.js +689 -895
  62. package/dist/ops/webgl/gatherSub.js +1 -1
  63. package/dist/ops/webgl/gelu.js +2 -2
  64. package/dist/ops/webgl/matMulGelu.d.ts +20 -0
  65. package/dist/ops/webgl/matMulGelu.js +166 -0
  66. package/dist/ops/webgl/mulDropout.js +1 -1
  67. package/dist/ops/webgl/qkv.js +1 -1
  68. package/dist/ops/webgl/rope.js +1 -1
  69. package/dist/ops/webgl/scatterSub.js +1 -1
  70. package/dist/{range-9AzeApCc.js → range-C_vpUjBu.js} +1 -1
  71. package/dist/{reshape-Boe4DuIO.js → reshape-z51Eu-re.js} +1 -1
  72. package/dist/{sin-KmhiDuMa.js → sin-H567uayl.js} +1 -1
  73. package/dist/{slice_util-19zDNNSn.js → slice_util-BdhYwFY_.js} +2 -2
  74. package/dist/{softmax-Cujsg4ay.js → softmax-Dsxflvdl.js} +1 -1
  75. package/dist/{split-DbcNm1-i.js → split-B_k_jwud.js} +1 -1
  76. package/dist/{stack-D1YjmgKN.js → stack-CmqSdsfs.js} +1 -1
  77. package/dist/{sum-R28pucR5.js → sum-DdkDf2MG.js} +1 -1
  78. package/dist/{tensor-BVeHdl7V.js → tensor-BGYi41cj.js} +1 -1
  79. package/dist/{tensor2d-DqFGNs_K.js → tensor2d-DUr_htjt.js} +1 -1
  80. package/dist/{tfjs_backend-Cug-PH75.js → tfjs_backend-DuKis_xG.js} +46 -46
  81. package/dist/training/AdamExt.js +1 -1
  82. package/dist/training/DatasetBuilder.js +18 -18
  83. package/dist/training/FullTrainer.js +1 -1
  84. package/dist/training/Trainer.js +5 -5
  85. package/dist/training/sparseCrossEntropy.js +4 -4
  86. package/dist/utilities/dummy.js +2 -2
  87. package/dist/utilities/generate.js +3 -3
  88. package/dist/utilities/load.js +1 -1
  89. package/dist/utilities/profile.js +1 -1
  90. package/dist/utilities/weights.js +2 -2
  91. package/dist/{variable-LJT9Ld63.js → variable-BJTZ3jOy.js} +1 -1
  92. package/dist/{zeros-dnQxFgAD.js → zeros-8xl-W2DC.js} +1 -1
  93. package/package.json +1 -1
  94. package/dist/gelu-CnCt17Lk.js +0 -26
package/dist/Generator.js CHANGED
@@ -1,7 +1,7 @@
1
1
  import { E as u } from "./index-Dwqa6Zy2.js";
2
- import "./index-C4JCoBvj.js";
3
- import { t as d } from "./tensor2d-DqFGNs_K.js";
4
- import { c as p } from "./concat-CuRsVY-K.js";
2
+ import "./index--6vO-cOz.js";
3
+ import { t as d } from "./tensor2d-DUr_htjt.js";
4
+ import { c as p } from "./concat-DvWM7HGZ.js";
5
5
  class w extends u {
6
6
  constructor(s, e) {
7
7
  super(), this.model = s, this.tokeniser = e;
@@ -1,17 +1,17 @@
1
1
  import { defaultConfig as x } from "./config.js";
2
2
  import W from "./layers/TransformerBlock.js";
3
- import { E as F, D as P, T as q, r as K, p as T } from "./TiedEmbedding-CnJ1bx4q.js";
4
- import D from "./layers/RoPECache.js";
3
+ import { E as F, D as P, T as q, r as T, p as D } from "./TiedEmbedding-DznFwzcB.js";
4
+ import K from "./layers/RoPECache.js";
5
5
  import N from "./layers/RMSNorm.js";
6
6
  import { estimateParameterCount as R } from "./utilities/parameters.js";
7
7
  import { createSoftmaxCrossEntropyWithGrad as A } from "./training/sparseCrossEntropy.js";
8
8
  import B from "./layers/BaseLayer.js";
9
- import { o as y, h as E, p as G, E as v, a6 as O, a7 as Q, a8 as j, t as w, a5 as U, f as C } from "./index-C4JCoBvj.js";
10
- import { r as _ } from "./reshape-Boe4DuIO.js";
11
- import { r as V } from "./range-9AzeApCc.js";
12
- import { e as X } from "./tfjs_backend-Cug-PH75.js";
13
- import { g as H } from "./gather-ZYRWhmXR.js";
14
- import { s as J } from "./softmax-Cujsg4ay.js";
9
+ import { o as $, h as E, p as G, E as v, a9 as O, aa as j, ab as Q, t as w, a8 as V, f as C } from "./index--6vO-cOz.js";
10
+ import { r as _ } from "./reshape-z51Eu-re.js";
11
+ import { r as X } from "./range-C_vpUjBu.js";
12
+ import { e as H } from "./tfjs_backend-DuKis_xG.js";
13
+ import { g as J } from "./gather-C5D8PxwA.js";
14
+ import { s as U } from "./softmax-Dsxflvdl.js";
15
15
  /**
16
16
  * @license
17
17
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -28,13 +28,13 @@ import { s as J } from "./softmax-Cujsg4ay.js";
28
28
  * limitations under the License.
29
29
  * =============================================================================
30
30
  */
31
- function Y(h, t) {
32
- let e = E(h, "a", "mod"), o = E(t, "b", "mod");
31
+ function Y(c, t) {
32
+ let e = E(c, "a", "mod"), o = E(t, "b", "mod");
33
33
  [e, o] = G(e, o);
34
- const n = { a: e, b: o };
35
- return v.runKernel(O, n);
34
+ const i = { a: e, b: o };
35
+ return v.runKernel(O, i);
36
36
  }
37
- const Z = /* @__PURE__ */ y({ mod_: Y });
37
+ const Z = /* @__PURE__ */ $({ mod_: Y });
38
38
  /**
39
39
  * @license
40
40
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -51,17 +51,17 @@ const Z = /* @__PURE__ */ y({ mod_: Y });
51
51
  * limitations under the License.
52
52
  * =============================================================================
53
53
  */
54
- function tt(h, t, e, o = !1) {
55
- const n = E(h, "logits", "multinomial"), s = n.size, r = n.rank;
54
+ function tt(c, t, e, o = !1) {
55
+ const i = E(c, "logits", "multinomial"), s = i.size, l = i.rank;
56
56
  if (s < 2)
57
57
  throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
58
- if (r > 2)
59
- throw new Error(`Rank of probabilities must be 1 or 2, but is ${r}`);
58
+ if (l > 2)
59
+ throw new Error(`Rank of probabilities must be 1 or 2, but is ${l}`);
60
60
  e = e || Math.random();
61
- const i = { logits: r === 1 ? _(n, [1, -1]) : n }, p = { numSamples: t, seed: e, normalized: o }, l = v.runKernel(Q, i, p);
62
- return r === 1 ? _(l, [l.size]) : l;
61
+ const n = { logits: l === 1 ? _(i, [1, -1]) : i }, h = { numSamples: t, seed: e, normalized: o }, a = v.runKernel(j, n, h);
62
+ return l === 1 ? _(a, [a.size]) : a;
63
63
  }
64
- const M = /* @__PURE__ */ y({ multinomial_: tt });
64
+ const M = /* @__PURE__ */ $({ multinomial_: tt });
65
65
  /**
66
66
  * @license
67
67
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -78,19 +78,19 @@ const M = /* @__PURE__ */ y({ multinomial_: tt });
78
78
  * limitations under the License.
79
79
  * =============================================================================
80
80
  */
81
- function et(h, t = 1, e = !0) {
82
- const o = E(h, "x", "topk");
81
+ function et(c, t = 1, e = !0) {
82
+ const o = E(c, "x", "topk");
83
83
  if (o.rank === 0)
84
84
  throw new Error("topk() expects the input to be of rank 1 or higher");
85
- const n = o.shape[o.shape.length - 1];
85
+ const i = o.shape[o.shape.length - 1];
86
86
  if (t < 0)
87
87
  throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
88
- if (t > n)
89
- throw new Error(`'k' passed to topk() must be <= the last dimension (${n}) but got ${t}`);
90
- const s = { x: o }, r = { k: t, sorted: e }, [a, i] = v.runKernel(j, s, r);
91
- return { values: a, indices: i };
88
+ if (t > i)
89
+ throw new Error(`'k' passed to topk() must be <= the last dimension (${i}) but got ${t}`);
90
+ const s = { x: o }, l = { k: t, sorted: e }, [r, n] = v.runKernel(Q, s, l);
91
+ return { values: r, indices: n };
92
92
  }
93
- const ot = /* @__PURE__ */ y({ topk_: et });
93
+ const ot = /* @__PURE__ */ $({ topk_: et });
94
94
  /**
95
95
  * @license
96
96
  * Copyright 2018 Google LLC
@@ -100,11 +100,11 @@ const ot = /* @__PURE__ */ y({ topk_: et });
100
100
  * https://opensource.org/licenses/MIT.
101
101
  * =============================================================================
102
102
  */
103
- function st(h) {
104
- return new P(h);
103
+ function st(c) {
104
+ return new P(c);
105
105
  }
106
- function nt(h) {
107
- return new F(h);
106
+ function it(c) {
107
+ return new F(c);
108
108
  }
109
109
  class wt extends B {
110
110
  wte;
@@ -124,12 +124,12 @@ class wt extends B {
124
124
  vocabSize: this.config.gpt.vocabSize,
125
125
  embedDim: this.config.gpt.nEmbed,
126
126
  name: "token_embedding"
127
- }), this.config.gpt.useRope === !1 ? this.wpe = nt({
127
+ }), this.config.gpt.useRope === !1 ? this.wpe = it({
128
128
  inputDim: this.config.gpt.blockSize,
129
129
  outputDim: this.config.gpt.nEmbed,
130
130
  name: "positional_embedding",
131
- embeddingsInitializer: K({ mean: 0, stddev: 0.02 })
132
- }) : (this.ropeCache = new D(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = st({ rate: this.config.gpt.dropout }), this.blocks = [];
131
+ embeddingsInitializer: T({ mean: 0, stddev: 0.02 })
132
+ }) : (this.ropeCache = new K(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = st({ rate: this.config.gpt.dropout }), this.blocks = [];
133
133
  for (let e = 0; e < this.config.gpt.nLayer; e++)
134
134
  this.blocks.push(new W(e, this.config));
135
135
  this.lnF = new N(this.config, 1e-8, "final_rms_norm");
@@ -163,12 +163,12 @@ class wt extends B {
163
163
  }
164
164
  inputPhase(t, e, o = !1) {
165
165
  return w(() => {
166
- const n = this.wte.embed(t);
166
+ const i = this.wte.embed(t);
167
167
  if (this.config.gpt.useRope === !1) {
168
- const [, s] = t.shape, r = this.config.gpt.blockSize, a = V(0, s, 1, "int32"), i = Z(U(a, C(e, "int32")), C(r, "int32")), p = this.wpe.apply(i), l = n.add(p);
169
- return this.drop.apply(l, { training: o });
168
+ const [, s] = t.shape, l = this.config.gpt.blockSize, r = X(0, s, 1, "int32"), n = Z(V(r, C(e, "int32")), C(l, "int32")), h = this.wpe.apply(n), a = i.add(h);
169
+ return this.drop.apply(a, { training: o });
170
170
  } else
171
- return this.drop.apply(n, { training: o });
171
+ return this.drop.apply(i, { training: o });
172
172
  });
173
173
  }
174
174
  setSkipMask(t) {
@@ -209,67 +209,64 @@ class wt extends B {
209
209
  return w(() => {
210
210
  if (t.length === 0)
211
211
  throw new Error("No attentions for rollout");
212
- const [e, o, n] = t[0].shape;
213
- for (const s of t) {
214
- const [r, a, i] = s.shape;
215
- if (r !== e || a !== o || i !== n)
212
+ const [e, o, i] = t[0].shape;
213
+ for (const n of t) {
214
+ const [h, a, p] = n.shape;
215
+ if (h !== e || a !== o || p !== i)
216
216
  throw new Error(
217
- `Inconsistent attention shapes in rollout: expected [${e},${o},${n}] got [${r},${a},${i}]`
217
+ `Inconsistent attention shapes in rollout: expected [${e},${o},${i}] got [${h},${a},${p}]`
218
218
  );
219
219
  }
220
- if (o === n) {
221
- const s = X(n, n).expandDims(0);
222
- let r = s.tile([e, 1, 1]);
223
- for (const a of t) {
224
- const i = a.add(s);
225
- r = i.div(i.sum(-1, !0)).matMul(r);
226
- }
227
- return r;
220
+ const s = t.map((n) => n.slice([0, 0, 0], [e, o, o])), l = H(o, o).expandDims(0);
221
+ let r = l.tile([e, 1, 1]);
222
+ for (const n of s) {
223
+ const h = n.add(l);
224
+ r = h.div(h.sum(-1, !0)).matMul(r);
228
225
  }
229
- throw new Error(`Unsupported attention shapes for rollout: [B=${e}, Q=${o}, K=${n}]`);
226
+ return r;
230
227
  });
231
228
  }
232
- forward(t, e, o = !1, n = !1, s) {
229
+ forward(t, e, o = !1, i = !1, s) {
233
230
  return this.validateInput(t), w(() => {
234
231
  this.startMemory();
235
- const r = s?.[0]?.length ?? 0;
236
- let a = this.inputPhase(t, r, o);
237
- const i = [];
232
+ const l = s?.[0]?.length ?? 0;
233
+ let r = this.inputPhase(t, l, o);
234
+ const n = [];
238
235
  if (s && s.length !== this.blocks.length)
239
236
  throw console.error("Cache", s), new Error(`Cache length ${s.length} does not match number of blocks ${this.blocks.length}`);
240
- for (let c = 0; c < this.blocks.length; c++) {
241
- const u = a, m = this.blocks[c], {
237
+ for (let g = 0; g < this.blocks.length; g++) {
238
+ const u = r, m = this.blocks[g], {
242
239
  output: b,
243
240
  attention: k,
244
241
  cache: f
245
- } = m.call(a, o, n, s ? s[c] : void 0);
246
- a = b, u.dispose(), n && k && i.push(k), s && f ? (s[c]?.k.dispose(), s[c]?.v.dispose(), s[c] = f) : f && (f.k.dispose(), f.v.dispose());
242
+ } = m.call(r, o, i, s ? s[g] : void 0);
243
+ r = b, u.dispose(), i && k && n.push(k), s && f ? (s[g]?.k.dispose(), s[g]?.v.dispose(), s[g] = f) : f && (f.k.dispose(), f.v.dispose());
247
244
  }
245
+ let h;
246
+ i && n.length > 0 && (h = this.computeAttentionRollout(n)), r = this.lnF.apply(r);
247
+ const a = this.wte.project(r);
248
248
  let p;
249
- n && i.length > 0 && (p = this.computeAttentionRollout(i)), a = this.lnF.apply(a);
250
- const l = this.wte.project(a);
251
- let g;
252
- return e && (g = this.calculateLoss(l, e)), this.endMemory("Forward"), { logits: l, loss: g, attention: n ? p : void 0 };
249
+ return e && (p = this.calculateLoss(a, e)), this.endMemory("Forward"), { logits: a, loss: p, attention: i ? h : void 0 };
253
250
  });
254
251
  }
255
252
  generate(t, e, o) {
256
- const n = o?.temperature ?? 1, s = o?.topK, r = o?.usePadding ?? !1, a = o?.includeAttention ?? !1;
253
+ const i = o?.temperature ?? 1, s = o?.topK, l = o?.usePadding ?? !1, r = o?.includeAttention ?? !1;
257
254
  return w(() => {
258
- const i = t, p = i.shape[1], l = p <= this.config.gpt.blockSize ? i : i.slice(
259
- [0, p - this.config.gpt.blockSize],
260
- [i.shape[0], this.config.gpt.blockSize]
261
- ), g = r ? this.config.gpt.blockSize - l.shape[1] : 0, c = g > 0 ? T(l, [
255
+ const n = t, h = n.shape[1], a = h <= this.config.gpt.blockSize ? n : n.slice(
256
+ [0, h - this.config.gpt.blockSize],
257
+ [n.shape[0], this.config.gpt.blockSize]
258
+ ), p = l ? this.config.gpt.blockSize - a.shape[1] : 0, g = p > 0 ? D(a, [
262
259
  [0, 0],
263
- [0, g]
264
- ]) : l, { logits: u, attention: m } = this.forward(c, void 0, !1, a, e), b = u.shape[1] - 1 - g, k = u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]), f = m ? m.slice([0, b, 0], [m.shape[0], 1, m.shape[2]]) : void 0, $ = k.div(n);
260
+ [0, p]
261
+ ]) : a, { logits: u, attention: m } = this.forward(g, void 0, !1, r, e), b = u.shape[1] - 1 - p, k = u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]), f = m ? m.slice([0, b, 0], [m.shape[0], 1, m.shape[2]]) : void 0, y = k.div(i);
265
262
  let d;
266
263
  if (s) {
267
- const { values: S, indices: I } = ot($, s), L = M(S.squeeze([1]), 1);
268
- d = H(I.squeeze([1]), L, 1);
264
+ const { values: S, indices: I } = ot(y, s), L = M(S.squeeze([1]), 1);
265
+ d = J(I.squeeze([1]), L, 1);
269
266
  } else
270
- d = M($.squeeze([1]), 1);
267
+ d = M(y.squeeze([1]), 1);
271
268
  let z;
272
- return o?.includeProbabilities && (z = J($.squeeze([1]))), d = d.reshape([1, 1]), { output: d, attention: f?.squeeze([1]), probabilities: z };
269
+ return o?.includeProbabilities && (z = U(y.squeeze([1]))), d = d.reshape([1, 1]), { output: d, attention: f?.squeeze([1]), probabilities: z };
273
270
  });
274
271
  }
275
272
  getNumParams() {
@@ -0,0 +1,212 @@
1
+ import { ac as f, ad as g, n as p, ae as C, j as x } from "./index--6vO-cOz.js";
2
+ import { u as I } from "./gpgpu_math-CUzjlO9A.js";
3
+ /**
4
+ * @license
5
+ * Copyright 2018 Google LLC. All Rights Reserved.
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ * =============================================================================
18
+ */
19
+ function R(t, e, o = "index") {
20
+ const s = f(e);
21
+ return s.map((n, r) => {
22
+ const i = `int ${t[r]} = ${o} / ${n}`, u = r === s.length - 1 ? `int ${t[r + 1]} = ${o} - ${t[r]} * ${n}` : `index -= ${t[r]} * ${n}`;
23
+ return `${i}; ${u};`;
24
+ }).join("");
25
+ }
26
+ function y(t, e) {
27
+ const o = t.length, s = t.map((r) => `${e}[${r}]`), n = new Array(o - 1);
28
+ n[o - 2] = s[o - 1];
29
+ for (let r = o - 3; r >= 0; --r)
30
+ n[r] = `(${n[r + 1]} * ${s[r + 1]})`;
31
+ return n;
32
+ }
33
+ function S(t, e, o = "index") {
34
+ const s = t.map((r, i) => i), n = y(s, e);
35
+ return n.map((r, i) => {
36
+ const u = `int ${t[i]} = ${o} / ${n[i]}`, a = i === n.length - 1 ? `int ${t[i + 1]} = ${o} - ${t[i]} * ${n[i]}` : `index -= ${t[i]} * ${n[i]}`;
37
+ return `${u}; ${a};`;
38
+ }).join("");
39
+ }
40
+ function F(t) {
41
+ const e = f(t).map((o) => o.toString());
42
+ return `
43
+ int getFlatIndex(ivec3 coords) {
44
+ return coords.x * ${e[0]} + coords.y * ${e[1]} + coords.z;
45
+ }
46
+ `;
47
+ }
48
+ function v() {
49
+ return `
50
+ int getFlatIndex(ivec3 coords) {
51
+ return coords.x * outShapeStrides[0] + coords.y * outShapeStrides[1] + coords.z;
52
+ }
53
+ `;
54
+ }
55
+ /**
56
+ * @license
57
+ * Copyright 2017 Google LLC. All Rights Reserved.
58
+ * Licensed under the Apache License, Version 2.0 (the "License");
59
+ * you may not use this file except in compliance with the License.
60
+ * You may obtain a copy of the License at
61
+ *
62
+ * http://www.apache.org/licenses/LICENSE-2.0
63
+ *
64
+ * Unless required by applicable law or agreed to in writing, software
65
+ * distributed under the License is distributed on an "AS IS" BASIS,
66
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
67
+ * See the License for the specific language governing permissions and
68
+ * limitations under the License.
69
+ * =============================================================================
70
+ */
71
+ function h(t, e = 2) {
72
+ return p(t.slice(0, t.length - e));
73
+ }
74
+ function m(t) {
75
+ if (t.length === 0)
76
+ throw Error("Cannot get rows and columns of an empty shape array.");
77
+ return [
78
+ t.length > 1 ? t[t.length - 2] : 1,
79
+ t[t.length - 1]
80
+ ];
81
+ }
82
+ function d(t) {
83
+ return t % 2 === 0;
84
+ }
85
+ function $(t, e) {
86
+ if (t = t.slice(-2), e = e.slice(-2), g(t, e) || !t.length || !e.length || t[0] === 0 || t[1] === 0 || e[0] === 0 || e[1] === 0)
87
+ return !0;
88
+ if (t.length !== e.length) {
89
+ const o = t[t.length - 1], s = e[e.length - 1];
90
+ if (o === s || d(o) && d(s) && (t[0] === 1 || e[0] === 1))
91
+ return !0;
92
+ }
93
+ return t[1] === e[1] && d(t[0]) && d(e[0]);
94
+ }
95
+ /**
96
+ * @license
97
+ * Copyright 2018 Google LLC. All Rights Reserved.
98
+ * Licensed under the Apache License, Version 2.0 (the "License");
99
+ * you may not use this file except in compliance with the License.
100
+ * You may obtain a copy of the License at
101
+ *
102
+ * http://www.apache.org/licenses/LICENSE-2.0
103
+ *
104
+ * Unless required by applicable law or agreed to in writing, software
105
+ * distributed under the License is distributed on an "AS IS" BASIS,
106
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
107
+ * See the License for the specific language governing permissions and
108
+ * limitations under the License.
109
+ * =============================================================================
110
+ */
111
+ class b {
112
+ constructor(e, o) {
113
+ this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.customUniforms = [{ name: "inputShape", type: "ivec3" }], this.outputShape = e, this.enableShapeUniforms = I(this.outputShape.length);
114
+ let s = "";
115
+ for (let n = 0; n < 4; n++) {
116
+ let r = "thisRC = rc;";
117
+ n % 2 === 1 && (r += "thisRC.z += 1;"), n > 1 && (r += "thisRC.y += 1;"), s += `
118
+ ${r}
119
+ ${n > 0 ? "if(thisRC.y < rows && thisRC.z < cols){" : ""}
120
+ int flatIndex = getFlatIndex(thisRC);
121
+
122
+ ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);
123
+ vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));
124
+
125
+ result[${n}] =
126
+ getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);
127
+ ${n > 0 ? "}" : ""}
128
+ `;
129
+ }
130
+ this.userCode = `
131
+ ${w(o, this.enableShapeUniforms)}
132
+ ${this.enableShapeUniforms ? v() : F(e)}
133
+
134
+ void main() {
135
+ ivec3 rc = getOutputCoords();
136
+
137
+ vec4 result = vec4(0.);
138
+
139
+ ivec3 thisRC;
140
+ int rows = ${this.enableShapeUniforms ? "outShape[1]" : e[1]};
141
+ int cols = ${this.enableShapeUniforms ? "outShape[2]" : e[2]};
142
+
143
+ ${s}
144
+
145
+ setOutput(result);
146
+ }
147
+ `;
148
+ }
149
+ }
150
+ function w(t, e) {
151
+ return `
152
+ ivec3 inputCoordsFromReshapedOutCoords(int index) {
153
+ ${e ? S(["r", "c", "d"], "inputShape") : R(["r", "c", "d"], t)}
154
+ return ivec3(r, c, d);
155
+ }
156
+ `;
157
+ }
158
+ /**
159
+ * @license
160
+ * Copyright 2020 Google LLC. All Rights Reserved.
161
+ * Licensed under the Apache License, Version 2.0 (the "License");
162
+ * you may not use this file except in compliance with the License.
163
+ * You may obtain a copy of the License at
164
+ *
165
+ * http://www.apache.org/licenses/LICENSE-2.0
166
+ *
167
+ * Unless required by applicable law or agreed to in writing, software
168
+ * distributed under the License is distributed on an "AS IS" BASIS,
169
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
170
+ * See the License for the specific language governing permissions and
171
+ * limitations under the License.
172
+ * =============================================================================
173
+ */
174
+ function D(t, e, o) {
175
+ const s = [
176
+ h(t.shape),
177
+ ...m(t.shape)
178
+ ], n = {
179
+ dtype: t.dtype,
180
+ shape: s,
181
+ dataId: t.dataId
182
+ }, r = [
183
+ h(e),
184
+ ...m(e)
185
+ ], i = new b(r, s), u = !0, a = [s], c = o.runWebGLProgram(i, [n], t.dtype, a, u);
186
+ return { dataId: c.dataId, shape: e, dtype: c.dtype };
187
+ }
188
+ /**
189
+ * @license
190
+ * Copyright 2020 Google LLC. All Rights Reserved.
191
+ * Licensed under the Apache License, Version 2.0 (the "License");
192
+ * you may not use this file except in compliance with the License.
193
+ * You may obtain a copy of the License at
194
+ *
195
+ * http://www.apache.org/licenses/LICENSE-2.0
196
+ *
197
+ * Unless required by applicable law or agreed to in writing, software
198
+ * distributed under the License is distributed on an "AS IS" BASIS,
199
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ * See the License for the specific language governing permissions and
201
+ * limitations under the License.
202
+ * =============================================================================
203
+ */
204
+ function k(t) {
205
+ const { inputs: e, backend: o, attrs: s } = t, { x: n } = e, { shape: r } = s, i = o, u = p(n.shape), a = C(r, u), c = p(a);
206
+ x(u === c, () => `The new shape (${a}) has ${c} elements and the old shape (${n.shape}) has ${u} elements. The new shape and old shape must have the same number of elements.`);
207
+ const l = i.texData.get(n.dataId);
208
+ return l.isPacked && !$(n.shape, a) && !(l.texture !== null && $(l.shape, a)) ? D(n, a, i) : (i.incRef(n.dataId), { dataId: n.dataId, shape: a, dtype: n.dtype });
209
+ }
210
+ export {
211
+ k as r
212
+ };
@@ -11,7 +11,7 @@ import g from "./tokeniser/bpe.js";
11
11
  import "./papaparse.min-C8l2Kvo1.js";
12
12
  import "./index-Tf7vU29b.js";
13
13
  import "./jszip.min-CjP2V1VV.js";
14
- import "./index-C4JCoBvj.js";
14
+ import "./index--6vO-cOz.js";
15
15
  import "./ops/cpu/scatterSub.js";
16
16
  import "./ops/webgl/scatterSub.js";
17
17
  import "./ops/cpu/gatherSub.js";
@@ -31,6 +31,12 @@ import "./ops/webgl/appendCache.js";
31
31
  import "./ops/cpu/fusedSoftmax.js";
32
32
  import "./ops/webgl/fusedSoftmax.js";
33
33
  import "./ops/grads/fusedSoftmax.js";
34
+ import "./ops/cpu/matMulGelu.js";
35
+ import "./ops/webgl/matMulGelu.js";
36
+ import "./ops/grads/matMulGelu.js";
37
+ import "./ops/cpu/gelu.js";
38
+ import "./ops/webgl/gelu.js";
39
+ import "./ops/grads/gelu.js";
34
40
  import w from "./utilities/profile.js";
35
41
  class a {
36
42
  ee = new p();