@genai-fi/nanogpt 0.20.0 → 0.20.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/BaseTokeniser-DSg9zcYq.js +221 -0
- package/dist/DatasetBuilder-DgURD85T.js +712 -0
- package/dist/Generator.d.ts +82 -0
- package/dist/Generator.js +2 -0
- package/dist/RealDiv-DBu0FQqT.js +362 -0
- package/dist/Reshape-CABOPB9d.js +94 -0
- package/dist/Reshape-DqO3r8BC.js +17 -0
- package/dist/TeachableLLM.d.ts +70 -0
- package/dist/TeachableLLM.js +2 -0
- package/dist/Trainer.d.ts +43 -0
- package/dist/Trainer.js +2 -0
- package/dist/backend.d.ts +2 -0
- package/dist/backend.js +13 -0
- package/dist/backend_util-Cg-roD1p.js +399 -0
- package/dist/binary_op_util-CrYk9LXL.js +103 -0
- package/dist/checks/appendCache.d.ts +1 -0
- package/dist/checks/appendCache.js +55 -0
- package/dist/checks/attentionMask.d.ts +1 -0
- package/dist/checks/attentionMask.js +56 -0
- package/dist/checks/check.d.ts +9 -0
- package/dist/checks/check.js +32 -0
- package/dist/checks/gelu.d.ts +1 -0
- package/dist/checks/gelu.js +46 -0
- package/dist/checks/index.d.ts +26 -0
- package/dist/checks/index.js +28 -0
- package/dist/checks/matMulGelu.d.ts +1 -0
- package/dist/checks/matMulGelu.js +84 -0
- package/dist/checks/normRMS.d.ts +1 -0
- package/dist/checks/normRMS.js +28 -0
- package/dist/checks/normRMSGrad.d.ts +1 -0
- package/dist/checks/normRMSGrad.js +22 -0
- package/dist/checks/packUnpack.d.ts +1 -0
- package/dist/checks/packUnpack.js +46 -0
- package/dist/checks/qkv.d.ts +1 -0
- package/dist/checks/qkv.js +34 -0
- package/dist/checks/rope.d.ts +1 -0
- package/dist/checks/rope.js +30 -0
- package/dist/checks/weights.d.ts +14 -0
- package/dist/checks/weights.js +27 -0
- package/dist/chunk-BPntVaq0.js +23 -0
- package/dist/complex_util-CkazZsaH.js +60 -0
- package/dist/concat_util-CWDZCBlA.js +19 -0
- package/dist/data/docx.d.ts +2 -0
- package/dist/data/docx.js +3046 -0
- package/dist/data/pdf.d.ts +2 -0
- package/dist/data/pdf.js +17 -0
- package/dist/data/textLoader.d.ts +7 -0
- package/dist/data/textLoader.js +613 -0
- package/dist/dist-BewPQWjc.js +7572 -0
- package/dist/dist-DVmq73nz.js +8775 -0
- package/dist/dist-DXwIvKxl.js +896 -0
- package/dist/dist-VEU5mfO0.js +7545 -0
- package/dist/gelu-Bf1HW1RY.js +27 -0
- package/dist/gpgpu_math-DvLcCH6u.js +1612 -0
- package/dist/inference/types.d.ts +16 -0
- package/dist/inference/types.js +0 -0
- package/dist/kernel_funcs_utils-HiXOOx3f.js +229 -0
- package/dist/layers/BaseLayer.d.ts +44 -0
- package/dist/layers/BaseLayer.js +76 -0
- package/dist/layers/CausalSelfAttention.d.ts +39 -0
- package/dist/layers/CausalSelfAttention.js +99 -0
- package/dist/layers/LoRA.d.ts +14 -0
- package/dist/layers/LoRA.js +48 -0
- package/dist/layers/MLP.d.ts +17 -0
- package/dist/layers/MLP.js +34 -0
- package/dist/layers/PositionEmbedding.d.ts +8 -0
- package/dist/layers/PositionEmbedding.js +27 -0
- package/dist/layers/RMSNorm.d.ts +12 -0
- package/dist/layers/RMSNorm.js +20 -0
- package/dist/layers/RoPECache.d.ts +18 -0
- package/dist/layers/RoPECache.js +337 -0
- package/dist/layers/TiedEmbedding.d.ts +13 -0
- package/dist/layers/TiedEmbedding.js +32 -0
- package/dist/layers/TransformerBlock.d.ts +27 -0
- package/dist/layers/TransformerBlock.js +51 -0
- package/dist/layers/WeightStore.d.ts +20 -0
- package/dist/layers/WeightStore.js +69 -0
- package/dist/loader/load.d.ts +6 -0
- package/dist/loader/load.js +2 -0
- package/dist/loader/loadHF.d.ts +8 -0
- package/dist/loader/loadHF.js +2 -0
- package/dist/loader/loadTransformers.d.ts +4 -0
- package/dist/loader/loadTransformers.js +2 -0
- package/dist/loader/loadZipMeta.d.ts +3 -0
- package/dist/loader/loadZipMeta.js +16 -0
- package/dist/loader/newZipLoad.d.ts +3 -0
- package/dist/loader/newZipLoad.js +2 -0
- package/dist/loader/oldZipLoad.d.ts +9 -0
- package/dist/loader/oldZipLoad.js +2 -0
- package/dist/loader/save.d.ts +16 -0
- package/dist/loader/save.js +2 -0
- package/dist/loader/types.d.ts +68 -0
- package/dist/loader/types.js +0 -0
- package/dist/main-D5CbfCiV.js +13500 -0
- package/dist/main.d.ts +50 -0
- package/dist/main.js +16 -0
- package/dist/matMul16-BNfZSnNM.js +81 -0
- package/dist/matMulGelu-CPTntosE.js +162 -0
- package/dist/models/NanoGPTV1.d.ts +16 -0
- package/dist/models/NanoGPTV1.js +2 -0
- package/dist/models/NanoGPTV2.d.ts +16 -0
- package/dist/models/NanoGPTV2.js +2 -0
- package/dist/models/config.d.ts +27 -0
- package/dist/models/config.js +37 -0
- package/dist/models/factory.d.ts +3 -0
- package/dist/models/factory.js +2 -0
- package/dist/models/model.d.ts +44 -0
- package/dist/models/model.js +2 -0
- package/dist/ops/adamAdjust.d.ts +2 -0
- package/dist/ops/adamAdjust.js +18 -0
- package/dist/ops/adamMoments.d.ts +2 -0
- package/dist/ops/adamMoments.js +16 -0
- package/dist/ops/add16.d.ts +2 -0
- package/dist/ops/add16.js +12 -0
- package/dist/ops/appendCache.d.ts +2 -0
- package/dist/ops/appendCache.js +25 -0
- package/dist/ops/attentionMask.d.ts +2 -0
- package/dist/ops/attentionMask.js +16 -0
- package/dist/ops/concat16.d.ts +2 -0
- package/dist/ops/concat16.js +8 -0
- package/dist/ops/cpu/adamAdjust.d.ts +1 -0
- package/dist/ops/cpu/adamAdjust.js +16 -0
- package/dist/ops/cpu/adamMoments.d.ts +1 -0
- package/dist/ops/cpu/adamMoments.js +16 -0
- package/dist/ops/cpu/appendCache.d.ts +1 -0
- package/dist/ops/cpu/appendCache.js +65 -0
- package/dist/ops/cpu/attentionMask.d.ts +1 -0
- package/dist/ops/cpu/attentionMask.js +16 -0
- package/dist/ops/cpu/fusedSoftmax.d.ts +9 -0
- package/dist/ops/cpu/fusedSoftmax.js +22 -0
- package/dist/ops/cpu/gatherSub.d.ts +1 -0
- package/dist/ops/cpu/gatherSub.js +12 -0
- package/dist/ops/cpu/gelu.d.ts +1 -0
- package/dist/ops/cpu/gelu.js +36 -0
- package/dist/ops/cpu/matMul16.d.ts +1 -0
- package/dist/ops/cpu/matMul16.js +14 -0
- package/dist/ops/cpu/matMulGelu.d.ts +1 -0
- package/dist/ops/cpu/matMulGelu.js +41 -0
- package/dist/ops/cpu/matMulMul.d.ts +1 -0
- package/dist/ops/cpu/matMulMul.js +20 -0
- package/dist/ops/cpu/mulDropout.d.ts +1 -0
- package/dist/ops/cpu/mulDropout.js +20 -0
- package/dist/ops/cpu/normRMS.d.ts +1 -0
- package/dist/ops/cpu/normRMS.js +35 -0
- package/dist/ops/cpu/qkv.d.ts +5 -0
- package/dist/ops/cpu/qkv.js +73 -0
- package/dist/ops/cpu/rope.d.ts +6 -0
- package/dist/ops/cpu/rope.js +81 -0
- package/dist/ops/cpu/scatterSub.d.ts +1 -0
- package/dist/ops/cpu/scatterSub.js +12 -0
- package/dist/ops/dot16.d.ts +2 -0
- package/dist/ops/dot16.js +29 -0
- package/dist/ops/dropout.d.ts +2 -0
- package/dist/ops/dropout.js +11 -0
- package/dist/ops/dropout16.d.ts +2 -0
- package/dist/ops/dropout16.js +22 -0
- package/dist/ops/gatherSub.d.ts +2 -0
- package/dist/ops/gatherSub.js +13 -0
- package/dist/ops/gelu.d.ts +3 -0
- package/dist/ops/gelu.js +2 -0
- package/dist/ops/globalNorm.d.ts +2 -0
- package/dist/ops/globalNorm.js +19 -0
- package/dist/ops/grads/add16.d.ts +1 -0
- package/dist/ops/grads/add16.js +27 -0
- package/dist/ops/grads/attentionMask.d.ts +1 -0
- package/dist/ops/grads/attentionMask.js +26 -0
- package/dist/ops/grads/dropout16.d.ts +1 -0
- package/dist/ops/grads/dropout16.js +1 -0
- package/dist/ops/grads/gelu.d.ts +2 -0
- package/dist/ops/grads/gelu.js +2 -0
- package/dist/ops/grads/matMul16.d.ts +2 -0
- package/dist/ops/grads/matMul16.js +2 -0
- package/dist/ops/grads/matMulGelu.d.ts +1 -0
- package/dist/ops/grads/matMulGelu.js +22 -0
- package/dist/ops/grads/mul16.d.ts +1 -0
- package/dist/ops/grads/mul16.js +1 -0
- package/dist/ops/grads/normRMS.d.ts +3 -0
- package/dist/ops/grads/normRMS.js +37 -0
- package/dist/ops/grads/pack16.d.ts +2 -0
- package/dist/ops/grads/pack16.js +2 -0
- package/dist/ops/grads/qkv.d.ts +3 -0
- package/dist/ops/grads/qkv.js +46 -0
- package/dist/ops/grads/rope.d.ts +2 -0
- package/dist/ops/grads/rope.js +2 -0
- package/dist/ops/grads/softmax16.d.ts +2 -0
- package/dist/ops/grads/softmax16.js +23 -0
- package/dist/ops/grads/unpack16.d.ts +2 -0
- package/dist/ops/grads/unpack16.js +2 -0
- package/dist/ops/grads/utils.d.ts +4 -0
- package/dist/ops/grads/utils.js +12 -0
- package/dist/ops/log.d.ts +0 -0
- package/dist/ops/log.js +1 -0
- package/dist/ops/matMul16.d.ts +15 -0
- package/dist/ops/matMul16.js +2 -0
- package/dist/ops/matMulGelu.d.ts +3 -0
- package/dist/ops/matMulGelu.js +20 -0
- package/dist/ops/matMulMul.d.ts +2 -0
- package/dist/ops/matMulMul.js +16 -0
- package/dist/ops/mul16.d.ts +2 -0
- package/dist/ops/mul16.js +43 -0
- package/dist/ops/mulDrop.d.ts +2 -0
- package/dist/ops/mulDrop.js +15 -0
- package/dist/ops/normRMS.d.ts +2 -0
- package/dist/ops/normRMS.js +22 -0
- package/dist/ops/pack16.d.ts +2 -0
- package/dist/ops/pack16.js +2 -0
- package/dist/ops/qkv.d.ts +2 -0
- package/dist/ops/qkv.js +16 -0
- package/dist/ops/reshape16.d.ts +2 -0
- package/dist/ops/reshape16.js +33 -0
- package/dist/ops/rope.d.ts +3 -0
- package/dist/ops/rope.js +2 -0
- package/dist/ops/scatterSub.d.ts +2 -0
- package/dist/ops/scatterSub.js +13 -0
- package/dist/ops/slice16.d.ts +2 -0
- package/dist/ops/slice16.js +11 -0
- package/dist/ops/softmax16.d.ts +2 -0
- package/dist/ops/softmax16.js +9 -0
- package/dist/ops/sub16.d.ts +2 -0
- package/dist/ops/sub16.js +11 -0
- package/dist/ops/sum16.d.ts +2 -0
- package/dist/ops/sum16.js +13 -0
- package/dist/ops/transpose16.d.ts +3 -0
- package/dist/ops/transpose16.js +32 -0
- package/dist/ops/unpack16.d.ts +2 -0
- package/dist/ops/unpack16.js +2 -0
- package/dist/ops/webgl/adamAdjust.d.ts +1 -0
- package/dist/ops/webgl/adamAdjust.js +82 -0
- package/dist/ops/webgl/adamMoments.d.ts +1 -0
- package/dist/ops/webgl/adamMoments.js +44 -0
- package/dist/ops/webgl/appendCache.d.ts +1 -0
- package/dist/ops/webgl/appendCache.js +53 -0
- package/dist/ops/webgl/attentionMask.d.ts +1 -0
- package/dist/ops/webgl/attentionMask.js +64 -0
- package/dist/ops/webgl/dropout16.d.ts +1 -0
- package/dist/ops/webgl/dropout16.js +12 -0
- package/dist/ops/webgl/fusedSoftmax.d.ts +11 -0
- package/dist/ops/webgl/fusedSoftmax.js +70 -0
- package/dist/ops/webgl/gatherSub.d.ts +1 -0
- package/dist/ops/webgl/gatherSub.js +28 -0
- package/dist/ops/webgl/gelu.d.ts +2 -0
- package/dist/ops/webgl/gelu.js +48 -0
- package/dist/ops/webgl/log.d.ts +17 -0
- package/dist/ops/webgl/log.js +14 -0
- package/dist/ops/webgl/matMul16.d.ts +1 -0
- package/dist/ops/webgl/matMul16.js +37 -0
- package/dist/ops/webgl/matMulGelu.d.ts +21 -0
- package/dist/ops/webgl/matMulGelu.js +2 -0
- package/dist/ops/webgl/matMulMul.d.ts +14 -0
- package/dist/ops/webgl/matMulMul.js +24 -0
- package/dist/ops/webgl/mulDropout.d.ts +1 -0
- package/dist/ops/webgl/mulDropout.js +32 -0
- package/dist/ops/webgl/normRMS.d.ts +1 -0
- package/dist/ops/webgl/normRMS.js +114 -0
- package/dist/ops/webgl/qkv.d.ts +1 -0
- package/dist/ops/webgl/qkv.js +54 -0
- package/dist/ops/webgl/rope.d.ts +1 -0
- package/dist/ops/webgl/rope.js +72 -0
- package/dist/ops/webgl/scatterSub.d.ts +1 -0
- package/dist/ops/webgl/scatterSub.js +28 -0
- package/dist/ops/webgpu/adamAdjust.d.ts +1 -0
- package/dist/ops/webgpu/adamAdjust.js +77 -0
- package/dist/ops/webgpu/adamMoments.d.ts +1 -0
- package/dist/ops/webgpu/adamMoments.js +76 -0
- package/dist/ops/webgpu/add16.d.ts +1 -0
- package/dist/ops/webgpu/add16.js +14 -0
- package/dist/ops/webgpu/appendCache.d.ts +1 -0
- package/dist/ops/webgpu/appendCache.js +130 -0
- package/dist/ops/webgpu/attentionMask.d.ts +1 -0
- package/dist/ops/webgpu/attentionMask.js +42 -0
- package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
- package/dist/ops/webgpu/attentionMask32_program.js +62 -0
- package/dist/ops/webgpu/clipScale.d.ts +1 -0
- package/dist/ops/webgpu/clipScale.js +45 -0
- package/dist/ops/webgpu/concat16.d.ts +19 -0
- package/dist/ops/webgpu/concat16.js +111 -0
- package/dist/ops/webgpu/dropout16.d.ts +1 -0
- package/dist/ops/webgpu/dropout16.js +59 -0
- package/dist/ops/webgpu/gatherSub.d.ts +1 -0
- package/dist/ops/webgpu/gatherSub.js +52 -0
- package/dist/ops/webgpu/gelu.d.ts +14 -0
- package/dist/ops/webgpu/gelu.js +147 -0
- package/dist/ops/webgpu/index.d.ts +0 -0
- package/dist/ops/webgpu/index.js +26 -0
- package/dist/ops/webgpu/matMul16.d.ts +1 -0
- package/dist/ops/webgpu/matMul16.js +70 -0
- package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
- package/dist/ops/webgpu/matMul16_program.js +303 -0
- package/dist/ops/webgpu/mul16.d.ts +1 -0
- package/dist/ops/webgpu/mul16.js +14 -0
- package/dist/ops/webgpu/norm2.d.ts +1 -0
- package/dist/ops/webgpu/norm2.js +46 -0
- package/dist/ops/webgpu/normRMS.d.ts +1 -0
- package/dist/ops/webgpu/normRMS.js +26 -0
- package/dist/ops/webgpu/normRMS16_program.d.ts +10 -0
- package/dist/ops/webgpu/normRMS16_program.js +28 -0
- package/dist/ops/webgpu/normRMS32_program.d.ts +10 -0
- package/dist/ops/webgpu/normRMS32_program.js +28 -0
- package/dist/ops/webgpu/normRMSGrad.d.ts +1 -0
- package/dist/ops/webgpu/normRMSGrad.js +225 -0
- package/dist/ops/webgpu/pack16.d.ts +1 -0
- package/dist/ops/webgpu/pack16.js +21 -0
- package/dist/ops/webgpu/pack16_program.d.ts +19 -0
- package/dist/ops/webgpu/pack16_program.js +93 -0
- package/dist/ops/webgpu/qkv.d.ts +1 -0
- package/dist/ops/webgpu/qkv.js +64 -0
- package/dist/ops/webgpu/rope.d.ts +1 -0
- package/dist/ops/webgpu/rope.js +163 -0
- package/dist/ops/webgpu/scatterSub.d.ts +1 -0
- package/dist/ops/webgpu/scatterSub.js +53 -0
- package/dist/ops/webgpu/slice16.d.ts +7 -0
- package/dist/ops/webgpu/slice16.js +74 -0
- package/dist/ops/webgpu/softmax16.d.ts +17 -0
- package/dist/ops/webgpu/softmax16.js +18 -0
- package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
- package/dist/ops/webgpu/softmax16_program.js +89 -0
- package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
- package/dist/ops/webgpu/softmax16_subgroup_program.js +70 -0
- package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
- package/dist/ops/webgpu/softmax16grad.js +31 -0
- package/dist/ops/webgpu/sub16.d.ts +1 -0
- package/dist/ops/webgpu/sub16.js +14 -0
- package/dist/ops/webgpu/sum16.d.ts +1 -0
- package/dist/ops/webgpu/sum16.js +29 -0
- package/dist/ops/webgpu/transpose16.d.ts +1 -0
- package/dist/ops/webgpu/transpose16.js +37 -0
- package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
- package/dist/ops/webgpu/transpose16_program.js +51 -0
- package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
- package/dist/ops/webgpu/transpose16_shared_program.js +79 -0
- package/dist/ops/webgpu/unpack16.d.ts +1 -0
- package/dist/ops/webgpu/unpack16.js +60 -0
- package/dist/ops/webgpu/utils/binary_op.d.ts +35 -0
- package/dist/ops/webgpu/utils/binary_op.js +141 -0
- package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
- package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
- package/dist/ops/webgpu/utils/reductions.d.ts +43 -0
- package/dist/ops/webgpu/utils/reductions.js +263 -0
- package/dist/pack16-Ck-spx_F.js +39 -0
- package/dist/patches/webgpu_backend.d.ts +18 -0
- package/dist/patches/webgpu_backend.js +43 -0
- package/dist/patches/webgpu_base.d.ts +21 -0
- package/dist/patches/webgpu_base.js +22 -0
- package/dist/patches/webgpu_program.d.ts +36 -0
- package/dist/patches/webgpu_program.js +293 -0
- package/dist/pdf-UoDqCYzz.js +16726 -0
- package/dist/picomatch-3tUnMMbd.js +1063 -0
- package/dist/rope-CbeGlsV8.js +25 -0
- package/dist/selu_util-zkAx5doH.js +24 -0
- package/dist/shared-D1coEFea.js +1314 -0
- package/dist/shared-DOgWaqvL.js +5 -0
- package/dist/slice_util-Dgb3ANWI.js +208 -0
- package/dist/tfjs_backend-BjuQ5FqB.js +614 -0
- package/dist/tokeniser/BaseTokeniser.d.ts +33 -0
- package/dist/tokeniser/BaseTokeniser.js +2 -0
- package/dist/tokeniser/CharTokeniser.d.ts +24 -0
- package/dist/tokeniser/CharTokeniser.js +92 -0
- package/dist/tokeniser/bpe.d.ts +28 -0
- package/dist/tokeniser/bpe.js +170 -0
- package/dist/tokeniser/messages.d.ts +61 -0
- package/dist/tokeniser/messages.js +0 -0
- package/dist/tokeniser/type.d.ts +34 -0
- package/dist/tokeniser/type.js +0 -0
- package/dist/training/AdamW.d.ts +36 -0
- package/dist/training/AdamW.js +128 -0
- package/dist/training/BasicTrainer.d.ts +63 -0
- package/dist/training/BasicTrainer.js +265 -0
- package/dist/training/DatasetBuilder.d.ts +26 -0
- package/dist/training/DatasetBuilder.js +2 -0
- package/dist/training/Evaluator.d.ts +19 -0
- package/dist/training/Evaluator.js +48 -0
- package/dist/training/LRScheduler.d.ts +12 -0
- package/dist/training/LRScheduler.js +38 -0
- package/dist/training/PreTrainer.d.ts +11 -0
- package/dist/training/PreTrainer.js +22 -0
- package/dist/training/SFTTrainer.d.ts +12 -0
- package/dist/training/SFTTrainer.js +24 -0
- package/dist/training/loss.d.ts +3 -0
- package/dist/training/loss.js +19 -0
- package/dist/training/orthoGrad.d.ts +2 -0
- package/dist/training/orthoGrad.js +10 -0
- package/dist/training/sparseCrossEntropy.d.ts +7 -0
- package/dist/training/sparseCrossEntropy.js +47 -0
- package/dist/training/tasks/ConversationTask.d.ts +18 -0
- package/dist/training/tasks/ConversationTask.js +38 -0
- package/dist/training/tasks/PretrainingTask.d.ts +17 -0
- package/dist/training/tasks/PretrainingTask.js +42 -0
- package/dist/training/tasks/StartSentenceTask.d.ts +18 -0
- package/dist/training/tasks/StartSentenceTask.js +45 -0
- package/dist/training/tasks/Task.d.ts +22 -0
- package/dist/training/tasks/Task.js +55 -0
- package/dist/training/tasks/splitter.d.ts +5 -0
- package/dist/training/tasks/splitter.js +18 -0
- package/dist/training/types.d.ts +78 -0
- package/dist/training/types.js +0 -0
- package/dist/training/validation.d.ts +17 -0
- package/dist/training/validation.js +2 -0
- package/dist/utilities/arrayClose.d.ts +1 -0
- package/dist/utilities/arrayClose.js +16 -0
- package/dist/utilities/datasetID.d.ts +2 -0
- package/dist/utilities/datasetID.js +18 -0
- package/dist/utilities/dummy.d.ts +9 -0
- package/dist/utilities/dummy.js +36 -0
- package/dist/utilities/multinomialCPU.d.ts +2 -0
- package/dist/utilities/multinomialCPU.js +9 -0
- package/dist/utilities/naming.d.ts +4 -0
- package/dist/utilities/naming.js +0 -0
- package/dist/utilities/packed.d.ts +4 -0
- package/dist/utilities/packed.js +13 -0
- package/dist/utilities/parameters.d.ts +11 -0
- package/dist/utilities/parameters.js +38 -0
- package/dist/utilities/performance.d.ts +2 -0
- package/dist/utilities/performance.js +16 -0
- package/dist/utilities/profile.d.ts +17 -0
- package/dist/utilities/profile.js +33 -0
- package/dist/utilities/safetensors.d.ts +3 -0
- package/dist/utilities/safetensors.js +53 -0
- package/dist/utilities/sentences.d.ts +5 -0
- package/dist/utilities/sentences.js +32 -0
- package/dist/utilities/tokenParse.d.ts +1 -0
- package/dist/utilities/tokenParse.js +17 -0
- package/dist/utilities/topP.d.ts +1 -0
- package/dist/utilities/topP.js +12 -0
- package/dist/utilities/waitForModel.d.ts +2 -0
- package/dist/utilities/waitForModel.js +12 -0
- package/dist/utilities/weights.d.ts +12 -0
- package/dist/utilities/weights.js +40 -0
- package/dist/utilities/yielder.d.ts +1 -0
- package/dist/utilities/yielder.js +7 -0
- package/dist/webgpu-Dt7BMzWz.js +525 -0
- package/dist/webgpu_program-WOyIVMlZ.js +392 -0
- package/dist/webgpu_util-B_F3SShA.js +106 -0
- package/package.json +1 -1
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
import { _n as e, di as t, kt as n, ni as r, oi as i, qt as a } from "../dist-BewPQWjc.js";
|
|
2
|
+
import { AdamWOptimizer as o } from "./AdamW.js";
|
|
3
|
+
import { calculateAccuracy as s, calculateLoss as c } from "./loss.js";
|
|
4
|
+
import l from "./Evaluator.js";
|
|
5
|
+
import u from "../utilities/profile.js";
|
|
6
|
+
import { createTensorStatistics as d } from "../checks/weights.js";
|
|
7
|
+
//#region lib/training/BasicTrainer.ts
|
|
8
|
+
var f = {
|
|
9
|
+
logInterval: 1,
|
|
10
|
+
maxEpochs: 100,
|
|
11
|
+
sftMode: "full",
|
|
12
|
+
batchSize: 32
|
|
13
|
+
}, p = {
|
|
14
|
+
learningRate: 3e-4,
|
|
15
|
+
beta1: .9,
|
|
16
|
+
beta2: .99,
|
|
17
|
+
epsilon: 1e-8,
|
|
18
|
+
weightDecay: .01,
|
|
19
|
+
warmupSteps: 100,
|
|
20
|
+
decayEpochs: 100,
|
|
21
|
+
epochSteps: 1e4,
|
|
22
|
+
minLearningRate: 1e-5,
|
|
23
|
+
lossScaling: 1
|
|
24
|
+
}, m = class {
|
|
25
|
+
tokenizer;
|
|
26
|
+
model;
|
|
27
|
+
optimizer;
|
|
28
|
+
running = !1;
|
|
29
|
+
lastState;
|
|
30
|
+
_gradientCheckpointing = !1;
|
|
31
|
+
_mixedPrecision = !1;
|
|
32
|
+
maskedLoss = !1;
|
|
33
|
+
optimizerConfig;
|
|
34
|
+
metrics = /* @__PURE__ */ new Set();
|
|
35
|
+
_labelSmoothing = 0;
|
|
36
|
+
_layerDrop = 0;
|
|
37
|
+
_dropout = 0;
|
|
38
|
+
constructor(e, t, n, r) {
|
|
39
|
+
this.tokenizer = t, this.model = e, this.optimizerConfig = {
|
|
40
|
+
...p,
|
|
41
|
+
...n,
|
|
42
|
+
lossScaling: e.lossScaling
|
|
43
|
+
};
|
|
44
|
+
let i = r || new o(this.optimizerConfig);
|
|
45
|
+
r && r.updateConfig(this.optimizerConfig), this.optimizer = i;
|
|
46
|
+
}
|
|
47
|
+
setLossMasking() {
|
|
48
|
+
this.maskedLoss = !0;
|
|
49
|
+
}
|
|
50
|
+
setGradientCheckpointing(e) {
|
|
51
|
+
this._gradientCheckpointing = e;
|
|
52
|
+
}
|
|
53
|
+
setMixedPrecision(e) {
|
|
54
|
+
this._mixedPrecision = e;
|
|
55
|
+
}
|
|
56
|
+
setLabelSmoothing(e) {
|
|
57
|
+
this._labelSmoothing = e;
|
|
58
|
+
}
|
|
59
|
+
setDropout(e) {
|
|
60
|
+
this._dropout = e;
|
|
61
|
+
}
|
|
62
|
+
setLayerDrop(e) {
|
|
63
|
+
this._layerDrop = e;
|
|
64
|
+
}
|
|
65
|
+
setLearningRate(e) {
|
|
66
|
+
this.optimizerConfig.learningRate = e, this.updateOptimizer();
|
|
67
|
+
}
|
|
68
|
+
setMetrics(e) {
|
|
69
|
+
this.metrics = new Set(e);
|
|
70
|
+
}
|
|
71
|
+
reset() {
|
|
72
|
+
this.lastState = void 0, this.running = !1;
|
|
73
|
+
}
|
|
74
|
+
stop() {
|
|
75
|
+
this.running = !1;
|
|
76
|
+
}
|
|
77
|
+
get isRunning() {
|
|
78
|
+
return this.running;
|
|
79
|
+
}
|
|
80
|
+
getOptimizer() {
|
|
81
|
+
return this.optimizer;
|
|
82
|
+
}
|
|
83
|
+
updateOptimizer(e) {
|
|
84
|
+
e && (this.optimizerConfig = {
|
|
85
|
+
...this.optimizerConfig,
|
|
86
|
+
...e
|
|
87
|
+
}), this.optimizer.updateConfig(this.optimizerConfig);
|
|
88
|
+
}
|
|
89
|
+
resumeFromLog(e) {
|
|
90
|
+
(!this.lastState || this.lastState.step === 0) && (this.lastState = {
|
|
91
|
+
losses: [],
|
|
92
|
+
validationLosses: [],
|
|
93
|
+
logStartTime: 0,
|
|
94
|
+
step: e.step,
|
|
95
|
+
lastLoss: e.trainingMetrics.loss,
|
|
96
|
+
totalSteps: e.step,
|
|
97
|
+
trainingDuration: e.duration
|
|
98
|
+
});
|
|
99
|
+
}
|
|
100
|
+
trainStep(n, o, l = !1, u = !1) {
|
|
101
|
+
return t(() => {
|
|
102
|
+
this.model.getProfiler()?.startMemory();
|
|
103
|
+
let { xs: t, ys: d } = o, { value: f, grads: p } = a(() => {
|
|
104
|
+
let r = this.model.forward({
|
|
105
|
+
training: !0,
|
|
106
|
+
checkpointing: this._gradientCheckpointing,
|
|
107
|
+
mixedPrecision: this._mixedPrecision,
|
|
108
|
+
dropout: this._dropout,
|
|
109
|
+
layerDrop: this._layerDrop,
|
|
110
|
+
ropePositionOffset: 0
|
|
111
|
+
}, t), a = c(r, d, this.maskedLoss, !1, this._labelSmoothing);
|
|
112
|
+
this.metrics.has("accuracy") && (n.accuracy = s(r, d), i(n.accuracy)), r.dispose();
|
|
113
|
+
let o = a.mul(e(this.optimizerConfig.lossScaling));
|
|
114
|
+
return a.dispose(), o;
|
|
115
|
+
});
|
|
116
|
+
if (l) this.model.getProfiler()?.endMemory("Training");
|
|
117
|
+
else {
|
|
118
|
+
let e = this.optimizer.applyGradients(p);
|
|
119
|
+
this.metrics.has("gradientNorm") ? (n.gradientNorm = e, i(e)) : (n.gradientNorm = void 0, e.dispose());
|
|
120
|
+
let t = Object.keys(p);
|
|
121
|
+
this.model.weightStore.touchVariables(t), this.model.getProfiler()?.endMemory("Training"), u ? (n.gradients = p, Object.values(p).forEach((e) => i(e))) : r(p);
|
|
122
|
+
}
|
|
123
|
+
return f.mul(e(1 / this.optimizerConfig.lossScaling));
|
|
124
|
+
});
|
|
125
|
+
}
|
|
126
|
+
async dummyPass() {
|
|
127
|
+
let e = n([1, this.model.config.blockSize], "int32"), t = n([1, this.model.config.blockSize], "int32");
|
|
128
|
+
try {
|
|
129
|
+
let n = this.trainStep({}, {
|
|
130
|
+
xs: e,
|
|
131
|
+
ys: t
|
|
132
|
+
}, !0);
|
|
133
|
+
await n.data(), n.dispose();
|
|
134
|
+
} catch (e) {
|
|
135
|
+
console.error("Error during dummy pass:", e);
|
|
136
|
+
} finally {
|
|
137
|
+
e.dispose(), t.dispose();
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
dispose() {
|
|
141
|
+
this.optimizer && this.optimizer.dispose();
|
|
142
|
+
}
|
|
143
|
+
createEmptyState() {
|
|
144
|
+
return {
|
|
145
|
+
step: 0,
|
|
146
|
+
lastLoss: 1e6,
|
|
147
|
+
totalSteps: 0,
|
|
148
|
+
losses: [],
|
|
149
|
+
validationLosses: [],
|
|
150
|
+
logStartTime: 0,
|
|
151
|
+
trainingDuration: 0,
|
|
152
|
+
...this.lastState || {}
|
|
153
|
+
};
|
|
154
|
+
}
|
|
155
|
+
async stepDataset(e, t, n) {
|
|
156
|
+
let { logInterval: i = 10 } = {
|
|
157
|
+
...f,
|
|
158
|
+
...t
|
|
159
|
+
};
|
|
160
|
+
t.metrics && this.setMetrics(t.metrics);
|
|
161
|
+
let a = Date.now(), o = this.createEmptyState();
|
|
162
|
+
this.lastState = o, await this.dummyPass(), this.metrics.has("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new u())), this.running = !0, o.logStartTime = a;
|
|
163
|
+
let s = n ? new l(this.model, n, this.maskedLoss) : void 0, c = await e.iterator();
|
|
164
|
+
try {
|
|
165
|
+
for (; this.running;) {
|
|
166
|
+
let e = await c.next();
|
|
167
|
+
if (e.done) break;
|
|
168
|
+
let n = e.value, r = this.trainStep(o, n, !1);
|
|
169
|
+
if (t.debug) {
|
|
170
|
+
let e = (await r.data())[0];
|
|
171
|
+
if (isNaN(e) || !isFinite(e)) throw console.error("Invalid loss value:", e), console.error("Batch xs:", n.xs.toString()), console.error("Batch ys:", n.ys.toString()), console.error("State:", o), Error("Loss is NaN or Infinity");
|
|
172
|
+
console.log(`Step ${o.step}: Loss = ${e}`);
|
|
173
|
+
}
|
|
174
|
+
n.xs.dispose(), n.ys.dispose(), o.step++, o.totalSteps++, o.step % i === 0 ? await this.performLogging(r, n.xs.shape[0], t, s) : (o.gradientNorm &&= (o.gradientNorm.dispose(), void 0), o.accuracy &&= (o.accuracy.dispose(), void 0)), r.dispose();
|
|
175
|
+
}
|
|
176
|
+
} catch (e) {
|
|
177
|
+
throw console.error("Training error:", e), e;
|
|
178
|
+
}
|
|
179
|
+
throw this.model.trainingState = {
|
|
180
|
+
steps: o.totalSteps,
|
|
181
|
+
learningRate: this.optimizer.lr,
|
|
182
|
+
batchSize: t.batchSize || 32,
|
|
183
|
+
loss: o.lastLoss,
|
|
184
|
+
tokensProcessed: o.totalSteps * (t.batchSize || 32) * this.model.config.blockSize,
|
|
185
|
+
duration: o.trainingDuration
|
|
186
|
+
}, r(), this.running = !1, Error("No log returned before training stopped.");
|
|
187
|
+
}
|
|
188
|
+
async performLogging(e, t, n, r) {
|
|
189
|
+
let i = n?.onStep, a = this.metrics.has("gradientStatistics"), o = (await e.data())[0], s = this.lastState;
|
|
190
|
+
s.lastLoss = o, s.trainingDuration += Date.now() - s.logStartTime;
|
|
191
|
+
let c = s.totalSteps * t * this.model.config.blockSize, l = {
|
|
192
|
+
trainingMetrics: {
|
|
193
|
+
loss: s.lastLoss,
|
|
194
|
+
perplexity: this.metrics.has("perplexity") ? Math.exp(s.lastLoss) : void 0,
|
|
195
|
+
accuracy: s.accuracy ? (await s.accuracy.data())[0] : void 0
|
|
196
|
+
},
|
|
197
|
+
step: s.step,
|
|
198
|
+
time: Date.now() - s.logStartTime,
|
|
199
|
+
gradientNorm: s.gradientNorm ? (await s.gradientNorm.data())[1] : void 0,
|
|
200
|
+
batchSize: t,
|
|
201
|
+
learningRate: this.metrics.has("learningRate") ? this.optimizer.lr : void 0,
|
|
202
|
+
duration: s.trainingDuration,
|
|
203
|
+
totalTokens: c,
|
|
204
|
+
tokensPerSecond: c / (s.trainingDuration / 1e3),
|
|
205
|
+
memoryUsage: this.metrics.has("memoryUsage") ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
|
|
206
|
+
};
|
|
207
|
+
if (s.gradientNorm &&= (s.gradientNorm.dispose(), void 0), s.accuracy &&= (s.accuracy.dispose(), void 0), this.model.trainingState = {
|
|
208
|
+
steps: s.totalSteps,
|
|
209
|
+
learningRate: this.optimizer.lr,
|
|
210
|
+
batchSize: t,
|
|
211
|
+
loss: s.lastLoss,
|
|
212
|
+
tokensProcessed: c,
|
|
213
|
+
duration: s.trainingDuration
|
|
214
|
+
}, a && s.gradients) {
|
|
215
|
+
let e = /* @__PURE__ */ new Map();
|
|
216
|
+
for (let [t, n] of Object.entries(s.gradients)) e.set(t, await d(n)), n.dispose();
|
|
217
|
+
l.gradientMetrics = e;
|
|
218
|
+
}
|
|
219
|
+
if (r) try {
|
|
220
|
+
let e = await r.evaluate(5);
|
|
221
|
+
Array.isArray(e) ? l.validationMetrics = {
|
|
222
|
+
loss: e[0].loss,
|
|
223
|
+
accuracy: e[0].accuracy
|
|
224
|
+
} : (s.validationLosses.push(e.loss), l.validationMetrics = {
|
|
225
|
+
accuracy: e.accuracy,
|
|
226
|
+
loss: e.loss,
|
|
227
|
+
perplexity: this.metrics.has("perplexity") ? Math.exp(e.loss) : void 0
|
|
228
|
+
});
|
|
229
|
+
} catch (e) {
|
|
230
|
+
console.error("Validation error:", e);
|
|
231
|
+
}
|
|
232
|
+
i && await i(l), s.logStartTime = Date.now();
|
|
233
|
+
}
|
|
234
|
+
async trainOnDataset(e, t, n) {
|
|
235
|
+
let { logInterval: i = 10, maxEpochs: a = Infinity } = {
|
|
236
|
+
...f,
|
|
237
|
+
...t
|
|
238
|
+
}, o = a * (t?.epochSteps || 1e3);
|
|
239
|
+
t.metrics && this.setMetrics(t.metrics);
|
|
240
|
+
let s = Date.now(), c = this.createEmptyState();
|
|
241
|
+
this.lastState = c, await this.dummyPass(), t?.metrics?.includes("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new u())), this.running = !0, c.logStartTime = s;
|
|
242
|
+
let d = n ? new l(this.model, n, this.maskedLoss) : void 0, p = await e.iterator();
|
|
243
|
+
try {
|
|
244
|
+
for (; this.running;) {
|
|
245
|
+
let e = await p.next();
|
|
246
|
+
if (e.done) break;
|
|
247
|
+
let n = e.value, r = c.step % i === 0, a = (t?.metrics?.includes("gradientStatistics") || !1) && r, s = this.trainStep(c, n, !1, a);
|
|
248
|
+
if (t.debug) {
|
|
249
|
+
let e = (await s.data())[0];
|
|
250
|
+
if (isNaN(e) || !isFinite(e)) throw console.error("Invalid loss value:", e), console.error("Batch xs:", await n.xs.array()), console.error("Batch ys:", await n.ys.array()), console.error("State:", c), Error("Loss is NaN or Infinity");
|
|
251
|
+
console.log(`Step ${c.step}: Loss = ${e}`);
|
|
252
|
+
}
|
|
253
|
+
n.xs.dispose(), n.ys.dispose(), c.step++, c.totalSteps++, r ? await this.performLogging(s, n.xs.shape[0], t, d) : (c.gradientNorm &&= (c.gradientNorm.dispose(), void 0), c.accuracy &&= (c.accuracy.dispose(), void 0)), s.dispose(), c.step >= o && this.stop();
|
|
254
|
+
}
|
|
255
|
+
} catch (e) {
|
|
256
|
+
throw console.error("Training error:", e), r(), e;
|
|
257
|
+
}
|
|
258
|
+
return r(), this.running = !1, {
|
|
259
|
+
losses: c.losses,
|
|
260
|
+
validationLosses: c.validationLosses
|
|
261
|
+
};
|
|
262
|
+
}
|
|
263
|
+
};
|
|
264
|
+
//#endregion
|
|
265
|
+
export { m as default };
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import { Tensor } from '@tensorflow/tfjs-core';
|
|
2
|
+
import { Conversation, ITokeniser } from '../tokeniser/type';
|
|
3
|
+
import { Dataset } from '@tensorflow/tfjs-data';
|
|
4
|
+
export declare const PAGE_FACTOR = 8;
|
|
5
|
+
export declare function flattenTokens(textData: Conversation[][], tokenizer: ITokeniser): Uint16Array;
|
|
6
|
+
export declare function flattenTokensWithMask(textData: Conversation[][], tokenizer: ITokeniser): {
|
|
7
|
+
tokens: Uint16Array;
|
|
8
|
+
mask: Uint8Array;
|
|
9
|
+
};
|
|
10
|
+
export declare function shuffle(array: Uint32Array): Uint32Array;
|
|
11
|
+
export interface DatasetState {
|
|
12
|
+
shuffledIndexes: Uint32Array;
|
|
13
|
+
step: number;
|
|
14
|
+
}
|
|
15
|
+
export declare class DatasetBuilder {
|
|
16
|
+
tokenizer: ITokeniser;
|
|
17
|
+
blockSize: number;
|
|
18
|
+
constructor(tokenizer: ITokeniser, blockSize?: number);
|
|
19
|
+
createTextDataset(flatTokens: Uint16Array, batchSize?: number, indexes?: Uint32Array, mask?: Uint8Array, ignoreIndex?: number): Promise<{
|
|
20
|
+
dataset: Dataset<{
|
|
21
|
+
xs: Tensor;
|
|
22
|
+
ys: Tensor;
|
|
23
|
+
}>;
|
|
24
|
+
state: DatasetState;
|
|
25
|
+
}>;
|
|
26
|
+
}
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import { Dataset } from '@tensorflow/tfjs-data';
|
|
2
|
+
import { TensorContainer } from '@tensorflow/tfjs-core';
|
|
3
|
+
import { default as Model, ModelForwardAttributes } from '../../models/model';
|
|
4
|
+
interface Result {
|
|
5
|
+
loss: number;
|
|
6
|
+
accuracy: number;
|
|
7
|
+
}
|
|
8
|
+
export default class Evaluator {
|
|
9
|
+
private model;
|
|
10
|
+
private iterator?;
|
|
11
|
+
private xs?;
|
|
12
|
+
private ys?;
|
|
13
|
+
private masked;
|
|
14
|
+
constructor(model: Model<ModelForwardAttributes>, dataset: Dataset<TensorContainer>, masked?: boolean);
|
|
15
|
+
dispose(): void;
|
|
16
|
+
private calculateBatchLoss;
|
|
17
|
+
evaluate(maxBatches?: number): Promise<Result | Result[]>;
|
|
18
|
+
}
|
|
19
|
+
export {};
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import { di as e } from "../dist-BewPQWjc.js";
|
|
2
|
+
import { calculateAccuracy as t, calculateLoss as n } from "./loss.js";
|
|
3
|
+
//#region lib/training/Evaluator.ts
|
|
4
|
+
var r = class {
|
|
5
|
+
model;
|
|
6
|
+
iterator;
|
|
7
|
+
xs;
|
|
8
|
+
ys;
|
|
9
|
+
masked = !1;
|
|
10
|
+
constructor(e, t, n) {
|
|
11
|
+
this.model = e, this.masked = !!n, this.iterator = t.iterator();
|
|
12
|
+
}
|
|
13
|
+
dispose() {
|
|
14
|
+
this.xs && this.xs.dispose(), this.ys && this.ys.dispose();
|
|
15
|
+
}
|
|
16
|
+
async calculateBatchLoss(r, i, a, o) {
|
|
17
|
+
let [s, c] = e(() => {
|
|
18
|
+
let e = this.model.forward({ training: !1 }, r), s = n(e, i, o, a), c = t(e, i);
|
|
19
|
+
return e.dispose(), [s, c];
|
|
20
|
+
}), l = await s.array(), u = await c.array(), d = l, f = u;
|
|
21
|
+
return c.dispose(), s.dispose(), Array.isArray(d) ? d.map((e) => ({
|
|
22
|
+
loss: e,
|
|
23
|
+
accuracy: f
|
|
24
|
+
})) : {
|
|
25
|
+
loss: d,
|
|
26
|
+
accuracy: f
|
|
27
|
+
};
|
|
28
|
+
}
|
|
29
|
+
async evaluate(e = 100) {
|
|
30
|
+
let t = 0, n = 0, r = 0;
|
|
31
|
+
if (this.iterator) {
|
|
32
|
+
let i = await this.iterator;
|
|
33
|
+
for (let a = 0; a < e; a++) {
|
|
34
|
+
let e = await i.next();
|
|
35
|
+
if (e.done) break;
|
|
36
|
+
let { xs: a, ys: o } = e.value, s = await this.calculateBatchLoss(a, o, !1, this.masked);
|
|
37
|
+
a.dispose(), o.dispose(), t += s.loss, n += s.accuracy, r++;
|
|
38
|
+
}
|
|
39
|
+
return {
|
|
40
|
+
loss: t / r,
|
|
41
|
+
accuracy: n / r
|
|
42
|
+
};
|
|
43
|
+
} else if (this.xs && this.ys) return this.calculateBatchLoss(this.xs, this.ys, !0, !0);
|
|
44
|
+
throw Error("No data available for evaluation");
|
|
45
|
+
}
|
|
46
|
+
};
|
|
47
|
+
//#endregion
|
|
48
|
+
export { r as default };
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import { LRSchedulerConfig } from './types';
|
|
2
|
+
export default class LRScheduler {
|
|
3
|
+
protected learningRate: number;
|
|
4
|
+
private config;
|
|
5
|
+
private step;
|
|
6
|
+
private startLearningRate;
|
|
7
|
+
constructor(learningRate: number, config: LRSchedulerConfig);
|
|
8
|
+
serializeConfig(): LRSchedulerConfig;
|
|
9
|
+
updateConfig(newConfig: Partial<LRSchedulerConfig>, learningRate?: number): void;
|
|
10
|
+
get lr(): number;
|
|
11
|
+
getNextLR(): number;
|
|
12
|
+
}
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
//#region lib/training/LRScheduler.ts
|
|
2
|
+
var e = class {
|
|
3
|
+
learningRate;
|
|
4
|
+
config;
|
|
5
|
+
step = 0;
|
|
6
|
+
startLearningRate;
|
|
7
|
+
constructor(e, t) {
|
|
8
|
+
this.learningRate = e, this.config = t, this.startLearningRate = e, t.step !== void 0 && (this.step = t.step);
|
|
9
|
+
}
|
|
10
|
+
serializeConfig() {
|
|
11
|
+
return {
|
|
12
|
+
...this.config,
|
|
13
|
+
step: this.step
|
|
14
|
+
};
|
|
15
|
+
}
|
|
16
|
+
updateConfig(e, t) {
|
|
17
|
+
this.config = {
|
|
18
|
+
...this.config,
|
|
19
|
+
...e
|
|
20
|
+
}, t !== void 0 && (this.startLearningRate = t);
|
|
21
|
+
}
|
|
22
|
+
get lr() {
|
|
23
|
+
return this.learningRate;
|
|
24
|
+
}
|
|
25
|
+
getNextLR() {
|
|
26
|
+
let e = this.step;
|
|
27
|
+
if (this.config.warmupSteps > 0 && e < this.config.warmupSteps) {
|
|
28
|
+
let t = (e + 1) / this.config.warmupSteps, n = this.startLearningRate * t;
|
|
29
|
+
return this.learningRate = n, this.step++, n;
|
|
30
|
+
}
|
|
31
|
+
let t = this.config.epochSteps * this.config.decayEpochs;
|
|
32
|
+
if (e >= t || t <= this.config.warmupSteps) return this.learningRate = this.config.minLearningRate, this.step++, this.config.minLearningRate;
|
|
33
|
+
let n = (e - this.config.warmupSteps) / (t - this.config.warmupSteps), r = .5 * (1 + Math.cos(Math.PI * n)), i = this.config.minLearningRate + r * (this.startLearningRate - this.config.minLearningRate);
|
|
34
|
+
return this.learningRate = i, this.step++, i;
|
|
35
|
+
}
|
|
36
|
+
};
|
|
37
|
+
//#endregion
|
|
38
|
+
export { e as default };
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import { default as Model, ModelForwardAttributes } from '../../models/model';
|
|
2
|
+
import { default as BasicTrainer } from './BasicTrainer';
|
|
3
|
+
import { ITokeniser } from '../../tokeniser/type';
|
|
4
|
+
import { DatasetBuilder } from './DatasetBuilder';
|
|
5
|
+
import { AdamWOptimizer } from './AdamW';
|
|
6
|
+
import { AdamWOptimizerConfig } from './types';
|
|
7
|
+
export default class PreTrainer extends BasicTrainer {
|
|
8
|
+
tokenizer: ITokeniser;
|
|
9
|
+
datasetBuilder: DatasetBuilder;
|
|
10
|
+
constructor(model: Model<ModelForwardAttributes>, tokenizer: ITokeniser, optConfig?: Partial<AdamWOptimizerConfig>, optimizer?: AdamWOptimizer);
|
|
11
|
+
}
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import { t as e } from "../DatasetBuilder-DgURD85T.js";
|
|
2
|
+
import t from "./BasicTrainer.js";
|
|
3
|
+
//#region lib/training/PreTrainer.ts
|
|
4
|
+
var n = {
|
|
5
|
+
decayEpochs: 100,
|
|
6
|
+
epochSteps: 1e4,
|
|
7
|
+
warmupSteps: 1e3,
|
|
8
|
+
minLearningRate: 3e-5,
|
|
9
|
+
weightDecay: .1,
|
|
10
|
+
learningRate: 3e-4
|
|
11
|
+
}, r = class extends t {
|
|
12
|
+
tokenizer;
|
|
13
|
+
datasetBuilder;
|
|
14
|
+
constructor(t, r, i, a) {
|
|
15
|
+
super(t, r, {
|
|
16
|
+
...n,
|
|
17
|
+
...i
|
|
18
|
+
}, a), this.tokenizer = r, this.optimizerConfig.minLearningRate = i?.minLearningRate ?? this.optimizerConfig.learningRate / 20, this.updateOptimizer(), this.datasetBuilder = new e(r, t.config.blockSize);
|
|
19
|
+
}
|
|
20
|
+
};
|
|
21
|
+
//#endregion
|
|
22
|
+
export { r as default };
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import { default as Model, ModelForwardAttributes } from '../../models/model';
|
|
2
|
+
import { default as BasicTrainer } from './BasicTrainer';
|
|
3
|
+
import { ITokeniser } from '../../tokeniser/type';
|
|
4
|
+
import { AdamWOptimizer } from './AdamW';
|
|
5
|
+
import { AdamWOptimizerConfig } from './types';
|
|
6
|
+
import { DatasetBuilder } from './DatasetBuilder';
|
|
7
|
+
export default class SFTTrainer extends BasicTrainer {
|
|
8
|
+
tokenizer: ITokeniser;
|
|
9
|
+
datasetBuilder: DatasetBuilder;
|
|
10
|
+
loraName?: string;
|
|
11
|
+
constructor(model: Model<ModelForwardAttributes>, tokenizer: ITokeniser, optConfig?: Partial<AdamWOptimizerConfig>, optimizer?: AdamWOptimizer);
|
|
12
|
+
}
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import { t as e } from "../DatasetBuilder-DgURD85T.js";
|
|
2
|
+
import t from "./BasicTrainer.js";
|
|
3
|
+
//#region lib/training/SFTTrainer.ts
|
|
4
|
+
var n = {
|
|
5
|
+
decayEpochs: 100,
|
|
6
|
+
epochSteps: 1e4,
|
|
7
|
+
warmupSteps: 100,
|
|
8
|
+
minLearningRate: 1e-5,
|
|
9
|
+
weightDecay: .1,
|
|
10
|
+
beta2: .95,
|
|
11
|
+
learningRate: 3e-4
|
|
12
|
+
}, r = class extends t {
|
|
13
|
+
tokenizer;
|
|
14
|
+
datasetBuilder;
|
|
15
|
+
loraName;
|
|
16
|
+
constructor(t, r, i, a) {
|
|
17
|
+
super(t, r, {
|
|
18
|
+
...n,
|
|
19
|
+
...i
|
|
20
|
+
}, a), this.tokenizer = r, this.optimizerConfig.minLearningRate = i?.minLearningRate ?? this.optimizerConfig.learningRate / 20, this.updateOptimizer(), this.datasetBuilder = new e(r, t.config.blockSize), this.maskedLoss = !0;
|
|
21
|
+
}
|
|
22
|
+
};
|
|
23
|
+
//#endregion
|
|
24
|
+
export { r as default };
|
|
@@ -0,0 +1,3 @@
|
|
|
1
|
+
import { Tensor } from '@tensorflow/tfjs-core';
|
|
2
|
+
export declare function calculateLoss(logits: Tensor, targets: Tensor, masked?: boolean, keepBatch?: boolean, labelSmoothing?: number): Tensor;
|
|
3
|
+
export declare function calculateAccuracy(logits: Tensor, targets: Tensor): Tensor;
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import { createSoftmaxCrossEntropyWithGrad as e } from "./sparseCrossEntropy.js";
|
|
2
|
+
//#region lib/training/loss.ts
|
|
3
|
+
function t(t, n, r, i, a) {
|
|
4
|
+
try {
|
|
5
|
+
return e(r, i, a && a > 0 ? a : void 0)(t, n);
|
|
6
|
+
} catch (e) {
|
|
7
|
+
throw console.error("Error computing loss:", e), Error(`Loss computation failed: ${e}`);
|
|
8
|
+
}
|
|
9
|
+
}
|
|
10
|
+
function n(e, t) {
|
|
11
|
+
try {
|
|
12
|
+
let n = e.argMax(-1), r = n.equal(t).cast("float32"), i = r.mean();
|
|
13
|
+
return n.dispose(), r.dispose(), i;
|
|
14
|
+
} catch (e) {
|
|
15
|
+
throw console.error("Error computing accuracy:", e), Error(`Accuracy computation failed: ${e}`);
|
|
16
|
+
}
|
|
17
|
+
}
|
|
18
|
+
//#endregion
|
|
19
|
+
export { n as calculateAccuracy, t as calculateLoss };
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
import { di as e } from "../dist-BewPQWjc.js";
|
|
2
|
+
//#region lib/training/orthoGrad.ts
|
|
3
|
+
function t(t, n, r) {
|
|
4
|
+
return e(() => {
|
|
5
|
+
let e = t.reshape([-1]), i = n.reshape([-1]), a = e.mul(e).sum().add(r), o = e.mul(i).sum().div(a), s = i.sub(e.mul(o)), c = i.norm(), l = s.norm().add(r);
|
|
6
|
+
return s.mul(c.div(l)).reshape(n.shape);
|
|
7
|
+
});
|
|
8
|
+
}
|
|
9
|
+
//#endregion
|
|
10
|
+
export { t as orthogonalizeGradient };
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs-core';
|
|
2
|
+
/**
|
|
3
|
+
* Numerically stable sparse cross-entropy with gradient support
|
|
4
|
+
* This version handles potential numerical issues better
|
|
5
|
+
*/
|
|
6
|
+
export declare function sparseSoftmaxCrossEntropy(logits: tf.Tensor, labels: tf.Tensor, validMask?: tf.Tensor, keepBatch?: boolean, originalBatchShape?: number[], labelSmoothing?: number): tf.Tensor;
|
|
7
|
+
export declare function createSoftmaxCrossEntropyWithGrad(masked?: boolean, keepBatch?: boolean, labelSmoothing?: number): (...args: tf.Tensor[]) => tf.Tensor<tf.Rank>;
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import { At as e, Bt as t, Gr as n, Nn as r, Pn as i, Rt as a, St as o, Wr as s, Wt as c, Y as l, _n as u, bn as d, di as f, mn as p, qr as m } from "../dist-BewPQWjc.js";
|
|
2
|
+
import { gatherSub as h } from "../ops/gatherSub.js";
|
|
3
|
+
import { scatterSub as g } from "../ops/scatterSub.js";
|
|
4
|
+
//#region lib/training/sparseCrossEntropy.ts
|
|
5
|
+
function _(r, i, o, c, l, u = 0) {
|
|
6
|
+
return f(() => {
|
|
7
|
+
let f = r.shape[r.shape.length - 1], g = l || r.shape.slice(0, -1), _ = g.reduce((e, t) => e * t, 1), v = r.shape.length > 2 ? r.reshape([_, f]) : r, y = i.shape.length > 1 ? i.reshape([_]).cast("int32") : i.cast("int32"), b = t(v, d(v, -1, !0)), x = a(b, -1), S = h(x, y, b), C;
|
|
8
|
+
if (u > 0) {
|
|
9
|
+
let n = t(x, e(b, -1));
|
|
10
|
+
C = m(s(S, 1 - u), s(n, u));
|
|
11
|
+
} else C = S;
|
|
12
|
+
if (o) if (C = s(C, o), c) {
|
|
13
|
+
let e = p(o.reshape(g), -1);
|
|
14
|
+
C = n(p(C.reshape(g), -1), e);
|
|
15
|
+
} else {
|
|
16
|
+
let e = p(o);
|
|
17
|
+
C = n(p(C), e);
|
|
18
|
+
}
|
|
19
|
+
else C = c ? e(C.reshape(g), -1) : e(C);
|
|
20
|
+
return C;
|
|
21
|
+
});
|
|
22
|
+
}
|
|
23
|
+
function v(e, n, a = 0) {
|
|
24
|
+
return c((c, d, m) => {
|
|
25
|
+
let h = c.shape[c.shape.length - 1], v = c.shape.slice(0, -1), y = v.reduce((e, t) => e * t, 1), b = c.reshape([y, h]), x = d.reshape([y]).cast("int32"), S, C = null;
|
|
26
|
+
if (e) {
|
|
27
|
+
let e = u(65535, "int32"), t = o(x, e);
|
|
28
|
+
C = t.cast("float32"), S = i(t, x, r(x)), e.dispose(), t.dispose();
|
|
29
|
+
} else S = x;
|
|
30
|
+
let w = _(b, S, C || void 0, n, v, a);
|
|
31
|
+
return m(C ? [
|
|
32
|
+
b,
|
|
33
|
+
S,
|
|
34
|
+
C
|
|
35
|
+
] : [b, S]), b.dispose(), x.dispose(), {
|
|
36
|
+
value: w,
|
|
37
|
+
gradFunc: (n, i) => f(() => {
|
|
38
|
+
let o = i[0], f = i[1], m = e ? i[2] : void 0, _ = l(o), v = m ? p(m) : u(o.shape[0], "float32"), y = n.div(v).broadcastTo([o.shape[0]]), b = m && e ? s(y, m) : y, x;
|
|
39
|
+
x = a > 0 ? g(t(_, a / h), f, s(b, 1 - a)) : g(_, f, b);
|
|
40
|
+
let S = r(d);
|
|
41
|
+
return [x.reshape(c.shape), S];
|
|
42
|
+
})
|
|
43
|
+
};
|
|
44
|
+
});
|
|
45
|
+
}
|
|
46
|
+
//#endregion
|
|
47
|
+
export { v as createSoftmaxCrossEntropyWithGrad, _ as sparseSoftmaxCrossEntropy };
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import { Conversation, ITokeniser } from '../../../main';
|
|
2
|
+
import { Task } from './Task';
|
|
3
|
+
export default class ConversationTask extends Task {
|
|
4
|
+
private rawConvo;
|
|
5
|
+
private shuffledIndices;
|
|
6
|
+
private index;
|
|
7
|
+
get length(): number;
|
|
8
|
+
constructor(conversations: Conversation[][]);
|
|
9
|
+
hasMoreConversations(): boolean;
|
|
10
|
+
nextConversation(): Conversation[] | null;
|
|
11
|
+
nextTokens(tokeniser: ITokeniser): number[] | null;
|
|
12
|
+
nextTokens(tokeniser: ITokeniser, masking: boolean): {
|
|
13
|
+
tokens: number[];
|
|
14
|
+
mask: boolean[];
|
|
15
|
+
} | null;
|
|
16
|
+
shuffle(): void;
|
|
17
|
+
estimateTokens(tokeniser: ITokeniser): Promise<number>;
|
|
18
|
+
}
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import { a as e } from "../../DatasetBuilder-DgURD85T.js";
|
|
2
|
+
import { Task as t } from "./Task.js";
|
|
3
|
+
//#region lib/training/tasks/ConversationTask.ts
|
|
4
|
+
var n = class extends t {
|
|
5
|
+
rawConvo;
|
|
6
|
+
shuffledIndices = null;
|
|
7
|
+
index = 0;
|
|
8
|
+
get length() {
|
|
9
|
+
return this.rawConvo.length;
|
|
10
|
+
}
|
|
11
|
+
constructor(e) {
|
|
12
|
+
super(), this.rawConvo = e;
|
|
13
|
+
}
|
|
14
|
+
hasMoreConversations() {
|
|
15
|
+
return this.index < this.rawConvo.length;
|
|
16
|
+
}
|
|
17
|
+
nextConversation() {
|
|
18
|
+
if (this.index >= this.rawConvo.length) return null;
|
|
19
|
+
let e = this.rawConvo[this.shuffledIndices ? this.shuffledIndices[this.index] : this.index];
|
|
20
|
+
return this.index++, e;
|
|
21
|
+
}
|
|
22
|
+
nextTokens(e, t) {
|
|
23
|
+
let n = this.nextConversation();
|
|
24
|
+
return n ? e.encodeConversation(n, !1, t) : null;
|
|
25
|
+
}
|
|
26
|
+
shuffle() {
|
|
27
|
+
if (!this.shuffledIndices) {
|
|
28
|
+
this.shuffledIndices = new Uint32Array(this.rawConvo.length);
|
|
29
|
+
for (let e = 0; e < this.rawConvo.length; e++) this.shuffledIndices[e] = e;
|
|
30
|
+
}
|
|
31
|
+
e(this.shuffledIndices), this.index = 0;
|
|
32
|
+
}
|
|
33
|
+
async estimateTokens(e) {
|
|
34
|
+
return e.encodeConversation(this.rawConvo[0]).length * this.length;
|
|
35
|
+
}
|
|
36
|
+
};
|
|
37
|
+
//#endregion
|
|
38
|
+
export { n as default };
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import { Conversation, ITokeniser } from '../../../main';
|
|
2
|
+
import { Task } from './Task';
|
|
3
|
+
export default class PretrainingTask extends Task {
|
|
4
|
+
private rawText;
|
|
5
|
+
private index;
|
|
6
|
+
get length(): number;
|
|
7
|
+
constructor(texts: string[]);
|
|
8
|
+
hasMoreConversations(): boolean;
|
|
9
|
+
nextConversation(): Conversation[] | null;
|
|
10
|
+
nextTokens(tokeniser: ITokeniser): number[] | null;
|
|
11
|
+
nextTokens(tokeniser: ITokeniser, masking: boolean): {
|
|
12
|
+
tokens: number[];
|
|
13
|
+
mask: boolean[];
|
|
14
|
+
} | null;
|
|
15
|
+
shuffle(): void;
|
|
16
|
+
estimateTokens(tokeniser: ITokeniser): Promise<number>;
|
|
17
|
+
}
|