@genai-fi/nanogpt 0.7.3 → 0.8.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. package/dist/Generator.d.ts +25 -2
  2. package/dist/Generator.js +152 -49
  3. package/dist/{RealDiv-Dy0p8Bvo.js → RealDiv-D_q39E3A.js} +13 -13
  4. package/dist/{Reshape-DvudQDvJ.js → Reshape-41YpQqEo.js} +1 -1
  5. package/dist/{Reshape-DH5srBP0.js → Reshape-Bh_jzKzV.js} +5 -5
  6. package/dist/TeachableLLM.d.ts +6 -6
  7. package/dist/TeachableLLM.js +33 -31
  8. package/dist/Trainer.d.ts +13 -2
  9. package/dist/Trainer.js +21 -12
  10. package/dist/{axis_util-BzbKo31C.js → axis_util-Did9235A.js} +3 -3
  11. package/dist/backend.js +2 -2
  12. package/dist/{backend_util-TE7aTPhZ.js → backend_util-yC3YH1jo.js} +58 -58
  13. package/dist/{broadcast_to-CdbwV-Dj.js → broadcast_to-CUvOdOT5.js} +2 -2
  14. package/dist/checks/appendCache.d.ts +1 -0
  15. package/dist/checks/appendCache.js +22 -0
  16. package/dist/checks/attentionMask.d.ts +1 -0
  17. package/dist/checks/attentionMask.js +37 -0
  18. package/dist/checks/check.d.ts +9 -0
  19. package/dist/checks/check.js +20 -0
  20. package/dist/checks/gelu.d.ts +1 -0
  21. package/dist/checks/gelu.js +18 -0
  22. package/dist/checks/index.d.ts +19 -0
  23. package/dist/checks/index.js +21 -0
  24. package/dist/checks/normRMS.d.ts +1 -0
  25. package/dist/checks/normRMS.js +16 -0
  26. package/dist/checks/normRMSGrad.d.ts +1 -0
  27. package/dist/checks/normRMSGrad.js +12 -0
  28. package/dist/checks/qkv.d.ts +1 -0
  29. package/dist/checks/qkv.js +25 -0
  30. package/dist/checks/rope.d.ts +1 -0
  31. package/dist/checks/rope.js +21 -0
  32. package/dist/{concat-CsxrgovM.js → concat-pHiVqR3L.js} +1 -1
  33. package/dist/{dataset-CtdBYwjo.js → dataset-DPPl-iLT.js} +9 -9
  34. package/dist/{dropout-DYs5QFGQ.js → dropout-CcKSfOYE.js} +18 -18
  35. package/dist/exports_initializers-DKk7-bsx.js +16 -0
  36. package/dist/{gather-CMMy2KEG.js → gather-CPg6ZlQA.js} +1 -1
  37. package/dist/{gelu-C-dPj6Ku.js → gelu-BkcmEEyD.js} +1 -1
  38. package/dist/{gpgpu_math-DGNLNL4I.js → gpgpu_math-D_ODOLix.js} +26 -26
  39. package/dist/{index-BoWRt-10.js → index-DdmHGZjq.js} +659 -650
  40. package/dist/{index-CLthM0TO.js → index-evZ57wr4.js} +185 -185
  41. package/dist/{kernel_funcs_utils-BYKWV8Aa.js → kernel_funcs_utils-CDfFpUab.js} +21 -21
  42. package/dist/layers/BaseLayer.d.ts +8 -13
  43. package/dist/layers/BaseLayer.js +25 -13
  44. package/dist/layers/CausalSelfAttention.d.ts +3 -2
  45. package/dist/layers/CausalSelfAttention.js +28 -28
  46. package/dist/layers/MLP.d.ts +3 -2
  47. package/dist/layers/MLP.js +16 -20
  48. package/dist/layers/PositionEmbedding.d.ts +9 -0
  49. package/dist/layers/PositionEmbedding.js +45 -0
  50. package/dist/layers/RMSNorm.d.ts +3 -2
  51. package/dist/layers/RMSNorm.js +6 -6
  52. package/dist/layers/RoPECache.d.ts +1 -1
  53. package/dist/layers/RoPECache.js +4 -4
  54. package/dist/layers/TiedEmbedding.d.ts +3 -2
  55. package/dist/layers/TiedEmbedding.js +29 -7
  56. package/dist/layers/TransformerBlock.d.ts +3 -2
  57. package/dist/layers/TransformerBlock.js +1 -1
  58. package/dist/loader/load.d.ts +2 -2
  59. package/dist/loader/loadHF.d.ts +2 -2
  60. package/dist/loader/loadTransformers.d.ts +4 -2
  61. package/dist/loader/loadTransformers.js +10 -9
  62. package/dist/loader/newZipLoad.d.ts +2 -2
  63. package/dist/loader/oldZipLoad.d.ts +2 -2
  64. package/dist/loader/oldZipLoad.js +44 -51
  65. package/dist/loader/save.d.ts +8 -0
  66. package/dist/loader/save.js +62 -0
  67. package/dist/{log_sum_exp-DbjkV734.js → log_sum_exp-C8yFJfZz.js} +45 -24
  68. package/dist/main.d.ts +6 -4
  69. package/dist/main.js +24 -18
  70. package/dist/{mat_mul-8m8pfdcx.js → mat_mul-Dpy2mMRu.js} +1 -1
  71. package/dist/mod-CbibJi3D.js +27 -0
  72. package/dist/models/NanoGPTV1.d.ts +15 -0
  73. package/dist/models/NanoGPTV1.js +71 -0
  74. package/dist/{config.d.ts → models/config.d.ts} +1 -0
  75. package/dist/{config.js → models/config.js} +1 -0
  76. package/dist/models/factory.d.ts +3 -0
  77. package/dist/models/factory.js +14 -0
  78. package/dist/models/model.d.ts +26 -0
  79. package/dist/models/model.js +70 -0
  80. package/dist/{mulmat_packed_gpu-VSekgsNv.js → mulmat_packed_gpu-q_Gmwyld.js} +1 -1
  81. package/dist/{ones-Dj0SDhHf.js → ones-BAqVh-eA.js} +2 -2
  82. package/dist/ops/adamAdjust.js +1 -1
  83. package/dist/ops/adamMoments.js +1 -1
  84. package/dist/ops/appendCache.js +3 -3
  85. package/dist/ops/attentionMask.js +1 -1
  86. package/dist/ops/cpu/adamAdjust.js +9 -9
  87. package/dist/ops/cpu/adamMoments.js +2 -2
  88. package/dist/ops/cpu/appendCache.js +2 -2
  89. package/dist/ops/cpu/attentionMask.js +5 -5
  90. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  91. package/dist/ops/cpu/gatherSub.js +5 -5
  92. package/dist/ops/cpu/gelu.js +1 -1
  93. package/dist/ops/cpu/matMulGelu.js +2 -2
  94. package/dist/ops/cpu/matMulMul.js +1 -1
  95. package/dist/ops/cpu/mulDropout.js +1 -1
  96. package/dist/ops/cpu/normRMS.js +1 -1
  97. package/dist/ops/cpu/qkv.js +3 -3
  98. package/dist/ops/cpu/rope.js +5 -5
  99. package/dist/ops/cpu/scatterSub.js +7 -7
  100. package/dist/ops/fusedSoftmax.js +1 -1
  101. package/dist/ops/gatherSub.js +1 -1
  102. package/dist/ops/gelu.js +2 -2
  103. package/dist/ops/grads/attentionMask.js +1 -1
  104. package/dist/ops/grads/fusedSoftmax.js +2 -2
  105. package/dist/ops/grads/gelu.js +2 -2
  106. package/dist/ops/grads/matMulGelu.js +1 -1
  107. package/dist/ops/grads/normRMS.js +1 -1
  108. package/dist/ops/grads/qkv.js +1 -1
  109. package/dist/ops/grads/rope.js +1 -1
  110. package/dist/ops/matMulGelu.js +1 -1
  111. package/dist/ops/matMulMul.js +1 -1
  112. package/dist/ops/mulDrop.js +1 -1
  113. package/dist/ops/normRMS.js +1 -1
  114. package/dist/ops/qkv.js +1 -1
  115. package/dist/ops/rope.js +4 -4
  116. package/dist/ops/scatterSub.js +1 -1
  117. package/dist/ops/webgl/adamAdjust.js +2 -2
  118. package/dist/ops/webgl/adamMoments.js +1 -1
  119. package/dist/ops/webgl/appendCache.js +1 -1
  120. package/dist/ops/webgl/attentionMask.js +1 -1
  121. package/dist/ops/webgl/fusedSoftmax.js +4 -4
  122. package/dist/ops/webgl/gatherSub.js +1 -1
  123. package/dist/ops/webgl/gelu.js +2 -2
  124. package/dist/ops/webgl/log.js +3 -3
  125. package/dist/ops/webgl/matMulGelu.js +10 -10
  126. package/dist/ops/webgl/matMulMul.js +1 -1
  127. package/dist/ops/webgl/mulDropout.js +1 -1
  128. package/dist/ops/webgl/normRMS.js +2 -2
  129. package/dist/ops/webgl/qkv.js +1 -1
  130. package/dist/ops/webgl/rope.js +1 -1
  131. package/dist/ops/webgl/scatterSub.js +1 -1
  132. package/dist/ops/webgpu/adamAdjust.js +3 -3
  133. package/dist/ops/webgpu/adamMoments.js +3 -3
  134. package/dist/ops/webgpu/appendCache.js +3 -3
  135. package/dist/ops/webgpu/attentionMask.js +3 -3
  136. package/dist/ops/webgpu/gatherSub.js +3 -3
  137. package/dist/ops/webgpu/gelu.js +3 -3
  138. package/dist/ops/webgpu/normRMS.js +2 -2
  139. package/dist/ops/webgpu/normRMSGrad.js +5 -5
  140. package/dist/ops/webgpu/qkv.js +3 -3
  141. package/dist/ops/webgpu/rope.js +3 -3
  142. package/dist/ops/webgpu/scatterSub.js +3 -3
  143. package/dist/ops/webgpu/utils/reductions.js +4 -4
  144. package/dist/ops-542ai2vG.js +1525 -0
  145. package/dist/{random_width-sZORGo5k.js → random_width-DKGeiFuR.js} +1471 -1538
  146. package/dist/{range-CRuAh-gd.js → range-BcUvLuf5.js} +1 -1
  147. package/dist/{reciprocal-BvGAyKyu.js → reciprocal-DhDWSKiD.js} +1 -1
  148. package/dist/{register_all_kernels-BwDSRN-f.js → register_all_kernels-Do9VvZmo.js} +2488 -2534
  149. package/dist/{max-Ddnnb5xe.js → relu-B1AXs7p5.js} +6 -6
  150. package/dist/{reshape-CdBq1WJ6.js → reshape-WeJkT3ja.js} +1 -1
  151. package/dist/{scatter_nd_util-DUstGbU1.js → scatter_nd_util-B7yDhiQr.js} +1 -1
  152. package/dist/{selu_util-BJEXVvjX.js → selu_util-BgUO9gHY.js} +125 -146
  153. package/dist/{shared-wS99K7_n.js → shared-CZiWmQCI.js} +1 -1
  154. package/dist/{shared-B8ztnyEk.js → shared-V6D_md-c.js} +72 -72
  155. package/dist/{sin-BeA3tsEd.js → sin-CPxad7Am.js} +1 -1
  156. package/dist/{slice-BiOsknYS.js → slice-B7jXtPnp.js} +1 -1
  157. package/dist/{softmax-Bv_6lyMX.js → softmax-BfsyI4As.js} +1 -1
  158. package/dist/{split-B-dikLRw.js → split-BPxr8_8m.js} +1 -1
  159. package/dist/{stack-B17UN2nn.js → stack-BNwLzE43.js} +1 -1
  160. package/dist/{sum-66ew2byf.js → sum-ByFINZgi.js} +3 -3
  161. package/dist/{tensor-JwS7ZYY6.js → tensor-DbqgIV9B.js} +1 -1
  162. package/dist/tensor1d-CtJq5BOv.js +27 -0
  163. package/dist/{tensor2d-wxPAnDQy.js → tensor2d-CObBWBkW.js} +1 -1
  164. package/dist/tensor3d-BOukqWwr.js +30 -0
  165. package/dist/tensor4d-DLtk7Nxh.js +30 -0
  166. package/dist/training/Adam.js +2 -2
  167. package/dist/training/AdamExt.js +1 -1
  168. package/dist/training/DatasetBuilder.js +2 -2
  169. package/dist/training/Evaluator.d.ts +2 -2
  170. package/dist/training/FullTrainer.d.ts +3 -3
  171. package/dist/training/FullTrainer.js +61 -69
  172. package/dist/training/Trainer.d.ts +15 -3
  173. package/dist/training/Trainer.js +39 -47
  174. package/dist/training/sparseCrossEntropy.js +12 -13
  175. package/dist/utilities/arrayClose.d.ts +1 -1
  176. package/dist/utilities/arrayClose.js +16 -7
  177. package/dist/utilities/dummy.d.ts +4 -4
  178. package/dist/utilities/dummy.js +13 -13
  179. package/dist/utilities/multinomialCPU.js +2 -2
  180. package/dist/utilities/parameters.d.ts +1 -1
  181. package/dist/utilities/performance.js +1 -1
  182. package/dist/utilities/profile.js +1 -1
  183. package/dist/utilities/safetensors.js +2 -2
  184. package/dist/utilities/weights.js +2 -2
  185. package/dist/{variable-BuddVFLa.js → variable-DPFOJyRG.js} +1 -1
  186. package/dist/{webgpu_program-PFzf1hAQ.js → webgpu_program-Dhk9R5aG.js} +1 -1
  187. package/dist/{webgpu_util-D____QpY.js → webgpu_util-BqGnZg8t.js} +27 -27
  188. package/dist/{zeros--BdLQ3oG.js → zeros-Dnwix0p4.js} +1 -1
  189. package/package.json +2 -3
  190. package/dist/NanoGPTModel.d.ts +0 -52
  191. package/dist/NanoGPTModel.js +0 -203
  192. package/dist/TiedEmbedding-BxOerUmB.js +0 -43
  193. package/dist/ops-BFGCx8Ri.js +0 -1202
  194. package/dist/utilities/generate.d.ts +0 -3
  195. package/dist/utilities/generate.js +0 -22
  196. package/dist/utilities/save.d.ts +0 -9
  197. package/dist/utilities/save.js +0 -61
@@ -1,6 +1,6 @@
1
- import { k as B, j as G, am as K, a6 as W, an as z, ao as V, ac as N, ap as F, u as S } from "./index-BoWRt-10.js";
2
- import { u as O, f as Y } from "./gpgpu_math-DGNLNL4I.js";
3
- import { f as v } from "./backend_util-TE7aTPhZ.js";
1
+ import { l as B, j as G, az as K, aa as z, at as W, aA as V, ag as N, au as F, u as S } from "./index-DdmHGZjq.js";
2
+ import { u as O, f as Y } from "./gpgpu_math-D_ODOLix.js";
3
+ import { f as v } from "./backend_util-yC3YH1jo.js";
4
4
  /**
5
5
  * @license
6
6
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -295,7 +295,7 @@ function L(t) {
295
295
  return o.complexTensorInfos = { real: i, imag: a }, n;
296
296
  }
297
297
  const me = {
298
- kernelName: W,
298
+ kernelName: z,
299
299
  backendName: "webgl",
300
300
  kernelFunc: L
301
301
  };
@@ -315,16 +315,16 @@ const me = {
315
315
  * limitations under the License.
316
316
  * =============================================================================
317
317
  */
318
- const w = "return (a < 0.) ? b * a : a;", k = `
318
+ const w = "return (a < 0.) ? b * a : a;", R = `
319
319
  vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));
320
320
  return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);
321
321
  `;
322
322
  function oe(t) {
323
- const { inputs: e, backend: s, attrs: r } = t, { x: u } = e, { alpha: n } = r, o = s.makeTensorInfo([], "float32", V(n, "float32")), i = N().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new E(k, u.shape, o.shape) : new b(w, u.shape, o.shape), a = s.runWebGLProgram(i, [u, o], "float32");
323
+ const { inputs: e, backend: s, attrs: r } = t, { x: u } = e, { alpha: n } = r, o = s.makeTensorInfo([], "float32", V(n, "float32")), i = N().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new E(R, u.shape, o.shape) : new b(w, u.shape, o.shape), a = s.runWebGLProgram(i, [u, o], "float32");
324
324
  return s.disposeIntermediateTensorInfo(o), a;
325
325
  }
326
326
  const be = {
327
- kernelName: z,
327
+ kernelName: W,
328
328
  backendName: "webgl",
329
329
  kernelFunc: oe
330
330
  };
@@ -344,12 +344,12 @@ const be = {
344
344
  * limitations under the License.
345
345
  * =============================================================================
346
346
  */
347
- const R = "return (a < 0.) ? b * a : a;", U = `
347
+ const k = "return (a < 0.) ? b * a : a;", U = `
348
348
  vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));
349
349
  return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);
350
350
  `;
351
351
  function ue(t) {
352
- const { inputs: e, backend: s } = t, { x: r, alpha: u } = e, n = N().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new E(U, r.shape, u.shape) : new b(R, r.shape, u.shape);
352
+ const { inputs: e, backend: s } = t, { x: r, alpha: u } = e, n = N().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new E(U, r.shape, u.shape) : new b(k, r.shape, u.shape);
353
353
  return s.runWebGLProgram(n, [r, u], "float32");
354
354
  }
355
355
  const Ne = {
@@ -386,7 +386,7 @@ function ye({ opSnippet: t, packedOpSnippet: e, cpuKernelImpl: s, dtype: r }) {
386
386
  return c ? l = new ne(o.shape, e) : l = new q(o.shape, t), i.runWebGLProgram(l, [o], a);
387
387
  };
388
388
  }
389
- function Ie({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: s = !1, supportsComplex: r = !1, cpuKernelImpl: u, dtype: n }) {
389
+ function Ae({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: s = !1, supportsComplex: r = !1, cpuKernelImpl: u, dtype: n }) {
390
390
  return ({ inputs: o, backend: i }) => {
391
391
  const { a, b: c } = o, l = i;
392
392
  if (r && a.dtype === "complex64") {
@@ -404,8 +404,8 @@ function Ie({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: s = !1, suppor
404
404
  shape: c.shape
405
405
  }, D = new b(t, a.shape, c.shape);
406
406
  return l.runWebGLProgram(D, [$, _], S(p.dtype, x.dtype));
407
- }), A = L({ inputs: { real: g, imag: m }, backend: l });
408
- return l.disposeIntermediateTensorInfo(g), l.disposeIntermediateTensorInfo(m), A;
407
+ }), I = L({ inputs: { real: g, imag: m }, backend: l });
408
+ return l.disposeIntermediateTensorInfo(g), l.disposeIntermediateTensorInfo(m), I;
409
409
  }
410
410
  const d = n || S(a.dtype, c.dtype);
411
411
  if ((a.dtype === "string" || c.dtype === "string" || l.shouldExecuteOnCPU([a, c])) && u != null) {
@@ -415,15 +415,15 @@ function Ie({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: s = !1, suppor
415
415
  ) : h, m = a.dtype === "string" ? (
416
416
  // tslint:disable-next-line: no-any
417
417
  v(f)
418
- ) : f, [A, C] = u(a.shape, c.shape, g, m, d), p = l.makeTensorInfo(C, d), x = l.texData.get(p.dataId);
419
- return x.values = A, p;
418
+ ) : f, [I, C] = u(a.shape, c.shape, g, m, d), p = l.makeTensorInfo(C, d), x = l.texData.get(p.dataId);
419
+ return x.values = I, p;
420
420
  }
421
421
  const y = N().getBool("WEBGL_PACK_BINARY_OPERATIONS") && e != null;
422
- let I;
423
- return y ? I = new E(e, a.shape, c.shape, s) : I = new b(t, a.shape, c.shape), l.runWebGLProgram(I, [a, c], d);
422
+ let A;
423
+ return y ? A = new E(e, a.shape, c.shape, s) : A = new b(t, a.shape, c.shape), l.runWebGLProgram(A, [a, c], d);
424
424
  };
425
425
  }
426
- function Ae(t, e = !1) {
426
+ function Ie(t, e = !1) {
427
427
  if (t === "linear")
428
428
  return e ? ee : j;
429
429
  if (t === "relu")
@@ -433,9 +433,9 @@ function Ae(t, e = !1) {
433
433
  if (t === "relu6")
434
434
  return e ? ae : Q;
435
435
  if (t === "prelu")
436
- return e ? U : R;
436
+ return e ? U : k;
437
437
  if (t === "leakyrelu")
438
- return e ? k : w;
438
+ return e ? R : w;
439
439
  if (t === "sigmoid")
440
440
  return e ? re : X;
441
441
  throw new Error(`Activation ${t} has not been implemented for the WebGL backend.`);
@@ -446,7 +446,7 @@ export {
446
446
  T as C,
447
447
  ne as U,
448
448
  Z as a,
449
- Ie as b,
449
+ Ae as b,
450
450
  pe as c,
451
451
  he as d,
452
452
  q as e,
@@ -457,7 +457,7 @@ export {
457
457
  fe as j,
458
458
  xe as k,
459
459
  Oe as l,
460
- Ae as m,
460
+ Ie as m,
461
461
  me as n,
462
462
  ge as o,
463
463
  be as p,
@@ -1,27 +1,22 @@
1
- import { GPTConfig } from '../config';
1
+ import { GPTConfig } from '../models/config';
2
2
  import { default as MemoryProfiler } from '../utilities/profile';
3
3
  import { default as RoPECache } from './RoPECache';
4
4
  import { Tensor, Variable } from '@tensorflow/tfjs-core';
5
- export interface LayerConfig {
6
- checkpointing?: boolean;
7
- profiler?: MemoryProfiler;
8
- ropeCache?: RoPECache;
9
- }
10
- export interface GPTLayerConfig {
11
- gpt: GPTConfig;
12
- layerConfig: LayerConfig;
13
- }
14
5
  export interface ForwardAttributes {
15
6
  training: boolean;
7
+ checkpointing?: boolean;
8
+ ropeCache?: RoPECache;
16
9
  }
17
10
  export default abstract class BaseLayer<ATTR extends ForwardAttributes = ForwardAttributes> {
18
11
  readonly parent?: BaseLayer;
19
- readonly config: GPTLayerConfig;
12
+ readonly config: GPTConfig;
20
13
  private _variables;
21
14
  private _trainable;
22
15
  readonly children: BaseLayer[];
23
- constructor(config: GPTLayerConfig, parent?: BaseLayer);
16
+ private profiler?;
17
+ constructor(config: GPTConfig, parent?: BaseLayer);
24
18
  getProfiler(): MemoryProfiler | undefined;
19
+ setProfiler(profiler: MemoryProfiler | null): void;
25
20
  startMemory(): void;
26
21
  endMemory(label: string): void;
27
22
  addVariable(name: string, variable?: Variable): void;
@@ -29,7 +24,7 @@ export default abstract class BaseLayer<ATTR extends ForwardAttributes = Forward
29
24
  get trainableVariables(): Variable[];
30
25
  get trainable(): boolean;
31
26
  set trainable(value: boolean);
32
- getVariable(name: string): Variable;
27
+ getVariable(name: string, recursive?: boolean): Variable;
33
28
  hasVariable(name: string): boolean;
34
29
  setVariable(name: string, variable: Variable): void;
35
30
  saveWeights(map: Map<string, Tensor[]>): void;
@@ -1,22 +1,28 @@
1
- import { T as g, y as p, e as o, A as v } from "../index-BoWRt-10.js";
2
- import { v as _ } from "../variable-BuddVFLa.js";
3
- class M {
1
+ import { T as p, J as g, e as o, K as v } from "../index-DdmHGZjq.js";
2
+ import { v as _ } from "../variable-DPFOJyRG.js";
3
+ class T {
4
4
  parent;
5
5
  config;
6
6
  _variables = /* @__PURE__ */ new Map();
7
7
  _trainable = !0;
8
8
  children = [];
9
+ profiler;
9
10
  constructor(t, r) {
10
11
  this.config = t, this.parent = r, this.parent && this.parent.children.push(this);
11
12
  }
12
13
  getProfiler() {
13
- return this.config.layerConfig.profiler;
14
+ return this.profiler;
15
+ }
16
+ setProfiler(t) {
17
+ this.profiler = t || void 0, this.children.forEach((r) => {
18
+ r.setProfiler(t);
19
+ });
14
20
  }
15
21
  startMemory() {
16
- this.config.layerConfig.profiler?.startMemory();
22
+ this.profiler?.startMemory();
17
23
  }
18
24
  endMemory(t) {
19
- this.config.layerConfig.profiler?.endMemory(t);
25
+ this.profiler?.endMemory(t);
20
26
  }
21
27
  addVariable(t, r) {
22
28
  this._variables.set(t, r || null);
@@ -41,11 +47,17 @@ class M {
41
47
  r.trainable = t;
42
48
  });
43
49
  }
44
- getVariable(t) {
45
- const r = this._variables.get(t);
46
- if (!r)
50
+ getVariable(t, r = !1) {
51
+ const e = this._variables.get(t);
52
+ if (!e && r)
53
+ for (const i of this.children) {
54
+ const s = i.getVariable(t, !0);
55
+ if (s)
56
+ return s;
57
+ }
58
+ if (!e)
47
59
  throw new Error(`Variable ${t} not found`);
48
- return r;
60
+ return e;
49
61
  }
50
62
  hasVariable(t) {
51
63
  return this._variables.get(t) !== null;
@@ -85,7 +97,7 @@ class M {
85
97
  call(t, ...r) {
86
98
  this.build();
87
99
  const e = this.forward(t, ...r);
88
- if (t.training && e instanceof g) {
100
+ if (t.training && e instanceof p) {
89
101
  const i = this.dropout(e);
90
102
  return i !== e && e.dispose(), i;
91
103
  } else
@@ -95,7 +107,7 @@ class M {
95
107
  return this.build(), this.checkpointingFn(t, ...r);
96
108
  }
97
109
  checkpointingFn(t, ...r) {
98
- const e = this.trainableVariables, s = p((...a) => {
110
+ const e = this.trainableVariables, s = g((...a) => {
99
111
  const l = a[a.length - 1], n = a.slice(0, r.length), h = this.forward(t, ...n);
100
112
  return l(n), { value: h, gradFunc: (c, f) => {
101
113
  const u = o().state.activeTape;
@@ -112,5 +124,5 @@ class M {
112
124
  }
113
125
  }
114
126
  export {
115
- M as default
127
+ T as default
116
128
  };
@@ -1,5 +1,6 @@
1
- import { default as BaseLayer, ForwardAttributes, GPTLayerConfig } from './BaseLayer';
1
+ import { default as BaseLayer, ForwardAttributes } from './BaseLayer';
2
2
  import { Tensor } from '@tensorflow/tfjs-core';
3
+ import { GPTConfig } from '../models/config';
3
4
  export type KVCache = {
4
5
  k?: Tensor;
5
6
  v?: Tensor;
@@ -22,7 +23,7 @@ export default class CausalSelfAttention extends BaseLayer<AttentionForwardAttri
22
23
  private projUnits;
23
24
  private ATTN;
24
25
  private PROJ;
25
- constructor(index: number, config: GPTLayerConfig, parent?: BaseLayer);
26
+ constructor(index: number, config: GPTConfig, parent?: BaseLayer);
26
27
  protected build(): void;
27
28
  private getAttentionScores;
28
29
  private getAttentionScoresWithPast;
@@ -3,14 +3,14 @@ import O from "./BaseLayer.js";
3
3
  import { qkv as P } from "../ops/qkv.js";
4
4
  import { rope as v } from "../ops/rope.js";
5
5
  import { appendCache as V } from "../ops/appendCache.js";
6
- import { w as c, t as C } from "../index-BoWRt-10.js";
6
+ import { k as c, t as C } from "../index-DdmHGZjq.js";
7
7
  import { fusedSoftmax as T } from "../ops/fusedSoftmax.js";
8
- import { d as y } from "../random_width-sZORGo5k.js";
9
- import { v as b } from "../variable-BuddVFLa.js";
10
- import { r as k, d as L } from "../dropout-DYs5QFGQ.js";
11
- import { r as N } from "../reshape-CdBq1WJ6.js";
12
- import { m as R } from "../mat_mul-8m8pfdcx.js";
13
- class W extends O {
8
+ import { d as L } from "../random_width-DKGeiFuR.js";
9
+ import { v as b } from "../variable-DPFOJyRG.js";
10
+ import { r as k, d as y } from "../dropout-CcKSfOYE.js";
11
+ import { r as N } from "../reshape-WeJkT3ja.js";
12
+ import { m as R } from "../mat_mul-Dpy2mMRu.js";
13
+ class $ extends O {
14
14
  divisor;
15
15
  index;
16
16
  units;
@@ -18,27 +18,27 @@ class W extends O {
18
18
  ATTN;
19
19
  PROJ;
20
20
  constructor(t, i, s) {
21
- super(i, s), this.index = t, this.units = i.gpt.nEmbed * 3, this.projUnits = i.gpt.nEmbed, this.ATTN = `block_${this.index}_cAttn`, this.PROJ = `block_${this.index}_cProj`, this.addVariable(this.ATTN), this.addVariable(this.PROJ), this.divisor = 1 / Math.sqrt(i.gpt.nEmbed / i.gpt.nHead);
21
+ super(i, s), this.index = t, this.units = i.nEmbed * 3, this.projUnits = i.nEmbed, this.ATTN = `block_${this.index}_cAttn`, this.PROJ = `block_${this.index}_cProj`, this.addVariable(this.ATTN), this.addVariable(this.PROJ), this.divisor = 1 / Math.sqrt(i.nEmbed / i.nHead);
22
22
  }
23
23
  build() {
24
24
  this.hasVariable(this.ATTN) === !1 && this.setVariable(
25
25
  this.ATTN,
26
26
  b(
27
- k([this.config.gpt.nEmbed, this.units], 0, 0.02),
27
+ k([this.config.nEmbed, this.units], 0, 0.02),
28
28
  !0
29
29
  //`block_${this.index}_attn_cAttn_kernel`
30
30
  )
31
31
  ), this.hasVariable(this.PROJ) === !1 && this.setVariable(
32
32
  this.PROJ,
33
33
  b(
34
- k([this.projUnits, this.config.gpt.nEmbed], 0, 0.02),
34
+ k([this.projUnits, this.config.nEmbed], 0, 0.02),
35
35
  !0
36
36
  //`block_${this.index}_attn_cProj_kernel`
37
37
  )
38
38
  );
39
39
  }
40
40
  getAttentionScores(t, i, s, o) {
41
- const e = g(t, i, this.divisor), n = T(e, s ? this.config.gpt.dropout : 0, o);
41
+ const e = g(t, i, this.divisor), n = T(e, s ? this.config.dropout : 0, o);
42
42
  return e.dispose(), n;
43
43
  }
44
44
  // Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
@@ -47,50 +47,50 @@ class W extends O {
47
47
  return o.dispose(), e;
48
48
  }
49
49
  getQKV(t) {
50
- return P(t, this.getVariable(this.ATTN), this.config.gpt.nHead);
50
+ return P(t, this.getVariable(this.ATTN), this.config.nHead);
51
51
  }
52
52
  getOutputProjection(t) {
53
- const i = t.shape[0], s = t.shape[2], o = this.config.gpt.nEmbed, e = t.transpose([0, 2, 1, 3]), n = N(e, [i, s, o]), p = y(n, this.getVariable(this.PROJ));
54
- return n.dispose(), e.dispose(), p;
53
+ const i = t.shape[0], s = t.shape[2], o = this.config.nEmbed, e = t.transpose([0, 2, 1, 3]), n = N(e, [i, s, o]), r = L(n, this.getVariable(this.PROJ));
54
+ return n.dispose(), e.dispose(), r;
55
55
  }
56
56
  updateCache(t, i, s) {
57
- const o = this.config.gpt.blockSize, e = t.shape[2], n = s.length || 0, p = V(t, o, n, s.k);
57
+ const o = this.config.blockSize, e = t.shape[2], n = s.length || 0, r = V(t, o, n, s.k);
58
58
  t.dispose(), s.k && s.k.dispose();
59
- const a = V(i, o, n, s.v);
59
+ const p = V(i, o, n, s.v);
60
60
  i.dispose(), s.v && s.v.dispose();
61
61
  const d = Math.min(n + e, o), h = s.cumulativeLength + e;
62
- s.length = d, s.cumulativeLength = h, s.k = c(p), s.v = c(a);
62
+ s.length = d, s.cumulativeLength = h, s.k = c(r), s.v = c(p);
63
63
  }
64
64
  forward(t, i) {
65
65
  return C(() => {
66
66
  this.startMemory();
67
- const [s, o, e] = this.getQKV(i), n = t.pastKV ? t.pastKV.cumulativeLength : 0, p = this.config.layerConfig.ropeCache, a = p ? v(s, p, n) : s, d = p ? v(o, p, n) : o;
68
- p && (s.dispose(), o.dispose());
67
+ const [s, o, e] = this.getQKV(i), n = t.pastKV ? t.pastKV.cumulativeLength : 0, r = t.ropeCache, p = r ? v(s, r, n) : s, d = r ? v(o, r, n) : o;
68
+ r && (s.dispose(), o.dispose());
69
69
  const h = t.pastKV ? t.pastKV.length : 0;
70
70
  t.pastKV && !t.training && this.updateCache(d, e, t.pastKV);
71
71
  const u = t.pastKV?.k ? t.pastKV.k : d, m = t.pastKV?.v ? t.pastKV.v : e;
72
- let r;
73
- h > 0 ? r = this.getAttentionScoresWithPast(a, u, h) : r = this.getAttentionScores(a, u, t.training, t.seed || 0), a.dispose(), t.pastKV || u.dispose();
74
- const l = R(r, m), f = t.attentionScores !== void 0 && t.attentionScores.attentionOut !== void 0;
75
- f || r.dispose(), t.pastKV || m.dispose();
72
+ let a;
73
+ h > 0 ? a = this.getAttentionScoresWithPast(p, u, h) : a = this.getAttentionScores(p, u, t.training, t.seed || 0), p.dispose(), t.pastKV || u.dispose();
74
+ const l = R(a, m), f = t.attentionScores !== void 0 && t.attentionScores.attentionOut !== void 0;
75
+ f || a.dispose(), t.pastKV || m.dispose();
76
76
  const A = this.getOutputProjection(l);
77
77
  if (l.dispose(), f && t.attentionScores && t.attentionScores.attentionOut !== void 0) {
78
- const K = r.shape[1], S = r.shape[2];
78
+ const K = a.shape[1], S = a.shape[2];
79
79
  t.attentionScores.attentionOut?.push(
80
- c(r.slice([0, 0, 0, 0], [1, -1, -1, -1]).reshape([K, S, -1]))
80
+ c(a.slice([0, 0, 0, 0], [1, -1, -1, -1]).reshape([K, S, -1]))
81
81
  );
82
82
  }
83
83
  return this.endMemory("CausalSelfAttention"), A;
84
84
  });
85
85
  }
86
86
  dropout(t) {
87
- if (this.config.gpt.dropout > 0) {
88
- const i = L(t, this.config.gpt.dropout);
87
+ if (this.config.dropout > 0) {
88
+ const i = y(t, this.config.dropout);
89
89
  return t.dispose(), i;
90
90
  } else
91
91
  return t;
92
92
  }
93
93
  }
94
94
  export {
95
- W as default
95
+ $ as default
96
96
  };
@@ -1,11 +1,12 @@
1
1
  import { Tensor } from '@tensorflow/tfjs-core';
2
- import { default as BaseLayer, ForwardAttributes, GPTLayerConfig } from './BaseLayer';
2
+ import { default as BaseLayer, ForwardAttributes } from './BaseLayer';
3
+ import { GPTConfig } from '../main';
3
4
  export default class MLP extends BaseLayer {
4
5
  private index;
5
6
  private hiddenUnits;
6
7
  private MLPHIDDEN;
7
8
  private MLPOUT;
8
- constructor(index: number, config: GPTLayerConfig, parent?: BaseLayer);
9
+ constructor(index: number, config: GPTConfig, parent?: BaseLayer);
9
10
  protected build(): void;
10
11
  forward(_: ForwardAttributes, x: Tensor): Tensor;
11
12
  protected dropout(x: Tensor): Tensor;
@@ -1,56 +1,52 @@
1
- import { t as l } from "../index-BoWRt-10.js";
1
+ import { t as p } from "../index-DdmHGZjq.js";
2
2
  import u from "./BaseLayer.js";
3
3
  import { matMulGelu as M } from "../ops/matMulGelu.js";
4
- import { v as o } from "../variable-BuddVFLa.js";
5
- import { r as h, d as f } from "../dropout-DYs5QFGQ.js";
6
- import { r as d } from "../reshape-CdBq1WJ6.js";
7
- import { m as c } from "../mat_mul-8m8pfdcx.js";
8
- class V extends u {
4
+ import { v as o } from "../variable-DPFOJyRG.js";
5
+ import { r as h, d as f } from "../dropout-CcKSfOYE.js";
6
+ import { r as d } from "../reshape-WeJkT3ja.js";
7
+ import { m as c } from "../mat_mul-Dpy2mMRu.js";
8
+ class H extends u {
9
9
  index;
10
10
  hiddenUnits;
11
11
  MLPHIDDEN;
12
12
  MLPOUT;
13
13
  constructor(i, t, s) {
14
- super(t, s), this.index = i, this.hiddenUnits = t.gpt.mlpFactor * t.gpt.nEmbed, this.MLPHIDDEN = `block_${this.index}_mlpHidden`, this.MLPOUT = `block_${this.index}_mlpOut`, this.addVariable(this.MLPHIDDEN), this.addVariable(this.MLPOUT);
14
+ super(t, s), this.index = i, this.hiddenUnits = t.mlpFactor * t.nEmbed, this.MLPHIDDEN = `block_${this.index}_mlpHidden`, this.MLPOUT = `block_${this.index}_mlpOut`, this.addVariable(this.MLPHIDDEN), this.addVariable(this.MLPOUT);
15
15
  }
16
16
  build() {
17
17
  this.hasVariable(this.MLPHIDDEN) === !1 && this.setVariable(
18
18
  this.MLPHIDDEN,
19
19
  o(
20
- h([this.config.gpt.nEmbed, this.hiddenUnits], 0, 0.02),
20
+ h([this.config.nEmbed, this.hiddenUnits], 0, 0.02),
21
21
  !0
22
22
  //`block_${this.index}_attn_cAttn_kernel`
23
23
  )
24
24
  ), this.hasVariable(this.MLPOUT) === !1 && this.setVariable(
25
25
  this.MLPOUT,
26
26
  o(
27
- h(
28
- [this.hiddenUnits, this.config.gpt.nEmbed],
29
- 0,
30
- 0.02 / Math.sqrt(2 * this.config.gpt.nLayer)
31
- ),
27
+ h([this.hiddenUnits, this.config.nEmbed], 0, 0.02 / Math.sqrt(2 * this.config.nLayer)),
32
28
  !0
33
29
  //`block_${this.index}_attn_cProj_kernel`
34
30
  )
35
31
  );
36
32
  }
37
33
  forward(i, t) {
38
- return l(() => {
34
+ return p(() => {
39
35
  this.startMemory();
40
- const [s, r, e] = t.shape, n = d(t, [s * r, e]), a = M(n, this.getVariable(this.MLPHIDDEN)), p = c(a, this.getVariable(this.MLPOUT));
36
+ const [s, r, e] = t.shape, n = d(t, [s * r, e]), a = M(n, this.getVariable(this.MLPHIDDEN)), m = c(a, this.getVariable(this.MLPOUT));
41
37
  a.dispose();
42
- const m = d(p, [s, r, e]);
43
- return this.endMemory("MLP"), m;
38
+ const l = d(m, [s, r, e]);
39
+ return this.endMemory("MLP"), l;
44
40
  });
45
41
  }
46
42
  dropout(i) {
47
- if (this.config.gpt.dropout > 0) {
48
- const t = f(i, this.config.gpt.dropout);
43
+ if (this.config.dropout > 0) {
44
+ const t = f(i, this.config.dropout);
49
45
  return i.dispose(), t;
50
46
  }
51
47
  return i;
52
48
  }
53
49
  }
54
50
  export {
55
- V as default
51
+ H as default
56
52
  };
@@ -0,0 +1,9 @@
1
+ import { Tensor } from '@tensorflow/tfjs-core';
2
+ import { default as BaseLayer } from './BaseLayer';
3
+ import { GPTConfig, ModelForwardAttributes } from '../main';
4
+ export default class PositionEmbedding extends BaseLayer {
5
+ private wpe?;
6
+ private drop;
7
+ constructor(config: GPTConfig, name?: string, parent?: BaseLayer);
8
+ forward(attrs: ModelForwardAttributes, x: Tensor): Tensor;
9
+ }
@@ -0,0 +1,45 @@
1
+ import { t as c, a9 as u, b as i } from "../index-DdmHGZjq.js";
2
+ import f from "./BaseLayer.js";
3
+ import { E as g, D as h } from "../random_width-DKGeiFuR.js";
4
+ import { r as b } from "../exports_initializers-DKk7-bsx.js";
5
+ import { m as l } from "../mod-CbibJi3D.js";
6
+ import { r as w } from "../range-BcUvLuf5.js";
7
+ /**
8
+ * @license
9
+ * Copyright 2018 Google LLC
10
+ *
11
+ * Use of this source code is governed by an MIT-style
12
+ * license that can be found in the LICENSE file or at
13
+ * https://opensource.org/licenses/MIT.
14
+ * =============================================================================
15
+ */
16
+ function E(t) {
17
+ return new h(t);
18
+ }
19
+ function x(t) {
20
+ return new g(t);
21
+ }
22
+ class q extends f {
23
+ wpe;
24
+ // Position embeddings
25
+ drop;
26
+ // Dropout
27
+ constructor(o, n = "", r) {
28
+ super(o, r), this.wpe = x({
29
+ inputDim: this.config.blockSize,
30
+ outputDim: this.config.nEmbed,
31
+ name: n,
32
+ embeddingsInitializer: b({ mean: 0, stddev: 0.02 })
33
+ }), this.drop = E({ rate: this.config.dropout });
34
+ }
35
+ forward(o, n) {
36
+ const r = o.cache?.[0]?.length ?? 0;
37
+ return c(() => {
38
+ const [, s] = n.shape, e = this.config.blockSize, a = w(0, s, 1, "int32"), m = l(u(a, i(r, "int32")), i(e, "int32")), d = this.wpe.apply(m), p = n.add(d);
39
+ return this.drop.apply(p, { training: o.training });
40
+ });
41
+ }
42
+ }
43
+ export {
44
+ q as default
45
+ };
@@ -1,7 +1,8 @@
1
1
  import { Tensor } from '@tensorflow/tfjs-core';
2
- import { default as BaseLayer, ForwardAttributes, GPTLayerConfig } from './BaseLayer';
2
+ import { default as BaseLayer, ForwardAttributes } from './BaseLayer';
3
+ import { GPTConfig } from '../main';
3
4
  export default class RMSNorm extends BaseLayer {
4
5
  private GAMMA;
5
- constructor(config: GPTLayerConfig, name?: string, parent?: BaseLayer);
6
+ constructor(config: GPTConfig, name?: string, parent?: BaseLayer);
6
7
  forward(_: ForwardAttributes, x: Tensor): Tensor;
7
8
  }
@@ -1,12 +1,12 @@
1
- import { t as s } from "../index-BoWRt-10.js";
1
+ import { t as s } from "../index-DdmHGZjq.js";
2
2
  import e from "./BaseLayer.js";
3
3
  import { normRMS as a } from "../ops/normRMS.js";
4
- import { v as i } from "../variable-BuddVFLa.js";
5
- import { o as m } from "../ones-Dj0SDhHf.js";
6
- class f extends e {
4
+ import { v as i } from "../variable-DPFOJyRG.js";
5
+ import { o as m } from "../ones-BAqVh-eA.js";
6
+ class l extends e {
7
7
  GAMMA;
8
8
  constructor(r, t = "", o) {
9
- super(r, o), this.GAMMA = t, this.addVariable(this.GAMMA, i(m([r.gpt.nEmbed]), !0, this.GAMMA, "float32"));
9
+ super(r, o), this.GAMMA = t, this.addVariable(this.GAMMA, i(m([r.nEmbed]), !0, this.GAMMA, "float32"));
10
10
  }
11
11
  forward(r, t) {
12
12
  return s(() => {
@@ -17,5 +17,5 @@ class f extends e {
17
17
  }
18
18
  }
19
19
  export {
20
- f as default
20
+ l as default
21
21
  };
@@ -1,5 +1,5 @@
1
1
  import { Tensor } from '@tensorflow/tfjs-core';
2
- import { GPTConfig } from '../config';
2
+ import { GPTConfig } from '../models/config';
3
3
  export default class RoPECache {
4
4
  private readonly config;
5
5
  readonly rotaryDim: number;
@@ -1,7 +1,7 @@
1
- import { b as t, x as h, t as n, w as p } from "../index-BoWRt-10.js";
2
- import { r as c } from "../reciprocal-BvGAyKyu.js";
3
- import { c as f, s as m } from "../sin-BeA3tsEd.js";
4
- import { r as a } from "../range-CRuAh-gd.js";
1
+ import { b as t, x as h, t as n, k as p } from "../index-DdmHGZjq.js";
2
+ import { r as c } from "../reciprocal-DhDWSKiD.js";
3
+ import { c as f, s as m } from "../sin-CPxad7Am.js";
4
+ import { r as a } from "../range-BcUvLuf5.js";
5
5
  class D {
6
6
  constructor(o) {
7
7
  this.config = o;
@@ -1,11 +1,12 @@
1
1
  import { Tensor } from '@tensorflow/tfjs-core';
2
- import { default as BaseLayer, ForwardAttributes, GPTLayerConfig } from './BaseLayer';
2
+ import { default as BaseLayer, ForwardAttributes } from './BaseLayer';
3
+ import { GPTConfig } from '../models/config';
3
4
  export default class TiedEmbeddingOutputLayer extends BaseLayer {
4
5
  private vocabSize;
5
6
  private embedDim;
6
7
  private initializer;
7
8
  private WEIGHTS;
8
- constructor(config: GPTLayerConfig, name: string, parent?: BaseLayer);
9
+ constructor(config: GPTConfig, name: string, parent?: BaseLayer);
9
10
  embed(inputs: Tensor): Tensor;
10
11
  project(inputs: Tensor): Tensor;
11
12
  forward(_: ForwardAttributes, x: Tensor): Tensor;
@@ -1,9 +1,31 @@
1
- import "../random_width-sZORGo5k.js";
2
- import "../index-BoWRt-10.js";
3
- import { T as e } from "../TiedEmbedding-BxOerUmB.js";
4
- import "./BaseLayer.js";
5
- import "../variable-BuddVFLa.js";
6
- import "../gather-CMMy2KEG.js";
1
+ import { d as r } from "../random_width-DKGeiFuR.js";
2
+ import "../index-DdmHGZjq.js";
3
+ import { r as a } from "../exports_initializers-DKk7-bsx.js";
4
+ import s from "./BaseLayer.js";
5
+ import { v as m } from "../variable-DPFOJyRG.js";
6
+ import { g as o } from "../gather-CPg6ZlQA.js";
7
+ class S extends s {
8
+ vocabSize;
9
+ embedDim;
10
+ initializer;
11
+ WEIGHTS;
12
+ constructor(i, e, t) {
13
+ super(i, t), this.WEIGHTS = e, this.vocabSize = i.vocabSize, this.embedDim = i.nEmbed, this.initializer = a({
14
+ mean: 0,
15
+ stddev: 0.02
16
+ }), this.addVariable(this.WEIGHTS, m(this.initializer.apply([this.vocabSize, this.embedDim]), !0));
17
+ }
18
+ embed(i) {
19
+ return o(this.getVariable(this.WEIGHTS), i, 0);
20
+ }
21
+ project(i) {
22
+ return r(i, this.getVariable(this.WEIGHTS).transpose());
23
+ }
24
+ // Dummy, should not be used.
25
+ forward(i, e) {
26
+ return this.project(e);
27
+ }
28
+ }
7
29
  export {
8
- e as default
30
+ S as default
9
31
  };