whispercpp 1.3.5 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/LICENSE +1 -1
- data/README.md +133 -3
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -7
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +56 -46
- data/ext/ruby_whisper.h +165 -2
- data/ext/ruby_whisper_context.c +297 -126
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -66
- data/ext/ruby_whisper_segment.c +6 -7
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +46 -16
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +24 -19
- data/ext/sources/examples/cli/cli.cpp +51 -9
- data/ext/sources/examples/common-ggml.cpp +4 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +213 -163
- data/ext/sources/ggml/CMakeLists.txt +29 -15
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +73 -11
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +8 -3
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +155 -16
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +25 -5
- data/ext/sources/ggml/src/ggml-alloc.c +9 -10
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
- data/ext/sources/ggml/src/ggml-common.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
- data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
- data/ext/sources/ggml/src/ggml-impl.h +68 -1
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +385 -119
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
- data/ext/sources/ggml/src/ggml.c +268 -52
- data/ext/sources/ggml/src/gguf.cpp +377 -47
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +62 -40
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +445 -55
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_context_params.rb +82 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +44 -6
- data/whispercpp.gemspec +2 -2
- metadata +426 -280
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
- data/ext/sources/examples/talk-llama/llama-context.h +0 -360
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
- data/ext/sources/examples/talk-llama/llama-model.h +0 -544
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
- data/ext/sources/examples/talk-llama/llama.h +0 -1540
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -569
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
|
@@ -2,28 +2,82 @@
|
|
|
2
2
|
#pragma clang diagnostic ignored "-Wunused-function"
|
|
3
3
|
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
|
4
4
|
|
|
5
|
-
#ifdef HTP_DEBUG
|
|
6
|
-
# define FARF_HIGH 1
|
|
7
|
-
#endif
|
|
8
|
-
|
|
9
5
|
#include <HAP_farf.h>
|
|
10
|
-
#include <HAP_mem.h>
|
|
11
6
|
#include <HAP_perf.h>
|
|
12
|
-
|
|
13
|
-
#include <hexagon_protos.h>
|
|
14
|
-
#include <hexagon_types.h>
|
|
7
|
+
|
|
15
8
|
#include <math.h>
|
|
16
|
-
#include <qurt_thread.h>
|
|
17
9
|
#include <string.h>
|
|
18
10
|
|
|
11
|
+
#include "hex-dma.h"
|
|
12
|
+
#include "hvx-exp.h"
|
|
13
|
+
#include "hvx-sigmoid.h"
|
|
14
|
+
#include "hvx-utils.h"
|
|
15
|
+
|
|
19
16
|
#define GGML_COMMON_DECL_C
|
|
20
17
|
#include "ggml-common.h"
|
|
21
18
|
#include "htp-ctx.h"
|
|
22
|
-
#include "htp-dma.h"
|
|
23
|
-
#include "htp-msg.h"
|
|
24
19
|
#include "htp-ops.h"
|
|
25
|
-
|
|
26
|
-
|
|
20
|
+
|
|
21
|
+
struct htp_unary_context {
|
|
22
|
+
struct htp_ops_context * octx;
|
|
23
|
+
|
|
24
|
+
// Precomputed values
|
|
25
|
+
const uint8_t * data_src0;
|
|
26
|
+
const uint8_t * data_src1; // weight/scale tensor for RMS_NORM_MUL
|
|
27
|
+
uint8_t * data_dst;
|
|
28
|
+
|
|
29
|
+
size_t src0_data_row_size; // actual data bytes per row
|
|
30
|
+
size_t src1_data_row_size;
|
|
31
|
+
size_t dst_data_row_size; // actual data bytes per row
|
|
32
|
+
|
|
33
|
+
size_t src0_row_size_aligned;
|
|
34
|
+
size_t src1_row_size_aligned;
|
|
35
|
+
size_t dst_row_size_aligned;
|
|
36
|
+
|
|
37
|
+
size_t src0_spad_half_size;
|
|
38
|
+
size_t src1_spad_half_size;
|
|
39
|
+
size_t dst_spad_half_size;
|
|
40
|
+
|
|
41
|
+
uint32_t block;
|
|
42
|
+
uint32_t src0_nrows;
|
|
43
|
+
uint32_t src0_nrows_per_thread;
|
|
44
|
+
uint32_t nc;
|
|
45
|
+
bool broadcast_weight;
|
|
46
|
+
};
|
|
47
|
+
|
|
48
|
+
// Convert flat row index to DDR byte offset using the tensor's actual strides.
|
|
49
|
+
// ir = i1 + ne1*(i2 + ne2*i3) => offset = i1*nb1 + i2*nb2 + i3*nb3
|
|
50
|
+
static inline size_t unary_row_offset(uint32_t ir,
|
|
51
|
+
uint32_t ne1, uint32_t ne2,
|
|
52
|
+
size_t nb1, size_t nb2, size_t nb3) {
|
|
53
|
+
const uint32_t i1 = ir % ne1;
|
|
54
|
+
const uint32_t i2 = (ir / ne1) % ne2;
|
|
55
|
+
const uint32_t i3 = ir / (ne1 * ne2);
|
|
56
|
+
return i1 * nb1 + i2 * nb2 + i3 * nb3;
|
|
57
|
+
}
|
|
58
|
+
// Safe DMA block size from row `ir`: clamp to the tighter dim-1 slice
|
|
59
|
+
// boundary of src and dst so the nb1 stride stays valid for all rows.
|
|
60
|
+
static inline uint32_t unary_block_size(uint32_t ir,
|
|
61
|
+
uint32_t end_row,
|
|
62
|
+
uint32_t block,
|
|
63
|
+
bool src_contig,
|
|
64
|
+
bool dst_contig,
|
|
65
|
+
uint32_t src_ne1,
|
|
66
|
+
uint32_t dst_ne1) {
|
|
67
|
+
uint32_t limit = MIN(block, end_row - ir);
|
|
68
|
+
|
|
69
|
+
if (!src_contig) {
|
|
70
|
+
const uint32_t src_slice_end = (ir / src_ne1 + 1) * src_ne1;
|
|
71
|
+
limit = MIN(limit, src_slice_end - ir);
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
if (!dst_contig) {
|
|
75
|
+
const uint32_t dst_slice_end = (ir / dst_ne1 + 1) * dst_ne1;
|
|
76
|
+
limit = MIN(limit, dst_slice_end - ir);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
return limit;
|
|
80
|
+
}
|
|
27
81
|
|
|
28
82
|
#define htp_unary_preamble \
|
|
29
83
|
const uint32_t ne00 = src->ne[0]; \
|
|
@@ -51,110 +105,578 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
|
|
51
105
|
uint8_t * restrict pad,
|
|
52
106
|
const int num_elems,
|
|
53
107
|
float epsilon) {
|
|
108
|
+
(void)pad;
|
|
109
|
+
|
|
54
110
|
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
|
55
111
|
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
|
56
112
|
|
|
57
|
-
|
|
58
|
-
|
|
113
|
+
const int nvec = num_elems / VLEN_FP32; // number of full vectors
|
|
114
|
+
const int nloe = num_elems % VLEN_FP32; // leftover elements
|
|
115
|
+
|
|
116
|
+
// Compute sum of squares for full vectors
|
|
117
|
+
HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
|
|
118
|
+
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
|
|
119
|
+
|
|
120
|
+
#pragma unroll(4)
|
|
121
|
+
for (int i = 0; i < nvec; i++) {
|
|
122
|
+
HVX_Vector v1 = v_src[i];
|
|
123
|
+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
|
124
|
+
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
// Handle tail elements using vectorized ops with masking
|
|
128
|
+
if (nloe > 0) {
|
|
129
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
130
|
+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
|
131
|
+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
|
132
|
+
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
// Reduce HVX sum
|
|
136
|
+
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
|
|
137
|
+
|
|
138
|
+
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
|
|
139
|
+
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
|
|
140
|
+
HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
|
|
141
|
+
HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
|
|
142
|
+
|
|
143
|
+
// Scale full vectors
|
|
144
|
+
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
|
|
59
145
|
|
|
60
|
-
int step_of_1 = num_elems >> 5;
|
|
61
146
|
#pragma unroll(4)
|
|
62
|
-
for (int i = 0; i <
|
|
147
|
+
for (int i = 0; i < nvec; i++) {
|
|
148
|
+
HVX_Vector v1 = v_src[i];
|
|
149
|
+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
|
|
150
|
+
v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
// Handle tail elements using vectorized ops with masking
|
|
154
|
+
if (nloe > 0) {
|
|
155
|
+
|
|
156
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
157
|
+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
|
158
|
+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
|
|
159
|
+
HVX_Vector result = Q6_Vsf_equals_Vqf32(v2);
|
|
160
|
+
|
|
161
|
+
// Store with masking to avoid overwriting memory beyond the tensor
|
|
162
|
+
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
static void hvx_fast_rms_norm_mul_f32(const uint8_t * restrict src,
|
|
167
|
+
const uint8_t * restrict weight,
|
|
168
|
+
uint8_t * restrict dst,
|
|
169
|
+
const int num_elems,
|
|
170
|
+
float epsilon) {
|
|
171
|
+
const HVX_Vector * restrict v_src = (const HVX_Vector *) src;
|
|
172
|
+
const HVX_Vector * restrict v_weight = (const HVX_Vector *) weight;
|
|
173
|
+
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
|
174
|
+
|
|
175
|
+
const int nvec = num_elems / VLEN_FP32; // number of full vectors
|
|
176
|
+
const int nloe = num_elems % VLEN_FP32; // leftover elements
|
|
177
|
+
|
|
178
|
+
// Compute sum of squares for full vectors
|
|
179
|
+
HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
|
|
180
|
+
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
|
|
181
|
+
|
|
182
|
+
#pragma unroll(4)
|
|
183
|
+
for (int i = 0; i < nvec; i++) {
|
|
63
184
|
HVX_Vector v1 = v_src[i];
|
|
64
185
|
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
|
65
|
-
sum_v
|
|
186
|
+
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
|
66
187
|
}
|
|
67
188
|
|
|
68
|
-
|
|
69
|
-
|
|
189
|
+
// Handle tail elements using vectorized ops with masking
|
|
190
|
+
if (nloe > 0) {
|
|
191
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
192
|
+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
|
193
|
+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
|
194
|
+
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
|
195
|
+
}
|
|
70
196
|
|
|
71
|
-
|
|
72
|
-
|
|
197
|
+
// Reduce HVX sum
|
|
198
|
+
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
|
|
199
|
+
|
|
200
|
+
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
|
|
201
|
+
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
|
|
73
202
|
HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
|
|
74
203
|
HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
|
|
75
204
|
|
|
76
|
-
|
|
205
|
+
// Scale and multiply
|
|
206
|
+
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
|
|
77
207
|
|
|
78
208
|
#pragma unroll(4)
|
|
79
|
-
for (int i = 0; i <
|
|
209
|
+
for (int i = 0; i < nvec; i++) {
|
|
80
210
|
HVX_Vector v1 = v_src[i];
|
|
81
211
|
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
|
|
82
|
-
|
|
212
|
+
HVX_Vector v3 = Q6_Vsf_equals_Vqf32(v2);
|
|
213
|
+
HVX_Vector result = Q6_Vqf32_vmpy_VsfVsf(v3, v_weight[i]);
|
|
214
|
+
v_dst[i] = Q6_Vsf_equals_Vqf32(result);
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
// Handle tail elements using vectorized ops with masking
|
|
218
|
+
if (nloe > 0) {
|
|
219
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
220
|
+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
|
221
|
+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
|
|
222
|
+
HVX_Vector v3 = Q6_Vsf_equals_Vqf32(v2);
|
|
223
|
+
HVX_Vector result = Q6_Vqf32_vmpy_VsfVsf(v3, v_weight[nvec]);
|
|
224
|
+
HVX_Vector res_v = Q6_Vsf_equals_Vqf32(result);
|
|
225
|
+
|
|
226
|
+
// Store with masking to avoid overwriting memory beyond the tensor
|
|
227
|
+
hvx_vec_store_a(&v_dst[nvec], nloe * 4, res_v);
|
|
83
228
|
}
|
|
84
229
|
}
|
|
85
230
|
|
|
86
|
-
static void
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
231
|
+
static void hvx_fast_norm_f32(const uint8_t * restrict src,
|
|
232
|
+
uint8_t * restrict dst,
|
|
233
|
+
uint8_t * restrict pad,
|
|
234
|
+
const int num_elems,
|
|
235
|
+
float epsilon) {
|
|
236
|
+
(void)pad;
|
|
237
|
+
|
|
238
|
+
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
|
239
|
+
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
|
240
|
+
|
|
241
|
+
const int nvec = num_elems / VLEN_FP32; // number of full vectors
|
|
242
|
+
const int nloe = num_elems % VLEN_FP32; // leftover elements
|
|
243
|
+
|
|
244
|
+
// Compute sum of squares and sum of values for full vectors
|
|
245
|
+
HVX_Vector sum_sq_v = Q6_V_vsplat_R(0x00000000);
|
|
246
|
+
HVX_Vector sum_x_v = Q6_V_vsplat_R(0x00000000);
|
|
247
|
+
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
|
|
248
|
+
|
|
249
|
+
#pragma unroll(4)
|
|
250
|
+
for (int i = 0; i < nvec; i++) {
|
|
251
|
+
HVX_Vector v1 = v_src[i];
|
|
252
|
+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
|
253
|
+
sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2);
|
|
254
|
+
sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero()));
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
// Handle tail elements using vectorized ops with masking
|
|
258
|
+
if (nloe > 0) {
|
|
259
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
260
|
+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
|
261
|
+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
|
262
|
+
sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2);
|
|
263
|
+
sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero()));
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
// Reduce HVX sums
|
|
267
|
+
sum_sq_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_sq_v));
|
|
268
|
+
sum_x_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_x_v));
|
|
269
|
+
|
|
270
|
+
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
|
|
271
|
+
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
|
|
272
|
+
HVX_Vector mean_sq_v = Q6_Vqf32_vmpy_VsfVsf(sum_sq_v, denom_v);
|
|
273
|
+
HVX_Vector mean_x_v = Q6_Vqf32_vmpy_VsfVsf(sum_x_v, denom_v);
|
|
274
|
+
HVX_Vector mean_x_sq_v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(mean_x_v), Q6_Vsf_equals_Vqf32(mean_x_v));
|
|
275
|
+
HVX_Vector var_v = Q6_Vqf32_vsub_Vqf32Vqf32(mean_sq_v, mean_x_sq_v);
|
|
276
|
+
HVX_Vector var_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(var_v, epsilon_v);
|
|
277
|
+
|
|
278
|
+
// scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction
|
|
279
|
+
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v));
|
|
280
|
+
HVX_Vector mean_x_b = hvx_vec_repl_f32(Q6_Vsf_equals_Vqf32(mean_x_v));
|
|
281
|
+
|
|
282
|
+
#pragma unroll(4)
|
|
283
|
+
for (int i = 0; i < nvec; i++) {
|
|
284
|
+
HVX_Vector v1 = v_src[i];
|
|
285
|
+
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b);
|
|
286
|
+
HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v);
|
|
287
|
+
v_dst[i] = Q6_Vsf_equals_Vqf32(v3);
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
// Handle tail elements using vectorized ops with masking
|
|
291
|
+
if (nloe > 0) {
|
|
292
|
+
|
|
293
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
294
|
+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
|
295
|
+
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b);
|
|
296
|
+
HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v);
|
|
297
|
+
HVX_Vector result = Q6_Vsf_equals_Vqf32(v3);
|
|
298
|
+
|
|
299
|
+
// Store with masking to avoid overwriting memory beyond the tensor
|
|
300
|
+
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
|
|
301
|
+
}
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
static void scale_f32(const float * restrict src,
|
|
305
|
+
float * restrict dst,
|
|
306
|
+
uint8_t * restrict spad,
|
|
307
|
+
const uint32_t num_rows,
|
|
308
|
+
const uint32_t row_elems,
|
|
309
|
+
const size_t row_size,
|
|
310
|
+
int32_t * op_params) {
|
|
94
311
|
float scale = 0.f;
|
|
95
312
|
float bias = 0.f;
|
|
96
313
|
memcpy(&scale, &op_params[0], sizeof(float));
|
|
97
314
|
memcpy(&bias, &op_params[1], sizeof(float));
|
|
98
315
|
|
|
99
316
|
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
100
|
-
const
|
|
101
|
-
|
|
317
|
+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
|
318
|
+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
|
102
319
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
320
|
+
hvx_scale_offset_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
|
|
321
|
+
}
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
static void rms_norm_f32(const float * restrict src,
|
|
325
|
+
float * restrict dst,
|
|
326
|
+
uint8_t * restrict spad,
|
|
327
|
+
const uint32_t num_rows,
|
|
328
|
+
const uint32_t row_elems,
|
|
329
|
+
const size_t row_size,
|
|
330
|
+
int32_t * op_params) {
|
|
331
|
+
float epsilon = 0.f;
|
|
332
|
+
memcpy(&epsilon, op_params, sizeof(float));
|
|
106
333
|
|
|
107
|
-
|
|
334
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
335
|
+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
|
336
|
+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
|
337
|
+
|
|
338
|
+
hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
|
|
108
339
|
}
|
|
109
340
|
}
|
|
110
341
|
|
|
111
|
-
static void
|
|
342
|
+
static void rms_norm_mul_f32(const float * restrict src,
|
|
343
|
+
const float * restrict weight,
|
|
112
344
|
float * restrict dst,
|
|
113
|
-
uint8_t * restrict spad,
|
|
114
345
|
const uint32_t num_rows,
|
|
115
346
|
const uint32_t row_elems,
|
|
116
347
|
const size_t row_size,
|
|
348
|
+
const size_t weight_row_size,
|
|
117
349
|
int32_t * op_params,
|
|
118
|
-
|
|
350
|
+
bool broadcast_weight) {
|
|
119
351
|
float epsilon = 0.f;
|
|
120
352
|
memcpy(&epsilon, op_params, sizeof(float));
|
|
121
353
|
|
|
122
354
|
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
123
|
-
const
|
|
124
|
-
|
|
355
|
+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
|
356
|
+
const uint8_t * restrict w_local = (const uint8_t *)weight + (broadcast_weight ? 0 : ir * weight_row_size);
|
|
357
|
+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
|
358
|
+
|
|
359
|
+
hvx_fast_rms_norm_mul_f32(src_local, w_local, dst_local, row_elems, epsilon);
|
|
360
|
+
}
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
static void norm_f32(const float * restrict src,
|
|
364
|
+
float * restrict dst,
|
|
365
|
+
uint8_t * restrict spad,
|
|
366
|
+
const uint32_t num_rows,
|
|
367
|
+
const uint32_t row_elems,
|
|
368
|
+
const size_t row_size,
|
|
369
|
+
int32_t * op_params) {
|
|
370
|
+
float epsilon = 0.f;
|
|
371
|
+
memcpy(&epsilon, op_params, sizeof(float));
|
|
372
|
+
|
|
373
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
374
|
+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
|
375
|
+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
|
376
|
+
|
|
377
|
+
hvx_fast_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
|
|
378
|
+
}
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
static void sqr_f32(const float * restrict src,
|
|
382
|
+
float * restrict dst,
|
|
383
|
+
uint8_t * restrict spad,
|
|
384
|
+
const uint32_t num_rows,
|
|
385
|
+
const uint32_t row_elems,
|
|
386
|
+
const size_t row_size,
|
|
387
|
+
int32_t * op_params) {
|
|
388
|
+
|
|
389
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
390
|
+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
|
391
|
+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
|
392
|
+
|
|
393
|
+
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
|
394
|
+
}
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
static void sqrt_f32(const float * restrict src,
|
|
398
|
+
float * restrict dst,
|
|
399
|
+
uint8_t * restrict spad,
|
|
400
|
+
const uint32_t num_rows,
|
|
401
|
+
const uint32_t row_elems,
|
|
402
|
+
const size_t row_size,
|
|
403
|
+
int32_t * op_params) {
|
|
404
|
+
|
|
405
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
406
|
+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
|
407
|
+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
|
408
|
+
|
|
409
|
+
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
|
410
|
+
}
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
static void neg_f32(const float * restrict src,
|
|
414
|
+
float * restrict dst,
|
|
415
|
+
uint8_t * restrict spad,
|
|
416
|
+
const uint32_t num_rows,
|
|
417
|
+
const uint32_t row_elems,
|
|
418
|
+
const size_t row_size,
|
|
419
|
+
int32_t * op_params) {
|
|
420
|
+
|
|
421
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
422
|
+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
|
423
|
+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
|
424
|
+
|
|
425
|
+
hvx_scale_f32_aa(dst_local, src_local, row_elems, -1.0f);
|
|
426
|
+
}
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
static void exp_f32(const float * restrict src,
|
|
430
|
+
float * restrict dst,
|
|
431
|
+
uint8_t * restrict spad,
|
|
432
|
+
const uint32_t num_rows,
|
|
433
|
+
const uint32_t row_elems,
|
|
434
|
+
const size_t row_size,
|
|
435
|
+
int32_t * op_params) {
|
|
436
|
+
|
|
437
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
438
|
+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
|
439
|
+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
|
440
|
+
|
|
441
|
+
hvx_exp_f32(dst_local, src_local, row_elems, false);
|
|
442
|
+
}
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
static void sigmoid_f32(const float * restrict src,
|
|
446
|
+
float * restrict dst,
|
|
447
|
+
uint8_t * restrict spad,
|
|
448
|
+
const uint32_t num_rows,
|
|
449
|
+
const uint32_t row_elems,
|
|
450
|
+
const size_t row_size,
|
|
451
|
+
int32_t * op_params) {
|
|
452
|
+
|
|
453
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
454
|
+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
|
455
|
+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
|
125
456
|
|
|
126
|
-
|
|
127
|
-
|
|
457
|
+
hvx_sigmoid_f32_aa(dst_local, src_local, row_elems);
|
|
458
|
+
}
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
static void tri_f32(const float * restrict src,
|
|
462
|
+
float * restrict dst,
|
|
463
|
+
uint8_t * restrict spad,
|
|
464
|
+
const uint32_t num_rows,
|
|
465
|
+
const uint32_t row_elems,
|
|
466
|
+
const size_t row_size,
|
|
467
|
+
int32_t * op_params,
|
|
468
|
+
const uint32_t ir,
|
|
469
|
+
const struct htp_unary_context * uctx) {
|
|
470
|
+
|
|
471
|
+
const int32_t ttype = op_params[0];
|
|
472
|
+
const HVX_Vector zero = hvx_vec_splat_f32(0.0f);
|
|
473
|
+
const uint32_t nvec = row_elems / VLEN_FP32;
|
|
474
|
+
const uint32_t nloe = row_elems % VLEN_FP32;
|
|
475
|
+
|
|
476
|
+
const uint32_t ne01 = uctx->octx->src[0]->ne[1];
|
|
477
|
+
|
|
478
|
+
for (uint32_t b = 0; b < num_rows; b++) {
|
|
479
|
+
const uint32_t abs_row = ir + b;
|
|
480
|
+
const uint32_t i01 = abs_row % ne01;
|
|
481
|
+
|
|
482
|
+
const HVX_Vector * restrict v_src = (const HVX_Vector *) ((const uint8_t *) src + b * row_size);
|
|
483
|
+
HVX_Vector * restrict v_dst = (HVX_Vector *) ((uint8_t *) dst + b * row_size);
|
|
484
|
+
|
|
485
|
+
uint32_t boundary;
|
|
486
|
+
int keep_left;
|
|
487
|
+
switch (ttype) {
|
|
488
|
+
case 0: boundary = i01; keep_left = 0; break; // keep col >= row
|
|
489
|
+
case 1: boundary = i01 + 1; keep_left = 0; break; // keep col > row
|
|
490
|
+
case 2: boundary = i01 + 1; keep_left = 1; break; // keep col <= row
|
|
491
|
+
case 3: boundary = i01; keep_left = 1; break; // keep col < row
|
|
492
|
+
default: boundary = 0; keep_left = 0; break;
|
|
128
493
|
}
|
|
494
|
+
if (boundary > row_elems) boundary = row_elems;
|
|
129
495
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
496
|
+
// Full HVX vectors — each starts at a 128-byte aligned offset
|
|
497
|
+
for (uint32_t i = 0; i < nvec; i++) {
|
|
498
|
+
const uint32_t vec_start = i * VLEN_FP32;
|
|
499
|
+
const uint32_t vec_end = vec_start + VLEN_FP32;
|
|
500
|
+
if (keep_left) {
|
|
501
|
+
if (vec_end <= boundary) {
|
|
502
|
+
v_dst[i] = v_src[i];
|
|
503
|
+
} else if (vec_start >= boundary) {
|
|
504
|
+
v_dst[i] = zero;
|
|
505
|
+
} else {
|
|
506
|
+
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
|
|
507
|
+
v_dst[i] = Q6_V_vmux_QVV(mask, v_src[i], zero);
|
|
508
|
+
}
|
|
509
|
+
} else {
|
|
510
|
+
if (vec_end <= boundary) {
|
|
511
|
+
v_dst[i] = zero;
|
|
512
|
+
} else if (vec_start >= boundary) {
|
|
513
|
+
v_dst[i] = v_src[i];
|
|
514
|
+
} else {
|
|
515
|
+
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
|
|
516
|
+
v_dst[i] = Q6_V_vmux_QVV(mask, zero, v_src[i]);
|
|
517
|
+
}
|
|
518
|
+
}
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
// Tail elements (row_elems not a multiple of VLEN_FP32)
|
|
522
|
+
if (nloe > 0) {
|
|
523
|
+
const uint32_t vec_start = nvec * VLEN_FP32;
|
|
524
|
+
const uint32_t vec_end = vec_start + nloe;
|
|
525
|
+
HVX_Vector tail_val;
|
|
526
|
+
if (keep_left) {
|
|
527
|
+
if (vec_end <= boundary) {
|
|
528
|
+
tail_val = v_src[nvec];
|
|
529
|
+
} else if (vec_start >= boundary) {
|
|
530
|
+
tail_val = zero;
|
|
531
|
+
} else {
|
|
532
|
+
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
|
|
533
|
+
tail_val = Q6_V_vmux_QVV(mask, v_src[nvec], zero);
|
|
534
|
+
}
|
|
535
|
+
} else {
|
|
536
|
+
if (vec_end <= boundary) {
|
|
537
|
+
tail_val = zero;
|
|
538
|
+
} else if (vec_start >= boundary) {
|
|
539
|
+
tail_val = v_src[nvec];
|
|
540
|
+
} else {
|
|
541
|
+
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
|
|
542
|
+
tail_val = Q6_V_vmux_QVV(mask, zero, v_src[nvec]);
|
|
543
|
+
}
|
|
544
|
+
}
|
|
545
|
+
hvx_vec_store_a(&v_dst[nvec], nloe * sizeof(float), tail_val);
|
|
546
|
+
}
|
|
547
|
+
}
|
|
548
|
+
}
|
|
134
549
|
|
|
135
|
-
|
|
136
|
-
|
|
550
|
+
static void softplus_f32(const float * restrict src,
|
|
551
|
+
float * restrict dst,
|
|
552
|
+
uint8_t * restrict spad,
|
|
553
|
+
const uint32_t num_rows,
|
|
554
|
+
const uint32_t row_elems,
|
|
555
|
+
const size_t row_size,
|
|
556
|
+
int32_t * op_params) {
|
|
557
|
+
// softplus(x) = log(1 + exp(x))
|
|
558
|
+
// Match CPU reference: ggml_compute_softplus_f32() in ggml-impl.h
|
|
559
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
560
|
+
const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size));
|
|
561
|
+
float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size));
|
|
137
562
|
|
|
138
|
-
|
|
563
|
+
for (uint32_t i = 0; i < row_elems; i++) {
|
|
564
|
+
float x = src_f[i];
|
|
565
|
+
// For x > 20: softplus(x) ≈ x (avoids exp overflow)
|
|
566
|
+
dst_f[i] = (x > 20.0f) ? x : logf(1.0f + expf(x));
|
|
139
567
|
}
|
|
140
568
|
}
|
|
141
569
|
}
|
|
142
570
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
571
|
+
// --- L2_NORM HVX kernel ---
|
|
572
|
+
// Computes y[i] = x[i] / fmax(sqrt(sum(x[j]^2)), epsilon) for each row.
|
|
573
|
+
// scale = 1/fmax(sqrt(sum), epsilon) is computed entirely in HVX registers
|
|
574
|
+
// using rsqrt + inverse to avoid scalar extraction.
|
|
575
|
+
static void hvx_fast_l2_norm_f32(const uint8_t * restrict src,
|
|
576
|
+
uint8_t * restrict dst,
|
|
577
|
+
uint8_t * restrict pad,
|
|
578
|
+
const int num_elems,
|
|
579
|
+
float epsilon) {
|
|
580
|
+
(void)pad;
|
|
581
|
+
|
|
582
|
+
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
|
583
|
+
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
|
584
|
+
|
|
585
|
+
HVX_Vector sum_v = hvx_vec_splat_f32(0.0f);
|
|
586
|
+
|
|
587
|
+
const int nvec = num_elems / VLEN_FP32;
|
|
588
|
+
const int nloe = num_elems % VLEN_FP32;
|
|
589
|
+
|
|
590
|
+
#pragma unroll(4)
|
|
591
|
+
for (int i = 0; i < nvec; i++) {
|
|
592
|
+
HVX_Vector v1 = v_src[i];
|
|
593
|
+
HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
|
594
|
+
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq);
|
|
595
|
+
}
|
|
596
|
+
|
|
597
|
+
// Include tail elements in the sum-of-squares using a predicate mask
|
|
598
|
+
if (nloe > 0) {
|
|
599
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
600
|
+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
|
601
|
+
HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
|
602
|
+
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq);
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
// Compute scale = 1/fmax(sqrt(sum), epsilon) entirely in HVX registers.
|
|
606
|
+
// hvx_vec_rsqrt_f32 + hvx_vec_inverse_f32 avoids scalar extraction.
|
|
607
|
+
HVX_Vector sum_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
|
|
608
|
+
HVX_Vector rsqrt_v = hvx_vec_rsqrt_f32(sum_sf); // 1/sqrt(sum)
|
|
609
|
+
HVX_Vector sqrt_v = hvx_vec_inverse_f32(rsqrt_v); // sqrt(sum)
|
|
610
|
+
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
|
|
611
|
+
HVX_Vector denom_v = Q6_Vsf_vmax_VsfVsf(sqrt_v, epsilon_v); // fmax(sqrt(sum), epsilon)
|
|
612
|
+
HVX_Vector scale_v = hvx_vec_inverse_f32(denom_v); // 1/fmax(sqrt(sum), epsilon)
|
|
613
|
+
|
|
614
|
+
#pragma unroll(4)
|
|
615
|
+
for (int i = 0; i < nvec; i++) {
|
|
616
|
+
HVX_Vector v1 = v_src[i];
|
|
617
|
+
v_dst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v));
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
if (nloe > 0) {
|
|
621
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
622
|
+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
|
623
|
+
HVX_Vector result = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v));
|
|
624
|
+
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
|
|
625
|
+
}
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
static void l2_norm_f32(const float * restrict src,
|
|
629
|
+
float * restrict dst,
|
|
630
|
+
uint8_t * restrict spad,
|
|
631
|
+
const uint32_t num_rows,
|
|
632
|
+
const uint32_t row_elems,
|
|
633
|
+
const size_t row_size,
|
|
634
|
+
int32_t * op_params) {
|
|
635
|
+
float epsilon = 0.f;
|
|
636
|
+
memcpy(&epsilon, op_params, sizeof(float));
|
|
637
|
+
|
|
638
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
639
|
+
const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size));
|
|
640
|
+
float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size));
|
|
641
|
+
|
|
642
|
+
hvx_fast_l2_norm_f32((const uint8_t *)src_f, (uint8_t *)dst_f, spad, row_elems, epsilon);
|
|
643
|
+
}
|
|
644
|
+
}
|
|
645
|
+
|
|
646
|
+
static void tanh_f32(const float * restrict src,
|
|
647
|
+
float * restrict dst,
|
|
648
|
+
uint8_t * restrict spad,
|
|
649
|
+
const uint32_t num_rows,
|
|
650
|
+
const uint32_t row_elems,
|
|
651
|
+
const size_t row_size,
|
|
652
|
+
int32_t * op_params) {
|
|
653
|
+
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
|
654
|
+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
|
655
|
+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
|
656
|
+
|
|
657
|
+
hvx_tanh_f32_aa(dst_local, src_local, row_elems);
|
|
658
|
+
}
|
|
659
|
+
}
|
|
660
|
+
|
|
661
|
+
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
|
662
|
+
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
|
|
663
|
+
struct htp_ops_context * octx = uctx->octx;
|
|
664
|
+
const struct htp_tensor * src = octx->src[0];
|
|
665
|
+
const struct htp_tensor * dst = octx->dst;
|
|
666
|
+
|
|
151
667
|
htp_unary_preamble;
|
|
152
668
|
|
|
153
|
-
|
|
154
|
-
|
|
669
|
+
int htp_op = octx->op;
|
|
670
|
+
int32_t * op_params = octx->op_params;
|
|
671
|
+
uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread;
|
|
155
672
|
|
|
156
|
-
const
|
|
673
|
+
const size_t src0_data_row_size = uctx->src0_data_row_size;
|
|
674
|
+
const size_t dst_data_row_size = uctx->dst_data_row_size;
|
|
157
675
|
|
|
676
|
+
const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
|
|
677
|
+
const size_t dst_row_size_aligned = uctx->dst_row_size_aligned;
|
|
678
|
+
|
|
679
|
+
const uint32_t src0_nrows = uctx->src0_nrows;
|
|
158
680
|
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
|
159
681
|
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
|
160
682
|
|
|
@@ -166,66 +688,212 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
|
|
166
688
|
uint64_t t1, t2;
|
|
167
689
|
t1 = HAP_perf_get_qtimer_count();
|
|
168
690
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
691
|
+
const uint8_t * restrict data_src = uctx->data_src0;
|
|
692
|
+
const uint8_t * restrict data_src1 = uctx->data_src1;
|
|
693
|
+
uint8_t * restrict data_dst = uctx->data_dst;
|
|
694
|
+
|
|
695
|
+
const struct htp_tensor * src1 = (htp_op == HTP_OP_RMS_NORM_MUL) ? octx->src[1] : NULL;
|
|
696
|
+
const uint32_t nb11 = src1 ? src1->nb[1] : 0;
|
|
697
|
+
const uint32_t nb12 = src1 ? src1->nb[2] : 0;
|
|
698
|
+
const uint32_t nb13 = src1 ? src1->nb[3] : 0;
|
|
699
|
+
|
|
700
|
+
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
|
|
701
|
+
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
|
|
702
|
+
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
|
|
703
|
+
|
|
704
|
+
size_t src0_spad_half_size = uctx->src0_spad_half_size;
|
|
705
|
+
size_t src1_spad_half_size = uctx->src1_spad_half_size;
|
|
706
|
+
size_t dst_spad_half_size = uctx->dst_spad_half_size;
|
|
707
|
+
|
|
708
|
+
// Non-contiguous tensors have gaps at dim-2/3 boundaries that a single-stride
|
|
709
|
+
// 2D DMA descriptor cannot span. Clamp BLOCK to ne1 (one dim-1 slice) so every
|
|
710
|
+
// transfer stays within a nb1-uniform region. Skipped for contiguous tensors.
|
|
711
|
+
const bool src0_contig = (nb02 == (size_t)ne01 * nb01) &&
|
|
712
|
+
(nb03 == (size_t)ne02 * nb02);
|
|
713
|
+
const bool dst_contig = (nb2 == (size_t)ne1 * nb1) &&
|
|
714
|
+
(nb3 == (size_t)ne2 * nb2);
|
|
715
|
+
const uint32_t src0_max_block = src0_contig ? uctx->block : MIN((uint32_t)uctx->block, ne01);
|
|
716
|
+
const uint32_t dst_max_block = dst_contig ? uctx->block : MIN((uint32_t)uctx->block, ne1);
|
|
717
|
+
const uint32_t BLOCK = MIN(src0_max_block, dst_max_block);
|
|
718
|
+
if (BLOCK == 0) {
|
|
719
|
+
FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
|
720
|
+
octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
|
721
|
+
return;
|
|
174
722
|
}
|
|
175
|
-
|
|
176
|
-
|
|
723
|
+
|
|
724
|
+
dma_queue * dma_queue = octx->ctx->dma[ith];
|
|
725
|
+
|
|
726
|
+
// If weight is broadcasted, load it once per thread at the beginning of execution
|
|
727
|
+
if (htp_op == HTP_OP_RMS_NORM_MUL && uctx->broadcast_weight) {
|
|
728
|
+
dma_queue_push(dma_queue, dma_make_ptr(src1_spad_data, data_src1), uctx->src1_row_size_aligned, 0, uctx->src1_data_row_size, 1);
|
|
729
|
+
dma_queue_flush(dma_queue);
|
|
177
730
|
}
|
|
178
731
|
|
|
179
|
-
|
|
180
|
-
|
|
732
|
+
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; spad_idx++) {
|
|
733
|
+
const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
|
181
734
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
735
|
+
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
|
|
736
|
+
dma_queue_push(dma_queue,
|
|
737
|
+
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
|
|
738
|
+
nb1, dst_row_size_aligned, dst_data_row_size, 0);
|
|
185
739
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
case HTP_OP_SCALE:
|
|
191
|
-
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
|
192
|
-
break;
|
|
740
|
+
const size_t src0_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03);
|
|
741
|
+
dma_queue_push(dma_queue,
|
|
742
|
+
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + src0_off),
|
|
743
|
+
src0_row_size_aligned, nb01, src0_data_row_size, block_size);
|
|
193
744
|
|
|
194
|
-
|
|
195
|
-
|
|
745
|
+
if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) {
|
|
746
|
+
const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb11, nb12, nb13);
|
|
747
|
+
dma_queue_push(dma_queue,
|
|
748
|
+
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + src1_off),
|
|
749
|
+
uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, block_size);
|
|
750
|
+
}
|
|
751
|
+
|
|
752
|
+
ir += block_size;
|
|
196
753
|
}
|
|
197
754
|
|
|
755
|
+
for (uint32_t ir = src0_start_row; ir < src0_end_row; ) {
|
|
756
|
+
const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
|
757
|
+
|
|
758
|
+
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
|
|
759
|
+
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
|
|
760
|
+
float * src1_spad = NULL;
|
|
761
|
+
if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) {
|
|
762
|
+
src1_spad = (float *) dma_queue_pop(dma_queue).dst;
|
|
763
|
+
}
|
|
764
|
+
|
|
765
|
+
// Process block in VTCM
|
|
766
|
+
switch (htp_op) {
|
|
767
|
+
case HTP_OP_NORM:
|
|
768
|
+
norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
769
|
+
break;
|
|
770
|
+
case HTP_OP_RMS_NORM:
|
|
771
|
+
rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
772
|
+
break;
|
|
773
|
+
case HTP_OP_RMS_NORM_MUL:
|
|
774
|
+
{
|
|
775
|
+
const float * w_ptr = uctx->broadcast_weight ? (const float *) src1_spad_data : src1_spad;
|
|
776
|
+
rms_norm_mul_f32(src0_spad, w_ptr, dst_spad, block_size, ne0, src0_row_size_aligned, uctx->src1_row_size_aligned, op_params, uctx->broadcast_weight);
|
|
777
|
+
}
|
|
778
|
+
break;
|
|
779
|
+
case HTP_OP_SCALE:
|
|
780
|
+
scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
781
|
+
break;
|
|
782
|
+
case HTP_OP_SQR:
|
|
783
|
+
sqr_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
784
|
+
break;
|
|
785
|
+
case HTP_OP_SQRT:
|
|
786
|
+
sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
787
|
+
break;
|
|
788
|
+
case HTP_OP_UNARY_NEG:
|
|
789
|
+
neg_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
790
|
+
break;
|
|
791
|
+
case HTP_OP_UNARY_EXP:
|
|
792
|
+
exp_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
793
|
+
break;
|
|
794
|
+
case HTP_OP_UNARY_SIGMOID:
|
|
795
|
+
sigmoid_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
796
|
+
break;
|
|
797
|
+
case HTP_OP_UNARY_SOFTPLUS:
|
|
798
|
+
softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
799
|
+
break;
|
|
800
|
+
case HTP_OP_UNARY_TANH:
|
|
801
|
+
tanh_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
802
|
+
break;
|
|
803
|
+
case HTP_OP_L2_NORM:
|
|
804
|
+
l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
|
805
|
+
break;
|
|
806
|
+
case HTP_OP_TRI:
|
|
807
|
+
tri_f32(src0_spad, dst_spad, NULL, block_size, ne00, src0_row_size_aligned, op_params, ir, uctx);
|
|
808
|
+
break;
|
|
809
|
+
default:
|
|
810
|
+
break;
|
|
811
|
+
}
|
|
812
|
+
|
|
813
|
+
const size_t dst_off = unary_row_offset(ir, ne1, ne2, nb1, nb2, nb3);
|
|
814
|
+
dma_queue_push(dma_queue,
|
|
815
|
+
dma_make_ptr(data_dst + dst_off, dst_spad),
|
|
816
|
+
nb1, dst_row_size_aligned, dst_data_row_size, block_size);
|
|
817
|
+
|
|
818
|
+
// prefetch N+2 loop iteration if any
|
|
819
|
+
const uint32_t next_ir = ir + block_size;
|
|
820
|
+
if (next_ir < src0_end_row) {
|
|
821
|
+
const uint32_t next_block_size = unary_block_size(next_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
|
822
|
+
const uint32_t pref_ir = next_ir + next_block_size;
|
|
823
|
+
if (pref_ir < src0_end_row) {
|
|
824
|
+
const uint32_t pref_block_size = unary_block_size(pref_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
|
825
|
+
const size_t src0_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03);
|
|
826
|
+
dma_queue_push(dma_queue,
|
|
827
|
+
dma_make_ptr(src0_spad, data_src + src0_pref_off),
|
|
828
|
+
src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size);
|
|
829
|
+
|
|
830
|
+
if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) {
|
|
831
|
+
const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb11, nb12, nb13);
|
|
832
|
+
dma_queue_push(dma_queue,
|
|
833
|
+
dma_make_ptr(src1_spad, data_src1 + src1_pref_off),
|
|
834
|
+
uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, pref_block_size);
|
|
835
|
+
}
|
|
836
|
+
}
|
|
837
|
+
}
|
|
838
|
+
ir += block_size;
|
|
839
|
+
}
|
|
840
|
+
|
|
841
|
+
dma_queue_flush(dma_queue);
|
|
842
|
+
|
|
198
843
|
t2 = HAP_perf_get_qtimer_count();
|
|
199
844
|
|
|
200
|
-
FARF(HIGH, "unary-f32 %d/%d
|
|
845
|
+
FARF(HIGH, "unary-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, src->ne[0],
|
|
201
846
|
src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],
|
|
202
847
|
dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
|
203
848
|
}
|
|
204
849
|
|
|
205
|
-
static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
|
|
206
|
-
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
|
207
|
-
|
|
208
|
-
unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i,
|
|
209
|
-
octx->src0_nrows_per_thread);
|
|
210
|
-
}
|
|
211
|
-
|
|
212
850
|
static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|
213
851
|
int err = HTP_STATUS_OK;
|
|
214
852
|
|
|
215
|
-
const struct htp_tensor * src0 =
|
|
216
|
-
struct htp_tensor *
|
|
853
|
+
const struct htp_tensor * src0 = octx->src[0];
|
|
854
|
+
const struct htp_tensor * dst = octx->dst;
|
|
217
855
|
|
|
218
|
-
|
|
219
|
-
const char * op_type = NULL;
|
|
856
|
+
const char * op_type = NULL;
|
|
220
857
|
|
|
221
858
|
switch (octx->op) {
|
|
859
|
+
case HTP_OP_NORM:
|
|
860
|
+
op_type = "norm-f32";
|
|
861
|
+
break;
|
|
222
862
|
case HTP_OP_RMS_NORM:
|
|
223
|
-
|
|
224
|
-
|
|
863
|
+
op_type = "rmsnorm-f32";
|
|
864
|
+
break;
|
|
865
|
+
case HTP_OP_RMS_NORM_MUL:
|
|
866
|
+
op_type = "rmsnorm-mul-f32";
|
|
225
867
|
break;
|
|
226
868
|
case HTP_OP_SCALE:
|
|
227
|
-
|
|
228
|
-
|
|
869
|
+
op_type = "scale-f32";
|
|
870
|
+
break;
|
|
871
|
+
case HTP_OP_SQR:
|
|
872
|
+
op_type = "sqr-f32";
|
|
873
|
+
break;
|
|
874
|
+
case HTP_OP_SQRT:
|
|
875
|
+
op_type = "sqrt-f32";
|
|
876
|
+
break;
|
|
877
|
+
case HTP_OP_UNARY_NEG:
|
|
878
|
+
op_type = "neg-f32";
|
|
879
|
+
break;
|
|
880
|
+
case HTP_OP_UNARY_EXP:
|
|
881
|
+
op_type = "exp-f32";
|
|
882
|
+
break;
|
|
883
|
+
case HTP_OP_UNARY_SIGMOID:
|
|
884
|
+
op_type = "sigmoid-f32";
|
|
885
|
+
break;
|
|
886
|
+
case HTP_OP_UNARY_SOFTPLUS:
|
|
887
|
+
op_type = "softplus-f32";
|
|
888
|
+
break;
|
|
889
|
+
case HTP_OP_UNARY_TANH:
|
|
890
|
+
op_type = "tanh-f32";
|
|
891
|
+
break;
|
|
892
|
+
case HTP_OP_L2_NORM:
|
|
893
|
+
op_type = "l2norm-f32";
|
|
894
|
+
break;
|
|
895
|
+
case HTP_OP_TRI:
|
|
896
|
+
op_type = "tri-f32";
|
|
229
897
|
break;
|
|
230
898
|
|
|
231
899
|
default:
|
|
@@ -233,38 +901,139 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|
|
233
901
|
return HTP_STATUS_NO_SUPPORT;
|
|
234
902
|
}
|
|
235
903
|
|
|
236
|
-
const int n_threads = octx->n_threads;
|
|
237
904
|
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
|
905
|
+
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
|
238
906
|
|
|
239
|
-
const size_t
|
|
240
|
-
const size_t
|
|
907
|
+
const size_t src0_data_row_size = src0->ne[0] * sizeof(float);
|
|
908
|
+
const size_t dst_data_row_size = dst->ne[0] * sizeof(float);
|
|
909
|
+
|
|
910
|
+
const size_t src0_row_size_aligned = hex_round_up(src0_data_row_size, VLEN);
|
|
911
|
+
const size_t dst_row_size_aligned = hex_round_up(dst_data_row_size, VLEN);
|
|
912
|
+
|
|
913
|
+
size_t src1_data_row_size = 0;
|
|
914
|
+
size_t src1_row_size_aligned = 0;
|
|
915
|
+
bool broadcast_weight = false;
|
|
916
|
+
const struct htp_tensor * src1 = NULL;
|
|
917
|
+
|
|
918
|
+
if (octx->op == HTP_OP_RMS_NORM_MUL) {
|
|
919
|
+
src1 = octx->src[1];
|
|
920
|
+
src1_data_row_size = src1->ne[0] * sizeof(float);
|
|
921
|
+
src1_row_size_aligned = hex_round_up(src1_data_row_size, VLEN);
|
|
922
|
+
broadcast_weight = (src1->ne[1] * src1->ne[2] * src1->ne[3] == 1);
|
|
923
|
+
}
|
|
241
924
|
|
|
242
925
|
// VTCM scratchpads for all tensors
|
|
243
|
-
|
|
244
|
-
|
|
926
|
+
// N rows per thread, padded to HVX vector size
|
|
927
|
+
// Double buffering requires 2x size per buffer
|
|
245
928
|
|
|
246
|
-
size_t
|
|
929
|
+
size_t spad_size_per_row = 0;
|
|
930
|
+
size_t vtcm_row_per_thread = 0;
|
|
247
931
|
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
932
|
+
if (octx->op == HTP_OP_RMS_NORM_MUL) {
|
|
933
|
+
if (broadcast_weight) {
|
|
934
|
+
size_t available_vtcm = octx->ctx->vtcm_size;
|
|
935
|
+
size_t src1_spad_total = n_threads * src1_row_size_aligned;
|
|
936
|
+
if (available_vtcm > src1_spad_total) {
|
|
937
|
+
available_vtcm -= src1_spad_total;
|
|
938
|
+
} else {
|
|
939
|
+
available_vtcm = 0;
|
|
940
|
+
}
|
|
941
|
+
spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned);
|
|
942
|
+
vtcm_row_per_thread = available_vtcm / (n_threads * spad_size_per_row);
|
|
943
|
+
} else {
|
|
944
|
+
spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned + src1_row_size_aligned);
|
|
945
|
+
vtcm_row_per_thread = (octx->ctx->vtcm_size) / (n_threads * spad_size_per_row);
|
|
946
|
+
}
|
|
947
|
+
} else {
|
|
948
|
+
spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned);
|
|
949
|
+
vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row);
|
|
950
|
+
}
|
|
251
951
|
|
|
252
952
|
// Make sure the reserved vtcm size is sufficient
|
|
253
|
-
if (
|
|
953
|
+
if (vtcm_row_per_thread == 0) {
|
|
254
954
|
FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
|
255
|
-
|
|
955
|
+
spad_size_per_row * n_threads);
|
|
256
956
|
return HTP_STATUS_VTCM_TOO_SMALL;
|
|
257
957
|
}
|
|
258
958
|
|
|
959
|
+
octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread * 2;
|
|
960
|
+
octx->dst_spad.size_per_thread = dst_row_size_aligned * vtcm_row_per_thread * 2;
|
|
961
|
+
|
|
962
|
+
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
|
|
963
|
+
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
|
|
964
|
+
|
|
965
|
+
if (octx->op == HTP_OP_RMS_NORM_MUL) {
|
|
966
|
+
if (broadcast_weight) {
|
|
967
|
+
octx->src1_spad.size_per_thread = src1_row_size_aligned;
|
|
968
|
+
} else {
|
|
969
|
+
octx->src1_spad.size_per_thread = src1_row_size_aligned * vtcm_row_per_thread * 2;
|
|
970
|
+
}
|
|
971
|
+
octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread;
|
|
972
|
+
} else {
|
|
973
|
+
octx->src1_spad.size = 0;
|
|
974
|
+
octx->src1_spad.size_per_thread = 0;
|
|
975
|
+
}
|
|
976
|
+
|
|
259
977
|
octx->src0_spad.data = octx->ctx->vtcm_base;
|
|
260
|
-
|
|
978
|
+
if (octx->op == HTP_OP_RMS_NORM_MUL) {
|
|
979
|
+
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
|
980
|
+
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
|
981
|
+
} else {
|
|
982
|
+
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
|
983
|
+
}
|
|
984
|
+
|
|
985
|
+
octx->src0_spad.src = NULL;
|
|
986
|
+
octx->src1_spad.src = NULL;
|
|
987
|
+
octx->dst_spad.src = NULL;
|
|
988
|
+
|
|
989
|
+
FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
|
|
990
|
+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
|
991
|
+
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
|
261
992
|
|
|
262
993
|
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
|
263
|
-
|
|
994
|
+
struct htp_unary_context uctx = {
|
|
995
|
+
.octx = octx,
|
|
996
|
+
.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads,
|
|
997
|
+
.src0_nrows = src0_nrows,
|
|
998
|
+
|
|
999
|
+
.data_src0 = (const uint8_t *)src0->data,
|
|
1000
|
+
.data_src1 = (octx->op == HTP_OP_RMS_NORM_MUL) ? (const uint8_t *)src1->data : NULL,
|
|
1001
|
+
.data_dst = (uint8_t *)dst->data,
|
|
1002
|
+
|
|
1003
|
+
.src0_data_row_size = src0_data_row_size,
|
|
1004
|
+
.src1_data_row_size = src1_data_row_size,
|
|
1005
|
+
.dst_data_row_size = dst_data_row_size,
|
|
264
1006
|
|
|
265
|
-
|
|
1007
|
+
.src0_row_size_aligned = src0_row_size_aligned,
|
|
1008
|
+
.src1_row_size_aligned = src1_row_size_aligned,
|
|
1009
|
+
.dst_row_size_aligned = dst_row_size_aligned,
|
|
266
1010
|
|
|
267
|
-
|
|
1011
|
+
.src0_spad_half_size = octx->src0_spad.size_per_thread / 2,
|
|
1012
|
+
.src1_spad_half_size = (octx->op == HTP_OP_RMS_NORM_MUL) ? (octx->src1_spad.size_per_thread / (broadcast_weight ? 1 : 2)) : 0,
|
|
1013
|
+
.dst_spad_half_size = octx->dst_spad.size_per_thread / 2,
|
|
1014
|
+
|
|
1015
|
+
.block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned,
|
|
1016
|
+
.nc = src0->ne[0],
|
|
1017
|
+
.broadcast_weight = broadcast_weight,
|
|
1018
|
+
};
|
|
1019
|
+
|
|
1020
|
+
worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads);
|
|
1021
|
+
}
|
|
1022
|
+
|
|
1023
|
+
return err;
|
|
1024
|
+
}
|
|
1025
|
+
|
|
1026
|
+
int op_tri(struct htp_ops_context * octx) {
|
|
1027
|
+
int err = HTP_STATUS_OK;
|
|
1028
|
+
|
|
1029
|
+
switch (octx->src[0]->type) {
|
|
1030
|
+
case HTP_TYPE_F32:
|
|
1031
|
+
err = execute_op_unary_f32(octx);
|
|
1032
|
+
break;
|
|
1033
|
+
|
|
1034
|
+
default:
|
|
1035
|
+
err = HTP_STATUS_NO_SUPPORT;
|
|
1036
|
+
break;
|
|
268
1037
|
}
|
|
269
1038
|
|
|
270
1039
|
return err;
|
|
@@ -273,7 +1042,7 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|
|
273
1042
|
int op_unary(struct htp_ops_context * octx) {
|
|
274
1043
|
int err = HTP_STATUS_OK;
|
|
275
1044
|
|
|
276
|
-
switch (octx->
|
|
1045
|
+
switch (octx->src[0]->type) {
|
|
277
1046
|
case HTP_TYPE_F32:
|
|
278
1047
|
err = execute_op_unary_f32(octx);
|
|
279
1048
|
break;
|