whispercpp 1.3.6 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/README.md +38 -5
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -8
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +36 -42
- data/ext/ruby_whisper.h +135 -0
- data/ext/ruby_whisper_context.c +107 -28
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -65
- data/ext/ruby_whisper_segment.c +6 -6
- data/ext/ruby_whisper_transcribe.cpp +42 -15
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +1 -1
- data/ext/sources/examples/cli/cli.cpp +43 -9
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +199 -163
- data/ext/sources/ggml/CMakeLists.txt +21 -13
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +72 -10
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-rpc.h +3 -3
- data/ext/sources/ggml/include/ggml.h +101 -9
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +22 -5
- data/ext/sources/ggml/src/ggml-alloc.c +5 -1
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
- data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
- data/ext/sources/ggml/src/ggml-impl.h +6 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
- data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +289 -114
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
- data/ext/sources/ggml/src/ggml.c +110 -28
- data/ext/sources/ggml/src/gguf.cpp +173 -28
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +56 -12
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +411 -62
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +24 -6
- data/whispercpp.gemspec +2 -2
- metadata +215 -281
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
- data/ext/sources/examples/talk-llama/llama-context.h +0 -359
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
- data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
- data/ext/sources/examples/talk-llama/llama-model.h +0 -597
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
- data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
- data/ext/sources/examples/talk-llama/llama.h +0 -1573
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -704
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
- /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
|
@@ -0,0 +1,3838 @@
|
|
|
1
|
+
#include "parakeet.h"
|
|
2
|
+
#include "parakeet-arch.h"
|
|
3
|
+
|
|
4
|
+
#include "ggml.h"
|
|
5
|
+
#include "ggml-cpp.h"
|
|
6
|
+
#include "ggml-alloc.h"
|
|
7
|
+
#include "ggml-backend.h"
|
|
8
|
+
|
|
9
|
+
#include <atomic>
|
|
10
|
+
#include <algorithm>
|
|
11
|
+
#include <cassert>
|
|
12
|
+
#include <cfloat>
|
|
13
|
+
#define _USE_MATH_DEFINES
|
|
14
|
+
#include <cmath>
|
|
15
|
+
#include <climits>
|
|
16
|
+
#include <cstdarg>
|
|
17
|
+
#include <cstdio>
|
|
18
|
+
#include <cstring>
|
|
19
|
+
#include <fstream>
|
|
20
|
+
#include <functional>
|
|
21
|
+
#include <cctype>
|
|
22
|
+
#include <map>
|
|
23
|
+
#include <random>
|
|
24
|
+
#include <set>
|
|
25
|
+
#include <string>
|
|
26
|
+
#include <thread>
|
|
27
|
+
#include <vector>
|
|
28
|
+
|
|
29
|
+
#ifdef _MSC_VER
|
|
30
|
+
#include <codecvt>
|
|
31
|
+
#endif
|
|
32
|
+
|
|
33
|
+
#if defined(PARAKEET_BIG_ENDIAN)
|
|
34
|
+
template<typename T>
|
|
35
|
+
static T byteswap(T value) {
|
|
36
|
+
T value_swapped;
|
|
37
|
+
char * source = reinterpret_cast<char *>(&value);
|
|
38
|
+
char * target = reinterpret_cast<char *>(&value_swapped);
|
|
39
|
+
int size = sizeof(T);
|
|
40
|
+
for (int i = 0; i < size; i++) {
|
|
41
|
+
target[size - 1 - i] = source[i];
|
|
42
|
+
}
|
|
43
|
+
return value_swapped;
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
template<typename T>
|
|
47
|
+
static void byteswap_tensor_data(ggml_tensor * tensor) {
|
|
48
|
+
T * datum = reinterpret_cast<T *>(tensor->data);
|
|
49
|
+
for (int i = 0; i < ggml_nelements(tensor); i++) {
|
|
50
|
+
datum[i] = byteswap(datum[i]);
|
|
51
|
+
}
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
static void byteswap_tensor(ggml_tensor * tensor) {
|
|
55
|
+
switch (tensor->type) {
|
|
56
|
+
case GGML_TYPE_I16: {
|
|
57
|
+
byteswap_tensor_data<int16_t>(tensor);
|
|
58
|
+
break;
|
|
59
|
+
}
|
|
60
|
+
case GGML_TYPE_F16: {
|
|
61
|
+
byteswap_tensor_data<ggml_fp16_t>(tensor);
|
|
62
|
+
break;
|
|
63
|
+
}
|
|
64
|
+
case GGML_TYPE_I32: {
|
|
65
|
+
byteswap_tensor_data<int32_t>(tensor);
|
|
66
|
+
break;
|
|
67
|
+
}
|
|
68
|
+
case GGML_TYPE_F32: {
|
|
69
|
+
byteswap_tensor_data<float>(tensor);
|
|
70
|
+
break;
|
|
71
|
+
}
|
|
72
|
+
default: { // GML_TYPE_I8
|
|
73
|
+
break;
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
#define BYTESWAP_VALUE(d) d = byteswap(d)
|
|
79
|
+
#define BYTESWAP_FILTERS(f) \
|
|
80
|
+
do { \
|
|
81
|
+
for (auto & datum : f.data) { \
|
|
82
|
+
datum = byteswap(datum); \
|
|
83
|
+
} \
|
|
84
|
+
} while (0)
|
|
85
|
+
#define BYTESWAP_TENSOR(t) \
|
|
86
|
+
do { \
|
|
87
|
+
byteswap_tensor(t); \
|
|
88
|
+
} while (0)
|
|
89
|
+
#else
|
|
90
|
+
#define BYTESWAP_VALUE(d) do {} while (0)
|
|
91
|
+
#define BYTESWAP_FILTERS(f) do {} while (0)
|
|
92
|
+
#define BYTESWAP_TENSOR(t) do {} while (0)
|
|
93
|
+
#endif
|
|
94
|
+
|
|
95
|
+
#ifdef __GNUC__
|
|
96
|
+
#ifdef __MINGW32__
|
|
97
|
+
#define PARAKEET_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
|
98
|
+
#else
|
|
99
|
+
#define PARAKEET_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
|
100
|
+
#endif
|
|
101
|
+
#else
|
|
102
|
+
#define PARAKEET_ATTRIBUTE_FORMAT(...)
|
|
103
|
+
#endif
|
|
104
|
+
|
|
105
|
+
//
|
|
106
|
+
// logging
|
|
107
|
+
//
|
|
108
|
+
|
|
109
|
+
PARAKEET_ATTRIBUTE_FORMAT(2, 3)
|
|
110
|
+
static void parakeet_log_internal (ggml_log_level level, const char * format, ...);
|
|
111
|
+
static void parakeet_log_callback_default(ggml_log_level level, const char * text, void * user_data);
|
|
112
|
+
|
|
113
|
+
#define PARAKEET_LOG_ERROR(...) parakeet_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
|
114
|
+
#define PARAKEET_LOG_WARN(...) parakeet_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
|
115
|
+
#define PARAKEET_LOG_INFO(...) parakeet_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
|
116
|
+
|
|
117
|
+
// define this to enable verbose trace logging - useful for debugging purposes
|
|
118
|
+
//#define PARAKEET_DEBUG
|
|
119
|
+
|
|
120
|
+
#if defined(PARAKEET_DEBUG)
|
|
121
|
+
#define PARAKEET_LOG_DEBUG(...) parakeet_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
|
122
|
+
#else
|
|
123
|
+
#define PARAKEET_LOG_DEBUG(...)
|
|
124
|
+
#endif
|
|
125
|
+
|
|
126
|
+
#define PARAKEET_ASSERT(x) \
|
|
127
|
+
do { \
|
|
128
|
+
if (!(x)) { \
|
|
129
|
+
PARAKEET_LOG_ERROR("PARAKEET_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
|
130
|
+
abort(); \
|
|
131
|
+
} \
|
|
132
|
+
} while (0)
|
|
133
|
+
|
|
134
|
+
#define PARAKEET_MAX_NODES 8192
|
|
135
|
+
|
|
136
|
+
// Threshold for when local attention should be used.
|
|
137
|
+
// 8192 frames x 80ms = 655 s (about 10.9 mins)
|
|
138
|
+
static constexpr int PARAKEET_LOCAL_ATTN_THRESHOLD = 8192;
|
|
139
|
+
// Window of context in each director of the current token.
|
|
140
|
+
// 128 frames * 80ms = 10.24 s
|
|
141
|
+
static constexpr int PARAKEET_LOCAL_ATTN_WINDOW = 128;
|
|
142
|
+
|
|
143
|
+
static std::string format(const char * fmt, ...) {
|
|
144
|
+
va_list ap;
|
|
145
|
+
va_list ap2;
|
|
146
|
+
va_start(ap, fmt);
|
|
147
|
+
va_copy(ap2, ap);
|
|
148
|
+
int size = vsnprintf(NULL, 0, fmt, ap);
|
|
149
|
+
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
|
|
150
|
+
std::vector<char> buf(size + 1);
|
|
151
|
+
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
|
152
|
+
GGML_ASSERT(size2 == size);
|
|
153
|
+
va_end(ap2);
|
|
154
|
+
va_end(ap);
|
|
155
|
+
return std::string(buf.data(), size);
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
//
|
|
159
|
+
// ggml helpers
|
|
160
|
+
//
|
|
161
|
+
|
|
162
|
+
static bool ggml_graph_compute_helper(
|
|
163
|
+
struct ggml_cgraph * graph,
|
|
164
|
+
int n_threads,
|
|
165
|
+
ggml_abort_callback abort_callback,
|
|
166
|
+
void * abort_callback_data) {
|
|
167
|
+
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
|
|
168
|
+
|
|
169
|
+
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
|
|
170
|
+
|
|
171
|
+
auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
|
|
172
|
+
if (set_abort_callback_fn) {
|
|
173
|
+
set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data);
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
|
177
|
+
if (ggml_backend_set_n_threads_fn) {
|
|
178
|
+
ggml_backend_set_n_threads_fn(backend.get(), n_threads);
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
return ggml_backend_graph_compute(backend.get(), graph) == GGML_STATUS_SUCCESS;
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
static bool ggml_graph_compute_helper(
|
|
185
|
+
ggml_backend_sched_t sched,
|
|
186
|
+
struct ggml_cgraph * graph,
|
|
187
|
+
int n_threads,
|
|
188
|
+
bool sched_reset = true) {
|
|
189
|
+
for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
|
|
190
|
+
ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
|
|
191
|
+
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
|
|
192
|
+
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
|
193
|
+
|
|
194
|
+
auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
|
195
|
+
if (fn_set_n_threads) {
|
|
196
|
+
fn_set_n_threads(backend, n_threads);
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS);
|
|
201
|
+
|
|
202
|
+
if (!t || sched_reset) {
|
|
203
|
+
ggml_backend_sched_reset(sched);
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
return t;
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
// TODO: move these functions to ggml-base with support for ggml-backend?
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
struct parakeet_mel {
|
|
213
|
+
int n_len = 0;
|
|
214
|
+
int n_len_org = 0;
|
|
215
|
+
int n_mel = 0;
|
|
216
|
+
|
|
217
|
+
std::vector<float> data;
|
|
218
|
+
};
|
|
219
|
+
|
|
220
|
+
struct parakeet_filters {
|
|
221
|
+
int32_t n_mel = 0;
|
|
222
|
+
int32_t n_fb = 0; // number of frequency bins
|
|
223
|
+
|
|
224
|
+
std::vector<float> data;
|
|
225
|
+
};
|
|
226
|
+
|
|
227
|
+
struct parakeet_vocab {
|
|
228
|
+
using id = int32_t;
|
|
229
|
+
using token = std::string;
|
|
230
|
+
|
|
231
|
+
int n_vocab = 8192;
|
|
232
|
+
size_t max_token_length = 0;
|
|
233
|
+
|
|
234
|
+
std::map<token, id> token_to_id;
|
|
235
|
+
std::map<id, token> id_to_token;
|
|
236
|
+
|
|
237
|
+
id token_unk;
|
|
238
|
+
id token_bos;
|
|
239
|
+
id token_blank;
|
|
240
|
+
id token_eos;
|
|
241
|
+
};
|
|
242
|
+
|
|
243
|
+
struct parakeet_segment {
|
|
244
|
+
int64_t t0;
|
|
245
|
+
int64_t t1;
|
|
246
|
+
|
|
247
|
+
std::string text;
|
|
248
|
+
|
|
249
|
+
std::vector<parakeet_token_data> tokens;
|
|
250
|
+
};
|
|
251
|
+
|
|
252
|
+
struct parakeet_batch {
|
|
253
|
+
int32_t n_tokens;
|
|
254
|
+
|
|
255
|
+
parakeet_token * token;
|
|
256
|
+
int32_t * i_time; // index of the audio frame
|
|
257
|
+
parakeet_pos * pos;
|
|
258
|
+
int32_t * n_seq_id; // always 1, here for consistency with llama.cpp
|
|
259
|
+
parakeet_seq_id ** seq_id; // null terminated
|
|
260
|
+
int8_t * logits;
|
|
261
|
+
};
|
|
262
|
+
|
|
263
|
+
// ggml_backend_sched wrapper for parakeet usage
|
|
264
|
+
struct parakeet_sched {
|
|
265
|
+
ggml_backend_sched_t sched = nullptr;
|
|
266
|
+
|
|
267
|
+
std::vector<uint8_t> meta;
|
|
268
|
+
};
|
|
269
|
+
|
|
270
|
+
// TODO: Find out is there a multiple version types. It is not yet clear to me
|
|
271
|
+
// at this point.
|
|
272
|
+
enum parakeet_arch {
|
|
273
|
+
PARAKEET_ARCH_UNKNOWN = 0,
|
|
274
|
+
PARAKEET_ARCH_TDT = 1, // NVIDIA Parakeet TDT (RNN-T)
|
|
275
|
+
};
|
|
276
|
+
|
|
277
|
+
struct parakeet_hparams {
|
|
278
|
+
int32_t n_vocab = 8192;
|
|
279
|
+
int32_t n_audio_ctx = 0; // 0 = unlimited, will be set based on input
|
|
280
|
+
int32_t n_audio_state = 1024;
|
|
281
|
+
int32_t n_audio_head = 8;
|
|
282
|
+
int32_t n_audio_layer = 24;
|
|
283
|
+
int32_t n_mels = 128;
|
|
284
|
+
int32_t ftype = 1;
|
|
285
|
+
int32_t n_fft = 512; // FFT size for mel spectrogram
|
|
286
|
+
float eps = 1e-5f;
|
|
287
|
+
int32_t subsampling_factor = 8;
|
|
288
|
+
int32_t n_subsampling_channels = 256;
|
|
289
|
+
int32_t n_conv_kernel = 9;
|
|
290
|
+
int32_t n_pred_dim = 640;
|
|
291
|
+
int32_t n_pred_layers = 2;
|
|
292
|
+
int32_t n_tdt_durations = 5;
|
|
293
|
+
int32_t n_max_tokens = 10;
|
|
294
|
+
|
|
295
|
+
parakeet_arch arch = PARAKEET_ARCH_TDT;
|
|
296
|
+
};
|
|
297
|
+
|
|
298
|
+
struct parakeet_layer_encoder {
|
|
299
|
+
struct ggml_tensor * norm_ff1_w = nullptr;
|
|
300
|
+
struct ggml_tensor * norm_ff1_b = nullptr;
|
|
301
|
+
|
|
302
|
+
struct ggml_tensor * ff1_linear1_w = nullptr;
|
|
303
|
+
struct ggml_tensor * ff1_linear2_w = nullptr;
|
|
304
|
+
|
|
305
|
+
struct ggml_tensor * norm_conv_w = nullptr;
|
|
306
|
+
struct ggml_tensor * norm_conv_b = nullptr;
|
|
307
|
+
|
|
308
|
+
struct ggml_tensor * conv_pw1_w = nullptr; // pointwise_conv1
|
|
309
|
+
struct ggml_tensor * conv_dw_w = nullptr; // depthwise_conv
|
|
310
|
+
struct ggml_tensor * conv_bn_w = nullptr; // batch_norm weight
|
|
311
|
+
struct ggml_tensor * conv_bn_b = nullptr; // batch_norm bias
|
|
312
|
+
struct ggml_tensor * conv_bn_mean = nullptr; // batch_norm running_mean
|
|
313
|
+
struct ggml_tensor * conv_bn_var = nullptr; // batch_norm running_var
|
|
314
|
+
struct ggml_tensor * conv_bn_num_batches = nullptr; // batch_norm num_batches_tracked
|
|
315
|
+
struct ggml_tensor * conv_pw2_w = nullptr; // pointwise_conv2
|
|
316
|
+
|
|
317
|
+
struct ggml_tensor * norm_attn_w = nullptr;
|
|
318
|
+
struct ggml_tensor * norm_attn_b = nullptr;
|
|
319
|
+
|
|
320
|
+
struct ggml_tensor * attn_pos_bias_u = nullptr;
|
|
321
|
+
struct ggml_tensor * attn_pos_bias_v = nullptr;
|
|
322
|
+
struct ggml_tensor * attn_q_w = nullptr;
|
|
323
|
+
struct ggml_tensor * attn_k_w = nullptr;
|
|
324
|
+
struct ggml_tensor * attn_v_w = nullptr;
|
|
325
|
+
struct ggml_tensor * attn_out_w = nullptr;
|
|
326
|
+
struct ggml_tensor * attn_pos_w = nullptr;
|
|
327
|
+
|
|
328
|
+
struct ggml_tensor * norm_ff2_w = nullptr;
|
|
329
|
+
struct ggml_tensor * norm_ff2_b = nullptr;
|
|
330
|
+
|
|
331
|
+
struct ggml_tensor * ff2_linear1_w = nullptr;
|
|
332
|
+
struct ggml_tensor * ff2_linear2_w = nullptr;
|
|
333
|
+
|
|
334
|
+
struct ggml_tensor * norm_out_w = nullptr;
|
|
335
|
+
struct ggml_tensor * norm_out_b = nullptr;
|
|
336
|
+
};
|
|
337
|
+
|
|
338
|
+
struct parakeet_lsmt_layer {
|
|
339
|
+
struct ggml_tensor * ih_w = nullptr; // input-to-hidden weight
|
|
340
|
+
struct ggml_tensor * hh_w = nullptr; // hidden-to-hidden weight
|
|
341
|
+
struct ggml_tensor * b_h = nullptr; // bias (ih folded into hh at conversion time)
|
|
342
|
+
};
|
|
343
|
+
|
|
344
|
+
struct parakeet_prediction_network {
|
|
345
|
+
struct ggml_tensor * embed_w = nullptr;
|
|
346
|
+
|
|
347
|
+
std::vector<parakeet_lsmt_layer> lstm_layer;
|
|
348
|
+
};
|
|
349
|
+
|
|
350
|
+
struct parakeet_joint_network {
|
|
351
|
+
struct ggml_tensor * pred_w = nullptr;
|
|
352
|
+
struct ggml_tensor * pred_b = nullptr;
|
|
353
|
+
struct ggml_tensor * enc_w = nullptr;
|
|
354
|
+
struct ggml_tensor * enc_b = nullptr;
|
|
355
|
+
struct ggml_tensor * net_w = nullptr;
|
|
356
|
+
struct ggml_tensor * net_b = nullptr;
|
|
357
|
+
};
|
|
358
|
+
|
|
359
|
+
struct parakeet_model {
|
|
360
|
+
parakeet_filters filters;
|
|
361
|
+
parakeet_hparams hparams;
|
|
362
|
+
|
|
363
|
+
struct ggml_tensor * enc_pre_out_w = nullptr;
|
|
364
|
+
struct ggml_tensor * enc_pre_out_b = nullptr;
|
|
365
|
+
struct ggml_tensor * enc_pre_conv_0_w = nullptr;
|
|
366
|
+
struct ggml_tensor * enc_pre_conv_0_b = nullptr;
|
|
367
|
+
struct ggml_tensor * enc_pre_conv_2_w = nullptr;
|
|
368
|
+
struct ggml_tensor * enc_pre_conv_2_b = nullptr;
|
|
369
|
+
struct ggml_tensor * enc_pre_conv_3_w = nullptr;
|
|
370
|
+
struct ggml_tensor * enc_pre_conv_3_b = nullptr;
|
|
371
|
+
struct ggml_tensor * enc_pre_conv_5_w = nullptr;
|
|
372
|
+
struct ggml_tensor * enc_pre_conv_5_b = nullptr;
|
|
373
|
+
struct ggml_tensor * enc_pre_conv_6_w = nullptr;
|
|
374
|
+
struct ggml_tensor * enc_pre_conv_6_b = nullptr;
|
|
375
|
+
|
|
376
|
+
std::vector<parakeet_layer_encoder> layers;
|
|
377
|
+
|
|
378
|
+
parakeet_prediction_network prediction;
|
|
379
|
+
|
|
380
|
+
parakeet_joint_network joint;
|
|
381
|
+
|
|
382
|
+
std::vector<uint32_t> tdt_durations;
|
|
383
|
+
|
|
384
|
+
std::vector<ggml_context *> ctxs;
|
|
385
|
+
|
|
386
|
+
std::vector<ggml_backend_buffer_t> buffers;
|
|
387
|
+
|
|
388
|
+
int n_loaded = 0;
|
|
389
|
+
std::map<std::string, struct ggml_tensor *> tensors;
|
|
390
|
+
};
|
|
391
|
+
|
|
392
|
+
struct parakeet_lstm_state_layer {
|
|
393
|
+
struct ggml_tensor * h_state = nullptr;
|
|
394
|
+
struct ggml_tensor * c_state = nullptr;
|
|
395
|
+
};
|
|
396
|
+
|
|
397
|
+
struct parakeet_lstm_state {
|
|
398
|
+
std::vector<parakeet_lstm_state_layer> layer;
|
|
399
|
+
|
|
400
|
+
std::vector<uint8_t> ctx_buf;
|
|
401
|
+
|
|
402
|
+
ggml_backend_buffer_t buffer = nullptr;
|
|
403
|
+
};
|
|
404
|
+
|
|
405
|
+
struct parakeet_state {
|
|
406
|
+
int64_t t_sample_us = 0;
|
|
407
|
+
int64_t t_encode_us = 0;
|
|
408
|
+
int64_t t_decode_us = 0;
|
|
409
|
+
int64_t t_predict_us = 0;
|
|
410
|
+
int64_t t_predict_build_us = 0; // time spent building the prediction graph
|
|
411
|
+
int64_t t_predict_alloc_us = 0; // time spent in ggml_backend_sched_alloc_graph
|
|
412
|
+
int64_t t_predict_compute_us = 0; // time spent in ggml_graph_compute_helper
|
|
413
|
+
int64_t t_mel_us = 0;
|
|
414
|
+
|
|
415
|
+
int32_t n_sample = 0; // number of tokens sampled
|
|
416
|
+
int32_t n_encode = 0; // number of encoder calls
|
|
417
|
+
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
|
|
418
|
+
int32_t n_predict = 0; // number of prediction network calls
|
|
419
|
+
int32_t n_fail_p = 0; // number of logprob threshold failures
|
|
420
|
+
int32_t n_fail_h = 0; // number of entropy threshold failures
|
|
421
|
+
|
|
422
|
+
parakeet_mel mel;
|
|
423
|
+
|
|
424
|
+
parakeet_batch batch;
|
|
425
|
+
|
|
426
|
+
int n_frames = 0;
|
|
427
|
+
|
|
428
|
+
std::vector<ggml_backend_t> backends;
|
|
429
|
+
|
|
430
|
+
parakeet_sched sched_encode;
|
|
431
|
+
parakeet_sched sched_decode;
|
|
432
|
+
|
|
433
|
+
// outputs from encoder stages
|
|
434
|
+
struct ggml_tensor * enc_out = nullptr;
|
|
435
|
+
struct ggml_tensor * pred_out = nullptr;
|
|
436
|
+
|
|
437
|
+
std::vector<uint8_t> enc_out_buf;
|
|
438
|
+
ggml_backend_buffer_t enc_out_buffer = nullptr;
|
|
439
|
+
|
|
440
|
+
std::vector<uint8_t> pred_out_buf;
|
|
441
|
+
ggml_backend_buffer_t pred_out_buffer = nullptr;
|
|
442
|
+
|
|
443
|
+
struct ggml_tensor * attn_mask = nullptr;
|
|
444
|
+
|
|
445
|
+
std::vector<float> inp_mel;
|
|
446
|
+
std::vector<float> inp_mask;
|
|
447
|
+
|
|
448
|
+
std::vector<float> logits;
|
|
449
|
+
|
|
450
|
+
std::vector<parakeet_segment> result_all;
|
|
451
|
+
|
|
452
|
+
std::vector<parakeet_token> decoded_tokens;
|
|
453
|
+
std::vector<parakeet_token_data> decoded_token_data;
|
|
454
|
+
|
|
455
|
+
std::string path_model;
|
|
456
|
+
|
|
457
|
+
int32_t n_audio_ctx = 0;
|
|
458
|
+
int32_t sched_encode_n_audio_ctx = 0;
|
|
459
|
+
|
|
460
|
+
parakeet_lstm_state lstm_state;
|
|
461
|
+
};
|
|
462
|
+
|
|
463
|
+
// FFT cache for mel spectrogram computation
|
|
464
|
+
struct parakeet_mel_cache {
|
|
465
|
+
int n_fft = 0;
|
|
466
|
+
|
|
467
|
+
// In FFT, we frequently use sine and cosine operations with the same values.
|
|
468
|
+
// We can use precalculated values to speed up the process.
|
|
469
|
+
std::vector<float> sin_vals;
|
|
470
|
+
std::vector<float> cos_vals;
|
|
471
|
+
|
|
472
|
+
// Hann window (Use cosf to eliminate difference)
|
|
473
|
+
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
|
|
474
|
+
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
|
|
475
|
+
std::vector<float> hann_window;
|
|
476
|
+
|
|
477
|
+
// Window function from model (Parakeet uses actual window from training)
|
|
478
|
+
std::vector<float> window;
|
|
479
|
+
|
|
480
|
+
void init(int fft_size) {
|
|
481
|
+
n_fft = fft_size;
|
|
482
|
+
sin_vals.resize(n_fft);
|
|
483
|
+
cos_vals.resize(n_fft);
|
|
484
|
+
hann_window.resize(n_fft);
|
|
485
|
+
|
|
486
|
+
fill_sin_cos_table();
|
|
487
|
+
fill_hann_window(n_fft, true, hann_window.data());
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
void fill_sin_cos_table() {
|
|
491
|
+
for (int i = 0; i < n_fft; i++) {
|
|
492
|
+
double theta = (2 * M_PI * i) / n_fft;
|
|
493
|
+
sin_vals[i] = sinf(theta);
|
|
494
|
+
cos_vals[i] = cosf(theta);
|
|
495
|
+
}
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
void fill_hann_window(int length, bool periodic, float * output) {
|
|
499
|
+
int offset = -1;
|
|
500
|
+
if (periodic) {
|
|
501
|
+
offset = 0;
|
|
502
|
+
}
|
|
503
|
+
for (int i = 0; i < length; i++) {
|
|
504
|
+
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
|
|
505
|
+
}
|
|
506
|
+
}
|
|
507
|
+
};
|
|
508
|
+
|
|
509
|
+
struct parakeet_context {
|
|
510
|
+
int64_t t_load_us = 0;
|
|
511
|
+
int64_t t_start_us = 0;
|
|
512
|
+
|
|
513
|
+
ggml_type wtype = ggml_type::GGML_TYPE_F16;
|
|
514
|
+
ggml_type itype = ggml_type::GGML_TYPE_F16;
|
|
515
|
+
|
|
516
|
+
parakeet_context_params params;
|
|
517
|
+
|
|
518
|
+
parakeet_model model;
|
|
519
|
+
parakeet_vocab vocab;
|
|
520
|
+
|
|
521
|
+
parakeet_state * state = nullptr;
|
|
522
|
+
|
|
523
|
+
parakeet_mel_cache mel_cache;
|
|
524
|
+
|
|
525
|
+
std::string path_model;
|
|
526
|
+
};
|
|
527
|
+
|
|
528
|
+
struct parakeet_global {
|
|
529
|
+
// We save the log callback globally
|
|
530
|
+
ggml_log_callback log_callback = parakeet_log_callback_default;
|
|
531
|
+
void * log_callback_user_data = nullptr;
|
|
532
|
+
};
|
|
533
|
+
|
|
534
|
+
static parakeet_global g_state;
|
|
535
|
+
|
|
536
|
+
static const std::string PARAKEET_SPM_SPACE = "\xE2\x96\x81";
|
|
537
|
+
|
|
538
|
+
static inline int utf8_codepoint_len(unsigned char c) {
|
|
539
|
+
if ((c & 0x80) == 0x00) return 1;
|
|
540
|
+
if ((c & 0xE0) == 0xC0) return 2;
|
|
541
|
+
if ((c & 0xF0) == 0xE0) return 3;
|
|
542
|
+
if ((c & 0xF8) == 0xF0) return 4;
|
|
543
|
+
return 1;
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
static bool is_sentencepiece_control(const std::string & piece) {
|
|
547
|
+
return piece == "<unk>" || piece == "<s>" || piece == "</s>" || piece == "[BLANK]";
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
static std::string sentencepiece_normalize(const std::string & text) {
|
|
551
|
+
std::string normalized;
|
|
552
|
+
normalized.reserve(text.size() + PARAKEET_SPM_SPACE.size());
|
|
553
|
+
normalized += PARAKEET_SPM_SPACE; // SentencePiece dummy prefix
|
|
554
|
+
|
|
555
|
+
for (unsigned char c : text) {
|
|
556
|
+
if (std::isspace(c)) {
|
|
557
|
+
normalized += PARAKEET_SPM_SPACE;
|
|
558
|
+
} else {
|
|
559
|
+
normalized += static_cast<char>(c);
|
|
560
|
+
}
|
|
561
|
+
}
|
|
562
|
+
|
|
563
|
+
return normalized;
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
static std::string sentencepiece_piece_to_text(const std::string & piece, bool is_first_piece) {
|
|
567
|
+
if (is_sentencepiece_control(piece)) {
|
|
568
|
+
return "";
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
std::string text;
|
|
572
|
+
text.reserve(piece.size());
|
|
573
|
+
|
|
574
|
+
size_t pos = 0;
|
|
575
|
+
while (pos < piece.size()) {
|
|
576
|
+
if (piece.compare(pos, PARAKEET_SPM_SPACE.size(), PARAKEET_SPM_SPACE) == 0) {
|
|
577
|
+
if (!is_first_piece || !text.empty()) {
|
|
578
|
+
text += ' ';
|
|
579
|
+
}
|
|
580
|
+
pos += PARAKEET_SPM_SPACE.size();
|
|
581
|
+
continue;
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
text += piece[pos];
|
|
585
|
+
++pos;
|
|
586
|
+
}
|
|
587
|
+
|
|
588
|
+
return text;
|
|
589
|
+
}
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
static struct parakeet_batch parakeet_batch_init(int32_t n_tokens) {
|
|
593
|
+
parakeet_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, };
|
|
594
|
+
|
|
595
|
+
batch.token = (parakeet_token * ) malloc(sizeof(parakeet_token) * (n_tokens));
|
|
596
|
+
batch.i_time = (int32_t *) malloc(sizeof(int32_t) * (n_tokens));
|
|
597
|
+
batch.pos = (parakeet_pos *) malloc(sizeof(parakeet_pos) * (n_tokens));
|
|
598
|
+
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens));
|
|
599
|
+
batch.seq_id = (parakeet_seq_id **) malloc(sizeof(parakeet_seq_id *) * (n_tokens + 1));
|
|
600
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
601
|
+
batch.seq_id[i] = (parakeet_seq_id *) malloc(sizeof(parakeet_seq_id));
|
|
602
|
+
}
|
|
603
|
+
batch.seq_id[n_tokens] = nullptr;
|
|
604
|
+
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
|
605
|
+
|
|
606
|
+
return batch;
|
|
607
|
+
}
|
|
608
|
+
|
|
609
|
+
static void parakeet_batch_free(struct parakeet_batch batch) {
|
|
610
|
+
if (batch.token) free(batch.token);
|
|
611
|
+
if (batch.i_time) free(batch.i_time);
|
|
612
|
+
if (batch.pos) free(batch.pos);
|
|
613
|
+
if (batch.n_seq_id) free(batch.n_seq_id);
|
|
614
|
+
if (batch.seq_id) {
|
|
615
|
+
for (int i = 0; batch.seq_id[i]; ++i) {
|
|
616
|
+
free(batch.seq_id[i]);
|
|
617
|
+
}
|
|
618
|
+
free(batch.seq_id);
|
|
619
|
+
}
|
|
620
|
+
if (batch.logits) free(batch.logits);
|
|
621
|
+
}
|
|
622
|
+
|
|
623
|
+
static void parakeet_batch_prep_legacy(parakeet_batch & batch, const parakeet_token * tokens, int n_tokens, int n_past, int seq_id) {
|
|
624
|
+
batch.n_tokens = n_tokens;
|
|
625
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
626
|
+
if (tokens) {
|
|
627
|
+
batch.token[i] = tokens[i];
|
|
628
|
+
}
|
|
629
|
+
batch.pos [i] = n_past + i;
|
|
630
|
+
batch.n_seq_id[i] = 1;
|
|
631
|
+
batch.seq_id [i][0] = seq_id;
|
|
632
|
+
batch.logits [i] = 0;
|
|
633
|
+
}
|
|
634
|
+
batch.logits[n_tokens - 1] = 1;
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
static size_t parakeet_sched_size(struct parakeet_sched & allocr) {
|
|
639
|
+
size_t size = allocr.meta.size();
|
|
640
|
+
for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) {
|
|
641
|
+
ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i);
|
|
642
|
+
size += ggml_backend_sched_get_buffer_size(allocr.sched, backend);
|
|
643
|
+
}
|
|
644
|
+
return size;
|
|
645
|
+
}
|
|
646
|
+
|
|
647
|
+
static bool parakeet_sched_graph_init(struct parakeet_sched & allocr, std::vector<ggml_backend_t> backends, std::function<struct ggml_cgraph *()> && get_graph) {
|
|
648
|
+
auto & sched = allocr.sched;
|
|
649
|
+
auto & meta = allocr.meta;
|
|
650
|
+
|
|
651
|
+
sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), PARAKEET_MAX_NODES, false, true);
|
|
652
|
+
|
|
653
|
+
if (!sched) {
|
|
654
|
+
PARAKEET_LOG_ERROR("%s: failed to create scheduler\n", __func__);
|
|
655
|
+
return false;
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
meta.resize(ggml_tensor_overhead()*PARAKEET_MAX_NODES + ggml_graph_overhead());
|
|
659
|
+
|
|
660
|
+
if (!ggml_backend_sched_alloc_graph(sched, get_graph())) {
|
|
661
|
+
PARAKEET_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
|
|
662
|
+
ggml_backend_sched_free(sched);
|
|
663
|
+
sched = nullptr;
|
|
664
|
+
return false;
|
|
665
|
+
}
|
|
666
|
+
|
|
667
|
+
ggml_backend_sched_reset(sched);
|
|
668
|
+
|
|
669
|
+
return true;
|
|
670
|
+
}
|
|
671
|
+
|
|
672
|
+
static void parakeet_sched_free(struct parakeet_sched & sched) {
|
|
673
|
+
if (sched.sched) {
|
|
674
|
+
ggml_backend_sched_free(sched.sched);
|
|
675
|
+
sched.sched = nullptr;
|
|
676
|
+
}
|
|
677
|
+
|
|
678
|
+
sched.meta.clear();
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
template<typename T>
|
|
683
|
+
static void read_safe(parakeet_model_loader * loader, T & dest) {
|
|
684
|
+
loader->read(loader->context, &dest, sizeof(T));
|
|
685
|
+
BYTESWAP_VALUE(dest);
|
|
686
|
+
}
|
|
687
|
+
|
|
688
|
+
static bool parakeet_lstm_state_init(
|
|
689
|
+
struct parakeet_state & pstate,
|
|
690
|
+
ggml_backend_t backend,
|
|
691
|
+
int n_layer,
|
|
692
|
+
int n_pred_dim) {
|
|
693
|
+
parakeet_lstm_state & lstm_state = pstate.lstm_state;
|
|
694
|
+
|
|
695
|
+
lstm_state.ctx_buf.resize(ggml_tensor_overhead() * n_layer * 2);
|
|
696
|
+
lstm_state.layer.resize(n_layer);
|
|
697
|
+
|
|
698
|
+
struct ggml_init_params params = {
|
|
699
|
+
/*.mem_size =*/ lstm_state.ctx_buf.size(),
|
|
700
|
+
/*.mem_buffer =*/ lstm_state.ctx_buf.data(),
|
|
701
|
+
/*.no_alloc =*/ true,
|
|
702
|
+
};
|
|
703
|
+
|
|
704
|
+
struct ggml_context * ctx = ggml_init(params);
|
|
705
|
+
|
|
706
|
+
if (!ctx) {
|
|
707
|
+
PARAKEET_LOG_ERROR("%s: failed to allocate memory for the lstm states context\n", __func__);
|
|
708
|
+
return false;
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
for (int il = 0; il < n_layer; ++il) {
|
|
713
|
+
lstm_state.layer[il].h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim);
|
|
714
|
+
lstm_state.layer[il].c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim);
|
|
715
|
+
}
|
|
716
|
+
|
|
717
|
+
lstm_state.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
|
|
718
|
+
if (!lstm_state.buffer) {
|
|
719
|
+
PARAKEET_LOG_ERROR("%s: failed to allocate memory for the lstm states\n", __func__);
|
|
720
|
+
return false;
|
|
721
|
+
}
|
|
722
|
+
|
|
723
|
+
ggml_backend_buffer_clear(lstm_state.buffer, 0);
|
|
724
|
+
|
|
725
|
+
ggml_free(ctx);
|
|
726
|
+
|
|
727
|
+
return true;
|
|
728
|
+
}
|
|
729
|
+
|
|
730
|
+
static bool parakeet_pred_state_init(
|
|
731
|
+
struct parakeet_state & pstate,
|
|
732
|
+
ggml_backend_t backend,
|
|
733
|
+
int n_pred_dim) {
|
|
734
|
+
pstate.pred_out_buf.resize(ggml_tensor_overhead());
|
|
735
|
+
|
|
736
|
+
struct ggml_init_params params = {
|
|
737
|
+
/*.mem_size =*/ pstate.pred_out_buf.size(),
|
|
738
|
+
/*.mem_buffer =*/ pstate.pred_out_buf.data(),
|
|
739
|
+
/*.no_alloc =*/ true,
|
|
740
|
+
};
|
|
741
|
+
|
|
742
|
+
struct ggml_context * ctx = ggml_init(params);
|
|
743
|
+
if (!ctx) {
|
|
744
|
+
PARAKEET_LOG_ERROR("%s: failed to allocate memory for pred tensor context\n", __func__);
|
|
745
|
+
return false;
|
|
746
|
+
}
|
|
747
|
+
|
|
748
|
+
pstate.pred_out = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim);
|
|
749
|
+
pstate.pred_out_buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
|
|
750
|
+
if (!pstate.pred_out_buffer) {
|
|
751
|
+
PARAKEET_LOG_ERROR("%s: failed to allocate memory for pred tensor\n", __func__);
|
|
752
|
+
ggml_free(ctx);
|
|
753
|
+
return false;
|
|
754
|
+
}
|
|
755
|
+
|
|
756
|
+
ggml_free(ctx);
|
|
757
|
+
|
|
758
|
+
return true;
|
|
759
|
+
}
|
|
760
|
+
|
|
761
|
+
static bool parakeet_enc_state_init(
|
|
762
|
+
struct parakeet_state & pstate,
|
|
763
|
+
ggml_backend_t backend,
|
|
764
|
+
int n_audio_state,
|
|
765
|
+
int n_frames_max) {
|
|
766
|
+
pstate.enc_out_buf.resize(ggml_tensor_overhead());
|
|
767
|
+
|
|
768
|
+
struct ggml_init_params params = {
|
|
769
|
+
/*.mem_size =*/ pstate.enc_out_buf.size(),
|
|
770
|
+
/*.mem_buffer =*/ pstate.enc_out_buf.data(),
|
|
771
|
+
/*.no_alloc =*/ true,
|
|
772
|
+
};
|
|
773
|
+
|
|
774
|
+
struct ggml_context * ctx = ggml_init(params);
|
|
775
|
+
if (!ctx) {
|
|
776
|
+
PARAKEET_LOG_ERROR("%s: failed to allocate memory for enc_out tensor context\n", __func__);
|
|
777
|
+
return false;
|
|
778
|
+
}
|
|
779
|
+
|
|
780
|
+
pstate.enc_out = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_frames_max);
|
|
781
|
+
pstate.enc_out_buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
|
|
782
|
+
if (!pstate.enc_out_buffer) {
|
|
783
|
+
PARAKEET_LOG_ERROR("%s: failed to allocate memory for enc_out tensor\n", __func__);
|
|
784
|
+
ggml_free(ctx);
|
|
785
|
+
return false;
|
|
786
|
+
}
|
|
787
|
+
|
|
788
|
+
ggml_free(ctx);
|
|
789
|
+
|
|
790
|
+
return true;
|
|
791
|
+
}
|
|
792
|
+
|
|
793
|
+
static ggml_backend_t parakeet_backend_init_gpu(const parakeet_context_params & params) {
|
|
794
|
+
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
|
795
|
+
|
|
796
|
+
ggml_backend_dev_t dev = nullptr;
|
|
797
|
+
|
|
798
|
+
int cnt = 0;
|
|
799
|
+
if (params.use_gpu) {
|
|
800
|
+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
|
801
|
+
ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
|
|
802
|
+
enum ggml_backend_dev_type dev_type = ggml_backend_dev_type(dev_cur);
|
|
803
|
+
const char * dev_name = ggml_backend_dev_name(dev_cur);
|
|
804
|
+
PARAKEET_LOG_INFO("%s: device %zu: %s (type: %d)\n", __func__, i, dev_name, dev_type);
|
|
805
|
+
if (dev_type == GGML_BACKEND_DEVICE_TYPE_GPU || dev_type == GGML_BACKEND_DEVICE_TYPE_IGPU) {
|
|
806
|
+
PARAKEET_LOG_INFO("%s: found GPU device %zu: %s (type: %d, cnt: %d)\n", __func__, i, dev_name, dev_type, cnt);
|
|
807
|
+
if (cnt == params.gpu_device) {
|
|
808
|
+
dev = dev_cur;
|
|
809
|
+
}
|
|
810
|
+
|
|
811
|
+
if (++cnt > params.gpu_device) {
|
|
812
|
+
break;
|
|
813
|
+
}
|
|
814
|
+
}
|
|
815
|
+
}
|
|
816
|
+
}
|
|
817
|
+
|
|
818
|
+
if (dev == nullptr) {
|
|
819
|
+
PARAKEET_LOG_INFO("%s: no GPU found\n", __func__);
|
|
820
|
+
return nullptr;
|
|
821
|
+
}
|
|
822
|
+
|
|
823
|
+
PARAKEET_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
|
|
824
|
+
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
|
|
825
|
+
if (!result) {
|
|
826
|
+
PARAKEET_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
|
827
|
+
}
|
|
828
|
+
|
|
829
|
+
return result;
|
|
830
|
+
}
|
|
831
|
+
|
|
832
|
+
static std::vector<ggml_backend_t> parakeet_backend_init(const parakeet_context_params & params) {
|
|
833
|
+
std::vector<ggml_backend_t> result;
|
|
834
|
+
|
|
835
|
+
ggml_backend_t backend_gpu = parakeet_backend_init_gpu(params);
|
|
836
|
+
|
|
837
|
+
if (backend_gpu) {
|
|
838
|
+
result.push_back(backend_gpu);
|
|
839
|
+
}
|
|
840
|
+
|
|
841
|
+
// ACCEL backends
|
|
842
|
+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
|
843
|
+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
|
844
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
|
|
845
|
+
PARAKEET_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
|
|
846
|
+
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
|
847
|
+
if (!backend) {
|
|
848
|
+
PARAKEET_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
|
849
|
+
continue;
|
|
850
|
+
}
|
|
851
|
+
result.push_back(backend);
|
|
852
|
+
}
|
|
853
|
+
}
|
|
854
|
+
|
|
855
|
+
ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
|
856
|
+
if (backend_cpu == nullptr) {
|
|
857
|
+
throw std::runtime_error("failed to initialize CPU backend");
|
|
858
|
+
}
|
|
859
|
+
result.push_back(backend_cpu);
|
|
860
|
+
|
|
861
|
+
return result;
|
|
862
|
+
}
|
|
863
|
+
|
|
864
|
+
using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>;
|
|
865
|
+
|
|
866
|
+
static buft_list_t make_buft_list(parakeet_context_params & params) {
|
|
867
|
+
// Prio order: GPU -> CPU Extra -> CPU
|
|
868
|
+
buft_list_t buft_list;
|
|
869
|
+
|
|
870
|
+
// GPU
|
|
871
|
+
if (params.use_gpu) {
|
|
872
|
+
int cnt = 0;
|
|
873
|
+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
|
874
|
+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
|
875
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU) {
|
|
876
|
+
if (cnt == params.gpu_device) {
|
|
877
|
+
auto * buft = ggml_backend_dev_buffer_type(dev);
|
|
878
|
+
if (buft) {
|
|
879
|
+
buft_list.emplace_back(dev, buft);
|
|
880
|
+
}
|
|
881
|
+
}
|
|
882
|
+
|
|
883
|
+
if (++cnt > params.gpu_device) {
|
|
884
|
+
break;
|
|
885
|
+
}
|
|
886
|
+
}
|
|
887
|
+
}
|
|
888
|
+
}
|
|
889
|
+
|
|
890
|
+
// CPU Extra
|
|
891
|
+
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
|
892
|
+
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
|
893
|
+
auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
|
894
|
+
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
|
|
895
|
+
if (get_extra_bufts_fn) {
|
|
896
|
+
ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev);
|
|
897
|
+
while (extra_bufts && *extra_bufts) {
|
|
898
|
+
buft_list.emplace_back(cpu_dev, *extra_bufts);
|
|
899
|
+
++extra_bufts;
|
|
900
|
+
}
|
|
901
|
+
}
|
|
902
|
+
|
|
903
|
+
// CPU
|
|
904
|
+
buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type());
|
|
905
|
+
|
|
906
|
+
return buft_list;
|
|
907
|
+
}
|
|
908
|
+
|
|
909
|
+
static bool weight_buft_supported(const parakeet_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
|
|
910
|
+
bool op_supported = true;
|
|
911
|
+
|
|
912
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
|
|
913
|
+
ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU ||
|
|
914
|
+
(ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
|
|
915
|
+
// GPU and default CPU backend support all operators
|
|
916
|
+
op_supported = true;
|
|
917
|
+
} else {
|
|
918
|
+
switch (op) {
|
|
919
|
+
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT and GGML_OP_GET_ROWS
|
|
920
|
+
case GGML_OP_GET_ROWS:
|
|
921
|
+
case GGML_OP_MUL_MAT: {
|
|
922
|
+
ggml_init_params params = {
|
|
923
|
+
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
|
|
924
|
+
/*.mem_buffer =*/ nullptr,
|
|
925
|
+
/*.no_alloc =*/ true,
|
|
926
|
+
};
|
|
927
|
+
|
|
928
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
929
|
+
if (!ctx_ptr) {
|
|
930
|
+
throw std::runtime_error("failed to create ggml context");
|
|
931
|
+
}
|
|
932
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
933
|
+
|
|
934
|
+
ggml_tensor * op_tensor = nullptr;
|
|
935
|
+
|
|
936
|
+
if (op == GGML_OP_MUL_MAT) {
|
|
937
|
+
int64_t n_ctx = hparams.n_audio_ctx;
|
|
938
|
+
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
|
|
939
|
+
op_tensor = ggml_mul_mat(ctx, w, b);
|
|
940
|
+
} else if (op == GGML_OP_GET_ROWS) {
|
|
941
|
+
int64_t num_indices = 8;
|
|
942
|
+
ggml_tensor * indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices);
|
|
943
|
+
op_tensor = ggml_get_rows(ctx, w, indices);
|
|
944
|
+
}
|
|
945
|
+
|
|
946
|
+
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
|
|
947
|
+
GGML_ASSERT(w->buffer == nullptr);
|
|
948
|
+
w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
|
|
949
|
+
op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
|
|
950
|
+
ggml_backend_buffer_free(w->buffer);
|
|
951
|
+
w->buffer = nullptr;
|
|
952
|
+
break;
|
|
953
|
+
}
|
|
954
|
+
default: {
|
|
955
|
+
op_supported = false;
|
|
956
|
+
break;
|
|
957
|
+
}
|
|
958
|
+
};
|
|
959
|
+
}
|
|
960
|
+
|
|
961
|
+
return op_supported;
|
|
962
|
+
}
|
|
963
|
+
|
|
964
|
+
static ggml_backend_buffer_type_t select_weight_buft(const parakeet_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
|
|
965
|
+
GGML_ASSERT(!buft_list.empty());
|
|
966
|
+
for (const auto & p : buft_list) {
|
|
967
|
+
ggml_backend_dev_t dev = p.first;
|
|
968
|
+
ggml_backend_buffer_type_t buft = p.second;
|
|
969
|
+
if (weight_buft_supported(hparams, w, op, buft, dev)) {
|
|
970
|
+
return buft;
|
|
971
|
+
}
|
|
972
|
+
}
|
|
973
|
+
|
|
974
|
+
return nullptr;
|
|
975
|
+
}
|
|
976
|
+
|
|
977
|
+
|
|
978
|
+
// load the model from a ggml file
|
|
979
|
+
//
|
|
980
|
+
|
|
981
|
+
// see the convert-parakeet-to-ggml.py script for details
|
|
982
|
+
//
|
|
983
|
+
static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_context & wctx) {
|
|
984
|
+
PARAKEET_LOG_INFO("%s: loading model\n", __func__);
|
|
985
|
+
|
|
986
|
+
const int64_t t_start_us = ggml_time_us();
|
|
987
|
+
|
|
988
|
+
wctx.t_start_us = t_start_us;
|
|
989
|
+
|
|
990
|
+
auto & model = wctx.model;
|
|
991
|
+
auto & vocab = wctx.vocab;
|
|
992
|
+
|
|
993
|
+
// verify magic
|
|
994
|
+
{
|
|
995
|
+
uint32_t magic;
|
|
996
|
+
read_safe(loader, magic);
|
|
997
|
+
if (magic != GGML_FILE_MAGIC) {
|
|
998
|
+
PARAKEET_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
|
|
999
|
+
return false;
|
|
1000
|
+
}
|
|
1001
|
+
}
|
|
1002
|
+
|
|
1003
|
+
//load hparams
|
|
1004
|
+
parakeet_hparams hparams;
|
|
1005
|
+
{
|
|
1006
|
+
read_safe(loader, hparams.n_vocab);
|
|
1007
|
+
read_safe(loader, hparams.n_audio_ctx);
|
|
1008
|
+
read_safe(loader, hparams.n_audio_state);
|
|
1009
|
+
read_safe(loader, hparams.n_audio_head);
|
|
1010
|
+
read_safe(loader, hparams.n_audio_layer);
|
|
1011
|
+
read_safe(loader, hparams.n_mels);
|
|
1012
|
+
read_safe(loader, hparams.ftype);
|
|
1013
|
+
read_safe(loader, hparams.n_fft);
|
|
1014
|
+
read_safe(loader, hparams.subsampling_factor);
|
|
1015
|
+
read_safe(loader, hparams.n_subsampling_channels);
|
|
1016
|
+
read_safe(loader, hparams.n_conv_kernel);
|
|
1017
|
+
read_safe(loader, hparams.n_pred_dim);
|
|
1018
|
+
read_safe(loader, hparams.n_pred_layers);
|
|
1019
|
+
read_safe(loader, hparams.n_tdt_durations);
|
|
1020
|
+
read_safe(loader, hparams.n_max_tokens);
|
|
1021
|
+
|
|
1022
|
+
hparams.arch = PARAKEET_ARCH_TDT;
|
|
1023
|
+
wctx.model.hparams = hparams;
|
|
1024
|
+
|
|
1025
|
+
const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
|
|
1026
|
+
|
|
1027
|
+
hparams.ftype %= GGML_QNT_VERSION_FACTOR;
|
|
1028
|
+
|
|
1029
|
+
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
|
|
1030
|
+
// in order to save memory and also to speed up the computation
|
|
1031
|
+
wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) hparams.ftype);
|
|
1032
|
+
if (wctx.wtype == GGML_TYPE_COUNT) {
|
|
1033
|
+
PARAKEET_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, hparams.ftype);
|
|
1034
|
+
return false;
|
|
1035
|
+
}
|
|
1036
|
+
|
|
1037
|
+
const char* arch_name = hparams.arch == PARAKEET_ARCH_TDT ? "Parakeet TDT" : "unknown";
|
|
1038
|
+
PARAKEET_LOG_INFO("%s: arch = %s\n", __func__, arch_name);
|
|
1039
|
+
PARAKEET_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
|
1040
|
+
PARAKEET_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
|
|
1041
|
+
PARAKEET_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
|
|
1042
|
+
PARAKEET_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
|
|
1043
|
+
PARAKEET_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
|
|
1044
|
+
PARAKEET_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels);
|
|
1045
|
+
PARAKEET_LOG_INFO("%s: n_fft = %d\n", __func__, hparams.n_fft);
|
|
1046
|
+
PARAKEET_LOG_INFO("%s: eps = %f\n", __func__, hparams.eps);
|
|
1047
|
+
PARAKEET_LOG_INFO("%s: ftype = %d\n", __func__, hparams.ftype);
|
|
1048
|
+
PARAKEET_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr);
|
|
1049
|
+
PARAKEET_LOG_INFO("%s: subsampling_factor = %d\n", __func__, hparams.subsampling_factor);
|
|
1050
|
+
PARAKEET_LOG_INFO("%s: n_subsampling_channels = %d\n", __func__, hparams.n_subsampling_channels);
|
|
1051
|
+
PARAKEET_LOG_INFO("%s: n_conv_kernel = %d\n", __func__, hparams.n_conv_kernel);
|
|
1052
|
+
PARAKEET_LOG_INFO("%s: n_pred_dim = %d\n", __func__, hparams.n_pred_dim);
|
|
1053
|
+
PARAKEET_LOG_INFO("%s: n_pred_layers = %d\n", __func__, hparams.n_pred_layers);
|
|
1054
|
+
PARAKEET_LOG_INFO("%s: n_tdt_durations = %d\n", __func__, hparams.n_tdt_durations);
|
|
1055
|
+
PARAKEET_LOG_INFO("%s: n_max_tokens = %d\n", __func__, hparams.n_max_tokens);
|
|
1056
|
+
}
|
|
1057
|
+
|
|
1058
|
+
// load mel filters
|
|
1059
|
+
{
|
|
1060
|
+
auto & filters = wctx.model.filters;
|
|
1061
|
+
|
|
1062
|
+
read_safe(loader, filters.n_mel);
|
|
1063
|
+
read_safe(loader, filters.n_fb);
|
|
1064
|
+
|
|
1065
|
+
filters.data.resize(filters.n_mel * filters.n_fb);
|
|
1066
|
+
loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float));
|
|
1067
|
+
BYTESWAP_FILTERS(filters);
|
|
1068
|
+
}
|
|
1069
|
+
|
|
1070
|
+
// load window function
|
|
1071
|
+
{
|
|
1072
|
+
int32_t n_window = 0;
|
|
1073
|
+
read_safe(loader, n_window);
|
|
1074
|
+
|
|
1075
|
+
wctx.mel_cache.window.resize(n_window);
|
|
1076
|
+
loader->read(loader->context, wctx.mel_cache.window.data(), n_window * sizeof(float));
|
|
1077
|
+
|
|
1078
|
+
#ifdef GGML_BIG_ENDIAN
|
|
1079
|
+
for (auto & datum : wctx.mel_cache.window) {
|
|
1080
|
+
datum = byteswap(datum);
|
|
1081
|
+
}
|
|
1082
|
+
#endif
|
|
1083
|
+
|
|
1084
|
+
PARAKEET_LOG_INFO("%s: loaded window function with %d samples\n", __func__, n_window);
|
|
1085
|
+
}
|
|
1086
|
+
|
|
1087
|
+
// load TDT (Token and Duration Transducer) values
|
|
1088
|
+
{
|
|
1089
|
+
auto & tdt_durations = wctx.model.tdt_durations;
|
|
1090
|
+
tdt_durations.resize(hparams.n_tdt_durations);
|
|
1091
|
+
loader->read(loader->context, tdt_durations.data(), hparams.n_tdt_durations * sizeof(uint32_t));
|
|
1092
|
+
|
|
1093
|
+
PARAKEET_LOG_INFO("%s: loaded tdt_durations: [", __func__);
|
|
1094
|
+
for (const auto value : tdt_durations) {
|
|
1095
|
+
PARAKEET_LOG_INFO("%u ", value);
|
|
1096
|
+
}
|
|
1097
|
+
PARAKEET_LOG_INFO("]\n");
|
|
1098
|
+
}
|
|
1099
|
+
|
|
1100
|
+
// load vocab
|
|
1101
|
+
{
|
|
1102
|
+
int32_t n_vocab = 0;
|
|
1103
|
+
read_safe(loader, n_vocab);
|
|
1104
|
+
|
|
1105
|
+
std::string word;
|
|
1106
|
+
std::vector<char> tmp;
|
|
1107
|
+
|
|
1108
|
+
tmp.reserve(128);
|
|
1109
|
+
|
|
1110
|
+
for (int i = 0; i < n_vocab; i++) {
|
|
1111
|
+
uint32_t len;
|
|
1112
|
+
read_safe(loader, len);
|
|
1113
|
+
|
|
1114
|
+
if (len > 0) {
|
|
1115
|
+
tmp.resize(len);
|
|
1116
|
+
loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
|
|
1117
|
+
word.assign(&tmp[0], tmp.size());
|
|
1118
|
+
} else {
|
|
1119
|
+
PARAKEET_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
|
|
1120
|
+
word = "";
|
|
1121
|
+
}
|
|
1122
|
+
|
|
1123
|
+
vocab.token_to_id[word] = i;
|
|
1124
|
+
vocab.id_to_token[i] = word;
|
|
1125
|
+
vocab.max_token_length = std::max(vocab.max_token_length, word.size());
|
|
1126
|
+
}
|
|
1127
|
+
// Blank token for transducer is at index n_vocab (8192), outside the vocabulary
|
|
1128
|
+
int blank_id = n_vocab;
|
|
1129
|
+
vocab.token_blank = blank_id;
|
|
1130
|
+
vocab.id_to_token[blank_id] = "[BLANK]";
|
|
1131
|
+
vocab.token_to_id["[BLANK]"] = blank_id;
|
|
1132
|
+
|
|
1133
|
+
// Set special token IDs by looking them up in the loaded vocabulary
|
|
1134
|
+
// These are from the SentencePiece vocab file loaded above
|
|
1135
|
+
if (vocab.token_to_id.find("<unk>") != vocab.token_to_id.end()) {
|
|
1136
|
+
vocab.token_unk = vocab.token_to_id.at("<unk>");
|
|
1137
|
+
} else {
|
|
1138
|
+
vocab.token_unk = 0; // Fallback
|
|
1139
|
+
}
|
|
1140
|
+
|
|
1141
|
+
if (vocab.token_to_id.find("<s>") != vocab.token_to_id.end()) {
|
|
1142
|
+
vocab.token_bos = vocab.token_to_id.at("<s>");
|
|
1143
|
+
} else if (vocab.token_to_id.find("<|startoftranscript|>") != vocab.token_to_id.end()) {
|
|
1144
|
+
vocab.token_bos = vocab.token_to_id.at("<|startoftranscript|>");
|
|
1145
|
+
} else {
|
|
1146
|
+
vocab.token_bos = 0; // Fallback
|
|
1147
|
+
}
|
|
1148
|
+
|
|
1149
|
+
if (vocab.token_to_id.find("</s>") != vocab.token_to_id.end()) {
|
|
1150
|
+
vocab.token_eos = vocab.token_to_id.at("</s>");
|
|
1151
|
+
} else if (vocab.token_to_id.find("<|endoftext|>") != vocab.token_to_id.end()) {
|
|
1152
|
+
vocab.token_eos = vocab.token_to_id.at("<|endoftext|>");
|
|
1153
|
+
} else {
|
|
1154
|
+
vocab.token_eos = 0; // Fallback
|
|
1155
|
+
}
|
|
1156
|
+
|
|
1157
|
+
vocab.n_vocab = model.hparams.n_vocab;
|
|
1158
|
+
|
|
1159
|
+
PARAKEET_LOG_INFO("%s: loaded vocab with %d tokens (blank_id=%d, unk=%d, bos=%d, eos=%d)\n",
|
|
1160
|
+
__func__, n_vocab, blank_id, vocab.token_unk, vocab.token_bos, vocab.token_eos);
|
|
1161
|
+
}
|
|
1162
|
+
|
|
1163
|
+
const ggml_type wtype = wctx.wtype;
|
|
1164
|
+
|
|
1165
|
+
|
|
1166
|
+
const int n_audio_layer = hparams.n_audio_layer;
|
|
1167
|
+
|
|
1168
|
+
// Calculate tensor count: pre_encode (12) + encoder layers (29 per layer) + prediction (9) + joint (6)
|
|
1169
|
+
size_t n_tensors = 12 + (29 * n_audio_layer) + 9 + 6;
|
|
1170
|
+
|
|
1171
|
+
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
|
1172
|
+
auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
|
1173
|
+
auto it = ctx_map.find(buft);
|
|
1174
|
+
if (it == ctx_map.end()) {
|
|
1175
|
+
ggml_init_params params = {
|
|
1176
|
+
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
|
|
1177
|
+
/*.mem_buffer =*/ nullptr,
|
|
1178
|
+
/*.no_alloc =*/ true,
|
|
1179
|
+
};
|
|
1180
|
+
|
|
1181
|
+
ggml_context * ctx = ggml_init(params);
|
|
1182
|
+
if (!ctx) {
|
|
1183
|
+
throw std::runtime_error("failed to create ggml context");
|
|
1184
|
+
}
|
|
1185
|
+
|
|
1186
|
+
ctx_map[buft] = ctx;
|
|
1187
|
+
wctx.model.ctxs.emplace_back(ctx);
|
|
1188
|
+
|
|
1189
|
+
return ctx;
|
|
1190
|
+
}
|
|
1191
|
+
|
|
1192
|
+
return it->second;
|
|
1193
|
+
};
|
|
1194
|
+
|
|
1195
|
+
// Create a list of available bufts, in priority order
|
|
1196
|
+
buft_list_t buft_list = make_buft_list(wctx.params);
|
|
1197
|
+
|
|
1198
|
+
auto create_tensor = [&](parakeet_tensor type, ggml_tensor * meta, int layer = -1) -> ggml_tensor * {
|
|
1199
|
+
ggml_op op = PARAKEET_TENSOR_INFO.at(type);
|
|
1200
|
+
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
|
|
1201
|
+
if (!buft) {
|
|
1202
|
+
throw std::runtime_error(format("failed to find a compatible buffer type for parakeet tensor %s",
|
|
1203
|
+
PARAKEET_TENSOR_NAMES.at(type)));
|
|
1204
|
+
}
|
|
1205
|
+
|
|
1206
|
+
ggml_context * ctx = get_ctx(buft);
|
|
1207
|
+
ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
|
|
1208
|
+
|
|
1209
|
+
std::string tensor_name;
|
|
1210
|
+
if (layer >= 0) {
|
|
1211
|
+
tensor_name = format(PARAKEET_TENSOR_NAMES.at(type), layer);
|
|
1212
|
+
} else {
|
|
1213
|
+
tensor_name = PARAKEET_TENSOR_NAMES.at(type);
|
|
1214
|
+
}
|
|
1215
|
+
|
|
1216
|
+
wctx.model.tensors[tensor_name] = tensor;
|
|
1217
|
+
|
|
1218
|
+
return tensor;
|
|
1219
|
+
};
|
|
1220
|
+
|
|
1221
|
+
// prepare tensors for the weights
|
|
1222
|
+
|
|
1223
|
+
ggml_init_params params = {
|
|
1224
|
+
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
|
|
1225
|
+
/*.mem_buffer =*/ nullptr,
|
|
1226
|
+
/*.no_alloc =*/ true,
|
|
1227
|
+
};
|
|
1228
|
+
|
|
1229
|
+
ggml_context * ctx = ggml_init(params);
|
|
1230
|
+
|
|
1231
|
+
const int n_audio_state = hparams.n_audio_state;
|
|
1232
|
+
|
|
1233
|
+
model.layers.resize(n_audio_layer);
|
|
1234
|
+
|
|
1235
|
+
// Encoder pre_encode
|
|
1236
|
+
const int n_subsampling_channels = hparams.n_subsampling_channels;
|
|
1237
|
+
const int n_pre_enc_features = (hparams.n_mels / hparams.subsampling_factor) * n_subsampling_channels;
|
|
1238
|
+
model.enc_pre_out_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_pre_enc_features, n_audio_state));
|
|
1239
|
+
ggml_set_name(model.enc_pre_out_w, "enc_pre_out_w");
|
|
1240
|
+
model.enc_pre_out_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
|
|
1241
|
+
ggml_set_name(model.enc_pre_out_b, "enc_pre_out_b");
|
|
1242
|
+
|
|
1243
|
+
model.enc_pre_conv_0_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels));
|
|
1244
|
+
ggml_set_name(model.enc_pre_conv_0_w, "enc_pre_conv_0_w");
|
|
1245
|
+
model.enc_pre_conv_0_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1));
|
|
1246
|
+
ggml_set_name(model.enc_pre_conv_0_b, "enc_pre_conv_0_b");
|
|
1247
|
+
|
|
1248
|
+
model.enc_pre_conv_2_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels));
|
|
1249
|
+
ggml_set_name(model.enc_pre_conv_2_w, "enc_pre_conv_2_w");
|
|
1250
|
+
model.enc_pre_conv_2_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1));
|
|
1251
|
+
ggml_set_name(model.enc_pre_conv_2_b, "enc_pre_conv_2_b");
|
|
1252
|
+
|
|
1253
|
+
model.enc_pre_conv_3_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, n_subsampling_channels));
|
|
1254
|
+
ggml_set_name(model.enc_pre_conv_3_w, "enc_pre_conv_3_w");
|
|
1255
|
+
model.enc_pre_conv_3_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1));
|
|
1256
|
+
ggml_set_name(model.enc_pre_conv_3_b, "enc_pre_conv_3_b");
|
|
1257
|
+
|
|
1258
|
+
model.enc_pre_conv_5_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels));
|
|
1259
|
+
ggml_set_name(model.enc_pre_conv_5_w, "enc_pre_conv_5_w");
|
|
1260
|
+
model.enc_pre_conv_5_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1));
|
|
1261
|
+
ggml_set_name(model.enc_pre_conv_5_b, "enc_pre_conv_5_b");
|
|
1262
|
+
|
|
1263
|
+
model.enc_pre_conv_6_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, n_subsampling_channels));
|
|
1264
|
+
ggml_set_name(model.enc_pre_conv_6_w, "enc_pre_conv_6_w");
|
|
1265
|
+
model.enc_pre_conv_6_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1));
|
|
1266
|
+
ggml_set_name(model.enc_pre_conv_6_b, "enc_pre_conv_6_b");
|
|
1267
|
+
|
|
1268
|
+
// Encoder layers
|
|
1269
|
+
for (int i = 0; i < n_audio_layer; ++i) {
|
|
1270
|
+
auto & layer = model.layers[i];
|
|
1271
|
+
|
|
1272
|
+
// Feed forward 1
|
|
1273
|
+
layer.norm_ff1_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
|
1274
|
+
layer.norm_ff1_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
|
1275
|
+
layer.ff1_linear1_w = create_tensor(PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i);
|
|
1276
|
+
ggml_format_name(layer.ff1_linear1_w, "enc_%d_ff1_linear1_w", i);
|
|
1277
|
+
layer.ff1_linear2_w = create_tensor(PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i);
|
|
1278
|
+
ggml_format_name(layer.ff1_linear2_w, "enc_%d_ff1_linear2_w", i);
|
|
1279
|
+
|
|
1280
|
+
// Convolution module
|
|
1281
|
+
layer.norm_conv_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
|
1282
|
+
ggml_format_name(layer.norm_conv_w, "enc_%d_norm_conv_w", i);
|
|
1283
|
+
layer.norm_conv_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
|
1284
|
+
ggml_format_name(layer.norm_conv_b, "enc_%d_norm_conv_b", i);
|
|
1285
|
+
layer.conv_pw1_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 2*n_audio_state), i);
|
|
1286
|
+
ggml_format_name(layer.conv_pw1_w, "enc_%d_conv_pw1_w", i);
|
|
1287
|
+
layer.conv_dw_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_conv_kernel, n_audio_state), i);
|
|
1288
|
+
ggml_format_name(layer.conv_dw_w, "enc_%d_conv_dw_w", i);
|
|
1289
|
+
layer.conv_bn_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
|
1290
|
+
ggml_format_name(layer.conv_bn_w, "enc_%d_conv_bn_w", i);
|
|
1291
|
+
layer.conv_bn_b = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
|
1292
|
+
ggml_format_name(layer.conv_bn_b, "enc_%d_conv_bn_b", i);
|
|
1293
|
+
layer.conv_bn_mean = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_MEAN, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
|
1294
|
+
layer.conv_bn_var = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_VAR, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
|
1295
|
+
ggml_format_name(layer.conv_bn_var, "enc_%d_conv_bn_var", i);
|
|
1296
|
+
layer.conv_bn_num_batches = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), i);
|
|
1297
|
+
layer.conv_pw2_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
|
1298
|
+
ggml_format_name(layer.conv_pw2_w, "enc_%d_conv_pw2_w", i);
|
|
1299
|
+
|
|
1300
|
+
// Self attention
|
|
1301
|
+
layer.norm_attn_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
|
1302
|
+
layer.norm_attn_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
|
1303
|
+
layer.attn_pos_bias_u = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_audio_state / hparams.n_audio_head, hparams.n_audio_head), i);
|
|
1304
|
+
layer.attn_pos_bias_v = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_audio_state / hparams.n_audio_head, hparams.n_audio_head), i);
|
|
1305
|
+
layer.attn_q_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
|
1306
|
+
layer.attn_k_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
|
1307
|
+
layer.attn_v_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
|
1308
|
+
layer.attn_out_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
|
1309
|
+
layer.attn_pos_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
|
1310
|
+
ggml_format_name(layer.attn_pos_w, "enc_%d_attn_pos_w", i);
|
|
1311
|
+
|
|
1312
|
+
// Feed forward 2
|
|
1313
|
+
layer.norm_ff2_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
|
1314
|
+
layer.norm_ff2_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
|
1315
|
+
layer.ff2_linear1_w = create_tensor(PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i);
|
|
1316
|
+
layer.ff2_linear2_w = create_tensor(PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i);
|
|
1317
|
+
|
|
1318
|
+
// Output norm
|
|
1319
|
+
layer.norm_out_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
|
1320
|
+
layer.norm_out_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
|
1321
|
+
}
|
|
1322
|
+
|
|
1323
|
+
// Prediction network (decoder)
|
|
1324
|
+
const int dec_hidden = hparams.n_pred_dim;
|
|
1325
|
+
const int n_pred_embed = hparams.n_vocab + 1; // vocab + blank token
|
|
1326
|
+
const int n_lstm_gates = 4 * dec_hidden; // 4 LSTM gates
|
|
1327
|
+
const int n_joint_out = hparams.n_vocab + hparams.n_tdt_durations + 1; // vocab + durations + blank
|
|
1328
|
+
|
|
1329
|
+
// The prediction/joint hidden dimension is 640, which is not a multiple of the
|
|
1330
|
+
// K-quant block size (256). For K-quant models, we keep these tensors at F32.
|
|
1331
|
+
const int blck = ggml_blck_size(wtype);
|
|
1332
|
+
const ggml_type pred_wtype = (blck > 1 && dec_hidden % blck != 0) ? GGML_TYPE_F32 : wtype;
|
|
1333
|
+
const ggml_type join_wtype = pred_wtype;
|
|
1334
|
+
|
|
1335
|
+
model.prediction.embed_w = create_tensor(PARAKEET_TENSOR_PRED_EMBED_WEIGHT, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_pred_embed));
|
|
1336
|
+
model.prediction.lstm_layer.resize(hparams.n_pred_layers);
|
|
1337
|
+
for (int i = 0; i < hparams.n_pred_layers; ++i) {
|
|
1338
|
+
auto & layer = model.prediction.lstm_layer[i];
|
|
1339
|
+
layer.ih_w = create_tensor(PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_lstm_gates), i);
|
|
1340
|
+
ggml_format_name(layer.ih_w, "pred_%d_ih_w", i);
|
|
1341
|
+
|
|
1342
|
+
layer.hh_w = create_tensor(PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_lstm_gates), i);
|
|
1343
|
+
ggml_format_name(layer.hh_w, "pred_%d_hh_w", i);
|
|
1344
|
+
|
|
1345
|
+
layer.b_h = create_tensor(PARAKEET_TENSOR_PRED_LSTM_BIAS_H, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_lstm_gates), i);
|
|
1346
|
+
ggml_format_name(layer.b_h, "pred_%d_b_h", i);
|
|
1347
|
+
}
|
|
1348
|
+
|
|
1349
|
+
// Joint network
|
|
1350
|
+
model.joint.pred_w = create_tensor(PARAKEET_TENSOR_JOINT_PRED_WEIGHT, ggml_new_tensor_2d(ctx, join_wtype, dec_hidden, dec_hidden));
|
|
1351
|
+
ggml_set_name(model.joint.pred_w, "pred_w");
|
|
1352
|
+
model.joint.pred_b = create_tensor(PARAKEET_TENSOR_JOINT_PRED_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dec_hidden));
|
|
1353
|
+
ggml_set_name(model.joint.pred_b, "pred_b");
|
|
1354
|
+
model.joint.enc_w = create_tensor(PARAKEET_TENSOR_JOINT_ENC_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, dec_hidden));
|
|
1355
|
+
ggml_set_name(model.joint.enc_w, "enc_w");
|
|
1356
|
+
model.joint.enc_b = create_tensor(PARAKEET_TENSOR_JOINT_ENC_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dec_hidden));
|
|
1357
|
+
ggml_set_name(model.joint.enc_b, "enc_b");
|
|
1358
|
+
model.joint.net_w = create_tensor(PARAKEET_TENSOR_JOINT_NET_WEIGHT, ggml_new_tensor_2d(ctx, join_wtype, dec_hidden, n_joint_out));
|
|
1359
|
+
ggml_set_name(model.joint.net_w, "net_w");
|
|
1360
|
+
model.joint.net_b = create_tensor(PARAKEET_TENSOR_JOINT_NET_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_joint_out));
|
|
1361
|
+
ggml_set_name(model.joint.net_b, "net_b");
|
|
1362
|
+
|
|
1363
|
+
ggml_free(ctx);
|
|
1364
|
+
|
|
1365
|
+
// allocate tensors in the backend buffers
|
|
1366
|
+
for (auto & p : ctx_map) {
|
|
1367
|
+
ggml_backend_buffer_type_t buft = p.first;
|
|
1368
|
+
ggml_context * ctx = p.second;
|
|
1369
|
+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
|
1370
|
+
if (buf) {
|
|
1371
|
+
wctx.model.buffers.emplace_back(buf);
|
|
1372
|
+
|
|
1373
|
+
size_t size_main = ggml_backend_buffer_get_size(buf);
|
|
1374
|
+
PARAKEET_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
|
|
1375
|
+
}
|
|
1376
|
+
}
|
|
1377
|
+
|
|
1378
|
+
// load weights
|
|
1379
|
+
{
|
|
1380
|
+
size_t total_size = 0;
|
|
1381
|
+
|
|
1382
|
+
auto & tensors_map = wctx.model.tensors;
|
|
1383
|
+
int & n_loaded = wctx.model.n_loaded;
|
|
1384
|
+
|
|
1385
|
+
n_loaded = 0;
|
|
1386
|
+
|
|
1387
|
+
std::vector<char> read_buf;
|
|
1388
|
+
|
|
1389
|
+
while (true) {
|
|
1390
|
+
int32_t n_dims;
|
|
1391
|
+
int32_t length;
|
|
1392
|
+
int32_t ttype;
|
|
1393
|
+
|
|
1394
|
+
read_safe(loader, n_dims);
|
|
1395
|
+
read_safe(loader, length);
|
|
1396
|
+
read_safe(loader, ttype);
|
|
1397
|
+
|
|
1398
|
+
if (loader->eof(loader->context)) {
|
|
1399
|
+
break;
|
|
1400
|
+
}
|
|
1401
|
+
|
|
1402
|
+
int32_t nelements = 1;
|
|
1403
|
+
int32_t ne[4] = { 1, 1, 1, 1 };
|
|
1404
|
+
for (int i = 0; i < n_dims; ++i) {
|
|
1405
|
+
read_safe(loader, ne[i]);
|
|
1406
|
+
nelements *= ne[i];
|
|
1407
|
+
}
|
|
1408
|
+
|
|
1409
|
+
std::string name;
|
|
1410
|
+
std::vector<char> tmp(length); // create a buffer
|
|
1411
|
+
loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
|
|
1412
|
+
name.assign(&tmp[0], tmp.size());
|
|
1413
|
+
|
|
1414
|
+
if (tensors_map.find(name) == tensors_map.end()) {
|
|
1415
|
+
PARAKEET_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
|
1416
|
+
return false;
|
|
1417
|
+
}
|
|
1418
|
+
|
|
1419
|
+
auto tensor = tensors_map[name.data()];
|
|
1420
|
+
|
|
1421
|
+
if (ggml_nelements(tensor) != nelements) {
|
|
1422
|
+
PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
|
1423
|
+
PARAKEET_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
|
|
1424
|
+
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
|
|
1425
|
+
return false;
|
|
1426
|
+
}
|
|
1427
|
+
|
|
1428
|
+
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2] || tensor->ne[3] != ne[3]) {
|
|
1429
|
+
PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d, %d], expected [%d, %d, %d, %d]\n",
|
|
1430
|
+
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], (int) tensor->ne[3], ne[0], ne[1], ne[2], ne[3]);
|
|
1431
|
+
return false;
|
|
1432
|
+
}
|
|
1433
|
+
|
|
1434
|
+
const size_t bpe = ggml_type_size(ggml_type(ttype));
|
|
1435
|
+
|
|
1436
|
+
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
|
1437
|
+
PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
|
1438
|
+
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
|
1439
|
+
return false;
|
|
1440
|
+
}
|
|
1441
|
+
|
|
1442
|
+
if (ggml_backend_buffer_is_host(tensor->buffer)) {
|
|
1443
|
+
// for the CPU and Metal backend, we can read directly into the tensor
|
|
1444
|
+
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
|
|
1445
|
+
BYTESWAP_TENSOR(tensor);
|
|
1446
|
+
} else {
|
|
1447
|
+
// read into a temporary buffer first, then copy to device memory
|
|
1448
|
+
read_buf.resize(ggml_nbytes(tensor));
|
|
1449
|
+
|
|
1450
|
+
loader->read(loader->context, read_buf.data(), read_buf.size());
|
|
1451
|
+
|
|
1452
|
+
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
|
|
1453
|
+
}
|
|
1454
|
+
|
|
1455
|
+
total_size += ggml_nbytes(tensor);
|
|
1456
|
+
n_loaded++;
|
|
1457
|
+
}
|
|
1458
|
+
|
|
1459
|
+
PARAKEET_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
|
|
1460
|
+
|
|
1461
|
+
if (n_loaded == 0) {
|
|
1462
|
+
PARAKEET_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
|
|
1463
|
+
} else if (n_loaded != (int) tensors_map.size()) {
|
|
1464
|
+
PARAKEET_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, tensors_map.size(), n_loaded);
|
|
1465
|
+
return false;
|
|
1466
|
+
}
|
|
1467
|
+
}
|
|
1468
|
+
|
|
1469
|
+
auto & buffers = wctx.model.buffers;
|
|
1470
|
+
for (auto & buf : buffers) {
|
|
1471
|
+
ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
|
1472
|
+
}
|
|
1473
|
+
|
|
1474
|
+
wctx.t_load_us = ggml_time_us() - t_start_us;
|
|
1475
|
+
|
|
1476
|
+
return true;
|
|
1477
|
+
}
|
|
1478
|
+
|
|
1479
|
+
// conv subsampling + conformer encoder
|
|
1480
|
+
static struct ggml_cgraph * parakeet_build_graph_encode(parakeet_context & pctx, parakeet_state & pstate) {
|
|
1481
|
+
const auto & model = pctx.model;
|
|
1482
|
+
const auto & hparams = model.hparams;
|
|
1483
|
+
const int n_mel_time = pstate.n_audio_ctx > 0 ? pstate.n_audio_ctx : hparams.n_audio_ctx;
|
|
1484
|
+
const int n_mels = hparams.n_mels;
|
|
1485
|
+
const int n_layer = hparams.n_audio_layer;
|
|
1486
|
+
const int n_state = hparams.n_audio_state;
|
|
1487
|
+
const float fc_factor = 0.5f;
|
|
1488
|
+
|
|
1489
|
+
struct ggml_init_params params = {
|
|
1490
|
+
/*.mem_size =*/ pstate.sched_encode.meta.size(),
|
|
1491
|
+
/*.mem_buffer =*/ pstate.sched_encode.meta.data(),
|
|
1492
|
+
/*.no_alloc =*/ true,
|
|
1493
|
+
};
|
|
1494
|
+
|
|
1495
|
+
struct ggml_context * ctx0 = ggml_init(params);
|
|
1496
|
+
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false);
|
|
1497
|
+
|
|
1498
|
+
// Conv subsampling
|
|
1499
|
+
|
|
1500
|
+
// [freq, time]
|
|
1501
|
+
struct ggml_tensor * mel = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_mels, n_mel_time, 1, 1);
|
|
1502
|
+
ggml_set_name(mel, "mel");
|
|
1503
|
+
ggml_set_input(mel);
|
|
1504
|
+
|
|
1505
|
+
// [freq, time, channels, batch]
|
|
1506
|
+
struct ggml_tensor * cur = ggml_conv_2d(ctx0, model.enc_pre_conv_0_w, mel, 2, 2, 1, 1, 1, 1);
|
|
1507
|
+
cur = ggml_add(ctx0, cur, model.enc_pre_conv_0_b);
|
|
1508
|
+
ggml_set_name(cur, "pre_conv_0");
|
|
1509
|
+
|
|
1510
|
+
cur = ggml_relu(ctx0, cur);
|
|
1511
|
+
ggml_set_name(cur, "pre_conv_0_relu");
|
|
1512
|
+
|
|
1513
|
+
// [freq, time, channels, batch]
|
|
1514
|
+
cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_2_w, cur, 2, 2, 1, 1, 1, 1);
|
|
1515
|
+
cur = ggml_add(ctx0, cur, model.enc_pre_conv_2_b);
|
|
1516
|
+
ggml_set_name(cur, "pre_conv_2");
|
|
1517
|
+
|
|
1518
|
+
// [freq, time, channels, batch]
|
|
1519
|
+
cur = ggml_conv_2d(ctx0, model.enc_pre_conv_3_w, cur, 1, 1, 0, 0, 1, 1);
|
|
1520
|
+
cur = ggml_add(ctx0, cur, model.enc_pre_conv_3_b);
|
|
1521
|
+
ggml_set_name(cur, "pre_conv_3");
|
|
1522
|
+
|
|
1523
|
+
cur = ggml_relu(ctx0, cur);
|
|
1524
|
+
ggml_set_name(cur, "pre_conv_3_relu");
|
|
1525
|
+
|
|
1526
|
+
// [freq, time, channels, batch]
|
|
1527
|
+
cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_5_w, cur, 2, 2, 1, 1, 1, 1);
|
|
1528
|
+
ggml_set_name(cur, "pre_conv_5_direct");
|
|
1529
|
+
cur = ggml_add(ctx0, cur, model.enc_pre_conv_5_b);
|
|
1530
|
+
ggml_set_name(cur, "pre_conv_5");
|
|
1531
|
+
|
|
1532
|
+
// [freq, time, channels, batch]
|
|
1533
|
+
cur = ggml_conv_2d(ctx0, model.enc_pre_conv_6_w, cur, 1, 1, 0, 0, 1, 1);
|
|
1534
|
+
cur = ggml_add(ctx0, cur, model.enc_pre_conv_6_b);
|
|
1535
|
+
ggml_set_name(cur, "pre_conv_6");
|
|
1536
|
+
|
|
1537
|
+
cur = ggml_relu(ctx0, cur);
|
|
1538
|
+
ggml_set_name(cur, "pre_conv_6_relu");
|
|
1539
|
+
|
|
1540
|
+
// [freq, time, chan]
|
|
1541
|
+
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
1542
|
+
// [freq, chan, time]
|
|
1543
|
+
cur = ggml_cont(ctx0, cur);
|
|
1544
|
+
|
|
1545
|
+
const int n_freq = cur->ne[0]; // 16
|
|
1546
|
+
const int n_chan = cur->ne[1]; // 256
|
|
1547
|
+
const int n_frames = cur->ne[2]; // time
|
|
1548
|
+
|
|
1549
|
+
// [freq, time, chan, batch] -> [(freq * chan), time]
|
|
1550
|
+
cur = ggml_reshape_2d(ctx0, cur, n_freq * n_chan, n_frames);
|
|
1551
|
+
|
|
1552
|
+
cur = ggml_mul_mat(ctx0, model.enc_pre_out_w, cur);
|
|
1553
|
+
cur = ggml_add(ctx0, cur, model.enc_pre_out_b);
|
|
1554
|
+
|
|
1555
|
+
ggml_set_name(cur, "pre_enc_out");
|
|
1556
|
+
|
|
1557
|
+
// Encoder
|
|
1558
|
+
// cur: [n_state, n_enc_time]
|
|
1559
|
+
|
|
1560
|
+
const int n_time = cur->ne[1];
|
|
1561
|
+
const bool local_attn = n_time > PARAKEET_LOCAL_ATTN_THRESHOLD;
|
|
1562
|
+
const int att_left = local_attn ? PARAKEET_LOCAL_ATTN_WINDOW : n_time - 1;
|
|
1563
|
+
const int att_right = local_attn ? PARAKEET_LOCAL_ATTN_WINDOW : n_time - 1;
|
|
1564
|
+
const int window_size = local_attn ? att_left + att_right + 1 : 2 * n_time - 1;
|
|
1565
|
+
const int d_half = n_state / 2;
|
|
1566
|
+
const int mask_dim = local_attn ? window_size : n_time;
|
|
1567
|
+
|
|
1568
|
+
// mask [key, n_time]
|
|
1569
|
+
struct ggml_tensor * attn_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, mask_dim, n_time);
|
|
1570
|
+
ggml_set_name(attn_mask, "attn_mask");
|
|
1571
|
+
ggml_set_input(attn_mask);
|
|
1572
|
+
|
|
1573
|
+
struct ggml_tensor * local_mask = nullptr;
|
|
1574
|
+
if (local_attn) {
|
|
1575
|
+
const int chunk = att_left + att_right;
|
|
1576
|
+
local_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, chunk + window_size - 1, chunk);
|
|
1577
|
+
ggml_set_name(local_mask, "local_mask");
|
|
1578
|
+
ggml_set_input(local_mask);
|
|
1579
|
+
}
|
|
1580
|
+
|
|
1581
|
+
struct ggml_tensor * pos_freqs = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, d_half);
|
|
1582
|
+
ggml_set_name(pos_freqs, "pos_freqs");
|
|
1583
|
+
ggml_set_input(pos_freqs);
|
|
1584
|
+
|
|
1585
|
+
struct ggml_tensor * rel_positions = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, window_size);
|
|
1586
|
+
ggml_set_name(rel_positions, "rel_positions");
|
|
1587
|
+
ggml_set_input(rel_positions);
|
|
1588
|
+
|
|
1589
|
+
struct ggml_tensor * freqs = ggml_repeat_4d(ctx0, pos_freqs, d_half, window_size, 1, 1);
|
|
1590
|
+
struct ggml_tensor * theta = ggml_mul(ctx0, freqs, rel_positions);
|
|
1591
|
+
|
|
1592
|
+
struct ggml_tensor * sin_t = ggml_reshape_3d(ctx0, ggml_sin(ctx0, theta), 1, d_half, window_size);
|
|
1593
|
+
struct ggml_tensor * cos_t = ggml_reshape_3d(ctx0, ggml_cos(ctx0, theta), 1, d_half, window_size);
|
|
1594
|
+
// [n_state, window_size]
|
|
1595
|
+
struct ggml_tensor * pos_emb = ggml_reshape_2d(ctx0, ggml_cont(ctx0, ggml_concat(ctx0, sin_t, cos_t, 0)), n_state, window_size);
|
|
1596
|
+
ggml_set_name(pos_emb, "pos_emb");
|
|
1597
|
+
|
|
1598
|
+
for (int il = 0; il < n_layer; ++il) {
|
|
1599
|
+
const auto & layer = model.layers[il];
|
|
1600
|
+
|
|
1601
|
+
// FFN1
|
|
1602
|
+
{
|
|
1603
|
+
struct ggml_tensor * residual = cur;
|
|
1604
|
+
ggml_format_name(cur, "enc_%d_res", il);
|
|
1605
|
+
|
|
1606
|
+
// norm
|
|
1607
|
+
cur = ggml_norm(ctx0, cur, hparams.eps);
|
|
1608
|
+
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_ff1_w), layer.norm_ff1_b);
|
|
1609
|
+
ggml_format_name(cur, "enc_%d_ffn_norm_1", il);
|
|
1610
|
+
|
|
1611
|
+
// ffn_1
|
|
1612
|
+
cur = ggml_mul_mat(ctx0, layer.ff1_linear1_w, cur);
|
|
1613
|
+
cur = ggml_silu(ctx0, cur);
|
|
1614
|
+
ggml_format_name(cur, "enc_%d_silu", il);
|
|
1615
|
+
|
|
1616
|
+
cur = ggml_mul_mat(ctx0, layer.ff1_linear2_w, cur);
|
|
1617
|
+
ggml_format_name(cur, "enc_%d_ffn_1", il);
|
|
1618
|
+
|
|
1619
|
+
cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor));
|
|
1620
|
+
ggml_format_name(cur, "enc_%d_res_ffn", il);
|
|
1621
|
+
}
|
|
1622
|
+
|
|
1623
|
+
// self attention block using relative positional encoding computed in graph.
|
|
1624
|
+
{
|
|
1625
|
+
// [feat, time_frames, 1, 1]
|
|
1626
|
+
struct ggml_tensor * residual = cur;
|
|
1627
|
+
|
|
1628
|
+
cur = ggml_norm(ctx0, cur, hparams.eps);
|
|
1629
|
+
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_attn_w), layer.norm_attn_b);
|
|
1630
|
+
ggml_format_name(cur, "enc_%d_attn_norm", il);
|
|
1631
|
+
|
|
1632
|
+
const int n_head = hparams.n_audio_head;
|
|
1633
|
+
const int d_head = n_state / n_head;
|
|
1634
|
+
|
|
1635
|
+
// [feat, time_frames, 1, 1]
|
|
1636
|
+
struct ggml_tensor * Q_cur = ggml_mul_mat(ctx0, layer.attn_q_w, cur);
|
|
1637
|
+
struct ggml_tensor * K_cur = ggml_mul_mat(ctx0, layer.attn_k_w, cur);
|
|
1638
|
+
struct ggml_tensor * V_cur = ggml_mul_mat(ctx0, layer.attn_v_w, cur);
|
|
1639
|
+
|
|
1640
|
+
Q_cur = ggml_reshape_3d(ctx0, Q_cur, d_head, n_head, n_time);
|
|
1641
|
+
K_cur = ggml_reshape_3d(ctx0, K_cur, d_head, n_head, n_time);
|
|
1642
|
+
V_cur = ggml_reshape_3d(ctx0, V_cur, d_head, n_head, n_time);
|
|
1643
|
+
|
|
1644
|
+
struct ggml_tensor * pos = ggml_mul_mat(ctx0, layer.attn_pos_w, pos_emb);
|
|
1645
|
+
pos = ggml_reshape_3d(ctx0, pos, d_head, n_head, window_size);
|
|
1646
|
+
pos = ggml_cont(ctx0, ggml_permute(ctx0, pos, 0, 2, 1, 3));
|
|
1647
|
+
|
|
1648
|
+
if (local_attn) {
|
|
1649
|
+
const int chunk = att_left + att_right;
|
|
1650
|
+
const int n_group = (n_time + chunk - 1) / chunk;
|
|
1651
|
+
const int n_time_padded = n_group * chunk;
|
|
1652
|
+
const int n_kv_chunk = chunk + window_size - 1;
|
|
1653
|
+
const int n_kv_dense = n_kv_chunk * n_group;
|
|
1654
|
+
const bool need_padding = n_time_padded > n_time;
|
|
1655
|
+
|
|
1656
|
+
Q_cur = ggml_cont(ctx0, ggml_permute(ctx0, Q_cur, 0, 2, 1, 3));
|
|
1657
|
+
K_cur = ggml_cont(ctx0, ggml_permute(ctx0, K_cur, 0, 2, 1, 3));
|
|
1658
|
+
V_cur = ggml_cont(ctx0, ggml_permute(ctx0, V_cur, 0, 2, 1, 3));
|
|
1659
|
+
|
|
1660
|
+
// content bias
|
|
1661
|
+
struct ggml_tensor * bias_u = ggml_reshape_3d(ctx0, layer.attn_pos_bias_u, d_head, 1, n_head);
|
|
1662
|
+
struct ggml_tensor * Q_u = ggml_add(ctx0, Q_cur, bias_u);
|
|
1663
|
+
|
|
1664
|
+
// position bias
|
|
1665
|
+
struct ggml_tensor * bias_v = ggml_reshape_3d(ctx0, layer.attn_pos_bias_v, d_head, 1, n_head);
|
|
1666
|
+
struct ggml_tensor * Q_v = ggml_add(ctx0, Q_cur, bias_v);
|
|
1667
|
+
|
|
1668
|
+
// right pad the time_frame.
|
|
1669
|
+
struct ggml_tensor * Q_u_padded = need_padding ?
|
|
1670
|
+
ggml_pad_ext(ctx0, Q_u, 0, 0, 0, n_time_padded - n_time, 0, 0, 0, 0) : Q_u;
|
|
1671
|
+
Q_u_padded = ggml_reshape_4d(ctx0, Q_u_padded, d_head, chunk, n_group, n_head);
|
|
1672
|
+
|
|
1673
|
+
// Add padding to front and back (for the first timeframe and the last timeframe).
|
|
1674
|
+
struct ggml_tensor * K_padded = ggml_pad_ext(ctx0, K_cur, 0, 0, att_left, att_right, 0, 0, 0, 0);
|
|
1675
|
+
|
|
1676
|
+
// pad time axis to match n_kv_dense if needed.
|
|
1677
|
+
if (n_kv_dense > K_padded->ne[1]) {
|
|
1678
|
+
K_padded = ggml_pad_ext(ctx0, K_padded, 0, 0, 0, n_kv_dense - K_padded->ne[1], 0, 0, 0, 0);
|
|
1679
|
+
}
|
|
1680
|
+
|
|
1681
|
+
// Create a 4d tensor where each group spans a wide window of
|
|
1682
|
+
// 512 keys (n_kv_chunk), but moving to the next group (nb[2])
|
|
1683
|
+
// only jumps forward by 256 frames (chunk * nb[1]). This creates
|
|
1684
|
+
// a 256 frame overlap, shared keys in RAM without copies.
|
|
1685
|
+
struct ggml_tensor * K_chunk = ggml_view_4d(ctx0, K_padded,
|
|
1686
|
+
d_head, n_kv_chunk, n_group, n_head,
|
|
1687
|
+
K_padded->nb[1],
|
|
1688
|
+
(size_t) chunk * K_padded->nb[1],
|
|
1689
|
+
K_padded->nb[2],
|
|
1690
|
+
0);
|
|
1691
|
+
K_chunk = ggml_cont(ctx0, K_chunk);
|
|
1692
|
+
|
|
1693
|
+
struct ggml_tensor * content_scores = ggml_mul_mat(ctx0, K_chunk, Q_u_padded);
|
|
1694
|
+
|
|
1695
|
+
// The above mul_mat operation, combined with K_chunk's overlapping
|
|
1696
|
+
// frames, produces a dense matrix. But some of the results in
|
|
1697
|
+
// this matrix were computed for keys that aren't part of that
|
|
1698
|
+
// query's window. So we shift each row to keep only the results
|
|
1699
|
+
// that we want.
|
|
1700
|
+
content_scores = ggml_view_4d(ctx0, content_scores,
|
|
1701
|
+
window_size, chunk, n_group, n_head,
|
|
1702
|
+
(size_t) (chunk + window_size) * content_scores->nb[0],
|
|
1703
|
+
content_scores->nb[2],
|
|
1704
|
+
content_scores->nb[3],
|
|
1705
|
+
0);
|
|
1706
|
+
content_scores = ggml_cont(ctx0, content_scores);
|
|
1707
|
+
|
|
1708
|
+
// ungrouping.
|
|
1709
|
+
content_scores = ggml_reshape_3d(ctx0, content_scores, window_size, n_time_padded, n_head);
|
|
1710
|
+
|
|
1711
|
+
// remove padding if padding was applied (truncating to n_time).
|
|
1712
|
+
if (need_padding) {
|
|
1713
|
+
content_scores = ggml_view_3d(ctx0, content_scores,
|
|
1714
|
+
window_size, n_time, n_head,
|
|
1715
|
+
content_scores->nb[1],
|
|
1716
|
+
content_scores->nb[2],
|
|
1717
|
+
0);
|
|
1718
|
+
}
|
|
1719
|
+
|
|
1720
|
+
struct ggml_tensor * rel_pos_scores = ggml_mul_mat(ctx0, pos, Q_v);
|
|
1721
|
+
|
|
1722
|
+
// attention_score = content similarity + relative position scores
|
|
1723
|
+
struct ggml_tensor * attn_scores = ggml_add(ctx0, content_scores, rel_pos_scores);
|
|
1724
|
+
|
|
1725
|
+
attn_scores = ggml_soft_max_ext(ctx0, attn_scores, attn_mask, 1.0f / std::sqrt(d_head), 0.0f);
|
|
1726
|
+
|
|
1727
|
+
// right pad the probabilites.
|
|
1728
|
+
struct ggml_tensor * probs_padded = need_padding ?
|
|
1729
|
+
ggml_pad_ext(ctx0, attn_scores, 0, 0, 0, n_time_padded - n_time, 0, 0, 0, 0) : attn_scores;
|
|
1730
|
+
|
|
1731
|
+
probs_padded = ggml_reshape_4d(ctx0, probs_padded, window_size, chunk, n_group, n_head);
|
|
1732
|
+
probs_padded = ggml_pad_ext(ctx0, probs_padded, 0, chunk, 0, 0, 0, 0, 0, 0);
|
|
1733
|
+
probs_padded = ggml_view_4d(ctx0, probs_padded,
|
|
1734
|
+
n_kv_chunk, chunk, n_group, n_head,
|
|
1735
|
+
(size_t) n_kv_chunk * probs_padded->nb[0],
|
|
1736
|
+
probs_padded->nb[2],
|
|
1737
|
+
probs_padded->nb[3],
|
|
1738
|
+
0);
|
|
1739
|
+
probs_padded = ggml_cont(ctx0, probs_padded);
|
|
1740
|
+
probs_padded = ggml_mul(ctx0, probs_padded, local_mask);
|
|
1741
|
+
|
|
1742
|
+
// Add padding to front and back (for the first timeframe and the last timeframe).
|
|
1743
|
+
struct ggml_tensor * V_padded = ggml_pad_ext(ctx0, V_cur, 0, 0, att_left, att_right, 0, 0, 0, 0);
|
|
1744
|
+
|
|
1745
|
+
// pad time axis to match n_kv_dense if needed.
|
|
1746
|
+
if (n_kv_dense > V_padded->ne[1]) {
|
|
1747
|
+
V_padded = ggml_pad_ext(ctx0, V_padded, 0, 0, 0, n_kv_dense - V_padded->ne[1], 0, 0, 0, 0);
|
|
1748
|
+
}
|
|
1749
|
+
|
|
1750
|
+
V_padded = ggml_cont(ctx0, ggml_transpose(ctx0, V_padded));
|
|
1751
|
+
|
|
1752
|
+
struct ggml_tensor * V_chunk = ggml_view_4d(ctx0, V_padded,
|
|
1753
|
+
n_kv_chunk, d_head, n_group, n_head,
|
|
1754
|
+
V_padded->nb[1],
|
|
1755
|
+
(size_t) chunk * V_padded->nb[0],
|
|
1756
|
+
V_padded->nb[2],
|
|
1757
|
+
0);
|
|
1758
|
+
V_chunk = ggml_cont(ctx0, V_chunk);
|
|
1759
|
+
|
|
1760
|
+
cur = ggml_mul_mat(ctx0, V_chunk, probs_padded);
|
|
1761
|
+
// ungroup.
|
|
1762
|
+
cur = ggml_reshape_3d(ctx0, cur, d_head, n_time_padded, n_head);
|
|
1763
|
+
// unpad
|
|
1764
|
+
if (need_padding) {
|
|
1765
|
+
cur = ggml_view_3d(ctx0, cur, d_head, n_time, n_head, cur->nb[1], cur->nb[2], 0);
|
|
1766
|
+
}
|
|
1767
|
+
|
|
1768
|
+
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3));
|
|
1769
|
+
cur = ggml_reshape_2d(ctx0, cur, n_state, n_time);
|
|
1770
|
+
cur = ggml_mul_mat(ctx0, layer.attn_out_w, cur);
|
|
1771
|
+
} else {
|
|
1772
|
+
struct ggml_tensor * Q_u = ggml_add(ctx0, Q_cur, layer.attn_pos_bias_u);
|
|
1773
|
+
ggml_format_name(Q_u, "enc_%d_attn_q_u", il);
|
|
1774
|
+
|
|
1775
|
+
struct ggml_tensor * K_prep = ggml_permute(ctx0, K_cur, 0, 2, 1, 3);
|
|
1776
|
+
struct ggml_tensor * Q_prep = ggml_permute(ctx0, Q_u, 0, 2, 1, 3);
|
|
1777
|
+
struct ggml_tensor * content_scores = ggml_mul_mat(ctx0, K_prep, Q_prep);
|
|
1778
|
+
ggml_format_name(content_scores, "enc_%d_attn_content_scores", il);
|
|
1779
|
+
|
|
1780
|
+
struct ggml_tensor * Q_v = ggml_add(ctx0, Q_cur, layer.attn_pos_bias_v);
|
|
1781
|
+
ggml_format_name(Q_v, "enc_%d_attn_q_v", il);
|
|
1782
|
+
|
|
1783
|
+
Q_v = ggml_permute(ctx0, Q_v, 0, 2, 1, 3);
|
|
1784
|
+
Q_v = ggml_cont(ctx0, Q_v);
|
|
1785
|
+
ggml_format_name(Q_v, "enc_%d_attn_q_v_perm", il);
|
|
1786
|
+
|
|
1787
|
+
struct ggml_tensor * rel_pos_scores = ggml_mul_mat(ctx0, pos, Q_v);
|
|
1788
|
+
ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos", il);
|
|
1789
|
+
|
|
1790
|
+
// Relative position shifting is performed in the following block.
|
|
1791
|
+
// Some more details on the operations performed below can be found here:
|
|
1792
|
+
// https://github.com/danbev/learning-ai/blob/main/notes/whisper/parakeet.md#relative-position-shift
|
|
1793
|
+
{
|
|
1794
|
+
const auto pos_window = rel_pos_scores->ne[0];
|
|
1795
|
+
const auto n_frame = rel_pos_scores->ne[1];
|
|
1796
|
+
const auto n_head_cur = rel_pos_scores->ne[2];
|
|
1797
|
+
|
|
1798
|
+
rel_pos_scores = ggml_pad(ctx0, rel_pos_scores, 1, 0, 0, 0);
|
|
1799
|
+
rel_pos_scores = ggml_roll(ctx0, rel_pos_scores, 1, 0, 0, 0);
|
|
1800
|
+
|
|
1801
|
+
rel_pos_scores = ggml_reshape_3d(ctx0, rel_pos_scores, n_frame, pos_window + 1, n_head_cur);
|
|
1802
|
+
ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_reshaped", il);
|
|
1803
|
+
|
|
1804
|
+
int center = pos_window / 2;
|
|
1805
|
+
size_t offset = rel_pos_scores->nb[0] * (center+1);
|
|
1806
|
+
|
|
1807
|
+
rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores,
|
|
1808
|
+
n_frame, pos_window, n_head_cur,
|
|
1809
|
+
(pos_window) * 4,
|
|
1810
|
+
rel_pos_scores->nb[2],
|
|
1811
|
+
offset);
|
|
1812
|
+
|
|
1813
|
+
ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted", il);
|
|
1814
|
+
|
|
1815
|
+
rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores,
|
|
1816
|
+
content_scores->ne[0],
|
|
1817
|
+
content_scores->ne[1],
|
|
1818
|
+
rel_pos_scores->ne[2],
|
|
1819
|
+
rel_pos_scores->nb[1],
|
|
1820
|
+
rel_pos_scores->nb[2],
|
|
1821
|
+
0);
|
|
1822
|
+
rel_pos_scores = ggml_cont(ctx0, rel_pos_scores);
|
|
1823
|
+
ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted_view", il);
|
|
1824
|
+
}
|
|
1825
|
+
|
|
1826
|
+
struct ggml_tensor * attn_scores = ggml_add(ctx0, content_scores, rel_pos_scores);
|
|
1827
|
+
ggml_format_name(attn_scores, "enc_%d_attn_scores", il);
|
|
1828
|
+
attn_scores = ggml_scale(ctx0, attn_scores, 1.0f / std::sqrt(d_head));
|
|
1829
|
+
attn_scores = ggml_add(ctx0, attn_scores, attn_mask);
|
|
1830
|
+
ggml_format_name(attn_scores, "enc_%d_attn_scores_scaled", il);
|
|
1831
|
+
|
|
1832
|
+
struct ggml_tensor * probs = ggml_soft_max(ctx0, attn_scores);
|
|
1833
|
+
ggml_format_name(probs, "enc_%d_attn_probs", il);
|
|
1834
|
+
|
|
1835
|
+
V_cur = ggml_cont(ctx0, ggml_permute(ctx0, V_cur, 1, 2, 0, 3));
|
|
1836
|
+
ggml_format_name(V_cur, "enc_%d_attn_v_cur", il);
|
|
1837
|
+
cur = ggml_mul_mat(ctx0, probs, V_cur);
|
|
1838
|
+
ggml_format_name(cur, "enc_%d_attn_inp", il);
|
|
1839
|
+
|
|
1840
|
+
cur = ggml_permute(ctx0, cur, 2, 0, 1, 3);
|
|
1841
|
+
cur = ggml_cont_2d(ctx0, cur, n_state, n_time);
|
|
1842
|
+
cur = ggml_mul_mat(ctx0, layer.attn_out_w, cur);
|
|
1843
|
+
}
|
|
1844
|
+
ggml_format_name(cur, "enc_%d_attn_out", il);
|
|
1845
|
+
|
|
1846
|
+
cur = ggml_add(ctx0, residual, cur);
|
|
1847
|
+
ggml_format_name(cur, "enc_%d_attn_res", il);
|
|
1848
|
+
}
|
|
1849
|
+
|
|
1850
|
+
// Convolution
|
|
1851
|
+
{
|
|
1852
|
+
struct ggml_tensor * residual = cur;
|
|
1853
|
+
ggml_format_name(cur, "enc_%d_residual_conv", il);
|
|
1854
|
+
|
|
1855
|
+
cur = ggml_norm(ctx0, cur, hparams.eps);
|
|
1856
|
+
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_conv_w), layer.norm_conv_b);
|
|
1857
|
+
ggml_format_name(cur, "enc_%d_norm_conv", il);
|
|
1858
|
+
|
|
1859
|
+
// pointwise 1d convolution: [1024, 138] -> [2048, 138]
|
|
1860
|
+
cur = ggml_mul_mat(ctx0, layer.conv_pw1_w, cur);
|
|
1861
|
+
ggml_format_name(cur, "enc_%d_conv_pw1", il);
|
|
1862
|
+
|
|
1863
|
+
{
|
|
1864
|
+
int64_t d = cur->ne[0] / 2;
|
|
1865
|
+
struct ggml_tensor * signal = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], 0);
|
|
1866
|
+
struct ggml_tensor * gate = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], d * cur->nb[0]);
|
|
1867
|
+
|
|
1868
|
+
cur = ggml_mul(ctx0, signal, ggml_sigmoid(ctx0, gate));
|
|
1869
|
+
ggml_format_name(cur, "enc_%d_conv_glu", il);
|
|
1870
|
+
}
|
|
1871
|
+
|
|
1872
|
+
cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
|
|
1873
|
+
|
|
1874
|
+
// use ggml_ssm_conv for f32 precision
|
|
1875
|
+
const int dw_pad = (hparams.n_conv_kernel - 1) / 2;
|
|
1876
|
+
cur = ggml_pad(ctx0, cur, dw_pad, 0, 0, 0);
|
|
1877
|
+
cur = ggml_roll(ctx0, cur, dw_pad, 0, 0, 0);
|
|
1878
|
+
cur = ggml_pad(ctx0, cur, dw_pad, 0, 0, 0);
|
|
1879
|
+
ggml_format_name(cur, "enc_%d_conv_dw_pad", il);
|
|
1880
|
+
|
|
1881
|
+
cur = ggml_ssm_conv(ctx0, cur, layer.conv_dw_w);
|
|
1882
|
+
ggml_format_name(cur, "enc_%d_conv_1d_dw", il);
|
|
1883
|
+
|
|
1884
|
+
cur = ggml_sub(ctx0, cur, layer.conv_bn_mean);
|
|
1885
|
+
struct ggml_tensor * std = ggml_sqrt(ctx0, layer.conv_bn_var);
|
|
1886
|
+
cur = ggml_div(ctx0, cur, std);
|
|
1887
|
+
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.conv_bn_w), layer.conv_bn_b);
|
|
1888
|
+
ggml_format_name(cur, "enc_%d_conv_bn", il);
|
|
1889
|
+
|
|
1890
|
+
cur = ggml_silu(ctx0, cur);
|
|
1891
|
+
ggml_format_name(cur, "enc_%d_conv_silu", il);
|
|
1892
|
+
|
|
1893
|
+
cur = ggml_mul_mat(ctx0, layer.conv_pw2_w, cur);
|
|
1894
|
+
ggml_format_name(cur, "enc_%d_conv_pw2", il);
|
|
1895
|
+
|
|
1896
|
+
cur = ggml_add(ctx0, residual, cur);
|
|
1897
|
+
ggml_format_name(cur, "enc_%d_conv_res", il);
|
|
1898
|
+
}
|
|
1899
|
+
|
|
1900
|
+
// FFN2
|
|
1901
|
+
{
|
|
1902
|
+
struct ggml_tensor * residual = cur;
|
|
1903
|
+
cur = ggml_norm(ctx0, cur, hparams.eps);
|
|
1904
|
+
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_ff2_w), layer.norm_ff2_b);
|
|
1905
|
+
ggml_format_name(cur, "enc_%d_ffn_norm_2", il);
|
|
1906
|
+
|
|
1907
|
+
cur = ggml_mul_mat(ctx0, layer.ff2_linear1_w, cur);
|
|
1908
|
+
cur = ggml_silu(ctx0, cur);
|
|
1909
|
+
cur = ggml_mul_mat(ctx0, layer.ff2_linear2_w, cur);
|
|
1910
|
+
cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, 0.5));
|
|
1911
|
+
ggml_format_name(cur, "enc_%d_ffn_res", il);
|
|
1912
|
+
}
|
|
1913
|
+
|
|
1914
|
+
cur = ggml_norm(ctx0, cur, hparams.eps);
|
|
1915
|
+
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_out_w), layer.norm_out_b);
|
|
1916
|
+
}
|
|
1917
|
+
|
|
1918
|
+
ggml_set_name(cur, "encoder_out");
|
|
1919
|
+
pstate.n_frames = cur->ne[1];
|
|
1920
|
+
|
|
1921
|
+
struct ggml_tensor * enc_out_view = ggml_view_2d(ctx0, pstate.enc_out, n_state, pstate.n_frames, pstate.enc_out->nb[1], 0);
|
|
1922
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, cur, enc_out_view));
|
|
1923
|
+
|
|
1924
|
+
ggml_free(ctx0);
|
|
1925
|
+
|
|
1926
|
+
return gf;
|
|
1927
|
+
}
|
|
1928
|
+
|
|
1929
|
+
static bool parakeet_encode_internal(
|
|
1930
|
+
parakeet_context & pctx,
|
|
1931
|
+
parakeet_state & pstate,
|
|
1932
|
+
const int mel_offset,
|
|
1933
|
+
const int n_threads,
|
|
1934
|
+
ggml_abort_callback abort_callback,
|
|
1935
|
+
void * abort_callback_data) {
|
|
1936
|
+
const int64_t t_start_us = ggml_time_us();
|
|
1937
|
+
|
|
1938
|
+
auto & sched = pstate.sched_encode.sched;
|
|
1939
|
+
|
|
1940
|
+
ggml_cgraph * gf = parakeet_build_graph_encode(pctx, pstate);
|
|
1941
|
+
|
|
1942
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
|
1943
|
+
// should never happen as we pre-allocate the memory
|
|
1944
|
+
return false;
|
|
1945
|
+
}
|
|
1946
|
+
|
|
1947
|
+
// set mel input
|
|
1948
|
+
{
|
|
1949
|
+
struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
|
|
1950
|
+
|
|
1951
|
+
const auto & mel_inp = pstate.mel;
|
|
1952
|
+
const int n_ctx = pstate.n_audio_ctx > 0 ? pstate.n_audio_ctx : pctx.model.hparams.n_audio_ctx;
|
|
1953
|
+
|
|
1954
|
+
assert(mel->type == GGML_TYPE_F32);
|
|
1955
|
+
assert(mel_inp.n_mel == pctx.model.hparams.n_mels);
|
|
1956
|
+
|
|
1957
|
+
pstate.inp_mel.resize(ggml_nelements(mel));
|
|
1958
|
+
|
|
1959
|
+
float * dst = pstate.inp_mel.data();
|
|
1960
|
+
memset(dst, 0, ggml_nbytes(mel));
|
|
1961
|
+
|
|
1962
|
+
const int i0 = std::min(mel_offset, mel_inp.n_len);
|
|
1963
|
+
const int i1 = std::min(mel_offset + n_ctx, mel_inp.n_len);
|
|
1964
|
+
|
|
1965
|
+
memcpy(dst, mel_inp.data.data() + i0 * mel_inp.n_mel, (i1 - i0) * mel_inp.n_mel * sizeof(float));
|
|
1966
|
+
|
|
1967
|
+
ggml_backend_tensor_set(mel, pstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float));
|
|
1968
|
+
}
|
|
1969
|
+
|
|
1970
|
+
// set attention mask
|
|
1971
|
+
{
|
|
1972
|
+
struct ggml_tensor * attn_mask = ggml_graph_get_tensor(gf, "attn_mask");
|
|
1973
|
+
const int n_q = attn_mask->ne[1];
|
|
1974
|
+
const int n_k = attn_mask->ne[0];
|
|
1975
|
+
|
|
1976
|
+
const int32_t subsampl_factor = pctx.model.hparams.subsampling_factor;
|
|
1977
|
+
const int n_tokens_real = (pstate.mel.n_len_org + subsampl_factor - 1) / subsampl_factor;
|
|
1978
|
+
|
|
1979
|
+
std::vector<float> mask_data(n_q * n_k);
|
|
1980
|
+
const float mask_value = -1e30f;
|
|
1981
|
+
|
|
1982
|
+
if (n_k == n_q) { // full attention
|
|
1983
|
+
for (int q = 0; q < n_q; ++q) {
|
|
1984
|
+
for (int k = 0; k < n_k; ++k) {
|
|
1985
|
+
mask_data[q * n_k + k] = (k >= n_tokens_real) ? mask_value : 0.0f;
|
|
1986
|
+
}
|
|
1987
|
+
}
|
|
1988
|
+
} else { // local attention
|
|
1989
|
+
const int att_left = n_k / 2;
|
|
1990
|
+
for (int q = 0; q < n_q; ++q) {
|
|
1991
|
+
for (int k = 0; k < n_k; ++k) {
|
|
1992
|
+
const int key = q - att_left + k;
|
|
1993
|
+
mask_data[q * n_k + k] = (key >= 0 && key < n_tokens_real) ? 0.0f : mask_value;
|
|
1994
|
+
}
|
|
1995
|
+
}
|
|
1996
|
+
}
|
|
1997
|
+
ggml_backend_tensor_set(attn_mask, mask_data.data(), 0, mask_data.size() * sizeof(float));
|
|
1998
|
+
}
|
|
1999
|
+
|
|
2000
|
+
// set local attention skew mask
|
|
2001
|
+
if (struct ggml_tensor * local_mask = ggml_graph_get_tensor(gf, "local_mask")) {
|
|
2002
|
+
const int n_k = local_mask->ne[0];
|
|
2003
|
+
const int n_q = local_mask->ne[1];
|
|
2004
|
+
|
|
2005
|
+
std::vector<float> mask_data(n_q * n_k);
|
|
2006
|
+
const int window_size = n_k - n_q + 1;
|
|
2007
|
+
for (int q = 0; q < n_q; ++q) {
|
|
2008
|
+
for (int k = 0; k < n_k; ++k) {
|
|
2009
|
+
const int rel = k - q;
|
|
2010
|
+
mask_data[q * n_k + k] = (rel >= 0 && rel < window_size) ? 1.0f : 0.0f;
|
|
2011
|
+
}
|
|
2012
|
+
}
|
|
2013
|
+
ggml_backend_tensor_set(local_mask, mask_data.data(), 0, mask_data.size() * sizeof(float));
|
|
2014
|
+
}
|
|
2015
|
+
|
|
2016
|
+
// set positional frequency
|
|
2017
|
+
{
|
|
2018
|
+
struct ggml_tensor * pos_freqs_t = ggml_graph_get_tensor(gf, "pos_freqs");
|
|
2019
|
+
const int d_half = pos_freqs_t->ne[0];
|
|
2020
|
+
const int n_state = pctx.model.hparams.n_audio_state;
|
|
2021
|
+
const float log_10000 = logf(10000.0f);
|
|
2022
|
+
std::vector<float> freqs(d_half);
|
|
2023
|
+
for (int k = 0; k < d_half; ++k) {
|
|
2024
|
+
freqs[k] = expf(-(float(k * 2) * log_10000 / float(n_state)));
|
|
2025
|
+
}
|
|
2026
|
+
ggml_backend_tensor_set(pos_freqs_t, freqs.data(), 0, freqs.size() * sizeof(float));
|
|
2027
|
+
}
|
|
2028
|
+
|
|
2029
|
+
// set relative position offsets
|
|
2030
|
+
{
|
|
2031
|
+
struct ggml_tensor * rel_pos_t = ggml_graph_get_tensor(gf, "rel_positions");
|
|
2032
|
+
const int window_size = rel_pos_t->ne[1];
|
|
2033
|
+
std::vector<float> pos(window_size);
|
|
2034
|
+
if (window_size == PARAKEET_LOCAL_ATTN_WINDOW * 2 + 1) {
|
|
2035
|
+
for (int t = 0; t < window_size; ++t) {
|
|
2036
|
+
pos[t] = float(PARAKEET_LOCAL_ATTN_WINDOW - t);
|
|
2037
|
+
}
|
|
2038
|
+
} else {
|
|
2039
|
+
const int n_time = (window_size + 1) / 2;
|
|
2040
|
+
for (int t = 0; t < window_size; ++t) {
|
|
2041
|
+
pos[t] = float(n_time - 1 - t);
|
|
2042
|
+
}
|
|
2043
|
+
}
|
|
2044
|
+
ggml_backend_tensor_set(rel_pos_t, pos.data(), 0, pos.size() * sizeof(float));
|
|
2045
|
+
}
|
|
2046
|
+
|
|
2047
|
+
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
|
2048
|
+
return false;
|
|
2049
|
+
}
|
|
2050
|
+
|
|
2051
|
+
pstate.t_encode_us += ggml_time_us() - t_start_us;
|
|
2052
|
+
pstate.n_encode++;
|
|
2053
|
+
|
|
2054
|
+
return !(abort_callback && abort_callback(abort_callback_data));
|
|
2055
|
+
}
|
|
2056
|
+
|
|
2057
|
+
static bool parakeet_ensure_encode_sched(
|
|
2058
|
+
parakeet_context & pctx,
|
|
2059
|
+
parakeet_state & pstate,
|
|
2060
|
+
int n_audio_ctx) {
|
|
2061
|
+
if (pstate.sched_encode.sched && pstate.sched_encode_n_audio_ctx == n_audio_ctx) {
|
|
2062
|
+
return true;
|
|
2063
|
+
}
|
|
2064
|
+
|
|
2065
|
+
parakeet_sched_free(pstate.sched_encode);
|
|
2066
|
+
|
|
2067
|
+
const int32_t prev_n_audio_ctx = pstate.n_audio_ctx;
|
|
2068
|
+
pstate.n_audio_ctx = n_audio_ctx;
|
|
2069
|
+
|
|
2070
|
+
const int subsampl_factor = pctx.model.hparams.subsampling_factor;
|
|
2071
|
+
const int n_frames_max = (n_audio_ctx + subsampl_factor - 1) / subsampl_factor;
|
|
2072
|
+
if (n_frames_max > pstate.enc_out->ne[1]) {
|
|
2073
|
+
ggml_backend_buffer_free(pstate.enc_out_buffer);
|
|
2074
|
+
pstate.enc_out_buffer = nullptr;
|
|
2075
|
+
pstate.enc_out = nullptr;
|
|
2076
|
+
|
|
2077
|
+
if (!parakeet_enc_state_init(pstate, pstate.backends[0], pctx.model.hparams.n_audio_state, n_frames_max)) {
|
|
2078
|
+
pstate.sched_encode_n_audio_ctx = 0;
|
|
2079
|
+
pstate.n_audio_ctx = prev_n_audio_ctx;
|
|
2080
|
+
return false;
|
|
2081
|
+
}
|
|
2082
|
+
}
|
|
2083
|
+
|
|
2084
|
+
const bool ok = parakeet_sched_graph_init(pstate.sched_encode, pstate.backends,
|
|
2085
|
+
[&]() {
|
|
2086
|
+
return parakeet_build_graph_encode(pctx, pstate);
|
|
2087
|
+
});
|
|
2088
|
+
|
|
2089
|
+
if (!ok) {
|
|
2090
|
+
pstate.sched_encode_n_audio_ctx = 0;
|
|
2091
|
+
pstate.n_audio_ctx = prev_n_audio_ctx;
|
|
2092
|
+
return false;
|
|
2093
|
+
}
|
|
2094
|
+
|
|
2095
|
+
pstate.sched_encode_n_audio_ctx = n_audio_ctx;
|
|
2096
|
+
return true;
|
|
2097
|
+
}
|
|
2098
|
+
|
|
2099
|
+
static struct ggml_tensor * parakeet_build_graph_lstm_layer(
|
|
2100
|
+
struct ggml_context * ctx0,
|
|
2101
|
+
struct ggml_cgraph * gf,
|
|
2102
|
+
struct ggml_tensor * x_t, // the current input token embedding
|
|
2103
|
+
struct ggml_tensor * w_ih, // input to hidden weights (4 weight tensors packed)
|
|
2104
|
+
struct ggml_tensor * w_hh, // hidden to hidden weights (4 weight tensors packed)
|
|
2105
|
+
struct ggml_tensor * b_h, // folded ih+hh bias (4 bias tensors packed)
|
|
2106
|
+
struct ggml_tensor * h_state, // this layers hidden state
|
|
2107
|
+
struct ggml_tensor * c_state, // this layers cell state
|
|
2108
|
+
int li) { // layer index (for tensor naming)
|
|
2109
|
+
|
|
2110
|
+
ggml_format_name(x_t, "lstm_layer_%d_x_t", li);
|
|
2111
|
+
ggml_format_name(h_state, "lstm_layer_%d_h_state", li);
|
|
2112
|
+
ggml_format_name(c_state, "lstm_layer_%d_c_state", li);
|
|
2113
|
+
|
|
2114
|
+
// The 4 gates (i, f, o, c) are packed in the same weight tensor.
|
|
2115
|
+
struct ggml_tensor * inp_gates = ggml_mul_mat(ctx0, w_ih, x_t);
|
|
2116
|
+
|
|
2117
|
+
// Hidden-to-Hidden Projections are also packed in the same weight tensor.
|
|
2118
|
+
// b_h holds the folded ih+hh bias (see parakeet_model_load), so it is
|
|
2119
|
+
// the only bias that needs to be added here.
|
|
2120
|
+
struct ggml_tensor * hid_gates = ggml_mul_mat(ctx0, w_hh, h_state);
|
|
2121
|
+
hid_gates = ggml_add(ctx0, hid_gates, b_h);
|
|
2122
|
+
|
|
2123
|
+
// Combine the input and hidden contributions of the gates.
|
|
2124
|
+
struct ggml_tensor * gates = ggml_add(ctx0, inp_gates, hid_gates);
|
|
2125
|
+
ggml_format_name(gates, "lstm_layer_%d_gates", li);
|
|
2126
|
+
|
|
2127
|
+
const int h_dim = h_state->ne[0];
|
|
2128
|
+
const size_t row_size = ggml_row_size(gates->type, h_dim);
|
|
2129
|
+
|
|
2130
|
+
// The gates are packed as [i, f, o, c] (reordered at convert time, see
|
|
2131
|
+
// parakeet_model_load), so the three sigmoid-gated outputs (i, f, o) are
|
|
2132
|
+
// contiguous and can be computed with a single ggml_sigmoid call.
|
|
2133
|
+
struct ggml_tensor * ifo = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, gates, 3 * h_dim, 0));
|
|
2134
|
+
ggml_format_name(ifo, "lstm_layer_%d_ifo", li);
|
|
2135
|
+
|
|
2136
|
+
// 1. Input Gate at time t.
|
|
2137
|
+
struct ggml_tensor * i_t = ggml_view_1d(ctx0, ifo, h_dim, 0 * row_size);
|
|
2138
|
+
ggml_format_name(i_t, "lstm_layer_%d_i_t", li);
|
|
2139
|
+
|
|
2140
|
+
// Forget gate.
|
|
2141
|
+
struct ggml_tensor * f_t = ggml_view_1d(ctx0, ifo, h_dim, 1 * row_size);
|
|
2142
|
+
ggml_format_name(f_t, "lstm_layer_%d_f_t", li);
|
|
2143
|
+
|
|
2144
|
+
// Output gate.
|
|
2145
|
+
struct ggml_tensor * o_t = ggml_view_1d(ctx0, ifo, h_dim, 2 * row_size);
|
|
2146
|
+
ggml_format_name(o_t, "lstm_layer_%d_o_t", li);
|
|
2147
|
+
|
|
2148
|
+
// Cell gate.
|
|
2149
|
+
struct ggml_tensor * c_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, gates, h_dim, 3 * row_size));
|
|
2150
|
+
ggml_format_name(c_t, "lstm_layer_%d_c_t", li);
|
|
2151
|
+
|
|
2152
|
+
// Calculate the new cell state.
|
|
2153
|
+
struct ggml_tensor * c_new = ggml_add(ctx0,
|
|
2154
|
+
ggml_mul(ctx0, f_t, c_state), // apply forget gate to cell state.
|
|
2155
|
+
ggml_mul(ctx0, i_t, c_t)); // apply input gate to cell gate.
|
|
2156
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_new, c_state));
|
|
2157
|
+
|
|
2158
|
+
// Calculate the new hidden state.
|
|
2159
|
+
struct ggml_tensor * h_new = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_new));
|
|
2160
|
+
ggml_set_output(h_new);
|
|
2161
|
+
ggml_format_name(h_new, "lstm_layer_%d_h_new", li);
|
|
2162
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_new, h_state));
|
|
2163
|
+
|
|
2164
|
+
return h_new;
|
|
2165
|
+
}
|
|
2166
|
+
|
|
2167
|
+
static struct ggml_cgraph * parakeet_build_graph_prediction(
|
|
2168
|
+
parakeet_context & pctx,
|
|
2169
|
+
parakeet_state & pstate,
|
|
2170
|
+
const parakeet_batch & batch,
|
|
2171
|
+
bool worst_case) {
|
|
2172
|
+
GGML_UNUSED(worst_case);
|
|
2173
|
+
const auto & model = pctx.model;
|
|
2174
|
+
const auto & hparams = model.hparams;
|
|
2175
|
+
const int n_tokens = batch.n_tokens;
|
|
2176
|
+
|
|
2177
|
+
struct ggml_init_params params = {
|
|
2178
|
+
/*.mem_size =*/ pstate.sched_decode.meta.size(),
|
|
2179
|
+
/*.mem_buffer =*/ pstate.sched_decode.meta.data(),
|
|
2180
|
+
/*.no_alloc =*/ true,
|
|
2181
|
+
};
|
|
2182
|
+
|
|
2183
|
+
struct ggml_context * ctx0 = ggml_init(params);
|
|
2184
|
+
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false);
|
|
2185
|
+
|
|
2186
|
+
// Prediction Network
|
|
2187
|
+
struct ggml_tensor * token = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
|
2188
|
+
ggml_set_name(token, "token_inp");
|
|
2189
|
+
ggml_set_input(token);
|
|
2190
|
+
|
|
2191
|
+
struct ggml_tensor * token_embd = ggml_get_rows(ctx0, model.prediction.embed_w, token);
|
|
2192
|
+
|
|
2193
|
+
struct ggml_tensor * inpL = token_embd;
|
|
2194
|
+
|
|
2195
|
+
for (int il = 0; il < hparams.n_pred_layers; ++il) {
|
|
2196
|
+
inpL = parakeet_build_graph_lstm_layer(ctx0, gf, inpL,
|
|
2197
|
+
model.prediction.lstm_layer[il].ih_w,
|
|
2198
|
+
model.prediction.lstm_layer[il].hh_w,
|
|
2199
|
+
model.prediction.lstm_layer[il].b_h,
|
|
2200
|
+
pstate.lstm_state.layer[il].h_state,
|
|
2201
|
+
pstate.lstm_state.layer[il].c_state,
|
|
2202
|
+
il);
|
|
2203
|
+
}
|
|
2204
|
+
|
|
2205
|
+
struct ggml_tensor * pred_out = inpL;
|
|
2206
|
+
ggml_format_name(pred_out, "lstm_pred_out");
|
|
2207
|
+
|
|
2208
|
+
// Project the prediction network output to the joint network hidden dimension.
|
|
2209
|
+
struct ggml_tensor * pred = ggml_mul_mat(ctx0, model.joint.pred_w, pred_out);
|
|
2210
|
+
pred = ggml_add(ctx0, pred, model.joint.pred_b);
|
|
2211
|
+
ggml_set_name(pred, "h_pred");
|
|
2212
|
+
|
|
2213
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, pred, pstate.pred_out));
|
|
2214
|
+
|
|
2215
|
+
ggml_free(ctx0);
|
|
2216
|
+
|
|
2217
|
+
return gf;
|
|
2218
|
+
}
|
|
2219
|
+
|
|
2220
|
+
static struct ggml_cgraph * parakeet_build_graph_joint(
|
|
2221
|
+
parakeet_context & pctx,
|
|
2222
|
+
parakeet_state & pstate,
|
|
2223
|
+
const parakeet_batch & batch,
|
|
2224
|
+
bool worst_case) {
|
|
2225
|
+
GGML_UNUSED(worst_case);
|
|
2226
|
+
const auto & model = pctx.model;
|
|
2227
|
+
const auto & hparams = model.hparams;
|
|
2228
|
+
|
|
2229
|
+
struct ggml_init_params params = {
|
|
2230
|
+
/*.mem_size =*/ pstate.sched_decode.meta.size(),
|
|
2231
|
+
/*.mem_buffer =*/ pstate.sched_decode.meta.data(),
|
|
2232
|
+
/*.no_alloc =*/ true,
|
|
2233
|
+
};
|
|
2234
|
+
|
|
2235
|
+
struct ggml_context * ctx0 = ggml_init(params);
|
|
2236
|
+
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false);
|
|
2237
|
+
|
|
2238
|
+
struct ggml_tensor * pred = pstate.pred_out;
|
|
2239
|
+
ggml_format_name(pred, "pred");
|
|
2240
|
+
|
|
2241
|
+
const int t_idx = batch.i_time[0];
|
|
2242
|
+
struct ggml_tensor * enc_out = ggml_view_1d(ctx0, pstate.enc_out, hparams.n_audio_state,
|
|
2243
|
+
(size_t) t_idx * pstate.enc_out->nb[1]);
|
|
2244
|
+
ggml_format_name(enc_out, "enc_out_view");
|
|
2245
|
+
|
|
2246
|
+
// Project the encoder output to the joint network hidden dimension.
|
|
2247
|
+
struct ggml_tensor * enc = ggml_mul_mat(ctx0, model.joint.enc_w, enc_out);
|
|
2248
|
+
enc = ggml_add(ctx0, enc, model.joint.enc_b);
|
|
2249
|
+
ggml_set_name(enc, "enc");
|
|
2250
|
+
|
|
2251
|
+
struct ggml_tensor * joint = ggml_add(ctx0, enc, pred);
|
|
2252
|
+
ggml_set_name(joint, "joint");
|
|
2253
|
+
joint = ggml_relu(ctx0, joint);
|
|
2254
|
+
|
|
2255
|
+
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.joint.net_w, joint);
|
|
2256
|
+
logits = ggml_add(ctx0, logits, model.joint.net_b);
|
|
2257
|
+
ggml_set_output(logits);
|
|
2258
|
+
ggml_set_name(logits, "logits");
|
|
2259
|
+
|
|
2260
|
+
struct ggml_tensor * probs = ggml_soft_max(ctx0, logits);
|
|
2261
|
+
struct ggml_tensor * log_probs = ggml_log(ctx0, probs);
|
|
2262
|
+
ggml_set_output(log_probs);
|
|
2263
|
+
ggml_format_name(log_probs, "log_probs");
|
|
2264
|
+
|
|
2265
|
+
ggml_build_forward_expand(gf, log_probs);
|
|
2266
|
+
|
|
2267
|
+
ggml_free(ctx0);
|
|
2268
|
+
|
|
2269
|
+
return gf;
|
|
2270
|
+
}
|
|
2271
|
+
|
|
2272
|
+
static bool parakeet_predict(
|
|
2273
|
+
parakeet_context & pctx,
|
|
2274
|
+
parakeet_state & pstate,
|
|
2275
|
+
const parakeet_batch & batch,
|
|
2276
|
+
const int n_threads,
|
|
2277
|
+
ggml_abort_callback abort_callback,
|
|
2278
|
+
void * abort_callback_data) {
|
|
2279
|
+
|
|
2280
|
+
const int n_tokens = batch.n_tokens;
|
|
2281
|
+
|
|
2282
|
+
const int64_t t_start_us = ggml_time_us();
|
|
2283
|
+
|
|
2284
|
+
{
|
|
2285
|
+
auto & sched = pstate.sched_decode.sched;
|
|
2286
|
+
|
|
2287
|
+
const int64_t t_build_start_us = ggml_time_us();
|
|
2288
|
+
ggml_cgraph * gf = parakeet_build_graph_prediction(pctx, pstate, batch, false);
|
|
2289
|
+
pstate.t_predict_build_us += ggml_time_us() - t_build_start_us;
|
|
2290
|
+
|
|
2291
|
+
const int64_t t_alloc_start_us = ggml_time_us();
|
|
2292
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
|
2293
|
+
// should never happen as we pre-allocate the memory
|
|
2294
|
+
return false;
|
|
2295
|
+
}
|
|
2296
|
+
pstate.t_predict_alloc_us += ggml_time_us() - t_alloc_start_us;
|
|
2297
|
+
|
|
2298
|
+
// set the inputs
|
|
2299
|
+
{
|
|
2300
|
+
struct ggml_tensor * token_inp = ggml_graph_get_tensor(gf, "token_inp");
|
|
2301
|
+
ggml_backend_tensor_set(token_inp, batch.token, 0, n_tokens * ggml_element_size(token_inp));
|
|
2302
|
+
}
|
|
2303
|
+
|
|
2304
|
+
const int64_t t_compute_start_us = ggml_time_us();
|
|
2305
|
+
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
|
2306
|
+
return false;
|
|
2307
|
+
}
|
|
2308
|
+
pstate.t_predict_compute_us += ggml_time_us() - t_compute_start_us;
|
|
2309
|
+
}
|
|
2310
|
+
|
|
2311
|
+
pstate.t_predict_us += ggml_time_us() - t_start_us;
|
|
2312
|
+
pstate.n_predict++;
|
|
2313
|
+
|
|
2314
|
+
return !(abort_callback && abort_callback(abort_callback_data));
|
|
2315
|
+
}
|
|
2316
|
+
|
|
2317
|
+
static bool parakeet_joint(
|
|
2318
|
+
parakeet_context & pctx,
|
|
2319
|
+
parakeet_state & pstate,
|
|
2320
|
+
const parakeet_batch & batch,
|
|
2321
|
+
const int n_threads,
|
|
2322
|
+
ggml_abort_callback abort_callback,
|
|
2323
|
+
void * abort_callback_data) {
|
|
2324
|
+
const int64_t t_start_us = ggml_time_us();
|
|
2325
|
+
|
|
2326
|
+
const auto & model = pctx.model;
|
|
2327
|
+
const auto & hparams = model.hparams;
|
|
2328
|
+
const int n_tokens = batch.n_tokens;
|
|
2329
|
+
|
|
2330
|
+
auto & logits_out = pstate.logits;
|
|
2331
|
+
|
|
2332
|
+
struct ggml_tensor * logits;
|
|
2333
|
+
|
|
2334
|
+
{
|
|
2335
|
+
auto & sched = pstate.sched_decode.sched;
|
|
2336
|
+
|
|
2337
|
+
ggml_cgraph * gf = parakeet_build_graph_joint(pctx, pstate, batch, false);
|
|
2338
|
+
|
|
2339
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
|
2340
|
+
// should never happen as we pre-allocate the memory
|
|
2341
|
+
return false;
|
|
2342
|
+
}
|
|
2343
|
+
|
|
2344
|
+
logits = ggml_graph_node(gf, -1);
|
|
2345
|
+
|
|
2346
|
+
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
|
2347
|
+
return false;
|
|
2348
|
+
}
|
|
2349
|
+
|
|
2350
|
+
}
|
|
2351
|
+
|
|
2352
|
+
const int n_logits = hparams.n_vocab + hparams.n_tdt_durations + 1; // one for the blank token
|
|
2353
|
+
logits_out.resize(n_tokens * n_logits);
|
|
2354
|
+
for (int i = 0; i < n_tokens; i++) {
|
|
2355
|
+
if (batch.logits[i] == 0) {
|
|
2356
|
+
continue;
|
|
2357
|
+
}
|
|
2358
|
+
ggml_backend_tensor_get(logits, logits_out.data() + (n_logits*i), sizeof(float)*(n_logits*i), sizeof(float)*n_logits);
|
|
2359
|
+
}
|
|
2360
|
+
|
|
2361
|
+
if (batch.n_tokens == 1) {
|
|
2362
|
+
pstate.t_decode_us += ggml_time_us() - t_start_us;
|
|
2363
|
+
pstate.n_decode++;
|
|
2364
|
+
}
|
|
2365
|
+
|
|
2366
|
+
return !(abort_callback && abort_callback(abort_callback_data));
|
|
2367
|
+
}
|
|
2368
|
+
|
|
2369
|
+
static bool is_word_start_token(parakeet_vocab & vocab, parakeet_token token_id) {
|
|
2370
|
+
const std::string & token_str = vocab.id_to_token[token_id];
|
|
2371
|
+
// check if it starts with the SentencePiece meta-space "▁" (U+2581) or 3-byte UTF-8 character: 0xE2 0x96 0x81
|
|
2372
|
+
if (!token_str.empty()) {
|
|
2373
|
+
if (token_str.find("\xE2\x96\x81") == 0 || token_str[0] == '_') {
|
|
2374
|
+
return true;
|
|
2375
|
+
}
|
|
2376
|
+
}
|
|
2377
|
+
return false;
|
|
2378
|
+
}
|
|
2379
|
+
|
|
2380
|
+
static bool is_punctuation_token(parakeet_vocab & vocab, parakeet_token token_id) {
|
|
2381
|
+
const std::string & token_str = vocab.id_to_token[token_id];
|
|
2382
|
+
static const std::string punct_chars = ".,!?;:'\"-()[]{}";
|
|
2383
|
+
|
|
2384
|
+
if (token_str.empty()) {
|
|
2385
|
+
return false;
|
|
2386
|
+
}
|
|
2387
|
+
|
|
2388
|
+
std::string clean_token = token_str;
|
|
2389
|
+
if (clean_token.find("\xE2\x96\x81") == 0) {
|
|
2390
|
+
clean_token = clean_token.substr(3); // Remove the 3-byte UTF-8 character
|
|
2391
|
+
} else if (clean_token[0] == '_') {
|
|
2392
|
+
clean_token = clean_token.substr(1);
|
|
2393
|
+
}
|
|
2394
|
+
|
|
2395
|
+
return clean_token.length() == 1 && punct_chars.find(clean_token[0]) != std::string::npos;
|
|
2396
|
+
}
|
|
2397
|
+
|
|
2398
|
+
// Collapse punctuation timestamps to match the original Parakeet model.
|
|
2399
|
+
// Punctuations symbols like ',', '.' and others are not spoken words but the
|
|
2400
|
+
// model will still produce a duration for these tokens. But since these are
|
|
2401
|
+
// non-spoken we collapse the timestamps so that they don't have an time duration.
|
|
2402
|
+
static void refine_timestamps_tdt(parakeet_vocab & vocab, std::vector<parakeet_token_data> & tokens) {
|
|
2403
|
+
if (tokens.empty()) {
|
|
2404
|
+
return;
|
|
2405
|
+
}
|
|
2406
|
+
|
|
2407
|
+
int64_t last_non_punct_t1 = -1;
|
|
2408
|
+
|
|
2409
|
+
for (size_t i = 0; i < tokens.size(); ++i) {
|
|
2410
|
+
if (is_punctuation_token(vocab, tokens[i].id)) {
|
|
2411
|
+
if (last_non_punct_t1 >= 0) {
|
|
2412
|
+
tokens[i].t0 = last_non_punct_t1;
|
|
2413
|
+
tokens[i].t1 = last_non_punct_t1;
|
|
2414
|
+
}
|
|
2415
|
+
} else {
|
|
2416
|
+
last_non_punct_t1 = tokens[i].t1;
|
|
2417
|
+
}
|
|
2418
|
+
}
|
|
2419
|
+
}
|
|
2420
|
+
|
|
2421
|
+
static parakeet_token_data create_token_data(
|
|
2422
|
+
parakeet_context & pctx,
|
|
2423
|
+
parakeet_state & pstate,
|
|
2424
|
+
parakeet_token token_id,
|
|
2425
|
+
int duration_idx,
|
|
2426
|
+
int duration_value,
|
|
2427
|
+
int frame_index,
|
|
2428
|
+
float token_logit,
|
|
2429
|
+
int n_vocab_logits) {
|
|
2430
|
+
|
|
2431
|
+
float token_sum = 0.0f;
|
|
2432
|
+
for (int i = 0; i < n_vocab_logits; ++i) {
|
|
2433
|
+
token_sum += expf(pstate.logits[i]);
|
|
2434
|
+
}
|
|
2435
|
+
float token_p = expf(token_logit) / token_sum;
|
|
2436
|
+
|
|
2437
|
+
parakeet_token_data token_data;
|
|
2438
|
+
token_data.id = token_id;
|
|
2439
|
+
token_data.duration_idx = duration_idx;
|
|
2440
|
+
token_data.duration_value = duration_value;
|
|
2441
|
+
token_data.frame_index = frame_index;
|
|
2442
|
+
token_data.p = token_p;
|
|
2443
|
+
token_data.plog = token_logit;
|
|
2444
|
+
token_data.t0 = frame_index * pctx.model.hparams.subsampling_factor;
|
|
2445
|
+
token_data.t1 = (frame_index + duration_value) * pctx.model.hparams.subsampling_factor;
|
|
2446
|
+
token_data.is_word_start = is_word_start_token(pctx.vocab, token_id);
|
|
2447
|
+
|
|
2448
|
+
return token_data;
|
|
2449
|
+
}
|
|
2450
|
+
|
|
2451
|
+
static bool parakeet_decode(
|
|
2452
|
+
parakeet_context & pctx,
|
|
2453
|
+
parakeet_state & pstate,
|
|
2454
|
+
parakeet_batch & batch,
|
|
2455
|
+
const int n_threads,
|
|
2456
|
+
const parakeet_full_params * params = nullptr) {
|
|
2457
|
+
const auto & hparams = pctx.model.hparams;
|
|
2458
|
+
const auto & tdt_durations = pctx.model.tdt_durations;
|
|
2459
|
+
|
|
2460
|
+
const int n_tdt_durations = hparams.n_tdt_durations;
|
|
2461
|
+
const int n_frames = pstate.n_frames;
|
|
2462
|
+
const int blank_id = pctx.vocab.token_blank;
|
|
2463
|
+
const int n_vocab_logits = blank_id + 1;
|
|
2464
|
+
const int max_tokens_per_timestep = hparams.n_max_tokens;
|
|
2465
|
+
|
|
2466
|
+
// time index into the encoder frame (current time frame)
|
|
2467
|
+
int t = 0;
|
|
2468
|
+
// number of symbols emitted for the current time frame
|
|
2469
|
+
int tokens_emitted = 0;
|
|
2470
|
+
|
|
2471
|
+
// Start with the blank token (8192)
|
|
2472
|
+
parakeet_token last_token = blank_id;
|
|
2473
|
+
|
|
2474
|
+
PARAKEET_LOG_DEBUG("parakeet_decode: starting decode with n_frames=%d\n", n_frames);
|
|
2475
|
+
|
|
2476
|
+
batch.n_tokens = 1;
|
|
2477
|
+
batch.token[0] = last_token;
|
|
2478
|
+
batch.logits[0] = 1;
|
|
2479
|
+
batch.i_time[0] = 0;
|
|
2480
|
+
|
|
2481
|
+
// run the prediction network for the initial blank token. This will
|
|
2482
|
+
// initialize the LSTM state and produce an initial hidden state that can
|
|
2483
|
+
// be used in the joint network below.
|
|
2484
|
+
if (!parakeet_predict(pctx, pstate, batch, n_threads,
|
|
2485
|
+
params ? params->abort_callback : nullptr,
|
|
2486
|
+
params ? params->abort_callback_user_data : nullptr)) {
|
|
2487
|
+
return false;
|
|
2488
|
+
}
|
|
2489
|
+
|
|
2490
|
+
// process all time frames of the encoder output
|
|
2491
|
+
while (t < n_frames) {
|
|
2492
|
+
batch.n_tokens = 1;
|
|
2493
|
+
batch.i_time[0] = t;
|
|
2494
|
+
batch.logits[0] = 1;
|
|
2495
|
+
|
|
2496
|
+
// Use the current encoder frame (t) and the output of the prediction to
|
|
2497
|
+
// generate probabilities for the next token and duration. batch.i_time
|
|
2498
|
+
// is used in to select the correct frame from the encoder output.
|
|
2499
|
+
// The joint network outputs logits for all the tokens in the vocabulary
|
|
2500
|
+
// plus the blank token, and also n_duration logits for the duration
|
|
2501
|
+
// tokens which contain information about how many frames to skip/advance forward.
|
|
2502
|
+
if (!parakeet_joint(pctx, pstate, batch, n_threads,
|
|
2503
|
+
params ? params->abort_callback : nullptr,
|
|
2504
|
+
params ? params->abort_callback_user_data : nullptr)) {
|
|
2505
|
+
return false;
|
|
2506
|
+
}
|
|
2507
|
+
|
|
2508
|
+
const int64_t t_start_sample_us = ggml_time_us();
|
|
2509
|
+
|
|
2510
|
+
// find the best token (greedy).
|
|
2511
|
+
// TODO: implement beam search?
|
|
2512
|
+
int best_token = 0;
|
|
2513
|
+
float max_logit = -1e10f;
|
|
2514
|
+
for (int i = 0; i < n_vocab_logits; ++i) {
|
|
2515
|
+
if (pstate.logits[i] > max_logit) {
|
|
2516
|
+
max_logit = pstate.logits[i];
|
|
2517
|
+
best_token = i;
|
|
2518
|
+
}
|
|
2519
|
+
}
|
|
2520
|
+
|
|
2521
|
+
// find the max index of the duration logits, and look up that index
|
|
2522
|
+
// value in the tdt_durations array to get the actual duration value.
|
|
2523
|
+
int best_duration_idx = 0;
|
|
2524
|
+
float best_duration_logit = -1e10f;
|
|
2525
|
+
for (int i = 0; i < n_tdt_durations; ++i) {
|
|
2526
|
+
if (pstate.logits[n_vocab_logits + i] > best_duration_logit) {
|
|
2527
|
+
best_duration_logit = pstate.logits[n_vocab_logits + i];
|
|
2528
|
+
best_duration_idx = i;
|
|
2529
|
+
}
|
|
2530
|
+
}
|
|
2531
|
+
// look up that max duration index value in the tdt_durations array to
|
|
2532
|
+
// get the actual duration value.
|
|
2533
|
+
int duration = tdt_durations[best_duration_idx];
|
|
2534
|
+
|
|
2535
|
+
if (best_token == blank_id) {
|
|
2536
|
+
if (duration == 0) {
|
|
2537
|
+
duration = 1;
|
|
2538
|
+
}
|
|
2539
|
+
// skip forward by duration time frames.
|
|
2540
|
+
t += duration;
|
|
2541
|
+
// reset symbols emitted counter
|
|
2542
|
+
tokens_emitted = 0;
|
|
2543
|
+
// continue without predicting.
|
|
2544
|
+
continue;
|
|
2545
|
+
}
|
|
2546
|
+
|
|
2547
|
+
// Emit non-blank token at current frame t.
|
|
2548
|
+
pstate.decoded_tokens.push_back(best_token);
|
|
2549
|
+
pstate.t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
2550
|
+
pstate.n_sample++;
|
|
2551
|
+
|
|
2552
|
+
parakeet_token_data token_data = create_token_data(
|
|
2553
|
+
pctx, pstate, best_token, best_duration_idx, duration, t,
|
|
2554
|
+
max_logit, n_vocab_logits);
|
|
2555
|
+
|
|
2556
|
+
pstate.decoded_token_data.push_back(token_data);
|
|
2557
|
+
|
|
2558
|
+
// Call token callback if registered (for real-time streaming)
|
|
2559
|
+
if (params && params->new_token_callback) {
|
|
2560
|
+
params->new_token_callback(&pctx, &pstate, &token_data, params->new_token_callback_user_data);
|
|
2561
|
+
}
|
|
2562
|
+
|
|
2563
|
+
last_token = best_token;
|
|
2564
|
+
|
|
2565
|
+
// advance predictor for the non-blank token.
|
|
2566
|
+
batch.token[0] = last_token;
|
|
2567
|
+
if (!parakeet_predict(pctx, pstate, batch, n_threads,
|
|
2568
|
+
params ? params->abort_callback : nullptr,
|
|
2569
|
+
params ? params->abort_callback_user_data : nullptr)) {
|
|
2570
|
+
return false;
|
|
2571
|
+
}
|
|
2572
|
+
|
|
2573
|
+
// if duration greater than 0, continue looping over the encoder frames
|
|
2574
|
+
// and skip to the updated time frame (t).
|
|
2575
|
+
if (duration > 0) {
|
|
2576
|
+
t += duration;
|
|
2577
|
+
tokens_emitted = 0;
|
|
2578
|
+
continue;
|
|
2579
|
+
}
|
|
2580
|
+
|
|
2581
|
+
// if duration is zero we stay on the current time frame.
|
|
2582
|
+
tokens_emitted++;
|
|
2583
|
+
if (tokens_emitted >= max_tokens_per_timestep) {
|
|
2584
|
+
t += 1; // forced blank/time advance behavior
|
|
2585
|
+
tokens_emitted = 0;
|
|
2586
|
+
}
|
|
2587
|
+
}
|
|
2588
|
+
|
|
2589
|
+
return true;
|
|
2590
|
+
}
|
|
2591
|
+
|
|
2592
|
+
// 500 -> 00:05.000
|
|
2593
|
+
// 6000 -> 01:00.000
|
|
2594
|
+
// naive Discrete Fourier Transform
|
|
2595
|
+
// input is real-valued
|
|
2596
|
+
// output is complex-valued
|
|
2597
|
+
static void dft(const float* in, int N, float* out, const parakeet_mel_cache & cache) {
|
|
2598
|
+
const int sin_cos_step = cache.n_fft / N;
|
|
2599
|
+
|
|
2600
|
+
for (int k = 0; k < N; k++) {
|
|
2601
|
+
float re = 0;
|
|
2602
|
+
float im = 0;
|
|
2603
|
+
|
|
2604
|
+
for (int n = 0; n < N; n++) {
|
|
2605
|
+
int idx = (k * n * sin_cos_step) % cache.n_fft; // t = 2*M_PI*k*n/N
|
|
2606
|
+
re += in[n]*cache.cos_vals[idx]; // cos(t)
|
|
2607
|
+
im -= in[n]*cache.sin_vals[idx]; // sin(t)
|
|
2608
|
+
}
|
|
2609
|
+
|
|
2610
|
+
out[k*2 + 0] = re;
|
|
2611
|
+
out[k*2 + 1] = im;
|
|
2612
|
+
}
|
|
2613
|
+
}
|
|
2614
|
+
|
|
2615
|
+
// Cooley-Tukey FFT
|
|
2616
|
+
// poor man's implementation - use something better
|
|
2617
|
+
// input is real-valued
|
|
2618
|
+
// output is complex-valued
|
|
2619
|
+
static void fft(float* in, int N, float* out, const parakeet_mel_cache & cache) {
|
|
2620
|
+
if (N == 1) {
|
|
2621
|
+
out[0] = in[0];
|
|
2622
|
+
out[1] = 0;
|
|
2623
|
+
return;
|
|
2624
|
+
}
|
|
2625
|
+
|
|
2626
|
+
const int half_N = N / 2;
|
|
2627
|
+
if (N - half_N*2 == 1) {
|
|
2628
|
+
dft(in, N, out, cache);
|
|
2629
|
+
return;
|
|
2630
|
+
}
|
|
2631
|
+
|
|
2632
|
+
float* even = in + N;
|
|
2633
|
+
for (int i = 0; i < half_N; ++i) {
|
|
2634
|
+
even[i]= in[2*i];
|
|
2635
|
+
}
|
|
2636
|
+
float* even_fft = out + 2 * N;
|
|
2637
|
+
fft(even, half_N, even_fft, cache);
|
|
2638
|
+
|
|
2639
|
+
float* odd = even;
|
|
2640
|
+
for (int i = 0; i < half_N; ++i) {
|
|
2641
|
+
odd[i] = in[2*i + 1];
|
|
2642
|
+
}
|
|
2643
|
+
float* odd_fft = even_fft + N;
|
|
2644
|
+
fft(odd, half_N, odd_fft, cache);
|
|
2645
|
+
|
|
2646
|
+
const int sin_cos_step = cache.n_fft / N;
|
|
2647
|
+
for (int k = 0; k < half_N; k++) {
|
|
2648
|
+
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
|
|
2649
|
+
float re = cache.cos_vals[idx]; // cos(t)
|
|
2650
|
+
float im = -cache.sin_vals[idx]; // sin(t)
|
|
2651
|
+
|
|
2652
|
+
float re_odd = odd_fft[2*k + 0];
|
|
2653
|
+
float im_odd = odd_fft[2*k + 1];
|
|
2654
|
+
|
|
2655
|
+
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
|
|
2656
|
+
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
|
|
2657
|
+
|
|
2658
|
+
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
|
|
2659
|
+
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
|
|
2660
|
+
}
|
|
2661
|
+
}
|
|
2662
|
+
|
|
2663
|
+
struct mel_worker_params {
|
|
2664
|
+
int ith;
|
|
2665
|
+
int window_size;
|
|
2666
|
+
int n_samples;
|
|
2667
|
+
int frame_size;
|
|
2668
|
+
int frame_step;
|
|
2669
|
+
int n_threads;
|
|
2670
|
+
};
|
|
2671
|
+
|
|
2672
|
+
static void log_mel_spectrogram_worker_thread(
|
|
2673
|
+
mel_worker_params params,
|
|
2674
|
+
const float * window_func,
|
|
2675
|
+
const std::vector<float> & samples,
|
|
2676
|
+
const parakeet_filters & filters,
|
|
2677
|
+
parakeet_mel & mel,
|
|
2678
|
+
const parakeet_mel_cache & cache) {
|
|
2679
|
+
std::vector<float> fft_in(params.frame_size * 2, 0.0);
|
|
2680
|
+
std::vector<float> fft_out(params.frame_size * 2 * 2 * 2);
|
|
2681
|
+
|
|
2682
|
+
int n_fb = filters.n_fb; // number of frequency bins
|
|
2683
|
+
int i = params.ith;
|
|
2684
|
+
|
|
2685
|
+
// make sure n_fb == 1 + (frame_size / 2), bin_0 to bin_nyquist
|
|
2686
|
+
assert(n_fb == 1 + (params.frame_size / 2));
|
|
2687
|
+
|
|
2688
|
+
const double eps = 5.960464477539063e-08;
|
|
2689
|
+
|
|
2690
|
+
// calculate FFT only when fft_in are not all zero
|
|
2691
|
+
for (; i < std::min(params.n_samples / params.frame_step + 1, mel.n_len); i += params.n_threads) {
|
|
2692
|
+
const int offset = i * params.frame_step;
|
|
2693
|
+
|
|
2694
|
+
const int window_pad_left = (params.frame_size - params.window_size) / 2;
|
|
2695
|
+
|
|
2696
|
+
// Zero-pad left
|
|
2697
|
+
std::fill(fft_in.begin(), fft_in.begin() + window_pad_left, 0.0f);
|
|
2698
|
+
|
|
2699
|
+
// Apply windowed samples in the center
|
|
2700
|
+
const int n_to_process = std::min({params.window_size, params.n_samples - offset});
|
|
2701
|
+
for (int j = 0; j < n_to_process; j++) {
|
|
2702
|
+
fft_in[window_pad_left + j] = window_func[j] * samples[offset + window_pad_left + j];
|
|
2703
|
+
}
|
|
2704
|
+
|
|
2705
|
+
// Zero-pad right (and any samples we didn't have)
|
|
2706
|
+
std::fill(fft_in.begin() + window_pad_left + n_to_process, fft_in.begin() + params.frame_size, 0.0f);
|
|
2707
|
+
|
|
2708
|
+
// FFT
|
|
2709
|
+
fft(fft_in.data(), params.frame_size, fft_out.data(), cache);
|
|
2710
|
+
|
|
2711
|
+
// Calculate modulus^2 of complex numbers
|
|
2712
|
+
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
|
|
2713
|
+
for (int j = 0; j < n_fb; j++) {
|
|
2714
|
+
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
|
|
2715
|
+
}
|
|
2716
|
+
|
|
2717
|
+
// mel spectrogram
|
|
2718
|
+
for (int j = 0; j < mel.n_mel; j++) {
|
|
2719
|
+
double sum = 0.0;
|
|
2720
|
+
// unroll loop (suggested by GH user @lunixbochs)
|
|
2721
|
+
int k = 0;
|
|
2722
|
+
for (k = 0; k < n_fb - 3; k += 4) {
|
|
2723
|
+
sum +=
|
|
2724
|
+
fft_out[k + 0] * filters.data[j * n_fb + k + 0] +
|
|
2725
|
+
fft_out[k + 1] * filters.data[j * n_fb + k + 1] +
|
|
2726
|
+
fft_out[k + 2] * filters.data[j * n_fb + k + 2] +
|
|
2727
|
+
fft_out[k + 3] * filters.data[j * n_fb + k + 3];
|
|
2728
|
+
}
|
|
2729
|
+
// handle n_fb remainder
|
|
2730
|
+
for (; k < n_fb; k++) {
|
|
2731
|
+
sum += fft_out[k] * filters.data[j * n_fb + k];
|
|
2732
|
+
}
|
|
2733
|
+
|
|
2734
|
+
mel.data[i * mel.n_mel + j] = std::log(sum + eps);
|
|
2735
|
+
}
|
|
2736
|
+
}
|
|
2737
|
+
|
|
2738
|
+
// Otherwise fft_out are all zero - use log(eps) for consistency
|
|
2739
|
+
const double empty_sum = std::log(eps);
|
|
2740
|
+
for (; i < mel.n_len; i += params.n_threads) {
|
|
2741
|
+
for (int j = 0; j < mel.n_mel; j++) {
|
|
2742
|
+
mel.data[i * mel.n_mel + j] = empty_sum;
|
|
2743
|
+
}
|
|
2744
|
+
}
|
|
2745
|
+
}
|
|
2746
|
+
|
|
2747
|
+
static bool log_mel_spectrogram(
|
|
2748
|
+
parakeet_state & wstate,
|
|
2749
|
+
const float * samples,
|
|
2750
|
+
const int n_samples,
|
|
2751
|
+
const int /*sample_rate*/,
|
|
2752
|
+
const int frame_size,
|
|
2753
|
+
const int frame_step,
|
|
2754
|
+
const int n_mel,
|
|
2755
|
+
const int n_threads,
|
|
2756
|
+
const parakeet_filters & filters,
|
|
2757
|
+
const bool debug,
|
|
2758
|
+
parakeet_mel & mel,
|
|
2759
|
+
const parakeet_mel_cache & cache) {
|
|
2760
|
+
const int64_t t_start_us = ggml_time_us();
|
|
2761
|
+
|
|
2762
|
+
const float * window_func = cache.window.empty() ? cache.hann_window.data() : cache.window.data();
|
|
2763
|
+
const int window_size = cache.window.empty() ? cache.n_fft : cache.window.size();
|
|
2764
|
+
|
|
2765
|
+
std::vector<float> samples_preprocessed(samples, samples + n_samples);
|
|
2766
|
+
|
|
2767
|
+
// Apply preemphasis filter (high-pass): x[i] = x[i] - 0.97 * x[i-1]
|
|
2768
|
+
{
|
|
2769
|
+
const float preemph = 0.97f;
|
|
2770
|
+
for (int i = n_samples - 1; i > 0; i--) {
|
|
2771
|
+
samples_preprocessed[i] = samples_preprocessed[i] - preemph * samples_preprocessed[i - 1];
|
|
2772
|
+
}
|
|
2773
|
+
}
|
|
2774
|
+
|
|
2775
|
+
// Parakeet Pytorch implementation uses centered contant padding.
|
|
2776
|
+
const size_t pad = (size_t)(frame_size / 2);
|
|
2777
|
+
std::vector<float> samples_padded(n_samples + 2 * pad, 0.0f);
|
|
2778
|
+
std::copy(samples_preprocessed.begin(), samples_preprocessed.end(), samples_padded.begin() + pad);
|
|
2779
|
+
|
|
2780
|
+
mel.n_mel = n_mel;
|
|
2781
|
+
mel.n_len = (samples_padded.size() - frame_size) / frame_step + 1;
|
|
2782
|
+
mel.n_len_org = mel.n_len;
|
|
2783
|
+
mel.data.resize(mel.n_mel * mel.n_len);
|
|
2784
|
+
|
|
2785
|
+
// Worker Threads (STFT + Mel + Natural Log)
|
|
2786
|
+
{
|
|
2787
|
+
std::vector<std::thread> workers(n_threads - 1);
|
|
2788
|
+
const mel_worker_params mel_params { 0, window_size, (int)samples_padded.size(), frame_size, frame_step, n_threads };
|
|
2789
|
+
|
|
2790
|
+
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
|
2791
|
+
mel_worker_params params = mel_params;
|
|
2792
|
+
params.ith = iw + 1;
|
|
2793
|
+
workers[iw] = std::thread(log_mel_spectrogram_worker_thread,
|
|
2794
|
+
params,
|
|
2795
|
+
window_func,
|
|
2796
|
+
std::cref(samples_padded),
|
|
2797
|
+
std::cref(filters),
|
|
2798
|
+
std::ref(mel),
|
|
2799
|
+
std::cref(cache));
|
|
2800
|
+
}
|
|
2801
|
+
|
|
2802
|
+
log_mel_spectrogram_worker_thread(
|
|
2803
|
+
mel_params,
|
|
2804
|
+
window_func,
|
|
2805
|
+
samples_padded,
|
|
2806
|
+
filters,
|
|
2807
|
+
mel,
|
|
2808
|
+
cache);
|
|
2809
|
+
|
|
2810
|
+
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
|
2811
|
+
workers[iw].join();
|
|
2812
|
+
}
|
|
2813
|
+
}
|
|
2814
|
+
|
|
2815
|
+
{
|
|
2816
|
+
const double eps = 1e-5;
|
|
2817
|
+
int valid_frames = n_samples / frame_step;
|
|
2818
|
+
|
|
2819
|
+
for (int j = 0; j < mel.n_mel; j++) {
|
|
2820
|
+
double sum = 0.0;
|
|
2821
|
+
double sq_diff_sum = 0.0;
|
|
2822
|
+
|
|
2823
|
+
// Calculate Mean ONLY on valid audio frames
|
|
2824
|
+
for (int i = 0; i < valid_frames; i++) {
|
|
2825
|
+
sum += (double)mel.data[i * mel.n_mel + j];
|
|
2826
|
+
}
|
|
2827
|
+
double mean = sum / valid_frames;
|
|
2828
|
+
|
|
2829
|
+
// Calculate Variance ONLY on valid audio frames
|
|
2830
|
+
for (int i = 0; i < valid_frames; i++) {
|
|
2831
|
+
double diff = (double)mel.data[i * mel.n_mel + j] - mean;
|
|
2832
|
+
sq_diff_sum += diff * diff;
|
|
2833
|
+
}
|
|
2834
|
+
|
|
2835
|
+
double std_dev = std::sqrt(sq_diff_sum / (valid_frames - 1.0));
|
|
2836
|
+
double denominator = std_dev + eps;
|
|
2837
|
+
|
|
2838
|
+
// Apply to ALL frames (including the padded ones)
|
|
2839
|
+
for (int i = 0; i < mel.n_len; i++) {
|
|
2840
|
+
mel.data[i * mel.n_mel + j] = (float)((mel.data[i * mel.n_mel + j] - mean) / denominator);
|
|
2841
|
+
}
|
|
2842
|
+
}
|
|
2843
|
+
}
|
|
2844
|
+
|
|
2845
|
+
wstate.t_mel_us += ggml_time_us() - t_start_us;
|
|
2846
|
+
|
|
2847
|
+
if (debug) {
|
|
2848
|
+
std::ofstream outFile("log_mel_spectrogram.json");
|
|
2849
|
+
outFile << "[";
|
|
2850
|
+
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
|
|
2851
|
+
outFile << mel.data[i] << ", ";
|
|
2852
|
+
}
|
|
2853
|
+
outFile << mel.data[mel.data.size() - 1] << "]";
|
|
2854
|
+
outFile.close();
|
|
2855
|
+
}
|
|
2856
|
+
|
|
2857
|
+
return true;
|
|
2858
|
+
}
|
|
2859
|
+
|
|
2860
|
+
static std::vector<parakeet_vocab::id> tokenize(const parakeet_vocab & vocab, const std::string & text) {
|
|
2861
|
+
std::vector<parakeet_vocab::id> tokens;
|
|
2862
|
+
const std::string normalized = sentencepiece_normalize(text);
|
|
2863
|
+
|
|
2864
|
+
size_t i = 0;
|
|
2865
|
+
while (i < normalized.size()) {
|
|
2866
|
+
const size_t remaining = normalized.size() - i;
|
|
2867
|
+
const size_t max_len = std::min(vocab.max_token_length, remaining);
|
|
2868
|
+
|
|
2869
|
+
bool found = false;
|
|
2870
|
+
for (size_t len = max_len; len > 0; --len) {
|
|
2871
|
+
const auto it = vocab.token_to_id.find(normalized.substr(i, len));
|
|
2872
|
+
if (it != vocab.token_to_id.end() && !is_sentencepiece_control(it->first)) {
|
|
2873
|
+
tokens.push_back(it->second);
|
|
2874
|
+
i += len;
|
|
2875
|
+
found = true;
|
|
2876
|
+
break;
|
|
2877
|
+
}
|
|
2878
|
+
}
|
|
2879
|
+
|
|
2880
|
+
if (!found) {
|
|
2881
|
+
if (vocab.token_unk >= 0) {
|
|
2882
|
+
tokens.push_back(vocab.token_unk);
|
|
2883
|
+
}
|
|
2884
|
+
|
|
2885
|
+
const unsigned char c = static_cast<unsigned char>(normalized[i]);
|
|
2886
|
+
i += utf8_codepoint_len(c);
|
|
2887
|
+
}
|
|
2888
|
+
}
|
|
2889
|
+
|
|
2890
|
+
return tokens;
|
|
2891
|
+
}
|
|
2892
|
+
|
|
2893
|
+
|
|
2894
|
+
//
|
|
2895
|
+
// interface implementation
|
|
2896
|
+
//
|
|
2897
|
+
|
|
2898
|
+
struct parakeet_state * parakeet_init_state(parakeet_context * ctx) {
|
|
2899
|
+
parakeet_state * state = new parakeet_state;
|
|
2900
|
+
|
|
2901
|
+
state->backends = parakeet_backend_init(ctx->params);
|
|
2902
|
+
if (state->backends.empty()) {
|
|
2903
|
+
PARAKEET_LOG_ERROR("%s: parakeet_backend_init() failed\n", __func__);
|
|
2904
|
+
parakeet_free_state(state);
|
|
2905
|
+
return nullptr;
|
|
2906
|
+
}
|
|
2907
|
+
|
|
2908
|
+
const int batch_size = ctx->model.hparams.n_audio_ctx;
|
|
2909
|
+
|
|
2910
|
+
state->logits.reserve(ctx->vocab.n_vocab * batch_size);
|
|
2911
|
+
|
|
2912
|
+
state->batch = parakeet_batch_init(batch_size);
|
|
2913
|
+
|
|
2914
|
+
{
|
|
2915
|
+
const int n_audio_state = ctx->model.hparams.n_audio_state;
|
|
2916
|
+
const int subsampl_factor = ctx->model.hparams.subsampling_factor;
|
|
2917
|
+
const int n_frames_max = (batch_size + subsampl_factor - 1) / subsampl_factor;
|
|
2918
|
+
|
|
2919
|
+
if (!parakeet_enc_state_init(*state, state->backends[0], n_audio_state, n_frames_max)) {
|
|
2920
|
+
PARAKEET_LOG_ERROR("%s: parakeet_enc_state_init() failed\n", __func__);
|
|
2921
|
+
parakeet_free_state(state);
|
|
2922
|
+
return nullptr;
|
|
2923
|
+
}
|
|
2924
|
+
|
|
2925
|
+
const size_t mem_enc_ctx = state->enc_out_buf.size();
|
|
2926
|
+
const size_t mem_enc_out_buf = ggml_backend_buffer_get_size(state->enc_out_buffer);
|
|
2927
|
+
PARAKEET_LOG_INFO("%s: enc_out state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__,
|
|
2928
|
+
mem_enc_ctx / 1024.0 / 1024.0, mem_enc_out_buf / 1024.0 / 1024.0);
|
|
2929
|
+
}
|
|
2930
|
+
|
|
2931
|
+
// conv/encoder allocator
|
|
2932
|
+
bool ok = parakeet_sched_graph_init(state->sched_encode, state->backends,
|
|
2933
|
+
[&]() {
|
|
2934
|
+
return parakeet_build_graph_encode(*ctx, *state);
|
|
2935
|
+
});
|
|
2936
|
+
|
|
2937
|
+
if (!ok) {
|
|
2938
|
+
PARAKEET_LOG_ERROR("%s: failed to init encode allocator\n", __func__);
|
|
2939
|
+
parakeet_free_state(state);
|
|
2940
|
+
return nullptr;
|
|
2941
|
+
}
|
|
2942
|
+
state->sched_encode_n_audio_ctx = state->n_audio_ctx > 0 ? state->n_audio_ctx : ctx->model.hparams.n_audio_ctx;
|
|
2943
|
+
|
|
2944
|
+
if (!parakeet_lstm_state_init(*state, state->backends[0], ctx->model.hparams.n_pred_layers, ctx->model.hparams.n_pred_dim)) {
|
|
2945
|
+
PARAKEET_LOG_ERROR("%s: parakeet_lstm_states_init () failed\n", __func__);
|
|
2946
|
+
parakeet_free_state(state);
|
|
2947
|
+
return nullptr;
|
|
2948
|
+
}
|
|
2949
|
+
|
|
2950
|
+
{
|
|
2951
|
+
const size_t mem_lstm_ctx = state->lstm_state.ctx_buf.size();
|
|
2952
|
+
const size_t mem_lstm_buf = ggml_backend_buffer_get_size(state->lstm_state.buffer);
|
|
2953
|
+
PARAKEET_LOG_INFO("%s: lstm state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__,
|
|
2954
|
+
mem_lstm_ctx / 1024.0 / 1024.0, mem_lstm_buf / 1024.0 / 1024.0);
|
|
2955
|
+
}
|
|
2956
|
+
|
|
2957
|
+
if (!parakeet_pred_state_init(*state, state->backends[0], ctx->model.hparams.n_pred_dim)) {
|
|
2958
|
+
PARAKEET_LOG_ERROR("%s: parakeet_pred_state_init() failed\n", __func__);
|
|
2959
|
+
parakeet_free_state(state);
|
|
2960
|
+
return nullptr;
|
|
2961
|
+
}
|
|
2962
|
+
|
|
2963
|
+
{
|
|
2964
|
+
const size_t mem_pred_ctx = state->pred_out_buf.size();
|
|
2965
|
+
const size_t mem_pred_out_buf = ggml_backend_buffer_get_size(state->pred_out_buffer);
|
|
2966
|
+
PARAKEET_LOG_INFO("%s: pred state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__,
|
|
2967
|
+
mem_pred_ctx / 1024.0 / 1024.0, mem_pred_out_buf / 1024.0 / 1024.0);
|
|
2968
|
+
}
|
|
2969
|
+
|
|
2970
|
+
PARAKEET_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_encode) / 1e6);
|
|
2971
|
+
|
|
2972
|
+
{
|
|
2973
|
+
bool ok = parakeet_sched_graph_init(state->sched_decode, state->backends,
|
|
2974
|
+
[&]() {
|
|
2975
|
+
const auto & hparams = ctx->model.hparams;
|
|
2976
|
+
const int n_tokens = hparams.n_audio_ctx; // Use audio ctx for Parakeet
|
|
2977
|
+
|
|
2978
|
+
parakeet_batch_prep_legacy(state->batch, nullptr, n_tokens, 0, 0);
|
|
2979
|
+
|
|
2980
|
+
return parakeet_build_graph_prediction(*ctx, *state, state->batch, true);
|
|
2981
|
+
});
|
|
2982
|
+
|
|
2983
|
+
if (!ok) {
|
|
2984
|
+
PARAKEET_LOG_ERROR("%s: failed to init decoder allocator\n", __func__);
|
|
2985
|
+
parakeet_free_state(state);
|
|
2986
|
+
return nullptr;
|
|
2987
|
+
}
|
|
2988
|
+
|
|
2989
|
+
PARAKEET_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_decode) / 1e6);
|
|
2990
|
+
}
|
|
2991
|
+
|
|
2992
|
+
return state;
|
|
2993
|
+
}
|
|
2994
|
+
|
|
2995
|
+
struct parakeet_context_params parakeet_context_default_params() {
|
|
2996
|
+
struct parakeet_context_params result = {
|
|
2997
|
+
/*.use_gpu =*/ true,
|
|
2998
|
+
/*.gpu_device =*/ 0,
|
|
2999
|
+
};
|
|
3000
|
+
return result;
|
|
3001
|
+
}
|
|
3002
|
+
|
|
3003
|
+
struct parakeet_context * parakeet_init_from_file_with_params_no_state(const char * path_model, struct parakeet_context_params params) {
|
|
3004
|
+
PARAKEET_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
|
|
3005
|
+
#ifdef _MSC_VER
|
|
3006
|
+
// Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues.
|
|
3007
|
+
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
|
3008
|
+
std::wstring path_model_wide = converter.from_bytes(path_model);
|
|
3009
|
+
auto fin = std::ifstream(path_model_wide, std::ios::binary);
|
|
3010
|
+
#else
|
|
3011
|
+
auto fin = std::ifstream(path_model, std::ios::binary);
|
|
3012
|
+
#endif
|
|
3013
|
+
if (!fin) {
|
|
3014
|
+
PARAKEET_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
|
|
3015
|
+
return nullptr;
|
|
3016
|
+
}
|
|
3017
|
+
|
|
3018
|
+
parakeet_model_loader loader = {};
|
|
3019
|
+
|
|
3020
|
+
loader.context = &fin;
|
|
3021
|
+
|
|
3022
|
+
loader.read = [](void * ctx, void * output, size_t read_size) {
|
|
3023
|
+
std::ifstream * fin = (std::ifstream*)ctx;
|
|
3024
|
+
fin->read((char *)output, read_size);
|
|
3025
|
+
return read_size;
|
|
3026
|
+
};
|
|
3027
|
+
|
|
3028
|
+
loader.eof = [](void * ctx) {
|
|
3029
|
+
std::ifstream * fin = (std::ifstream*)ctx;
|
|
3030
|
+
return fin->eof();
|
|
3031
|
+
};
|
|
3032
|
+
|
|
3033
|
+
loader.close = [](void * ctx) {
|
|
3034
|
+
std::ifstream * fin = (std::ifstream*)ctx;
|
|
3035
|
+
fin->close();
|
|
3036
|
+
};
|
|
3037
|
+
|
|
3038
|
+
auto ctx = parakeet_init_with_params_no_state(&loader, params);
|
|
3039
|
+
|
|
3040
|
+
if (ctx) {
|
|
3041
|
+
ctx->path_model = path_model;
|
|
3042
|
+
}
|
|
3043
|
+
|
|
3044
|
+
return ctx;
|
|
3045
|
+
}
|
|
3046
|
+
|
|
3047
|
+
struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params) {
|
|
3048
|
+
struct buf_context {
|
|
3049
|
+
uint8_t* buffer;
|
|
3050
|
+
size_t size;
|
|
3051
|
+
size_t current_offset;
|
|
3052
|
+
};
|
|
3053
|
+
|
|
3054
|
+
buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
|
|
3055
|
+
|
|
3056
|
+
PARAKEET_LOG_INFO("%s: loading model from buffer\n", __func__);
|
|
3057
|
+
|
|
3058
|
+
parakeet_model_loader loader = {};
|
|
3059
|
+
|
|
3060
|
+
loader.context = &ctx;
|
|
3061
|
+
|
|
3062
|
+
loader.read = [](void * ctx, void * output, size_t read_size) {
|
|
3063
|
+
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
|
|
3064
|
+
|
|
3065
|
+
size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset;
|
|
3066
|
+
|
|
3067
|
+
memcpy(output, buf->buffer + buf->current_offset, size_to_copy);
|
|
3068
|
+
buf->current_offset += size_to_copy;
|
|
3069
|
+
|
|
3070
|
+
return size_to_copy;
|
|
3071
|
+
};
|
|
3072
|
+
|
|
3073
|
+
loader.eof = [](void * ctx) {
|
|
3074
|
+
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
|
|
3075
|
+
|
|
3076
|
+
return buf->current_offset >= buf->size;
|
|
3077
|
+
};
|
|
3078
|
+
|
|
3079
|
+
loader.close = [](void * /*ctx*/) { };
|
|
3080
|
+
|
|
3081
|
+
return parakeet_init_with_params_no_state(&loader, params);
|
|
3082
|
+
}
|
|
3083
|
+
|
|
3084
|
+
struct parakeet_context * parakeet_init_with_params_no_state(struct parakeet_model_loader * loader, struct parakeet_context_params params) {
|
|
3085
|
+
ggml_time_init();
|
|
3086
|
+
|
|
3087
|
+
PARAKEET_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
|
|
3088
|
+
PARAKEET_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
|
|
3089
|
+
PARAKEET_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count());
|
|
3090
|
+
PARAKEET_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count());
|
|
3091
|
+
|
|
3092
|
+
parakeet_context * ctx = new parakeet_context;
|
|
3093
|
+
ctx->params = params;
|
|
3094
|
+
|
|
3095
|
+
bool model_loaded = false;
|
|
3096
|
+
try {
|
|
3097
|
+
model_loaded = parakeet_model_load(loader, *ctx);
|
|
3098
|
+
} catch (const std::exception & e) {
|
|
3099
|
+
PARAKEET_LOG_ERROR("%s: exception during model load: %s\n", __func__, e.what());
|
|
3100
|
+
} catch (...) {
|
|
3101
|
+
PARAKEET_LOG_ERROR("%s: unknown exception during model load\n", __func__);
|
|
3102
|
+
}
|
|
3103
|
+
|
|
3104
|
+
if (!model_loaded) {
|
|
3105
|
+
loader->close(loader->context);
|
|
3106
|
+
PARAKEET_LOG_ERROR("%s: failed to load model\n", __func__);
|
|
3107
|
+
delete ctx;
|
|
3108
|
+
return nullptr;
|
|
3109
|
+
}
|
|
3110
|
+
|
|
3111
|
+
loader->close(loader->context);
|
|
3112
|
+
|
|
3113
|
+
// Initialize mel cache with model's FFT size
|
|
3114
|
+
ctx->mel_cache.init(ctx->model.hparams.n_fft);
|
|
3115
|
+
PARAKEET_LOG_INFO("%s: initialized mel cache with n_fft = %d\n", __func__, ctx->model.hparams.n_fft);
|
|
3116
|
+
|
|
3117
|
+
return ctx;
|
|
3118
|
+
}
|
|
3119
|
+
|
|
3120
|
+
struct parakeet_context * parakeet_init_from_file_with_params(const char * path_model, struct parakeet_context_params params) {
|
|
3121
|
+
parakeet_context * ctx = parakeet_init_from_file_with_params_no_state(path_model, params);
|
|
3122
|
+
if (!ctx) {
|
|
3123
|
+
return nullptr;
|
|
3124
|
+
}
|
|
3125
|
+
|
|
3126
|
+
ctx->state = parakeet_init_state(ctx);
|
|
3127
|
+
if (!ctx->state) {
|
|
3128
|
+
parakeet_free(ctx);
|
|
3129
|
+
return nullptr;
|
|
3130
|
+
}
|
|
3131
|
+
|
|
3132
|
+
return ctx;
|
|
3133
|
+
}
|
|
3134
|
+
|
|
3135
|
+
struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params) {
|
|
3136
|
+
parakeet_context * ctx = parakeet_init_from_buffer_with_params_no_state(buffer, buffer_size, params);
|
|
3137
|
+
if (!ctx) {
|
|
3138
|
+
return nullptr;
|
|
3139
|
+
}
|
|
3140
|
+
|
|
3141
|
+
ctx->state = parakeet_init_state(ctx);
|
|
3142
|
+
if (!ctx->state) {
|
|
3143
|
+
parakeet_free(ctx);
|
|
3144
|
+
return nullptr;
|
|
3145
|
+
}
|
|
3146
|
+
|
|
3147
|
+
return ctx;
|
|
3148
|
+
}
|
|
3149
|
+
|
|
3150
|
+
struct parakeet_context * parakeet_init_with_params(struct parakeet_model_loader * loader, struct parakeet_context_params params) {
|
|
3151
|
+
parakeet_context * ctx = parakeet_init_with_params_no_state(loader, params);
|
|
3152
|
+
if (!ctx) {
|
|
3153
|
+
return nullptr;
|
|
3154
|
+
}
|
|
3155
|
+
|
|
3156
|
+
ctx->state = parakeet_init_state(ctx);
|
|
3157
|
+
if (!ctx->state) {
|
|
3158
|
+
parakeet_free(ctx);
|
|
3159
|
+
return nullptr;
|
|
3160
|
+
}
|
|
3161
|
+
|
|
3162
|
+
return ctx;
|
|
3163
|
+
}
|
|
3164
|
+
|
|
3165
|
+
void parakeet_free_state(struct parakeet_state * state) {
|
|
3166
|
+
if (state) {
|
|
3167
|
+
ggml_backend_buffer_free(state->lstm_state.buffer);
|
|
3168
|
+
ggml_backend_buffer_free(state->pred_out_buffer);
|
|
3169
|
+
ggml_backend_buffer_free(state->enc_out_buffer);
|
|
3170
|
+
|
|
3171
|
+
parakeet_batch_free(state->batch);
|
|
3172
|
+
|
|
3173
|
+
parakeet_sched_free(state->sched_encode);
|
|
3174
|
+
parakeet_sched_free(state->sched_decode);
|
|
3175
|
+
|
|
3176
|
+
for (auto & backend : state->backends) {
|
|
3177
|
+
ggml_backend_free(backend);
|
|
3178
|
+
}
|
|
3179
|
+
|
|
3180
|
+
delete state;
|
|
3181
|
+
}
|
|
3182
|
+
}
|
|
3183
|
+
|
|
3184
|
+
void parakeet_free(struct parakeet_context * ctx) {
|
|
3185
|
+
if (ctx) {
|
|
3186
|
+
for (ggml_context * context : ctx->model.ctxs) {
|
|
3187
|
+
ggml_free(context);
|
|
3188
|
+
}
|
|
3189
|
+
|
|
3190
|
+
for (ggml_backend_buffer_t buf : ctx->model.buffers) {
|
|
3191
|
+
ggml_backend_buffer_free(buf);
|
|
3192
|
+
}
|
|
3193
|
+
|
|
3194
|
+
parakeet_free_state(ctx->state);
|
|
3195
|
+
|
|
3196
|
+
delete ctx;
|
|
3197
|
+
}
|
|
3198
|
+
}
|
|
3199
|
+
|
|
3200
|
+
void parakeet_free_context_params(struct parakeet_context_params * params) {
|
|
3201
|
+
if (params) {
|
|
3202
|
+
delete params;
|
|
3203
|
+
}
|
|
3204
|
+
}
|
|
3205
|
+
|
|
3206
|
+
void parakeet_free_params(struct parakeet_full_params * params) {
|
|
3207
|
+
if (params) {
|
|
3208
|
+
delete params;
|
|
3209
|
+
}
|
|
3210
|
+
}
|
|
3211
|
+
|
|
3212
|
+
int parakeet_pcm_to_mel_with_state(struct parakeet_context * ctx, struct parakeet_state * state, const float * samples, int n_samples, int n_threads) {
|
|
3213
|
+
if (!log_mel_spectrogram(*state,
|
|
3214
|
+
samples,
|
|
3215
|
+
n_samples,
|
|
3216
|
+
PARAKEET_SAMPLE_RATE,
|
|
3217
|
+
ctx->model.hparams.n_fft,
|
|
3218
|
+
PARAKEET_HOP_LENGTH,
|
|
3219
|
+
ctx->model.filters.n_mel,
|
|
3220
|
+
n_threads,
|
|
3221
|
+
ctx->model.filters,
|
|
3222
|
+
false, // debug
|
|
3223
|
+
state->mel,
|
|
3224
|
+
ctx->mel_cache)) {
|
|
3225
|
+
PARAKEET_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
|
|
3226
|
+
return -1;
|
|
3227
|
+
}
|
|
3228
|
+
|
|
3229
|
+
return 0;
|
|
3230
|
+
}
|
|
3231
|
+
|
|
3232
|
+
int parakeet_pcm_to_mel(struct parakeet_context * ctx, const float * samples, int n_samples, int n_threads) {
|
|
3233
|
+
return parakeet_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
|
3234
|
+
}
|
|
3235
|
+
|
|
3236
|
+
int parakeet_set_mel_with_state(
|
|
3237
|
+
struct parakeet_context * ctx,
|
|
3238
|
+
struct parakeet_state * state,
|
|
3239
|
+
const float * data,
|
|
3240
|
+
int n_len,
|
|
3241
|
+
int n_mel) {
|
|
3242
|
+
if (n_mel != ctx->model.filters.n_mel) {
|
|
3243
|
+
PARAKEET_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel);
|
|
3244
|
+
return -1;
|
|
3245
|
+
}
|
|
3246
|
+
|
|
3247
|
+
state->mel.n_len = n_len;
|
|
3248
|
+
state->mel.n_len_org = n_len;
|
|
3249
|
+
state->mel.n_mel = n_mel;
|
|
3250
|
+
|
|
3251
|
+
state->mel.data.resize(n_len*n_mel);
|
|
3252
|
+
memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
|
|
3253
|
+
|
|
3254
|
+
return 0;
|
|
3255
|
+
}
|
|
3256
|
+
|
|
3257
|
+
int parakeet_set_mel(
|
|
3258
|
+
struct parakeet_context * ctx,
|
|
3259
|
+
const float * data,
|
|
3260
|
+
int n_len,
|
|
3261
|
+
int n_mel) {
|
|
3262
|
+
return parakeet_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel);
|
|
3263
|
+
}
|
|
3264
|
+
|
|
3265
|
+
int parakeet_encode_with_state(struct parakeet_context * ctx, struct parakeet_state * state, int offset, int n_threads) {
|
|
3266
|
+
if (!parakeet_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
|
|
3267
|
+
PARAKEET_LOG_ERROR("%s: failed to eval\n", __func__);
|
|
3268
|
+
return -1;
|
|
3269
|
+
}
|
|
3270
|
+
|
|
3271
|
+
return 0;
|
|
3272
|
+
}
|
|
3273
|
+
|
|
3274
|
+
int parakeet_encode(struct parakeet_context * ctx, int offset, int n_threads) {
|
|
3275
|
+
if (!parakeet_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
|
|
3276
|
+
PARAKEET_LOG_ERROR("%s: failed to eval\n", __func__);
|
|
3277
|
+
return -1;
|
|
3278
|
+
}
|
|
3279
|
+
|
|
3280
|
+
return 0;
|
|
3281
|
+
}
|
|
3282
|
+
|
|
3283
|
+
int parakeet_tokenize(struct parakeet_context * ctx, const char * text, parakeet_token * tokens, int n_max_tokens) {
|
|
3284
|
+
const auto res = tokenize(ctx->vocab, text);
|
|
3285
|
+
|
|
3286
|
+
if (n_max_tokens < (int) res.size()) {
|
|
3287
|
+
PARAKEET_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
|
|
3288
|
+
return -(int) res.size();
|
|
3289
|
+
}
|
|
3290
|
+
|
|
3291
|
+
for (int i = 0; i < (int) res.size(); i++) {
|
|
3292
|
+
tokens[i] = res[i];
|
|
3293
|
+
}
|
|
3294
|
+
|
|
3295
|
+
return res.size();
|
|
3296
|
+
}
|
|
3297
|
+
|
|
3298
|
+
int parakeet_token_count(struct parakeet_context * ctx, const char * text) {
|
|
3299
|
+
return -parakeet_tokenize(ctx, text, NULL, 0);
|
|
3300
|
+
}
|
|
3301
|
+
|
|
3302
|
+
int parakeet_model_n_vocab(struct parakeet_context * ctx) {
|
|
3303
|
+
return ctx->model.hparams.n_vocab;
|
|
3304
|
+
}
|
|
3305
|
+
|
|
3306
|
+
int parakeet_model_n_audio_ctx(struct parakeet_context * ctx) {
|
|
3307
|
+
return ctx->model.hparams.n_audio_ctx;
|
|
3308
|
+
}
|
|
3309
|
+
|
|
3310
|
+
int parakeet_model_n_audio_state(struct parakeet_context * ctx) {
|
|
3311
|
+
return ctx->model.hparams.n_audio_state;
|
|
3312
|
+
}
|
|
3313
|
+
|
|
3314
|
+
int parakeet_model_n_audio_head(struct parakeet_context * ctx) {
|
|
3315
|
+
return ctx->model.hparams.n_audio_head;
|
|
3316
|
+
}
|
|
3317
|
+
|
|
3318
|
+
int parakeet_model_n_audio_layer(struct parakeet_context * ctx) {
|
|
3319
|
+
return ctx->model.hparams.n_audio_layer;
|
|
3320
|
+
}
|
|
3321
|
+
|
|
3322
|
+
int parakeet_model_n_mels(struct parakeet_context * ctx) {
|
|
3323
|
+
return ctx->model.hparams.n_mels;
|
|
3324
|
+
}
|
|
3325
|
+
|
|
3326
|
+
int parakeet_model_ftype(struct parakeet_context * ctx) {
|
|
3327
|
+
return ctx->model.hparams.ftype;
|
|
3328
|
+
}
|
|
3329
|
+
|
|
3330
|
+
int parakeet_n_len_from_state(struct parakeet_state * state) {
|
|
3331
|
+
return state->mel.n_len_org;
|
|
3332
|
+
}
|
|
3333
|
+
|
|
3334
|
+
int parakeet_n_len(struct parakeet_context * ctx) {
|
|
3335
|
+
return ctx->state->mel.n_len_org;
|
|
3336
|
+
}
|
|
3337
|
+
|
|
3338
|
+
int parakeet_n_vocab(struct parakeet_context * ctx) {
|
|
3339
|
+
return ctx->vocab.n_vocab;
|
|
3340
|
+
}
|
|
3341
|
+
|
|
3342
|
+
int parakeet_n_audio_ctx(struct parakeet_context * ctx) {
|
|
3343
|
+
return ctx->model.hparams.n_audio_ctx;
|
|
3344
|
+
}
|
|
3345
|
+
|
|
3346
|
+
float * parakeet_get_logits(struct parakeet_context * ctx) {
|
|
3347
|
+
return ctx->state->logits.data();
|
|
3348
|
+
}
|
|
3349
|
+
|
|
3350
|
+
float * parakeet_get_logits_from_state(struct parakeet_state * state) {
|
|
3351
|
+
return state->logits.data();
|
|
3352
|
+
}
|
|
3353
|
+
|
|
3354
|
+
const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token) {
|
|
3355
|
+
return ctx->vocab.id_to_token.at(token).c_str();
|
|
3356
|
+
}
|
|
3357
|
+
|
|
3358
|
+
int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len) {
|
|
3359
|
+
std::string text = sentencepiece_piece_to_text(token_str, is_first);
|
|
3360
|
+
|
|
3361
|
+
if (output == nullptr) {
|
|
3362
|
+
return text.size();
|
|
3363
|
+
}
|
|
3364
|
+
|
|
3365
|
+
int bytes_to_copy = std::min((int)text.size(), max_len - 1);
|
|
3366
|
+
if (bytes_to_copy > 0) {
|
|
3367
|
+
memcpy(output, text.c_str(), bytes_to_copy);
|
|
3368
|
+
output[bytes_to_copy] = '\0';
|
|
3369
|
+
} else if (max_len > 0) {
|
|
3370
|
+
output[0] = '\0';
|
|
3371
|
+
}
|
|
3372
|
+
|
|
3373
|
+
return text.size();
|
|
3374
|
+
}
|
|
3375
|
+
|
|
3376
|
+
parakeet_token parakeet_token_bos(struct parakeet_context * ctx) {
|
|
3377
|
+
return ctx->vocab.token_bos;
|
|
3378
|
+
}
|
|
3379
|
+
|
|
3380
|
+
parakeet_token parakeet_token_unk(struct parakeet_context * ctx) {
|
|
3381
|
+
return ctx->vocab.token_unk;
|
|
3382
|
+
}
|
|
3383
|
+
|
|
3384
|
+
parakeet_token parakeet_token_blank(struct parakeet_context * ctx) {
|
|
3385
|
+
return ctx->vocab.token_blank;
|
|
3386
|
+
}
|
|
3387
|
+
|
|
3388
|
+
struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx) {
|
|
3389
|
+
if (ctx->state == nullptr) {
|
|
3390
|
+
return nullptr;
|
|
3391
|
+
}
|
|
3392
|
+
parakeet_timings * timings = new parakeet_timings;
|
|
3393
|
+
timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample);
|
|
3394
|
+
timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode);
|
|
3395
|
+
timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode);
|
|
3396
|
+
return timings;
|
|
3397
|
+
}
|
|
3398
|
+
|
|
3399
|
+
void parakeet_print_timings(struct parakeet_context * ctx) {
|
|
3400
|
+
const int64_t t_end_us = ggml_time_us();
|
|
3401
|
+
|
|
3402
|
+
PARAKEET_LOG_INFO("\n");
|
|
3403
|
+
PARAKEET_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
|
|
3404
|
+
if (ctx->state != nullptr) {
|
|
3405
|
+
|
|
3406
|
+
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
|
3407
|
+
const int32_t n_encode = std::max(1, ctx->state->n_encode);
|
|
3408
|
+
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
|
3409
|
+
const int32_t n_predict = std::max(1, ctx->state->n_predict);
|
|
3410
|
+
|
|
3411
|
+
PARAKEET_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
|
3412
|
+
PARAKEET_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
|
|
3413
|
+
PARAKEET_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
|
3414
|
+
PARAKEET_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
|
3415
|
+
PARAKEET_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
|
3416
|
+
PARAKEET_LOG_INFO("%s: predict time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_us, n_predict, 1e-3f * ctx->state->t_predict_us / n_predict);
|
|
3417
|
+
PARAKEET_LOG_INFO("%s: - build = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_build_us, n_predict, 1e-3f * ctx->state->t_predict_build_us / n_predict);
|
|
3418
|
+
PARAKEET_LOG_INFO("%s: - alloc = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_alloc_us, n_predict, 1e-3f * ctx->state->t_predict_alloc_us / n_predict);
|
|
3419
|
+
PARAKEET_LOG_INFO("%s: - compute = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_compute_us, n_predict, 1e-3f * ctx->state->t_predict_compute_us / n_predict);
|
|
3420
|
+
|
|
3421
|
+
}
|
|
3422
|
+
PARAKEET_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
|
3423
|
+
}
|
|
3424
|
+
|
|
3425
|
+
void parakeet_reset_timings(struct parakeet_context * ctx) {
|
|
3426
|
+
ctx->t_start_us = ggml_time_us();
|
|
3427
|
+
if (ctx->state != nullptr) {
|
|
3428
|
+
ctx->state->t_mel_us = 0;
|
|
3429
|
+
ctx->state->t_sample_us = 0;
|
|
3430
|
+
ctx->state->t_encode_us = 0;
|
|
3431
|
+
ctx->state->t_decode_us = 0;
|
|
3432
|
+
ctx->state->t_predict_us = 0;
|
|
3433
|
+
ctx->state->t_predict_build_us = 0;
|
|
3434
|
+
ctx->state->t_predict_alloc_us = 0;
|
|
3435
|
+
ctx->state->t_predict_compute_us = 0;
|
|
3436
|
+
|
|
3437
|
+
ctx->state->n_sample = 0;
|
|
3438
|
+
ctx->state->n_encode = 0;
|
|
3439
|
+
ctx->state->n_decode = 0;
|
|
3440
|
+
ctx->state->n_predict = 0;
|
|
3441
|
+
}
|
|
3442
|
+
}
|
|
3443
|
+
|
|
3444
|
+
const char * parakeet_print_system_info(void) {
|
|
3445
|
+
static std::string s;
|
|
3446
|
+
|
|
3447
|
+
s = "";
|
|
3448
|
+
s += "PARAKEET : ";
|
|
3449
|
+
|
|
3450
|
+
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
|
3451
|
+
auto * reg = ggml_backend_reg_get(i);
|
|
3452
|
+
auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");
|
|
3453
|
+
if (get_features_fn) {
|
|
3454
|
+
ggml_backend_feature * features = get_features_fn(reg);
|
|
3455
|
+
s += ggml_backend_reg_name(reg);
|
|
3456
|
+
s += " : ";
|
|
3457
|
+
for (; features->name; features++) {
|
|
3458
|
+
s += features->name;
|
|
3459
|
+
s += " = ";
|
|
3460
|
+
s += features->value;
|
|
3461
|
+
s += " | ";
|
|
3462
|
+
}
|
|
3463
|
+
}
|
|
3464
|
+
}
|
|
3465
|
+
return s.c_str();
|
|
3466
|
+
}
|
|
3467
|
+
|
|
3468
|
+
struct parakeet_context_params * parakeet_context_default_params_by_ref(void) {
|
|
3469
|
+
struct parakeet_context_params params = parakeet_context_default_params();
|
|
3470
|
+
|
|
3471
|
+
struct parakeet_context_params* result = new parakeet_context_params();
|
|
3472
|
+
*result = params;
|
|
3473
|
+
return result;
|
|
3474
|
+
}
|
|
3475
|
+
|
|
3476
|
+
struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy) {
|
|
3477
|
+
struct parakeet_full_params params = parakeet_full_default_params(strategy);
|
|
3478
|
+
|
|
3479
|
+
struct parakeet_full_params* result = new parakeet_full_params();
|
|
3480
|
+
*result = params;
|
|
3481
|
+
return result;
|
|
3482
|
+
}
|
|
3483
|
+
|
|
3484
|
+
struct parakeet_full_params parakeet_full_default_params(enum parakeet_sampling_strategy strategy) {
|
|
3485
|
+
struct parakeet_full_params result = {
|
|
3486
|
+
/*.strategy =*/ strategy,
|
|
3487
|
+
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
|
|
3488
|
+
/*.offset_ms =*/ 0,
|
|
3489
|
+
/*.duration_ms =*/ 0,
|
|
3490
|
+
/*.no_context =*/ true,
|
|
3491
|
+
/*.audio_ctx =*/ 0,
|
|
3492
|
+
/*.new_token_callback =*/ nullptr,
|
|
3493
|
+
/*.new_token_callback_user_data =*/ nullptr,
|
|
3494
|
+
/*.new_segment_callback =*/ nullptr,
|
|
3495
|
+
/*.new_segment_callback_user_data =*/ nullptr,
|
|
3496
|
+
/*.progress_callback =*/ nullptr,
|
|
3497
|
+
/*.progress_callback_user_data =*/ nullptr,
|
|
3498
|
+
/*.encoder_begin_callback =*/ nullptr,
|
|
3499
|
+
/*.encoder_begin_callback_user_data =*/ nullptr,
|
|
3500
|
+
/*.abort_callback =*/ nullptr,
|
|
3501
|
+
/*.abort_callback_user_data =*/ nullptr,
|
|
3502
|
+
};
|
|
3503
|
+
|
|
3504
|
+
return result;
|
|
3505
|
+
}
|
|
3506
|
+
|
|
3507
|
+
static void parakeet_reset_state(struct parakeet_state * state) {
|
|
3508
|
+
state->decoded_tokens.clear();
|
|
3509
|
+
state->decoded_token_data.clear();
|
|
3510
|
+
|
|
3511
|
+
if (state->lstm_state.buffer) {
|
|
3512
|
+
ggml_backend_buffer_clear(state->lstm_state.buffer, 0);
|
|
3513
|
+
}
|
|
3514
|
+
|
|
3515
|
+
}
|
|
3516
|
+
|
|
3517
|
+
// Encode and decode the mel spectrogram already in state, without recomputing it.
|
|
3518
|
+
static int parakeet_chunk_with_state(
|
|
3519
|
+
struct parakeet_context * ctx,
|
|
3520
|
+
struct parakeet_state * state,
|
|
3521
|
+
struct parakeet_full_params params) {
|
|
3522
|
+
return parakeet_chunk(ctx, state, params, nullptr, 0);
|
|
3523
|
+
}
|
|
3524
|
+
|
|
3525
|
+
int parakeet_full_with_state(
|
|
3526
|
+
struct parakeet_context * ctx,
|
|
3527
|
+
struct parakeet_state * state,
|
|
3528
|
+
struct parakeet_full_params params,
|
|
3529
|
+
const float * samples,
|
|
3530
|
+
int n_samples) {
|
|
3531
|
+
state->result_all.clear();
|
|
3532
|
+
|
|
3533
|
+
if (params.no_context) {
|
|
3534
|
+
parakeet_reset_state(state);
|
|
3535
|
+
}
|
|
3536
|
+
|
|
3537
|
+
if (n_samples > 0) {
|
|
3538
|
+
if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
|
3539
|
+
PARAKEET_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
|
3540
|
+
return -2;
|
|
3541
|
+
}
|
|
3542
|
+
}
|
|
3543
|
+
|
|
3544
|
+
const int n_mel_total = state->mel.n_len;
|
|
3545
|
+
const int n_audio_ctx = ctx->model.hparams.n_audio_ctx;
|
|
3546
|
+
|
|
3547
|
+
if (n_mel_total <= n_audio_ctx) {
|
|
3548
|
+
if (params.progress_callback) {
|
|
3549
|
+
params.progress_callback(ctx, state, 0, params.progress_callback_user_data);
|
|
3550
|
+
}
|
|
3551
|
+
return parakeet_chunk_with_state(ctx, state, params);
|
|
3552
|
+
}
|
|
3553
|
+
|
|
3554
|
+
PARAKEET_LOG_DEBUG("%s: audio too long (%d mel > n_audio_ctx=%d), using dynamic encoder graph\n",
|
|
3555
|
+
__func__, n_mel_total, n_audio_ctx);
|
|
3556
|
+
|
|
3557
|
+
if (params.encoder_begin_callback) {
|
|
3558
|
+
if (!params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data)) {
|
|
3559
|
+
PARAKEET_LOG_ERROR("%s: encoder_begin_callback returned false\n", __func__);
|
|
3560
|
+
return -6;
|
|
3561
|
+
}
|
|
3562
|
+
}
|
|
3563
|
+
|
|
3564
|
+
if (params.progress_callback) {
|
|
3565
|
+
params.progress_callback(ctx, state, 0, params.progress_callback_user_data);
|
|
3566
|
+
}
|
|
3567
|
+
|
|
3568
|
+
if (!parakeet_ensure_encode_sched(*ctx, *state, n_mel_total)) {
|
|
3569
|
+
PARAKEET_LOG_ERROR("%s: failed to allocate dynamic encoder graph for %d mel frames\n",
|
|
3570
|
+
__func__, n_mel_total);
|
|
3571
|
+
return -6;
|
|
3572
|
+
}
|
|
3573
|
+
|
|
3574
|
+
state->n_audio_ctx = n_mel_total;
|
|
3575
|
+
|
|
3576
|
+
if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads,
|
|
3577
|
+
params.abort_callback, params.abort_callback_user_data)) {
|
|
3578
|
+
PARAKEET_LOG_ERROR("%s: failed to encode\n", __func__);
|
|
3579
|
+
return -6;
|
|
3580
|
+
}
|
|
3581
|
+
|
|
3582
|
+
if (params.progress_callback) {
|
|
3583
|
+
params.progress_callback(ctx, state, 100, params.progress_callback_user_data);
|
|
3584
|
+
}
|
|
3585
|
+
|
|
3586
|
+
const size_t tokens_before = state->decoded_tokens.size();
|
|
3587
|
+
|
|
3588
|
+
if (!parakeet_decode(*ctx, *state, state->batch, params.n_threads, ¶ms)) {
|
|
3589
|
+
PARAKEET_LOG_ERROR("%s: failed to decode\n", __func__);
|
|
3590
|
+
return -7;
|
|
3591
|
+
}
|
|
3592
|
+
|
|
3593
|
+
const size_t tokens_after = state->decoded_tokens.size();
|
|
3594
|
+
const size_t new_token_count = tokens_after - tokens_before;
|
|
3595
|
+
|
|
3596
|
+
if (new_token_count > 0) {
|
|
3597
|
+
std::string text;
|
|
3598
|
+
std::vector<parakeet_token_data> result_tokens;
|
|
3599
|
+
|
|
3600
|
+
for (size_t i = tokens_before; i < tokens_after; i++) {
|
|
3601
|
+
const auto token_id = state->decoded_tokens[i];
|
|
3602
|
+
const char * tok_str = parakeet_token_to_str(ctx, token_id);
|
|
3603
|
+
if (tok_str) {
|
|
3604
|
+
const bool is_first = (tokens_before == 0) && text.empty();
|
|
3605
|
+
text += sentencepiece_piece_to_text(tok_str, is_first);
|
|
3606
|
+
}
|
|
3607
|
+
result_tokens.push_back(state->decoded_token_data[i]);
|
|
3608
|
+
}
|
|
3609
|
+
|
|
3610
|
+
refine_timestamps_tdt(ctx->vocab, result_tokens);
|
|
3611
|
+
|
|
3612
|
+
if (!text.empty()) {
|
|
3613
|
+
parakeet_segment seg;
|
|
3614
|
+
seg.t0 = 0;
|
|
3615
|
+
seg.t1 = state->n_frames;
|
|
3616
|
+
seg.text = text;
|
|
3617
|
+
seg.tokens = result_tokens;
|
|
3618
|
+
state->result_all.push_back(std::move(seg));
|
|
3619
|
+
|
|
3620
|
+
if (params.new_segment_callback) {
|
|
3621
|
+
params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data);
|
|
3622
|
+
}
|
|
3623
|
+
}
|
|
3624
|
+
}
|
|
3625
|
+
|
|
3626
|
+
return 0;
|
|
3627
|
+
}
|
|
3628
|
+
|
|
3629
|
+
int parakeet_full(
|
|
3630
|
+
struct parakeet_context * ctx,
|
|
3631
|
+
struct parakeet_full_params params,
|
|
3632
|
+
const float * samples,
|
|
3633
|
+
int n_samples) {
|
|
3634
|
+
return parakeet_full_with_state(ctx, ctx->state, params, samples, n_samples);
|
|
3635
|
+
}
|
|
3636
|
+
|
|
3637
|
+
int parakeet_chunk(
|
|
3638
|
+
struct parakeet_context * ctx,
|
|
3639
|
+
struct parakeet_state * state,
|
|
3640
|
+
struct parakeet_full_params params,
|
|
3641
|
+
const float * samples,
|
|
3642
|
+
int n_samples) {
|
|
3643
|
+
|
|
3644
|
+
if (params.no_context) {
|
|
3645
|
+
parakeet_reset_state(state);
|
|
3646
|
+
}
|
|
3647
|
+
|
|
3648
|
+
if (n_samples > 0) {
|
|
3649
|
+
if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
|
3650
|
+
PARAKEET_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
|
3651
|
+
return -2;
|
|
3652
|
+
}
|
|
3653
|
+
}
|
|
3654
|
+
|
|
3655
|
+
if (params.audio_ctx == 0) {
|
|
3656
|
+
const int total_len = parakeet_n_len_from_state(state);
|
|
3657
|
+
const int model_max_ctx = parakeet_n_audio_ctx(ctx);
|
|
3658
|
+
params.audio_ctx = std::min(total_len, model_max_ctx);
|
|
3659
|
+
PARAKEET_LOG_DEBUG("Processing audio: total_frames=%d, chunk_size=%d\n", total_len, params.audio_ctx);
|
|
3660
|
+
}
|
|
3661
|
+
state->n_audio_ctx = params.audio_ctx;
|
|
3662
|
+
|
|
3663
|
+
const int n_frames = parakeet_n_len_from_state(state);
|
|
3664
|
+
|
|
3665
|
+
if (!parakeet_ensure_encode_sched(*ctx, *state, state->n_audio_ctx)) {
|
|
3666
|
+
PARAKEET_LOG_ERROR("%s: failed to allocate encoder graph for %d mel frames\n",
|
|
3667
|
+
__func__, state->n_audio_ctx);
|
|
3668
|
+
return -6;
|
|
3669
|
+
}
|
|
3670
|
+
|
|
3671
|
+
if (params.encoder_begin_callback) {
|
|
3672
|
+
if (!params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data)) {
|
|
3673
|
+
PARAKEET_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__);
|
|
3674
|
+
return -6;
|
|
3675
|
+
}
|
|
3676
|
+
}
|
|
3677
|
+
if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
|
3678
|
+
PARAKEET_LOG_ERROR("%s: failed to encode\n", __func__);
|
|
3679
|
+
return -6;
|
|
3680
|
+
}
|
|
3681
|
+
|
|
3682
|
+
const size_t tokens_before = state->decoded_tokens.size();
|
|
3683
|
+
|
|
3684
|
+
if (!parakeet_decode(*ctx, *state, state->batch, params.n_threads, ¶ms)) {
|
|
3685
|
+
PARAKEET_LOG_ERROR("%s: failed to decode\n", __func__);
|
|
3686
|
+
return -7;
|
|
3687
|
+
}
|
|
3688
|
+
|
|
3689
|
+
const size_t tokens_after = state->decoded_tokens.size();
|
|
3690
|
+
const size_t new_token_count = tokens_after - tokens_before;
|
|
3691
|
+
|
|
3692
|
+
if (new_token_count > 0) {
|
|
3693
|
+
std::string text;
|
|
3694
|
+
std::vector<parakeet_token_data> result_tokens;
|
|
3695
|
+
|
|
3696
|
+
for (size_t i = tokens_before; i < tokens_after; i++) {
|
|
3697
|
+
const auto token_id = state->decoded_tokens[i];
|
|
3698
|
+
const char * token_str = parakeet_token_to_str(ctx, token_id);
|
|
3699
|
+
if (token_str) {
|
|
3700
|
+
const bool is_first_piece = (tokens_before == 0) && text.empty();
|
|
3701
|
+
text += sentencepiece_piece_to_text(token_str, is_first_piece);
|
|
3702
|
+
}
|
|
3703
|
+
|
|
3704
|
+
// Use the stored token data from parakeet_decode
|
|
3705
|
+
result_tokens.push_back(state->decoded_token_data[i]);
|
|
3706
|
+
}
|
|
3707
|
+
|
|
3708
|
+
refine_timestamps_tdt(ctx->vocab, result_tokens);
|
|
3709
|
+
|
|
3710
|
+
if (!text.empty()) {
|
|
3711
|
+
parakeet_segment segment;
|
|
3712
|
+
segment.t0 = 0; // Caller tracks timing
|
|
3713
|
+
segment.t1 = n_frames;
|
|
3714
|
+
segment.text = text;
|
|
3715
|
+
segment.tokens = result_tokens;
|
|
3716
|
+
|
|
3717
|
+
state->result_all.push_back(std::move(segment));
|
|
3718
|
+
|
|
3719
|
+
if (params.new_segment_callback) {
|
|
3720
|
+
params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data);
|
|
3721
|
+
}
|
|
3722
|
+
}
|
|
3723
|
+
}
|
|
3724
|
+
|
|
3725
|
+
return 0;
|
|
3726
|
+
}
|
|
3727
|
+
|
|
3728
|
+
int parakeet_full_n_segments_from_state(struct parakeet_state * state) {
|
|
3729
|
+
return state->result_all.size();
|
|
3730
|
+
}
|
|
3731
|
+
|
|
3732
|
+
int parakeet_full_n_segments(struct parakeet_context * ctx) {
|
|
3733
|
+
return ctx->state->result_all.size();
|
|
3734
|
+
}
|
|
3735
|
+
|
|
3736
|
+
int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment) {
|
|
3737
|
+
return state->result_all[i_segment].t0;
|
|
3738
|
+
}
|
|
3739
|
+
|
|
3740
|
+
int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment) {
|
|
3741
|
+
return state->result_all[i_segment].t1;
|
|
3742
|
+
}
|
|
3743
|
+
|
|
3744
|
+
int64_t parakeet_full_get_segment_t0(struct parakeet_context * ctx, int i_segment) {
|
|
3745
|
+
return parakeet_full_get_segment_t0_from_state(ctx->state, i_segment);
|
|
3746
|
+
}
|
|
3747
|
+
|
|
3748
|
+
int64_t parakeet_full_get_segment_t1(struct parakeet_context * ctx, int i_segment) {
|
|
3749
|
+
return parakeet_full_get_segment_t1_from_state(ctx->state, i_segment);
|
|
3750
|
+
}
|
|
3751
|
+
|
|
3752
|
+
const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment) {
|
|
3753
|
+
return state->result_all[i_segment].text.c_str();
|
|
3754
|
+
}
|
|
3755
|
+
|
|
3756
|
+
const char * parakeet_full_get_segment_text(struct parakeet_context * ctx, int i_segment) {
|
|
3757
|
+
return ctx->state->result_all[i_segment].text.c_str();
|
|
3758
|
+
}
|
|
3759
|
+
|
|
3760
|
+
int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment) {
|
|
3761
|
+
return state->result_all[i_segment].tokens.size();
|
|
3762
|
+
}
|
|
3763
|
+
|
|
3764
|
+
int parakeet_full_n_tokens(struct parakeet_context * ctx, int i_segment) {
|
|
3765
|
+
return ctx->state->result_all[i_segment].tokens.size();
|
|
3766
|
+
}
|
|
3767
|
+
|
|
3768
|
+
const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token) {
|
|
3769
|
+
return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str();
|
|
3770
|
+
}
|
|
3771
|
+
|
|
3772
|
+
const char* parakeet_full_get_token_text(struct parakeet_context * ctx, int i_segment, int i_token) {
|
|
3773
|
+
return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str();
|
|
3774
|
+
}
|
|
3775
|
+
|
|
3776
|
+
parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token) {
|
|
3777
|
+
return state->result_all[i_segment].tokens[i_token].id;
|
|
3778
|
+
}
|
|
3779
|
+
|
|
3780
|
+
parakeet_token parakeet_full_get_token_id(struct parakeet_context * ctx, int i_segment, int i_token) {
|
|
3781
|
+
return ctx->state->result_all[i_segment].tokens[i_token].id;
|
|
3782
|
+
}
|
|
3783
|
+
|
|
3784
|
+
struct parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token) {
|
|
3785
|
+
return state->result_all[i_segment].tokens[i_token];
|
|
3786
|
+
}
|
|
3787
|
+
|
|
3788
|
+
struct parakeet_token_data parakeet_full_get_token_data(struct parakeet_context * ctx, int i_segment, int i_token) {
|
|
3789
|
+
return ctx->state->result_all[i_segment].tokens[i_token];
|
|
3790
|
+
}
|
|
3791
|
+
|
|
3792
|
+
float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token) {
|
|
3793
|
+
return state->result_all[i_segment].tokens[i_token].p;
|
|
3794
|
+
}
|
|
3795
|
+
|
|
3796
|
+
float parakeet_full_get_token_p(struct parakeet_context * ctx, int i_segment, int i_token) {
|
|
3797
|
+
return ctx->state->result_all[i_segment].tokens[i_token].p;
|
|
3798
|
+
}
|
|
3799
|
+
|
|
3800
|
+
void parakeet_log_set(ggml_log_callback log_callback, void * user_data) {
|
|
3801
|
+
g_state.log_callback = log_callback ? log_callback : parakeet_log_callback_default;
|
|
3802
|
+
g_state.log_callback_user_data = user_data;
|
|
3803
|
+
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
|
3804
|
+
}
|
|
3805
|
+
|
|
3806
|
+
const char * parakeet_version(void) {
|
|
3807
|
+
return PARAKEET_VERSION;
|
|
3808
|
+
}
|
|
3809
|
+
|
|
3810
|
+
GGML_ATTRIBUTE_FORMAT(2, 3)
|
|
3811
|
+
static void parakeet_log_internal(ggml_log_level level, const char * format, ...) {
|
|
3812
|
+
va_list args;
|
|
3813
|
+
va_start(args, format);
|
|
3814
|
+
char buffer[1024];
|
|
3815
|
+
int len = vsnprintf(buffer, 1024, format, args);
|
|
3816
|
+
if (len < 1024) {
|
|
3817
|
+
g_state.log_callback(level, buffer, g_state.log_callback_user_data);
|
|
3818
|
+
} else {
|
|
3819
|
+
char* buffer2 = new char[len+1];
|
|
3820
|
+
vsnprintf(buffer2, len+1, format, args);
|
|
3821
|
+
buffer2[len] = 0;
|
|
3822
|
+
g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
|
|
3823
|
+
delete[] buffer2;
|
|
3824
|
+
}
|
|
3825
|
+
va_end(args);
|
|
3826
|
+
}
|
|
3827
|
+
|
|
3828
|
+
static void parakeet_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
|
|
3829
|
+
(void) level;
|
|
3830
|
+
(void) user_data;
|
|
3831
|
+
#ifndef PARAKEET_DEBUG
|
|
3832
|
+
if (level == GGML_LOG_LEVEL_DEBUG) {
|
|
3833
|
+
return;
|
|
3834
|
+
}
|
|
3835
|
+
#endif
|
|
3836
|
+
fputs(text, stderr);
|
|
3837
|
+
fflush(stderr);
|
|
3838
|
+
}
|