@genai-fi/nanogpt 0.20.0 → 0.20.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/BaseTokeniser-DSg9zcYq.js +221 -0
- package/dist/DatasetBuilder-DgURD85T.js +712 -0
- package/dist/Generator.d.ts +82 -0
- package/dist/Generator.js +2 -0
- package/dist/RealDiv-DBu0FQqT.js +362 -0
- package/dist/Reshape-CABOPB9d.js +94 -0
- package/dist/Reshape-DqO3r8BC.js +17 -0
- package/dist/TeachableLLM.d.ts +70 -0
- package/dist/TeachableLLM.js +2 -0
- package/dist/Trainer.d.ts +43 -0
- package/dist/Trainer.js +2 -0
- package/dist/backend.d.ts +2 -0
- package/dist/backend.js +13 -0
- package/dist/backend_util-Cg-roD1p.js +399 -0
- package/dist/binary_op_util-CrYk9LXL.js +103 -0
- package/dist/checks/appendCache.d.ts +1 -0
- package/dist/checks/appendCache.js +55 -0
- package/dist/checks/attentionMask.d.ts +1 -0
- package/dist/checks/attentionMask.js +56 -0
- package/dist/checks/check.d.ts +9 -0
- package/dist/checks/check.js +32 -0
- package/dist/checks/gelu.d.ts +1 -0
- package/dist/checks/gelu.js +46 -0
- package/dist/checks/index.d.ts +26 -0
- package/dist/checks/index.js +28 -0
- package/dist/checks/matMulGelu.d.ts +1 -0
- package/dist/checks/matMulGelu.js +84 -0
- package/dist/checks/normRMS.d.ts +1 -0
- package/dist/checks/normRMS.js +28 -0
- package/dist/checks/normRMSGrad.d.ts +1 -0
- package/dist/checks/normRMSGrad.js +22 -0
- package/dist/checks/packUnpack.d.ts +1 -0
- package/dist/checks/packUnpack.js +46 -0
- package/dist/checks/qkv.d.ts +1 -0
- package/dist/checks/qkv.js +34 -0
- package/dist/checks/rope.d.ts +1 -0
- package/dist/checks/rope.js +30 -0
- package/dist/checks/weights.d.ts +14 -0
- package/dist/checks/weights.js +27 -0
- package/dist/chunk-BPntVaq0.js +23 -0
- package/dist/complex_util-CkazZsaH.js +60 -0
- package/dist/concat_util-CWDZCBlA.js +19 -0
- package/dist/data/docx.d.ts +2 -0
- package/dist/data/docx.js +3046 -0
- package/dist/data/pdf.d.ts +2 -0
- package/dist/data/pdf.js +17 -0
- package/dist/data/textLoader.d.ts +7 -0
- package/dist/data/textLoader.js +613 -0
- package/dist/dist-BewPQWjc.js +7572 -0
- package/dist/dist-DVmq73nz.js +8775 -0
- package/dist/dist-DXwIvKxl.js +896 -0
- package/dist/dist-VEU5mfO0.js +7545 -0
- package/dist/gelu-Bf1HW1RY.js +27 -0
- package/dist/gpgpu_math-DvLcCH6u.js +1612 -0
- package/dist/inference/types.d.ts +16 -0
- package/dist/inference/types.js +0 -0
- package/dist/kernel_funcs_utils-HiXOOx3f.js +229 -0
- package/dist/layers/BaseLayer.d.ts +44 -0
- package/dist/layers/BaseLayer.js +76 -0
- package/dist/layers/CausalSelfAttention.d.ts +39 -0
- package/dist/layers/CausalSelfAttention.js +99 -0
- package/dist/layers/LoRA.d.ts +14 -0
- package/dist/layers/LoRA.js +48 -0
- package/dist/layers/MLP.d.ts +17 -0
- package/dist/layers/MLP.js +34 -0
- package/dist/layers/PositionEmbedding.d.ts +8 -0
- package/dist/layers/PositionEmbedding.js +27 -0
- package/dist/layers/RMSNorm.d.ts +12 -0
- package/dist/layers/RMSNorm.js +20 -0
- package/dist/layers/RoPECache.d.ts +18 -0
- package/dist/layers/RoPECache.js +337 -0
- package/dist/layers/TiedEmbedding.d.ts +13 -0
- package/dist/layers/TiedEmbedding.js +32 -0
- package/dist/layers/TransformerBlock.d.ts +27 -0
- package/dist/layers/TransformerBlock.js +51 -0
- package/dist/layers/WeightStore.d.ts +20 -0
- package/dist/layers/WeightStore.js +69 -0
- package/dist/loader/load.d.ts +6 -0
- package/dist/loader/load.js +2 -0
- package/dist/loader/loadHF.d.ts +8 -0
- package/dist/loader/loadHF.js +2 -0
- package/dist/loader/loadTransformers.d.ts +4 -0
- package/dist/loader/loadTransformers.js +2 -0
- package/dist/loader/loadZipMeta.d.ts +3 -0
- package/dist/loader/loadZipMeta.js +16 -0
- package/dist/loader/newZipLoad.d.ts +3 -0
- package/dist/loader/newZipLoad.js +2 -0
- package/dist/loader/oldZipLoad.d.ts +9 -0
- package/dist/loader/oldZipLoad.js +2 -0
- package/dist/loader/save.d.ts +16 -0
- package/dist/loader/save.js +2 -0
- package/dist/loader/types.d.ts +68 -0
- package/dist/loader/types.js +0 -0
- package/dist/main-D5CbfCiV.js +13500 -0
- package/dist/main.d.ts +50 -0
- package/dist/main.js +16 -0
- package/dist/matMul16-BNfZSnNM.js +81 -0
- package/dist/matMulGelu-CPTntosE.js +162 -0
- package/dist/models/NanoGPTV1.d.ts +16 -0
- package/dist/models/NanoGPTV1.js +2 -0
- package/dist/models/NanoGPTV2.d.ts +16 -0
- package/dist/models/NanoGPTV2.js +2 -0
- package/dist/models/config.d.ts +27 -0
- package/dist/models/config.js +37 -0
- package/dist/models/factory.d.ts +3 -0
- package/dist/models/factory.js +2 -0
- package/dist/models/model.d.ts +44 -0
- package/dist/models/model.js +2 -0
- package/dist/ops/adamAdjust.d.ts +2 -0
- package/dist/ops/adamAdjust.js +18 -0
- package/dist/ops/adamMoments.d.ts +2 -0
- package/dist/ops/adamMoments.js +16 -0
- package/dist/ops/add16.d.ts +2 -0
- package/dist/ops/add16.js +12 -0
- package/dist/ops/appendCache.d.ts +2 -0
- package/dist/ops/appendCache.js +25 -0
- package/dist/ops/attentionMask.d.ts +2 -0
- package/dist/ops/attentionMask.js +16 -0
- package/dist/ops/concat16.d.ts +2 -0
- package/dist/ops/concat16.js +8 -0
- package/dist/ops/cpu/adamAdjust.d.ts +1 -0
- package/dist/ops/cpu/adamAdjust.js +16 -0
- package/dist/ops/cpu/adamMoments.d.ts +1 -0
- package/dist/ops/cpu/adamMoments.js +16 -0
- package/dist/ops/cpu/appendCache.d.ts +1 -0
- package/dist/ops/cpu/appendCache.js +65 -0
- package/dist/ops/cpu/attentionMask.d.ts +1 -0
- package/dist/ops/cpu/attentionMask.js +16 -0
- package/dist/ops/cpu/fusedSoftmax.d.ts +9 -0
- package/dist/ops/cpu/fusedSoftmax.js +22 -0
- package/dist/ops/cpu/gatherSub.d.ts +1 -0
- package/dist/ops/cpu/gatherSub.js +12 -0
- package/dist/ops/cpu/gelu.d.ts +1 -0
- package/dist/ops/cpu/gelu.js +36 -0
- package/dist/ops/cpu/matMul16.d.ts +1 -0
- package/dist/ops/cpu/matMul16.js +14 -0
- package/dist/ops/cpu/matMulGelu.d.ts +1 -0
- package/dist/ops/cpu/matMulGelu.js +41 -0
- package/dist/ops/cpu/matMulMul.d.ts +1 -0
- package/dist/ops/cpu/matMulMul.js +20 -0
- package/dist/ops/cpu/mulDropout.d.ts +1 -0
- package/dist/ops/cpu/mulDropout.js +20 -0
- package/dist/ops/cpu/normRMS.d.ts +1 -0
- package/dist/ops/cpu/normRMS.js +35 -0
- package/dist/ops/cpu/qkv.d.ts +5 -0
- package/dist/ops/cpu/qkv.js +73 -0
- package/dist/ops/cpu/rope.d.ts +6 -0
- package/dist/ops/cpu/rope.js +81 -0
- package/dist/ops/cpu/scatterSub.d.ts +1 -0
- package/dist/ops/cpu/scatterSub.js +12 -0
- package/dist/ops/dot16.d.ts +2 -0
- package/dist/ops/dot16.js +29 -0
- package/dist/ops/dropout.d.ts +2 -0
- package/dist/ops/dropout.js +11 -0
- package/dist/ops/dropout16.d.ts +2 -0
- package/dist/ops/dropout16.js +22 -0
- package/dist/ops/gatherSub.d.ts +2 -0
- package/dist/ops/gatherSub.js +13 -0
- package/dist/ops/gelu.d.ts +3 -0
- package/dist/ops/gelu.js +2 -0
- package/dist/ops/globalNorm.d.ts +2 -0
- package/dist/ops/globalNorm.js +19 -0
- package/dist/ops/grads/add16.d.ts +1 -0
- package/dist/ops/grads/add16.js +27 -0
- package/dist/ops/grads/attentionMask.d.ts +1 -0
- package/dist/ops/grads/attentionMask.js +26 -0
- package/dist/ops/grads/dropout16.d.ts +1 -0
- package/dist/ops/grads/dropout16.js +1 -0
- package/dist/ops/grads/gelu.d.ts +2 -0
- package/dist/ops/grads/gelu.js +2 -0
- package/dist/ops/grads/matMul16.d.ts +2 -0
- package/dist/ops/grads/matMul16.js +2 -0
- package/dist/ops/grads/matMulGelu.d.ts +1 -0
- package/dist/ops/grads/matMulGelu.js +22 -0
- package/dist/ops/grads/mul16.d.ts +1 -0
- package/dist/ops/grads/mul16.js +1 -0
- package/dist/ops/grads/normRMS.d.ts +3 -0
- package/dist/ops/grads/normRMS.js +37 -0
- package/dist/ops/grads/pack16.d.ts +2 -0
- package/dist/ops/grads/pack16.js +2 -0
- package/dist/ops/grads/qkv.d.ts +3 -0
- package/dist/ops/grads/qkv.js +46 -0
- package/dist/ops/grads/rope.d.ts +2 -0
- package/dist/ops/grads/rope.js +2 -0
- package/dist/ops/grads/softmax16.d.ts +2 -0
- package/dist/ops/grads/softmax16.js +23 -0
- package/dist/ops/grads/unpack16.d.ts +2 -0
- package/dist/ops/grads/unpack16.js +2 -0
- package/dist/ops/grads/utils.d.ts +4 -0
- package/dist/ops/grads/utils.js +12 -0
- package/dist/ops/log.d.ts +0 -0
- package/dist/ops/log.js +1 -0
- package/dist/ops/matMul16.d.ts +15 -0
- package/dist/ops/matMul16.js +2 -0
- package/dist/ops/matMulGelu.d.ts +3 -0
- package/dist/ops/matMulGelu.js +20 -0
- package/dist/ops/matMulMul.d.ts +2 -0
- package/dist/ops/matMulMul.js +16 -0
- package/dist/ops/mul16.d.ts +2 -0
- package/dist/ops/mul16.js +43 -0
- package/dist/ops/mulDrop.d.ts +2 -0
- package/dist/ops/mulDrop.js +15 -0
- package/dist/ops/normRMS.d.ts +2 -0
- package/dist/ops/normRMS.js +22 -0
- package/dist/ops/pack16.d.ts +2 -0
- package/dist/ops/pack16.js +2 -0
- package/dist/ops/qkv.d.ts +2 -0
- package/dist/ops/qkv.js +16 -0
- package/dist/ops/reshape16.d.ts +2 -0
- package/dist/ops/reshape16.js +33 -0
- package/dist/ops/rope.d.ts +3 -0
- package/dist/ops/rope.js +2 -0
- package/dist/ops/scatterSub.d.ts +2 -0
- package/dist/ops/scatterSub.js +13 -0
- package/dist/ops/slice16.d.ts +2 -0
- package/dist/ops/slice16.js +11 -0
- package/dist/ops/softmax16.d.ts +2 -0
- package/dist/ops/softmax16.js +9 -0
- package/dist/ops/sub16.d.ts +2 -0
- package/dist/ops/sub16.js +11 -0
- package/dist/ops/sum16.d.ts +2 -0
- package/dist/ops/sum16.js +13 -0
- package/dist/ops/transpose16.d.ts +3 -0
- package/dist/ops/transpose16.js +32 -0
- package/dist/ops/unpack16.d.ts +2 -0
- package/dist/ops/unpack16.js +2 -0
- package/dist/ops/webgl/adamAdjust.d.ts +1 -0
- package/dist/ops/webgl/adamAdjust.js +82 -0
- package/dist/ops/webgl/adamMoments.d.ts +1 -0
- package/dist/ops/webgl/adamMoments.js +44 -0
- package/dist/ops/webgl/appendCache.d.ts +1 -0
- package/dist/ops/webgl/appendCache.js +53 -0
- package/dist/ops/webgl/attentionMask.d.ts +1 -0
- package/dist/ops/webgl/attentionMask.js +64 -0
- package/dist/ops/webgl/dropout16.d.ts +1 -0
- package/dist/ops/webgl/dropout16.js +12 -0
- package/dist/ops/webgl/fusedSoftmax.d.ts +11 -0
- package/dist/ops/webgl/fusedSoftmax.js +70 -0
- package/dist/ops/webgl/gatherSub.d.ts +1 -0
- package/dist/ops/webgl/gatherSub.js +28 -0
- package/dist/ops/webgl/gelu.d.ts +2 -0
- package/dist/ops/webgl/gelu.js +48 -0
- package/dist/ops/webgl/log.d.ts +17 -0
- package/dist/ops/webgl/log.js +14 -0
- package/dist/ops/webgl/matMul16.d.ts +1 -0
- package/dist/ops/webgl/matMul16.js +37 -0
- package/dist/ops/webgl/matMulGelu.d.ts +21 -0
- package/dist/ops/webgl/matMulGelu.js +2 -0
- package/dist/ops/webgl/matMulMul.d.ts +14 -0
- package/dist/ops/webgl/matMulMul.js +24 -0
- package/dist/ops/webgl/mulDropout.d.ts +1 -0
- package/dist/ops/webgl/mulDropout.js +32 -0
- package/dist/ops/webgl/normRMS.d.ts +1 -0
- package/dist/ops/webgl/normRMS.js +114 -0
- package/dist/ops/webgl/qkv.d.ts +1 -0
- package/dist/ops/webgl/qkv.js +54 -0
- package/dist/ops/webgl/rope.d.ts +1 -0
- package/dist/ops/webgl/rope.js +72 -0
- package/dist/ops/webgl/scatterSub.d.ts +1 -0
- package/dist/ops/webgl/scatterSub.js +28 -0
- package/dist/ops/webgpu/adamAdjust.d.ts +1 -0
- package/dist/ops/webgpu/adamAdjust.js +77 -0
- package/dist/ops/webgpu/adamMoments.d.ts +1 -0
- package/dist/ops/webgpu/adamMoments.js +76 -0
- package/dist/ops/webgpu/add16.d.ts +1 -0
- package/dist/ops/webgpu/add16.js +14 -0
- package/dist/ops/webgpu/appendCache.d.ts +1 -0
- package/dist/ops/webgpu/appendCache.js +130 -0
- package/dist/ops/webgpu/attentionMask.d.ts +1 -0
- package/dist/ops/webgpu/attentionMask.js +42 -0
- package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
- package/dist/ops/webgpu/attentionMask32_program.js +62 -0
- package/dist/ops/webgpu/clipScale.d.ts +1 -0
- package/dist/ops/webgpu/clipScale.js +45 -0
- package/dist/ops/webgpu/concat16.d.ts +19 -0
- package/dist/ops/webgpu/concat16.js +111 -0
- package/dist/ops/webgpu/dropout16.d.ts +1 -0
- package/dist/ops/webgpu/dropout16.js +59 -0
- package/dist/ops/webgpu/gatherSub.d.ts +1 -0
- package/dist/ops/webgpu/gatherSub.js +52 -0
- package/dist/ops/webgpu/gelu.d.ts +14 -0
- package/dist/ops/webgpu/gelu.js +147 -0
- package/dist/ops/webgpu/index.d.ts +0 -0
- package/dist/ops/webgpu/index.js +26 -0
- package/dist/ops/webgpu/matMul16.d.ts +1 -0
- package/dist/ops/webgpu/matMul16.js +70 -0
- package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
- package/dist/ops/webgpu/matMul16_program.js +303 -0
- package/dist/ops/webgpu/mul16.d.ts +1 -0
- package/dist/ops/webgpu/mul16.js +14 -0
- package/dist/ops/webgpu/norm2.d.ts +1 -0
- package/dist/ops/webgpu/norm2.js +46 -0
- package/dist/ops/webgpu/normRMS.d.ts +1 -0
- package/dist/ops/webgpu/normRMS.js +26 -0
- package/dist/ops/webgpu/normRMS16_program.d.ts +10 -0
- package/dist/ops/webgpu/normRMS16_program.js +28 -0
- package/dist/ops/webgpu/normRMS32_program.d.ts +10 -0
- package/dist/ops/webgpu/normRMS32_program.js +28 -0
- package/dist/ops/webgpu/normRMSGrad.d.ts +1 -0
- package/dist/ops/webgpu/normRMSGrad.js +225 -0
- package/dist/ops/webgpu/pack16.d.ts +1 -0
- package/dist/ops/webgpu/pack16.js +21 -0
- package/dist/ops/webgpu/pack16_program.d.ts +19 -0
- package/dist/ops/webgpu/pack16_program.js +93 -0
- package/dist/ops/webgpu/qkv.d.ts +1 -0
- package/dist/ops/webgpu/qkv.js +64 -0
- package/dist/ops/webgpu/rope.d.ts +1 -0
- package/dist/ops/webgpu/rope.js +163 -0
- package/dist/ops/webgpu/scatterSub.d.ts +1 -0
- package/dist/ops/webgpu/scatterSub.js +53 -0
- package/dist/ops/webgpu/slice16.d.ts +7 -0
- package/dist/ops/webgpu/slice16.js +74 -0
- package/dist/ops/webgpu/softmax16.d.ts +17 -0
- package/dist/ops/webgpu/softmax16.js +18 -0
- package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
- package/dist/ops/webgpu/softmax16_program.js +89 -0
- package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
- package/dist/ops/webgpu/softmax16_subgroup_program.js +70 -0
- package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
- package/dist/ops/webgpu/softmax16grad.js +31 -0
- package/dist/ops/webgpu/sub16.d.ts +1 -0
- package/dist/ops/webgpu/sub16.js +14 -0
- package/dist/ops/webgpu/sum16.d.ts +1 -0
- package/dist/ops/webgpu/sum16.js +29 -0
- package/dist/ops/webgpu/transpose16.d.ts +1 -0
- package/dist/ops/webgpu/transpose16.js +37 -0
- package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
- package/dist/ops/webgpu/transpose16_program.js +51 -0
- package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
- package/dist/ops/webgpu/transpose16_shared_program.js +79 -0
- package/dist/ops/webgpu/unpack16.d.ts +1 -0
- package/dist/ops/webgpu/unpack16.js +60 -0
- package/dist/ops/webgpu/utils/binary_op.d.ts +35 -0
- package/dist/ops/webgpu/utils/binary_op.js +141 -0
- package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
- package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
- package/dist/ops/webgpu/utils/reductions.d.ts +43 -0
- package/dist/ops/webgpu/utils/reductions.js +263 -0
- package/dist/pack16-Ck-spx_F.js +39 -0
- package/dist/patches/webgpu_backend.d.ts +18 -0
- package/dist/patches/webgpu_backend.js +43 -0
- package/dist/patches/webgpu_base.d.ts +21 -0
- package/dist/patches/webgpu_base.js +22 -0
- package/dist/patches/webgpu_program.d.ts +36 -0
- package/dist/patches/webgpu_program.js +293 -0
- package/dist/pdf-UoDqCYzz.js +16726 -0
- package/dist/picomatch-3tUnMMbd.js +1063 -0
- package/dist/rope-CbeGlsV8.js +25 -0
- package/dist/selu_util-zkAx5doH.js +24 -0
- package/dist/shared-D1coEFea.js +1314 -0
- package/dist/shared-DOgWaqvL.js +5 -0
- package/dist/slice_util-Dgb3ANWI.js +208 -0
- package/dist/tfjs_backend-BjuQ5FqB.js +614 -0
- package/dist/tokeniser/BaseTokeniser.d.ts +33 -0
- package/dist/tokeniser/BaseTokeniser.js +2 -0
- package/dist/tokeniser/CharTokeniser.d.ts +24 -0
- package/dist/tokeniser/CharTokeniser.js +92 -0
- package/dist/tokeniser/bpe.d.ts +28 -0
- package/dist/tokeniser/bpe.js +170 -0
- package/dist/tokeniser/messages.d.ts +61 -0
- package/dist/tokeniser/messages.js +0 -0
- package/dist/tokeniser/type.d.ts +34 -0
- package/dist/tokeniser/type.js +0 -0
- package/dist/training/AdamW.d.ts +36 -0
- package/dist/training/AdamW.js +128 -0
- package/dist/training/BasicTrainer.d.ts +63 -0
- package/dist/training/BasicTrainer.js +265 -0
- package/dist/training/DatasetBuilder.d.ts +26 -0
- package/dist/training/DatasetBuilder.js +2 -0
- package/dist/training/Evaluator.d.ts +19 -0
- package/dist/training/Evaluator.js +48 -0
- package/dist/training/LRScheduler.d.ts +12 -0
- package/dist/training/LRScheduler.js +38 -0
- package/dist/training/PreTrainer.d.ts +11 -0
- package/dist/training/PreTrainer.js +22 -0
- package/dist/training/SFTTrainer.d.ts +12 -0
- package/dist/training/SFTTrainer.js +24 -0
- package/dist/training/loss.d.ts +3 -0
- package/dist/training/loss.js +19 -0
- package/dist/training/orthoGrad.d.ts +2 -0
- package/dist/training/orthoGrad.js +10 -0
- package/dist/training/sparseCrossEntropy.d.ts +7 -0
- package/dist/training/sparseCrossEntropy.js +47 -0
- package/dist/training/tasks/ConversationTask.d.ts +18 -0
- package/dist/training/tasks/ConversationTask.js +38 -0
- package/dist/training/tasks/PretrainingTask.d.ts +17 -0
- package/dist/training/tasks/PretrainingTask.js +42 -0
- package/dist/training/tasks/StartSentenceTask.d.ts +18 -0
- package/dist/training/tasks/StartSentenceTask.js +45 -0
- package/dist/training/tasks/Task.d.ts +22 -0
- package/dist/training/tasks/Task.js +55 -0
- package/dist/training/tasks/splitter.d.ts +5 -0
- package/dist/training/tasks/splitter.js +18 -0
- package/dist/training/types.d.ts +78 -0
- package/dist/training/types.js +0 -0
- package/dist/training/validation.d.ts +17 -0
- package/dist/training/validation.js +2 -0
- package/dist/utilities/arrayClose.d.ts +1 -0
- package/dist/utilities/arrayClose.js +16 -0
- package/dist/utilities/datasetID.d.ts +2 -0
- package/dist/utilities/datasetID.js +18 -0
- package/dist/utilities/dummy.d.ts +9 -0
- package/dist/utilities/dummy.js +36 -0
- package/dist/utilities/multinomialCPU.d.ts +2 -0
- package/dist/utilities/multinomialCPU.js +9 -0
- package/dist/utilities/naming.d.ts +4 -0
- package/dist/utilities/naming.js +0 -0
- package/dist/utilities/packed.d.ts +4 -0
- package/dist/utilities/packed.js +13 -0
- package/dist/utilities/parameters.d.ts +11 -0
- package/dist/utilities/parameters.js +38 -0
- package/dist/utilities/performance.d.ts +2 -0
- package/dist/utilities/performance.js +16 -0
- package/dist/utilities/profile.d.ts +17 -0
- package/dist/utilities/profile.js +33 -0
- package/dist/utilities/safetensors.d.ts +3 -0
- package/dist/utilities/safetensors.js +53 -0
- package/dist/utilities/sentences.d.ts +5 -0
- package/dist/utilities/sentences.js +32 -0
- package/dist/utilities/tokenParse.d.ts +1 -0
- package/dist/utilities/tokenParse.js +17 -0
- package/dist/utilities/topP.d.ts +1 -0
- package/dist/utilities/topP.js +12 -0
- package/dist/utilities/waitForModel.d.ts +2 -0
- package/dist/utilities/waitForModel.js +12 -0
- package/dist/utilities/weights.d.ts +12 -0
- package/dist/utilities/weights.js +40 -0
- package/dist/utilities/yielder.d.ts +1 -0
- package/dist/utilities/yielder.js +7 -0
- package/dist/webgpu-Dt7BMzWz.js +525 -0
- package/dist/webgpu_program-WOyIVMlZ.js +392 -0
- package/dist/webgpu_util-B_F3SShA.js +106 -0
- package/package.json +1 -1
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import { yieldIfNeeded as e } from "../utilities/yielder.js";
|
|
2
|
+
import { n as t, t as n } from "../BaseTokeniser-DSg9zcYq.js";
|
|
3
|
+
//#region lib/tokeniser/CharTokeniser.ts
|
|
4
|
+
var r = ["<eos>", "<unk>"], i = class extends n {
|
|
5
|
+
vocabSize = 0;
|
|
6
|
+
eosToken = 0;
|
|
7
|
+
bosToken = 0;
|
|
8
|
+
unkToken = 0;
|
|
9
|
+
vocab = [];
|
|
10
|
+
cache = /* @__PURE__ */ new Map();
|
|
11
|
+
_trained = !1;
|
|
12
|
+
constructor(e) {
|
|
13
|
+
if (super(), Array.isArray(e)) {
|
|
14
|
+
if (this.vocab = e, this.vocab.length > 0) this.vocabSize = this.vocab.length, t.forEach((e) => {
|
|
15
|
+
let t = this.vocab.indexOf(e);
|
|
16
|
+
t !== -1 && this.addSpecialToken(e, t);
|
|
17
|
+
}), this.eosToken = this.getSpecialTokenIndex("<eos>"), this.bosToken = this.getSpecialTokenIndex("<bos>") ?? this.eosToken, this.unkToken = this.getSpecialTokenIndex("") ?? -1, this.unkToken === -1 && (this.unkToken = this.vocab.indexOf("<unk>")), this.unkToken === -1 && (this.unkToken = this.vocab.indexOf("<pad>")), this.unkToken === -1 && (this.unkToken = this.vocab.indexOf("_")), this.unkToken === -1 && (this.unkToken = this.vocab.indexOf(" ")), this.unkToken === -1 && (this.unkToken = this.eosToken), this.vocab = this.vocab.map((e) => e === "<pad>" ? "" : e), this.vocab.forEach((e, t) => {
|
|
18
|
+
this.cache.set(e, t);
|
|
19
|
+
});
|
|
20
|
+
else throw Error("Vocab cannot be empty");
|
|
21
|
+
this._trained = !0;
|
|
22
|
+
} else this.vocabSize = e, this.vocab = Array(this.vocabSize).fill(""), this.addSpecialTokens(), this.eosToken = this.getSpecialTokenIndex("<eos>"), this.bosToken = this.getSpecialTokenIndex("<bos>") ?? this.eosToken, this.unkToken = this.getSpecialTokenIndex(""), this.vocab.forEach((e, t) => {
|
|
23
|
+
this.cache.set(e, t);
|
|
24
|
+
}), this.cache.set("", this.unkToken);
|
|
25
|
+
}
|
|
26
|
+
addToken(e, t) {
|
|
27
|
+
if (this.cache.has(e)) return this.cache.get(e);
|
|
28
|
+
let n;
|
|
29
|
+
if (t === void 0 ? (n = this.vocab.indexOf("", this.unkToken + 1), n === -1 && (n = this.vocabSize)) : n = t, n >= this.vocabSize) throw Error("Vocab size exceeded");
|
|
30
|
+
return this.vocab[n] = e, this.cache.set(e, n), n;
|
|
31
|
+
}
|
|
32
|
+
get trained() {
|
|
33
|
+
return this.vocab.length === this.vocabSize && this._trained;
|
|
34
|
+
}
|
|
35
|
+
destroy() {
|
|
36
|
+
this.cache.clear(), this.vocab = [];
|
|
37
|
+
}
|
|
38
|
+
async train(t, n, i) {
|
|
39
|
+
this.datasetID = i;
|
|
40
|
+
let a = /* @__PURE__ */ new Set(), o = performance.now();
|
|
41
|
+
for (let r of t) r.forEach((e) => {
|
|
42
|
+
for (let t of e.content) a.add(t);
|
|
43
|
+
}), o = await e(o, n, 0);
|
|
44
|
+
let s = Array.from(a), c = this.vocab.indexOf("", this.unkToken + 1), l = this.vocabSize - r.length;
|
|
45
|
+
if (c === -1) return this.generateID(), this.vocabSize;
|
|
46
|
+
if (this._trained = !0, s.length > l) {
|
|
47
|
+
let e = /* @__PURE__ */ new Map();
|
|
48
|
+
t.forEach((t) => {
|
|
49
|
+
t.forEach((t) => {
|
|
50
|
+
for (let n of t.content) e.set(n, (e.get(n) || 0) + 1);
|
|
51
|
+
});
|
|
52
|
+
}), s.sort((t, n) => (e.get(t) || 0) - (e.get(n) || 0)), s.splice(0, s.length - l);
|
|
53
|
+
}
|
|
54
|
+
let u = c;
|
|
55
|
+
if (u !== -1) {
|
|
56
|
+
let e = new Set(this.vocab);
|
|
57
|
+
for (let t of s) if (!e.has(t) && (this.vocab[u] = t, e.add(t), u = this.vocab.indexOf("", u + 1), u === -1)) break;
|
|
58
|
+
}
|
|
59
|
+
return this.cache.clear(), this.vocab.forEach((e, t) => {
|
|
60
|
+
this.cache.set(e, t);
|
|
61
|
+
}), this.generateID(), this.emit("trainStatus", "trained"), this.vocabSize;
|
|
62
|
+
}
|
|
63
|
+
tokenise(e, t) {
|
|
64
|
+
if (!this.trained) throw Error("Tokeniser not trained");
|
|
65
|
+
return e.map((e) => t ? e.split("").map((e) => this.cache.get(e) ?? this.unkToken) : e.split("").map((e) => {
|
|
66
|
+
let t = this.cache.get(e);
|
|
67
|
+
return t === void 0 ? "" : this.vocab[t];
|
|
68
|
+
}));
|
|
69
|
+
}
|
|
70
|
+
detokenise(e) {
|
|
71
|
+
return e.map((e) => Array.from(e).map((e) => this.vocab[e] || "").join(""));
|
|
72
|
+
}
|
|
73
|
+
encode(e) {
|
|
74
|
+
return this.tokenise([e], !0)[0];
|
|
75
|
+
}
|
|
76
|
+
decode(e) {
|
|
77
|
+
return this.detokenise([e])[0];
|
|
78
|
+
}
|
|
79
|
+
getVocab() {
|
|
80
|
+
return this.vocab;
|
|
81
|
+
}
|
|
82
|
+
getMerges() {
|
|
83
|
+
return [];
|
|
84
|
+
}
|
|
85
|
+
async createTrainingData(e, t = 5) {
|
|
86
|
+
let n = await this.tokenise(e, !0), r = [], i = [];
|
|
87
|
+
for (let e = 0; e < n.length - t; e++) r.push(...n[e].slice(0, t)), i.push(n[e + 1][0]);
|
|
88
|
+
return [r, i];
|
|
89
|
+
}
|
|
90
|
+
};
|
|
91
|
+
//#endregion
|
|
92
|
+
export { i as default };
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import { default as BaseTokeniser } from './BaseTokeniser';
|
|
2
|
+
import { Conversation } from './type';
|
|
3
|
+
export default class BPETokeniser extends BaseTokeniser {
|
|
4
|
+
private targetSize;
|
|
5
|
+
private vocab;
|
|
6
|
+
private vocabIndex;
|
|
7
|
+
private merges;
|
|
8
|
+
private pretokenMap;
|
|
9
|
+
constructor(vocabSize: number);
|
|
10
|
+
constructor(vocab: string[], merges?: [string, string][]);
|
|
11
|
+
addToken(token: string, index?: number): number;
|
|
12
|
+
destroy(): void;
|
|
13
|
+
get trained(): boolean;
|
|
14
|
+
get vocabSize(): number;
|
|
15
|
+
get eosToken(): number;
|
|
16
|
+
get bosToken(): number;
|
|
17
|
+
get unkToken(): number;
|
|
18
|
+
train(text?: Conversation[][], cb?: (vocab: number) => void, datasetID?: string): Promise<number>;
|
|
19
|
+
getVocab(): string[];
|
|
20
|
+
getMerges(): [string, string][];
|
|
21
|
+
private tokeniseWord;
|
|
22
|
+
private tokeniseStrings;
|
|
23
|
+
tokenise(text: string[], numeric: true): number[][];
|
|
24
|
+
tokenise(text: string[]): string[][];
|
|
25
|
+
detokenise(tokens: number[][]): string[];
|
|
26
|
+
encode(text: string): number[];
|
|
27
|
+
decode(tokens: number[]): string;
|
|
28
|
+
}
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import { yieldIfNeeded as e } from "../utilities/yielder.js";
|
|
2
|
+
import { n as t, t as n } from "../BaseTokeniser-DSg9zcYq.js";
|
|
3
|
+
import r from "../utilities/tokenParse.js";
|
|
4
|
+
//#region lib/tokeniser/bpe.ts
|
|
5
|
+
function i(e, t) {
|
|
6
|
+
return `${e}-::-${t}`;
|
|
7
|
+
}
|
|
8
|
+
function a(e) {
|
|
9
|
+
let t = /* @__PURE__ */ new Map();
|
|
10
|
+
for (let n = 0; n < e.length; n++) {
|
|
11
|
+
let r = e[n];
|
|
12
|
+
for (let e = 0; e < r.length - 1; e++) {
|
|
13
|
+
let a = i(r[e], r[e + 1]), o = t.get(a) || {
|
|
14
|
+
a: r[e],
|
|
15
|
+
b: r[e + 1],
|
|
16
|
+
count: 0,
|
|
17
|
+
instances: /* @__PURE__ */ new Set()
|
|
18
|
+
};
|
|
19
|
+
o.count += 1, o.instances.add(n), t.set(a, o);
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
return {
|
|
23
|
+
pairs: t,
|
|
24
|
+
tokens: e
|
|
25
|
+
};
|
|
26
|
+
}
|
|
27
|
+
function o(e, t, n, r, a) {
|
|
28
|
+
let o = i(t, n);
|
|
29
|
+
if (e.pairs.has(o)) {
|
|
30
|
+
let t = e.pairs.get(o);
|
|
31
|
+
t.count += a, a > 0 ? t.instances.add(r) : t.count <= 0 ? e.pairs.delete(o) : t.instances.delete(r);
|
|
32
|
+
} else e.pairs.set(o, {
|
|
33
|
+
a: t,
|
|
34
|
+
b: n,
|
|
35
|
+
count: a,
|
|
36
|
+
instances: new Set([r])
|
|
37
|
+
});
|
|
38
|
+
}
|
|
39
|
+
function s(e) {
|
|
40
|
+
let t = null, n = 0;
|
|
41
|
+
for (let r of e.pairs.values()) r.count > n && (n = r.count, t = r);
|
|
42
|
+
return t;
|
|
43
|
+
}
|
|
44
|
+
function c(e, t) {
|
|
45
|
+
return e.map((e) => {
|
|
46
|
+
let n = [];
|
|
47
|
+
for (let r = 0; r < e.length; r++) r < e.length - 1 && e[r] === t[0] && e[r + 1] === t[1] ? (n.push(t[0] + t[1]), r++) : n.push(e[r]);
|
|
48
|
+
return n;
|
|
49
|
+
});
|
|
50
|
+
}
|
|
51
|
+
function l(e, t) {
|
|
52
|
+
t.instances.forEach((n) => {
|
|
53
|
+
let r = e.tokens[n], i = [];
|
|
54
|
+
for (let a = 0; a < r.length; a++) if (a < r.length - 1 && r[a] === t.a && r[a + 1] === t.b) {
|
|
55
|
+
let s = t.a + t.b;
|
|
56
|
+
i.push(s), a > 0 && (o(e, r[a - 1], t.a, n, -1), o(e, r[a - 1], s, n, 1)), a++, a < r.length - 1 && (o(e, t.b, r[a + 1], n, -1), o(e, s, r[a + 1], n, 1));
|
|
57
|
+
} else i.push(r[a]);
|
|
58
|
+
e.tokens[n] = i;
|
|
59
|
+
}), e.pairs.delete(i(t.a, t.b));
|
|
60
|
+
}
|
|
61
|
+
var u = class extends n {
|
|
62
|
+
targetSize;
|
|
63
|
+
vocab = /* @__PURE__ */ new Set();
|
|
64
|
+
vocabIndex = /* @__PURE__ */ new Map();
|
|
65
|
+
merges = [];
|
|
66
|
+
pretokenMap = /* @__PURE__ */ new Map();
|
|
67
|
+
constructor(e, n) {
|
|
68
|
+
super(), Array.isArray(e) ? (e.forEach((e, t) => {
|
|
69
|
+
this.vocab.add(e), this.vocabIndex.set(e, t);
|
|
70
|
+
}), n && (this.merges = n), this.targetSize = e.length, t.forEach((t) => {
|
|
71
|
+
let n = e.indexOf(t);
|
|
72
|
+
n !== -1 && this.addSpecialToken(t, n);
|
|
73
|
+
})) : (this.addSpecialTokens(), this.targetSize = e);
|
|
74
|
+
}
|
|
75
|
+
addToken(e, t) {
|
|
76
|
+
if (this.vocab.has(e)) return this.vocabIndex.get(e);
|
|
77
|
+
{
|
|
78
|
+
this.vocab.add(e);
|
|
79
|
+
let n = t === void 0 ? this.vocab.size - 1 : t;
|
|
80
|
+
return this.vocabIndex.set(e, n), n;
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
destroy() {
|
|
84
|
+
this.vocab.clear(), this.vocabIndex.clear(), this.merges = [], this.pretokenMap.clear();
|
|
85
|
+
}
|
|
86
|
+
get trained() {
|
|
87
|
+
return this.vocab.size > t.length && this.vocab.size <= this.targetSize;
|
|
88
|
+
}
|
|
89
|
+
get vocabSize() {
|
|
90
|
+
return this.vocab.size;
|
|
91
|
+
}
|
|
92
|
+
get eosToken() {
|
|
93
|
+
return this.vocabIndex.get("<eos>") ?? 0;
|
|
94
|
+
}
|
|
95
|
+
get bosToken() {
|
|
96
|
+
return this.vocabIndex.get("<bos>") ?? 0;
|
|
97
|
+
}
|
|
98
|
+
get unkToken() {
|
|
99
|
+
return this.vocabIndex.get("") ?? 1;
|
|
100
|
+
}
|
|
101
|
+
async train(t = [], n, i) {
|
|
102
|
+
this.datasetID = i;
|
|
103
|
+
let o = performance.now(), c = Array(t.length);
|
|
104
|
+
for (let i = 0; i < t.length; i++) {
|
|
105
|
+
let a = t[i], s = Array(a.length);
|
|
106
|
+
for (let e = 0; e < a.length; e++) s[e] = r(a[e].content);
|
|
107
|
+
o = await e(o, n, this.vocab.size), c[i] = s;
|
|
108
|
+
}
|
|
109
|
+
let u = c.flat(2), d = new Set(u);
|
|
110
|
+
this.vocab = /* @__PURE__ */ new Set(), this.pretokenMap.clear(), this.merges = [], this.addSpecialTokens();
|
|
111
|
+
let f = Array.from(d), p = f.map((e) => Array.from(e).map((e) => (this.vocab.add(e), e))), m = a(p);
|
|
112
|
+
if (o = await e(o, n, this.vocab.size), this.vocab.size >= this.targetSize) {
|
|
113
|
+
console.warn("Initial vocab size is greater than or equal to target size. No merges will be performed.");
|
|
114
|
+
let e = /* @__PURE__ */ new Map();
|
|
115
|
+
u.forEach((t) => {
|
|
116
|
+
Array.from(t).forEach((t) => {
|
|
117
|
+
e.set(t, (e.get(t) || 0) + 1);
|
|
118
|
+
});
|
|
119
|
+
});
|
|
120
|
+
let t = Array.from(e.entries()).sort((e, t) => t[1] - e[1]);
|
|
121
|
+
this.vocab = /* @__PURE__ */ new Set(), this.addSpecialTokens(), t.slice(0, this.targetSize - this.vocab.size).map(([e]) => e).forEach((e) => this.vocab.add(e)), this.vocabIndex.clear();
|
|
122
|
+
let n = 0;
|
|
123
|
+
for (let e of this.vocab.keys()) this.vocabIndex.set(e, n++);
|
|
124
|
+
return this.generateID(), this.emit("trainStatus", "trained"), this.vocab.size;
|
|
125
|
+
}
|
|
126
|
+
for (; this.vocab.size < this.targetSize && this.merges.length < this.targetSize;) {
|
|
127
|
+
let t = s(m);
|
|
128
|
+
if (!t) break;
|
|
129
|
+
this.merges.push([t.a, t.b]), this.vocab.add(t.a + t.b), l(m, t), o = await e(o, n, this.vocab.size);
|
|
130
|
+
}
|
|
131
|
+
f.forEach((e, t) => {
|
|
132
|
+
let n = p[t];
|
|
133
|
+
this.pretokenMap.set(e, n);
|
|
134
|
+
}), this.vocabIndex.clear();
|
|
135
|
+
let h = 0;
|
|
136
|
+
for (let e of this.vocab.keys()) this.vocabIndex.set(e, h++);
|
|
137
|
+
return this.generateID(), this.emit("trainStatus", "trained"), this.vocab.size;
|
|
138
|
+
}
|
|
139
|
+
getVocab() {
|
|
140
|
+
return Array.from(this.vocab);
|
|
141
|
+
}
|
|
142
|
+
getMerges() {
|
|
143
|
+
return this.merges;
|
|
144
|
+
}
|
|
145
|
+
tokeniseWord(e) {
|
|
146
|
+
let t = Array.from(e);
|
|
147
|
+
return this.merges.forEach((e) => {
|
|
148
|
+
t = c([t], e)[0];
|
|
149
|
+
}), this.pretokenMap.set(e, t), t;
|
|
150
|
+
}
|
|
151
|
+
tokeniseStrings(e) {
|
|
152
|
+
return e.map((e) => r(e).map((e) => this.pretokenMap.has(e) ? this.pretokenMap.get(e) : this.tokeniseWord(e)).flat(1));
|
|
153
|
+
}
|
|
154
|
+
tokenise(e, t) {
|
|
155
|
+
let n = this.tokeniseStrings(e);
|
|
156
|
+
return t ? n.map((e) => e.map((e) => this.vocabIndex.get(e) ?? this.unkToken)) : n.map((e) => e.map((e) => this.vocab.has(e) ? e : ""));
|
|
157
|
+
}
|
|
158
|
+
detokenise(e) {
|
|
159
|
+
let t = this.getVocab();
|
|
160
|
+
return e.map((e) => e.map((e) => t[e]).join(""));
|
|
161
|
+
}
|
|
162
|
+
encode(e) {
|
|
163
|
+
return this.tokenise([e], !0)[0];
|
|
164
|
+
}
|
|
165
|
+
decode(e) {
|
|
166
|
+
return this.detokenise([e])[0];
|
|
167
|
+
}
|
|
168
|
+
};
|
|
169
|
+
//#endregion
|
|
170
|
+
export { u as default };
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
interface TrainMessage {
|
|
2
|
+
type: 'train';
|
|
3
|
+
id: number;
|
|
4
|
+
text: string[];
|
|
5
|
+
vocabSize: number;
|
|
6
|
+
}
|
|
7
|
+
interface TrainResponse {
|
|
8
|
+
type: 'trainResponse';
|
|
9
|
+
id: number;
|
|
10
|
+
vocabSize: number;
|
|
11
|
+
}
|
|
12
|
+
interface TrainStatusMessage {
|
|
13
|
+
type: 'trainStatus';
|
|
14
|
+
id: number;
|
|
15
|
+
progress: number;
|
|
16
|
+
vocabSize: number;
|
|
17
|
+
}
|
|
18
|
+
interface TokeniseMessage {
|
|
19
|
+
type: 'tokenise';
|
|
20
|
+
id: number;
|
|
21
|
+
numeric?: boolean;
|
|
22
|
+
text: string[];
|
|
23
|
+
}
|
|
24
|
+
interface TokeniseResponse {
|
|
25
|
+
type: 'tokeniseResponse';
|
|
26
|
+
id: number;
|
|
27
|
+
numeric: boolean;
|
|
28
|
+
tokens: string[][] | number[][];
|
|
29
|
+
}
|
|
30
|
+
interface DetokeniseMessage {
|
|
31
|
+
type: 'detokenise';
|
|
32
|
+
id: number;
|
|
33
|
+
tokens: number[][];
|
|
34
|
+
}
|
|
35
|
+
interface DetokeniseResponse {
|
|
36
|
+
type: 'detokeniseResponse';
|
|
37
|
+
id: number;
|
|
38
|
+
text: string[];
|
|
39
|
+
}
|
|
40
|
+
interface TokensMessage {
|
|
41
|
+
type: 'tokens';
|
|
42
|
+
id: number;
|
|
43
|
+
}
|
|
44
|
+
interface TokensResponse {
|
|
45
|
+
type: 'tokensResponse';
|
|
46
|
+
id: number;
|
|
47
|
+
tokens: string[];
|
|
48
|
+
}
|
|
49
|
+
interface BuildTrainingDataMessage {
|
|
50
|
+
type: 'buildTrainingData';
|
|
51
|
+
id: number;
|
|
52
|
+
text: string[];
|
|
53
|
+
windowSize: number;
|
|
54
|
+
}
|
|
55
|
+
interface BuildTrainingDataResponse {
|
|
56
|
+
type: 'buildTrainingDataResponse';
|
|
57
|
+
id: number;
|
|
58
|
+
trainingData: [number[], number[]];
|
|
59
|
+
}
|
|
60
|
+
export type TokeniserMessage = TrainMessage | TrainResponse | TrainStatusMessage | TokeniseMessage | DetokeniseMessage | TokeniseResponse | DetokeniseResponse | TokensMessage | TokensResponse | BuildTrainingDataMessage | BuildTrainingDataResponse;
|
|
61
|
+
export {};
|
|
File without changes
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import { default as EE } from 'eventemitter3';
|
|
2
|
+
export type Roles = 'user' | 'assistant' | 'system' | 'text';
|
|
3
|
+
export interface Conversation {
|
|
4
|
+
role: Roles;
|
|
5
|
+
content: string;
|
|
6
|
+
}
|
|
7
|
+
export interface ITokeniser extends EE<'trainStatus'> {
|
|
8
|
+
id: string;
|
|
9
|
+
datasetID?: string;
|
|
10
|
+
train(text: Conversation[][], cb?: (vocab: number) => void, datasetID?: string): Promise<number>;
|
|
11
|
+
getVocab(): string[];
|
|
12
|
+
getMerges(): [string, string][];
|
|
13
|
+
destroy(): void;
|
|
14
|
+
encode(text: string): number[];
|
|
15
|
+
encodeConversation(conversation: Conversation[], completion?: boolean): number[];
|
|
16
|
+
encodeConversation(conversation: Conversation[], completion: boolean, masking: boolean): {
|
|
17
|
+
tokens: number[];
|
|
18
|
+
mask: boolean[];
|
|
19
|
+
};
|
|
20
|
+
encodeConversation(conversation: Conversation[], completion?: boolean, masking?: boolean): number[] | {
|
|
21
|
+
tokens: number[];
|
|
22
|
+
mask: boolean[];
|
|
23
|
+
};
|
|
24
|
+
encodeSequence(text: string): number[];
|
|
25
|
+
encodeAsSequence(conversation: Conversation[], completion?: boolean): number[];
|
|
26
|
+
decode(tokens: number[] | Uint16Array): string;
|
|
27
|
+
decodeConversation(tokens: number[] | Uint16Array): Conversation[];
|
|
28
|
+
vocabSize: number;
|
|
29
|
+
eosToken: number;
|
|
30
|
+
bosToken: number;
|
|
31
|
+
trained: boolean;
|
|
32
|
+
getSpecialTokenIndex(token: string): number | undefined;
|
|
33
|
+
isSpecialToken(index: number): boolean;
|
|
34
|
+
}
|
|
File without changes
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import { Optimizer, Tensor } from '@tensorflow/tfjs-core';
|
|
2
|
+
import { ConfigDict, Serializable, SerializableConstructor } from '@tensorflow/tfjs-core/dist/serialization';
|
|
3
|
+
import { NamedTensor, NamedVariableMap } from '@tensorflow/tfjs-core/dist/tensor_types';
|
|
4
|
+
import { default as LRScheduler } from './LRScheduler';
|
|
5
|
+
import { AdamWOptimizerConfig } from './types';
|
|
6
|
+
export declare class AdamWOptimizer extends Optimizer {
|
|
7
|
+
private config;
|
|
8
|
+
readonly className = "AdamW";
|
|
9
|
+
private accBeta1;
|
|
10
|
+
private accBeta2;
|
|
11
|
+
private accumulatedMoments;
|
|
12
|
+
protected learningRate: number;
|
|
13
|
+
protected beta1: number;
|
|
14
|
+
protected beta2: number;
|
|
15
|
+
protected lossScaling: number;
|
|
16
|
+
protected weightDecay: number;
|
|
17
|
+
protected epsilon: number | null;
|
|
18
|
+
protected lrScheduler: LRScheduler;
|
|
19
|
+
protected clipNorm?: number;
|
|
20
|
+
protected orthGradEpsilon: number;
|
|
21
|
+
protected orthGrad: boolean;
|
|
22
|
+
constructor(config: AdamWOptimizerConfig);
|
|
23
|
+
get lr(): number;
|
|
24
|
+
saveMoments(): Promise<ArrayBuffer>;
|
|
25
|
+
loadMoments(momentData: ArrayBuffer): Promise<void>;
|
|
26
|
+
serializeConfig(): AdamWOptimizerConfig;
|
|
27
|
+
private orthogonalizeGradient;
|
|
28
|
+
updateConfig(newConfig: Partial<AdamWOptimizerConfig>): void;
|
|
29
|
+
applyGradients(variableGradients: NamedVariableMap | NamedTensor[]): Tensor;
|
|
30
|
+
dispose(): void;
|
|
31
|
+
getWeights(): Promise<NamedTensor[]>;
|
|
32
|
+
setWeights(weightValues: NamedTensor[]): Promise<void>;
|
|
33
|
+
getConfig(): ConfigDict;
|
|
34
|
+
/** @nocollapse */
|
|
35
|
+
static fromConfig<T extends Serializable>(cls: SerializableConstructor<T>, config: ConfigDict): T;
|
|
36
|
+
}
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
import { _n as e, c as t, di as n, ii as r, kt as i, ni as a } from "../dist-BewPQWjc.js";
|
|
2
|
+
import { load_safetensors as o, save_safetensors as s } from "../utilities/safetensors.js";
|
|
3
|
+
import { adamAdjust as c } from "../ops/adamAdjust.js";
|
|
4
|
+
import { adamMoments as l } from "../ops/adamMoments.js";
|
|
5
|
+
import u from "./LRScheduler.js";
|
|
6
|
+
import { clipScale as d } from "../ops/globalNorm.js";
|
|
7
|
+
//#region lib/training/AdamW.ts
|
|
8
|
+
var f = class extends t {
|
|
9
|
+
config;
|
|
10
|
+
className = "AdamW";
|
|
11
|
+
accBeta1 = 0;
|
|
12
|
+
accBeta2 = 0;
|
|
13
|
+
accumulatedMoments = [];
|
|
14
|
+
learningRate;
|
|
15
|
+
beta1;
|
|
16
|
+
beta2;
|
|
17
|
+
lossScaling;
|
|
18
|
+
weightDecay;
|
|
19
|
+
epsilon = null;
|
|
20
|
+
lrScheduler;
|
|
21
|
+
clipNorm;
|
|
22
|
+
orthGradEpsilon = 1e-30;
|
|
23
|
+
orthGrad;
|
|
24
|
+
constructor(e) {
|
|
25
|
+
super(), this.config = e, this.accBeta1 = e.accBeta1 ?? e.beta1, this.accBeta2 = e.accBeta2 ?? e.beta2, this.learningRate = e.learningRate, this.beta1 = e.beta1, this.beta2 = e.beta2, this.weightDecay = e.weightDecay, this.lossScaling = e.lossScaling, this.clipNorm = e.clipNorm, this.orthGrad = e.orthoGrad ?? !1, e.epsilon === null || e.epsilon === void 0 ? this.epsilon = r().backend.epsilon() : this.epsilon = e.epsilon, this.lrScheduler = new u(e.learningRate, e);
|
|
26
|
+
}
|
|
27
|
+
get lr() {
|
|
28
|
+
return this.learningRate;
|
|
29
|
+
}
|
|
30
|
+
saveMoments() {
|
|
31
|
+
let e = {};
|
|
32
|
+
return this.accumulatedMoments.forEach((t) => {
|
|
33
|
+
e[t.originalName] = t.variable;
|
|
34
|
+
}), s(e);
|
|
35
|
+
}
|
|
36
|
+
async loadMoments(e) {
|
|
37
|
+
let t = await o(e);
|
|
38
|
+
Object.entries(t).forEach(([e, t]) => {
|
|
39
|
+
let n = t.variable(!1);
|
|
40
|
+
this.accumulatedMoments.push({
|
|
41
|
+
originalName: e,
|
|
42
|
+
variable: n
|
|
43
|
+
});
|
|
44
|
+
});
|
|
45
|
+
}
|
|
46
|
+
serializeConfig() {
|
|
47
|
+
return {
|
|
48
|
+
learningRate: this.learningRate,
|
|
49
|
+
beta1: this.beta1,
|
|
50
|
+
beta2: this.beta2,
|
|
51
|
+
accBeta1: this.accBeta1,
|
|
52
|
+
accBeta2: this.accBeta2,
|
|
53
|
+
epsilon: this.epsilon ?? void 0,
|
|
54
|
+
weightDecay: this.weightDecay,
|
|
55
|
+
lossScaling: this.lossScaling,
|
|
56
|
+
clipNorm: this.clipNorm,
|
|
57
|
+
orthoGrad: this.orthGrad,
|
|
58
|
+
...this.lrScheduler.serializeConfig()
|
|
59
|
+
};
|
|
60
|
+
}
|
|
61
|
+
orthogonalizeGradient(e, t) {
|
|
62
|
+
return n(() => {
|
|
63
|
+
let n = e.reshape([-1]), r = t.reshape([-1]), i = n.mul(n).sum().add(this.orthGradEpsilon), a = n.mul(r).sum().div(i), o = r.sub(n.mul(a)), s = r.norm(), c = o.norm().add(this.orthGradEpsilon);
|
|
64
|
+
return o.mul(s.div(c)).reshape(t.shape);
|
|
65
|
+
});
|
|
66
|
+
}
|
|
67
|
+
updateConfig(e) {
|
|
68
|
+
let t = {
|
|
69
|
+
...this.config,
|
|
70
|
+
...e
|
|
71
|
+
};
|
|
72
|
+
this.learningRate = t.learningRate, this.beta1 = t.beta1, this.beta2 = t.beta2, this.weightDecay = t.weightDecay, this.lossScaling = t.lossScaling, this.epsilon = t.epsilon ?? this.epsilon, this.clipNorm = t.clipNorm, this.lrScheduler.updateConfig(t, t.learningRate);
|
|
73
|
+
}
|
|
74
|
+
applyGradients(t) {
|
|
75
|
+
let a = this.lrScheduler.getNextLR();
|
|
76
|
+
this.learningRate = a;
|
|
77
|
+
let o = Array.isArray(t) ? t.map((e) => e.name) : Object.keys(t), s = n(() => {
|
|
78
|
+
let a = 1 - this.accBeta1, s = 1 - this.accBeta2, u;
|
|
79
|
+
return u = this.clipNorm === void 0 ? e(1 / this.lossScaling) : d(o.map((e, n) => Array.isArray(t) ? t[n].tensor : t[e]), 1 / this.lossScaling, this.clipNorm), o.forEach((e, o) => {
|
|
80
|
+
let d = r().registeredVariables[e];
|
|
81
|
+
this.accumulatedMoments[o] ?? (this.accumulatedMoments[o] = {
|
|
82
|
+
originalName: `${e}/m`,
|
|
83
|
+
variable: n(() => i([...d.shape, 2]).variable(!1))
|
|
84
|
+
});
|
|
85
|
+
let f = Array.isArray(t) ? t[o].tensor : t[e];
|
|
86
|
+
if (f == null) return;
|
|
87
|
+
let p = this.orthGrad ? this.orthogonalizeGradient(d, f) : f, m = this.accumulatedMoments[o].variable, h = l(m, p, this.beta1, this.beta2, u);
|
|
88
|
+
m.assign(h), this.orthGrad && p.dispose();
|
|
89
|
+
let g = c(h, d, a, s, this.epsilon ?? 1e-8, this.learningRate, d.shape.length > 1 ? this.weightDecay : 0);
|
|
90
|
+
d.assign(g);
|
|
91
|
+
}), this.accBeta1 *= this.beta1, this.accBeta2 *= this.beta2, u;
|
|
92
|
+
});
|
|
93
|
+
return this.incrementIterations(), s;
|
|
94
|
+
}
|
|
95
|
+
dispose() {
|
|
96
|
+
this.accumulatedMoments != null && a(this.accumulatedMoments.map((e) => e.variable));
|
|
97
|
+
}
|
|
98
|
+
async getWeights() {
|
|
99
|
+
let e = [...this.accumulatedMoments];
|
|
100
|
+
return [await this.saveIterations()].concat(e.map((e) => ({
|
|
101
|
+
name: e.originalName,
|
|
102
|
+
tensor: e.variable
|
|
103
|
+
})));
|
|
104
|
+
}
|
|
105
|
+
async setWeights(e) {
|
|
106
|
+
e = await this.extractIterations(e), n(() => {
|
|
107
|
+
this.accBeta1 = this.beta1 ** +(this.iterations_ + 1), this.accBeta2 = this.beta2 ** +(this.iterations_ + 1);
|
|
108
|
+
});
|
|
109
|
+
let t = e.length / 2;
|
|
110
|
+
this.accumulatedMoments = e.slice(0, t).map((e) => ({
|
|
111
|
+
originalName: e.name,
|
|
112
|
+
variable: e.tensor.variable(!1)
|
|
113
|
+
}));
|
|
114
|
+
}
|
|
115
|
+
getConfig() {
|
|
116
|
+
return {
|
|
117
|
+
learningRate: this.learningRate,
|
|
118
|
+
beta1: this.beta1,
|
|
119
|
+
beta2: this.beta2,
|
|
120
|
+
epsilon: this.epsilon
|
|
121
|
+
};
|
|
122
|
+
}
|
|
123
|
+
static fromConfig(e, t) {
|
|
124
|
+
return new e(t.learningRate, t.beta1, t.beta2, t.epsilon);
|
|
125
|
+
}
|
|
126
|
+
};
|
|
127
|
+
//#endregion
|
|
128
|
+
export { f as AdamWOptimizer };
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import { ITokeniser } from '../tokeniser/type';
|
|
2
|
+
import { Scalar, Tensor } from '@tensorflow/tfjs-core';
|
|
3
|
+
import { Dataset } from '@tensorflow/tfjs-data';
|
|
4
|
+
import { default as Model, ModelForwardAttributes } from '../../models/model';
|
|
5
|
+
import { AdamWOptimizerConfig, TrainingLogEntry, TrainingMetrics, TrainingOptions, TrainingState } from './types';
|
|
6
|
+
import { AdamWOptimizer } from './AdamW';
|
|
7
|
+
export default class BasicTrainer {
|
|
8
|
+
tokenizer: ITokeniser;
|
|
9
|
+
model: Model<ModelForwardAttributes>;
|
|
10
|
+
optimizer: AdamWOptimizer;
|
|
11
|
+
protected running: boolean;
|
|
12
|
+
protected lastState?: TrainingState;
|
|
13
|
+
protected _gradientCheckpointing: boolean;
|
|
14
|
+
protected _mixedPrecision: boolean;
|
|
15
|
+
protected maskedLoss: boolean;
|
|
16
|
+
protected optimizerConfig: AdamWOptimizerConfig;
|
|
17
|
+
protected metrics: Set<TrainingMetrics>;
|
|
18
|
+
protected _labelSmoothing: number;
|
|
19
|
+
protected _layerDrop: number;
|
|
20
|
+
protected _dropout: number;
|
|
21
|
+
constructor(model: Model<ModelForwardAttributes>, tokenizer: ITokeniser, optConfig?: Partial<AdamWOptimizerConfig>, optimizer?: AdamWOptimizer);
|
|
22
|
+
setLossMasking(): void;
|
|
23
|
+
setGradientCheckpointing(enabled: boolean): void;
|
|
24
|
+
setMixedPrecision(enabled: boolean): void;
|
|
25
|
+
setLabelSmoothing(smoothing: number): void;
|
|
26
|
+
setDropout(dropout: number): void;
|
|
27
|
+
setLayerDrop(layerDrop: number): void;
|
|
28
|
+
setLearningRate(learningRate: number): void;
|
|
29
|
+
setMetrics(metrics: TrainingMetrics[]): void;
|
|
30
|
+
reset(): void;
|
|
31
|
+
stop(): void;
|
|
32
|
+
get isRunning(): boolean;
|
|
33
|
+
getOptimizer(): AdamWOptimizer;
|
|
34
|
+
updateOptimizer(config?: Partial<AdamWOptimizerConfig>): void;
|
|
35
|
+
resumeFromLog(log: TrainingLogEntry): void;
|
|
36
|
+
protected trainStep(state: Partial<TrainingState>, batch: {
|
|
37
|
+
xs: Tensor;
|
|
38
|
+
ys: Tensor;
|
|
39
|
+
}, dummy?: boolean, keepGrads?: boolean): Scalar;
|
|
40
|
+
private dummyPass;
|
|
41
|
+
dispose(): void;
|
|
42
|
+
private createEmptyState;
|
|
43
|
+
stepDataset(dataset: Dataset<{
|
|
44
|
+
xs: Tensor;
|
|
45
|
+
ys: Tensor;
|
|
46
|
+
}>, options: Partial<TrainingOptions>, validationDataset?: Dataset<{
|
|
47
|
+
xs: Tensor;
|
|
48
|
+
ys: Tensor;
|
|
49
|
+
}>): Promise<{
|
|
50
|
+
log: TrainingLogEntry;
|
|
51
|
+
}>;
|
|
52
|
+
private performLogging;
|
|
53
|
+
trainOnDataset(dataset: Dataset<{
|
|
54
|
+
xs: Tensor;
|
|
55
|
+
ys: Tensor;
|
|
56
|
+
}>, options: Partial<TrainingOptions>, validationDataset?: Dataset<{
|
|
57
|
+
xs: Tensor;
|
|
58
|
+
ys: Tensor;
|
|
59
|
+
}>): Promise<{
|
|
60
|
+
losses: number[];
|
|
61
|
+
validationLosses: number[];
|
|
62
|
+
}>;
|
|
63
|
+
}
|