whispercpp 1.3.6 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/README.md +38 -5
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -8
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +36 -42
- data/ext/ruby_whisper.h +135 -0
- data/ext/ruby_whisper_context.c +107 -28
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -65
- data/ext/ruby_whisper_segment.c +6 -6
- data/ext/ruby_whisper_transcribe.cpp +42 -15
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +1 -1
- data/ext/sources/examples/cli/cli.cpp +43 -9
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +199 -163
- data/ext/sources/ggml/CMakeLists.txt +21 -13
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +72 -10
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-rpc.h +3 -3
- data/ext/sources/ggml/include/ggml.h +101 -9
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +22 -5
- data/ext/sources/ggml/src/ggml-alloc.c +5 -1
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
- data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
- data/ext/sources/ggml/src/ggml-impl.h +6 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
- data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +289 -114
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
- data/ext/sources/ggml/src/ggml.c +110 -28
- data/ext/sources/ggml/src/gguf.cpp +173 -28
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +56 -12
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +411 -62
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +24 -6
- data/whispercpp.gemspec +2 -2
- metadata +215 -281
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
- data/ext/sources/examples/talk-llama/llama-context.h +0 -359
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
- data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
- data/ext/sources/examples/talk-llama/llama-model.h +0 -597
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
- data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
- data/ext/sources/examples/talk-llama/llama.h +0 -1573
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -704
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
- /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
|
@@ -9,12 +9,13 @@
|
|
|
9
9
|
#include <string.h>
|
|
10
10
|
|
|
11
11
|
#include "hex-dma.h"
|
|
12
|
+
#include "hvx-exp.h"
|
|
13
|
+
#include "hvx-sigmoid.h"
|
|
12
14
|
#include "hvx-utils.h"
|
|
13
15
|
|
|
14
16
|
#define GGML_COMMON_DECL_C
|
|
15
17
|
#include "ggml-common.h"
|
|
16
18
|
#include "htp-ctx.h"
|
|
17
|
-
#include "htp-msg.h"
|
|
18
19
|
#include "htp-ops.h"
|
|
19
20
|
|
|
20
21
|
struct htp_unary_context {
|
|
@@ -22,23 +23,62 @@ struct htp_unary_context {
|
|
|
22
23
|
|
|
23
24
|
// Precomputed values
|
|
24
25
|
const uint8_t * data_src0;
|
|
26
|
+
const uint8_t * data_src1; // weight/scale tensor for RMS_NORM_MUL
|
|
25
27
|
uint8_t * data_dst;
|
|
26
28
|
|
|
27
|
-
size_t
|
|
28
|
-
size_t
|
|
29
|
+
size_t src0_data_row_size; // actual data bytes per row
|
|
30
|
+
size_t src1_data_row_size;
|
|
31
|
+
size_t dst_data_row_size; // actual data bytes per row
|
|
29
32
|
|
|
30
33
|
size_t src0_row_size_aligned;
|
|
34
|
+
size_t src1_row_size_aligned;
|
|
31
35
|
size_t dst_row_size_aligned;
|
|
32
36
|
|
|
33
37
|
size_t src0_spad_half_size;
|
|
38
|
+
size_t src1_spad_half_size;
|
|
34
39
|
size_t dst_spad_half_size;
|
|
35
40
|
|
|
36
41
|
uint32_t block;
|
|
37
42
|
uint32_t src0_nrows;
|
|
38
43
|
uint32_t src0_nrows_per_thread;
|
|
39
44
|
uint32_t nc;
|
|
45
|
+
bool broadcast_weight;
|
|
40
46
|
};
|
|
41
47
|
|
|
48
|
+
// Convert flat row index to DDR byte offset using the tensor's actual strides.
|
|
49
|
+
// ir = i1 + ne1*(i2 + ne2*i3) => offset = i1*nb1 + i2*nb2 + i3*nb3
|
|
50
|
+
static inline size_t unary_row_offset(uint32_t ir,
|
|
51
|
+
uint32_t ne1, uint32_t ne2,
|
|
52
|
+
size_t nb1, size_t nb2, size_t nb3) {
|
|
53
|
+
const uint32_t i1 = ir % ne1;
|
|
54
|
+
const uint32_t i2 = (ir / ne1) % ne2;
|
|
55
|
+
const uint32_t i3 = ir / (ne1 * ne2);
|
|
56
|
+
return i1 * nb1 + i2 * nb2 + i3 * nb3;
|
|
57
|
+
}
|
|
58
|
+
// Safe DMA block size from row `ir`: clamp to the tighter dim-1 slice
|
|
59
|
+
// boundary of src and dst so the nb1 stride stays valid for all rows.
|
|
60
|
+
static inline uint32_t unary_block_size(uint32_t ir,
|
|
61
|
+
uint32_t end_row,
|
|
62
|
+
uint32_t block,
|
|
63
|
+
bool src_contig,
|
|
64
|
+
bool dst_contig,
|
|
65
|
+
uint32_t src_ne1,
|
|
66
|
+
uint32_t dst_ne1) {
|
|
67
|
+
uint32_t limit = MIN(block, end_row - ir);
|
|
68
|
+
|
|
69
|
+
if (!src_contig) {
|
|
70
|
+
const uint32_t src_slice_end = (ir / src_ne1 + 1) * src_ne1;
|
|
71
|
+
limit = MIN(limit, src_slice_end - ir);
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
if (!dst_contig) {
|
|
75
|
+
const uint32_t dst_slice_end = (ir / dst_ne1 + 1) * dst_ne1;
|
|
76
|
+
limit = MIN(limit, dst_slice_end - ir);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
return limit;
|
|
80
|
+
}
|
|
81
|
+
|
|
42
82
|
#define htp_unary_preamble \
|
|
43
83
|
const uint32_t ne00 = src->ne[0]; \
|
|
44
84
|
const uint32_t ne01 = src->ne[1]; \
|
|
@@ -65,34 +105,199 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
|
|
65
105
|
uint8_t * restrict pad,
|
|
66
106
|
const int num_elems,
|
|
67
107
|
float epsilon) {
|
|
108
|
+
(void)pad;
|
|
109
|
+
|
|
68
110
|
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
|
69
111
|
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
|
70
112
|
|
|
71
|
-
|
|
113
|
+
const int nvec = num_elems / VLEN_FP32; // number of full vectors
|
|
114
|
+
const int nloe = num_elems % VLEN_FP32; // leftover elements
|
|
115
|
+
|
|
116
|
+
// Compute sum of squares for full vectors
|
|
117
|
+
HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
|
|
118
|
+
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
|
|
119
|
+
|
|
120
|
+
#pragma unroll(4)
|
|
121
|
+
for (int i = 0; i < nvec; i++) {
|
|
122
|
+
HVX_Vector v1 = v_src[i];
|
|
123
|
+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
|
124
|
+
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
// Handle tail elements using vectorized ops with masking
|
|
128
|
+
if (nloe > 0) {
|
|
129
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
130
|
+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
|
131
|
+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
|
132
|
+
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
// Reduce HVX sum
|
|
136
|
+
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
|
|
137
|
+
|
|
138
|
+
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
|
|
139
|
+
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
|
|
140
|
+
HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
|
|
141
|
+
HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
|
|
142
|
+
|
|
143
|
+
// Scale full vectors
|
|
144
|
+
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
|
|
145
|
+
|
|
146
|
+
#pragma unroll(4)
|
|
147
|
+
for (int i = 0; i < nvec; i++) {
|
|
148
|
+
HVX_Vector v1 = v_src[i];
|
|
149
|
+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
|
|
150
|
+
v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
// Handle tail elements using vectorized ops with masking
|
|
154
|
+
if (nloe > 0) {
|
|
155
|
+
|
|
156
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
157
|
+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
|
158
|
+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
|
|
159
|
+
HVX_Vector result = Q6_Vsf_equals_Vqf32(v2);
|
|
160
|
+
|
|
161
|
+
// Store with masking to avoid overwriting memory beyond the tensor
|
|
162
|
+
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
static void hvx_fast_rms_norm_mul_f32(const uint8_t * restrict src,
|
|
167
|
+
const uint8_t * restrict weight,
|
|
168
|
+
uint8_t * restrict dst,
|
|
169
|
+
const int num_elems,
|
|
170
|
+
float epsilon) {
|
|
171
|
+
const HVX_Vector * restrict v_src = (const HVX_Vector *) src;
|
|
172
|
+
const HVX_Vector * restrict v_weight = (const HVX_Vector *) weight;
|
|
173
|
+
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
|
174
|
+
|
|
175
|
+
const int nvec = num_elems / VLEN_FP32; // number of full vectors
|
|
176
|
+
const int nloe = num_elems % VLEN_FP32; // leftover elements
|
|
177
|
+
|
|
178
|
+
// Compute sum of squares for full vectors
|
|
179
|
+
HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
|
|
72
180
|
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
|
|
73
181
|
|
|
74
|
-
int step_of_1 = num_elems >> 5;
|
|
75
182
|
#pragma unroll(4)
|
|
76
|
-
for (int i = 0; i <
|
|
183
|
+
for (int i = 0; i < nvec; i++) {
|
|
77
184
|
HVX_Vector v1 = v_src[i];
|
|
78
185
|
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
|
79
|
-
sum_v
|
|
186
|
+
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
|
80
187
|
}
|
|
81
188
|
|
|
82
|
-
|
|
189
|
+
// Handle tail elements using vectorized ops with masking
|
|
190
|
+
if (nloe > 0) {
|
|
191
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
192
|
+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
|
193
|
+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
|
194
|
+
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
// Reduce HVX sum
|
|
198
|
+
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
|
|
83
199
|
|
|
84
200
|
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
|
|
85
201
|
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
|
|
86
202
|
HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
|
|
87
203
|
HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
|
|
88
204
|
|
|
205
|
+
// Scale and multiply
|
|
89
206
|
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
|
|
90
207
|
|
|
91
208
|
#pragma unroll(4)
|
|
92
|
-
for (int i = 0; i <
|
|
209
|
+
for (int i = 0; i < nvec; i++) {
|
|
93
210
|
HVX_Vector v1 = v_src[i];
|
|
94
211
|
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
|
|
95
|
-
|
|
212
|
+
HVX_Vector v3 = Q6_Vsf_equals_Vqf32(v2);
|
|
213
|
+
HVX_Vector result = Q6_Vqf32_vmpy_VsfVsf(v3, v_weight[i]);
|
|
214
|
+
v_dst[i] = Q6_Vsf_equals_Vqf32(result);
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
// Handle tail elements using vectorized ops with masking
|
|
218
|
+
if (nloe > 0) {
|
|
219
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
220
|
+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
|
221
|
+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
|
|
222
|
+
HVX_Vector v3 = Q6_Vsf_equals_Vqf32(v2);
|
|
223
|
+
HVX_Vector result = Q6_Vqf32_vmpy_VsfVsf(v3, v_weight[nvec]);
|
|
224
|
+
HVX_Vector res_v = Q6_Vsf_equals_Vqf32(result);
|
|
225
|
+
|
|
226
|
+
// Store with masking to avoid overwriting memory beyond the tensor
|
|
227
|
+
hvx_vec_store_a(&v_dst[nvec], nloe * 4, res_v);
|
|
228
|
+
}
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
static void hvx_fast_norm_f32(const uint8_t * restrict src,
|
|
232
|
+
uint8_t * restrict dst,
|
|
233
|
+
uint8_t * restrict pad,
|
|
234
|
+
const int num_elems,
|
|
235
|
+
float epsilon) {
|
|
236
|
+
(void)pad;
|
|
237
|
+
|
|
238
|
+
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
|
239
|
+
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
|
240
|
+
|
|
241
|
+
const int nvec = num_elems / VLEN_FP32; // number of full vectors
|
|
242
|
+
const int nloe = num_elems % VLEN_FP32; // leftover elements
|
|
243
|
+
|
|
244
|
+
// Compute sum of squares and sum of values for full vectors
|
|
245
|
+
HVX_Vector sum_sq_v = Q6_V_vsplat_R(0x00000000);
|
|
246
|
+
HVX_Vector sum_x_v = Q6_V_vsplat_R(0x00000000);
|
|
247
|
+
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
|
|
248
|
+
|
|
249
|
+
#pragma unroll(4)
|
|
250
|
+
for (int i = 0; i < nvec; i++) {
|
|
251
|
+
HVX_Vector v1 = v_src[i];
|
|
252
|
+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
|
253
|
+
sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2);
|
|
254
|
+
sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero()));
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
// Handle tail elements using vectorized ops with masking
|
|
258
|
+
if (nloe > 0) {
|
|
259
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
260
|
+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
|
261
|
+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
|
262
|
+
sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2);
|
|
263
|
+
sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero()));
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
// Reduce HVX sums
|
|
267
|
+
sum_sq_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_sq_v));
|
|
268
|
+
sum_x_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_x_v));
|
|
269
|
+
|
|
270
|
+
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
|
|
271
|
+
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
|
|
272
|
+
HVX_Vector mean_sq_v = Q6_Vqf32_vmpy_VsfVsf(sum_sq_v, denom_v);
|
|
273
|
+
HVX_Vector mean_x_v = Q6_Vqf32_vmpy_VsfVsf(sum_x_v, denom_v);
|
|
274
|
+
HVX_Vector mean_x_sq_v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(mean_x_v), Q6_Vsf_equals_Vqf32(mean_x_v));
|
|
275
|
+
HVX_Vector var_v = Q6_Vqf32_vsub_Vqf32Vqf32(mean_sq_v, mean_x_sq_v);
|
|
276
|
+
HVX_Vector var_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(var_v, epsilon_v);
|
|
277
|
+
|
|
278
|
+
// scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction
|
|
279
|
+
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v));
|
|
280
|
+
HVX_Vector mean_x_b = hvx_vec_repl_f32(Q6_Vsf_equals_Vqf32(mean_x_v));
|
|
281
|
+
|
|
282
|
+
#pragma unroll(4)
|
|
283
|
+
for (int i = 0; i < nvec; i++) {
|
|
284
|
+
HVX_Vector v1 = v_src[i];
|
|
285
|
+
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b);
|
|
286
|
+
HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v);
|
|
287
|
+
v_dst[i] = Q6_Vsf_equals_Vqf32(v3);
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
// Handle tail elements using vectorized ops with masking
|
|
291
|
+
if (nloe > 0) {
|
|
292
|
+
|
|
293
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
294
|
+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
|
295
|
+
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b);
|
|
296
|
+
HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v);
|
|
297
|
+
HVX_Vector result = Q6_Vsf_equals_Vqf32(v3);
|
|
298
|
+
|
|
299
|
+
// Store with masking to avoid overwriting memory beyond the tensor
|
|
300
|
+
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
|
|
96
301
|
}
|
|
97
302
|
}
|
|
98
303
|
|
|
@@ -134,6 +339,45 @@ static void rms_norm_f32(const float * restrict src,
|
|
|
134
339
|
}
|
|
135
340
|
}
|
|
136
341
|
|
|
342
|
+
static void rms_norm_mul_f32(const float * restrict src,
|
|
343
|
+
const float * restrict weight,
|
|
344
|
+
float * restrict dst,
|
|
345
|
+
const uint32_t num_rows,
|
|
346
|
+
const uint32_t row_elems,
|
|
347
|
+
const size_t row_size,
|
|
348
|
+
const size_t weight_row_size,
|
|
349
|
+
int32_t * op_params,
|
|
350
|
+
bool broadcast_weight) {
|
|
351
|
+
float epsilon = 0.f;
|
|
352
|
+
memcpy(&epsilon, op_params, sizeof(float));
|
|
353
|
+
|
|
354
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
355
|
+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
|
356
|
+
const uint8_t * restrict w_local = (const uint8_t *)weight + (broadcast_weight ? 0 : ir * weight_row_size);
|
|
357
|
+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
|
358
|
+
|
|
359
|
+
hvx_fast_rms_norm_mul_f32(src_local, w_local, dst_local, row_elems, epsilon);
|
|
360
|
+
}
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
static void norm_f32(const float * restrict src,
|
|
364
|
+
float * restrict dst,
|
|
365
|
+
uint8_t * restrict spad,
|
|
366
|
+
const uint32_t num_rows,
|
|
367
|
+
const uint32_t row_elems,
|
|
368
|
+
const size_t row_size,
|
|
369
|
+
int32_t * op_params) {
|
|
370
|
+
float epsilon = 0.f;
|
|
371
|
+
memcpy(&epsilon, op_params, sizeof(float));
|
|
372
|
+
|
|
373
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
374
|
+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
|
375
|
+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
|
376
|
+
|
|
377
|
+
hvx_fast_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
|
|
378
|
+
}
|
|
379
|
+
}
|
|
380
|
+
|
|
137
381
|
static void sqr_f32(const float * restrict src,
|
|
138
382
|
float * restrict dst,
|
|
139
383
|
uint8_t * restrict spad,
|
|
@@ -166,11 +410,259 @@ static void sqrt_f32(const float * restrict src,
|
|
|
166
410
|
}
|
|
167
411
|
}
|
|
168
412
|
|
|
413
|
+
static void neg_f32(const float * restrict src,
|
|
414
|
+
float * restrict dst,
|
|
415
|
+
uint8_t * restrict spad,
|
|
416
|
+
const uint32_t num_rows,
|
|
417
|
+
const uint32_t row_elems,
|
|
418
|
+
const size_t row_size,
|
|
419
|
+
int32_t * op_params) {
|
|
420
|
+
|
|
421
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
422
|
+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
|
423
|
+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
|
424
|
+
|
|
425
|
+
hvx_scale_f32_aa(dst_local, src_local, row_elems, -1.0f);
|
|
426
|
+
}
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
static void exp_f32(const float * restrict src,
|
|
430
|
+
float * restrict dst,
|
|
431
|
+
uint8_t * restrict spad,
|
|
432
|
+
const uint32_t num_rows,
|
|
433
|
+
const uint32_t row_elems,
|
|
434
|
+
const size_t row_size,
|
|
435
|
+
int32_t * op_params) {
|
|
436
|
+
|
|
437
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
438
|
+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
|
439
|
+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
|
440
|
+
|
|
441
|
+
hvx_exp_f32(dst_local, src_local, row_elems, false);
|
|
442
|
+
}
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
static void sigmoid_f32(const float * restrict src,
|
|
446
|
+
float * restrict dst,
|
|
447
|
+
uint8_t * restrict spad,
|
|
448
|
+
const uint32_t num_rows,
|
|
449
|
+
const uint32_t row_elems,
|
|
450
|
+
const size_t row_size,
|
|
451
|
+
int32_t * op_params) {
|
|
452
|
+
|
|
453
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
454
|
+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
|
455
|
+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
|
456
|
+
|
|
457
|
+
hvx_sigmoid_f32_aa(dst_local, src_local, row_elems);
|
|
458
|
+
}
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
static void tri_f32(const float * restrict src,
|
|
462
|
+
float * restrict dst,
|
|
463
|
+
uint8_t * restrict spad,
|
|
464
|
+
const uint32_t num_rows,
|
|
465
|
+
const uint32_t row_elems,
|
|
466
|
+
const size_t row_size,
|
|
467
|
+
int32_t * op_params,
|
|
468
|
+
const uint32_t ir,
|
|
469
|
+
const struct htp_unary_context * uctx) {
|
|
470
|
+
|
|
471
|
+
const int32_t ttype = op_params[0];
|
|
472
|
+
const HVX_Vector zero = hvx_vec_splat_f32(0.0f);
|
|
473
|
+
const uint32_t nvec = row_elems / VLEN_FP32;
|
|
474
|
+
const uint32_t nloe = row_elems % VLEN_FP32;
|
|
475
|
+
|
|
476
|
+
const uint32_t ne01 = uctx->octx->src[0]->ne[1];
|
|
477
|
+
|
|
478
|
+
for (uint32_t b = 0; b < num_rows; b++) {
|
|
479
|
+
const uint32_t abs_row = ir + b;
|
|
480
|
+
const uint32_t i01 = abs_row % ne01;
|
|
481
|
+
|
|
482
|
+
const HVX_Vector * restrict v_src = (const HVX_Vector *) ((const uint8_t *) src + b * row_size);
|
|
483
|
+
HVX_Vector * restrict v_dst = (HVX_Vector *) ((uint8_t *) dst + b * row_size);
|
|
484
|
+
|
|
485
|
+
uint32_t boundary;
|
|
486
|
+
int keep_left;
|
|
487
|
+
switch (ttype) {
|
|
488
|
+
case 0: boundary = i01; keep_left = 0; break; // keep col >= row
|
|
489
|
+
case 1: boundary = i01 + 1; keep_left = 0; break; // keep col > row
|
|
490
|
+
case 2: boundary = i01 + 1; keep_left = 1; break; // keep col <= row
|
|
491
|
+
case 3: boundary = i01; keep_left = 1; break; // keep col < row
|
|
492
|
+
default: boundary = 0; keep_left = 0; break;
|
|
493
|
+
}
|
|
494
|
+
if (boundary > row_elems) boundary = row_elems;
|
|
495
|
+
|
|
496
|
+
// Full HVX vectors — each starts at a 128-byte aligned offset
|
|
497
|
+
for (uint32_t i = 0; i < nvec; i++) {
|
|
498
|
+
const uint32_t vec_start = i * VLEN_FP32;
|
|
499
|
+
const uint32_t vec_end = vec_start + VLEN_FP32;
|
|
500
|
+
if (keep_left) {
|
|
501
|
+
if (vec_end <= boundary) {
|
|
502
|
+
v_dst[i] = v_src[i];
|
|
503
|
+
} else if (vec_start >= boundary) {
|
|
504
|
+
v_dst[i] = zero;
|
|
505
|
+
} else {
|
|
506
|
+
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
|
|
507
|
+
v_dst[i] = Q6_V_vmux_QVV(mask, v_src[i], zero);
|
|
508
|
+
}
|
|
509
|
+
} else {
|
|
510
|
+
if (vec_end <= boundary) {
|
|
511
|
+
v_dst[i] = zero;
|
|
512
|
+
} else if (vec_start >= boundary) {
|
|
513
|
+
v_dst[i] = v_src[i];
|
|
514
|
+
} else {
|
|
515
|
+
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
|
|
516
|
+
v_dst[i] = Q6_V_vmux_QVV(mask, zero, v_src[i]);
|
|
517
|
+
}
|
|
518
|
+
}
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
// Tail elements (row_elems not a multiple of VLEN_FP32)
|
|
522
|
+
if (nloe > 0) {
|
|
523
|
+
const uint32_t vec_start = nvec * VLEN_FP32;
|
|
524
|
+
const uint32_t vec_end = vec_start + nloe;
|
|
525
|
+
HVX_Vector tail_val;
|
|
526
|
+
if (keep_left) {
|
|
527
|
+
if (vec_end <= boundary) {
|
|
528
|
+
tail_val = v_src[nvec];
|
|
529
|
+
} else if (vec_start >= boundary) {
|
|
530
|
+
tail_val = zero;
|
|
531
|
+
} else {
|
|
532
|
+
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
|
|
533
|
+
tail_val = Q6_V_vmux_QVV(mask, v_src[nvec], zero);
|
|
534
|
+
}
|
|
535
|
+
} else {
|
|
536
|
+
if (vec_end <= boundary) {
|
|
537
|
+
tail_val = zero;
|
|
538
|
+
} else if (vec_start >= boundary) {
|
|
539
|
+
tail_val = v_src[nvec];
|
|
540
|
+
} else {
|
|
541
|
+
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
|
|
542
|
+
tail_val = Q6_V_vmux_QVV(mask, zero, v_src[nvec]);
|
|
543
|
+
}
|
|
544
|
+
}
|
|
545
|
+
hvx_vec_store_a(&v_dst[nvec], nloe * sizeof(float), tail_val);
|
|
546
|
+
}
|
|
547
|
+
}
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
static void softplus_f32(const float * restrict src,
|
|
551
|
+
float * restrict dst,
|
|
552
|
+
uint8_t * restrict spad,
|
|
553
|
+
const uint32_t num_rows,
|
|
554
|
+
const uint32_t row_elems,
|
|
555
|
+
const size_t row_size,
|
|
556
|
+
int32_t * op_params) {
|
|
557
|
+
// softplus(x) = log(1 + exp(x))
|
|
558
|
+
// Match CPU reference: ggml_compute_softplus_f32() in ggml-impl.h
|
|
559
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
560
|
+
const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size));
|
|
561
|
+
float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size));
|
|
562
|
+
|
|
563
|
+
for (uint32_t i = 0; i < row_elems; i++) {
|
|
564
|
+
float x = src_f[i];
|
|
565
|
+
// For x > 20: softplus(x) ≈ x (avoids exp overflow)
|
|
566
|
+
dst_f[i] = (x > 20.0f) ? x : logf(1.0f + expf(x));
|
|
567
|
+
}
|
|
568
|
+
}
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
// --- L2_NORM HVX kernel ---
|
|
572
|
+
// Computes y[i] = x[i] / fmax(sqrt(sum(x[j]^2)), epsilon) for each row.
|
|
573
|
+
// scale = 1/fmax(sqrt(sum), epsilon) is computed entirely in HVX registers
|
|
574
|
+
// using rsqrt + inverse to avoid scalar extraction.
|
|
575
|
+
static void hvx_fast_l2_norm_f32(const uint8_t * restrict src,
|
|
576
|
+
uint8_t * restrict dst,
|
|
577
|
+
uint8_t * restrict pad,
|
|
578
|
+
const int num_elems,
|
|
579
|
+
float epsilon) {
|
|
580
|
+
(void)pad;
|
|
581
|
+
|
|
582
|
+
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
|
583
|
+
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
|
584
|
+
|
|
585
|
+
HVX_Vector sum_v = hvx_vec_splat_f32(0.0f);
|
|
586
|
+
|
|
587
|
+
const int nvec = num_elems / VLEN_FP32;
|
|
588
|
+
const int nloe = num_elems % VLEN_FP32;
|
|
589
|
+
|
|
590
|
+
#pragma unroll(4)
|
|
591
|
+
for (int i = 0; i < nvec; i++) {
|
|
592
|
+
HVX_Vector v1 = v_src[i];
|
|
593
|
+
HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
|
594
|
+
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq);
|
|
595
|
+
}
|
|
596
|
+
|
|
597
|
+
// Include tail elements in the sum-of-squares using a predicate mask
|
|
598
|
+
if (nloe > 0) {
|
|
599
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
600
|
+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
|
601
|
+
HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
|
602
|
+
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq);
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
// Compute scale = 1/fmax(sqrt(sum), epsilon) entirely in HVX registers.
|
|
606
|
+
// hvx_vec_rsqrt_f32 + hvx_vec_inverse_f32 avoids scalar extraction.
|
|
607
|
+
HVX_Vector sum_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
|
|
608
|
+
HVX_Vector rsqrt_v = hvx_vec_rsqrt_f32(sum_sf); // 1/sqrt(sum)
|
|
609
|
+
HVX_Vector sqrt_v = hvx_vec_inverse_f32(rsqrt_v); // sqrt(sum)
|
|
610
|
+
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
|
|
611
|
+
HVX_Vector denom_v = Q6_Vsf_vmax_VsfVsf(sqrt_v, epsilon_v); // fmax(sqrt(sum), epsilon)
|
|
612
|
+
HVX_Vector scale_v = hvx_vec_inverse_f32(denom_v); // 1/fmax(sqrt(sum), epsilon)
|
|
613
|
+
|
|
614
|
+
#pragma unroll(4)
|
|
615
|
+
for (int i = 0; i < nvec; i++) {
|
|
616
|
+
HVX_Vector v1 = v_src[i];
|
|
617
|
+
v_dst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v));
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
if (nloe > 0) {
|
|
621
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
622
|
+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
|
623
|
+
HVX_Vector result = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v));
|
|
624
|
+
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
|
|
625
|
+
}
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
static void l2_norm_f32(const float * restrict src,
|
|
629
|
+
float * restrict dst,
|
|
630
|
+
uint8_t * restrict spad,
|
|
631
|
+
const uint32_t num_rows,
|
|
632
|
+
const uint32_t row_elems,
|
|
633
|
+
const size_t row_size,
|
|
634
|
+
int32_t * op_params) {
|
|
635
|
+
float epsilon = 0.f;
|
|
636
|
+
memcpy(&epsilon, op_params, sizeof(float));
|
|
637
|
+
|
|
638
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
639
|
+
const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size));
|
|
640
|
+
float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size));
|
|
641
|
+
|
|
642
|
+
hvx_fast_l2_norm_f32((const uint8_t *)src_f, (uint8_t *)dst_f, spad, row_elems, epsilon);
|
|
643
|
+
}
|
|
644
|
+
}
|
|
645
|
+
|
|
646
|
+
static void tanh_f32(const float * restrict src,
|
|
647
|
+
float * restrict dst,
|
|
648
|
+
uint8_t * restrict spad,
|
|
649
|
+
const uint32_t num_rows,
|
|
650
|
+
const uint32_t row_elems,
|
|
651
|
+
const size_t row_size,
|
|
652
|
+
int32_t * op_params) {
|
|
653
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
654
|
+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
|
655
|
+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
|
656
|
+
|
|
657
|
+
hvx_tanh_f32_aa(dst_local, src_local, row_elems);
|
|
658
|
+
}
|
|
659
|
+
}
|
|
660
|
+
|
|
169
661
|
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
|
170
662
|
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
|
|
171
663
|
struct htp_ops_context * octx = uctx->octx;
|
|
172
|
-
const struct htp_tensor * src =
|
|
173
|
-
const struct htp_tensor * dst =
|
|
664
|
+
const struct htp_tensor * src = octx->src[0];
|
|
665
|
+
const struct htp_tensor * dst = octx->dst;
|
|
174
666
|
|
|
175
667
|
htp_unary_preamble;
|
|
176
668
|
|
|
@@ -178,8 +670,8 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
|
|
178
670
|
int32_t * op_params = octx->op_params;
|
|
179
671
|
uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread;
|
|
180
672
|
|
|
181
|
-
const size_t
|
|
182
|
-
const size_t
|
|
673
|
+
const size_t src0_data_row_size = uctx->src0_data_row_size;
|
|
674
|
+
const size_t dst_data_row_size = uctx->dst_data_row_size;
|
|
183
675
|
|
|
184
676
|
const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
|
|
185
677
|
const size_t dst_row_size_aligned = uctx->dst_row_size_aligned;
|
|
@@ -197,15 +689,32 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
|
|
197
689
|
t1 = HAP_perf_get_qtimer_count();
|
|
198
690
|
|
|
199
691
|
const uint8_t * restrict data_src = uctx->data_src0;
|
|
692
|
+
const uint8_t * restrict data_src1 = uctx->data_src1;
|
|
200
693
|
uint8_t * restrict data_dst = uctx->data_dst;
|
|
201
694
|
|
|
695
|
+
const struct htp_tensor * src1 = (htp_op == HTP_OP_RMS_NORM_MUL) ? octx->src[1] : NULL;
|
|
696
|
+
const uint32_t nb11 = src1 ? src1->nb[1] : 0;
|
|
697
|
+
const uint32_t nb12 = src1 ? src1->nb[2] : 0;
|
|
698
|
+
const uint32_t nb13 = src1 ? src1->nb[3] : 0;
|
|
699
|
+
|
|
202
700
|
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
|
|
701
|
+
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
|
|
203
702
|
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
|
|
204
703
|
|
|
205
704
|
size_t src0_spad_half_size = uctx->src0_spad_half_size;
|
|
705
|
+
size_t src1_spad_half_size = uctx->src1_spad_half_size;
|
|
206
706
|
size_t dst_spad_half_size = uctx->dst_spad_half_size;
|
|
207
707
|
|
|
208
|
-
|
|
708
|
+
// Non-contiguous tensors have gaps at dim-2/3 boundaries that a single-stride
|
|
709
|
+
// 2D DMA descriptor cannot span. Clamp BLOCK to ne1 (one dim-1 slice) so every
|
|
710
|
+
// transfer stays within a nb1-uniform region. Skipped for contiguous tensors.
|
|
711
|
+
const bool src0_contig = (nb02 == (size_t)ne01 * nb01) &&
|
|
712
|
+
(nb03 == (size_t)ne02 * nb02);
|
|
713
|
+
const bool dst_contig = (nb2 == (size_t)ne1 * nb1) &&
|
|
714
|
+
(nb3 == (size_t)ne2 * nb2);
|
|
715
|
+
const uint32_t src0_max_block = src0_contig ? uctx->block : MIN((uint32_t)uctx->block, ne01);
|
|
716
|
+
const uint32_t dst_max_block = dst_contig ? uctx->block : MIN((uint32_t)uctx->block, ne1);
|
|
717
|
+
const uint32_t BLOCK = MIN(src0_max_block, dst_max_block);
|
|
209
718
|
if (BLOCK == 0) {
|
|
210
719
|
FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
|
211
720
|
octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
|
@@ -214,30 +723,59 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
|
|
214
723
|
|
|
215
724
|
dma_queue * dma_queue = octx->ctx->dma[ith];
|
|
216
725
|
|
|
217
|
-
|
|
218
|
-
|
|
726
|
+
// If weight is broadcasted, load it once per thread at the beginning of execution
|
|
727
|
+
if (htp_op == HTP_OP_RMS_NORM_MUL && uctx->broadcast_weight) {
|
|
728
|
+
dma_queue_push(dma_queue, dma_make_ptr(src1_spad_data, data_src1), uctx->src1_row_size_aligned, 0, uctx->src1_data_row_size, 1);
|
|
729
|
+
dma_queue_flush(dma_queue);
|
|
730
|
+
}
|
|
731
|
+
|
|
732
|
+
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; spad_idx++) {
|
|
733
|
+
const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
|
219
734
|
|
|
220
735
|
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
|
|
221
|
-
|
|
736
|
+
dma_queue_push(dma_queue,
|
|
222
737
|
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
|
|
223
|
-
|
|
738
|
+
nb1, dst_row_size_aligned, dst_data_row_size, 0);
|
|
739
|
+
|
|
740
|
+
const size_t src0_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03);
|
|
741
|
+
dma_queue_push(dma_queue,
|
|
742
|
+
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + src0_off),
|
|
743
|
+
src0_row_size_aligned, nb01, src0_data_row_size, block_size);
|
|
744
|
+
|
|
745
|
+
if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) {
|
|
746
|
+
const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb11, nb12, nb13);
|
|
747
|
+
dma_queue_push(dma_queue,
|
|
748
|
+
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + src1_off),
|
|
749
|
+
uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, block_size);
|
|
750
|
+
}
|
|
224
751
|
|
|
225
|
-
|
|
226
|
-
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)),
|
|
227
|
-
src0_row_size_aligned, src0_row_size, block_size);
|
|
752
|
+
ir += block_size;
|
|
228
753
|
}
|
|
229
754
|
|
|
230
|
-
for (uint32_t ir = src0_start_row; ir < src0_end_row;
|
|
231
|
-
const uint32_t block_size =
|
|
755
|
+
for (uint32_t ir = src0_start_row; ir < src0_end_row; ) {
|
|
756
|
+
const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
|
232
757
|
|
|
233
758
|
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
|
|
234
759
|
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
|
|
760
|
+
float * src1_spad = NULL;
|
|
761
|
+
if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) {
|
|
762
|
+
src1_spad = (float *) dma_queue_pop(dma_queue).dst;
|
|
763
|
+
}
|
|
235
764
|
|
|
236
765
|
// Process block in VTCM
|
|
237
766
|
switch (htp_op) {
|
|
767
|
+
case HTP_OP_NORM:
|
|
768
|
+
norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
769
|
+
break;
|
|
238
770
|
case HTP_OP_RMS_NORM:
|
|
239
771
|
rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
240
772
|
break;
|
|
773
|
+
case HTP_OP_RMS_NORM_MUL:
|
|
774
|
+
{
|
|
775
|
+
const float * w_ptr = uctx->broadcast_weight ? (const float *) src1_spad_data : src1_spad;
|
|
776
|
+
rms_norm_mul_f32(src0_spad, w_ptr, dst_spad, block_size, ne0, src0_row_size_aligned, uctx->src1_row_size_aligned, op_params, uctx->broadcast_weight);
|
|
777
|
+
}
|
|
778
|
+
break;
|
|
241
779
|
case HTP_OP_SCALE:
|
|
242
780
|
scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
243
781
|
break;
|
|
@@ -247,22 +785,57 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
|
|
247
785
|
case HTP_OP_SQRT:
|
|
248
786
|
sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
249
787
|
break;
|
|
788
|
+
case HTP_OP_UNARY_NEG:
|
|
789
|
+
neg_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
790
|
+
break;
|
|
791
|
+
case HTP_OP_UNARY_EXP:
|
|
792
|
+
exp_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
793
|
+
break;
|
|
794
|
+
case HTP_OP_UNARY_SIGMOID:
|
|
795
|
+
sigmoid_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
796
|
+
break;
|
|
797
|
+
case HTP_OP_UNARY_SOFTPLUS:
|
|
798
|
+
softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
799
|
+
break;
|
|
800
|
+
case HTP_OP_UNARY_TANH:
|
|
801
|
+
tanh_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
802
|
+
break;
|
|
803
|
+
case HTP_OP_L2_NORM:
|
|
804
|
+
l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
805
|
+
break;
|
|
806
|
+
case HTP_OP_TRI:
|
|
807
|
+
tri_f32(src0_spad, dst_spad, NULL, block_size, ne00, src0_row_size_aligned, op_params, ir, uctx);
|
|
808
|
+
break;
|
|
250
809
|
default:
|
|
251
810
|
break;
|
|
252
811
|
}
|
|
253
812
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
813
|
+
const size_t dst_off = unary_row_offset(ir, ne1, ne2, nb1, nb2, nb3);
|
|
814
|
+
dma_queue_push(dma_queue,
|
|
815
|
+
dma_make_ptr(data_dst + dst_off, dst_spad),
|
|
816
|
+
nb1, dst_row_size_aligned, dst_data_row_size, block_size);
|
|
257
817
|
|
|
258
818
|
// prefetch N+2 loop iteration if any
|
|
259
|
-
const uint32_t
|
|
260
|
-
if (
|
|
261
|
-
const uint32_t
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
819
|
+
const uint32_t next_ir = ir + block_size;
|
|
820
|
+
if (next_ir < src0_end_row) {
|
|
821
|
+
const uint32_t next_block_size = unary_block_size(next_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
|
822
|
+
const uint32_t pref_ir = next_ir + next_block_size;
|
|
823
|
+
if (pref_ir < src0_end_row) {
|
|
824
|
+
const uint32_t pref_block_size = unary_block_size(pref_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
|
825
|
+
const size_t src0_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03);
|
|
826
|
+
dma_queue_push(dma_queue,
|
|
827
|
+
dma_make_ptr(src0_spad, data_src + src0_pref_off),
|
|
828
|
+
src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size);
|
|
829
|
+
|
|
830
|
+
if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) {
|
|
831
|
+
const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb11, nb12, nb13);
|
|
832
|
+
dma_queue_push(dma_queue,
|
|
833
|
+
dma_make_ptr(src1_spad, data_src1 + src1_pref_off),
|
|
834
|
+
uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, pref_block_size);
|
|
835
|
+
}
|
|
836
|
+
}
|
|
265
837
|
}
|
|
838
|
+
ir += block_size;
|
|
266
839
|
}
|
|
267
840
|
|
|
268
841
|
dma_queue_flush(dma_queue);
|
|
@@ -277,15 +850,21 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
|
|
277
850
|
static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|
278
851
|
int err = HTP_STATUS_OK;
|
|
279
852
|
|
|
280
|
-
const struct htp_tensor * src0 =
|
|
281
|
-
struct htp_tensor *
|
|
853
|
+
const struct htp_tensor * src0 = octx->src[0];
|
|
854
|
+
const struct htp_tensor * dst = octx->dst;
|
|
282
855
|
|
|
283
856
|
const char * op_type = NULL;
|
|
284
857
|
|
|
285
858
|
switch (octx->op) {
|
|
859
|
+
case HTP_OP_NORM:
|
|
860
|
+
op_type = "norm-f32";
|
|
861
|
+
break;
|
|
286
862
|
case HTP_OP_RMS_NORM:
|
|
287
863
|
op_type = "rmsnorm-f32";
|
|
288
864
|
break;
|
|
865
|
+
case HTP_OP_RMS_NORM_MUL:
|
|
866
|
+
op_type = "rmsnorm-mul-f32";
|
|
867
|
+
break;
|
|
289
868
|
case HTP_OP_SCALE:
|
|
290
869
|
op_type = "scale-f32";
|
|
291
870
|
break;
|
|
@@ -295,6 +874,27 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|
|
295
874
|
case HTP_OP_SQRT:
|
|
296
875
|
op_type = "sqrt-f32";
|
|
297
876
|
break;
|
|
877
|
+
case HTP_OP_UNARY_NEG:
|
|
878
|
+
op_type = "neg-f32";
|
|
879
|
+
break;
|
|
880
|
+
case HTP_OP_UNARY_EXP:
|
|
881
|
+
op_type = "exp-f32";
|
|
882
|
+
break;
|
|
883
|
+
case HTP_OP_UNARY_SIGMOID:
|
|
884
|
+
op_type = "sigmoid-f32";
|
|
885
|
+
break;
|
|
886
|
+
case HTP_OP_UNARY_SOFTPLUS:
|
|
887
|
+
op_type = "softplus-f32";
|
|
888
|
+
break;
|
|
889
|
+
case HTP_OP_UNARY_TANH:
|
|
890
|
+
op_type = "tanh-f32";
|
|
891
|
+
break;
|
|
892
|
+
case HTP_OP_L2_NORM:
|
|
893
|
+
op_type = "l2norm-f32";
|
|
894
|
+
break;
|
|
895
|
+
case HTP_OP_TRI:
|
|
896
|
+
op_type = "tri-f32";
|
|
897
|
+
break;
|
|
298
898
|
|
|
299
899
|
default:
|
|
300
900
|
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
|
|
@@ -304,18 +904,50 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|
|
304
904
|
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
|
305
905
|
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
|
306
906
|
|
|
307
|
-
const size_t
|
|
308
|
-
const size_t
|
|
907
|
+
const size_t src0_data_row_size = src0->ne[0] * sizeof(float);
|
|
908
|
+
const size_t dst_data_row_size = dst->ne[0] * sizeof(float);
|
|
909
|
+
|
|
910
|
+
const size_t src0_row_size_aligned = hex_round_up(src0_data_row_size, VLEN);
|
|
911
|
+
const size_t dst_row_size_aligned = hex_round_up(dst_data_row_size, VLEN);
|
|
309
912
|
|
|
310
|
-
|
|
311
|
-
|
|
913
|
+
size_t src1_data_row_size = 0;
|
|
914
|
+
size_t src1_row_size_aligned = 0;
|
|
915
|
+
bool broadcast_weight = false;
|
|
916
|
+
const struct htp_tensor * src1 = NULL;
|
|
917
|
+
|
|
918
|
+
if (octx->op == HTP_OP_RMS_NORM_MUL) {
|
|
919
|
+
src1 = octx->src[1];
|
|
920
|
+
src1_data_row_size = src1->ne[0] * sizeof(float);
|
|
921
|
+
src1_row_size_aligned = hex_round_up(src1_data_row_size, VLEN);
|
|
922
|
+
broadcast_weight = (src1->ne[1] * src1->ne[2] * src1->ne[3] == 1);
|
|
923
|
+
}
|
|
312
924
|
|
|
313
925
|
// VTCM scratchpads for all tensors
|
|
314
926
|
// N rows per thread, padded to HVX vector size
|
|
315
927
|
// Double buffering requires 2x size per buffer
|
|
316
928
|
|
|
317
|
-
size_t spad_size_per_row
|
|
318
|
-
size_t vtcm_row_per_thread =
|
|
929
|
+
size_t spad_size_per_row = 0;
|
|
930
|
+
size_t vtcm_row_per_thread = 0;
|
|
931
|
+
|
|
932
|
+
if (octx->op == HTP_OP_RMS_NORM_MUL) {
|
|
933
|
+
if (broadcast_weight) {
|
|
934
|
+
size_t available_vtcm = octx->ctx->vtcm_size;
|
|
935
|
+
size_t src1_spad_total = n_threads * src1_row_size_aligned;
|
|
936
|
+
if (available_vtcm > src1_spad_total) {
|
|
937
|
+
available_vtcm -= src1_spad_total;
|
|
938
|
+
} else {
|
|
939
|
+
available_vtcm = 0;
|
|
940
|
+
}
|
|
941
|
+
spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned);
|
|
942
|
+
vtcm_row_per_thread = available_vtcm / (n_threads * spad_size_per_row);
|
|
943
|
+
} else {
|
|
944
|
+
spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned + src1_row_size_aligned);
|
|
945
|
+
vtcm_row_per_thread = (octx->ctx->vtcm_size) / (n_threads * spad_size_per_row);
|
|
946
|
+
}
|
|
947
|
+
} else {
|
|
948
|
+
spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned);
|
|
949
|
+
vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row);
|
|
950
|
+
}
|
|
319
951
|
|
|
320
952
|
// Make sure the reserved vtcm size is sufficient
|
|
321
953
|
if (vtcm_row_per_thread == 0) {
|
|
@@ -330,8 +962,29 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|
|
330
962
|
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
|
|
331
963
|
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
|
|
332
964
|
|
|
965
|
+
if (octx->op == HTP_OP_RMS_NORM_MUL) {
|
|
966
|
+
if (broadcast_weight) {
|
|
967
|
+
octx->src1_spad.size_per_thread = src1_row_size_aligned;
|
|
968
|
+
} else {
|
|
969
|
+
octx->src1_spad.size_per_thread = src1_row_size_aligned * vtcm_row_per_thread * 2;
|
|
970
|
+
}
|
|
971
|
+
octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread;
|
|
972
|
+
} else {
|
|
973
|
+
octx->src1_spad.size = 0;
|
|
974
|
+
octx->src1_spad.size_per_thread = 0;
|
|
975
|
+
}
|
|
976
|
+
|
|
333
977
|
octx->src0_spad.data = octx->ctx->vtcm_base;
|
|
334
|
-
|
|
978
|
+
if (octx->op == HTP_OP_RMS_NORM_MUL) {
|
|
979
|
+
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
|
980
|
+
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
|
981
|
+
} else {
|
|
982
|
+
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
|
983
|
+
}
|
|
984
|
+
|
|
985
|
+
octx->src0_spad.src = NULL;
|
|
986
|
+
octx->src1_spad.src = NULL;
|
|
987
|
+
octx->dst_spad.src = NULL;
|
|
335
988
|
|
|
336
989
|
FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
|
|
337
990
|
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
|
@@ -344,19 +997,24 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|
|
344
997
|
.src0_nrows = src0_nrows,
|
|
345
998
|
|
|
346
999
|
.data_src0 = (const uint8_t *)src0->data,
|
|
1000
|
+
.data_src1 = (octx->op == HTP_OP_RMS_NORM_MUL) ? (const uint8_t *)src1->data : NULL,
|
|
347
1001
|
.data_dst = (uint8_t *)dst->data,
|
|
348
1002
|
|
|
349
|
-
.
|
|
350
|
-
.
|
|
1003
|
+
.src0_data_row_size = src0_data_row_size,
|
|
1004
|
+
.src1_data_row_size = src1_data_row_size,
|
|
1005
|
+
.dst_data_row_size = dst_data_row_size,
|
|
351
1006
|
|
|
352
1007
|
.src0_row_size_aligned = src0_row_size_aligned,
|
|
1008
|
+
.src1_row_size_aligned = src1_row_size_aligned,
|
|
353
1009
|
.dst_row_size_aligned = dst_row_size_aligned,
|
|
354
1010
|
|
|
355
1011
|
.src0_spad_half_size = octx->src0_spad.size_per_thread / 2,
|
|
1012
|
+
.src1_spad_half_size = (octx->op == HTP_OP_RMS_NORM_MUL) ? (octx->src1_spad.size_per_thread / (broadcast_weight ? 1 : 2)) : 0,
|
|
356
1013
|
.dst_spad_half_size = octx->dst_spad.size_per_thread / 2,
|
|
357
1014
|
|
|
358
1015
|
.block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned,
|
|
359
1016
|
.nc = src0->ne[0],
|
|
1017
|
+
.broadcast_weight = broadcast_weight,
|
|
360
1018
|
};
|
|
361
1019
|
|
|
362
1020
|
worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads);
|
|
@@ -365,10 +1023,26 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|
|
365
1023
|
return err;
|
|
366
1024
|
}
|
|
367
1025
|
|
|
1026
|
+
int op_tri(struct htp_ops_context * octx) {
|
|
1027
|
+
int err = HTP_STATUS_OK;
|
|
1028
|
+
|
|
1029
|
+
switch (octx->src[0]->type) {
|
|
1030
|
+
case HTP_TYPE_F32:
|
|
1031
|
+
err = execute_op_unary_f32(octx);
|
|
1032
|
+
break;
|
|
1033
|
+
|
|
1034
|
+
default:
|
|
1035
|
+
err = HTP_STATUS_NO_SUPPORT;
|
|
1036
|
+
break;
|
|
1037
|
+
}
|
|
1038
|
+
|
|
1039
|
+
return err;
|
|
1040
|
+
}
|
|
1041
|
+
|
|
368
1042
|
int op_unary(struct htp_ops_context * octx) {
|
|
369
1043
|
int err = HTP_STATUS_OK;
|
|
370
1044
|
|
|
371
|
-
switch (octx->
|
|
1045
|
+
switch (octx->src[0]->type) {
|
|
372
1046
|
case HTP_TYPE_F32:
|
|
373
1047
|
err = execute_op_unary_f32(octx);
|
|
374
1048
|
break;
|