@genai-fi/nanogpt 0.7.3 → 0.8.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. package/dist/Generator.d.ts +25 -2
  2. package/dist/Generator.js +152 -49
  3. package/dist/{RealDiv-Dy0p8Bvo.js → RealDiv-D_q39E3A.js} +13 -13
  4. package/dist/{Reshape-DvudQDvJ.js → Reshape-41YpQqEo.js} +1 -1
  5. package/dist/{Reshape-DH5srBP0.js → Reshape-Bh_jzKzV.js} +5 -5
  6. package/dist/TeachableLLM.d.ts +6 -6
  7. package/dist/TeachableLLM.js +33 -31
  8. package/dist/Trainer.d.ts +13 -2
  9. package/dist/Trainer.js +21 -12
  10. package/dist/{axis_util-BzbKo31C.js → axis_util-Did9235A.js} +3 -3
  11. package/dist/backend.js +2 -2
  12. package/dist/{backend_util-TE7aTPhZ.js → backend_util-yC3YH1jo.js} +58 -58
  13. package/dist/{broadcast_to-CdbwV-Dj.js → broadcast_to-CUvOdOT5.js} +2 -2
  14. package/dist/checks/appendCache.d.ts +1 -0
  15. package/dist/checks/appendCache.js +22 -0
  16. package/dist/checks/attentionMask.d.ts +1 -0
  17. package/dist/checks/attentionMask.js +37 -0
  18. package/dist/checks/check.d.ts +9 -0
  19. package/dist/checks/check.js +20 -0
  20. package/dist/checks/gelu.d.ts +1 -0
  21. package/dist/checks/gelu.js +18 -0
  22. package/dist/checks/index.d.ts +19 -0
  23. package/dist/checks/index.js +21 -0
  24. package/dist/checks/normRMS.d.ts +1 -0
  25. package/dist/checks/normRMS.js +16 -0
  26. package/dist/checks/normRMSGrad.d.ts +1 -0
  27. package/dist/checks/normRMSGrad.js +12 -0
  28. package/dist/checks/qkv.d.ts +1 -0
  29. package/dist/checks/qkv.js +25 -0
  30. package/dist/checks/rope.d.ts +1 -0
  31. package/dist/checks/rope.js +21 -0
  32. package/dist/{concat-CsxrgovM.js → concat-pHiVqR3L.js} +1 -1
  33. package/dist/{dataset-CtdBYwjo.js → dataset-DPPl-iLT.js} +9 -9
  34. package/dist/{dropout-DYs5QFGQ.js → dropout-CcKSfOYE.js} +18 -18
  35. package/dist/exports_initializers-DKk7-bsx.js +16 -0
  36. package/dist/{gather-CMMy2KEG.js → gather-CPg6ZlQA.js} +1 -1
  37. package/dist/{gelu-C-dPj6Ku.js → gelu-BkcmEEyD.js} +1 -1
  38. package/dist/{gpgpu_math-DGNLNL4I.js → gpgpu_math-D_ODOLix.js} +26 -26
  39. package/dist/{index-BoWRt-10.js → index-DdmHGZjq.js} +659 -650
  40. package/dist/{index-CLthM0TO.js → index-evZ57wr4.js} +185 -185
  41. package/dist/{kernel_funcs_utils-BYKWV8Aa.js → kernel_funcs_utils-CDfFpUab.js} +21 -21
  42. package/dist/layers/BaseLayer.d.ts +8 -13
  43. package/dist/layers/BaseLayer.js +25 -13
  44. package/dist/layers/CausalSelfAttention.d.ts +3 -2
  45. package/dist/layers/CausalSelfAttention.js +28 -28
  46. package/dist/layers/MLP.d.ts +3 -2
  47. package/dist/layers/MLP.js +16 -20
  48. package/dist/layers/PositionEmbedding.d.ts +9 -0
  49. package/dist/layers/PositionEmbedding.js +45 -0
  50. package/dist/layers/RMSNorm.d.ts +3 -2
  51. package/dist/layers/RMSNorm.js +6 -6
  52. package/dist/layers/RoPECache.d.ts +1 -1
  53. package/dist/layers/RoPECache.js +4 -4
  54. package/dist/layers/TiedEmbedding.d.ts +3 -2
  55. package/dist/layers/TiedEmbedding.js +29 -7
  56. package/dist/layers/TransformerBlock.d.ts +3 -2
  57. package/dist/layers/TransformerBlock.js +1 -1
  58. package/dist/loader/load.d.ts +2 -2
  59. package/dist/loader/loadHF.d.ts +2 -2
  60. package/dist/loader/loadTransformers.d.ts +4 -2
  61. package/dist/loader/loadTransformers.js +10 -9
  62. package/dist/loader/newZipLoad.d.ts +2 -2
  63. package/dist/loader/oldZipLoad.d.ts +2 -2
  64. package/dist/loader/oldZipLoad.js +44 -51
  65. package/dist/loader/save.d.ts +8 -0
  66. package/dist/loader/save.js +62 -0
  67. package/dist/{log_sum_exp-DbjkV734.js → log_sum_exp-C8yFJfZz.js} +45 -24
  68. package/dist/main.d.ts +6 -4
  69. package/dist/main.js +24 -18
  70. package/dist/{mat_mul-8m8pfdcx.js → mat_mul-Dpy2mMRu.js} +1 -1
  71. package/dist/mod-CbibJi3D.js +27 -0
  72. package/dist/models/NanoGPTV1.d.ts +15 -0
  73. package/dist/models/NanoGPTV1.js +71 -0
  74. package/dist/{config.d.ts → models/config.d.ts} +1 -0
  75. package/dist/{config.js → models/config.js} +1 -0
  76. package/dist/models/factory.d.ts +3 -0
  77. package/dist/models/factory.js +14 -0
  78. package/dist/models/model.d.ts +26 -0
  79. package/dist/models/model.js +70 -0
  80. package/dist/{mulmat_packed_gpu-VSekgsNv.js → mulmat_packed_gpu-q_Gmwyld.js} +1 -1
  81. package/dist/{ones-Dj0SDhHf.js → ones-BAqVh-eA.js} +2 -2
  82. package/dist/ops/adamAdjust.js +1 -1
  83. package/dist/ops/adamMoments.js +1 -1
  84. package/dist/ops/appendCache.js +3 -3
  85. package/dist/ops/attentionMask.js +1 -1
  86. package/dist/ops/cpu/adamAdjust.js +9 -9
  87. package/dist/ops/cpu/adamMoments.js +2 -2
  88. package/dist/ops/cpu/appendCache.js +2 -2
  89. package/dist/ops/cpu/attentionMask.js +5 -5
  90. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  91. package/dist/ops/cpu/gatherSub.js +5 -5
  92. package/dist/ops/cpu/gelu.js +1 -1
  93. package/dist/ops/cpu/matMulGelu.js +2 -2
  94. package/dist/ops/cpu/matMulMul.js +1 -1
  95. package/dist/ops/cpu/mulDropout.js +1 -1
  96. package/dist/ops/cpu/normRMS.js +1 -1
  97. package/dist/ops/cpu/qkv.js +3 -3
  98. package/dist/ops/cpu/rope.js +5 -5
  99. package/dist/ops/cpu/scatterSub.js +7 -7
  100. package/dist/ops/fusedSoftmax.js +1 -1
  101. package/dist/ops/gatherSub.js +1 -1
  102. package/dist/ops/gelu.js +2 -2
  103. package/dist/ops/grads/attentionMask.js +1 -1
  104. package/dist/ops/grads/fusedSoftmax.js +2 -2
  105. package/dist/ops/grads/gelu.js +2 -2
  106. package/dist/ops/grads/matMulGelu.js +1 -1
  107. package/dist/ops/grads/normRMS.js +1 -1
  108. package/dist/ops/grads/qkv.js +1 -1
  109. package/dist/ops/grads/rope.js +1 -1
  110. package/dist/ops/matMulGelu.js +1 -1
  111. package/dist/ops/matMulMul.js +1 -1
  112. package/dist/ops/mulDrop.js +1 -1
  113. package/dist/ops/normRMS.js +1 -1
  114. package/dist/ops/qkv.js +1 -1
  115. package/dist/ops/rope.js +4 -4
  116. package/dist/ops/scatterSub.js +1 -1
  117. package/dist/ops/webgl/adamAdjust.js +2 -2
  118. package/dist/ops/webgl/adamMoments.js +1 -1
  119. package/dist/ops/webgl/appendCache.js +1 -1
  120. package/dist/ops/webgl/attentionMask.js +1 -1
  121. package/dist/ops/webgl/fusedSoftmax.js +4 -4
  122. package/dist/ops/webgl/gatherSub.js +1 -1
  123. package/dist/ops/webgl/gelu.js +2 -2
  124. package/dist/ops/webgl/log.js +3 -3
  125. package/dist/ops/webgl/matMulGelu.js +10 -10
  126. package/dist/ops/webgl/matMulMul.js +1 -1
  127. package/dist/ops/webgl/mulDropout.js +1 -1
  128. package/dist/ops/webgl/normRMS.js +2 -2
  129. package/dist/ops/webgl/qkv.js +1 -1
  130. package/dist/ops/webgl/rope.js +1 -1
  131. package/dist/ops/webgl/scatterSub.js +1 -1
  132. package/dist/ops/webgpu/adamAdjust.js +3 -3
  133. package/dist/ops/webgpu/adamMoments.js +3 -3
  134. package/dist/ops/webgpu/appendCache.js +3 -3
  135. package/dist/ops/webgpu/attentionMask.js +3 -3
  136. package/dist/ops/webgpu/gatherSub.js +3 -3
  137. package/dist/ops/webgpu/gelu.js +3 -3
  138. package/dist/ops/webgpu/normRMS.js +2 -2
  139. package/dist/ops/webgpu/normRMSGrad.js +5 -5
  140. package/dist/ops/webgpu/qkv.js +3 -3
  141. package/dist/ops/webgpu/rope.js +3 -3
  142. package/dist/ops/webgpu/scatterSub.js +3 -3
  143. package/dist/ops/webgpu/utils/reductions.js +4 -4
  144. package/dist/ops-542ai2vG.js +1525 -0
  145. package/dist/{random_width-sZORGo5k.js → random_width-DKGeiFuR.js} +1471 -1538
  146. package/dist/{range-CRuAh-gd.js → range-BcUvLuf5.js} +1 -1
  147. package/dist/{reciprocal-BvGAyKyu.js → reciprocal-DhDWSKiD.js} +1 -1
  148. package/dist/{register_all_kernels-BwDSRN-f.js → register_all_kernels-Do9VvZmo.js} +2488 -2534
  149. package/dist/{max-Ddnnb5xe.js → relu-B1AXs7p5.js} +6 -6
  150. package/dist/{reshape-CdBq1WJ6.js → reshape-WeJkT3ja.js} +1 -1
  151. package/dist/{scatter_nd_util-DUstGbU1.js → scatter_nd_util-B7yDhiQr.js} +1 -1
  152. package/dist/{selu_util-BJEXVvjX.js → selu_util-BgUO9gHY.js} +125 -146
  153. package/dist/{shared-wS99K7_n.js → shared-CZiWmQCI.js} +1 -1
  154. package/dist/{shared-B8ztnyEk.js → shared-V6D_md-c.js} +72 -72
  155. package/dist/{sin-BeA3tsEd.js → sin-CPxad7Am.js} +1 -1
  156. package/dist/{slice-BiOsknYS.js → slice-B7jXtPnp.js} +1 -1
  157. package/dist/{softmax-Bv_6lyMX.js → softmax-BfsyI4As.js} +1 -1
  158. package/dist/{split-B-dikLRw.js → split-BPxr8_8m.js} +1 -1
  159. package/dist/{stack-B17UN2nn.js → stack-BNwLzE43.js} +1 -1
  160. package/dist/{sum-66ew2byf.js → sum-ByFINZgi.js} +3 -3
  161. package/dist/{tensor-JwS7ZYY6.js → tensor-DbqgIV9B.js} +1 -1
  162. package/dist/tensor1d-CtJq5BOv.js +27 -0
  163. package/dist/{tensor2d-wxPAnDQy.js → tensor2d-CObBWBkW.js} +1 -1
  164. package/dist/tensor3d-BOukqWwr.js +30 -0
  165. package/dist/tensor4d-DLtk7Nxh.js +30 -0
  166. package/dist/training/Adam.js +2 -2
  167. package/dist/training/AdamExt.js +1 -1
  168. package/dist/training/DatasetBuilder.js +2 -2
  169. package/dist/training/Evaluator.d.ts +2 -2
  170. package/dist/training/FullTrainer.d.ts +3 -3
  171. package/dist/training/FullTrainer.js +61 -69
  172. package/dist/training/Trainer.d.ts +15 -3
  173. package/dist/training/Trainer.js +39 -47
  174. package/dist/training/sparseCrossEntropy.js +12 -13
  175. package/dist/utilities/arrayClose.d.ts +1 -1
  176. package/dist/utilities/arrayClose.js +16 -7
  177. package/dist/utilities/dummy.d.ts +4 -4
  178. package/dist/utilities/dummy.js +13 -13
  179. package/dist/utilities/multinomialCPU.js +2 -2
  180. package/dist/utilities/parameters.d.ts +1 -1
  181. package/dist/utilities/performance.js +1 -1
  182. package/dist/utilities/profile.js +1 -1
  183. package/dist/utilities/safetensors.js +2 -2
  184. package/dist/utilities/weights.js +2 -2
  185. package/dist/{variable-BuddVFLa.js → variable-DPFOJyRG.js} +1 -1
  186. package/dist/{webgpu_program-PFzf1hAQ.js → webgpu_program-Dhk9R5aG.js} +1 -1
  187. package/dist/{webgpu_util-D____QpY.js → webgpu_util-BqGnZg8t.js} +27 -27
  188. package/dist/{zeros--BdLQ3oG.js → zeros-Dnwix0p4.js} +1 -1
  189. package/package.json +2 -3
  190. package/dist/NanoGPTModel.d.ts +0 -52
  191. package/dist/NanoGPTModel.js +0 -203
  192. package/dist/TiedEmbedding-BxOerUmB.js +0 -43
  193. package/dist/ops-BFGCx8Ri.js +0 -1202
  194. package/dist/utilities/generate.d.ts +0 -3
  195. package/dist/utilities/generate.js +0 -22
  196. package/dist/utilities/save.d.ts +0 -9
  197. package/dist/utilities/save.js +0 -61
package/dist/Trainer.js CHANGED
@@ -1,11 +1,13 @@
1
1
  import { E as l } from "./index-Dwqa6Zy2.js";
2
2
  import h from "./training/FullTrainer.js";
3
- class p extends l {
3
+ class m extends l {
4
4
  trainer;
5
5
  hasTrained = !1;
6
6
  trainDataset;
7
7
  validationDataset;
8
8
  totalSamples = 0;
9
+ log = [];
10
+ progress = null;
9
11
  constructor(t, e) {
10
12
  super(), this.trainer = new h(t, e, 1e-3);
11
13
  }
@@ -13,7 +15,7 @@ class p extends l {
13
15
  this.trainer.stop();
14
16
  }
15
17
  reset() {
16
- this.hasTrained = !1, this.trainer.reset();
18
+ this.hasTrained = !1, this.log = [], this.trainer.reset();
17
19
  }
18
20
  async prepare(t, e) {
19
21
  const { trainDataset: a, validationDataset: s } = await this.trainer.createTrainValidationSplit(
@@ -26,7 +28,7 @@ class p extends l {
26
28
  async train(t) {
27
29
  if (!this.trainDataset || !this.validationDataset)
28
30
  throw new Error("Datasets not prepared");
29
- this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), await this.trainer.trainOnDataset(
31
+ this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), this.trainer.setGradientCheckpointing(t?.gradientCheckpointing || !1), await this.trainer.trainOnDataset(
30
32
  this.trainDataset,
31
33
  {
32
34
  prompt: t?.prompt,
@@ -35,16 +37,17 @@ class p extends l {
35
37
  maxSteps: t?.maxSteps || 1e3,
36
38
  advancedMetrics: t?.advancedMetrics || !1,
37
39
  onStep: async (e, a) => {
40
+ this.log.push(e), this.progress = {
41
+ ...a,
42
+ progress: a.totalSamples / this.totalSamples,
43
+ remaining: Math.max(
44
+ 0,
45
+ (this.totalSamples - a.totalSamples) / a.totalSamples * a.duration
46
+ )
47
+ };
38
48
  const s = this.listeners("log");
39
49
  for (const i of s)
40
- await i(e, {
41
- ...a,
42
- progress: a.totalSamples / this.totalSamples,
43
- remaining: Math.max(
44
- 0,
45
- (this.totalSamples - a.totalSamples) / a.totalSamples * a.duration
46
- )
47
- });
50
+ await i(e, this.progress);
48
51
  }
49
52
  },
50
53
  this.validationDataset
@@ -76,7 +79,13 @@ class p extends l {
76
79
  });
77
80
  this.emit("stop");
78
81
  }
82
+ getLog() {
83
+ return this.log;
84
+ }
85
+ getProgress() {
86
+ return this.progress;
87
+ }
79
88
  }
80
89
  export {
81
- p as default
90
+ m as default
82
91
  };
@@ -1,4 +1,4 @@
1
- import { l as c } from "./index-BoWRt-10.js";
1
+ import { n as c } from "./index-DdmHGZjq.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2017 Google LLC. All Rights Reserved.
@@ -28,7 +28,7 @@ function a(e, n, t) {
28
28
  t.indexOf(u) === -1 ? s.push(e[o++]) : s.push(n[f++]);
29
29
  return s;
30
30
  }
31
- function p(e, n) {
31
+ function l(e, n) {
32
32
  const t = [], r = e.length;
33
33
  for (let o = 0; o < r; o++)
34
34
  n.indexOf(o) === -1 && t.push(e[o]);
@@ -62,7 +62,7 @@ function x(e, n) {
62
62
  export {
63
63
  x as a,
64
64
  m as b,
65
- p as c,
65
+ l as c,
66
66
  i as d,
67
67
  h as e,
68
68
  a as f,
package/dist/backend.js CHANGED
@@ -1,6 +1,6 @@
1
- import { g as a, s as i, r as o } from "./index-BoWRt-10.js";
1
+ import { g as a, s as i, r as o } from "./index-DdmHGZjq.js";
2
2
  async function e(t) {
3
- a() !== t && (t === "webgpu" && (await import("./index-CLthM0TO.js"), await import("./ops/webgpu/index.js")), await i(t), await o(), console.log(`Backend set to ${t}`));
3
+ a() !== t && (t === "webgpu" && (await import("./index-evZ57wr4.js"), await import("./ops/webgpu/index.js")), await i(t), await o(), console.log(`Backend set to ${t}`));
4
4
  }
5
5
  export {
6
6
  e as selectBackend
@@ -1,7 +1,7 @@
1
- import { j as m, a1 as O, l as g, aK as $, aL as R, aM as M, k as _, aa as y, aw as D, aN as T, u as b, aO as F } from "./index-BoWRt-10.js";
2
- import { b as L, d as W, f as v, c as N, e as x, g as P, a as C, h as z } from "./axis_util-BzbKo31C.js";
3
- import { S as U, a as B, b as V, c as j, d as k, e as G, f as H, g as q, h as Z, i as K, j as X, k as J, l as Y, m as Q, s as ee, n as te, o as ne, t as se } from "./selu_util-BJEXVvjX.js";
4
- import { c as re, v as oe, a as ae } from "./scatter_nd_util-DUstGbU1.js";
1
+ import { j as m, a3 as R, n as g, aN as $, aO as O, aP as _, l as M, ae as y, ax as D, aQ as T, u as b, aR as F } from "./index-DdmHGZjq.js";
2
+ import { b as L, d as W, f as v, c as N, e as x, g as P, a as C, h as z } from "./axis_util-Did9235A.js";
3
+ import { S as U, a as B, b as V, c as j, d as G, e as H, f as k, g as q, h as Z, i as X, j as J, k as K, l as Q, m as Y, s as ee, n as te, o as ne, t as se } from "./selu_util-BgUO9gHY.js";
4
+ import { c as re, v as oe, a as ae } from "./scatter_nd_util-B7yDhiQr.js";
5
5
  function ie(e, n) {
6
6
  const r = e.shape.length, t = n.shape.length;
7
7
  if (r < 1)
@@ -24,7 +24,7 @@ function ie(e, n) {
24
24
  for (let i = o; i < r; ++i)
25
25
  h *= u[i], c.push(u[i]);
26
26
  const d = [
27
- ...O(e.shape).map((i) => i / h),
27
+ ...R(e.shape).map((i) => i / h),
28
28
  1
29
29
  ].slice(0, o);
30
30
  return [c, a, h, d];
@@ -233,7 +233,7 @@ function Ie(e, n) {
233
233
  r.push(e[t][0]);
234
234
  return r;
235
235
  }
236
- function we(e, n, r) {
236
+ function Se(e, n, r) {
237
237
  const t = e.slice(0, 1);
238
238
  for (let s = 0; s < r; ++s)
239
239
  t.push(e[s + 1] - n[s][0] - n[s][1]);
@@ -255,7 +255,7 @@ function we(e, n, r) {
255
255
  * limitations under the License.
256
256
  * =============================================================================
257
257
  */
258
- const Se = 0.3275911, Ae = 0.254829592, Oe = -0.284496736, Re = 1.421413741, Me = -1.453152027, _e = 1.061405429;
258
+ const we = 0.3275911, Ae = 0.254829592, Re = -0.284496736, Oe = 1.421413741, _e = -1.453152027, Me = 1.061405429;
259
259
  /**
260
260
  * @license
261
261
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -333,7 +333,7 @@ function ve(e, n, r) {
333
333
  * limitations under the License.
334
334
  * =============================================================================
335
335
  */
336
- const E = "->", Ne = /->/g, w = ",", S = "...";
336
+ const E = "->", Ne = /->/g, S = ",", w = "...";
337
337
  function xe(e, n) {
338
338
  e = e.replace(/\s/g, "");
339
339
  const r = (e.length - e.replace(Ne, "").length) / E.length;
@@ -342,8 +342,8 @@ function xe(e, n) {
342
342
  if (r > 1)
343
343
  throw new Error(`Equation must contain exactly one arrow ("${E}").`);
344
344
  const [t, s] = e.split(E);
345
- g(t.indexOf(S) === -1, () => `The ellipsis notation ("${S}") is not supported yet.`);
346
- const o = t.split(w), a = o.length;
345
+ g(t.indexOf(w) === -1, () => `The ellipsis notation ("${w}") is not supported yet.`);
346
+ const o = t.split(S), a = o.length;
347
347
  if (n !== a)
348
348
  throw new Error(`Expected ${a} input tensors, received ${n}`);
349
349
  if (a > 2)
@@ -357,7 +357,7 @@ function xe(e, n) {
357
357
  }
358
358
  for (let l = 0; l < t.length; ++l) {
359
359
  const f = t[l];
360
- u.indexOf(f) === -1 && f !== w && u.push(f);
360
+ u.indexOf(f) === -1 && f !== S && u.push(f);
361
361
  }
362
362
  const c = new Array(o.length);
363
363
  for (let l = 0; l < a; ++l) {
@@ -449,10 +449,10 @@ function je(e) {
449
449
  return `Received SparseTensor with denseShape[0] = 0 but
450
450
  indices.shape[0] = ${e}`;
451
451
  }
452
- function ke(e, n) {
452
+ function Ge(e, n) {
453
453
  return `indices(${e}, 0) is invalid: ${n} < 0`;
454
454
  }
455
- function Ge(e, n, r) {
455
+ function He(e, n, r) {
456
456
  return `indices(${e}, 0) is invalid: ${n} >= ${r}`;
457
457
  }
458
458
  /**
@@ -471,7 +471,7 @@ function Ge(e, n, r) {
471
471
  * limitations under the License.
472
472
  * =============================================================================
473
473
  */
474
- function He(e, n) {
474
+ function ke(e, n) {
475
475
  return `only one output dimension may be -1, not both ${e} and ${n}`;
476
476
  }
477
477
  function qe(e, n) {
@@ -480,12 +480,12 @@ function qe(e, n) {
480
480
  function Ze() {
481
481
  return "reshape cannot infer the missing input size for an empty tensor unless all specified input sizes are non-zero";
482
482
  }
483
- function Ke(e, n) {
483
+ function Xe(e, n) {
484
484
  const r = m(e), t = m(n);
485
485
  return `Input to reshape is a SparseTensor with ${r}
486
486
  dense values, but the requested shape requires a multiple of ${t}. inputShape=${e} outputShape= ${n}`;
487
487
  }
488
- function Xe(e, n) {
488
+ function Je(e, n) {
489
489
  const r = m(e), t = m(n);
490
490
  return `Input to reshape is a tensor with ${r} dense values, but the requested shape has ${t}. inputShape=${e} outputShape=${n}`;
491
491
  }
@@ -505,13 +505,13 @@ function Xe(e, n) {
505
505
  * limitations under the License.
506
506
  * =============================================================================
507
507
  */
508
- function Je() {
508
+ function Ke() {
509
509
  return "segment ids must be >= 0";
510
510
  }
511
- function Ye() {
511
+ function Qe() {
512
512
  return "segment ids are not increasing";
513
513
  }
514
- function Qe(e, n) {
514
+ function Ye(e, n) {
515
515
  return `Segment id ${e} out of range [0, ${n}), possibly because segmentIds input is not sorted.`;
516
516
  }
517
517
  function et(e, n, r) {
@@ -593,22 +593,22 @@ const rt = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
593
593
  */
594
594
  function ot(e) {
595
595
  try {
596
- return e.map((n) => R(n));
596
+ return e.map((n) => O(n));
597
597
  } catch (n) {
598
598
  throw new Error(`Failed to decode encoded string bytes into utf-8, error: ${n}`);
599
599
  }
600
600
  }
601
601
  function at(e) {
602
- return e.map((n) => M(n));
602
+ return e.map((n) => _(n));
603
603
  }
604
604
  const ht = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
605
605
  __proto__: null,
606
606
  ERF_A1: Ae,
607
- ERF_A2: Oe,
608
- ERF_A3: Re,
609
- ERF_A4: Me,
610
- ERF_A5: _e,
611
- ERF_P: Se,
607
+ ERF_A2: Re,
608
+ ERF_A3: Oe,
609
+ ERF_A4: _e,
610
+ ERF_A5: Me,
611
+ ERF_P: we,
612
612
  PARALLELIZE_THRESHOLD: I,
613
613
  get RowPartitionType() {
614
614
  return p;
@@ -616,7 +616,7 @@ const ht = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
616
616
  SELU_SCALE: U,
617
617
  SELU_SCALEALPHA: B,
618
618
  applyActivation: V,
619
- assertAndGetBroadcastShape: _,
619
+ assertAndGetBroadcastShape: M,
620
620
  assertAxesAreInnerMostDims: L,
621
621
  assertParamsConsistent: ue,
622
622
  assignToTypedArray: Le,
@@ -628,18 +628,18 @@ const ht = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
628
628
  combineRaggedTensorToTensorShapes: ce,
629
629
  complexWithEvenIndex: Te,
630
630
  complexWithOddIndex: be,
631
- computeConv2DInfo: k,
632
- computeConv3DInfo: G,
633
- computeDefaultPad: H,
631
+ computeConv2DInfo: G,
632
+ computeConv3DInfo: H,
633
+ computeDefaultPad: k,
634
634
  computeDilation2DInfo: q,
635
635
  computeOptimalWindowSize: ge,
636
636
  computeOutAndReduceShapes: N,
637
637
  computeOutShape: le,
638
638
  computePool2DInfo: Z,
639
- computePool3DInfo: K,
640
- convertConv2DDataFormat: X,
639
+ computePool3DInfo: X,
640
+ convertConv2DDataFormat: J,
641
641
  decodeEinsumEquation: xe,
642
- eitherStridesOrDilationsAreOne: J,
642
+ eitherStridesOrDilationsAreOne: K,
643
643
  expandShapeToKeepDim: x,
644
644
  exponent: ve,
645
645
  exponents: We,
@@ -650,8 +650,8 @@ const ht = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
650
650
  getComplexWithIndex: Fe,
651
651
  getEinsumComputePath: ze,
652
652
  getEinsumPermutation: Pe,
653
- getFusedBiasGradient: Y,
654
- getFusedDyActivation: Q,
653
+ getFusedBiasGradient: Q,
654
+ getFusedDyActivation: Y,
655
655
  getImageCenter: de,
656
656
  getInnerMostAxes: C,
657
657
  getPermuted: Ee,
@@ -661,19 +661,19 @@ const ht = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
661
661
  getReshapedPermuted: $e,
662
662
  getRowPartitionTypesHelper: he,
663
663
  getSliceBeginCoords: Ie,
664
- getSliceSize: we,
664
+ getSliceSize: Se,
665
665
  getSparseFillEmptyRowsIndicesDenseShapeMismatch: je,
666
- getSparseFillEmptyRowsNegativeIndexErrorMessage: ke,
667
- getSparseFillEmptyRowsOutOfRangeIndexErrorMessage: Ge,
666
+ getSparseFillEmptyRowsNegativeIndexErrorMessage: Ge,
667
+ getSparseFillEmptyRowsOutOfRangeIndexErrorMessage: He,
668
668
  getSparseReshapeEmptyTensorZeroOutputDimErrorMessage: Ze,
669
- getSparseReshapeInputOutputMismatchErrorMessage: Xe,
670
- getSparseReshapeInputOutputMultipleErrorMessage: Ke,
671
- getSparseReshapeMultipleNegativeOneOutputDimErrorMessage: He,
669
+ getSparseReshapeInputOutputMismatchErrorMessage: Je,
670
+ getSparseReshapeInputOutputMultipleErrorMessage: Xe,
671
+ getSparseReshapeMultipleNegativeOneOutputDimErrorMessage: ke,
672
672
  getSparseReshapeNegativeOutputDimErrorMessage: qe,
673
673
  getSparseSegmentReductionIndicesOutOfRangeErrorMessage: et,
674
- getSparseSegmentReductionNegativeSegmentIdsErrorMessage: Je,
675
- getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage: Ye,
676
- getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage: Qe,
674
+ getSparseSegmentReductionNegativeSegmentIdsErrorMessage: Ke,
675
+ getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage: Qe,
676
+ getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage: Ye,
677
677
  getUndoAxesPermutation: z,
678
678
  isIdentityPermutation: Ue,
679
679
  log: T,
@@ -697,8 +697,8 @@ export {
697
697
  Ee as B,
698
698
  $e as C,
699
699
  Ie as D,
700
- Se as E,
701
- we as F,
700
+ we as E,
701
+ Se as F,
702
702
  le as G,
703
703
  ue as H,
704
704
  xe as I,
@@ -728,22 +728,22 @@ export {
728
728
  ot as f,
729
729
  he as g,
730
730
  je as h,
731
- ke as i,
732
- Ge as j,
733
- He as k,
731
+ Ge as i,
732
+ He as j,
733
+ ke as k,
734
734
  qe as l,
735
735
  ye as m,
736
736
  Ze as n,
737
- Ke as o,
738
- Xe as p,
739
- Je as q,
740
- Ye as r,
741
- Qe as s,
737
+ Xe as o,
738
+ Je as p,
739
+ Ke as q,
740
+ Qe as r,
741
+ Ye as s,
742
742
  et as t,
743
743
  Ae as u,
744
744
  pe as v,
745
- Oe as w,
746
- Re as x,
747
- Me as y,
748
- _e as z
745
+ Re as w,
746
+ Oe as x,
747
+ _e as y,
748
+ Me as z
749
749
  };
@@ -1,5 +1,5 @@
1
- import { B as h, C as f, F as p, M as g, E as u, N as b } from "./index-BoWRt-10.js";
2
- import { r as T } from "./reshape-CdBq1WJ6.js";
1
+ import { C as h, D as f, M as p, H as g, E as u, X as b } from "./index-DdmHGZjq.js";
2
+ import { r as T } from "./reshape-WeJkT3ja.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -0,0 +1 @@
1
+ export declare function execute(backend: string): Promise<number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][] | number[][][][][][]>;
@@ -0,0 +1,22 @@
1
+ import { s, e as a } from "../index-DdmHGZjq.js";
2
+ import { t } from "../tensor4d-DLtk7Nxh.js";
3
+ async function u(e) {
4
+ await s(e);
5
+ const n = t(
6
+ [
7
+ [
8
+ [
9
+ [0.1, 0.2, 0, 0],
10
+ [0.1, 0.2, 0, 0],
11
+ [0, 0, 0, 0],
12
+ [0, 0, 0, 0]
13
+ ]
14
+ ]
15
+ ],
16
+ [1, 1, 4, 4]
17
+ ), r = t([[[[0.1, 0.2, 0.3, 0.4]]]], [1, 1, 1, 4]);
18
+ return await a().runKernel("AppendCache", { cache: n, item: r }, { maxSize: 4, pastLen: 2 }).array();
19
+ }
20
+ export {
21
+ u as execute
22
+ };
@@ -0,0 +1 @@
1
+ export declare function execute(backend: string): Promise<number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][] | number[][][][][][]>;
@@ -0,0 +1,37 @@
1
+ import { s as i, e } from "../index-DdmHGZjq.js";
2
+ import { t } from "../tensor4d-DLtk7Nxh.js";
3
+ import { t as a } from "../tensor2d-CObBWBkW.js";
4
+ async function k(n) {
5
+ await i(n);
6
+ const s = t(
7
+ [
8
+ [
9
+ [
10
+ [0.1, 0.2, 0.3, 0.4],
11
+ [0.3, 0.4, 0.5, 0.6]
12
+ ]
13
+ ]
14
+ ],
15
+ [1, 1, 2, 4]
16
+ ), o = t(
17
+ [
18
+ [
19
+ [
20
+ [0.5, 0.6, 0.5, 0.6],
21
+ [0.7, 0.8, 0.7, 0.8]
22
+ ]
23
+ ]
24
+ ],
25
+ [1, 1, 2, 4]
26
+ ), r = a(
27
+ [
28
+ [0, -1 / 0, -1 / 0, -1 / 0],
29
+ [0, 0, 0, -1 / 0]
30
+ ],
31
+ [2, 4]
32
+ );
33
+ return await e().runKernel("AttentionMask", { q: s, k: o, mask: r }, { divisor: 0.5, pastLen: 0 }).array();
34
+ }
35
+ export {
36
+ k as execute
37
+ };
@@ -0,0 +1,9 @@
1
+ interface Result {
2
+ backend: string;
3
+ result: unknown;
4
+ error?: string;
5
+ passed: boolean;
6
+ maxError?: number;
7
+ }
8
+ export default function runCheck(check: (backend: string) => Promise<unknown>, epsilon?: number): Promise<Result[]>;
9
+ export {};
@@ -0,0 +1,20 @@
1
+ import { arraysClose as l } from "../utilities/arrayClose.js";
2
+ async function f(c, a) {
3
+ const n = ["cpu", "webgl", "webgpu"], t = [];
4
+ for (const e of n)
5
+ try {
6
+ const r = await c(e);
7
+ t.push({ backend: e, result: r, passed: !0 });
8
+ } catch (r) {
9
+ t.push({ backend: e, error: r.message, result: [], passed: !1 });
10
+ }
11
+ const s = await Promise.all(t), u = s[0].result;
12
+ for (let e = 1; e < s.length; e++) {
13
+ const r = s[e].result, o = l(u, r);
14
+ s[e].passed = o <= (a ?? 1e-6), s[e].maxError = o;
15
+ }
16
+ return s;
17
+ }
18
+ export {
19
+ f as default
20
+ };
@@ -0,0 +1 @@
1
+ export declare function execute(backend: string): Promise<number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][] | number[][][][][][]>;
@@ -0,0 +1,18 @@
1
+ import { s as e, e as o } from "../index-DdmHGZjq.js";
2
+ import { t as s } from "../tensor2d-CObBWBkW.js";
3
+ async function m(t) {
4
+ await e(t);
5
+ const r = s(
6
+ [
7
+ [0.1, 0.2, 0, 0],
8
+ [0.1, 0.2, 0, 0],
9
+ [0, 0, 0, 0],
10
+ [0, 0, 0, 0]
11
+ ],
12
+ [4, 4]
13
+ );
14
+ return await o().runKernel("Gelu", { x: r }).array();
15
+ }
16
+ export {
17
+ m as execute
18
+ };
@@ -0,0 +1,19 @@
1
+ import { execute as rope } from './rope';
2
+ import { execute as normRMS } from './normRMS';
3
+ import { execute as qkv } from './qkv';
4
+ import { execute as gelu } from './gelu';
5
+ import { execute as normRMSGrad } from './normRMSGrad';
6
+ import { execute as appendCache } from './appendCache';
7
+ import { execute as attentionMask } from './attentionMask';
8
+ import { default as runCheck } from './check';
9
+ declare const checks: {
10
+ rope: typeof rope;
11
+ qkv: typeof qkv;
12
+ gelu: typeof gelu;
13
+ normRMS: typeof normRMS;
14
+ normRMSGrad: typeof normRMSGrad;
15
+ appendCache: typeof appendCache;
16
+ attentionMask: typeof attentionMask;
17
+ runCheck: typeof runCheck;
18
+ };
19
+ export default checks;
@@ -0,0 +1,21 @@
1
+ import { execute as e } from "./rope.js";
2
+ import { execute as t } from "./normRMS.js";
3
+ import { execute as o } from "./qkv.js";
4
+ import { execute as r } from "./gelu.js";
5
+ import { execute as c } from "./normRMSGrad.js";
6
+ import { execute as m } from "./appendCache.js";
7
+ import { execute as u } from "./attentionMask.js";
8
+ import x from "./check.js";
9
+ const d = {
10
+ rope: e,
11
+ qkv: o,
12
+ gelu: r,
13
+ normRMS: t,
14
+ normRMSGrad: c,
15
+ appendCache: m,
16
+ attentionMask: u,
17
+ runCheck: x
18
+ };
19
+ export {
20
+ d as default
21
+ };
@@ -0,0 +1 @@
1
+ export declare function execute(backend: string): Promise<(number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][] | number[][][][][][])[]>;
@@ -0,0 +1,16 @@
1
+ import { s as u, y as A, e as y } from "../index-DdmHGZjq.js";
2
+ import { a as h } from "../ops-542ai2vG.js";
3
+ import { t as p } from "../tensor1d-CtJq5BOv.js";
4
+ import { t as a } from "../tensor-DbqgIV9B.js";
5
+ const w = Array.from({ length: 2048 * 192 }, () => Math.random()), x = Array.from({ length: 192 }, () => Math.random()), M = Array.from({ length: 2048 * 192 }, () => Math.random());
6
+ async function k(t) {
7
+ await u(t);
8
+ const o = p(x, "float32"), n = a(w, [16, 128, 192], "float32"), s = a(M, [16, 128, 192], "float32"), e = (d, g) => {
9
+ const i = y().runKernel("RMSNorm", { x: d, gamma: g });
10
+ return h.meanSquaredError(i, s);
11
+ }, { value: m, grads: r } = A(e)([n, o]), c = await m.array(), f = await r[0].array(), l = await r[1].array();
12
+ return [c, f, l];
13
+ }
14
+ export {
15
+ k as execute
16
+ };
@@ -0,0 +1 @@
1
+ export declare function execute(backend: string): Promise<(number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][] | number[][][][][][])[]>;
@@ -0,0 +1,12 @@
1
+ import { s as c, e as d } from "../index-DdmHGZjq.js";
2
+ import { t as f } from "../tensor1d-CtJq5BOv.js";
3
+ import { t as r } from "../tensor-DbqgIV9B.js";
4
+ const y = Array.from({ length: 2048 * 192 }, () => Math.random()), i = Array.from({ length: 192 }, () => Math.random()), l = Array.from({ length: 2048 * 192 }, () => Math.random());
5
+ async function x(t) {
6
+ await c(t);
7
+ const o = f(i, "float32"), n = r(y, [16, 128, 192], "float32"), m = r(l, [16, 128, 192], "float32"), a = d().runKernel("RMSNormGrad", { x: n, gamma: o, dy: m }), s = await a[0].array(), e = await a[1].array();
8
+ return [s, e];
9
+ }
10
+ export {
11
+ x as execute
12
+ };
@@ -0,0 +1 @@
1
+ export declare function execute(backend: string): Promise<(number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][] | number[][][][][][])[]>;
@@ -0,0 +1,25 @@
1
+ import { s as c, e as m } from "../index-DdmHGZjq.js";
2
+ import { t as i } from "../tensor3d-BOukqWwr.js";
3
+ import { t as u } from "../tensor2d-CObBWBkW.js";
4
+ async function w(a) {
5
+ await c(a);
6
+ const o = i(
7
+ [
8
+ [
9
+ [0.1, 0.2],
10
+ [0.3, 0.4]
11
+ ]
12
+ ],
13
+ [1, 2, 2]
14
+ ), r = u(
15
+ [
16
+ [0.5, 0.6, 0.9, 1, 1.3, 1.4],
17
+ [0.7, 0.8, 1.1, 1.2, 1.5, 1.6]
18
+ ],
19
+ [2, 6]
20
+ ), t = m().runKernel("QKV", { x: o, kernel: r }, { heads: 1 }), s = await t[0].array(), n = await t[1].array(), e = await t[2].array();
21
+ return [s, n, e];
22
+ }
23
+ export {
24
+ w as execute
25
+ };
@@ -0,0 +1 @@
1
+ export declare function execute(backend: string): Promise<number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][] | number[][][][][][] | Promise<number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][] | number[][][][][][]>[]>;
@@ -0,0 +1,21 @@
1
+ import { s as c, e as i } from "../index-DdmHGZjq.js";
2
+ import { t as m } from "../tensor4d-DLtk7Nxh.js";
3
+ import { t } from "../tensor3d-BOukqWwr.js";
4
+ async function y(n) {
5
+ await c(n);
6
+ const s = m(
7
+ [
8
+ [
9
+ [
10
+ [0.1, 0.2],
11
+ [0.3, 0.4]
12
+ ]
13
+ ]
14
+ ],
15
+ [1, 1, 2, 2]
16
+ ), e = t([0.5, 0.6], [2, 1, 1]), o = t([0.9, 1], [2, 1, 1]), r = i().runKernel("Rope", { x: s, sin: e, cos: o }, { pastLen: 0 });
17
+ return Array.isArray(r) ? r.map((a) => a.array()) : r.array();
18
+ }
19
+ export {
20
+ y as execute
21
+ };
@@ -1,4 +1,4 @@
1
- import { B as s, l as a, D as p, M as i, E as l, Q as f } from "./index-BoWRt-10.js";
1
+ import { C as s, n as a, F as p, H as i, E as l, I as f } from "./index-DdmHGZjq.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.