@genai-fi/nanogpt 0.9.0 → 0.10.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +352 -14
- package/dist/Generator.js +69 -78
- package/dist/{RealDiv-D4EzDsC0.js → RealDiv-DgA3z9oO.js} +32 -206
- package/dist/Reshape-CF6odzV4.js +16 -0
- package/dist/Reshape-_kILl6tK.js +81 -0
- package/dist/TeachableLLM.js +28 -22
- package/dist/Trainer.d.ts +2 -0
- package/dist/Trainer.js +3 -2
- package/dist/{axis_util-TbGYJ208.js → axis_util-BvHEw88j.js} +7 -23
- package/dist/backend.d.ts +2 -1
- package/dist/backend.js +10 -4
- package/dist/backend_util-D-rUb2ty.js +474 -0
- package/dist/backend_webgpu-B0u2ndUn.js +547 -0
- package/dist/binary_op_util-pKXltfxI.js +192 -0
- package/dist/broadcast_to-CwF7XIeu.js +30 -0
- package/dist/checks/appendCache.js +2 -2
- package/dist/checks/attentionMask.js +3 -3
- package/dist/checks/check.d.ts +1 -1
- package/dist/checks/check.js +8 -8
- package/dist/checks/gelu.js +2 -2
- package/dist/checks/index.d.ts +2 -0
- package/dist/checks/index.js +7 -5
- package/dist/checks/matMulGelu.js +6 -6
- package/dist/checks/normRMS.js +7 -7
- package/dist/checks/normRMSGrad.js +3 -3
- package/dist/checks/packUnpack.d.ts +1 -0
- package/dist/checks/packUnpack.js +18 -0
- package/dist/checks/qkv.js +12 -27
- package/dist/checks/rope.js +2 -2
- package/dist/checks/weights.js +18 -16
- package/dist/complex-CSlYz-2T.js +13 -0
- package/dist/complex_util-Yc1A_gV1.js +55 -0
- package/dist/concat-BHlIJeyT.js +19 -0
- package/dist/concat_util-DcJk7YHS.js +22 -0
- package/dist/data/docx.js +1 -1
- package/dist/data/parquet.js +2 -2
- package/dist/data/pdf.js +1 -1
- package/dist/data/textLoader.js +1 -1
- package/dist/{dataset-DlZtKmBq.js → dataset-0xP8GjwI.js} +136 -236
- package/dist/dropout-C1pM3f11.js +99 -0
- package/dist/expand_dims-BPG4fwBP.js +13 -0
- package/dist/exports_initializers-xuidcwI4.js +7 -0
- package/dist/gather-DykLGqmW.js +10 -0
- package/dist/{gelu-Bp_-935b.js → gelu-CNLFZWea.js} +11 -10
- package/dist/{gpgpu_math-CDaYiyE_.js → gpgpu_math-DDVJCn6-.js} +90 -265
- package/dist/{index-C4L8Cm77.js → index-CieiGp4Y.js} +14 -14
- package/dist/index-CjOj7j-u.js +7308 -0
- package/dist/{index-Tf7vU29b.js → index-Cp39cXWe.js} +3 -10
- package/dist/{index-Dwqa6Zy2.js → index-DvYrXKkX.js} +2 -2
- package/dist/index-ZyQhjEPo.js +2157 -0
- package/dist/{jszip.min-CjP2V1VV.js → jszip.min-Bz5-11Bk.js} +56 -57
- package/dist/kernel_funcs_utils-Dg_-E44D.js +308 -0
- package/dist/layers/BaseLayer.d.ts +1 -0
- package/dist/layers/BaseLayer.js +7 -6
- package/dist/layers/CausalSelfAttention.d.ts +0 -1
- package/dist/layers/CausalSelfAttention.js +56 -55
- package/dist/layers/MLP.js +15 -16
- package/dist/layers/PositionEmbedding.js +5 -14
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.d.ts +2 -0
- package/dist/layers/RoPECache.js +22 -17
- package/dist/layers/TiedEmbedding.js +22 -17
- package/dist/layers/TransformerBlock.js +21 -20
- package/dist/loader/load.js +1 -1
- package/dist/loader/loadTransformers.js +1 -1
- package/dist/loader/oldZipLoad.js +39 -33
- package/dist/loader/save.js +1 -1
- package/dist/log_sum_exp-DWI-76TI.js +41 -0
- package/dist/main.d.ts +8 -0
- package/dist/main.js +63 -52
- package/dist/matMul16--R5hOwDG.js +77 -0
- package/dist/mat_mul-DeAh4uTH.js +12 -0
- package/dist/mod-Gt1rMB4n.js +12 -0
- package/dist/models/NanoGPTV1.js +40 -31
- package/dist/models/model.d.ts +2 -0
- package/dist/models/model.js +37 -29
- package/dist/{mulmat_packed_gpu-BT60jmzP.js → mulmat_packed_gpu-BMFhLwta.js} +1 -17
- package/dist/{non_max_suppression_impl-CsEgBuMA.js → non_max_suppression_impl-B2W7YjZB.js} +0 -32
- package/dist/ones-CAMiP4I2.js +15 -0
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.d.ts +1 -1
- package/dist/ops/adamMoments.js +4 -4
- package/dist/ops/add16.d.ts +2 -0
- package/dist/ops/add16.js +9 -0
- package/dist/ops/appendCache.js +16 -9
- package/dist/ops/attentionMask.js +4 -4
- package/dist/ops/concat16.d.ts +2 -0
- package/dist/ops/concat16.js +9 -0
- package/dist/ops/cpu/adamAdjust.js +14 -13
- package/dist/ops/cpu/adamMoments.js +10 -9
- package/dist/ops/cpu/appendCache.js +9 -8
- package/dist/ops/cpu/attentionMask.js +15 -14
- package/dist/ops/cpu/fusedSoftmax.js +13 -12
- package/dist/ops/cpu/gatherSub.js +9 -24
- package/dist/ops/cpu/gelu.js +13 -12
- package/dist/ops/cpu/matMul16.d.ts +1 -0
- package/dist/ops/cpu/matMul16.js +16 -0
- package/dist/ops/cpu/matMulGelu.js +18 -16
- package/dist/ops/cpu/matMulMul.js +8 -7
- package/dist/ops/cpu/mulDropout.js +4 -3
- package/dist/ops/cpu/normRMS.js +11 -10
- package/dist/ops/cpu/qkv.js +17 -13
- package/dist/ops/cpu/rope.js +23 -22
- package/dist/ops/cpu/scatterSub.js +16 -30
- package/dist/ops/dot16.d.ts +2 -0
- package/dist/ops/dot16.js +42 -0
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/add16.d.ts +1 -0
- package/dist/ops/grads/add16.js +27 -0
- package/dist/ops/grads/attentionMask.js +12 -19
- package/dist/ops/grads/gelu.js +4 -3
- package/dist/ops/grads/matMul16.d.ts +2 -0
- package/dist/ops/grads/matMul16.js +9 -0
- package/dist/ops/grads/matMulGelu.js +8 -7
- package/dist/ops/grads/normRMS.js +8 -7
- package/dist/ops/grads/{fusedSoftmax.d.ts → pack16.d.ts} +1 -1
- package/dist/ops/grads/pack16.js +7 -0
- package/dist/ops/grads/qkv.d.ts +3 -1
- package/dist/ops/grads/qkv.js +28 -22
- package/dist/ops/grads/rope.d.ts +2 -1
- package/dist/ops/grads/rope.js +6 -13
- package/dist/ops/grads/softmax16.d.ts +2 -0
- package/dist/ops/grads/softmax16.js +26 -0
- package/dist/ops/grads/unpack16.d.ts +2 -0
- package/dist/ops/grads/unpack16.js +6 -0
- package/dist/ops/grads/utils.d.ts +3 -0
- package/dist/ops/grads/utils.js +10 -0
- package/dist/ops/matMul16.d.ts +15 -0
- package/dist/ops/matMul16.js +13 -0
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mul16.d.ts +2 -0
- package/dist/ops/mul16.js +8 -0
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/pack16.d.ts +2 -0
- package/dist/ops/pack16.js +6 -0
- package/dist/ops/qkv.d.ts +1 -1
- package/dist/ops/qkv.js +8 -4
- package/dist/ops/reshape16.d.ts +2 -0
- package/dist/ops/reshape16.js +43 -0
- package/dist/ops/rope.d.ts +1 -1
- package/dist/ops/rope.js +8 -10
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/slice16.d.ts +2 -0
- package/dist/ops/slice16.js +9 -0
- package/dist/ops/softmax16.d.ts +2 -0
- package/dist/ops/softmax16.js +12 -0
- package/dist/ops/sub16.d.ts +2 -0
- package/dist/ops/sub16.js +8 -0
- package/dist/ops/sum16.d.ts +2 -0
- package/dist/ops/sum16.js +13 -0
- package/dist/ops/transpose16.d.ts +3 -0
- package/dist/ops/transpose16.js +41 -0
- package/dist/ops/unpack16.d.ts +2 -0
- package/dist/ops/unpack16.js +6 -0
- package/dist/ops/webgl/adamAdjust.js +3 -2
- package/dist/ops/webgl/adamMoments.js +2 -1
- package/dist/ops/webgl/appendCache.js +2 -1
- package/dist/ops/webgl/attentionMask.js +5 -4
- package/dist/ops/webgl/fusedSoftmax.js +6 -4
- package/dist/ops/webgl/gatherSub.js +7 -6
- package/dist/ops/webgl/gelu.js +3 -2
- package/dist/ops/webgl/log.js +12 -27
- package/dist/ops/webgl/matMul16.d.ts +1 -0
- package/dist/ops/webgl/matMul16.js +37 -0
- package/dist/ops/webgl/matMulGelu.js +17 -15
- package/dist/ops/webgl/matMulMul.js +13 -12
- package/dist/ops/webgl/mulDropout.js +9 -8
- package/dist/ops/webgl/normRMS.js +8 -7
- package/dist/ops/webgl/qkv.js +6 -5
- package/dist/ops/webgl/rope.js +11 -10
- package/dist/ops/webgl/scatterSub.js +6 -5
- package/dist/ops/webgpu/adamAdjust.js +12 -10
- package/dist/ops/webgpu/adamMoments.js +27 -22
- package/dist/ops/webgpu/add16.d.ts +1 -0
- package/dist/ops/webgpu/add16.js +14 -0
- package/dist/ops/webgpu/appendCache.js +64 -17
- package/dist/ops/webgpu/attentionMask.js +19 -62
- package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
- package/dist/ops/webgpu/attentionMask32_program.js +54 -0
- package/dist/ops/webgpu/concat16.d.ts +19 -0
- package/dist/ops/webgpu/concat16.js +128 -0
- package/dist/ops/webgpu/gatherSub.js +9 -7
- package/dist/ops/webgpu/gelu.js +78 -31
- package/dist/ops/webgpu/index.js +12 -0
- package/dist/ops/webgpu/matMul16.d.ts +1 -0
- package/dist/ops/webgpu/matMul16.js +58 -0
- package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
- package/dist/ops/webgpu/matMul16_program.js +336 -0
- package/dist/ops/webgpu/mul16.d.ts +1 -0
- package/dist/ops/webgpu/mul16.js +14 -0
- package/dist/ops/webgpu/normRMS.js +21 -40
- package/dist/ops/webgpu/normRMS16_program.d.ts +9 -0
- package/dist/ops/webgpu/normRMS16_program.js +24 -0
- package/dist/ops/webgpu/normRMS32_program.d.ts +9 -0
- package/dist/ops/webgpu/normRMS32_program.js +24 -0
- package/dist/ops/webgpu/normRMSGrad.js +113 -64
- package/dist/ops/webgpu/pack16.d.ts +1 -0
- package/dist/ops/webgpu/pack16.js +19 -0
- package/dist/ops/webgpu/pack16_program.d.ts +19 -0
- package/dist/ops/webgpu/pack16_program.js +92 -0
- package/dist/ops/webgpu/qkv.js +20 -55
- package/dist/ops/webgpu/rope.js +77 -22
- package/dist/ops/webgpu/scatterSub.js +9 -7
- package/dist/ops/webgpu/slice16.d.ts +7 -0
- package/dist/ops/webgpu/slice16.js +71 -0
- package/dist/{variable-Bm2OFwGI.js → ops/webgpu/softmax16.d.ts} +2 -8
- package/dist/ops/webgpu/softmax16.js +23 -0
- package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
- package/dist/ops/webgpu/softmax16_program.js +73 -0
- package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
- package/dist/ops/webgpu/softmax16_subgroup_program.js +75 -0
- package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
- package/dist/ops/webgpu/softmax16grad.js +38 -0
- package/dist/ops/webgpu/sub16.d.ts +1 -0
- package/dist/ops/webgpu/sub16.js +14 -0
- package/dist/ops/webgpu/sum16.d.ts +1 -0
- package/dist/ops/webgpu/sum16.js +40 -0
- package/dist/ops/webgpu/transpose16.d.ts +1 -0
- package/dist/ops/webgpu/transpose16.js +35 -0
- package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
- package/dist/ops/webgpu/transpose16_program.js +50 -0
- package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
- package/dist/ops/webgpu/transpose16_shared_program.js +71 -0
- package/dist/ops/webgpu/unpack16.d.ts +1 -0
- package/dist/ops/webgpu/unpack16.js +49 -0
- package/dist/ops/webgpu/utils/binary_op.d.ts +19 -0
- package/dist/ops/webgpu/utils/binary_op.js +79 -0
- package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
- package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
- package/dist/ops/webgpu/utils/reductions.d.ts +32 -4
- package/dist/ops/webgpu/utils/reductions.js +236 -45
- package/dist/ops-CNI3TwqM.js +645 -0
- package/dist/pack16-CFUqumar.js +41 -0
- package/dist/{papaparse.min-C8l2Kvo1.js → papaparse.min-C0cScC2i.js} +2 -8
- package/dist/{parquet-C0Tlmv9c.js → parquet-BE8MU_ge.js} +201 -278
- package/dist/patches/PackedTensor.d.ts +12 -0
- package/dist/patches/PackedTensor.js +11 -0
- package/dist/patches/engine.d.ts +261 -0
- package/dist/patches/engine.js +10 -0
- package/dist/patches/tape.d.ts +12 -0
- package/dist/patches/tape.js +5 -0
- package/dist/patches/webgpu_backend.d.ts +18 -0
- package/dist/patches/webgpu_backend.js +57 -0
- package/dist/{tensor-CZr4dh61.js → patches/webgpu_base.d.ts} +5 -8
- package/dist/patches/webgpu_base.js +34 -0
- package/dist/patches/webgpu_program.d.ts +36 -0
- package/dist/patches/webgpu_program.js +401 -0
- package/dist/{pdf-kJD-f258.js → pdf-NIhmP3sq.js} +424 -428
- package/dist/random_width-DY6Kk2Dl.js +10051 -0
- package/dist/range-BMS52eQi.js +11 -0
- package/dist/reciprocal-CTmshQ9J.js +10 -0
- package/dist/{register_all_kernels-DIGpEwcf.js → register_all_kernels-Bwu1PTuU.js} +719 -9766
- package/dist/relu-yZ2-7WxU.js +10 -0
- package/dist/reshape-DevtBWtf.js +10 -0
- package/dist/rope-B5UUMsPi.js +32 -0
- package/dist/{scatter_nd_util-BQdz--Gn.js → scatter_nd_util-5EL-8VAQ.js} +1 -1
- package/dist/selu_util-D1w6yyTO.js +303 -0
- package/dist/{shared-DuP7ue-R.js → shared-BRksrJb3.js} +1 -17
- package/dist/shared-BuAXb4CI.js +2145 -0
- package/dist/sin-BGfy2HZo.js +16 -0
- package/dist/slice-D_gkkqZK.js +13 -0
- package/dist/slice_util-DtEldBfK.js +261 -0
- package/dist/softmax-ZHVebtR1.js +13 -0
- package/dist/split-DrfihRpZ.js +10 -0
- package/dist/squeeze-DZEpeblb.js +11 -0
- package/dist/stack-yOIAalTq.js +13 -0
- package/dist/sum-_fzj5ZTB.js +12 -0
- package/dist/tensor-DdQUJZlz.js +909 -0
- package/dist/tensor-f35l8Odg.js +8 -0
- package/dist/tensor1d-CeZuc-Rv.js +12 -0
- package/dist/tensor2d-G4Ys2GxX.js +15 -0
- package/dist/tensor4d-B8roDgtc.js +15 -0
- package/dist/tensor_util-DV-FP5Q3.js +523 -0
- package/dist/tfjs_backend-kNyO5L2d.js +653 -0
- package/dist/tile-BzyEiF-F.js +13 -0
- package/dist/tokeniser/CharTokeniser.js +1 -1
- package/dist/tokeniser/bpe.js +1 -1
- package/dist/training/Adam.d.ts +2 -1
- package/dist/training/Adam.js +12 -28
- package/dist/training/AdamExt.d.ts +1 -0
- package/dist/training/AdamExt.js +2 -2
- package/dist/training/DatasetBuilder.js +3 -20
- package/dist/training/FullTrainer.js +82 -64
- package/dist/training/Trainer.d.ts +11 -6
- package/dist/training/Trainer.js +51 -39
- package/dist/training/sparseCrossEntropy.js +3 -3
- package/dist/transpose-DKELTqhe.js +38 -0
- package/dist/utilities/arrayClose.js +7 -7
- package/dist/utilities/dummy.js +35 -27
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/packed.d.ts +7 -0
- package/dist/utilities/packed.js +716 -0
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/sentences.d.ts +5 -0
- package/dist/utilities/sentences.js +41 -0
- package/dist/utilities/weights.js +2 -2
- package/dist/variable-Bhn5bHYv.js +7 -0
- package/dist/{webgpu_program-DkQJOJSd.js → webgpu_program-Cigz-7RF.js} +15 -44
- package/dist/webgpu_util-BBCnKm2X.js +65 -0
- package/dist/zeros-2gldETuK.js +14 -0
- package/package.json +4 -3
- package/dist/Reshape-Bowtk9BP.js +0 -127
- package/dist/Reshape-DUqYftGC.js +0 -30
- package/dist/backend_util-CJIiDoV1.js +0 -749
- package/dist/broadcast_to-DzlNweb8.js +0 -44
- package/dist/concat-B912vBbo.js +0 -33
- package/dist/dropout-C-csYCLj.js +0 -193
- package/dist/exports_initializers-B8iZMgQ0.js +0 -16
- package/dist/gather-Dnpgw-YQ.js +0 -25
- package/dist/index-BzFyqcy-.js +0 -4457
- package/dist/index-C1rx_Ajs.js +0 -12076
- package/dist/kernel_funcs_utils-DKLK0Mg3.js +0 -466
- package/dist/log_sum_exp-DO6z8tSE.js +0 -103
- package/dist/mat_mul-DzjTFx-u.js +0 -27
- package/dist/mod-Dobti4j4.js +0 -27
- package/dist/ones-tIJeHlq-.js +0 -29
- package/dist/ops/fusedSoftmax.d.ts +0 -2
- package/dist/ops/fusedSoftmax.js +0 -10
- package/dist/ops/grads/fusedSoftmax.js +0 -22
- package/dist/ops-LuCMAnmM.js +0 -1525
- package/dist/random_width-CXVRloNK.js +0 -13670
- package/dist/range-CWcz7xFA.js +0 -26
- package/dist/reciprocal-C4rNcM-S.js +0 -25
- package/dist/relu-BjCh_SYb.js +0 -25
- package/dist/reshape-CnIwVG1c.js +0 -25
- package/dist/selu_util-OtRzVwW5.js +0 -719
- package/dist/shared-DmRsFyaJ.js +0 -3134
- package/dist/sin-gpDNRxE0.js +0 -47
- package/dist/slice-d0Vo9XTN.js +0 -28
- package/dist/softmax-D7Jj3p_P.js +0 -28
- package/dist/split-DK2k5eHf.js +0 -25
- package/dist/stack-DFatutCx.js +0 -27
- package/dist/sum-CJ0ULhmt.js +0 -27
- package/dist/tensor1d-vML0r3q6.js +0 -27
- package/dist/tensor2d-D76QGjF3.js +0 -30
- package/dist/tensor4d-Df1WlVDY.js +0 -30
- package/dist/webgpu_util-pLEV9tks.js +0 -80
- package/dist/zeros-Bj5rMYA7.js +0 -52
package/README.md
CHANGED
|
@@ -1,28 +1,366 @@
|
|
|
1
1
|
# GenAI NanoGPT
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
A browser-native implementation of GPT language models built on TensorFlow.js, developed as part of the Finnish Generation AI research project. This library enables training, fine-tuning, and inference of transformer-based language models entirely in the browser with support for explainable AI (XAI) features. It is intended to be used as an educational tool for learning about the model training process since it targets mostly tiny models. In principle it could be adapted to load other pre-trained models from Hugging Face.
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
Live version available here: https://lm.gen-ai.fi
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
## Overview
|
|
8
8
|
|
|
9
|
-
|
|
9
|
+
GenAI NanoGPT is inspired by [Andrej Karpathy's NanoGPT](https://github.com/karpathy/nanoGPT) but reimagined for the browser using TensorFlow.js. It provides a complete pipeline for:
|
|
10
|
+
|
|
11
|
+
- **Training** language models from scratch in the browser
|
|
12
|
+
- **Loading** pre-trained models from various sources (Hugging Face, local files)
|
|
13
|
+
- **Generating** text efficiently on a wide range of devices
|
|
14
|
+
- **Analyzing** model behavior through attention visualization and embeddings
|
|
15
|
+
- **Optimizing** performance across CPU, WebGL, and WebGPU backends
|
|
16
|
+
|
|
17
|
+
### Key Features
|
|
18
|
+
|
|
19
|
+
- 🚀 **Browser-Native**: No server required - train and run models entirely client-side
|
|
20
|
+
- 📱 **Works on Small Devices**: Train models on iPads, phones, and Chromebooks - no powerful hardware needed
|
|
21
|
+
- 🎯 **Multiple Backends**: Automatic backend selection (CPU, WebGL, WebGPU) for optimal performance
|
|
22
|
+
- 🔧 **Flexible Tokenization**: Support for both character-level and BPE tokenizers
|
|
23
|
+
- 📊 **XAI Support**: Attention score visualization, gradient analysis, and embedding extraction
|
|
24
|
+
- 💾 **Model Persistence**: Save and load models in SafeTensors format
|
|
25
|
+
- ⚡ **Performance Optimizations**: Custom WebGPU kernels, gradient checkpointing, and mixed precision training
|
|
26
|
+
- 🎨 **Real-time Training**: Live training metrics and generation during training
|
|
27
|
+
|
|
28
|
+
## Installation
|
|
29
|
+
|
|
30
|
+
```bash
|
|
10
31
|
npm install @genai-fi/nanogpt
|
|
11
32
|
```
|
|
12
33
|
|
|
13
|
-
|
|
34
|
+
## Quick Start
|
|
35
|
+
|
|
36
|
+
### Creating and Training a Model
|
|
37
|
+
|
|
38
|
+
```javascript
|
|
39
|
+
import { TeachableLLM, selectBackend } from '@genai-fi/nanogpt';
|
|
40
|
+
|
|
41
|
+
// Select the best available backend
|
|
42
|
+
await selectBackend('webgpu'); // or 'webgl', 'cpu'
|
|
43
|
+
|
|
44
|
+
// Create a new model
|
|
45
|
+
const model = TeachableLLM.create('char', {
|
|
46
|
+
vocabSize: 200,
|
|
47
|
+
blockSize: 128, // Context window size
|
|
48
|
+
nLayer: 4, // Number of transformer layers
|
|
49
|
+
nHead: 4, // Number of attention heads
|
|
50
|
+
nEmbed: 192, // Embedding dimension
|
|
51
|
+
dropout: 0.1,
|
|
52
|
+
useRope: true, // Use Rotary Position Embeddings
|
|
53
|
+
});
|
|
54
|
+
|
|
55
|
+
// Training data
|
|
56
|
+
const trainingText = [
|
|
57
|
+
'The quick brown fox jumps over the lazy dog.',
|
|
58
|
+
'A journey of a thousand miles begins with a single step.',
|
|
59
|
+
// ... more text
|
|
60
|
+
];
|
|
61
|
+
|
|
62
|
+
// Train the model
|
|
63
|
+
await model.train(trainingText, {
|
|
64
|
+
batchSize: 16,
|
|
65
|
+
learningRate: 3e-4,
|
|
66
|
+
maxSteps: 1000,
|
|
67
|
+
logInterval: 10,
|
|
68
|
+
validationSplit: 0.1,
|
|
69
|
+
});
|
|
14
70
|
|
|
71
|
+
// Generate text
|
|
72
|
+
const output = await model.generateText('Once upon a time', {
|
|
73
|
+
maxLength: 100,
|
|
74
|
+
temperature: 0.8,
|
|
75
|
+
topP: 0.9,
|
|
76
|
+
});
|
|
77
|
+
|
|
78
|
+
console.log(output);
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
### Loading a Pre-trained Model
|
|
82
|
+
|
|
83
|
+
```javascript
|
|
84
|
+
import { TeachableLLM, waitForModel } from '@genai-fi/nanogpt';
|
|
85
|
+
|
|
86
|
+
// Load from Hugging Face
|
|
87
|
+
const model = TeachableLLM.loadModel('username/model-name');
|
|
88
|
+
|
|
89
|
+
// Or load from a file
|
|
90
|
+
const fileInput = document.getElementById('fileInput');
|
|
91
|
+
fileInput.addEventListener('change', async (event) => {
|
|
92
|
+
const file = event.target.files[0];
|
|
93
|
+
const model = TeachableLLM.loadModel(file);
|
|
94
|
+
await waitForModel(model);
|
|
95
|
+
|
|
96
|
+
const text = await model.generateText('Hello');
|
|
97
|
+
console.log(text);
|
|
98
|
+
});
|
|
15
99
|
```
|
|
16
|
-
import { TeachableLLM, CharTokeniser } from '@genai-fi/nanogpt';
|
|
17
|
-
import * as tf from '@tensorflow/tfjs';
|
|
18
100
|
|
|
19
|
-
|
|
20
|
-
|
|
101
|
+
## Event Handlers and Real-time Updates
|
|
102
|
+
|
|
103
|
+
### Monitoring Training Progress
|
|
104
|
+
|
|
105
|
+
Track training metrics in real-time with event handlers:
|
|
106
|
+
|
|
107
|
+
```javascript
|
|
108
|
+
const model = TeachableLLM.create('char', config);
|
|
109
|
+
|
|
110
|
+
// Listen for training step updates
|
|
111
|
+
model.on('trainStep', (step, progress) => {
|
|
112
|
+
console.log(`Step ${step.step}/${progress.totalSteps}`);
|
|
113
|
+
console.log(`Loss: ${step.loss.toFixed(4)}`);
|
|
114
|
+
console.log(`Validation Loss: ${step.valLoss?.toFixed(4) || 'N/A'}`);
|
|
115
|
+
console.log(`Progress: ${(progress.progress * 100).toFixed(1)}%`);
|
|
116
|
+
console.log(`Time Remaining: ${progress.timeRemaining}s`);
|
|
117
|
+
|
|
118
|
+
// Update UI progress bar
|
|
119
|
+
updateProgressBar(progress.progress);
|
|
120
|
+
updateLossChart(step.loss, step.valLoss);
|
|
121
|
+
});
|
|
122
|
+
|
|
123
|
+
await model.train(trainingText, options);
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
### Real-time Token Generation
|
|
127
|
+
|
|
128
|
+
Stream generated tokens as they're produced:
|
|
129
|
+
|
|
130
|
+
```javascript
|
|
131
|
+
const generator = model.generator();
|
|
132
|
+
|
|
133
|
+
// Listen for generated tokens
|
|
134
|
+
generator.on('tokens', (tokens) => {
|
|
135
|
+
// tokens is an array of new token IDs
|
|
136
|
+
const text = model.tokeniser.decode(tokens);
|
|
137
|
+
console.log('New tokens:', text);
|
|
138
|
+
|
|
139
|
+
// Update UI incrementally
|
|
140
|
+
appendToOutput(text);
|
|
141
|
+
});
|
|
142
|
+
|
|
143
|
+
// Generation lifecycle events
|
|
144
|
+
generator.on('start', () => {
|
|
145
|
+
console.log('Generation started');
|
|
146
|
+
showSpinner();
|
|
147
|
+
});
|
|
148
|
+
|
|
149
|
+
generator.on('stop', () => {
|
|
150
|
+
console.log('Generation complete');
|
|
151
|
+
hideSpinner();
|
|
152
|
+
});
|
|
153
|
+
|
|
154
|
+
generator.on('error', (error) => {
|
|
155
|
+
console.error('Generation error:', error);
|
|
156
|
+
});
|
|
157
|
+
|
|
158
|
+
// Start generation
|
|
159
|
+
await generator.generate('Once upon a time', {
|
|
160
|
+
maxLength: 200,
|
|
161
|
+
temperature: 0.8,
|
|
162
|
+
});
|
|
163
|
+
```
|
|
164
|
+
|
|
165
|
+
## Training on Small Devices
|
|
166
|
+
|
|
167
|
+
GenAI NanoGPT is designed to work efficiently on resource-constrained devices like iPads, phones, and Chromebooks:
|
|
168
|
+
|
|
169
|
+
### Recommended Settings for Small Devices
|
|
170
|
+
|
|
171
|
+
```javascript
|
|
172
|
+
// Smaller model configuration for mobile devices
|
|
173
|
+
const mobileModel = TeachableLLM.create('char', {
|
|
21
174
|
vocabSize: 200,
|
|
22
|
-
blockSize: 128,
|
|
23
|
-
nLayer: 4,
|
|
24
|
-
nHead: 3,
|
|
25
|
-
nEmbed: 192,
|
|
26
|
-
|
|
175
|
+
blockSize: 128, // Smaller context window
|
|
176
|
+
nLayer: 4, // Fewer layers
|
|
177
|
+
nHead: 3, // Fewer attention heads
|
|
178
|
+
nEmbed: 192, // Smaller embeddings
|
|
179
|
+
});
|
|
180
|
+
|
|
181
|
+
// Training options optimized for limited memory
|
|
182
|
+
await mobileModel.train(trainingText, {
|
|
183
|
+
batchSize: 8, // Smaller batch size
|
|
184
|
+
learningRate: 3e-4,
|
|
185
|
+
maxSteps: 500,
|
|
186
|
+
validationSplit: 0.1,
|
|
187
|
+
logInterval: 50,
|
|
188
|
+
gradientCheckpointing: true,
|
|
189
|
+
mixedPrecision: true,
|
|
190
|
+
});
|
|
191
|
+
```
|
|
192
|
+
|
|
193
|
+
### Tips for Training on Mobile Devices
|
|
194
|
+
|
|
195
|
+
1. **Start Small**: Use smaller models (4 layers) and shorter context windows (128 tokens)
|
|
196
|
+
2. **Reduce Batch Size**: Use batch sizes of 8-16 depending on available memory
|
|
197
|
+
3. **Use Character Tokenization**: Character-level tokenizers use less memory than BPE
|
|
198
|
+
4. **Optimize Training Data**: Use smaller datasets or train in stages
|
|
199
|
+
|
|
200
|
+
## Advanced Usage
|
|
201
|
+
|
|
202
|
+
### Attention Visualization
|
|
203
|
+
|
|
204
|
+
```javascript
|
|
205
|
+
const generator = model.generator();
|
|
206
|
+
|
|
207
|
+
const text = await generator.generate('Prompt', {
|
|
208
|
+
attentionScores: true,
|
|
209
|
+
maxLength: 50,
|
|
27
210
|
});
|
|
211
|
+
|
|
212
|
+
// Get attention data for visualization
|
|
213
|
+
const attentionData = generator.getAttentionData();
|
|
214
|
+
// Shape: [num_tokens][num_layers][num_heads][seq_len][seq_len]
|
|
215
|
+
|
|
216
|
+
const probabilities = generator.getProbabilitiesData();
|
|
217
|
+
// Shape: [num_tokens][seq_len][vocab_size]
|
|
218
|
+
```
|
|
219
|
+
|
|
220
|
+
### Streaming Generation
|
|
221
|
+
|
|
222
|
+
```javascript
|
|
223
|
+
const generator = model.generator();
|
|
224
|
+
|
|
225
|
+
generator.on('tokens', (tokens) => {
|
|
226
|
+
// Update UI with new tokens in real-time
|
|
227
|
+
updateDisplay(tokens);
|
|
228
|
+
});
|
|
229
|
+
|
|
230
|
+
generator.on('start', () => console.log('Generation started'));
|
|
231
|
+
generator.on('stop', () => console.log('Generation complete'));
|
|
232
|
+
|
|
233
|
+
await generator.generate('Once upon a time', {
|
|
234
|
+
maxLength: 200,
|
|
235
|
+
});
|
|
236
|
+
```
|
|
237
|
+
|
|
238
|
+
### Memory Management
|
|
239
|
+
|
|
240
|
+
```javascript
|
|
241
|
+
// Enable profiling
|
|
242
|
+
model.enableProfiler = true;
|
|
243
|
+
|
|
244
|
+
// After training/generation
|
|
245
|
+
const profiler = model.getProfiler();
|
|
246
|
+
if (profiler) {
|
|
247
|
+
console.log('Memory stats:', profiler.getStats());
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
// Clean up
|
|
251
|
+
model.dispose();
|
|
252
|
+
```
|
|
253
|
+
|
|
254
|
+
## Examples
|
|
255
|
+
|
|
256
|
+
See the [`browser-tests`](browser-tests/) directory for complete examples:
|
|
257
|
+
|
|
258
|
+
- [`generate.html`](browser-tests/generate.html): Text generation with UI
|
|
259
|
+
- [`rope-train.html`](browser-tests/rope-train.html): Training a model with RoPE
|
|
260
|
+
- [`hf.html`](browser-tests/hf.html): Loading from Hugging Face
|
|
261
|
+
- [`loader.html`](browser-tests/loader.html): Loading different file formats
|
|
262
|
+
- [`perf.html`](browser-tests/perf.html): Performance testing
|
|
263
|
+
|
|
264
|
+
## Development
|
|
265
|
+
|
|
266
|
+
### Setup
|
|
267
|
+
|
|
268
|
+
```bash
|
|
269
|
+
git clone https://github.com/knicos/genai-nanogpt.git
|
|
270
|
+
cd genai-nanogpt
|
|
271
|
+
npm install
|
|
272
|
+
```
|
|
273
|
+
|
|
274
|
+
### Building
|
|
275
|
+
|
|
276
|
+
```bash
|
|
277
|
+
npm run build # Build for production
|
|
278
|
+
npm run dev # Development mode with watch
|
|
279
|
+
```
|
|
280
|
+
|
|
281
|
+
### Testing
|
|
282
|
+
|
|
283
|
+
```bash
|
|
284
|
+
npm test # Run all tests
|
|
285
|
+
```
|
|
286
|
+
|
|
287
|
+
### Browser Tests
|
|
288
|
+
|
|
289
|
+
```bash
|
|
290
|
+
npm run test:gl # Start dev server
|
|
291
|
+
```
|
|
292
|
+
|
|
293
|
+
### Project Structure
|
|
294
|
+
|
|
295
|
+
```
|
|
296
|
+
lib/
|
|
297
|
+
├── models/ # Model architectures (NanoGPT)
|
|
298
|
+
├── layers/ # Transformer layers (attention, MLP, etc.)
|
|
299
|
+
├── ops/ # Custom TensorFlow.js operations
|
|
300
|
+
│ ├── cpu/ # CPU kernels
|
|
301
|
+
│ ├── webgl/ # WebGL kernels
|
|
302
|
+
│ └── webgpu/ # WebGPU kernels
|
|
303
|
+
├── training/ # Training utilities and optimizers
|
|
304
|
+
├── tokeniser/ # Tokenization implementations
|
|
305
|
+
├── loader/ # Model loading/saving
|
|
306
|
+
├── utilities/ # Helper functions
|
|
307
|
+
└── TeachableLLM.ts # Main API
|
|
308
|
+
```
|
|
309
|
+
|
|
310
|
+
### Custom Operations
|
|
311
|
+
|
|
312
|
+
This library implements several custom TensorFlow.js operations optimized for transformer models:
|
|
313
|
+
|
|
314
|
+
- **RoPE**: Rotary Position Embeddings
|
|
315
|
+
- **Attention Mask**: Causal attention masking
|
|
316
|
+
- **RMS Norm**: Root Mean Square normalization
|
|
317
|
+
- **Adam Optimizer**: Extended Adam with weight decay
|
|
318
|
+
- **16-bit Operators**: To enable mixed-precision training
|
|
319
|
+
|
|
320
|
+
See [`lib/ops`](lib/ops/) for implementations.
|
|
321
|
+
|
|
322
|
+
### Contributing
|
|
323
|
+
|
|
324
|
+
1. Fork the repository
|
|
325
|
+
2. Create a feature branch: `git checkout -b feature/amazing-feature`
|
|
326
|
+
3. Commit your changes: `git commit -m 'Add amazing feature'`
|
|
327
|
+
4. Push to the branch: `git push origin feature/amazing-feature`
|
|
328
|
+
5. Open a Pull Request
|
|
329
|
+
|
|
330
|
+
### Code Style
|
|
331
|
+
|
|
332
|
+
This project uses ESLint and Prettier for code formatting:
|
|
333
|
+
|
|
334
|
+
```bash
|
|
335
|
+
npm run lint # Check code style
|
|
336
|
+
```
|
|
337
|
+
|
|
338
|
+
## Performance Tips
|
|
339
|
+
|
|
340
|
+
1. **Use WebGPU**: Provides the best performance for training and inference
|
|
341
|
+
2. **Batch Size**: Larger batches improve GPU utilization but require more memory
|
|
342
|
+
3. **Mixed Precision**: Enable for faster training on supported hardware (coming soon)
|
|
343
|
+
4. **Gradient Checkpointing**: Reduce memory usage during training, but slower
|
|
344
|
+
5. **Use RoPE**: More efficient than absolute position embeddings
|
|
345
|
+
6. **Start Small on Mobile**: Use 2-4 layers and batch size 2-8 on phones/tablets
|
|
346
|
+
|
|
347
|
+
## Acknowledgments
|
|
348
|
+
|
|
349
|
+
- Inspired by [Andrej Karpathy's NanoGPT](https://github.com/karpathy/nanoGPT)
|
|
350
|
+
- Built with [TensorFlow.js](https://www.tensorflow.org/js)
|
|
351
|
+
- Developed as part of the Finnish [Generation AI research project](https://generation-ai-stn.fi)
|
|
352
|
+
|
|
353
|
+
## Citation
|
|
354
|
+
|
|
355
|
+
If you use this library in your research, please cite:
|
|
356
|
+
|
|
357
|
+
```bibtex
|
|
358
|
+
@inproceedings{10.1145/3769994.3770061,
|
|
359
|
+
author = {Pope, Nicolas and Tedre, Matti},
|
|
360
|
+
title = {A Teachable Machine for Transformers},
|
|
361
|
+
year = {2025},
|
|
362
|
+
publisher = {Association for Computing Machinery},
|
|
363
|
+
doi = {10.1145/3769994.3770061},
|
|
364
|
+
booktitle = {Proceedings of the 25th Koli Calling International Conference on Computing Education Research},
|
|
365
|
+
}
|
|
28
366
|
```
|
package/dist/Generator.js
CHANGED
|
@@ -1,82 +1,73 @@
|
|
|
1
|
-
import { E as C } from "./index-
|
|
2
|
-
import {
|
|
1
|
+
import { E as C } from "./index-DvYrXKkX.js";
|
|
2
|
+
import { A as _, B as I, E as O, t as R, k as q } from "./index-ZyQhjEPo.js";
|
|
3
|
+
import "./utilities/packed.js";
|
|
3
4
|
import "./ops/cpu/attentionMask.js";
|
|
4
5
|
import "./ops/webgl/attentionMask.js";
|
|
5
6
|
import "./ops/grads/attentionMask.js";
|
|
6
|
-
import "./
|
|
7
|
-
import "./
|
|
8
|
-
import "./
|
|
9
|
-
import
|
|
10
|
-
import { t as G } from "./register_all_kernels-DIGpEwcf.js";
|
|
11
|
-
import "./index-Tf7vU29b.js";
|
|
12
|
-
import "./dataset-DlZtKmBq.js";
|
|
7
|
+
import { p as K } from "./random_width-DY6Kk2Dl.js";
|
|
8
|
+
import { t as j } from "./register_all_kernels-Bwu1PTuU.js";
|
|
9
|
+
import "./index-Cp39cXWe.js";
|
|
10
|
+
import "./dataset-0xP8GjwI.js";
|
|
13
11
|
import "./ops/cpu/rope.js";
|
|
14
12
|
import "./ops/webgl/rope.js";
|
|
15
|
-
import "./
|
|
13
|
+
import "./rope-B5UUMsPi.js";
|
|
16
14
|
import "./ops/cpu/appendCache.js";
|
|
17
15
|
import "./ops/webgl/appendCache.js";
|
|
18
|
-
import "./ops/
|
|
19
|
-
import "./
|
|
20
|
-
import "./ops/
|
|
21
|
-
import "./ops/cpu/
|
|
22
|
-
import "./
|
|
23
|
-
import "./ops/
|
|
16
|
+
import "./ops/grads/softmax16.js";
|
|
17
|
+
import "./matMul16--R5hOwDG.js";
|
|
18
|
+
import "./ops/webgl/matMul16.js";
|
|
19
|
+
import "./ops/cpu/matMul16.js";
|
|
20
|
+
import "./pack16-CFUqumar.js";
|
|
21
|
+
import "./ops/transpose16.js";
|
|
22
|
+
import "./ops/reshape16.js";
|
|
23
|
+
import "./ops/cpu/qkv.js";
|
|
24
|
+
import "./ops/webgl/qkv.js";
|
|
25
|
+
import "./ops/grads/qkv.js";
|
|
24
26
|
import "./ops/cpu/normRMS.js";
|
|
25
27
|
import "./ops/webgl/normRMS.js";
|
|
26
28
|
import "./ops/grads/normRMS.js";
|
|
29
|
+
import "./ops/grads/add16.js";
|
|
27
30
|
import { sparseSoftmaxCrossEntropy as V } from "./training/sparseCrossEntropy.js";
|
|
28
|
-
import "./jszip.min-
|
|
31
|
+
import "./jszip.min-Bz5-11Bk.js";
|
|
29
32
|
import $ from "./tokeniser/CharTokeniser.js";
|
|
30
33
|
import "./ops/cpu/adamAdjust.js";
|
|
31
34
|
import "./ops/webgl/adamAdjust.js";
|
|
32
35
|
import "./ops/cpu/adamMoments.js";
|
|
33
36
|
import "./ops/webgl/adamMoments.js";
|
|
34
|
-
import "./papaparse.min-
|
|
35
|
-
import
|
|
37
|
+
import "./papaparse.min-C0cScC2i.js";
|
|
38
|
+
import G from "./utilities/topP.js";
|
|
36
39
|
import "./ops/cpu/scatterSub.js";
|
|
37
40
|
import "./ops/webgl/scatterSub.js";
|
|
38
41
|
import "./ops/cpu/gatherSub.js";
|
|
39
42
|
import "./ops/webgl/gatherSub.js";
|
|
43
|
+
import "./ops/cpu/matMulGelu.js";
|
|
44
|
+
import "./ops/webgl/matMulGelu.js";
|
|
45
|
+
import "./ops/grads/matMulGelu.js";
|
|
40
46
|
import "./ops/cpu/gelu.js";
|
|
41
47
|
import "./ops/webgl/gelu.js";
|
|
42
|
-
import "./gelu-
|
|
48
|
+
import "./gelu-CNLFZWea.js";
|
|
43
49
|
import "./ops/webgl/log.js";
|
|
44
50
|
import "./checks/normRMS.js";
|
|
45
51
|
import "./checks/normRMSGrad.js";
|
|
46
|
-
import
|
|
47
|
-
import {
|
|
48
|
-
import {
|
|
49
|
-
import {
|
|
50
|
-
import {
|
|
51
|
-
import {
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
56
|
-
* you may not use this file except in compliance with the License.
|
|
57
|
-
* You may obtain a copy of the License at
|
|
58
|
-
*
|
|
59
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
60
|
-
*
|
|
61
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
62
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
63
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
64
|
-
* See the License for the specific language governing permissions and
|
|
65
|
-
* limitations under the License.
|
|
66
|
-
* =============================================================================
|
|
67
|
-
*/
|
|
68
|
-
function U(p, t, s, e = !1) {
|
|
69
|
-
const o = I(p, "logits", "multinomial"), i = o.size, c = o.rank;
|
|
52
|
+
import M from "./utilities/multinomialCPU.js";
|
|
53
|
+
import { i as N } from "./tensor_util-DV-FP5Q3.js";
|
|
54
|
+
import { r as E } from "./reshape-DevtBWtf.js";
|
|
55
|
+
import { t as P } from "./tensor2d-G4Ys2GxX.js";
|
|
56
|
+
import { s as S } from "./softmax-ZHVebtR1.js";
|
|
57
|
+
import { g as B } from "./gather-DykLGqmW.js";
|
|
58
|
+
import { c as H } from "./concat-BHlIJeyT.js";
|
|
59
|
+
function U(l, t, s, e = !1) {
|
|
60
|
+
const o = I(l, "logits", "multinomial"), i = o.size, c = o.rank;
|
|
70
61
|
if (i < 2)
|
|
71
62
|
throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${i}.`);
|
|
72
63
|
if (c > 2)
|
|
73
64
|
throw new Error(`Rank of probabilities must be 1 or 2, but is ${c}`);
|
|
74
65
|
s = s || Math.random();
|
|
75
|
-
const n = { logits: c === 1 ? E(o, [1, -1]) : o },
|
|
66
|
+
const n = { logits: c === 1 ? E(o, [1, -1]) : o }, p = { numSamples: t, seed: s, normalized: e }, d = O.runKernel(N, n, p);
|
|
76
67
|
return c === 1 ? E(d, [d.size]) : d;
|
|
77
68
|
}
|
|
78
69
|
const z = /* @__PURE__ */ _({ multinomial_: U }), W = [
|
|
79
|
-
...Array.from({ length: 95 }, (
|
|
70
|
+
...Array.from({ length: 95 }, (l, t) => String.fromCharCode(t + 32)),
|
|
80
71
|
// ASCII
|
|
81
72
|
// Spanish accented letters and punctuation
|
|
82
73
|
..."áéíóúüñ¿¡",
|
|
@@ -87,10 +78,10 @@ const z = /* @__PURE__ */ _({ multinomial_: U }), W = [
|
|
|
87
78
|
// Cyrillic letters
|
|
88
79
|
..."абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ"
|
|
89
80
|
];
|
|
90
|
-
function
|
|
91
|
-
return
|
|
81
|
+
function F(l, t) {
|
|
82
|
+
return l.length === t ? l : l.length > t ? l.slice(0, t) : l.concat(Array(t - l.length).fill(""));
|
|
92
83
|
}
|
|
93
|
-
class
|
|
84
|
+
class te extends C {
|
|
94
85
|
constructor(t, s) {
|
|
95
86
|
super(), this.model = t, this.tokeniser = s, this.actualTokeniser = s;
|
|
96
87
|
}
|
|
@@ -116,7 +107,7 @@ class Wt extends C {
|
|
|
116
107
|
const c = await t.decode([i]);
|
|
117
108
|
if (e) {
|
|
118
109
|
const T = await Promise.all(
|
|
119
|
-
e.map((n) => n.array().then((
|
|
110
|
+
e.map((n) => n.array().then((p) => p))
|
|
120
111
|
);
|
|
121
112
|
e.forEach((n) => n.dispose()), this.attentionData.push(T);
|
|
122
113
|
}
|
|
@@ -131,14 +122,14 @@ class Wt extends C {
|
|
|
131
122
|
} : void 0,
|
|
132
123
|
cache: s,
|
|
133
124
|
outputEmbeddings: !!e?.embeddings
|
|
134
|
-
}, [
|
|
135
|
-
const
|
|
125
|
+
}, [p, d] = R(() => {
|
|
126
|
+
const r = t, m = r.shape[1], h = m <= this.model.config.blockSize ? r : r.slice(
|
|
136
127
|
[0, m - this.model.config.blockSize],
|
|
137
|
-
[
|
|
138
|
-
),
|
|
128
|
+
[r.shape[0], this.model.config.blockSize]
|
|
129
|
+
), a = T ? this.model.config.blockSize - h.shape[1] : 0, v = a > 0 ? K(h, [
|
|
139
130
|
[0, 0],
|
|
140
|
-
[0,
|
|
141
|
-
]) : h, [g] = this.model.forward(n, v), u = g.shape[1] - 1 -
|
|
131
|
+
[0, a]
|
|
132
|
+
]) : h, [g] = this.model.forward(n, v), u = g.shape[1] - 1 - a, f = g.slice([0, u, 0], [g.shape[0], 1, g.shape[2]]);
|
|
142
133
|
let y;
|
|
143
134
|
if (e?.targets) {
|
|
144
135
|
const k = e.targets.shift();
|
|
@@ -148,46 +139,46 @@ class Wt extends C {
|
|
|
148
139
|
}
|
|
149
140
|
}
|
|
150
141
|
return n.attentionScores?.attentionOut && n.attentionScores.attentionOut.forEach((k, w) => {
|
|
151
|
-
k.shape[1] !== 1 && (n.attentionScores.attentionOut[w] =
|
|
142
|
+
k.shape[1] !== 1 && (n.attentionScores.attentionOut[w] = q(
|
|
152
143
|
k.slice([0, u, 0], [k.shape[0], 1, k.shape[2]])
|
|
153
144
|
), k.dispose());
|
|
154
145
|
}), g.dispose(), [f.div(o).squeeze([1]), y];
|
|
155
146
|
});
|
|
156
147
|
let b, x;
|
|
157
148
|
if (c) {
|
|
158
|
-
const
|
|
159
|
-
|
|
160
|
-
const h =
|
|
161
|
-
e?.includeProbabilities && (x = m), b =
|
|
149
|
+
const r = S(p), m = await r.array();
|
|
150
|
+
r.dispose();
|
|
151
|
+
const h = G(m, c);
|
|
152
|
+
e?.includeProbabilities && (x = m), b = M(h);
|
|
162
153
|
} else if (i) {
|
|
163
|
-
const { values:
|
|
164
|
-
b =
|
|
165
|
-
} else if (b = z(
|
|
166
|
-
const
|
|
167
|
-
x = await
|
|
154
|
+
const { values: r, indices: m } = j(p, i), h = z(r, 1);
|
|
155
|
+
b = B(m, h, 1), r.dispose(), m.dispose(), h.dispose();
|
|
156
|
+
} else if (b = z(p, 1), e?.includeProbabilities) {
|
|
157
|
+
const r = S(p);
|
|
158
|
+
x = await r.array(), r.dispose();
|
|
168
159
|
}
|
|
169
160
|
if (n.embeddings) {
|
|
170
|
-
const m = (e?.embeddings === "all" ? n.embeddings : n.embeddings.filter((
|
|
171
|
-
const v =
|
|
172
|
-
|
|
161
|
+
const m = (e?.embeddings === "all" ? n.embeddings : n.embeddings.filter((a) => a.name.startsWith("block_output_"))).map(async (a) => {
|
|
162
|
+
const v = a.tensor.shape[1], g = a.tensor.slice([0, v - 1, 0], [a.tensor.shape[0], 1, a.tensor.shape[2]]);
|
|
163
|
+
a.tensor.dispose();
|
|
173
164
|
const u = g.squeeze([1]);
|
|
174
165
|
if (g.dispose(), e?.embeddings === "softmax") {
|
|
175
166
|
const f = this.model.project(u);
|
|
176
167
|
u.dispose();
|
|
177
168
|
const y = S(f, -1);
|
|
178
|
-
return f.dispose(), { name:
|
|
169
|
+
return f.dispose(), { name: a.name, tensor: await y.array() };
|
|
179
170
|
} else if (e?.embeddings === "logits") {
|
|
180
171
|
const f = this.model.project(u);
|
|
181
|
-
return u.dispose(), { name:
|
|
172
|
+
return u.dispose(), { name: a.name, tensor: await f.array() };
|
|
182
173
|
} else {
|
|
183
174
|
const f = await u.array();
|
|
184
|
-
return u.dispose(), { name:
|
|
175
|
+
return u.dispose(), { name: a.name, tensor: f };
|
|
185
176
|
}
|
|
186
177
|
}), h = await Promise.all(m);
|
|
187
178
|
this.embeddingsData.push(h);
|
|
188
179
|
}
|
|
189
180
|
const A = b.reshape([1, 1]);
|
|
190
|
-
b.dispose(), b = A,
|
|
181
|
+
b.dispose(), b = A, p.dispose();
|
|
191
182
|
let L;
|
|
192
183
|
return d && (L = await d.array(), d.dispose()), { output: b, probabilities: x, attention: n.attentionScores?.attentionOut, loss: L };
|
|
193
184
|
}
|
|
@@ -211,10 +202,10 @@ class Wt extends C {
|
|
|
211
202
|
const d = s;
|
|
212
203
|
s = H([s, i], 1), d.dispose();
|
|
213
204
|
}
|
|
214
|
-
const
|
|
215
|
-
if (this.cache || i.dispose(),
|
|
205
|
+
const p = await this.processResponse(this.actualTokeniser, i, T, c);
|
|
206
|
+
if (this.cache || i.dispose(), p === null)
|
|
216
207
|
break;
|
|
217
|
-
this.outputText +=
|
|
208
|
+
this.outputText += p;
|
|
218
209
|
}
|
|
219
210
|
return s.dispose(), this.outputText;
|
|
220
211
|
}
|
|
@@ -233,7 +224,7 @@ class Wt extends C {
|
|
|
233
224
|
o[i] = { k: void 0, v: void 0, length: 0, cumulativeLength: 0 };
|
|
234
225
|
this.cache = o, this.lastToken = -1;
|
|
235
226
|
}
|
|
236
|
-
const e = this.tokeniser.trained ? this.tokeniser : new $(
|
|
227
|
+
const e = this.tokeniser.trained ? this.tokeniser : new $(F(W, this.tokeniser.vocabSize));
|
|
237
228
|
this.actualTokeniser = e;
|
|
238
229
|
}
|
|
239
230
|
async step(t, s) {
|
|
@@ -268,5 +259,5 @@ class Wt extends C {
|
|
|
268
259
|
}
|
|
269
260
|
}
|
|
270
261
|
export {
|
|
271
|
-
|
|
262
|
+
te as default
|
|
272
263
|
};
|