@genai-fi/nanogpt 0.7.3 → 0.8.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.d.ts +25 -2
- package/dist/Generator.js +152 -49
- package/dist/{RealDiv-Dy0p8Bvo.js → RealDiv-D_q39E3A.js} +13 -13
- package/dist/{Reshape-DvudQDvJ.js → Reshape-41YpQqEo.js} +1 -1
- package/dist/{Reshape-DH5srBP0.js → Reshape-Bh_jzKzV.js} +5 -5
- package/dist/TeachableLLM.d.ts +6 -6
- package/dist/TeachableLLM.js +33 -31
- package/dist/Trainer.d.ts +13 -2
- package/dist/Trainer.js +21 -12
- package/dist/{axis_util-BzbKo31C.js → axis_util-Did9235A.js} +3 -3
- package/dist/backend.js +2 -2
- package/dist/{backend_util-TE7aTPhZ.js → backend_util-yC3YH1jo.js} +58 -58
- package/dist/{broadcast_to-CdbwV-Dj.js → broadcast_to-CUvOdOT5.js} +2 -2
- package/dist/checks/appendCache.d.ts +1 -0
- package/dist/checks/appendCache.js +22 -0
- package/dist/checks/attentionMask.d.ts +1 -0
- package/dist/checks/attentionMask.js +37 -0
- package/dist/checks/check.d.ts +9 -0
- package/dist/checks/check.js +20 -0
- package/dist/checks/gelu.d.ts +1 -0
- package/dist/checks/gelu.js +18 -0
- package/dist/checks/index.d.ts +19 -0
- package/dist/checks/index.js +21 -0
- package/dist/checks/normRMS.d.ts +1 -0
- package/dist/checks/normRMS.js +16 -0
- package/dist/checks/normRMSGrad.d.ts +1 -0
- package/dist/checks/normRMSGrad.js +12 -0
- package/dist/checks/qkv.d.ts +1 -0
- package/dist/checks/qkv.js +25 -0
- package/dist/checks/rope.d.ts +1 -0
- package/dist/checks/rope.js +21 -0
- package/dist/{concat-CsxrgovM.js → concat-pHiVqR3L.js} +1 -1
- package/dist/{dataset-CtdBYwjo.js → dataset-DPPl-iLT.js} +9 -9
- package/dist/{dropout-DYs5QFGQ.js → dropout-CcKSfOYE.js} +18 -18
- package/dist/exports_initializers-DKk7-bsx.js +16 -0
- package/dist/{gather-CMMy2KEG.js → gather-CPg6ZlQA.js} +1 -1
- package/dist/{gelu-C-dPj6Ku.js → gelu-BkcmEEyD.js} +1 -1
- package/dist/{gpgpu_math-DGNLNL4I.js → gpgpu_math-D_ODOLix.js} +26 -26
- package/dist/{index-BoWRt-10.js → index-DdmHGZjq.js} +659 -650
- package/dist/{index-CLthM0TO.js → index-evZ57wr4.js} +185 -185
- package/dist/{kernel_funcs_utils-BYKWV8Aa.js → kernel_funcs_utils-CDfFpUab.js} +21 -21
- package/dist/layers/BaseLayer.d.ts +8 -13
- package/dist/layers/BaseLayer.js +25 -13
- package/dist/layers/CausalSelfAttention.d.ts +3 -2
- package/dist/layers/CausalSelfAttention.js +28 -28
- package/dist/layers/MLP.d.ts +3 -2
- package/dist/layers/MLP.js +16 -20
- package/dist/layers/PositionEmbedding.d.ts +9 -0
- package/dist/layers/PositionEmbedding.js +45 -0
- package/dist/layers/RMSNorm.d.ts +3 -2
- package/dist/layers/RMSNorm.js +6 -6
- package/dist/layers/RoPECache.d.ts +1 -1
- package/dist/layers/RoPECache.js +4 -4
- package/dist/layers/TiedEmbedding.d.ts +3 -2
- package/dist/layers/TiedEmbedding.js +29 -7
- package/dist/layers/TransformerBlock.d.ts +3 -2
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/loader/load.d.ts +2 -2
- package/dist/loader/loadHF.d.ts +2 -2
- package/dist/loader/loadTransformers.d.ts +4 -2
- package/dist/loader/loadTransformers.js +10 -9
- package/dist/loader/newZipLoad.d.ts +2 -2
- package/dist/loader/oldZipLoad.d.ts +2 -2
- package/dist/loader/oldZipLoad.js +44 -51
- package/dist/loader/save.d.ts +8 -0
- package/dist/loader/save.js +62 -0
- package/dist/{log_sum_exp-DbjkV734.js → log_sum_exp-C8yFJfZz.js} +45 -24
- package/dist/main.d.ts +6 -4
- package/dist/main.js +24 -18
- package/dist/{mat_mul-8m8pfdcx.js → mat_mul-Dpy2mMRu.js} +1 -1
- package/dist/mod-CbibJi3D.js +27 -0
- package/dist/models/NanoGPTV1.d.ts +15 -0
- package/dist/models/NanoGPTV1.js +71 -0
- package/dist/{config.d.ts → models/config.d.ts} +1 -0
- package/dist/{config.js → models/config.js} +1 -0
- package/dist/models/factory.d.ts +3 -0
- package/dist/models/factory.js +14 -0
- package/dist/models/model.d.ts +26 -0
- package/dist/models/model.js +70 -0
- package/dist/{mulmat_packed_gpu-VSekgsNv.js → mulmat_packed_gpu-q_Gmwyld.js} +1 -1
- package/dist/{ones-Dj0SDhHf.js → ones-BAqVh-eA.js} +2 -2
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.js +1 -1
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/adamAdjust.js +9 -9
- package/dist/ops/cpu/adamMoments.js +2 -2
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +5 -5
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +2 -2
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +7 -7
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +2 -2
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/rope.js +4 -4
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/adamAdjust.js +2 -2
- package/dist/ops/webgl/adamMoments.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +4 -4
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/log.js +3 -3
- package/dist/ops/webgl/matMulGelu.js +10 -10
- package/dist/ops/webgl/matMulMul.js +1 -1
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops/webgpu/adamAdjust.js +3 -3
- package/dist/ops/webgpu/adamMoments.js +3 -3
- package/dist/ops/webgpu/appendCache.js +3 -3
- package/dist/ops/webgpu/attentionMask.js +3 -3
- package/dist/ops/webgpu/gatherSub.js +3 -3
- package/dist/ops/webgpu/gelu.js +3 -3
- package/dist/ops/webgpu/normRMS.js +2 -2
- package/dist/ops/webgpu/normRMSGrad.js +5 -5
- package/dist/ops/webgpu/qkv.js +3 -3
- package/dist/ops/webgpu/rope.js +3 -3
- package/dist/ops/webgpu/scatterSub.js +3 -3
- package/dist/ops/webgpu/utils/reductions.js +4 -4
- package/dist/ops-542ai2vG.js +1525 -0
- package/dist/{random_width-sZORGo5k.js → random_width-DKGeiFuR.js} +1471 -1538
- package/dist/{range-CRuAh-gd.js → range-BcUvLuf5.js} +1 -1
- package/dist/{reciprocal-BvGAyKyu.js → reciprocal-DhDWSKiD.js} +1 -1
- package/dist/{register_all_kernels-BwDSRN-f.js → register_all_kernels-Do9VvZmo.js} +2488 -2534
- package/dist/{max-Ddnnb5xe.js → relu-B1AXs7p5.js} +6 -6
- package/dist/{reshape-CdBq1WJ6.js → reshape-WeJkT3ja.js} +1 -1
- package/dist/{scatter_nd_util-DUstGbU1.js → scatter_nd_util-B7yDhiQr.js} +1 -1
- package/dist/{selu_util-BJEXVvjX.js → selu_util-BgUO9gHY.js} +125 -146
- package/dist/{shared-wS99K7_n.js → shared-CZiWmQCI.js} +1 -1
- package/dist/{shared-B8ztnyEk.js → shared-V6D_md-c.js} +72 -72
- package/dist/{sin-BeA3tsEd.js → sin-CPxad7Am.js} +1 -1
- package/dist/{slice-BiOsknYS.js → slice-B7jXtPnp.js} +1 -1
- package/dist/{softmax-Bv_6lyMX.js → softmax-BfsyI4As.js} +1 -1
- package/dist/{split-B-dikLRw.js → split-BPxr8_8m.js} +1 -1
- package/dist/{stack-B17UN2nn.js → stack-BNwLzE43.js} +1 -1
- package/dist/{sum-66ew2byf.js → sum-ByFINZgi.js} +3 -3
- package/dist/{tensor-JwS7ZYY6.js → tensor-DbqgIV9B.js} +1 -1
- package/dist/tensor1d-CtJq5BOv.js +27 -0
- package/dist/{tensor2d-wxPAnDQy.js → tensor2d-CObBWBkW.js} +1 -1
- package/dist/tensor3d-BOukqWwr.js +30 -0
- package/dist/tensor4d-DLtk7Nxh.js +30 -0
- package/dist/training/Adam.js +2 -2
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +2 -2
- package/dist/training/Evaluator.d.ts +2 -2
- package/dist/training/FullTrainer.d.ts +3 -3
- package/dist/training/FullTrainer.js +61 -69
- package/dist/training/Trainer.d.ts +15 -3
- package/dist/training/Trainer.js +39 -47
- package/dist/training/sparseCrossEntropy.js +12 -13
- package/dist/utilities/arrayClose.d.ts +1 -1
- package/dist/utilities/arrayClose.js +16 -7
- package/dist/utilities/dummy.d.ts +4 -4
- package/dist/utilities/dummy.js +13 -13
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/parameters.d.ts +1 -1
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-BuddVFLa.js → variable-DPFOJyRG.js} +1 -1
- package/dist/{webgpu_program-PFzf1hAQ.js → webgpu_program-Dhk9R5aG.js} +1 -1
- package/dist/{webgpu_util-D____QpY.js → webgpu_util-BqGnZg8t.js} +27 -27
- package/dist/{zeros--BdLQ3oG.js → zeros-Dnwix0p4.js} +1 -1
- package/package.json +2 -3
- package/dist/NanoGPTModel.d.ts +0 -52
- package/dist/NanoGPTModel.js +0 -203
- package/dist/TiedEmbedding-BxOerUmB.js +0 -43
- package/dist/ops-BFGCx8Ri.js +0 -1202
- package/dist/utilities/generate.d.ts +0 -3
- package/dist/utilities/generate.js +0 -22
- package/dist/utilities/save.d.ts +0 -9
- package/dist/utilities/save.js +0 -61
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { u as O, f as Y } from "./gpgpu_math-
|
|
3
|
-
import { f as v } from "./backend_util-
|
|
1
|
+
import { l as B, j as G, az as K, aa as z, at as W, aA as V, ag as N, au as F, u as S } from "./index-DdmHGZjq.js";
|
|
2
|
+
import { u as O, f as Y } from "./gpgpu_math-D_ODOLix.js";
|
|
3
|
+
import { f as v } from "./backend_util-yC3YH1jo.js";
|
|
4
4
|
/**
|
|
5
5
|
* @license
|
|
6
6
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -295,7 +295,7 @@ function L(t) {
|
|
|
295
295
|
return o.complexTensorInfos = { real: i, imag: a }, n;
|
|
296
296
|
}
|
|
297
297
|
const me = {
|
|
298
|
-
kernelName:
|
|
298
|
+
kernelName: z,
|
|
299
299
|
backendName: "webgl",
|
|
300
300
|
kernelFunc: L
|
|
301
301
|
};
|
|
@@ -315,16 +315,16 @@ const me = {
|
|
|
315
315
|
* limitations under the License.
|
|
316
316
|
* =============================================================================
|
|
317
317
|
*/
|
|
318
|
-
const w = "return (a < 0.) ? b * a : a;",
|
|
318
|
+
const w = "return (a < 0.) ? b * a : a;", R = `
|
|
319
319
|
vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));
|
|
320
320
|
return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);
|
|
321
321
|
`;
|
|
322
322
|
function oe(t) {
|
|
323
|
-
const { inputs: e, backend: s, attrs: r } = t, { x: u } = e, { alpha: n } = r, o = s.makeTensorInfo([], "float32", V(n, "float32")), i = N().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new E(
|
|
323
|
+
const { inputs: e, backend: s, attrs: r } = t, { x: u } = e, { alpha: n } = r, o = s.makeTensorInfo([], "float32", V(n, "float32")), i = N().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new E(R, u.shape, o.shape) : new b(w, u.shape, o.shape), a = s.runWebGLProgram(i, [u, o], "float32");
|
|
324
324
|
return s.disposeIntermediateTensorInfo(o), a;
|
|
325
325
|
}
|
|
326
326
|
const be = {
|
|
327
|
-
kernelName:
|
|
327
|
+
kernelName: W,
|
|
328
328
|
backendName: "webgl",
|
|
329
329
|
kernelFunc: oe
|
|
330
330
|
};
|
|
@@ -344,12 +344,12 @@ const be = {
|
|
|
344
344
|
* limitations under the License.
|
|
345
345
|
* =============================================================================
|
|
346
346
|
*/
|
|
347
|
-
const
|
|
347
|
+
const k = "return (a < 0.) ? b * a : a;", U = `
|
|
348
348
|
vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));
|
|
349
349
|
return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);
|
|
350
350
|
`;
|
|
351
351
|
function ue(t) {
|
|
352
|
-
const { inputs: e, backend: s } = t, { x: r, alpha: u } = e, n = N().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new E(U, r.shape, u.shape) : new b(
|
|
352
|
+
const { inputs: e, backend: s } = t, { x: r, alpha: u } = e, n = N().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new E(U, r.shape, u.shape) : new b(k, r.shape, u.shape);
|
|
353
353
|
return s.runWebGLProgram(n, [r, u], "float32");
|
|
354
354
|
}
|
|
355
355
|
const Ne = {
|
|
@@ -386,7 +386,7 @@ function ye({ opSnippet: t, packedOpSnippet: e, cpuKernelImpl: s, dtype: r }) {
|
|
|
386
386
|
return c ? l = new ne(o.shape, e) : l = new q(o.shape, t), i.runWebGLProgram(l, [o], a);
|
|
387
387
|
};
|
|
388
388
|
}
|
|
389
|
-
function
|
|
389
|
+
function Ae({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: s = !1, supportsComplex: r = !1, cpuKernelImpl: u, dtype: n }) {
|
|
390
390
|
return ({ inputs: o, backend: i }) => {
|
|
391
391
|
const { a, b: c } = o, l = i;
|
|
392
392
|
if (r && a.dtype === "complex64") {
|
|
@@ -404,8 +404,8 @@ function Ie({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: s = !1, suppor
|
|
|
404
404
|
shape: c.shape
|
|
405
405
|
}, D = new b(t, a.shape, c.shape);
|
|
406
406
|
return l.runWebGLProgram(D, [$, _], S(p.dtype, x.dtype));
|
|
407
|
-
}),
|
|
408
|
-
return l.disposeIntermediateTensorInfo(g), l.disposeIntermediateTensorInfo(m),
|
|
407
|
+
}), I = L({ inputs: { real: g, imag: m }, backend: l });
|
|
408
|
+
return l.disposeIntermediateTensorInfo(g), l.disposeIntermediateTensorInfo(m), I;
|
|
409
409
|
}
|
|
410
410
|
const d = n || S(a.dtype, c.dtype);
|
|
411
411
|
if ((a.dtype === "string" || c.dtype === "string" || l.shouldExecuteOnCPU([a, c])) && u != null) {
|
|
@@ -415,15 +415,15 @@ function Ie({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: s = !1, suppor
|
|
|
415
415
|
) : h, m = a.dtype === "string" ? (
|
|
416
416
|
// tslint:disable-next-line: no-any
|
|
417
417
|
v(f)
|
|
418
|
-
) : f, [
|
|
419
|
-
return x.values =
|
|
418
|
+
) : f, [I, C] = u(a.shape, c.shape, g, m, d), p = l.makeTensorInfo(C, d), x = l.texData.get(p.dataId);
|
|
419
|
+
return x.values = I, p;
|
|
420
420
|
}
|
|
421
421
|
const y = N().getBool("WEBGL_PACK_BINARY_OPERATIONS") && e != null;
|
|
422
|
-
let
|
|
423
|
-
return y ?
|
|
422
|
+
let A;
|
|
423
|
+
return y ? A = new E(e, a.shape, c.shape, s) : A = new b(t, a.shape, c.shape), l.runWebGLProgram(A, [a, c], d);
|
|
424
424
|
};
|
|
425
425
|
}
|
|
426
|
-
function
|
|
426
|
+
function Ie(t, e = !1) {
|
|
427
427
|
if (t === "linear")
|
|
428
428
|
return e ? ee : j;
|
|
429
429
|
if (t === "relu")
|
|
@@ -433,9 +433,9 @@ function Ae(t, e = !1) {
|
|
|
433
433
|
if (t === "relu6")
|
|
434
434
|
return e ? ae : Q;
|
|
435
435
|
if (t === "prelu")
|
|
436
|
-
return e ? U :
|
|
436
|
+
return e ? U : k;
|
|
437
437
|
if (t === "leakyrelu")
|
|
438
|
-
return e ?
|
|
438
|
+
return e ? R : w;
|
|
439
439
|
if (t === "sigmoid")
|
|
440
440
|
return e ? re : X;
|
|
441
441
|
throw new Error(`Activation ${t} has not been implemented for the WebGL backend.`);
|
|
@@ -446,7 +446,7 @@ export {
|
|
|
446
446
|
T as C,
|
|
447
447
|
ne as U,
|
|
448
448
|
Z as a,
|
|
449
|
-
|
|
449
|
+
Ae as b,
|
|
450
450
|
pe as c,
|
|
451
451
|
he as d,
|
|
452
452
|
q as e,
|
|
@@ -457,7 +457,7 @@ export {
|
|
|
457
457
|
fe as j,
|
|
458
458
|
xe as k,
|
|
459
459
|
Oe as l,
|
|
460
|
-
|
|
460
|
+
Ie as m,
|
|
461
461
|
me as n,
|
|
462
462
|
ge as o,
|
|
463
463
|
be as p,
|
|
@@ -1,27 +1,22 @@
|
|
|
1
|
-
import { GPTConfig } from '../config';
|
|
1
|
+
import { GPTConfig } from '../models/config';
|
|
2
2
|
import { default as MemoryProfiler } from '../utilities/profile';
|
|
3
3
|
import { default as RoPECache } from './RoPECache';
|
|
4
4
|
import { Tensor, Variable } from '@tensorflow/tfjs-core';
|
|
5
|
-
export interface LayerConfig {
|
|
6
|
-
checkpointing?: boolean;
|
|
7
|
-
profiler?: MemoryProfiler;
|
|
8
|
-
ropeCache?: RoPECache;
|
|
9
|
-
}
|
|
10
|
-
export interface GPTLayerConfig {
|
|
11
|
-
gpt: GPTConfig;
|
|
12
|
-
layerConfig: LayerConfig;
|
|
13
|
-
}
|
|
14
5
|
export interface ForwardAttributes {
|
|
15
6
|
training: boolean;
|
|
7
|
+
checkpointing?: boolean;
|
|
8
|
+
ropeCache?: RoPECache;
|
|
16
9
|
}
|
|
17
10
|
export default abstract class BaseLayer<ATTR extends ForwardAttributes = ForwardAttributes> {
|
|
18
11
|
readonly parent?: BaseLayer;
|
|
19
|
-
readonly config:
|
|
12
|
+
readonly config: GPTConfig;
|
|
20
13
|
private _variables;
|
|
21
14
|
private _trainable;
|
|
22
15
|
readonly children: BaseLayer[];
|
|
23
|
-
|
|
16
|
+
private profiler?;
|
|
17
|
+
constructor(config: GPTConfig, parent?: BaseLayer);
|
|
24
18
|
getProfiler(): MemoryProfiler | undefined;
|
|
19
|
+
setProfiler(profiler: MemoryProfiler | null): void;
|
|
25
20
|
startMemory(): void;
|
|
26
21
|
endMemory(label: string): void;
|
|
27
22
|
addVariable(name: string, variable?: Variable): void;
|
|
@@ -29,7 +24,7 @@ export default abstract class BaseLayer<ATTR extends ForwardAttributes = Forward
|
|
|
29
24
|
get trainableVariables(): Variable[];
|
|
30
25
|
get trainable(): boolean;
|
|
31
26
|
set trainable(value: boolean);
|
|
32
|
-
getVariable(name: string): Variable;
|
|
27
|
+
getVariable(name: string, recursive?: boolean): Variable;
|
|
33
28
|
hasVariable(name: string): boolean;
|
|
34
29
|
setVariable(name: string, variable: Variable): void;
|
|
35
30
|
saveWeights(map: Map<string, Tensor[]>): void;
|
package/dist/layers/BaseLayer.js
CHANGED
|
@@ -1,22 +1,28 @@
|
|
|
1
|
-
import { T as
|
|
2
|
-
import { v as _ } from "../variable-
|
|
3
|
-
class
|
|
1
|
+
import { T as p, J as g, e as o, K as v } from "../index-DdmHGZjq.js";
|
|
2
|
+
import { v as _ } from "../variable-DPFOJyRG.js";
|
|
3
|
+
class T {
|
|
4
4
|
parent;
|
|
5
5
|
config;
|
|
6
6
|
_variables = /* @__PURE__ */ new Map();
|
|
7
7
|
_trainable = !0;
|
|
8
8
|
children = [];
|
|
9
|
+
profiler;
|
|
9
10
|
constructor(t, r) {
|
|
10
11
|
this.config = t, this.parent = r, this.parent && this.parent.children.push(this);
|
|
11
12
|
}
|
|
12
13
|
getProfiler() {
|
|
13
|
-
return this.
|
|
14
|
+
return this.profiler;
|
|
15
|
+
}
|
|
16
|
+
setProfiler(t) {
|
|
17
|
+
this.profiler = t || void 0, this.children.forEach((r) => {
|
|
18
|
+
r.setProfiler(t);
|
|
19
|
+
});
|
|
14
20
|
}
|
|
15
21
|
startMemory() {
|
|
16
|
-
this.
|
|
22
|
+
this.profiler?.startMemory();
|
|
17
23
|
}
|
|
18
24
|
endMemory(t) {
|
|
19
|
-
this.
|
|
25
|
+
this.profiler?.endMemory(t);
|
|
20
26
|
}
|
|
21
27
|
addVariable(t, r) {
|
|
22
28
|
this._variables.set(t, r || null);
|
|
@@ -41,11 +47,17 @@ class M {
|
|
|
41
47
|
r.trainable = t;
|
|
42
48
|
});
|
|
43
49
|
}
|
|
44
|
-
getVariable(t) {
|
|
45
|
-
const
|
|
46
|
-
if (!r)
|
|
50
|
+
getVariable(t, r = !1) {
|
|
51
|
+
const e = this._variables.get(t);
|
|
52
|
+
if (!e && r)
|
|
53
|
+
for (const i of this.children) {
|
|
54
|
+
const s = i.getVariable(t, !0);
|
|
55
|
+
if (s)
|
|
56
|
+
return s;
|
|
57
|
+
}
|
|
58
|
+
if (!e)
|
|
47
59
|
throw new Error(`Variable ${t} not found`);
|
|
48
|
-
return
|
|
60
|
+
return e;
|
|
49
61
|
}
|
|
50
62
|
hasVariable(t) {
|
|
51
63
|
return this._variables.get(t) !== null;
|
|
@@ -85,7 +97,7 @@ class M {
|
|
|
85
97
|
call(t, ...r) {
|
|
86
98
|
this.build();
|
|
87
99
|
const e = this.forward(t, ...r);
|
|
88
|
-
if (t.training && e instanceof
|
|
100
|
+
if (t.training && e instanceof p) {
|
|
89
101
|
const i = this.dropout(e);
|
|
90
102
|
return i !== e && e.dispose(), i;
|
|
91
103
|
} else
|
|
@@ -95,7 +107,7 @@ class M {
|
|
|
95
107
|
return this.build(), this.checkpointingFn(t, ...r);
|
|
96
108
|
}
|
|
97
109
|
checkpointingFn(t, ...r) {
|
|
98
|
-
const e = this.trainableVariables, s =
|
|
110
|
+
const e = this.trainableVariables, s = g((...a) => {
|
|
99
111
|
const l = a[a.length - 1], n = a.slice(0, r.length), h = this.forward(t, ...n);
|
|
100
112
|
return l(n), { value: h, gradFunc: (c, f) => {
|
|
101
113
|
const u = o().state.activeTape;
|
|
@@ -112,5 +124,5 @@ class M {
|
|
|
112
124
|
}
|
|
113
125
|
}
|
|
114
126
|
export {
|
|
115
|
-
|
|
127
|
+
T as default
|
|
116
128
|
};
|
|
@@ -1,5 +1,6 @@
|
|
|
1
|
-
import { default as BaseLayer, ForwardAttributes
|
|
1
|
+
import { default as BaseLayer, ForwardAttributes } from './BaseLayer';
|
|
2
2
|
import { Tensor } from '@tensorflow/tfjs-core';
|
|
3
|
+
import { GPTConfig } from '../models/config';
|
|
3
4
|
export type KVCache = {
|
|
4
5
|
k?: Tensor;
|
|
5
6
|
v?: Tensor;
|
|
@@ -22,7 +23,7 @@ export default class CausalSelfAttention extends BaseLayer<AttentionForwardAttri
|
|
|
22
23
|
private projUnits;
|
|
23
24
|
private ATTN;
|
|
24
25
|
private PROJ;
|
|
25
|
-
constructor(index: number, config:
|
|
26
|
+
constructor(index: number, config: GPTConfig, parent?: BaseLayer);
|
|
26
27
|
protected build(): void;
|
|
27
28
|
private getAttentionScores;
|
|
28
29
|
private getAttentionScoresWithPast;
|
|
@@ -3,14 +3,14 @@ import O from "./BaseLayer.js";
|
|
|
3
3
|
import { qkv as P } from "../ops/qkv.js";
|
|
4
4
|
import { rope as v } from "../ops/rope.js";
|
|
5
5
|
import { appendCache as V } from "../ops/appendCache.js";
|
|
6
|
-
import {
|
|
6
|
+
import { k as c, t as C } from "../index-DdmHGZjq.js";
|
|
7
7
|
import { fusedSoftmax as T } from "../ops/fusedSoftmax.js";
|
|
8
|
-
import { d as
|
|
9
|
-
import { v as b } from "../variable-
|
|
10
|
-
import { r as k, d as
|
|
11
|
-
import { r as N } from "../reshape-
|
|
12
|
-
import { m as R } from "../mat_mul-
|
|
13
|
-
class
|
|
8
|
+
import { d as L } from "../random_width-DKGeiFuR.js";
|
|
9
|
+
import { v as b } from "../variable-DPFOJyRG.js";
|
|
10
|
+
import { r as k, d as y } from "../dropout-CcKSfOYE.js";
|
|
11
|
+
import { r as N } from "../reshape-WeJkT3ja.js";
|
|
12
|
+
import { m as R } from "../mat_mul-Dpy2mMRu.js";
|
|
13
|
+
class $ extends O {
|
|
14
14
|
divisor;
|
|
15
15
|
index;
|
|
16
16
|
units;
|
|
@@ -18,27 +18,27 @@ class W extends O {
|
|
|
18
18
|
ATTN;
|
|
19
19
|
PROJ;
|
|
20
20
|
constructor(t, i, s) {
|
|
21
|
-
super(i, s), this.index = t, this.units = i.
|
|
21
|
+
super(i, s), this.index = t, this.units = i.nEmbed * 3, this.projUnits = i.nEmbed, this.ATTN = `block_${this.index}_cAttn`, this.PROJ = `block_${this.index}_cProj`, this.addVariable(this.ATTN), this.addVariable(this.PROJ), this.divisor = 1 / Math.sqrt(i.nEmbed / i.nHead);
|
|
22
22
|
}
|
|
23
23
|
build() {
|
|
24
24
|
this.hasVariable(this.ATTN) === !1 && this.setVariable(
|
|
25
25
|
this.ATTN,
|
|
26
26
|
b(
|
|
27
|
-
k([this.config.
|
|
27
|
+
k([this.config.nEmbed, this.units], 0, 0.02),
|
|
28
28
|
!0
|
|
29
29
|
//`block_${this.index}_attn_cAttn_kernel`
|
|
30
30
|
)
|
|
31
31
|
), this.hasVariable(this.PROJ) === !1 && this.setVariable(
|
|
32
32
|
this.PROJ,
|
|
33
33
|
b(
|
|
34
|
-
k([this.projUnits, this.config.
|
|
34
|
+
k([this.projUnits, this.config.nEmbed], 0, 0.02),
|
|
35
35
|
!0
|
|
36
36
|
//`block_${this.index}_attn_cProj_kernel`
|
|
37
37
|
)
|
|
38
38
|
);
|
|
39
39
|
}
|
|
40
40
|
getAttentionScores(t, i, s, o) {
|
|
41
|
-
const e = g(t, i, this.divisor), n = T(e, s ? this.config.
|
|
41
|
+
const e = g(t, i, this.divisor), n = T(e, s ? this.config.dropout : 0, o);
|
|
42
42
|
return e.dispose(), n;
|
|
43
43
|
}
|
|
44
44
|
// Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
|
|
@@ -47,50 +47,50 @@ class W extends O {
|
|
|
47
47
|
return o.dispose(), e;
|
|
48
48
|
}
|
|
49
49
|
getQKV(t) {
|
|
50
|
-
return P(t, this.getVariable(this.ATTN), this.config.
|
|
50
|
+
return P(t, this.getVariable(this.ATTN), this.config.nHead);
|
|
51
51
|
}
|
|
52
52
|
getOutputProjection(t) {
|
|
53
|
-
const i = t.shape[0], s = t.shape[2], o = this.config.
|
|
54
|
-
return n.dispose(), e.dispose(),
|
|
53
|
+
const i = t.shape[0], s = t.shape[2], o = this.config.nEmbed, e = t.transpose([0, 2, 1, 3]), n = N(e, [i, s, o]), r = L(n, this.getVariable(this.PROJ));
|
|
54
|
+
return n.dispose(), e.dispose(), r;
|
|
55
55
|
}
|
|
56
56
|
updateCache(t, i, s) {
|
|
57
|
-
const o = this.config.
|
|
57
|
+
const o = this.config.blockSize, e = t.shape[2], n = s.length || 0, r = V(t, o, n, s.k);
|
|
58
58
|
t.dispose(), s.k && s.k.dispose();
|
|
59
|
-
const
|
|
59
|
+
const p = V(i, o, n, s.v);
|
|
60
60
|
i.dispose(), s.v && s.v.dispose();
|
|
61
61
|
const d = Math.min(n + e, o), h = s.cumulativeLength + e;
|
|
62
|
-
s.length = d, s.cumulativeLength = h, s.k = c(
|
|
62
|
+
s.length = d, s.cumulativeLength = h, s.k = c(r), s.v = c(p);
|
|
63
63
|
}
|
|
64
64
|
forward(t, i) {
|
|
65
65
|
return C(() => {
|
|
66
66
|
this.startMemory();
|
|
67
|
-
const [s, o, e] = this.getQKV(i), n = t.pastKV ? t.pastKV.cumulativeLength : 0,
|
|
68
|
-
|
|
67
|
+
const [s, o, e] = this.getQKV(i), n = t.pastKV ? t.pastKV.cumulativeLength : 0, r = t.ropeCache, p = r ? v(s, r, n) : s, d = r ? v(o, r, n) : o;
|
|
68
|
+
r && (s.dispose(), o.dispose());
|
|
69
69
|
const h = t.pastKV ? t.pastKV.length : 0;
|
|
70
70
|
t.pastKV && !t.training && this.updateCache(d, e, t.pastKV);
|
|
71
71
|
const u = t.pastKV?.k ? t.pastKV.k : d, m = t.pastKV?.v ? t.pastKV.v : e;
|
|
72
|
-
let
|
|
73
|
-
h > 0 ?
|
|
74
|
-
const l = R(
|
|
75
|
-
f ||
|
|
72
|
+
let a;
|
|
73
|
+
h > 0 ? a = this.getAttentionScoresWithPast(p, u, h) : a = this.getAttentionScores(p, u, t.training, t.seed || 0), p.dispose(), t.pastKV || u.dispose();
|
|
74
|
+
const l = R(a, m), f = t.attentionScores !== void 0 && t.attentionScores.attentionOut !== void 0;
|
|
75
|
+
f || a.dispose(), t.pastKV || m.dispose();
|
|
76
76
|
const A = this.getOutputProjection(l);
|
|
77
77
|
if (l.dispose(), f && t.attentionScores && t.attentionScores.attentionOut !== void 0) {
|
|
78
|
-
const K =
|
|
78
|
+
const K = a.shape[1], S = a.shape[2];
|
|
79
79
|
t.attentionScores.attentionOut?.push(
|
|
80
|
-
c(
|
|
80
|
+
c(a.slice([0, 0, 0, 0], [1, -1, -1, -1]).reshape([K, S, -1]))
|
|
81
81
|
);
|
|
82
82
|
}
|
|
83
83
|
return this.endMemory("CausalSelfAttention"), A;
|
|
84
84
|
});
|
|
85
85
|
}
|
|
86
86
|
dropout(t) {
|
|
87
|
-
if (this.config.
|
|
88
|
-
const i =
|
|
87
|
+
if (this.config.dropout > 0) {
|
|
88
|
+
const i = y(t, this.config.dropout);
|
|
89
89
|
return t.dispose(), i;
|
|
90
90
|
} else
|
|
91
91
|
return t;
|
|
92
92
|
}
|
|
93
93
|
}
|
|
94
94
|
export {
|
|
95
|
-
|
|
95
|
+
$ as default
|
|
96
96
|
};
|
package/dist/layers/MLP.d.ts
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import { Tensor } from '@tensorflow/tfjs-core';
|
|
2
|
-
import { default as BaseLayer, ForwardAttributes
|
|
2
|
+
import { default as BaseLayer, ForwardAttributes } from './BaseLayer';
|
|
3
|
+
import { GPTConfig } from '../main';
|
|
3
4
|
export default class MLP extends BaseLayer {
|
|
4
5
|
private index;
|
|
5
6
|
private hiddenUnits;
|
|
6
7
|
private MLPHIDDEN;
|
|
7
8
|
private MLPOUT;
|
|
8
|
-
constructor(index: number, config:
|
|
9
|
+
constructor(index: number, config: GPTConfig, parent?: BaseLayer);
|
|
9
10
|
protected build(): void;
|
|
10
11
|
forward(_: ForwardAttributes, x: Tensor): Tensor;
|
|
11
12
|
protected dropout(x: Tensor): Tensor;
|
package/dist/layers/MLP.js
CHANGED
|
@@ -1,56 +1,52 @@
|
|
|
1
|
-
import { t as
|
|
1
|
+
import { t as p } from "../index-DdmHGZjq.js";
|
|
2
2
|
import u from "./BaseLayer.js";
|
|
3
3
|
import { matMulGelu as M } from "../ops/matMulGelu.js";
|
|
4
|
-
import { v as o } from "../variable-
|
|
5
|
-
import { r as h, d as f } from "../dropout-
|
|
6
|
-
import { r as d } from "../reshape-
|
|
7
|
-
import { m as c } from "../mat_mul-
|
|
8
|
-
class
|
|
4
|
+
import { v as o } from "../variable-DPFOJyRG.js";
|
|
5
|
+
import { r as h, d as f } from "../dropout-CcKSfOYE.js";
|
|
6
|
+
import { r as d } from "../reshape-WeJkT3ja.js";
|
|
7
|
+
import { m as c } from "../mat_mul-Dpy2mMRu.js";
|
|
8
|
+
class H extends u {
|
|
9
9
|
index;
|
|
10
10
|
hiddenUnits;
|
|
11
11
|
MLPHIDDEN;
|
|
12
12
|
MLPOUT;
|
|
13
13
|
constructor(i, t, s) {
|
|
14
|
-
super(t, s), this.index = i, this.hiddenUnits = t.
|
|
14
|
+
super(t, s), this.index = i, this.hiddenUnits = t.mlpFactor * t.nEmbed, this.MLPHIDDEN = `block_${this.index}_mlpHidden`, this.MLPOUT = `block_${this.index}_mlpOut`, this.addVariable(this.MLPHIDDEN), this.addVariable(this.MLPOUT);
|
|
15
15
|
}
|
|
16
16
|
build() {
|
|
17
17
|
this.hasVariable(this.MLPHIDDEN) === !1 && this.setVariable(
|
|
18
18
|
this.MLPHIDDEN,
|
|
19
19
|
o(
|
|
20
|
-
h([this.config.
|
|
20
|
+
h([this.config.nEmbed, this.hiddenUnits], 0, 0.02),
|
|
21
21
|
!0
|
|
22
22
|
//`block_${this.index}_attn_cAttn_kernel`
|
|
23
23
|
)
|
|
24
24
|
), this.hasVariable(this.MLPOUT) === !1 && this.setVariable(
|
|
25
25
|
this.MLPOUT,
|
|
26
26
|
o(
|
|
27
|
-
h(
|
|
28
|
-
[this.hiddenUnits, this.config.gpt.nEmbed],
|
|
29
|
-
0,
|
|
30
|
-
0.02 / Math.sqrt(2 * this.config.gpt.nLayer)
|
|
31
|
-
),
|
|
27
|
+
h([this.hiddenUnits, this.config.nEmbed], 0, 0.02 / Math.sqrt(2 * this.config.nLayer)),
|
|
32
28
|
!0
|
|
33
29
|
//`block_${this.index}_attn_cProj_kernel`
|
|
34
30
|
)
|
|
35
31
|
);
|
|
36
32
|
}
|
|
37
33
|
forward(i, t) {
|
|
38
|
-
return
|
|
34
|
+
return p(() => {
|
|
39
35
|
this.startMemory();
|
|
40
|
-
const [s, r, e] = t.shape, n = d(t, [s * r, e]), a = M(n, this.getVariable(this.MLPHIDDEN)),
|
|
36
|
+
const [s, r, e] = t.shape, n = d(t, [s * r, e]), a = M(n, this.getVariable(this.MLPHIDDEN)), m = c(a, this.getVariable(this.MLPOUT));
|
|
41
37
|
a.dispose();
|
|
42
|
-
const
|
|
43
|
-
return this.endMemory("MLP"),
|
|
38
|
+
const l = d(m, [s, r, e]);
|
|
39
|
+
return this.endMemory("MLP"), l;
|
|
44
40
|
});
|
|
45
41
|
}
|
|
46
42
|
dropout(i) {
|
|
47
|
-
if (this.config.
|
|
48
|
-
const t = f(i, this.config.
|
|
43
|
+
if (this.config.dropout > 0) {
|
|
44
|
+
const t = f(i, this.config.dropout);
|
|
49
45
|
return i.dispose(), t;
|
|
50
46
|
}
|
|
51
47
|
return i;
|
|
52
48
|
}
|
|
53
49
|
}
|
|
54
50
|
export {
|
|
55
|
-
|
|
51
|
+
H as default
|
|
56
52
|
};
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
import { Tensor } from '@tensorflow/tfjs-core';
|
|
2
|
+
import { default as BaseLayer } from './BaseLayer';
|
|
3
|
+
import { GPTConfig, ModelForwardAttributes } from '../main';
|
|
4
|
+
export default class PositionEmbedding extends BaseLayer {
|
|
5
|
+
private wpe?;
|
|
6
|
+
private drop;
|
|
7
|
+
constructor(config: GPTConfig, name?: string, parent?: BaseLayer);
|
|
8
|
+
forward(attrs: ModelForwardAttributes, x: Tensor): Tensor;
|
|
9
|
+
}
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import { t as c, a9 as u, b as i } from "../index-DdmHGZjq.js";
|
|
2
|
+
import f from "./BaseLayer.js";
|
|
3
|
+
import { E as g, D as h } from "../random_width-DKGeiFuR.js";
|
|
4
|
+
import { r as b } from "../exports_initializers-DKk7-bsx.js";
|
|
5
|
+
import { m as l } from "../mod-CbibJi3D.js";
|
|
6
|
+
import { r as w } from "../range-BcUvLuf5.js";
|
|
7
|
+
/**
|
|
8
|
+
* @license
|
|
9
|
+
* Copyright 2018 Google LLC
|
|
10
|
+
*
|
|
11
|
+
* Use of this source code is governed by an MIT-style
|
|
12
|
+
* license that can be found in the LICENSE file or at
|
|
13
|
+
* https://opensource.org/licenses/MIT.
|
|
14
|
+
* =============================================================================
|
|
15
|
+
*/
|
|
16
|
+
function E(t) {
|
|
17
|
+
return new h(t);
|
|
18
|
+
}
|
|
19
|
+
function x(t) {
|
|
20
|
+
return new g(t);
|
|
21
|
+
}
|
|
22
|
+
class q extends f {
|
|
23
|
+
wpe;
|
|
24
|
+
// Position embeddings
|
|
25
|
+
drop;
|
|
26
|
+
// Dropout
|
|
27
|
+
constructor(o, n = "", r) {
|
|
28
|
+
super(o, r), this.wpe = x({
|
|
29
|
+
inputDim: this.config.blockSize,
|
|
30
|
+
outputDim: this.config.nEmbed,
|
|
31
|
+
name: n,
|
|
32
|
+
embeddingsInitializer: b({ mean: 0, stddev: 0.02 })
|
|
33
|
+
}), this.drop = E({ rate: this.config.dropout });
|
|
34
|
+
}
|
|
35
|
+
forward(o, n) {
|
|
36
|
+
const r = o.cache?.[0]?.length ?? 0;
|
|
37
|
+
return c(() => {
|
|
38
|
+
const [, s] = n.shape, e = this.config.blockSize, a = w(0, s, 1, "int32"), m = l(u(a, i(r, "int32")), i(e, "int32")), d = this.wpe.apply(m), p = n.add(d);
|
|
39
|
+
return this.drop.apply(p, { training: o.training });
|
|
40
|
+
});
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
export {
|
|
44
|
+
q as default
|
|
45
|
+
};
|
package/dist/layers/RMSNorm.d.ts
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import { Tensor } from '@tensorflow/tfjs-core';
|
|
2
|
-
import { default as BaseLayer, ForwardAttributes
|
|
2
|
+
import { default as BaseLayer, ForwardAttributes } from './BaseLayer';
|
|
3
|
+
import { GPTConfig } from '../main';
|
|
3
4
|
export default class RMSNorm extends BaseLayer {
|
|
4
5
|
private GAMMA;
|
|
5
|
-
constructor(config:
|
|
6
|
+
constructor(config: GPTConfig, name?: string, parent?: BaseLayer);
|
|
6
7
|
forward(_: ForwardAttributes, x: Tensor): Tensor;
|
|
7
8
|
}
|
package/dist/layers/RMSNorm.js
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
import { t as s } from "../index-
|
|
1
|
+
import { t as s } from "../index-DdmHGZjq.js";
|
|
2
2
|
import e from "./BaseLayer.js";
|
|
3
3
|
import { normRMS as a } from "../ops/normRMS.js";
|
|
4
|
-
import { v as i } from "../variable-
|
|
5
|
-
import { o as m } from "../ones-
|
|
6
|
-
class
|
|
4
|
+
import { v as i } from "../variable-DPFOJyRG.js";
|
|
5
|
+
import { o as m } from "../ones-BAqVh-eA.js";
|
|
6
|
+
class l extends e {
|
|
7
7
|
GAMMA;
|
|
8
8
|
constructor(r, t = "", o) {
|
|
9
|
-
super(r, o), this.GAMMA = t, this.addVariable(this.GAMMA, i(m([r.
|
|
9
|
+
super(r, o), this.GAMMA = t, this.addVariable(this.GAMMA, i(m([r.nEmbed]), !0, this.GAMMA, "float32"));
|
|
10
10
|
}
|
|
11
11
|
forward(r, t) {
|
|
12
12
|
return s(() => {
|
|
@@ -17,5 +17,5 @@ class f extends e {
|
|
|
17
17
|
}
|
|
18
18
|
}
|
|
19
19
|
export {
|
|
20
|
-
|
|
20
|
+
l as default
|
|
21
21
|
};
|
package/dist/layers/RoPECache.js
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import { b as t, x as h, t as n,
|
|
2
|
-
import { r as c } from "../reciprocal-
|
|
3
|
-
import { c as f, s as m } from "../sin-
|
|
4
|
-
import { r as a } from "../range-
|
|
1
|
+
import { b as t, x as h, t as n, k as p } from "../index-DdmHGZjq.js";
|
|
2
|
+
import { r as c } from "../reciprocal-DhDWSKiD.js";
|
|
3
|
+
import { c as f, s as m } from "../sin-CPxad7Am.js";
|
|
4
|
+
import { r as a } from "../range-BcUvLuf5.js";
|
|
5
5
|
class D {
|
|
6
6
|
constructor(o) {
|
|
7
7
|
this.config = o;
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import { Tensor } from '@tensorflow/tfjs-core';
|
|
2
|
-
import { default as BaseLayer, ForwardAttributes
|
|
2
|
+
import { default as BaseLayer, ForwardAttributes } from './BaseLayer';
|
|
3
|
+
import { GPTConfig } from '../models/config';
|
|
3
4
|
export default class TiedEmbeddingOutputLayer extends BaseLayer {
|
|
4
5
|
private vocabSize;
|
|
5
6
|
private embedDim;
|
|
6
7
|
private initializer;
|
|
7
8
|
private WEIGHTS;
|
|
8
|
-
constructor(config:
|
|
9
|
+
constructor(config: GPTConfig, name: string, parent?: BaseLayer);
|
|
9
10
|
embed(inputs: Tensor): Tensor;
|
|
10
11
|
project(inputs: Tensor): Tensor;
|
|
11
12
|
forward(_: ForwardAttributes, x: Tensor): Tensor;
|
|
@@ -1,9 +1,31 @@
|
|
|
1
|
-
import "../random_width-
|
|
2
|
-
import "../index-
|
|
3
|
-
import {
|
|
4
|
-
import "./BaseLayer.js";
|
|
5
|
-
import "../variable-
|
|
6
|
-
import "../gather-
|
|
1
|
+
import { d as r } from "../random_width-DKGeiFuR.js";
|
|
2
|
+
import "../index-DdmHGZjq.js";
|
|
3
|
+
import { r as a } from "../exports_initializers-DKk7-bsx.js";
|
|
4
|
+
import s from "./BaseLayer.js";
|
|
5
|
+
import { v as m } from "../variable-DPFOJyRG.js";
|
|
6
|
+
import { g as o } from "../gather-CPg6ZlQA.js";
|
|
7
|
+
class S extends s {
|
|
8
|
+
vocabSize;
|
|
9
|
+
embedDim;
|
|
10
|
+
initializer;
|
|
11
|
+
WEIGHTS;
|
|
12
|
+
constructor(i, e, t) {
|
|
13
|
+
super(i, t), this.WEIGHTS = e, this.vocabSize = i.vocabSize, this.embedDim = i.nEmbed, this.initializer = a({
|
|
14
|
+
mean: 0,
|
|
15
|
+
stddev: 0.02
|
|
16
|
+
}), this.addVariable(this.WEIGHTS, m(this.initializer.apply([this.vocabSize, this.embedDim]), !0));
|
|
17
|
+
}
|
|
18
|
+
embed(i) {
|
|
19
|
+
return o(this.getVariable(this.WEIGHTS), i, 0);
|
|
20
|
+
}
|
|
21
|
+
project(i) {
|
|
22
|
+
return r(i, this.getVariable(this.WEIGHTS).transpose());
|
|
23
|
+
}
|
|
24
|
+
// Dummy, should not be used.
|
|
25
|
+
forward(i, e) {
|
|
26
|
+
return this.project(e);
|
|
27
|
+
}
|
|
28
|
+
}
|
|
7
29
|
export {
|
|
8
|
-
|
|
30
|
+
S as default
|
|
9
31
|
};
|