@genai-fi/nanogpt 0.19.0 → 0.20.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +9 -10
- package/dist/Generator.d.ts +0 -82
- package/dist/Generator.js +0 -11941
- package/dist/RealDiv-CGwv0liw.js +0 -365
- package/dist/Reshape-BW__R4mZ.js +0 -79
- package/dist/Reshape-CPBkTIH2.js +0 -14
- package/dist/TeachableLLM.d.ts +0 -70
- package/dist/TeachableLLM.js +0 -273
- package/dist/Trainer.d.ts +0 -43
- package/dist/Trainer.js +0 -244
- package/dist/_commonjsHelpers-ByX85dGu.js +0 -33
- package/dist/axis_util-GTVlo58H.js +0 -55
- package/dist/backend.d.ts +0 -2
- package/dist/backend.js +0 -13
- package/dist/backend_util-GaFarB78.js +0 -425
- package/dist/backend_webgpu-BqASlsbV.js +0 -545
- package/dist/binary_op_util-pKXltfxI.js +0 -192
- package/dist/broadcast_to-eS93CCN_.js +0 -28
- package/dist/checks/appendCache.d.ts +0 -1
- package/dist/checks/appendCache.js +0 -22
- package/dist/checks/attentionMask.d.ts +0 -1
- package/dist/checks/attentionMask.js +0 -37
- package/dist/checks/check.d.ts +0 -9
- package/dist/checks/check.js +0 -20
- package/dist/checks/gelu.d.ts +0 -1
- package/dist/checks/gelu.js +0 -18
- package/dist/checks/index.d.ts +0 -26
- package/dist/checks/index.js +0 -28
- package/dist/checks/matMulGelu.d.ts +0 -1
- package/dist/checks/matMulGelu.js +0 -28
- package/dist/checks/normRMS.d.ts +0 -1
- package/dist/checks/normRMS.js +0 -16
- package/dist/checks/normRMSGrad.d.ts +0 -1
- package/dist/checks/normRMSGrad.js +0 -12
- package/dist/checks/packUnpack.d.ts +0 -1
- package/dist/checks/packUnpack.js +0 -18
- package/dist/checks/qkv.d.ts +0 -1
- package/dist/checks/qkv.js +0 -34
- package/dist/checks/rope.d.ts +0 -1
- package/dist/checks/rope.js +0 -36
- package/dist/checks/weights.d.ts +0 -14
- package/dist/checks/weights.js +0 -31
- package/dist/clip_by_value-DDA7rrcT.js +0 -12
- package/dist/complex-DI35Q-gW.js +0 -11
- package/dist/complex_util-Yc1A_gV1.js +0 -55
- package/dist/concat-CAQpCret.js +0 -17
- package/dist/concat_util-D18dJ4fD.js +0 -22
- package/dist/data/docx.d.ts +0 -2
- package/dist/data/docx.js +0 -15
- package/dist/data/parquet.d.ts +0 -2
- package/dist/data/parquet.js +0 -17
- package/dist/data/pdf.d.ts +0 -2
- package/dist/data/pdf.js +0 -14
- package/dist/data/textLoader.d.ts +0 -7
- package/dist/data/textLoader.js +0 -108
- package/dist/dataset-CGGp1z9P.js +0 -1124
- package/dist/dropout_util--NxWuYg2.js +0 -27
- package/dist/expand_dims-Bkd1YD5x.js +0 -11
- package/dist/exports_initializers-CYzKLjN7.js +0 -7
- package/dist/floor-BQtb-Azg.js +0 -9
- package/dist/gather-qIqEqaGn.js +0 -9
- package/dist/gelu-B220X1Go.js +0 -26
- package/dist/gpgpu_math-BwvV12df.js +0 -2022
- package/dist/index-CUXkjxiT.js +0 -3516
- package/dist/index-CieiGp4Y.js +0 -349
- package/dist/index-CjOWnMXP.js +0 -7308
- package/dist/index-Cp39cXWe.js +0 -1016
- package/dist/index-D5v913EJ.js +0 -4
- package/dist/index-DmeWGGmS.js +0 -1074
- package/dist/index-DvYrXKkX.js +0 -113
- package/dist/index-Ksja3su6.js +0 -151
- package/dist/index-xuotMAFm.js +0 -118
- package/dist/inference/types.d.ts +0 -16
- package/dist/inference/types.js +0 -1
- package/dist/jszip.min-BZhlzntC.js +0 -2313
- package/dist/kernel_funcs_utils-pq0CK9co.js +0 -306
- package/dist/layers/BaseLayer.d.ts +0 -44
- package/dist/layers/BaseLayer.js +0 -74
- package/dist/layers/CausalSelfAttention.d.ts +0 -39
- package/dist/layers/CausalSelfAttention.js +0 -86
- package/dist/layers/LoRA.d.ts +0 -14
- package/dist/layers/LoRA.js +0 -58
- package/dist/layers/MLP.d.ts +0 -17
- package/dist/layers/MLP.js +0 -44
- package/dist/layers/PositionEmbedding.d.ts +0 -8
- package/dist/layers/PositionEmbedding.js +0 -31
- package/dist/layers/RMSNorm.d.ts +0 -12
- package/dist/layers/RMSNorm.js +0 -22
- package/dist/layers/RoPECache.d.ts +0 -18
- package/dist/layers/RoPECache.js +0 -50
- package/dist/layers/TiedEmbedding.d.ts +0 -13
- package/dist/layers/TiedEmbedding.js +0 -36
- package/dist/layers/TransformerBlock.d.ts +0 -27
- package/dist/layers/TransformerBlock.js +0 -40
- package/dist/layers/WeightStore.d.ts +0 -20
- package/dist/layers/WeightStore.js +0 -76
- package/dist/loader/load.d.ts +0 -6
- package/dist/loader/load.js +0 -68
- package/dist/loader/loadHF.d.ts +0 -8
- package/dist/loader/loadHF.js +0 -22
- package/dist/loader/loadTransformers.d.ts +0 -4
- package/dist/loader/loadTransformers.js +0 -44
- package/dist/loader/loadZipMeta.d.ts +0 -3
- package/dist/loader/loadZipMeta.js +0 -16
- package/dist/loader/newZipLoad.d.ts +0 -3
- package/dist/loader/newZipLoad.js +0 -31
- package/dist/loader/oldZipLoad.d.ts +0 -9
- package/dist/loader/oldZipLoad.js +0 -80
- package/dist/loader/save.d.ts +0 -16
- package/dist/loader/save.js +0 -90
- package/dist/loader/types.d.ts +0 -67
- package/dist/loader/types.js +0 -1
- package/dist/main.d.ts +0 -50
- package/dist/main.js +0 -109
- package/dist/matMul16-BcVC_E62.js +0 -80
- package/dist/matMulGelu-JNLZqKQp.js +0 -163
- package/dist/mat_mul-DhG0Newp.js +0 -11
- package/dist/mod-CSdCpRjf.js +0 -11
- package/dist/models/NanoGPTV1.d.ts +0 -16
- package/dist/models/NanoGPTV1.js +0 -99
- package/dist/models/NanoGPTV2.d.ts +0 -16
- package/dist/models/NanoGPTV2.js +0 -90
- package/dist/models/config.d.ts +0 -27
- package/dist/models/config.js +0 -50
- package/dist/models/factory.d.ts +0 -3
- package/dist/models/factory.js +0 -16
- package/dist/models/model.d.ts +0 -44
- package/dist/models/model.js +0 -134
- package/dist/non_max_suppression_impl-B2W7YjZB.js +0 -102
- package/dist/not_equal-hurPF26l.js +0 -64
- package/dist/ones-BytntneX.js +0 -14
- package/dist/ops/adamAdjust.d.ts +0 -2
- package/dist/ops/adamAdjust.js +0 -9
- package/dist/ops/adamMoments.d.ts +0 -2
- package/dist/ops/adamMoments.js +0 -9
- package/dist/ops/add16.d.ts +0 -2
- package/dist/ops/add16.js +0 -9
- package/dist/ops/appendCache.d.ts +0 -2
- package/dist/ops/appendCache.js +0 -22
- package/dist/ops/attentionMask.d.ts +0 -2
- package/dist/ops/attentionMask.js +0 -10
- package/dist/ops/concat16.d.ts +0 -2
- package/dist/ops/concat16.js +0 -9
- package/dist/ops/cpu/adamAdjust.d.ts +0 -1
- package/dist/ops/cpu/adamAdjust.js +0 -18
- package/dist/ops/cpu/adamMoments.d.ts +0 -1
- package/dist/ops/cpu/adamMoments.js +0 -16
- package/dist/ops/cpu/appendCache.d.ts +0 -1
- package/dist/ops/cpu/appendCache.js +0 -23
- package/dist/ops/cpu/attentionMask.d.ts +0 -1
- package/dist/ops/cpu/attentionMask.js +0 -22
- package/dist/ops/cpu/fusedSoftmax.d.ts +0 -9
- package/dist/ops/cpu/fusedSoftmax.js +0 -29
- package/dist/ops/cpu/gatherSub.d.ts +0 -1
- package/dist/ops/cpu/gatherSub.js +0 -18
- package/dist/ops/cpu/gelu.d.ts +0 -1
- package/dist/ops/cpu/gelu.js +0 -40
- package/dist/ops/cpu/matMul16.d.ts +0 -1
- package/dist/ops/cpu/matMul16.js +0 -15
- package/dist/ops/cpu/matMulGelu.d.ts +0 -1
- package/dist/ops/cpu/matMulGelu.js +0 -53
- package/dist/ops/cpu/matMulMul.d.ts +0 -1
- package/dist/ops/cpu/matMulMul.js +0 -23
- package/dist/ops/cpu/mulDropout.d.ts +0 -1
- package/dist/ops/cpu/mulDropout.js +0 -23
- package/dist/ops/cpu/normRMS.d.ts +0 -1
- package/dist/ops/cpu/normRMS.js +0 -39
- package/dist/ops/cpu/qkv.d.ts +0 -5
- package/dist/ops/cpu/qkv.js +0 -41
- package/dist/ops/cpu/rope.d.ts +0 -6
- package/dist/ops/cpu/rope.js +0 -38
- package/dist/ops/cpu/scatterSub.d.ts +0 -1
- package/dist/ops/cpu/scatterSub.js +0 -23
- package/dist/ops/dot16.d.ts +0 -2
- package/dist/ops/dot16.js +0 -42
- package/dist/ops/dropout.d.ts +0 -2
- package/dist/ops/dropout.js +0 -14
- package/dist/ops/dropout16.d.ts +0 -2
- package/dist/ops/dropout16.js +0 -25
- package/dist/ops/gatherSub.d.ts +0 -2
- package/dist/ops/gatherSub.js +0 -9
- package/dist/ops/gelu.d.ts +0 -3
- package/dist/ops/gelu.js +0 -8
- package/dist/ops/globalNorm.d.ts +0 -2
- package/dist/ops/globalNorm.js +0 -13
- package/dist/ops/grads/add16.d.ts +0 -1
- package/dist/ops/grads/add16.js +0 -26
- package/dist/ops/grads/attentionMask.d.ts +0 -1
- package/dist/ops/grads/attentionMask.js +0 -21
- package/dist/ops/grads/dropout16.d.ts +0 -1
- package/dist/ops/grads/dropout16.js +0 -2
- package/dist/ops/grads/gelu.d.ts +0 -2
- package/dist/ops/grads/gelu.js +0 -5
- package/dist/ops/grads/matMul16.d.ts +0 -2
- package/dist/ops/grads/matMul16.js +0 -9
- package/dist/ops/grads/matMulGelu.d.ts +0 -1
- package/dist/ops/grads/matMulGelu.js +0 -17
- package/dist/ops/grads/mul16.d.ts +0 -1
- package/dist/ops/grads/mul16.js +0 -4
- package/dist/ops/grads/normRMS.d.ts +0 -3
- package/dist/ops/grads/normRMS.js +0 -33
- package/dist/ops/grads/pack16.d.ts +0 -2
- package/dist/ops/grads/pack16.js +0 -6
- package/dist/ops/grads/qkv.d.ts +0 -3
- package/dist/ops/grads/qkv.js +0 -34
- package/dist/ops/grads/rope.d.ts +0 -2
- package/dist/ops/grads/rope.js +0 -5
- package/dist/ops/grads/softmax16.d.ts +0 -2
- package/dist/ops/grads/softmax16.js +0 -25
- package/dist/ops/grads/unpack16.d.ts +0 -2
- package/dist/ops/grads/unpack16.js +0 -5
- package/dist/ops/grads/utils.d.ts +0 -4
- package/dist/ops/grads/utils.js +0 -14
- package/dist/ops/log.d.ts +0 -0
- package/dist/ops/log.js +0 -1
- package/dist/ops/matMul16.d.ts +0 -15
- package/dist/ops/matMul16.js +0 -13
- package/dist/ops/matMulGelu.d.ts +0 -3
- package/dist/ops/matMulGelu.js +0 -14
- package/dist/ops/matMulMul.d.ts +0 -2
- package/dist/ops/matMulMul.js +0 -9
- package/dist/ops/mul16.d.ts +0 -2
- package/dist/ops/mul16.js +0 -39
- package/dist/ops/mulDrop.d.ts +0 -2
- package/dist/ops/mulDrop.js +0 -9
- package/dist/ops/normRMS.d.ts +0 -2
- package/dist/ops/normRMS.js +0 -19
- package/dist/ops/pack16.d.ts +0 -2
- package/dist/ops/pack16.js +0 -5
- package/dist/ops/qkv.d.ts +0 -2
- package/dist/ops/qkv.js +0 -10
- package/dist/ops/reshape16.d.ts +0 -2
- package/dist/ops/reshape16.js +0 -41
- package/dist/ops/rope.d.ts +0 -3
- package/dist/ops/rope.js +0 -7
- package/dist/ops/scatterSub.d.ts +0 -2
- package/dist/ops/scatterSub.js +0 -9
- package/dist/ops/slice16.d.ts +0 -2
- package/dist/ops/slice16.js +0 -9
- package/dist/ops/softmax16.d.ts +0 -2
- package/dist/ops/softmax16.js +0 -9
- package/dist/ops/sub16.d.ts +0 -2
- package/dist/ops/sub16.js +0 -8
- package/dist/ops/sum16.d.ts +0 -2
- package/dist/ops/sum16.js +0 -13
- package/dist/ops/transpose16.d.ts +0 -3
- package/dist/ops/transpose16.js +0 -40
- package/dist/ops/unpack16.d.ts +0 -2
- package/dist/ops/unpack16.js +0 -6
- package/dist/ops/webgl/adamAdjust.d.ts +0 -1
- package/dist/ops/webgl/adamAdjust.js +0 -49
- package/dist/ops/webgl/adamMoments.d.ts +0 -1
- package/dist/ops/webgl/adamMoments.js +0 -40
- package/dist/ops/webgl/appendCache.d.ts +0 -1
- package/dist/ops/webgl/appendCache.js +0 -44
- package/dist/ops/webgl/attentionMask.d.ts +0 -1
- package/dist/ops/webgl/attentionMask.js +0 -45
- package/dist/ops/webgl/dropout16.d.ts +0 -1
- package/dist/ops/webgl/dropout16.js +0 -11
- package/dist/ops/webgl/fusedSoftmax.d.ts +0 -11
- package/dist/ops/webgl/fusedSoftmax.js +0 -80
- package/dist/ops/webgl/gatherSub.d.ts +0 -1
- package/dist/ops/webgl/gatherSub.js +0 -27
- package/dist/ops/webgl/gelu.d.ts +0 -2
- package/dist/ops/webgl/gelu.js +0 -50
- package/dist/ops/webgl/log.d.ts +0 -17
- package/dist/ops/webgl/log.js +0 -23
- package/dist/ops/webgl/matMul16.d.ts +0 -1
- package/dist/ops/webgl/matMul16.js +0 -45
- package/dist/ops/webgl/matMulGelu.d.ts +0 -21
- package/dist/ops/webgl/matMulGelu.js +0 -9
- package/dist/ops/webgl/matMulMul.d.ts +0 -14
- package/dist/ops/webgl/matMulMul.js +0 -28
- package/dist/ops/webgl/mulDropout.d.ts +0 -1
- package/dist/ops/webgl/mulDropout.js +0 -41
- package/dist/ops/webgl/normRMS.d.ts +0 -1
- package/dist/ops/webgl/normRMS.js +0 -93
- package/dist/ops/webgl/qkv.d.ts +0 -1
- package/dist/ops/webgl/qkv.js +0 -46
- package/dist/ops/webgl/rope.d.ts +0 -1
- package/dist/ops/webgl/rope.js +0 -56
- package/dist/ops/webgl/scatterSub.d.ts +0 -1
- package/dist/ops/webgl/scatterSub.js +0 -27
- package/dist/ops/webgpu/adamAdjust.d.ts +0 -1
- package/dist/ops/webgpu/adamAdjust.js +0 -57
- package/dist/ops/webgpu/adamMoments.d.ts +0 -1
- package/dist/ops/webgpu/adamMoments.js +0 -60
- package/dist/ops/webgpu/add16.d.ts +0 -1
- package/dist/ops/webgpu/add16.js +0 -13
- package/dist/ops/webgpu/appendCache.d.ts +0 -1
- package/dist/ops/webgpu/appendCache.js +0 -105
- package/dist/ops/webgpu/attentionMask.d.ts +0 -1
- package/dist/ops/webgpu/attentionMask.js +0 -26
- package/dist/ops/webgpu/attentionMask32_program.d.ts +0 -19
- package/dist/ops/webgpu/attentionMask32_program.js +0 -54
- package/dist/ops/webgpu/clipScale.d.ts +0 -1
- package/dist/ops/webgpu/clipScale.js +0 -58
- package/dist/ops/webgpu/concat16.d.ts +0 -19
- package/dist/ops/webgpu/concat16.js +0 -126
- package/dist/ops/webgpu/dropout16.d.ts +0 -1
- package/dist/ops/webgpu/dropout16.js +0 -51
- package/dist/ops/webgpu/gatherSub.d.ts +0 -1
- package/dist/ops/webgpu/gatherSub.js +0 -39
- package/dist/ops/webgpu/gelu.d.ts +0 -14
- package/dist/ops/webgpu/gelu.js +0 -141
- package/dist/ops/webgpu/index.d.ts +0 -0
- package/dist/ops/webgpu/index.js +0 -26
- package/dist/ops/webgpu/matMul16.d.ts +0 -1
- package/dist/ops/webgpu/matMul16.js +0 -65
- package/dist/ops/webgpu/matMul16_program.d.ts +0 -42
- package/dist/ops/webgpu/matMul16_program.js +0 -343
- package/dist/ops/webgpu/mul16.d.ts +0 -1
- package/dist/ops/webgpu/mul16.js +0 -13
- package/dist/ops/webgpu/norm2.d.ts +0 -1
- package/dist/ops/webgpu/norm2.js +0 -76
- package/dist/ops/webgpu/normRMS.d.ts +0 -1
- package/dist/ops/webgpu/normRMS.js +0 -34
- package/dist/ops/webgpu/normRMS16_program.d.ts +0 -10
- package/dist/ops/webgpu/normRMS16_program.js +0 -25
- package/dist/ops/webgpu/normRMS32_program.d.ts +0 -10
- package/dist/ops/webgpu/normRMS32_program.js +0 -25
- package/dist/ops/webgpu/normRMSGrad.d.ts +0 -1
- package/dist/ops/webgpu/normRMSGrad.js +0 -284
- package/dist/ops/webgpu/pack16.d.ts +0 -1
- package/dist/ops/webgpu/pack16.js +0 -18
- package/dist/ops/webgpu/pack16_program.d.ts +0 -19
- package/dist/ops/webgpu/pack16_program.js +0 -92
- package/dist/ops/webgpu/qkv.d.ts +0 -1
- package/dist/ops/webgpu/qkv.js +0 -24
- package/dist/ops/webgpu/rope.d.ts +0 -1
- package/dist/ops/webgpu/rope.js +0 -135
- package/dist/ops/webgpu/scatterSub.d.ts +0 -1
- package/dist/ops/webgpu/scatterSub.js +0 -40
- package/dist/ops/webgpu/slice16.d.ts +0 -7
- package/dist/ops/webgpu/slice16.js +0 -69
- package/dist/ops/webgpu/softmax16.d.ts +0 -17
- package/dist/ops/webgpu/softmax16.js +0 -21
- package/dist/ops/webgpu/softmax16_program.d.ts +0 -13
- package/dist/ops/webgpu/softmax16_program.js +0 -73
- package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +0 -17
- package/dist/ops/webgpu/softmax16_subgroup_program.js +0 -75
- package/dist/ops/webgpu/softmax16grad.d.ts +0 -1
- package/dist/ops/webgpu/softmax16grad.js +0 -37
- package/dist/ops/webgpu/sub16.d.ts +0 -1
- package/dist/ops/webgpu/sub16.js +0 -13
- package/dist/ops/webgpu/sum16.d.ts +0 -1
- package/dist/ops/webgpu/sum16.js +0 -38
- package/dist/ops/webgpu/transpose16.d.ts +0 -1
- package/dist/ops/webgpu/transpose16.js +0 -34
- package/dist/ops/webgpu/transpose16_program.d.ts +0 -16
- package/dist/ops/webgpu/transpose16_program.js +0 -50
- package/dist/ops/webgpu/transpose16_shared_program.d.ts +0 -15
- package/dist/ops/webgpu/transpose16_shared_program.js +0 -70
- package/dist/ops/webgpu/unpack16.d.ts +0 -1
- package/dist/ops/webgpu/unpack16.js +0 -48
- package/dist/ops/webgpu/utils/binary_op.d.ts +0 -35
- package/dist/ops/webgpu/utils/binary_op.js +0 -139
- package/dist/ops/webgpu/utils/deviceInfo.d.ts +0 -7
- package/dist/ops/webgpu/utils/deviceInfo.js +0 -11
- package/dist/ops/webgpu/utils/reductions.d.ts +0 -43
- package/dist/ops/webgpu/utils/reductions.js +0 -275
- package/dist/ops-CsXeTq1P.js +0 -476
- package/dist/pack16-bqltoUlR.js +0 -39
- package/dist/papaparse.min-C0cScC2i.js +0 -418
- package/dist/parquet-Bqjmp2vo.js +0 -44231
- package/dist/patches/webgpu_backend.d.ts +0 -18
- package/dist/patches/webgpu_backend.js +0 -56
- package/dist/patches/webgpu_base.d.ts +0 -21
- package/dist/patches/webgpu_base.js +0 -34
- package/dist/patches/webgpu_program.d.ts +0 -36
- package/dist/patches/webgpu_program.js +0 -400
- package/dist/pdf-NIhmP3sq.js +0 -19477
- package/dist/rand_util-CZ7yLoUm.js +0 -50
- package/dist/random_normal-IBRrha8a.js +0 -14
- package/dist/random_width-DN5ZtQkM.js +0 -9796
- package/dist/range-C-CjF-LI.js +0 -10
- package/dist/relu-J_X6MUzx.js +0 -9
- package/dist/reshape-BDOuCSNW.js +0 -9
- package/dist/resize_nearest_neighbor-BojqlfRe.js +0 -150
- package/dist/rope-DcrZM_e6.js +0 -24
- package/dist/scatter_nd_util-ByNJaL6I.js +0 -46
- package/dist/segment_util-Dasb2Zaf.js +0 -43
- package/dist/selu_util-BLhIqRkw.js +0 -44
- package/dist/shared-3agzAqQ_.js +0 -53
- package/dist/shared-CagdqkLh.js +0 -2143
- package/dist/slice-BzS11Qh0.js +0 -12
- package/dist/slice_util-CC35pLmT.js +0 -153
- package/dist/softmax-D4q1LJN7.js +0 -12
- package/dist/split-C2Sj255c.js +0 -9
- package/dist/squeeze-ho4wLUek.js +0 -10
- package/dist/stack-DudVrtmG.js +0 -11
- package/dist/step-BTxPtq1r.js +0 -261
- package/dist/sum-BpiwSWvg.js +0 -11
- package/dist/tensor-BWFldCso.js +0 -8
- package/dist/tensor1d-LMGMIUlr.js +0 -11
- package/dist/tensor2d-BnXMKScO.js +0 -14
- package/dist/tensor4d-C6UCG_u8.js +0 -14
- package/dist/tfjs_backend-BGnG-ppu.js +0 -654
- package/dist/tile-CFy-xTO6.js +0 -11
- package/dist/tokeniser/BaseTokeniser.d.ts +0 -33
- package/dist/tokeniser/BaseTokeniser.js +0 -124
- package/dist/tokeniser/CharTokeniser.d.ts +0 -24
- package/dist/tokeniser/CharTokeniser.js +0 -107
- package/dist/tokeniser/bpe.d.ts +0 -28
- package/dist/tokeniser/bpe.js +0 -173
- package/dist/tokeniser/messages.d.ts +0 -61
- package/dist/tokeniser/messages.js +0 -1
- package/dist/tokeniser/type.d.ts +0 -34
- package/dist/tokeniser/type.js +0 -1
- package/dist/training/AdamW.d.ts +0 -36
- package/dist/training/AdamW.js +0 -138
- package/dist/training/BasicTrainer.d.ts +0 -63
- package/dist/training/BasicTrainer.js +0 -265
- package/dist/training/DatasetBuilder.d.ts +0 -26
- package/dist/training/DatasetBuilder.js +0 -86
- package/dist/training/Evaluator.d.ts +0 -19
- package/dist/training/Evaluator.js +0 -39
- package/dist/training/LRScheduler.d.ts +0 -12
- package/dist/training/LRScheduler.js +0 -34
- package/dist/training/PreTrainer.d.ts +0 -11
- package/dist/training/PreTrainer.js +0 -20
- package/dist/training/SFTTrainer.d.ts +0 -12
- package/dist/training/SFTTrainer.js +0 -22
- package/dist/training/loss.d.ts +0 -3
- package/dist/training/loss.js +0 -24
- package/dist/training/orthoGrad.d.ts +0 -2
- package/dist/training/orthoGrad.js +0 -10
- package/dist/training/sparseCrossEntropy.d.ts +0 -7
- package/dist/training/sparseCrossEntropy.js +0 -69
- package/dist/training/tasks/ConversationTask.d.ts +0 -18
- package/dist/training/tasks/ConversationTask.js +0 -40
- package/dist/training/tasks/PretrainingTask.d.ts +0 -17
- package/dist/training/tasks/PretrainingTask.js +0 -47
- package/dist/training/tasks/StartSentenceTask.d.ts +0 -18
- package/dist/training/tasks/StartSentenceTask.js +0 -49
- package/dist/training/tasks/Task.d.ts +0 -22
- package/dist/training/tasks/Task.js +0 -68
- package/dist/training/tasks/splitter.d.ts +0 -5
- package/dist/training/tasks/splitter.js +0 -21
- package/dist/training/types.d.ts +0 -78
- package/dist/training/types.js +0 -1
- package/dist/training/validation.d.ts +0 -17
- package/dist/training/validation.js +0 -84
- package/dist/transpose-9kRxIXWR.js +0 -36
- package/dist/unsorted_segment_sum-DJvk5xnh.js +0 -277
- package/dist/utilities/arrayClose.d.ts +0 -1
- package/dist/utilities/arrayClose.js +0 -20
- package/dist/utilities/datasetID.d.ts +0 -2
- package/dist/utilities/datasetID.js +0 -21
- package/dist/utilities/dummy.d.ts +0 -9
- package/dist/utilities/dummy.js +0 -43
- package/dist/utilities/multinomialCPU.d.ts +0 -2
- package/dist/utilities/multinomialCPU.js +0 -13
- package/dist/utilities/naming.d.ts +0 -4
- package/dist/utilities/naming.js +0 -1
- package/dist/utilities/packed.d.ts +0 -4
- package/dist/utilities/packed.js +0 -15
- package/dist/utilities/parameters.d.ts +0 -11
- package/dist/utilities/parameters.js +0 -57
- package/dist/utilities/performance.d.ts +0 -2
- package/dist/utilities/performance.js +0 -16
- package/dist/utilities/profile.d.ts +0 -17
- package/dist/utilities/profile.js +0 -38
- package/dist/utilities/safetensors.d.ts +0 -3
- package/dist/utilities/safetensors.js +0 -83
- package/dist/utilities/sentences.d.ts +0 -5
- package/dist/utilities/sentences.js +0 -41
- package/dist/utilities/tokenParse.d.ts +0 -1
- package/dist/utilities/tokenParse.js +0 -21
- package/dist/utilities/topP.d.ts +0 -1
- package/dist/utilities/topP.js +0 -13
- package/dist/utilities/waitForModel.d.ts +0 -2
- package/dist/utilities/waitForModel.js +0 -12
- package/dist/utilities/weights.d.ts +0 -12
- package/dist/utilities/weights.js +0 -45
- package/dist/utilities/yielder.d.ts +0 -1
- package/dist/utilities/yielder.js +0 -7
- package/dist/variable-Ck482e3n.js +0 -7
- package/dist/webgpu_program-B4HmApL1.js +0 -525
- package/dist/webgpu_util-DYlGSwOJ.js +0 -64
- package/dist/zeros-DvZpK8s6.js +0 -13
- package/dist/zeros_like-CWjDdwr-.js +0 -721
|
@@ -1,84 +0,0 @@
|
|
|
1
|
-
import "../index-CUXkjxiT.js";
|
|
2
|
-
import "../random_width-DN5ZtQkM.js";
|
|
3
|
-
import "../zeros_like-CWjDdwr-.js";
|
|
4
|
-
import "../Generator.js";
|
|
5
|
-
import "../index-Cp39cXWe.js";
|
|
6
|
-
import "../dataset-CGGp1z9P.js";
|
|
7
|
-
import "../ops/cpu/attentionMask.js";
|
|
8
|
-
import "../ops/webgl/attentionMask.js";
|
|
9
|
-
import "../ops/grads/attentionMask.js";
|
|
10
|
-
import "../ops/cpu/rope.js";
|
|
11
|
-
import "../ops/webgl/rope.js";
|
|
12
|
-
import "../rope-DcrZM_e6.js";
|
|
13
|
-
import "../ops/cpu/appendCache.js";
|
|
14
|
-
import "../ops/webgl/appendCache.js";
|
|
15
|
-
import "../ops/grads/softmax16.js";
|
|
16
|
-
import "../matMul16-BcVC_E62.js";
|
|
17
|
-
import "../ops/webgl/matMul16.js";
|
|
18
|
-
import "../ops/cpu/matMul16.js";
|
|
19
|
-
import "../pack16-bqltoUlR.js";
|
|
20
|
-
import "../ops/transpose16.js";
|
|
21
|
-
import "../ops/reshape16.js";
|
|
22
|
-
import "../ops/cpu/qkv.js";
|
|
23
|
-
import "../ops/webgl/qkv.js";
|
|
24
|
-
import "../ops/grads/qkv.js";
|
|
25
|
-
import "../ops/cpu/normRMS.js";
|
|
26
|
-
import "../ops/webgl/normRMS.js";
|
|
27
|
-
import "../ops/grads/normRMS.js";
|
|
28
|
-
import "../ops/dropout16.js";
|
|
29
|
-
import "../ops/webgl/dropout16.js";
|
|
30
|
-
import "../ops/grads/add16.js";
|
|
31
|
-
import "../jszip.min-BZhlzntC.js";
|
|
32
|
-
import "../index-DvYrXKkX.js";
|
|
33
|
-
import "../ops/cpu/adamAdjust.js";
|
|
34
|
-
import "../ops/webgl/adamAdjust.js";
|
|
35
|
-
import "../ops/cpu/adamMoments.js";
|
|
36
|
-
import "../ops/webgl/adamMoments.js";
|
|
37
|
-
import "../ops/cpu/gatherSub.js";
|
|
38
|
-
import "../ops/webgl/gatherSub.js";
|
|
39
|
-
import "../ops/cpu/scatterSub.js";
|
|
40
|
-
import "../ops/webgl/scatterSub.js";
|
|
41
|
-
import { PAGE_FACTOR as a, shuffle as k } from "./DatasetBuilder.js";
|
|
42
|
-
import "../papaparse.min-C0cScC2i.js";
|
|
43
|
-
import { tokensFromTasks as y } from "./tasks/Task.js";
|
|
44
|
-
import "../ops/cpu/matMulGelu.js";
|
|
45
|
-
import "../matMulGelu-JNLZqKQp.js";
|
|
46
|
-
import "../ops/grads/matMulGelu.js";
|
|
47
|
-
import "../ops/cpu/gelu.js";
|
|
48
|
-
import "../ops/webgl/gelu.js";
|
|
49
|
-
import "../gelu-B220X1Go.js";
|
|
50
|
-
import "../ops/webgl/log.js";
|
|
51
|
-
import "../checks/normRMS.js";
|
|
52
|
-
import "../checks/normRMSGrad.js";
|
|
53
|
-
async function Mt(n, x, o, l, c = 0.1, d) {
|
|
54
|
-
const r = n instanceof Uint16Array ? n : await y(n, x, void 0, d), i = r instanceof Uint16Array ? r : r.tokens, f = r instanceof Uint16Array ? void 0 : r.mask, m = /* @__PURE__ */ new Set();
|
|
55
|
-
if (c > 0) {
|
|
56
|
-
const t = Math.floor(i.length / (o.blockSize * a)), s = Math.max(1, Math.floor(t * c));
|
|
57
|
-
for (; m.size < s; ) {
|
|
58
|
-
const M = Math.floor(Math.random() * t);
|
|
59
|
-
m.add(M);
|
|
60
|
-
}
|
|
61
|
-
}
|
|
62
|
-
const p = new Uint32Array(
|
|
63
|
-
i.length - m.size * o.blockSize * a
|
|
64
|
-
), e = new Uint32Array(m.size * o.blockSize * a);
|
|
65
|
-
let h = 0, g = 0;
|
|
66
|
-
for (let t = 0; t < i.length; t++) {
|
|
67
|
-
const s = Math.floor(t / (o.blockSize * a));
|
|
68
|
-
m.has(s) ? g < e.length && (e[g++] = t) : h < p.length && (p[h++] = t);
|
|
69
|
-
}
|
|
70
|
-
const { dataset: v, state: w } = await o.createTextDataset(
|
|
71
|
-
i,
|
|
72
|
-
l,
|
|
73
|
-
k(p),
|
|
74
|
-
f || void 0
|
|
75
|
-
), { dataset: z, state: A } = await o.createTextDataset(
|
|
76
|
-
i,
|
|
77
|
-
l,
|
|
78
|
-
k(e)
|
|
79
|
-
);
|
|
80
|
-
return { trainDataset: v, validationDataset: z, size: i.length, validationState: A, trainState: w };
|
|
81
|
-
}
|
|
82
|
-
export {
|
|
83
|
-
Mt as createTrainValidationSplit
|
|
84
|
-
};
|
|
@@ -1,36 +0,0 @@
|
|
|
1
|
-
import { o as u, q as i, E as o, ao as $, ap as g, aq as x, x as l, t as m, ar as p } from "./index-CUXkjxiT.js";
|
|
2
|
-
import { c as k } from "./complex-DI35Q-gW.js";
|
|
3
|
-
function K(r) {
|
|
4
|
-
const e = { input: i(r, "input", "imag") };
|
|
5
|
-
return o.runKernel($, e);
|
|
6
|
-
}
|
|
7
|
-
const h = /* @__PURE__ */ u({ imag_: K });
|
|
8
|
-
function E(r) {
|
|
9
|
-
const e = { x: i(r, "x", "neg") };
|
|
10
|
-
return o.runKernel(g, e);
|
|
11
|
-
}
|
|
12
|
-
const _ = /* @__PURE__ */ u({ neg_: E });
|
|
13
|
-
function b(r) {
|
|
14
|
-
const e = { input: i(r, "input", "real") };
|
|
15
|
-
return o.runKernel(x, e);
|
|
16
|
-
}
|
|
17
|
-
const d = /* @__PURE__ */ u({ real_: b });
|
|
18
|
-
function N(r, t, e) {
|
|
19
|
-
const n = i(r, "x", "transpose");
|
|
20
|
-
if (t == null && (t = n.shape.map((s, a) => a).reverse()), l(n.rank === t.length, () => `Error in transpose: rank of input ${n.rank} must match length of perm ${t}.`), t.forEach((s) => {
|
|
21
|
-
l(s >= 0 && s < n.rank, () => `All entries in 'perm' must be between 0 and ${n.rank - 1} but got ${t}`);
|
|
22
|
-
}), n.rank <= 1)
|
|
23
|
-
return n.clone();
|
|
24
|
-
const f = { x: n }, c = { perm: t };
|
|
25
|
-
return n.dtype === "complex64" ? m(() => {
|
|
26
|
-
let s = d(n), a = h(n);
|
|
27
|
-
return s = o.runKernel(p, { x: s }, c), a = o.runKernel(p, { x: a }, c), e && (a = _(a)), k(s, a);
|
|
28
|
-
}) : o.runKernel(p, f, c);
|
|
29
|
-
}
|
|
30
|
-
const v = /* @__PURE__ */ u({ transpose_: N });
|
|
31
|
-
export {
|
|
32
|
-
h as i,
|
|
33
|
-
_ as n,
|
|
34
|
-
d as r,
|
|
35
|
-
v as t
|
|
36
|
-
};
|
|
@@ -1,277 +0,0 @@
|
|
|
1
|
-
import { o as h, q as c, E as d, bm as T, bn as q, bo as H, x as l, bp as P, L as _, bq as y, br as B, bs as I, bt as W, bu as A, bv as L, bw as G, bx as O, by as z, bz as F, B as M, _ as j, bA as J, bB as U, bC as V, a2 as Q, a1 as N, m as X, bD as Y, bE as Z, bF as R, bG as nn, bH as tn, bI as sn, bJ as en, bK as rn, bL as on, bM as an, bN as un, aE as cn, bO as ln } from "./index-CUXkjxiT.js";
|
|
2
|
-
import { k as C, c as g, m as D } from "./step-BTxPtq1r.js";
|
|
3
|
-
import { r as b } from "./reshape-BDOuCSNW.js";
|
|
4
|
-
import { m as pn, b as hn, e as w } from "./not_equal-hurPF26l.js";
|
|
5
|
-
import { s as K } from "./sum-BpiwSWvg.js";
|
|
6
|
-
function fn(s, n = null, t = !1) {
|
|
7
|
-
const i = { x: c(s, "x", "all", "bool") }, o = { axis: n, keepDims: t };
|
|
8
|
-
return d.runKernel(T, i, o);
|
|
9
|
-
}
|
|
10
|
-
const nt = /* @__PURE__ */ h({ all_: fn });
|
|
11
|
-
function dn(s, n = null, t = !1) {
|
|
12
|
-
const i = { x: c(s, "x", "any", "bool") }, o = { axis: n, keepDims: t };
|
|
13
|
-
return d.runKernel(q, i, o);
|
|
14
|
-
}
|
|
15
|
-
const tt = /* @__PURE__ */ h({ any_: dn });
|
|
16
|
-
function mn(s, n = 0) {
|
|
17
|
-
const e = { x: c(s, "x", "argMax") }, i = { axis: n };
|
|
18
|
-
return d.runKernel(H, e, i);
|
|
19
|
-
}
|
|
20
|
-
const st = /* @__PURE__ */ h({ argMax_: mn });
|
|
21
|
-
function $n(s, n, t, e, i) {
|
|
22
|
-
const o = c(s, "x", "avgPool", "float32"), p = 1;
|
|
23
|
-
l(C(t, p), () => `Error in avgPool: Either strides or dilations must be 1. Got strides ${t} and dilations '${p}'`);
|
|
24
|
-
let r = o, a = !1;
|
|
25
|
-
o.rank === 3 && (a = !0, r = b(o, [1, o.shape[0], o.shape[1], o.shape[2]])), l(r.rank === 4, () => `Error in avgPool: x must be rank 4 but got rank ${r.rank}.`), g("avgPool", e, i);
|
|
26
|
-
const u = { x: r }, m = { filterSize: n, strides: t, pad: e, dimRoundingMode: i };
|
|
27
|
-
let f = d.runKernel(P, u, m);
|
|
28
|
-
return f = _(f, o.dtype), a ? b(f, [f.shape[1], f.shape[2], f.shape[3]]) : f;
|
|
29
|
-
}
|
|
30
|
-
const et = /* @__PURE__ */ h({ avgPool_: $n });
|
|
31
|
-
function bn(s) {
|
|
32
|
-
const t = { x: c(s, "x", "tanh", "float32") };
|
|
33
|
-
return d.runKernel(y, t);
|
|
34
|
-
}
|
|
35
|
-
const rt = /* @__PURE__ */ h({ tanh_: bn });
|
|
36
|
-
function xn(s, n, t) {
|
|
37
|
-
const e = c(s, "x", "batchToSpaceND"), i = n.reduce((r, a) => r * a);
|
|
38
|
-
l(e.rank >= 1 + n.length, () => `input rank is ${e.rank} but should be > than blockShape.length ${n.length}`), l(t.length === n.length, () => `crops.length is ${t.length} but should be equal to blockShape.length ${n.length}`), l(e.shape[0] % i === 0, () => `input tensor batch is ${e.shape[0]} but is not divisible by the product of the elements of blockShape ${n.join(" * ")} === ${i}`);
|
|
39
|
-
const o = { x: e }, p = { blockShape: n, crops: t };
|
|
40
|
-
return d.runKernel(B, o, p);
|
|
41
|
-
}
|
|
42
|
-
const ot = /* @__PURE__ */ h({ batchToSpaceND_: xn });
|
|
43
|
-
function kn(s) {
|
|
44
|
-
let n;
|
|
45
|
-
return s.rank === 0 || s.rank === 1 ? n = b(s, [1, 1, 1, s.size]) : s.rank === 2 ? n = b(s, [1, 1, s.shape[0], s.shape[1]]) : s.rank === 3 ? n = b(s, [1, s.shape[0], s.shape[1], s.shape[2]]) : n = s, n;
|
|
46
|
-
}
|
|
47
|
-
function vn(s, n, t, e, i, o) {
|
|
48
|
-
o == null && (o = 1e-3);
|
|
49
|
-
const p = c(s, "x", "batchNorm"), r = c(n, "mean", "batchNorm"), a = c(t, "variance", "batchNorm");
|
|
50
|
-
let u;
|
|
51
|
-
i != null && (u = c(i, "scale", "batchNorm"));
|
|
52
|
-
let m;
|
|
53
|
-
e != null && (m = c(e, "offset", "batchNorm")), l(r.rank === a.rank, () => "Batch normalization gradient requires mean and variance to have equal ranks."), l(m == null || r.rank === m.rank, () => "Batch normalization gradient requires mean and offset to have equal ranks."), l(u == null || r.rank === u.rank, () => "Batch normalization gradient requires mean and scale to have equal ranks.");
|
|
54
|
-
const x = {
|
|
55
|
-
x: kn(p),
|
|
56
|
-
scale: u,
|
|
57
|
-
offset: m,
|
|
58
|
-
mean: r,
|
|
59
|
-
variance: a
|
|
60
|
-
}, k = { varianceEpsilon: o }, $ = d.runKernel(I, x, k);
|
|
61
|
-
return b($, p.shape);
|
|
62
|
-
}
|
|
63
|
-
const at = /* @__PURE__ */ h({ batchNorm_: vn });
|
|
64
|
-
function gn(s, n, t, e, i = "NHWC", o = [1, 1], p) {
|
|
65
|
-
const r = c(s, "x", "conv2d", "float32"), a = c(n, "filter", "conv2d", "float32");
|
|
66
|
-
let u = r, m = !1;
|
|
67
|
-
r.rank === 3 && (m = !0, u = b(r, [1, r.shape[0], r.shape[1], r.shape[2]])), l(u.rank === 4, () => `Error in conv2d: input must be rank 4, but got rank ${u.rank}.`), l(a.rank === 4, () => `Error in conv2d: filter must be rank 4, but got rank ${a.rank}.`), g("conv2d", e, p);
|
|
68
|
-
const f = i === "NHWC" ? u.shape[3] : u.shape[1];
|
|
69
|
-
l(f === a.shape[2], () => `Error in conv2d: depth of input (${f}) must match input depth for filter ${a.shape[2]}.`), l(C(t, o), () => `Error in conv2D: Either strides or dilations must be 1. Got strides ${t} and dilations '${o}'`), l(D(o), () => "Error in conv2D: Dilated rates should be larger than 0."), l(D(t), () => "Error in conv2D: Strides should be larger than 0.");
|
|
70
|
-
const x = { x: u, filter: a }, k = { strides: t, pad: e, dataFormat: i, dilations: o, dimRoundingMode: p }, $ = d.runKernel(W, x, k);
|
|
71
|
-
return m ? b($, [$.shape[1], $.shape[2], $.shape[3]]) : $;
|
|
72
|
-
}
|
|
73
|
-
const S = /* @__PURE__ */ h({ conv2d_: gn });
|
|
74
|
-
function Dn(s, n, t, e, i = "NWC", o = 1, p) {
|
|
75
|
-
const r = c(s, "x", "conv1d"), a = c(n, "filter", "conv1d");
|
|
76
|
-
let u = r, m = !1;
|
|
77
|
-
r.rank === 2 && (m = !0, u = b(r, [1, r.shape[0], r.shape[1]])), l(u.rank === 3, () => `Error in conv1d: input must be rank 3, but got rank ${u.rank}.`), l(a.rank === 3, () => `Error in conv1d: filter must be rank 3, but got rank ${a.rank}.`), g("conv1d", e, p), l(u.shape[2] === a.shape[1], () => `Error in conv1d: depth of input (${u.shape[2]}) must match input depth for filter ${a.shape[1]}.`), l(C(t, o), () => `Error in conv1D: Either stride or dilation must be 1. Got stride ${t} and dilation '${o}'`), l(D(o), () => "Error in conv1D: Dilated rates should be larger than 0."), l(D(t), () => "Error in conv1D: Stride should be larger than 0."), l(i === "NWC", () => `Error in conv1d: got dataFormat of ${i} but only NWC is currently supported.`);
|
|
78
|
-
const f = b(a, [1, a.shape[0], a.shape[1], a.shape[2]]), x = b(u, [u.shape[0], 1, u.shape[1], u.shape[2]]), v = S(x, f, [1, t], e, "NHWC", [1, o], p);
|
|
79
|
-
return m ? b(v, [v.shape[2], v.shape[3]]) : b(v, [v.shape[0], v.shape[2], v.shape[3]]);
|
|
80
|
-
}
|
|
81
|
-
const it = /* @__PURE__ */ h({ conv1d_: Dn });
|
|
82
|
-
function Cn(s, n, t, e, i, o = "NHWC", p) {
|
|
83
|
-
l(s.length === n.rank, () => `Length of inShape (${s.length}) and rank of dy (${n.rank}) must match`);
|
|
84
|
-
let r = s, a = n, u = !1;
|
|
85
|
-
n.rank === 3 && (u = !0, a = b(n, [1, n.shape[0], n.shape[1], n.shape[2]]), r = [1, s[0], s[1], s[2]]), l(r.length === 4, () => `Error in conv2dDerInput: inShape must be length 4, but got length ${r.length}.`), l(a.rank === 4, () => `Error in conv2dDerInput: dy must be rank 4, but got rank ${a.rank}`), l(t.rank === 4, () => `Error in conv2dDerInput: filter must be rank 4, but got rank ${t.rank}`);
|
|
86
|
-
const m = o === "NHWC" ? r[3] : r[1], f = o === "NHWC" ? a.shape[3] : a.shape[1];
|
|
87
|
-
l(m === t.shape[2], () => `Error in conv2dDerInput: depth of input (${m}) must match input depth for filter ${t.shape[2]}.`), l(f === t.shape[3], () => `Error in conv2dDerInput: depth of output (${f}) must match output depth for filter ${t.shape[3]}.`), g("conv2dDerInput", i, p);
|
|
88
|
-
const x = { dy: a, filter: t }, k = { strides: e, pad: i, dataFormat: o, dimRoundingMode: p, inputShape: r }, $ = d.runKernel(A, x, k);
|
|
89
|
-
return u ? b($, [$.shape[1], $.shape[2], $.shape[3]]) : $;
|
|
90
|
-
}
|
|
91
|
-
const En = /* @__PURE__ */ h({ conv2DBackpropInput_: Cn });
|
|
92
|
-
function Nn(s, n, t, e, i, o) {
|
|
93
|
-
const p = c(s, "x", "conv2dTranspose"), r = c(n, "filter", "conv2dTranspose");
|
|
94
|
-
return En(t, p, r, e, i, "NHWC", o);
|
|
95
|
-
}
|
|
96
|
-
const ut = /* @__PURE__ */ h({ conv2dTranspose_: Nn });
|
|
97
|
-
function _n(s) {
|
|
98
|
-
const t = { x: c(s, "x", "cos", "float32") };
|
|
99
|
-
return d.runKernel(L, t);
|
|
100
|
-
}
|
|
101
|
-
const ct = /* @__PURE__ */ h({ cos_: _n });
|
|
102
|
-
function wn(s) {
|
|
103
|
-
const t = { x: c(s, "x", "cosh", "float32") };
|
|
104
|
-
return d.runKernel(G, t);
|
|
105
|
-
}
|
|
106
|
-
const lt = /* @__PURE__ */ h({ cosh_: wn });
|
|
107
|
-
function Kn(s, n = 0, t = !1, e = !1) {
|
|
108
|
-
const o = { x: c(s, "x", "cumprod") }, p = { axis: n, exclusive: t, reverse: e };
|
|
109
|
-
return d.runKernel(O, o, p);
|
|
110
|
-
}
|
|
111
|
-
const pt = /* @__PURE__ */ h({ cumprod_: Kn });
|
|
112
|
-
function Sn(s, n = 0, t = !1, e = !1) {
|
|
113
|
-
const o = { x: c(s, "x", "cumsum") }, p = { axis: n, exclusive: t, reverse: e };
|
|
114
|
-
return d.runKernel(z, o, p);
|
|
115
|
-
}
|
|
116
|
-
const ht = /* @__PURE__ */ h({ cumsum_: Sn });
|
|
117
|
-
function Tn(s, n, t, e, i = "NHWC", o = [1, 1], p) {
|
|
118
|
-
const r = c(s, "x", "depthwiseConv2d", "float32"), a = c(n, "filter", "depthwiseConv2d", "float32");
|
|
119
|
-
let u = r, m = !1;
|
|
120
|
-
r.rank === 3 && (m = !0, u = b(r, [1, r.shape[0], r.shape[1], r.shape[2]])), l(u.rank === 4, () => `Error in depthwiseConv2d: input must be rank 4, but got rank ${u.rank}.`), l(a.rank === 4, () => `Error in depthwiseConv2d: filter must be rank 4, but got rank ${a.rank}.`);
|
|
121
|
-
const f = i === "NHWC" ? u.shape[3] : u.shape[1];
|
|
122
|
-
l(f === a.shape[2], () => `Error in depthwiseConv2d: number of input channels (${f}) must match the inChannels dimension in filter ${a.shape[2]}.`), g("depthwiseConv2d", e, p);
|
|
123
|
-
const x = { x: u, filter: a }, k = { strides: t, pad: e, dataFormat: i, dilations: o, dimRoundingMode: p }, $ = d.runKernel(F, x, k);
|
|
124
|
-
return m ? b($, [$.shape[1], $.shape[2], $.shape[3]]) : $;
|
|
125
|
-
}
|
|
126
|
-
const qn = /* @__PURE__ */ h({ depthwiseConv2d_: Tn });
|
|
127
|
-
function Hn(s, n) {
|
|
128
|
-
let t = c(s, "a", "equal", "string_or_numeric"), e = c(n, "b", "equal", "string_or_numeric");
|
|
129
|
-
[t, e] = M(t, e), j(t.shape, e.shape);
|
|
130
|
-
const i = { a: t, b: e };
|
|
131
|
-
return d.runKernel(J, i);
|
|
132
|
-
}
|
|
133
|
-
const ft = /* @__PURE__ */ h({ equal_: Hn });
|
|
134
|
-
function Pn(s) {
|
|
135
|
-
let n = c(s, "x", "erf");
|
|
136
|
-
l(n.dtype === "int32" || n.dtype === "float32", () => "Input dtype must be `int32` or `float32`."), n.dtype === "int32" && (n = _(n, "float32"));
|
|
137
|
-
const t = { x: n };
|
|
138
|
-
return d.runKernel(U, t);
|
|
139
|
-
}
|
|
140
|
-
const dt = /* @__PURE__ */ h({ erf_: Pn });
|
|
141
|
-
function yn(s) {
|
|
142
|
-
const t = { x: c(s, "x", "softplus") };
|
|
143
|
-
return d.runKernel(V, t);
|
|
144
|
-
}
|
|
145
|
-
const mt = /* @__PURE__ */ h({ softplus_: yn });
|
|
146
|
-
function Bn(s, n = -1) {
|
|
147
|
-
const t = c(s, "logits", "logSoftmax");
|
|
148
|
-
if (n === -1 && (n = t.rank - 1), n !== t.rank - 1)
|
|
149
|
-
throw Error(`Log Softmax along a non-last dimension is not yet supported. Logits was rank ${t.rank} and axis was ${n}`);
|
|
150
|
-
return Q((i, o) => {
|
|
151
|
-
const r = pn(i, n, !0), a = N(i, r), u = N(_(a, "float32"), hn(K(w(a), n, !0)));
|
|
152
|
-
return o([u]), { value: u, gradFunc: (f, x) => {
|
|
153
|
-
const [k] = x, $ = !0, E = w(k);
|
|
154
|
-
return N(f, X(K(f, n, $), E));
|
|
155
|
-
} };
|
|
156
|
-
})(t);
|
|
157
|
-
}
|
|
158
|
-
const $t = /* @__PURE__ */ h({ logSoftmax_: Bn });
|
|
159
|
-
function In(s) {
|
|
160
|
-
const t = { x: c(s, "x", "logicalNot", "bool") };
|
|
161
|
-
return d.runKernel(Y, t);
|
|
162
|
-
}
|
|
163
|
-
const bt = /* @__PURE__ */ h({ logicalNot_: In });
|
|
164
|
-
function Wn(s, n, t, e, i) {
|
|
165
|
-
const o = c(s, "x", "maxPool"), p = 1;
|
|
166
|
-
let r = o, a = !1;
|
|
167
|
-
o.rank === 3 && (a = !0, r = b(o, [1, o.shape[0], o.shape[1], o.shape[2]])), l(r.rank === 4, () => `Error in maxPool: input must be rank 4 but got rank ${r.rank}.`), l(C(t, p), () => `Error in maxPool: Either strides or dilations must be 1. Got strides ${t} and dilations '${p}'`), g("maxPool", e, i);
|
|
168
|
-
const u = { x: r }, m = { filterSize: n, strides: t, pad: e, dimRoundingMode: i }, f = d.runKernel(Z, u, m);
|
|
169
|
-
return a ? b(f, [f.shape[1], f.shape[2], f.shape[3]]) : f;
|
|
170
|
-
}
|
|
171
|
-
const xt = /* @__PURE__ */ h({ maxPool_: Wn });
|
|
172
|
-
function An(s, n, t = 1, e = 0, i = "int32") {
|
|
173
|
-
if (n < 2)
|
|
174
|
-
throw new Error(`Error in oneHot: depth must be >=2, but it is ${n}`);
|
|
175
|
-
const p = { indices: c(s, "indices", "oneHot", "int32") }, r = { dtype: i, depth: n, onValue: t, offValue: e };
|
|
176
|
-
return d.runKernel(R, p, r);
|
|
177
|
-
}
|
|
178
|
-
const kt = /* @__PURE__ */ h({ oneHot_: An });
|
|
179
|
-
function Ln(s) {
|
|
180
|
-
const t = { x: c(s, "x", "onesLike") };
|
|
181
|
-
return d.runKernel(nn, t);
|
|
182
|
-
}
|
|
183
|
-
const vt = /* @__PURE__ */ h({ onesLike_: Ln });
|
|
184
|
-
function Gn(s, n, t = 0) {
|
|
185
|
-
const e = c(s, "x", "pad");
|
|
186
|
-
if (e.rank === 0)
|
|
187
|
-
throw new Error("pad(scalar) is not defined. Pass non-scalar to pad");
|
|
188
|
-
const i = { paddings: n, constantValue: t }, o = { x: e };
|
|
189
|
-
return d.runKernel(tn, o, i);
|
|
190
|
-
}
|
|
191
|
-
const gt = /* @__PURE__ */ h({ pad_: Gn });
|
|
192
|
-
function On(s, n, t) {
|
|
193
|
-
const e = c(s, "x", "spaceToBatchND");
|
|
194
|
-
l(e.rank >= 1 + n.length, () => `input rank ${e.rank} should be > than [blockShape] ${n.length}`), l(t.length === n.length, () => `paddings.shape[0] ${t.length} must be equal to [blockShape] ${n.length}`), l(e.shape.reduce((p, r, a) => a > 0 && a <= n.length ? p && (r + t[a - 1][0] + t[a - 1][1]) % n[a - 1] === 0 : p, !0), () => `input spatial dimensions ${e.shape.slice(1)} with paddings ${t.toString()} must be divisible by blockShapes ${n.toString()}`);
|
|
195
|
-
const i = { x: e }, o = { blockShape: n, paddings: t };
|
|
196
|
-
return d.runKernel(sn, i, o);
|
|
197
|
-
}
|
|
198
|
-
const Dt = /* @__PURE__ */ h({ spaceToBatchND_: On });
|
|
199
|
-
function zn(s, n) {
|
|
200
|
-
const e = { x: c(s, "x", "reverse") }, i = { dims: n };
|
|
201
|
-
return d.runKernel(en, e, i);
|
|
202
|
-
}
|
|
203
|
-
const Ct = /* @__PURE__ */ h({ reverse_: zn });
|
|
204
|
-
function Fn(s) {
|
|
205
|
-
const t = { x: c(s, "x", "rsqrt", "float32") };
|
|
206
|
-
return d.runKernel(rn, t);
|
|
207
|
-
}
|
|
208
|
-
const Et = /* @__PURE__ */ h({ rsqrt_: Fn });
|
|
209
|
-
function Mn(s) {
|
|
210
|
-
const t = { x: c(s, "x", "selu") };
|
|
211
|
-
return d.runKernel(on, t);
|
|
212
|
-
}
|
|
213
|
-
const Nt = /* @__PURE__ */ h({ selu_: Mn });
|
|
214
|
-
function jn(s, n, t, e, i, o = [1, 1], p = "NHWC") {
|
|
215
|
-
const r = c(s, "x", "separableConv2d"), a = c(n, "depthwiseFilter", "separableConv2d"), u = c(t, "pointwiseFilter", "separableConv2d");
|
|
216
|
-
let m = r, f = !1;
|
|
217
|
-
if (r.rank === 3 && (f = !0, m = b(r, [1, r.shape[0], r.shape[1], r.shape[2]])), p === "NCHW")
|
|
218
|
-
throw new Error("separableConv2d currently does not support dataFormat NCHW; only NHWC is supported");
|
|
219
|
-
l(m.rank === 4, () => `Error in separableConv2d: input must be rank 4, but got rank ${m.rank}.`), l(a.rank === 4, () => `Error in separableConv2d: depthwise filter must be rank 4, but got rank ${a.rank}.`), l(u.rank === 4, () => `Error in separableConv2d: pointwise filter must be rank 4, but got rank ${a.rank}.`), l(u.shape[0] === 1, () => `Error in separableConv2d: the first dimension of pointwise filter must be 1, but got ${u.shape[0]}.`), l(u.shape[1] === 1, () => `Error in separableConv2d: the second dimension of pointwise filter must be 1, but got ${u.shape[1]}.`);
|
|
220
|
-
const x = a.shape[2], k = a.shape[3];
|
|
221
|
-
l(u.shape[2] === x * k, () => `Error in separableConv2d: the third dimension of pointwise filter must be ${x * k}, but got ${u.shape[2]}.`);
|
|
222
|
-
const $ = qn(m, a, e, i, p, o), v = S($, u, 1, "valid", p);
|
|
223
|
-
return f ? b(v, [v.shape[1], v.shape[2], v.shape[3]]) : v;
|
|
224
|
-
}
|
|
225
|
-
const _t = /* @__PURE__ */ h({ separableConv2d_: jn });
|
|
226
|
-
function Jn(s) {
|
|
227
|
-
const t = { x: c(s, "x", "sin", "float32") };
|
|
228
|
-
return d.runKernel(an, t);
|
|
229
|
-
}
|
|
230
|
-
const wt = /* @__PURE__ */ h({ sin_: Jn });
|
|
231
|
-
function Un(s) {
|
|
232
|
-
const t = { x: c(s, "x", "sinh") };
|
|
233
|
-
return d.runKernel(un, t);
|
|
234
|
-
}
|
|
235
|
-
const Kt = /* @__PURE__ */ h({ sinh_: Un });
|
|
236
|
-
function Vn(s, n, t) {
|
|
237
|
-
const e = c(s, "x", "unsortedSegmentSum"), i = c(n, "segmentIds", "unsortedSegmentSum", "int32");
|
|
238
|
-
l(cn(t), () => "numSegments must be of dtype int");
|
|
239
|
-
const o = { x: e, segmentIds: i }, p = { numSegments: t };
|
|
240
|
-
return d.runKernel(ln, o, p);
|
|
241
|
-
}
|
|
242
|
-
const St = /* @__PURE__ */ h({ unsortedSegmentSum_: Vn });
|
|
243
|
-
export {
|
|
244
|
-
Et as A,
|
|
245
|
-
Nt as B,
|
|
246
|
-
_t as C,
|
|
247
|
-
Kt as D,
|
|
248
|
-
rt as E,
|
|
249
|
-
St as F,
|
|
250
|
-
En as G,
|
|
251
|
-
mt as a,
|
|
252
|
-
Dt as b,
|
|
253
|
-
ct as c,
|
|
254
|
-
ot as d,
|
|
255
|
-
ft as e,
|
|
256
|
-
et as f,
|
|
257
|
-
nt as g,
|
|
258
|
-
tt as h,
|
|
259
|
-
st as i,
|
|
260
|
-
at as j,
|
|
261
|
-
it as k,
|
|
262
|
-
bt as l,
|
|
263
|
-
xt as m,
|
|
264
|
-
ut as n,
|
|
265
|
-
S as o,
|
|
266
|
-
lt as p,
|
|
267
|
-
pt as q,
|
|
268
|
-
Ct as r,
|
|
269
|
-
wt as s,
|
|
270
|
-
ht as t,
|
|
271
|
-
qn as u,
|
|
272
|
-
dt as v,
|
|
273
|
-
$t as w,
|
|
274
|
-
kt as x,
|
|
275
|
-
vt as y,
|
|
276
|
-
gt as z
|
|
277
|
-
};
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
export declare function arraysClose(a: unknown, b: unknown): number;
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
function i(r, e) {
|
|
2
|
-
let t = 0;
|
|
3
|
-
if ((Array.isArray(r) || r instanceof Float32Array) && (Array.isArray(e) || e instanceof Float32Array)) {
|
|
4
|
-
if (r.length !== e.length) return Number.POSITIVE_INFINITY;
|
|
5
|
-
for (let n = 0; n < r.length; ++n)
|
|
6
|
-
t = Math.max(t, i(r[n], e[n]));
|
|
7
|
-
return t;
|
|
8
|
-
} else if (typeof r == "number" && typeof e == "number") {
|
|
9
|
-
if (isNaN(r) && isNaN(e))
|
|
10
|
-
return 0;
|
|
11
|
-
if (!isFinite(r) || !isFinite(e))
|
|
12
|
-
return r === e ? 0 : Number.POSITIVE_INFINITY;
|
|
13
|
-
const n = Math.abs(r - e);
|
|
14
|
-
return t = Math.max(t, n), t;
|
|
15
|
-
} else
|
|
16
|
-
return Number.POSITIVE_INFINITY;
|
|
17
|
-
}
|
|
18
|
-
export {
|
|
19
|
-
i as arraysClose
|
|
20
|
-
};
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
function h(s) {
|
|
2
|
-
const c = s.map((n) => String(n.id)).sort();
|
|
3
|
-
let e = 2166136261, o = 2654435769;
|
|
4
|
-
const t = (n) => {
|
|
5
|
-
e ^= n & 255, e = Math.imul(e, 16777619), o ^= n & 255, o = Math.imul(o, 2246822507);
|
|
6
|
-
}, g = (n) => {
|
|
7
|
-
const r = n.length >>> 0;
|
|
8
|
-
t(r & 255), t(r >>> 8 & 255), t(r >>> 16 & 255), t(r >>> 24 & 255);
|
|
9
|
-
for (let a = 0; a < n.length; a++) {
|
|
10
|
-
const l = n.charCodeAt(a);
|
|
11
|
-
t(l & 255), t(l >>> 8 & 255);
|
|
12
|
-
}
|
|
13
|
-
}, i = c.length >>> 0;
|
|
14
|
-
t(i & 255), t(i >>> 8 & 255), t(i >>> 16 & 255), t(i >>> 24 & 255);
|
|
15
|
-
for (const n of c)
|
|
16
|
-
g(n);
|
|
17
|
-
return "dataset__" + (e >>> 0).toString(36) + "_" + (o >>> 0).toString(36);
|
|
18
|
-
}
|
|
19
|
-
export {
|
|
20
|
-
h as default
|
|
21
|
-
};
|
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
import { default as Model, ModelForwardAttributes } from '../models/model';
|
|
2
|
-
export declare function dummyPassAsync(model: Model<ModelForwardAttributes>): Promise<void>;
|
|
3
|
-
export interface MemoryRequirements {
|
|
4
|
-
perBatch: number;
|
|
5
|
-
tapeSize: number;
|
|
6
|
-
gradients: number;
|
|
7
|
-
}
|
|
8
|
-
export declare function dummyPassTrainAsync(model: Model<ModelForwardAttributes>): Promise<MemoryRequirements>;
|
|
9
|
-
export declare function dummyPass(model: Model<ModelForwardAttributes>): void;
|
package/dist/utilities/dummy.js
DELETED
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
import { b as y, h as I, v as P } from "../index-CUXkjxiT.js";
|
|
2
|
-
import { z as c } from "../zeros-DvZpK8s6.js";
|
|
3
|
-
async function w(s) {
|
|
4
|
-
const t = c([1, s.config.blockSize], "int32"), n = s.forward({ training: !1 }, t);
|
|
5
|
-
await n.data(), n.dispose(), t.dispose();
|
|
6
|
-
}
|
|
7
|
-
async function b(s) {
|
|
8
|
-
const t = y(), n = t.numBytesInGPUAllocated ?? t.numBytesAllocatedInGPU ?? t.numBytes;
|
|
9
|
-
await w(s);
|
|
10
|
-
const a = c([1, s.config.blockSize], "int32"), e = {
|
|
11
|
-
perBatch: 0,
|
|
12
|
-
tapeSize: 0,
|
|
13
|
-
gradients: s.getNumParams() * 4
|
|
14
|
-
};
|
|
15
|
-
try {
|
|
16
|
-
const i = () => {
|
|
17
|
-
const o = s.forward({ training: !0 }, a), u = I().state.activeTape;
|
|
18
|
-
let p = 0;
|
|
19
|
-
if (u)
|
|
20
|
-
for (const g of u)
|
|
21
|
-
p += g.saved?.reduce((B, z) => B + z.size * 4, 0) || 0;
|
|
22
|
-
e.tapeSize = p;
|
|
23
|
-
const l = o.mean();
|
|
24
|
-
return o.dispose(), l;
|
|
25
|
-
}, { value: d, grads: m } = P(i), r = y(), f = r.numBytesInGPUAllocated ?? r.numBytesAllocatedInGPU ?? r.numBytes;
|
|
26
|
-
e.perBatch = f - n - e.gradients, await d.data(), d.dispose();
|
|
27
|
-
for (const o in m)
|
|
28
|
-
m[o].dispose();
|
|
29
|
-
a.dispose();
|
|
30
|
-
} catch (i) {
|
|
31
|
-
console.error("Error during dummy training pass:", i), a.dispose();
|
|
32
|
-
}
|
|
33
|
-
return e;
|
|
34
|
-
}
|
|
35
|
-
function v(s) {
|
|
36
|
-
const t = c([1, s.config.blockSize], "int32");
|
|
37
|
-
s.forward({ training: !1 }, t).dispose(), t.dispose();
|
|
38
|
-
}
|
|
39
|
-
export {
|
|
40
|
-
v as dummyPass,
|
|
41
|
-
w as dummyPassAsync,
|
|
42
|
-
b as dummyPassTrainAsync
|
|
43
|
-
};
|
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
import "../index-CUXkjxiT.js";
|
|
2
|
-
import { t as e } from "../tensor2d-BnXMKScO.js";
|
|
3
|
-
function f(n, i) {
|
|
4
|
-
let r = 0;
|
|
5
|
-
const o = i ?? Math.random();
|
|
6
|
-
for (let t = 0; t < n.length; t++)
|
|
7
|
-
if (r += n[t], o < r)
|
|
8
|
-
return e([[t]], [1, 1], "int32");
|
|
9
|
-
return e([[n.length - 1]], [1, 1], "int32");
|
|
10
|
-
}
|
|
11
|
-
export {
|
|
12
|
-
f as default
|
|
13
|
-
};
|
package/dist/utilities/naming.js
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
|
package/dist/utilities/packed.js
DELETED
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
import { h as n } from "../index-CUXkjxiT.js";
|
|
2
|
-
function o() {
|
|
3
|
-
return n().backendName === "webgpu";
|
|
4
|
-
}
|
|
5
|
-
function r(e) {
|
|
6
|
-
return e.dtype === "packedF16";
|
|
7
|
-
}
|
|
8
|
-
function a(e) {
|
|
9
|
-
return r(e);
|
|
10
|
-
}
|
|
11
|
-
export {
|
|
12
|
-
r as isPackableTensor,
|
|
13
|
-
a as isPackedTensor,
|
|
14
|
-
o as packingSupported
|
|
15
|
-
};
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
import { GPTConfig } from '../models/config';
|
|
2
|
-
export declare function estimateLayerParameters(config: GPTConfig): number;
|
|
3
|
-
export declare function estimateParameterCount(config: GPTConfig): number;
|
|
4
|
-
export declare function estimateMemoryUsage(config: GPTConfig): number;
|
|
5
|
-
export declare function estimateTrainingMemoryUsage(config: GPTConfig, batchSize: number): number;
|
|
6
|
-
export declare function estimateResources(config: GPTConfig, batchSize: number): {
|
|
7
|
-
numParams: number;
|
|
8
|
-
modelMemoryMB: number;
|
|
9
|
-
trainingMemoryMB: number;
|
|
10
|
-
};
|
|
11
|
-
export declare function validateConfig(config: GPTConfig): void;
|
|
@@ -1,57 +0,0 @@
|
|
|
1
|
-
function b(e) {
|
|
2
|
-
const r = 4 * e.nEmbed * e.nEmbed, t = e.mlpFactor * e.nEmbed * e.nEmbed + // fc
|
|
3
|
-
e.nEmbed * e.mlpFactor * e.nEmbed;
|
|
4
|
-
return r + t;
|
|
5
|
-
}
|
|
6
|
-
function a(e) {
|
|
7
|
-
const r = e.vocabSize * e.nEmbed, t = e.nLayer * (4 * e.nEmbed * e.nEmbed), m = e.nLayer * (e.mlpFactor * e.nEmbed * e.nEmbed + // fc
|
|
8
|
-
e.nEmbed * e.mlpFactor * e.nEmbed);
|
|
9
|
-
return r + t + m;
|
|
10
|
-
}
|
|
11
|
-
function o(e) {
|
|
12
|
-
return a(e) * 4;
|
|
13
|
-
}
|
|
14
|
-
function E(e, r) {
|
|
15
|
-
const t = o(e), m = t * 2, n = t * 2, s = r * e.blockSize * e.nEmbed * 4;
|
|
16
|
-
return t + m + n + s;
|
|
17
|
-
}
|
|
18
|
-
function i(e, r) {
|
|
19
|
-
const t = a(e), m = o(e) / (1024 * 1024), n = E(e, r) / (1024 * 1024);
|
|
20
|
-
return {
|
|
21
|
-
numParams: t,
|
|
22
|
-
modelMemoryMB: m,
|
|
23
|
-
trainingMemoryMB: n
|
|
24
|
-
};
|
|
25
|
-
}
|
|
26
|
-
function d(e) {
|
|
27
|
-
if (e.nEmbed % e.nHead !== 0)
|
|
28
|
-
throw new Error("nEmbed_divisible_nHead");
|
|
29
|
-
if (e.blockSize <= 0)
|
|
30
|
-
throw new Error("blockSize_positive");
|
|
31
|
-
if (e.vocabSize <= 0)
|
|
32
|
-
throw new Error("vocabSize_positive");
|
|
33
|
-
if (e.nLayer <= 0)
|
|
34
|
-
throw new Error("nLayer_positive");
|
|
35
|
-
if (e.mlpFactor <= 0)
|
|
36
|
-
throw new Error("mlpFactor_positive");
|
|
37
|
-
if (e.nEmbed / e.nHead % 2 !== 0)
|
|
38
|
-
throw new Error("headDim_even");
|
|
39
|
-
if (!Number.isInteger(e.nEmbed))
|
|
40
|
-
throw new Error("nEmbed_integer");
|
|
41
|
-
if (!Number.isInteger(e.nHead))
|
|
42
|
-
throw new Error("nHead_integer");
|
|
43
|
-
if (!Number.isInteger(e.nLayer))
|
|
44
|
-
throw new Error("nLayer_integer");
|
|
45
|
-
if (!Number.isInteger(e.blockSize))
|
|
46
|
-
throw new Error("blockSize_integer");
|
|
47
|
-
if (!Number.isInteger(e.vocabSize))
|
|
48
|
-
throw new Error("vocabSize_integer");
|
|
49
|
-
}
|
|
50
|
-
export {
|
|
51
|
-
b as estimateLayerParameters,
|
|
52
|
-
o as estimateMemoryUsage,
|
|
53
|
-
a as estimateParameterCount,
|
|
54
|
-
i as estimateResources,
|
|
55
|
-
E as estimateTrainingMemoryUsage,
|
|
56
|
-
d as validateConfig
|
|
57
|
-
};
|
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
import { t as s } from "../index-CUXkjxiT.js";
|
|
2
|
-
async function f(e, o = 10, r = !1) {
|
|
3
|
-
for (let t = 0; t < 100; t++) {
|
|
4
|
-
const a = r ? await e() : s(e);
|
|
5
|
-
t === 99 && await a.data(), a.dispose();
|
|
6
|
-
}
|
|
7
|
-
const n = performance.now();
|
|
8
|
-
for (let t = 0; t < o; t++) {
|
|
9
|
-
const a = r ? await e() : s(e);
|
|
10
|
-
t === o - 1 && await a.data(), a.dispose();
|
|
11
|
-
}
|
|
12
|
-
return (performance.now() - n) / o;
|
|
13
|
-
}
|
|
14
|
-
export {
|
|
15
|
-
f as default
|
|
16
|
-
};
|
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
import { MemoryInfo } from '@tensorflow/tfjs-core';
|
|
2
|
-
export interface ExtendedMemoryInfo extends MemoryInfo {
|
|
3
|
-
numBytesInGPUAllocated?: number;
|
|
4
|
-
numBytesAllocatedInGPU?: number;
|
|
5
|
-
}
|
|
6
|
-
export default class MemoryProfiler {
|
|
7
|
-
private log;
|
|
8
|
-
private maxMemory;
|
|
9
|
-
private maxLabel?;
|
|
10
|
-
private lastMemInfo;
|
|
11
|
-
private peakMemory;
|
|
12
|
-
startMemory(): void;
|
|
13
|
-
getPeakMemory(): number;
|
|
14
|
-
getMaxMemory(): number;
|
|
15
|
-
endMemory(label: string): void;
|
|
16
|
-
printSummary(): void;
|
|
17
|
-
}
|