@genai-fi/nanogpt 0.19.1 → 0.20.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +9 -10
- package/dist/Generator.d.ts +0 -82
- package/dist/Generator.js +0 -11941
- package/dist/RealDiv-CGwv0liw.js +0 -365
- package/dist/Reshape-BW__R4mZ.js +0 -79
- package/dist/Reshape-CPBkTIH2.js +0 -14
- package/dist/TeachableLLM.d.ts +0 -70
- package/dist/TeachableLLM.js +0 -273
- package/dist/Trainer.d.ts +0 -43
- package/dist/Trainer.js +0 -244
- package/dist/_commonjsHelpers-ByX85dGu.js +0 -33
- package/dist/axis_util-GTVlo58H.js +0 -55
- package/dist/backend.d.ts +0 -2
- package/dist/backend.js +0 -13
- package/dist/backend_util-GaFarB78.js +0 -425
- package/dist/backend_webgpu-BqASlsbV.js +0 -545
- package/dist/binary_op_util-pKXltfxI.js +0 -192
- package/dist/broadcast_to-eS93CCN_.js +0 -28
- package/dist/checks/appendCache.d.ts +0 -1
- package/dist/checks/appendCache.js +0 -22
- package/dist/checks/attentionMask.d.ts +0 -1
- package/dist/checks/attentionMask.js +0 -37
- package/dist/checks/check.d.ts +0 -9
- package/dist/checks/check.js +0 -20
- package/dist/checks/gelu.d.ts +0 -1
- package/dist/checks/gelu.js +0 -18
- package/dist/checks/index.d.ts +0 -26
- package/dist/checks/index.js +0 -28
- package/dist/checks/matMulGelu.d.ts +0 -1
- package/dist/checks/matMulGelu.js +0 -28
- package/dist/checks/normRMS.d.ts +0 -1
- package/dist/checks/normRMS.js +0 -16
- package/dist/checks/normRMSGrad.d.ts +0 -1
- package/dist/checks/normRMSGrad.js +0 -12
- package/dist/checks/packUnpack.d.ts +0 -1
- package/dist/checks/packUnpack.js +0 -18
- package/dist/checks/qkv.d.ts +0 -1
- package/dist/checks/qkv.js +0 -34
- package/dist/checks/rope.d.ts +0 -1
- package/dist/checks/rope.js +0 -36
- package/dist/checks/weights.d.ts +0 -14
- package/dist/checks/weights.js +0 -31
- package/dist/clip_by_value-DDA7rrcT.js +0 -12
- package/dist/complex-DI35Q-gW.js +0 -11
- package/dist/complex_util-Yc1A_gV1.js +0 -55
- package/dist/concat-CAQpCret.js +0 -17
- package/dist/concat_util-D18dJ4fD.js +0 -22
- package/dist/data/docx.d.ts +0 -2
- package/dist/data/docx.js +0 -15
- package/dist/data/parquet.d.ts +0 -2
- package/dist/data/parquet.js +0 -17
- package/dist/data/pdf.d.ts +0 -2
- package/dist/data/pdf.js +0 -14
- package/dist/data/textLoader.d.ts +0 -7
- package/dist/data/textLoader.js +0 -118
- package/dist/dataset-CGGp1z9P.js +0 -1124
- package/dist/dropout_util--NxWuYg2.js +0 -27
- package/dist/expand_dims-Bkd1YD5x.js +0 -11
- package/dist/exports_initializers-CYzKLjN7.js +0 -7
- package/dist/floor-BQtb-Azg.js +0 -9
- package/dist/gather-qIqEqaGn.js +0 -9
- package/dist/gelu-B220X1Go.js +0 -26
- package/dist/gpgpu_math-BwvV12df.js +0 -2022
- package/dist/index-CUXkjxiT.js +0 -3516
- package/dist/index-CieiGp4Y.js +0 -349
- package/dist/index-CjOWnMXP.js +0 -7308
- package/dist/index-Cp39cXWe.js +0 -1016
- package/dist/index-D5v913EJ.js +0 -4
- package/dist/index-DmeWGGmS.js +0 -1074
- package/dist/index-DvYrXKkX.js +0 -113
- package/dist/index-Ksja3su6.js +0 -151
- package/dist/index-xuotMAFm.js +0 -118
- package/dist/inference/types.d.ts +0 -16
- package/dist/inference/types.js +0 -1
- package/dist/jszip.min-BZhlzntC.js +0 -2313
- package/dist/kernel_funcs_utils-pq0CK9co.js +0 -306
- package/dist/layers/BaseLayer.d.ts +0 -44
- package/dist/layers/BaseLayer.js +0 -74
- package/dist/layers/CausalSelfAttention.d.ts +0 -39
- package/dist/layers/CausalSelfAttention.js +0 -86
- package/dist/layers/LoRA.d.ts +0 -14
- package/dist/layers/LoRA.js +0 -58
- package/dist/layers/MLP.d.ts +0 -17
- package/dist/layers/MLP.js +0 -44
- package/dist/layers/PositionEmbedding.d.ts +0 -8
- package/dist/layers/PositionEmbedding.js +0 -31
- package/dist/layers/RMSNorm.d.ts +0 -12
- package/dist/layers/RMSNorm.js +0 -22
- package/dist/layers/RoPECache.d.ts +0 -18
- package/dist/layers/RoPECache.js +0 -50
- package/dist/layers/TiedEmbedding.d.ts +0 -13
- package/dist/layers/TiedEmbedding.js +0 -36
- package/dist/layers/TransformerBlock.d.ts +0 -27
- package/dist/layers/TransformerBlock.js +0 -40
- package/dist/layers/WeightStore.d.ts +0 -20
- package/dist/layers/WeightStore.js +0 -76
- package/dist/loader/load.d.ts +0 -6
- package/dist/loader/load.js +0 -68
- package/dist/loader/loadHF.d.ts +0 -8
- package/dist/loader/loadHF.js +0 -22
- package/dist/loader/loadTransformers.d.ts +0 -4
- package/dist/loader/loadTransformers.js +0 -44
- package/dist/loader/loadZipMeta.d.ts +0 -3
- package/dist/loader/loadZipMeta.js +0 -16
- package/dist/loader/newZipLoad.d.ts +0 -3
- package/dist/loader/newZipLoad.js +0 -31
- package/dist/loader/oldZipLoad.d.ts +0 -9
- package/dist/loader/oldZipLoad.js +0 -80
- package/dist/loader/save.d.ts +0 -16
- package/dist/loader/save.js +0 -90
- package/dist/loader/types.d.ts +0 -67
- package/dist/loader/types.js +0 -1
- package/dist/main.d.ts +0 -50
- package/dist/main.js +0 -109
- package/dist/matMul16-BcVC_E62.js +0 -80
- package/dist/matMulGelu-JNLZqKQp.js +0 -163
- package/dist/mat_mul-DhG0Newp.js +0 -11
- package/dist/mod-CSdCpRjf.js +0 -11
- package/dist/models/NanoGPTV1.d.ts +0 -16
- package/dist/models/NanoGPTV1.js +0 -99
- package/dist/models/NanoGPTV2.d.ts +0 -16
- package/dist/models/NanoGPTV2.js +0 -90
- package/dist/models/config.d.ts +0 -27
- package/dist/models/config.js +0 -50
- package/dist/models/factory.d.ts +0 -3
- package/dist/models/factory.js +0 -16
- package/dist/models/model.d.ts +0 -44
- package/dist/models/model.js +0 -134
- package/dist/non_max_suppression_impl-B2W7YjZB.js +0 -102
- package/dist/not_equal-hurPF26l.js +0 -64
- package/dist/ones-BytntneX.js +0 -14
- package/dist/ops/adamAdjust.d.ts +0 -2
- package/dist/ops/adamAdjust.js +0 -9
- package/dist/ops/adamMoments.d.ts +0 -2
- package/dist/ops/adamMoments.js +0 -9
- package/dist/ops/add16.d.ts +0 -2
- package/dist/ops/add16.js +0 -9
- package/dist/ops/appendCache.d.ts +0 -2
- package/dist/ops/appendCache.js +0 -22
- package/dist/ops/attentionMask.d.ts +0 -2
- package/dist/ops/attentionMask.js +0 -10
- package/dist/ops/concat16.d.ts +0 -2
- package/dist/ops/concat16.js +0 -9
- package/dist/ops/cpu/adamAdjust.d.ts +0 -1
- package/dist/ops/cpu/adamAdjust.js +0 -18
- package/dist/ops/cpu/adamMoments.d.ts +0 -1
- package/dist/ops/cpu/adamMoments.js +0 -16
- package/dist/ops/cpu/appendCache.d.ts +0 -1
- package/dist/ops/cpu/appendCache.js +0 -23
- package/dist/ops/cpu/attentionMask.d.ts +0 -1
- package/dist/ops/cpu/attentionMask.js +0 -22
- package/dist/ops/cpu/fusedSoftmax.d.ts +0 -9
- package/dist/ops/cpu/fusedSoftmax.js +0 -29
- package/dist/ops/cpu/gatherSub.d.ts +0 -1
- package/dist/ops/cpu/gatherSub.js +0 -18
- package/dist/ops/cpu/gelu.d.ts +0 -1
- package/dist/ops/cpu/gelu.js +0 -40
- package/dist/ops/cpu/matMul16.d.ts +0 -1
- package/dist/ops/cpu/matMul16.js +0 -15
- package/dist/ops/cpu/matMulGelu.d.ts +0 -1
- package/dist/ops/cpu/matMulGelu.js +0 -53
- package/dist/ops/cpu/matMulMul.d.ts +0 -1
- package/dist/ops/cpu/matMulMul.js +0 -23
- package/dist/ops/cpu/mulDropout.d.ts +0 -1
- package/dist/ops/cpu/mulDropout.js +0 -23
- package/dist/ops/cpu/normRMS.d.ts +0 -1
- package/dist/ops/cpu/normRMS.js +0 -39
- package/dist/ops/cpu/qkv.d.ts +0 -5
- package/dist/ops/cpu/qkv.js +0 -41
- package/dist/ops/cpu/rope.d.ts +0 -6
- package/dist/ops/cpu/rope.js +0 -38
- package/dist/ops/cpu/scatterSub.d.ts +0 -1
- package/dist/ops/cpu/scatterSub.js +0 -23
- package/dist/ops/dot16.d.ts +0 -2
- package/dist/ops/dot16.js +0 -42
- package/dist/ops/dropout.d.ts +0 -2
- package/dist/ops/dropout.js +0 -14
- package/dist/ops/dropout16.d.ts +0 -2
- package/dist/ops/dropout16.js +0 -25
- package/dist/ops/gatherSub.d.ts +0 -2
- package/dist/ops/gatherSub.js +0 -9
- package/dist/ops/gelu.d.ts +0 -3
- package/dist/ops/gelu.js +0 -8
- package/dist/ops/globalNorm.d.ts +0 -2
- package/dist/ops/globalNorm.js +0 -13
- package/dist/ops/grads/add16.d.ts +0 -1
- package/dist/ops/grads/add16.js +0 -26
- package/dist/ops/grads/attentionMask.d.ts +0 -1
- package/dist/ops/grads/attentionMask.js +0 -21
- package/dist/ops/grads/dropout16.d.ts +0 -1
- package/dist/ops/grads/dropout16.js +0 -2
- package/dist/ops/grads/gelu.d.ts +0 -2
- package/dist/ops/grads/gelu.js +0 -5
- package/dist/ops/grads/matMul16.d.ts +0 -2
- package/dist/ops/grads/matMul16.js +0 -9
- package/dist/ops/grads/matMulGelu.d.ts +0 -1
- package/dist/ops/grads/matMulGelu.js +0 -17
- package/dist/ops/grads/mul16.d.ts +0 -1
- package/dist/ops/grads/mul16.js +0 -4
- package/dist/ops/grads/normRMS.d.ts +0 -3
- package/dist/ops/grads/normRMS.js +0 -33
- package/dist/ops/grads/pack16.d.ts +0 -2
- package/dist/ops/grads/pack16.js +0 -6
- package/dist/ops/grads/qkv.d.ts +0 -3
- package/dist/ops/grads/qkv.js +0 -34
- package/dist/ops/grads/rope.d.ts +0 -2
- package/dist/ops/grads/rope.js +0 -5
- package/dist/ops/grads/softmax16.d.ts +0 -2
- package/dist/ops/grads/softmax16.js +0 -25
- package/dist/ops/grads/unpack16.d.ts +0 -2
- package/dist/ops/grads/unpack16.js +0 -5
- package/dist/ops/grads/utils.d.ts +0 -4
- package/dist/ops/grads/utils.js +0 -14
- package/dist/ops/log.d.ts +0 -0
- package/dist/ops/log.js +0 -1
- package/dist/ops/matMul16.d.ts +0 -15
- package/dist/ops/matMul16.js +0 -13
- package/dist/ops/matMulGelu.d.ts +0 -3
- package/dist/ops/matMulGelu.js +0 -14
- package/dist/ops/matMulMul.d.ts +0 -2
- package/dist/ops/matMulMul.js +0 -9
- package/dist/ops/mul16.d.ts +0 -2
- package/dist/ops/mul16.js +0 -39
- package/dist/ops/mulDrop.d.ts +0 -2
- package/dist/ops/mulDrop.js +0 -9
- package/dist/ops/normRMS.d.ts +0 -2
- package/dist/ops/normRMS.js +0 -19
- package/dist/ops/pack16.d.ts +0 -2
- package/dist/ops/pack16.js +0 -5
- package/dist/ops/qkv.d.ts +0 -2
- package/dist/ops/qkv.js +0 -10
- package/dist/ops/reshape16.d.ts +0 -2
- package/dist/ops/reshape16.js +0 -41
- package/dist/ops/rope.d.ts +0 -3
- package/dist/ops/rope.js +0 -7
- package/dist/ops/scatterSub.d.ts +0 -2
- package/dist/ops/scatterSub.js +0 -9
- package/dist/ops/slice16.d.ts +0 -2
- package/dist/ops/slice16.js +0 -9
- package/dist/ops/softmax16.d.ts +0 -2
- package/dist/ops/softmax16.js +0 -9
- package/dist/ops/sub16.d.ts +0 -2
- package/dist/ops/sub16.js +0 -8
- package/dist/ops/sum16.d.ts +0 -2
- package/dist/ops/sum16.js +0 -13
- package/dist/ops/transpose16.d.ts +0 -3
- package/dist/ops/transpose16.js +0 -40
- package/dist/ops/unpack16.d.ts +0 -2
- package/dist/ops/unpack16.js +0 -6
- package/dist/ops/webgl/adamAdjust.d.ts +0 -1
- package/dist/ops/webgl/adamAdjust.js +0 -49
- package/dist/ops/webgl/adamMoments.d.ts +0 -1
- package/dist/ops/webgl/adamMoments.js +0 -40
- package/dist/ops/webgl/appendCache.d.ts +0 -1
- package/dist/ops/webgl/appendCache.js +0 -44
- package/dist/ops/webgl/attentionMask.d.ts +0 -1
- package/dist/ops/webgl/attentionMask.js +0 -45
- package/dist/ops/webgl/dropout16.d.ts +0 -1
- package/dist/ops/webgl/dropout16.js +0 -11
- package/dist/ops/webgl/fusedSoftmax.d.ts +0 -11
- package/dist/ops/webgl/fusedSoftmax.js +0 -80
- package/dist/ops/webgl/gatherSub.d.ts +0 -1
- package/dist/ops/webgl/gatherSub.js +0 -27
- package/dist/ops/webgl/gelu.d.ts +0 -2
- package/dist/ops/webgl/gelu.js +0 -50
- package/dist/ops/webgl/log.d.ts +0 -17
- package/dist/ops/webgl/log.js +0 -23
- package/dist/ops/webgl/matMul16.d.ts +0 -1
- package/dist/ops/webgl/matMul16.js +0 -45
- package/dist/ops/webgl/matMulGelu.d.ts +0 -21
- package/dist/ops/webgl/matMulGelu.js +0 -9
- package/dist/ops/webgl/matMulMul.d.ts +0 -14
- package/dist/ops/webgl/matMulMul.js +0 -28
- package/dist/ops/webgl/mulDropout.d.ts +0 -1
- package/dist/ops/webgl/mulDropout.js +0 -41
- package/dist/ops/webgl/normRMS.d.ts +0 -1
- package/dist/ops/webgl/normRMS.js +0 -93
- package/dist/ops/webgl/qkv.d.ts +0 -1
- package/dist/ops/webgl/qkv.js +0 -46
- package/dist/ops/webgl/rope.d.ts +0 -1
- package/dist/ops/webgl/rope.js +0 -56
- package/dist/ops/webgl/scatterSub.d.ts +0 -1
- package/dist/ops/webgl/scatterSub.js +0 -27
- package/dist/ops/webgpu/adamAdjust.d.ts +0 -1
- package/dist/ops/webgpu/adamAdjust.js +0 -57
- package/dist/ops/webgpu/adamMoments.d.ts +0 -1
- package/dist/ops/webgpu/adamMoments.js +0 -60
- package/dist/ops/webgpu/add16.d.ts +0 -1
- package/dist/ops/webgpu/add16.js +0 -13
- package/dist/ops/webgpu/appendCache.d.ts +0 -1
- package/dist/ops/webgpu/appendCache.js +0 -105
- package/dist/ops/webgpu/attentionMask.d.ts +0 -1
- package/dist/ops/webgpu/attentionMask.js +0 -26
- package/dist/ops/webgpu/attentionMask32_program.d.ts +0 -19
- package/dist/ops/webgpu/attentionMask32_program.js +0 -54
- package/dist/ops/webgpu/clipScale.d.ts +0 -1
- package/dist/ops/webgpu/clipScale.js +0 -58
- package/dist/ops/webgpu/concat16.d.ts +0 -19
- package/dist/ops/webgpu/concat16.js +0 -126
- package/dist/ops/webgpu/dropout16.d.ts +0 -1
- package/dist/ops/webgpu/dropout16.js +0 -51
- package/dist/ops/webgpu/gatherSub.d.ts +0 -1
- package/dist/ops/webgpu/gatherSub.js +0 -39
- package/dist/ops/webgpu/gelu.d.ts +0 -14
- package/dist/ops/webgpu/gelu.js +0 -141
- package/dist/ops/webgpu/index.d.ts +0 -0
- package/dist/ops/webgpu/index.js +0 -26
- package/dist/ops/webgpu/matMul16.d.ts +0 -1
- package/dist/ops/webgpu/matMul16.js +0 -65
- package/dist/ops/webgpu/matMul16_program.d.ts +0 -42
- package/dist/ops/webgpu/matMul16_program.js +0 -343
- package/dist/ops/webgpu/mul16.d.ts +0 -1
- package/dist/ops/webgpu/mul16.js +0 -13
- package/dist/ops/webgpu/norm2.d.ts +0 -1
- package/dist/ops/webgpu/norm2.js +0 -76
- package/dist/ops/webgpu/normRMS.d.ts +0 -1
- package/dist/ops/webgpu/normRMS.js +0 -34
- package/dist/ops/webgpu/normRMS16_program.d.ts +0 -10
- package/dist/ops/webgpu/normRMS16_program.js +0 -25
- package/dist/ops/webgpu/normRMS32_program.d.ts +0 -10
- package/dist/ops/webgpu/normRMS32_program.js +0 -25
- package/dist/ops/webgpu/normRMSGrad.d.ts +0 -1
- package/dist/ops/webgpu/normRMSGrad.js +0 -284
- package/dist/ops/webgpu/pack16.d.ts +0 -1
- package/dist/ops/webgpu/pack16.js +0 -18
- package/dist/ops/webgpu/pack16_program.d.ts +0 -19
- package/dist/ops/webgpu/pack16_program.js +0 -92
- package/dist/ops/webgpu/qkv.d.ts +0 -1
- package/dist/ops/webgpu/qkv.js +0 -24
- package/dist/ops/webgpu/rope.d.ts +0 -1
- package/dist/ops/webgpu/rope.js +0 -135
- package/dist/ops/webgpu/scatterSub.d.ts +0 -1
- package/dist/ops/webgpu/scatterSub.js +0 -40
- package/dist/ops/webgpu/slice16.d.ts +0 -7
- package/dist/ops/webgpu/slice16.js +0 -69
- package/dist/ops/webgpu/softmax16.d.ts +0 -17
- package/dist/ops/webgpu/softmax16.js +0 -21
- package/dist/ops/webgpu/softmax16_program.d.ts +0 -13
- package/dist/ops/webgpu/softmax16_program.js +0 -73
- package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +0 -17
- package/dist/ops/webgpu/softmax16_subgroup_program.js +0 -75
- package/dist/ops/webgpu/softmax16grad.d.ts +0 -1
- package/dist/ops/webgpu/softmax16grad.js +0 -37
- package/dist/ops/webgpu/sub16.d.ts +0 -1
- package/dist/ops/webgpu/sub16.js +0 -13
- package/dist/ops/webgpu/sum16.d.ts +0 -1
- package/dist/ops/webgpu/sum16.js +0 -38
- package/dist/ops/webgpu/transpose16.d.ts +0 -1
- package/dist/ops/webgpu/transpose16.js +0 -34
- package/dist/ops/webgpu/transpose16_program.d.ts +0 -16
- package/dist/ops/webgpu/transpose16_program.js +0 -50
- package/dist/ops/webgpu/transpose16_shared_program.d.ts +0 -15
- package/dist/ops/webgpu/transpose16_shared_program.js +0 -70
- package/dist/ops/webgpu/unpack16.d.ts +0 -1
- package/dist/ops/webgpu/unpack16.js +0 -48
- package/dist/ops/webgpu/utils/binary_op.d.ts +0 -35
- package/dist/ops/webgpu/utils/binary_op.js +0 -139
- package/dist/ops/webgpu/utils/deviceInfo.d.ts +0 -7
- package/dist/ops/webgpu/utils/deviceInfo.js +0 -11
- package/dist/ops/webgpu/utils/reductions.d.ts +0 -43
- package/dist/ops/webgpu/utils/reductions.js +0 -275
- package/dist/ops-CsXeTq1P.js +0 -476
- package/dist/pack16-bqltoUlR.js +0 -39
- package/dist/papaparse.min-C0cScC2i.js +0 -418
- package/dist/parquet-Bqjmp2vo.js +0 -44231
- package/dist/patches/webgpu_backend.d.ts +0 -18
- package/dist/patches/webgpu_backend.js +0 -56
- package/dist/patches/webgpu_base.d.ts +0 -21
- package/dist/patches/webgpu_base.js +0 -34
- package/dist/patches/webgpu_program.d.ts +0 -36
- package/dist/patches/webgpu_program.js +0 -400
- package/dist/pdf-NIhmP3sq.js +0 -19477
- package/dist/rand_util-CZ7yLoUm.js +0 -50
- package/dist/random_normal-IBRrha8a.js +0 -14
- package/dist/random_width-DN5ZtQkM.js +0 -9796
- package/dist/range-C-CjF-LI.js +0 -10
- package/dist/relu-J_X6MUzx.js +0 -9
- package/dist/reshape-BDOuCSNW.js +0 -9
- package/dist/resize_nearest_neighbor-BojqlfRe.js +0 -150
- package/dist/rope-DcrZM_e6.js +0 -24
- package/dist/scatter_nd_util-ByNJaL6I.js +0 -46
- package/dist/segment_util-Dasb2Zaf.js +0 -43
- package/dist/selu_util-BLhIqRkw.js +0 -44
- package/dist/shared-3agzAqQ_.js +0 -53
- package/dist/shared-CagdqkLh.js +0 -2143
- package/dist/slice-BzS11Qh0.js +0 -12
- package/dist/slice_util-CC35pLmT.js +0 -153
- package/dist/softmax-D4q1LJN7.js +0 -12
- package/dist/split-C2Sj255c.js +0 -9
- package/dist/squeeze-ho4wLUek.js +0 -10
- package/dist/stack-DudVrtmG.js +0 -11
- package/dist/step-BTxPtq1r.js +0 -261
- package/dist/sum-BpiwSWvg.js +0 -11
- package/dist/tensor-BWFldCso.js +0 -8
- package/dist/tensor1d-LMGMIUlr.js +0 -11
- package/dist/tensor2d-BnXMKScO.js +0 -14
- package/dist/tensor4d-C6UCG_u8.js +0 -14
- package/dist/tfjs_backend-BGnG-ppu.js +0 -654
- package/dist/tile-CFy-xTO6.js +0 -11
- package/dist/tokeniser/BaseTokeniser.d.ts +0 -33
- package/dist/tokeniser/BaseTokeniser.js +0 -124
- package/dist/tokeniser/CharTokeniser.d.ts +0 -24
- package/dist/tokeniser/CharTokeniser.js +0 -107
- package/dist/tokeniser/bpe.d.ts +0 -28
- package/dist/tokeniser/bpe.js +0 -173
- package/dist/tokeniser/messages.d.ts +0 -61
- package/dist/tokeniser/messages.js +0 -1
- package/dist/tokeniser/type.d.ts +0 -34
- package/dist/tokeniser/type.js +0 -1
- package/dist/training/AdamW.d.ts +0 -36
- package/dist/training/AdamW.js +0 -138
- package/dist/training/BasicTrainer.d.ts +0 -63
- package/dist/training/BasicTrainer.js +0 -265
- package/dist/training/DatasetBuilder.d.ts +0 -26
- package/dist/training/DatasetBuilder.js +0 -86
- package/dist/training/Evaluator.d.ts +0 -19
- package/dist/training/Evaluator.js +0 -39
- package/dist/training/LRScheduler.d.ts +0 -12
- package/dist/training/LRScheduler.js +0 -34
- package/dist/training/PreTrainer.d.ts +0 -11
- package/dist/training/PreTrainer.js +0 -20
- package/dist/training/SFTTrainer.d.ts +0 -12
- package/dist/training/SFTTrainer.js +0 -22
- package/dist/training/loss.d.ts +0 -3
- package/dist/training/loss.js +0 -24
- package/dist/training/orthoGrad.d.ts +0 -2
- package/dist/training/orthoGrad.js +0 -10
- package/dist/training/sparseCrossEntropy.d.ts +0 -7
- package/dist/training/sparseCrossEntropy.js +0 -69
- package/dist/training/tasks/ConversationTask.d.ts +0 -18
- package/dist/training/tasks/ConversationTask.js +0 -40
- package/dist/training/tasks/PretrainingTask.d.ts +0 -17
- package/dist/training/tasks/PretrainingTask.js +0 -47
- package/dist/training/tasks/StartSentenceTask.d.ts +0 -18
- package/dist/training/tasks/StartSentenceTask.js +0 -49
- package/dist/training/tasks/Task.d.ts +0 -22
- package/dist/training/tasks/Task.js +0 -68
- package/dist/training/tasks/splitter.d.ts +0 -5
- package/dist/training/tasks/splitter.js +0 -21
- package/dist/training/types.d.ts +0 -78
- package/dist/training/types.js +0 -1
- package/dist/training/validation.d.ts +0 -17
- package/dist/training/validation.js +0 -84
- package/dist/transpose-9kRxIXWR.js +0 -36
- package/dist/unsorted_segment_sum-DJvk5xnh.js +0 -277
- package/dist/utilities/arrayClose.d.ts +0 -1
- package/dist/utilities/arrayClose.js +0 -20
- package/dist/utilities/datasetID.d.ts +0 -2
- package/dist/utilities/datasetID.js +0 -21
- package/dist/utilities/dummy.d.ts +0 -9
- package/dist/utilities/dummy.js +0 -43
- package/dist/utilities/multinomialCPU.d.ts +0 -2
- package/dist/utilities/multinomialCPU.js +0 -13
- package/dist/utilities/naming.d.ts +0 -4
- package/dist/utilities/naming.js +0 -1
- package/dist/utilities/packed.d.ts +0 -4
- package/dist/utilities/packed.js +0 -15
- package/dist/utilities/parameters.d.ts +0 -11
- package/dist/utilities/parameters.js +0 -57
- package/dist/utilities/performance.d.ts +0 -2
- package/dist/utilities/performance.js +0 -16
- package/dist/utilities/profile.d.ts +0 -17
- package/dist/utilities/profile.js +0 -38
- package/dist/utilities/safetensors.d.ts +0 -3
- package/dist/utilities/safetensors.js +0 -83
- package/dist/utilities/sentences.d.ts +0 -5
- package/dist/utilities/sentences.js +0 -41
- package/dist/utilities/tokenParse.d.ts +0 -1
- package/dist/utilities/tokenParse.js +0 -21
- package/dist/utilities/topP.d.ts +0 -1
- package/dist/utilities/topP.js +0 -13
- package/dist/utilities/waitForModel.d.ts +0 -2
- package/dist/utilities/waitForModel.js +0 -12
- package/dist/utilities/weights.d.ts +0 -12
- package/dist/utilities/weights.js +0 -45
- package/dist/utilities/yielder.d.ts +0 -1
- package/dist/utilities/yielder.js +0 -7
- package/dist/variable-Ck482e3n.js +0 -7
- package/dist/webgpu_program-B4HmApL1.js +0 -525
- package/dist/webgpu_util-DYlGSwOJ.js +0 -64
- package/dist/zeros-DvZpK8s6.js +0 -13
- package/dist/zeros_like-CWjDdwr-.js +0 -721
|
@@ -1,284 +0,0 @@
|
|
|
1
|
-
import { c as $, a6 as y, h as L } from "../../index-CUXkjxiT.js";
|
|
2
|
-
import { createReduceInfo as M } from "./utils/reductions.js";
|
|
3
|
-
import { f as w } from "../../webgpu_util-DYlGSwOJ.js";
|
|
4
|
-
import { e as _ } from "../../webgpu_program-B4HmApL1.js";
|
|
5
|
-
import { p as x, u as R } from "../../pack16-bqltoUlR.js";
|
|
6
|
-
import { isPackedTensor as h } from "../../utilities/packed.js";
|
|
7
|
-
import { reshape16 as z } from "../reshape16.js";
|
|
8
|
-
import { sum16 as N } from "../sum16.js";
|
|
9
|
-
import { slice16 as v } from "../slice16.js";
|
|
10
|
-
class Y {
|
|
11
|
-
outputShape;
|
|
12
|
-
shaderKey = "RMSNormGrad";
|
|
13
|
-
dispatchLayout;
|
|
14
|
-
dispatch;
|
|
15
|
-
workgroupSize = [64, 1, 1];
|
|
16
|
-
variableNames = ["x", "gamma", "dy"];
|
|
17
|
-
uniforms = "reduceSize : i32, batchSize: i32";
|
|
18
|
-
inputShape;
|
|
19
|
-
size = !1;
|
|
20
|
-
rowsPerWorkgroup;
|
|
21
|
-
packed = !1;
|
|
22
|
-
outputComponent;
|
|
23
|
-
constructor(e, a = 4, t = !1) {
|
|
24
|
-
if (this.packed = t, this.shaderKey = `RMSNormGrad_${a}`, this.rowsPerWorkgroup = a, this.inputShape = [e.batchSize, e.inSize], this.outputShape = [e.batchSize + e.batchSize / this.rowsPerWorkgroup, e.inSize], this.dispatchLayout = w(this.outputShape), this.dispatch = [e.batchSize / this.rowsPerWorkgroup, 1, 1], e.batchSize % this.rowsPerWorkgroup !== 0)
|
|
25
|
-
throw new Error(
|
|
26
|
-
`RMSNormGradProgram: batch size ${e.batchSize} must be divisible by rowsPerWorkgroup ${this.rowsPerWorkgroup}`
|
|
27
|
-
);
|
|
28
|
-
if (e.inSize > 1024)
|
|
29
|
-
throw new Error(`RMSNormGradProgram: inSize ${e.inSize} exceeds max of 1024`);
|
|
30
|
-
}
|
|
31
|
-
getUserCode() {
|
|
32
|
-
const e = this.workgroupSize[0], a = this.rowsPerWorkgroup, t = `
|
|
33
|
-
var<workgroup> partials : array<vec2<f32>, ${e}>;
|
|
34
|
-
var<workgroup> accumulation: array<${this.packed ? "vec2<f32>" : "f32"}, 1024>;
|
|
35
|
-
`, u = this.packed ? `
|
|
36
|
-
let X = unpack2x16float(u32(x[offset + k]));
|
|
37
|
-
let DY = unpack2x16float(u32(dy[offset + k]));
|
|
38
|
-
let G = unpack2x16float(u32(gamma[k]));
|
|
39
|
-
sum_x2 = fma(X.x, X.x, sum_x2);
|
|
40
|
-
sum_x2 = fma(X.y, X.y, sum_x2);
|
|
41
|
-
sum_dygx = fma(DY.x * G.x, X.x, sum_dygx);
|
|
42
|
-
sum_dygx = fma(DY.y * G.y, X.y, sum_dygx);
|
|
43
|
-
` : `
|
|
44
|
-
let X = f32(x[offset + k]);
|
|
45
|
-
let DY = f32(dy[offset + k]);
|
|
46
|
-
let G = f32(gamma[k]);
|
|
47
|
-
sum_x2 = fma(X, X, sum_x2);
|
|
48
|
-
sum_dygx = fma(DY * G, X, sum_dygx);
|
|
49
|
-
`, m = this.packed ? `
|
|
50
|
-
let X = unpack2x16float(u32(x[offset + k]));
|
|
51
|
-
let DY = unpack2x16float(u32(dy[offset + k]));
|
|
52
|
-
let G = unpack2x16float(u32(gamma[k]));
|
|
53
|
-
|
|
54
|
-
let dyGamma = DY * G;
|
|
55
|
-
let dx = vec2<f32>(
|
|
56
|
-
fma(dyGamma.x, invRMS, -X.x * scale),
|
|
57
|
-
fma(dyGamma.y, invRMS, -X.y * scale)
|
|
58
|
-
);
|
|
59
|
-
|
|
60
|
-
result[offset + k] = i32(pack2x16float(dx));
|
|
61
|
-
|
|
62
|
-
// dGamma
|
|
63
|
-
accumulation[k] = fma(DY, X * invRMS, accumulation[k]);
|
|
64
|
-
` : `
|
|
65
|
-
let X = f32(x[offset + k]);
|
|
66
|
-
let DY = f32(dy[offset + k]);
|
|
67
|
-
let G = f32(gamma[k]);
|
|
68
|
-
|
|
69
|
-
let dyGamma = DY * G;
|
|
70
|
-
let dx = fma(dyGamma, invRMS, -X * scale);
|
|
71
|
-
|
|
72
|
-
result[offset + k] = dx;
|
|
73
|
-
|
|
74
|
-
// dGamma
|
|
75
|
-
accumulation[k] = fma(DY, X * invRMS, accumulation[k]);
|
|
76
|
-
`, n = this.packed ? `
|
|
77
|
-
result[outDgBase + k] = i32(pack2x16float(accumulation[k]));
|
|
78
|
-
` : `
|
|
79
|
-
result[outDgBase + k] = accumulation[k];
|
|
80
|
-
`;
|
|
81
|
-
return `
|
|
82
|
-
fn DIV_CEIL(a : u32, b : u32) -> u32 {
|
|
83
|
-
return ((a - 1u) / b + 1u);
|
|
84
|
-
}
|
|
85
|
-
|
|
86
|
-
${t}
|
|
87
|
-
|
|
88
|
-
${_("index")} {
|
|
89
|
-
// One workgroup per row (batch).
|
|
90
|
-
let Length = uniforms.reduceSize;
|
|
91
|
-
let BatchSize = uniforms.batchSize;
|
|
92
|
-
for (var k = i32(localId.x); k < Length; k = k + ${e}) {
|
|
93
|
-
accumulation[k] = ${this.packed ? "vec2<f32>(0.0f)" : "0.0f"};
|
|
94
|
-
}
|
|
95
|
-
|
|
96
|
-
for (var rowOff = 0; rowOff < ${a}; rowOff = rowOff + 1) {
|
|
97
|
-
let row = i32(workgroupId.x) * ${a} + rowOff;
|
|
98
|
-
let offset = row * Length;
|
|
99
|
-
|
|
100
|
-
var sum_x2 = 0.0f;
|
|
101
|
-
var sum_dygx = 0.0f;
|
|
102
|
-
|
|
103
|
-
for (var k = i32(localId.x); k < Length; k = k + ${e}) {
|
|
104
|
-
${u}
|
|
105
|
-
}
|
|
106
|
-
|
|
107
|
-
partials[localId.x] = vec2<f32>(sum_x2, sum_dygx);
|
|
108
|
-
workgroupBarrier();
|
|
109
|
-
|
|
110
|
-
var reduceSize = min(u32(Length), ${e}u);
|
|
111
|
-
for (var currentSize = reduceSize / 2u; reduceSize > 1u; currentSize = reduceSize / 2u) {
|
|
112
|
-
let interval = DIV_CEIL(reduceSize, 2u);
|
|
113
|
-
if (localId.x < currentSize) {
|
|
114
|
-
partials[localId.x] = partials[localId.x] + partials[localId.x + interval];
|
|
115
|
-
}
|
|
116
|
-
reduceSize = interval;
|
|
117
|
-
workgroupBarrier();
|
|
118
|
-
}
|
|
119
|
-
|
|
120
|
-
let invN = 1.0f / f32(${this.packed ? "Length * 2" : "Length"});
|
|
121
|
-
let mean_x2 = fma(partials[0].x, invN, 1e-8);
|
|
122
|
-
let mean_dygx = partials[0].y * invN;
|
|
123
|
-
|
|
124
|
-
let invRMS = inverseSqrt(mean_x2);
|
|
125
|
-
let scale = (mean_dygx / (mean_x2)) * invRMS;
|
|
126
|
-
|
|
127
|
-
// write dx and dGamma.
|
|
128
|
-
for (var k = i32(localId.x); k < Length; k = k + ${e}) {
|
|
129
|
-
${m}
|
|
130
|
-
}
|
|
131
|
-
|
|
132
|
-
workgroupBarrier();
|
|
133
|
-
}
|
|
134
|
-
|
|
135
|
-
// Write out the partially accumulated dGamma
|
|
136
|
-
let outDgBase = BatchSize * Length + i32(workgroupId.x) * Length;
|
|
137
|
-
for (var k = i32(localId.x); k < Length; k = k + ${e}) {
|
|
138
|
-
${n}
|
|
139
|
-
}
|
|
140
|
-
}
|
|
141
|
-
`;
|
|
142
|
-
}
|
|
143
|
-
}
|
|
144
|
-
class I {
|
|
145
|
-
outputShape;
|
|
146
|
-
shaderKey = "RMSNormGrad";
|
|
147
|
-
dispatchLayout;
|
|
148
|
-
dispatch;
|
|
149
|
-
workgroupSize = [64, 1, 1];
|
|
150
|
-
variableNames = ["x", "dy"];
|
|
151
|
-
uniforms = "reduceSize : i32, batchSize: i32";
|
|
152
|
-
inputShape;
|
|
153
|
-
size = !1;
|
|
154
|
-
packed = !1;
|
|
155
|
-
outputComponent;
|
|
156
|
-
constructor(e, a = !1) {
|
|
157
|
-
this.packed = a, this.shaderKey = "RMSNormGrad_NoGamma", this.inputShape = [e.batchSize, e.inSize], this.outputShape = [e.batchSize, e.inSize], this.dispatchLayout = w(this.outputShape), this.dispatch = [e.batchSize, 1, 1];
|
|
158
|
-
}
|
|
159
|
-
getUserCode() {
|
|
160
|
-
const e = this.workgroupSize[0], a = `
|
|
161
|
-
var<workgroup> partials : array<vec2<f32>, ${e}>;
|
|
162
|
-
`, t = this.packed ? `
|
|
163
|
-
let X = unpack2x16float(u32(x[offset + k]));
|
|
164
|
-
let DY = unpack2x16float(u32(dy[offset + k]));
|
|
165
|
-
sum_x2 = fma(X.x, X.x, sum_x2);
|
|
166
|
-
sum_x2 = fma(X.y, X.y, sum_x2);
|
|
167
|
-
sum_dygx = fma(DY.x, X.x, sum_dygx);
|
|
168
|
-
sum_dygx = fma(DY.y, X.y, sum_dygx);
|
|
169
|
-
` : `
|
|
170
|
-
let X = f32(x[offset + k]);
|
|
171
|
-
let DY = f32(dy[offset + k]);
|
|
172
|
-
sum_x2 = fma(X, X, sum_x2);
|
|
173
|
-
sum_dygx = fma(DY, X, sum_dygx);
|
|
174
|
-
`, u = this.packed ? `
|
|
175
|
-
let X = unpack2x16float(u32(x[offset + k]));
|
|
176
|
-
let DY = unpack2x16float(u32(dy[offset + k]));
|
|
177
|
-
|
|
178
|
-
let dx = vec2<f32>(
|
|
179
|
-
fma(DY.x, invRMS, -X.x * scale),
|
|
180
|
-
fma(DY.y, invRMS, -X.y * scale)
|
|
181
|
-
);
|
|
182
|
-
|
|
183
|
-
result[offset + k] = i32(pack2x16float(dx));
|
|
184
|
-
` : `
|
|
185
|
-
let X = f32(x[offset + k]);
|
|
186
|
-
let DY = f32(dy[offset + k]);
|
|
187
|
-
|
|
188
|
-
let dx = fma(DY, invRMS, -X * scale);
|
|
189
|
-
|
|
190
|
-
result[offset + k] = dx;
|
|
191
|
-
`;
|
|
192
|
-
return `
|
|
193
|
-
fn DIV_CEIL(a : u32, b : u32) -> u32 {
|
|
194
|
-
return ((a - 1u) / b + 1u);
|
|
195
|
-
}
|
|
196
|
-
|
|
197
|
-
${a}
|
|
198
|
-
|
|
199
|
-
${_("index")} {
|
|
200
|
-
// One workgroup per row (batch).
|
|
201
|
-
let Length = uniforms.reduceSize;
|
|
202
|
-
let BatchSize = uniforms.batchSize;
|
|
203
|
-
|
|
204
|
-
let row = i32(workgroupId.x);
|
|
205
|
-
let offset = row * Length;
|
|
206
|
-
|
|
207
|
-
var sum_x2 = 0.0f;
|
|
208
|
-
var sum_dygx = 0.0f;
|
|
209
|
-
|
|
210
|
-
for (var k = i32(localId.x); k < Length; k = k + ${e}) {
|
|
211
|
-
${t}
|
|
212
|
-
}
|
|
213
|
-
|
|
214
|
-
partials[localId.x] = vec2<f32>(sum_x2, sum_dygx);
|
|
215
|
-
workgroupBarrier();
|
|
216
|
-
|
|
217
|
-
var reduceSize = min(u32(Length), ${e}u);
|
|
218
|
-
for (var currentSize = reduceSize / 2u; reduceSize > 1u; currentSize = reduceSize / 2u) {
|
|
219
|
-
let interval = DIV_CEIL(reduceSize, 2u);
|
|
220
|
-
if (localId.x < currentSize) {
|
|
221
|
-
partials[localId.x] = partials[localId.x] + partials[localId.x + interval];
|
|
222
|
-
}
|
|
223
|
-
reduceSize = interval;
|
|
224
|
-
workgroupBarrier();
|
|
225
|
-
}
|
|
226
|
-
|
|
227
|
-
let invN = 1.0f / f32(${this.packed ? "Length * 2" : "Length"});
|
|
228
|
-
let mean_x2 = fma(partials[0].x, invN, 1e-8);
|
|
229
|
-
let mean_dygx = partials[0].y * invN;
|
|
230
|
-
|
|
231
|
-
let invRMS = inverseSqrt(mean_x2);
|
|
232
|
-
let scale = (mean_dygx / (mean_x2)) * invRMS;
|
|
233
|
-
|
|
234
|
-
// write dx and dGamma.
|
|
235
|
-
for (var k = i32(localId.x); k < Length; k = k + ${e}) {
|
|
236
|
-
${u}
|
|
237
|
-
}
|
|
238
|
-
}
|
|
239
|
-
`;
|
|
240
|
-
}
|
|
241
|
-
}
|
|
242
|
-
function P(d) {
|
|
243
|
-
const { dy: e, x: a, gamma: t } = d.inputs, u = 4;
|
|
244
|
-
y(a.shape, e.shape, "Error in RMSNormGrad dy: ");
|
|
245
|
-
const m = h(a), n = t ? h(t) : !1, f = h(e), r = m || n || f, i = !r || m ? a : x(a), s = !r || n || !t ? t : x(t), c = !r || f ? e : x(e);
|
|
246
|
-
s && y(s.shape, [i.shape[i.shape.length - 1]], "Error in RMSNormGrad gamma: ");
|
|
247
|
-
const G = d.backend, o = M(s ? [i, s, c] : [i, c], -1), k = t ? new Y(o, u, r) : new I(o, r), X = [
|
|
248
|
-
{ type: "int32", data: [k.inputShape[1]] },
|
|
249
|
-
// Reduce size
|
|
250
|
-
{ type: "int32", data: [k.inputShape[0]] }
|
|
251
|
-
// Batch size
|
|
252
|
-
];
|
|
253
|
-
if (o.inSize > 1024)
|
|
254
|
-
throw new Error(`rmsNormGradGPU: inSize ${o.inSize} exceeds max of 1024`);
|
|
255
|
-
const D = G.runWebGPUProgram(
|
|
256
|
-
k,
|
|
257
|
-
s ? [i, s, c] : [i, c],
|
|
258
|
-
r ? "packedF16" : "float32",
|
|
259
|
-
X
|
|
260
|
-
);
|
|
261
|
-
r && !m && i.dispose(), r && !n && s && s.dispose(), r && !f && c.dispose();
|
|
262
|
-
const l = L().makeTensorFromTensorInfo(D);
|
|
263
|
-
if (t) {
|
|
264
|
-
const p = v(l, [0, 0], [o.batchSize, o.inSize]), S = v(
|
|
265
|
-
l,
|
|
266
|
-
[o.batchSize, 0],
|
|
267
|
-
[o.batchSize / u, o.inSize]
|
|
268
|
-
);
|
|
269
|
-
l.dispose();
|
|
270
|
-
const b = z(p, a.shape);
|
|
271
|
-
p.dispose();
|
|
272
|
-
const g = N(S, [0]);
|
|
273
|
-
return S.dispose(), [b, !r || n ? g : R(g)];
|
|
274
|
-
} else {
|
|
275
|
-
const p = z(l, a.shape);
|
|
276
|
-
return l.dispose(), [p];
|
|
277
|
-
}
|
|
278
|
-
}
|
|
279
|
-
const C = {
|
|
280
|
-
kernelName: "RMSNormGrad",
|
|
281
|
-
backendName: "webgpu",
|
|
282
|
-
kernelFunc: P
|
|
283
|
-
};
|
|
284
|
-
$(C);
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
export {};
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
import { c as i } from "../../index-CUXkjxiT.js";
|
|
2
|
-
import p from "./pack16_program.js";
|
|
3
|
-
function m(e) {
|
|
4
|
-
const { x: n } = e.inputs, { scaling: a, padding: r } = e.attrs, c = e.backend;
|
|
5
|
-
if (n.shape[n.shape.length - 1] % 2 !== 0)
|
|
6
|
-
throw new Error("Last dimension of input tensor must be even to use Pack16.");
|
|
7
|
-
e.attrs && (e.attrs.originalShape = n.shape);
|
|
8
|
-
const t = new p(n.shape, r), o = a !== 1;
|
|
9
|
-
o && t.useScaling();
|
|
10
|
-
const s = [{ type: "float32", data: [a] }];
|
|
11
|
-
return c.runWebGPUProgram(t, [n], "packedF16", o ? s : void 0);
|
|
12
|
-
}
|
|
13
|
-
const k = {
|
|
14
|
-
kernelName: "Pack16",
|
|
15
|
-
backendName: "webgpu",
|
|
16
|
-
kernelFunc: m
|
|
17
|
-
};
|
|
18
|
-
i(k);
|
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
import { WebGPUProgram } from '@tensorflow/tfjs-backend-webgpu';
|
|
2
|
-
export default class PackProgram implements WebGPUProgram {
|
|
3
|
-
outputShape: number[];
|
|
4
|
-
shaderKey: string;
|
|
5
|
-
dispatchLayout: {
|
|
6
|
-
x: number[];
|
|
7
|
-
};
|
|
8
|
-
dispatch: [number, number, number];
|
|
9
|
-
workgroupSize: [number, number, number];
|
|
10
|
-
variableNames: string[];
|
|
11
|
-
uniforms?: string;
|
|
12
|
-
size: boolean;
|
|
13
|
-
outputComponent: number;
|
|
14
|
-
scaling: boolean;
|
|
15
|
-
padding: number;
|
|
16
|
-
constructor(outShape: number[], padding?: number);
|
|
17
|
-
useScaling(): void;
|
|
18
|
-
getUserCode(): string;
|
|
19
|
-
}
|
|
@@ -1,92 +0,0 @@
|
|
|
1
|
-
import { f as o, c as a } from "../../webgpu_util-DYlGSwOJ.js";
|
|
2
|
-
import { e as s } from "../../webgpu_program-B4HmApL1.js";
|
|
3
|
-
class h {
|
|
4
|
-
outputShape;
|
|
5
|
-
shaderKey = "Pack16";
|
|
6
|
-
dispatchLayout;
|
|
7
|
-
dispatch;
|
|
8
|
-
workgroupSize = [64, 1, 1];
|
|
9
|
-
variableNames = ["x"];
|
|
10
|
-
uniforms;
|
|
11
|
-
size = !0;
|
|
12
|
-
outputComponent = 4;
|
|
13
|
-
scaling = !1;
|
|
14
|
-
padding = 0;
|
|
15
|
-
constructor(t, e = 0) {
|
|
16
|
-
if (t[t.length - 1] % 2 !== 0 && e === 0)
|
|
17
|
-
throw new Error("Last dimension of output shape must be even to use Pack16.");
|
|
18
|
-
if (e % 4 !== 0)
|
|
19
|
-
throw new Error("Padding must be a multiple of 4 to use Pack16.");
|
|
20
|
-
if (this.outputShape = [...t.slice(0, -1), t[t.length - 1]], e > 0) {
|
|
21
|
-
this.shaderKey += `_Padded${e}`, this.padding = e;
|
|
22
|
-
for (let i = this.outputShape.length - 2; i < this.outputShape.length; i++)
|
|
23
|
-
this.outputShape[i] % this.padding !== 0 && (this.outputShape[i] += this.padding - this.outputShape[i] % this.padding);
|
|
24
|
-
this.outputComponent = 1;
|
|
25
|
-
}
|
|
26
|
-
this.outputShape[this.outputShape.length - 1] /= 2, this.outputShape[this.outputShape.length - 1] % this.outputComponent !== 0 && (this.outputComponent = 1), this.dispatchLayout = o(this.outputShape), this.dispatch = a(this.dispatchLayout, this.outputShape, this.workgroupSize, [
|
|
27
|
-
this.outputComponent,
|
|
28
|
-
1,
|
|
29
|
-
1
|
|
30
|
-
]);
|
|
31
|
-
}
|
|
32
|
-
useScaling() {
|
|
33
|
-
this.shaderKey += "_Scaled", this.uniforms = "scaling : f32,", this.scaling = !0;
|
|
34
|
-
}
|
|
35
|
-
getUserCode() {
|
|
36
|
-
if (this.padding > 0 && this.outputComponent === 1) {
|
|
37
|
-
const t = this.outputShape.length;
|
|
38
|
-
return `
|
|
39
|
-
${s("index")} {
|
|
40
|
-
if (index < uniforms.size) {
|
|
41
|
-
var coords = getCoordsFromIndex(index);
|
|
42
|
-
coords[${t} - 1] = coords[${t} - 1] * 2;
|
|
43
|
-
let row = coords[${t} - 2];
|
|
44
|
-
let col = coords[${t} - 1];
|
|
45
|
-
let width = uniforms.xShape[${t} - 1];
|
|
46
|
-
let height = uniforms.xShape[${t} - 2];
|
|
47
|
-
|
|
48
|
-
var value1 = 0.0f;
|
|
49
|
-
if (col < width && row < height) {
|
|
50
|
-
let baseInputIndex = getIndexFromCoords${t}D(coords, uniforms.xShape);
|
|
51
|
-
value1 = x[baseInputIndex] ${this.scaling ? "* uniforms.scaling" : ""};
|
|
52
|
-
}
|
|
53
|
-
var value2 = 0.0f;
|
|
54
|
-
if (col + 1 < width && row < height) {
|
|
55
|
-
coords[${t} - 1] = coords[${t} - 1] + 1;
|
|
56
|
-
let baseInputIndex = getIndexFromCoords${t}D(coords, uniforms.xShape);
|
|
57
|
-
value2 = x[baseInputIndex] ${this.scaling ? "* uniforms.scaling" : ""};
|
|
58
|
-
}
|
|
59
|
-
let packed = i32(pack2x16float(vec2<f32>(value1, value2)));
|
|
60
|
-
result[index] = packed;
|
|
61
|
-
}
|
|
62
|
-
}`;
|
|
63
|
-
}
|
|
64
|
-
return this.outputComponent === 1 ? `
|
|
65
|
-
${s("index")} {
|
|
66
|
-
if (index < uniforms.size) {
|
|
67
|
-
let baseInputIndex = index * 2;
|
|
68
|
-
let x1 = x[baseInputIndex] ${this.scaling ? "* uniforms.scaling" : ""};
|
|
69
|
-
let x2 = x[baseInputIndex + 1] ${this.scaling ? "* uniforms.scaling" : ""};
|
|
70
|
-
let packed = i32(pack2x16float(vec2<f32>(x1, x2)));
|
|
71
|
-
result[index] = packed;
|
|
72
|
-
}
|
|
73
|
-
}` : `
|
|
74
|
-
${s("index")} {
|
|
75
|
-
if (index < uniforms.size) {
|
|
76
|
-
let baseInputIndex = index * 2;
|
|
77
|
-
let x1 = x[baseInputIndex] ${this.scaling ? "* uniforms.scaling" : ""};
|
|
78
|
-
let x2 = x[baseInputIndex + 1] ${this.scaling ? "* uniforms.scaling" : ""};
|
|
79
|
-
let packed = vec4<i32>(
|
|
80
|
-
i32(pack2x16float(vec2<f32>(x1.x, x1.y))),
|
|
81
|
-
i32(pack2x16float(vec2<f32>(x1.z, x1.w))),
|
|
82
|
-
i32(pack2x16float(vec2<f32>(x2.x, x2.y))),
|
|
83
|
-
i32(pack2x16float(vec2<f32>(x2.z, x2.w)))
|
|
84
|
-
);
|
|
85
|
-
result[index] = packed;
|
|
86
|
-
}
|
|
87
|
-
}`;
|
|
88
|
-
}
|
|
89
|
-
}
|
|
90
|
-
export {
|
|
91
|
-
h as default
|
|
92
|
-
};
|
package/dist/ops/webgpu/qkv.d.ts
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
export {};
|
package/dist/ops/webgpu/qkv.js
DELETED
|
@@ -1,24 +0,0 @@
|
|
|
1
|
-
import { c as h, a6 as l } from "../../index-CUXkjxiT.js";
|
|
2
|
-
import { b as f } from "../../matMul16-BcVC_E62.js";
|
|
3
|
-
import { slice16 as a } from "../slice16.js";
|
|
4
|
-
import { isPackedTensor as u } from "../../utilities/packed.js";
|
|
5
|
-
function k(i) {
|
|
6
|
-
const { x: n, kernel: c } = i.inputs, { heads: e } = i.attrs, r = n.shape[0], t = n.shape[1], s = n.shape[2], p = u(n);
|
|
7
|
-
if (l(c.shape, [p ? s * 2 : s, 3 * s], "Error in QKV: "), s % e !== 0)
|
|
8
|
-
throw new Error(`Channel dimension ${s} must be divisible by number of heads ${e} in QKV.`);
|
|
9
|
-
const o = f(n, c, !1, !1, {
|
|
10
|
-
forceOutputShape: [r, t, 3 * e, s / e],
|
|
11
|
-
perm: [0, 2, 1, 3]
|
|
12
|
-
}), m = [
|
|
13
|
-
a(o, [0, 0, 0, 0], [r, e, t, s / e]),
|
|
14
|
-
a(o, [0, e, 0, 0], [r, e, t, s / e]),
|
|
15
|
-
a(o, [0, 2 * e, 0, 0], [r, e, t, s / e])
|
|
16
|
-
];
|
|
17
|
-
return o.dispose(), m;
|
|
18
|
-
}
|
|
19
|
-
const b = {
|
|
20
|
-
kernelName: "QKV",
|
|
21
|
-
backendName: "webgpu",
|
|
22
|
-
kernelFunc: k
|
|
23
|
-
};
|
|
24
|
-
h(b);
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
export {};
|
package/dist/ops/webgpu/rope.js
DELETED
|
@@ -1,135 +0,0 @@
|
|
|
1
|
-
import { isPackedTensor as y } from "../../utilities/packed.js";
|
|
2
|
-
import { e as c } from "../../webgpu_program-B4HmApL1.js";
|
|
3
|
-
import { f as x, c as l } from "../../webgpu_util-DYlGSwOJ.js";
|
|
4
|
-
import { c as w, a6 as b } from "../../index-CUXkjxiT.js";
|
|
5
|
-
class v {
|
|
6
|
-
variableNames = ["x", "sin", "cos"];
|
|
7
|
-
outputShape;
|
|
8
|
-
shaderKey = "Rope";
|
|
9
|
-
dispatchLayout;
|
|
10
|
-
dispatch;
|
|
11
|
-
workgroupSize = [64, 1, 1];
|
|
12
|
-
size = !0;
|
|
13
|
-
uniforms = "pastLen: i32";
|
|
14
|
-
constructor(e, o, a, t) {
|
|
15
|
-
this.shaderKey = `Rope_${t}`, this.outputShape = [e, o, a, t], this.dispatchLayout = x(this.outputShape), this.dispatch = l(this.dispatchLayout, this.outputShape, this.workgroupSize);
|
|
16
|
-
}
|
|
17
|
-
getUserCode() {
|
|
18
|
-
const e = this.outputShape[3];
|
|
19
|
-
return `
|
|
20
|
-
${c("index")} {
|
|
21
|
-
if (index < uniforms.size) {
|
|
22
|
-
let coords = getCoordsFromIndex(index); // [b, h, t, d]
|
|
23
|
-
let b = coords[0];
|
|
24
|
-
let h = coords[1];
|
|
25
|
-
let t = coords[2];
|
|
26
|
-
let d = coords[3];
|
|
27
|
-
|
|
28
|
-
let rotaryDim = ${e};
|
|
29
|
-
|
|
30
|
-
var outVal = 0.0;
|
|
31
|
-
|
|
32
|
-
let xIdx = b * uniforms.outShapeStrides[0] +
|
|
33
|
-
h * uniforms.outShapeStrides[1] +
|
|
34
|
-
t * uniforms.outShapeStrides[2] +
|
|
35
|
-
d;
|
|
36
|
-
|
|
37
|
-
if (d < rotaryDim) {
|
|
38
|
-
let idx = (t + uniforms.pastLen) * uniforms.cosShape[1] + d / 2;
|
|
39
|
-
let cos = cos[idx];
|
|
40
|
-
let sin = sin[idx];
|
|
41
|
-
|
|
42
|
-
let ownX = x[xIdx] * cos;
|
|
43
|
-
var evenOdd = 0.0;
|
|
44
|
-
|
|
45
|
-
if (d % 2 == 0) {
|
|
46
|
-
// even index
|
|
47
|
-
evenOdd = -x[xIdx + 1];
|
|
48
|
-
} else {
|
|
49
|
-
// odd index
|
|
50
|
-
evenOdd = x[xIdx - 1];
|
|
51
|
-
}
|
|
52
|
-
|
|
53
|
-
outVal = fma(evenOdd, sin, ownX);
|
|
54
|
-
} else {
|
|
55
|
-
// pass through for non-rotary dims
|
|
56
|
-
outVal = x[xIdx];
|
|
57
|
-
}
|
|
58
|
-
|
|
59
|
-
setOutputAtIndex(index, outVal);
|
|
60
|
-
}
|
|
61
|
-
}
|
|
62
|
-
`;
|
|
63
|
-
}
|
|
64
|
-
}
|
|
65
|
-
class k {
|
|
66
|
-
variableNames = ["x", "sin", "cos"];
|
|
67
|
-
outputShape;
|
|
68
|
-
shaderKey = "Rope";
|
|
69
|
-
dispatchLayout;
|
|
70
|
-
dispatch;
|
|
71
|
-
workgroupSize = [64, 1, 1];
|
|
72
|
-
size = !0;
|
|
73
|
-
uniforms = "pastLen: i32";
|
|
74
|
-
constructor(e, o, a, t) {
|
|
75
|
-
this.shaderKey = `Rope_${t}`, this.outputShape = [e, o, a, t / 2], this.dispatchLayout = x(this.outputShape), this.dispatch = l(this.dispatchLayout, this.outputShape, this.workgroupSize);
|
|
76
|
-
}
|
|
77
|
-
getUserCode() {
|
|
78
|
-
return `
|
|
79
|
-
${c("index")} {
|
|
80
|
-
if (index < uniforms.size) {
|
|
81
|
-
let coords = getCoordsFromIndex(index); // [b, h, t, d]
|
|
82
|
-
let b = coords[0];
|
|
83
|
-
let h = coords[1];
|
|
84
|
-
let t = coords[2];
|
|
85
|
-
let d = coords[3];
|
|
86
|
-
|
|
87
|
-
var outVal = vec2<f32>(0.0, 0.0);
|
|
88
|
-
|
|
89
|
-
let xIdx = b * uniforms.outShapeStrides[0] +
|
|
90
|
-
h * uniforms.outShapeStrides[1] +
|
|
91
|
-
t * uniforms.outShapeStrides[2] +
|
|
92
|
-
d;
|
|
93
|
-
|
|
94
|
-
let idx = (t + uniforms.pastLen) * uniforms.cosShape[1] + d;
|
|
95
|
-
let cos = cos[idx];
|
|
96
|
-
let sin = sin[idx];
|
|
97
|
-
|
|
98
|
-
let xPair = unpack2x16float(u32(x[xIdx]));
|
|
99
|
-
let ownX = vec2<f32>(xPair.x * cos, xPair.y * cos);
|
|
100
|
-
|
|
101
|
-
let evenOdd = vec2<f32>(
|
|
102
|
-
-xPair.y,
|
|
103
|
-
xPair.x
|
|
104
|
-
);
|
|
105
|
-
|
|
106
|
-
outVal = vec2<f32>(
|
|
107
|
-
fma(evenOdd.x, sin, ownX.x),
|
|
108
|
-
fma(evenOdd.y, sin, ownX.y)
|
|
109
|
-
);
|
|
110
|
-
|
|
111
|
-
result[index] = i32(pack2x16float(outVal));
|
|
112
|
-
}
|
|
113
|
-
}
|
|
114
|
-
`;
|
|
115
|
-
}
|
|
116
|
-
}
|
|
117
|
-
function L(i) {
|
|
118
|
-
const { x: e } = i.inputs, { pastLen: o, negSin: a, ropeCache: t } = i.attrs, m = i.backend, d = y(e), p = e.shape[0], h = e.shape[1], r = e.shape[2], n = d ? e.shape[3] * 2 : e.shape[3], s = a ? t.getNegSin() : t.getSin(), u = t.getCos();
|
|
119
|
-
if (b(s.shape, u.shape, "Error in Rope: "), s.shape[0] < r + o)
|
|
120
|
-
throw new Error(
|
|
121
|
-
`Sin tensor shape ${s.shape} is not compatible with seqLength ${r} and pastLen ${o}.`
|
|
122
|
-
);
|
|
123
|
-
if (s.shape[1] * 2 < n)
|
|
124
|
-
throw new Error(`Sin tensor shape ${s.shape} is not compatible with feature dimension ${n}.`);
|
|
125
|
-
if (s.shape.length !== 3)
|
|
126
|
-
throw new Error(`Sin tensor must be 3-dimensional, but got shape ${s.shape}.`);
|
|
127
|
-
const f = d ? new k(p, h, r, n) : new v(p, h, r, n), S = [{ type: "int32", data: [o] }], g = d ? "packedF16" : e.dtype;
|
|
128
|
-
return m.runWebGPUProgram(f, [e, s, u], g, S);
|
|
129
|
-
}
|
|
130
|
-
const P = {
|
|
131
|
-
kernelName: "Rope",
|
|
132
|
-
backendName: "webgpu",
|
|
133
|
-
kernelFunc: L
|
|
134
|
-
};
|
|
135
|
-
w(P);
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
export {};
|
|
@@ -1,40 +0,0 @@
|
|
|
1
|
-
import { e as p } from "../../webgpu_program-B4HmApL1.js";
|
|
2
|
-
import { f as u, c as d } from "../../webgpu_util-DYlGSwOJ.js";
|
|
3
|
-
import { c as h, a6 as o } from "../../index-CUXkjxiT.js";
|
|
4
|
-
class b {
|
|
5
|
-
variableNames = ["labels", "softmaxProbs", "dy"];
|
|
6
|
-
outputShape;
|
|
7
|
-
shaderKey = "ScatterSub";
|
|
8
|
-
dispatchLayout;
|
|
9
|
-
dispatch;
|
|
10
|
-
workgroupSize = [64, 1, 1];
|
|
11
|
-
size = !0;
|
|
12
|
-
constructor(t, e) {
|
|
13
|
-
this.outputShape = [t, e], this.dispatchLayout = u(this.outputShape), this.dispatch = d(this.dispatchLayout, this.outputShape, this.workgroupSize);
|
|
14
|
-
}
|
|
15
|
-
getUserCode() {
|
|
16
|
-
return `
|
|
17
|
-
${p("index")} {
|
|
18
|
-
if (index < uniforms.size) {
|
|
19
|
-
let coords = getCoordsFromIndex(index); // [batch, depth]
|
|
20
|
-
let idx = i32(labels[coords[0]]);
|
|
21
|
-
let prob = softmaxProbs[index];
|
|
22
|
-
let dy = dy[coords[0]];
|
|
23
|
-
setOutputAtIndex(index, select(prob, prob - 1.0, idx == coords[1]) * dy);
|
|
24
|
-
}
|
|
25
|
-
}
|
|
26
|
-
`;
|
|
27
|
-
}
|
|
28
|
-
}
|
|
29
|
-
function f(a) {
|
|
30
|
-
const { logits: t, labels: e, dy: s } = a.inputs, c = a.backend, r = e.shape[0], i = t.shape[1];
|
|
31
|
-
o(s.shape, [r], "Error in EfficientScatterSub dy: "), o(t.shape, [r, i], "Error in EfficientScatterSub logits: "), o(e.shape, [r], "Error in EfficientScatterSub labels: ");
|
|
32
|
-
const n = new b(r, i);
|
|
33
|
-
return c.runWebGPUProgram(n, [e, t, s], "float32");
|
|
34
|
-
}
|
|
35
|
-
const l = {
|
|
36
|
-
kernelName: "EfficientScatterSub",
|
|
37
|
-
backendName: "webgpu",
|
|
38
|
-
kernelFunc: f
|
|
39
|
-
};
|
|
40
|
-
h(l);
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
import { WebGPUBackend } from '@tensorflow/tfjs-backend-webgpu';
|
|
2
|
-
import { SliceAttrs, SliceInputs, TensorInfo } from '@tensorflow/tfjs-core';
|
|
3
|
-
export declare function slice(args: {
|
|
4
|
-
inputs: SliceInputs;
|
|
5
|
-
backend: WebGPUBackend;
|
|
6
|
-
attrs: SliceAttrs;
|
|
7
|
-
}): TensorInfo;
|