whispercpp 1.3.5 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/LICENSE +1 -1
- data/README.md +133 -3
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -7
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +56 -46
- data/ext/ruby_whisper.h +165 -2
- data/ext/ruby_whisper_context.c +297 -126
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -66
- data/ext/ruby_whisper_segment.c +6 -7
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +46 -16
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +24 -19
- data/ext/sources/examples/cli/cli.cpp +51 -9
- data/ext/sources/examples/common-ggml.cpp +4 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +213 -163
- data/ext/sources/ggml/CMakeLists.txt +29 -15
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +73 -11
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +8 -3
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +155 -16
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +25 -5
- data/ext/sources/ggml/src/ggml-alloc.c +9 -10
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
- data/ext/sources/ggml/src/ggml-common.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
- data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
- data/ext/sources/ggml/src/ggml-impl.h +68 -1
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +385 -119
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
- data/ext/sources/ggml/src/ggml.c +268 -52
- data/ext/sources/ggml/src/gguf.cpp +377 -47
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +62 -40
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +445 -55
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_context_params.rb +82 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +44 -6
- data/whispercpp.gemspec +2 -2
- metadata +426 -280
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
- data/ext/sources/examples/talk-llama/llama-context.h +0 -360
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
- data/ext/sources/examples/talk-llama/llama-model.h +0 -544
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
- data/ext/sources/examples/talk-llama/llama.h +0 -1540
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -569
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
|
@@ -9,12 +9,14 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
|
|
|
9
9
|
|
|
10
10
|
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
|
|
11
11
|
switch (type) {
|
|
12
|
+
case GGML_TYPE_Q1_0: return vec_dot_q1_0_q8_1;
|
|
12
13
|
case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1;
|
|
13
14
|
case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1;
|
|
14
15
|
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
|
|
15
16
|
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
|
|
16
17
|
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
|
|
17
18
|
case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
|
|
19
|
+
case GGML_TYPE_NVFP4: return vec_dot_nvfp4_q8_1;
|
|
18
20
|
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
|
|
19
21
|
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
|
|
20
22
|
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
|
|
@@ -33,14 +35,16 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
|
|
|
33
35
|
}
|
|
34
36
|
}
|
|
35
37
|
|
|
36
|
-
static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
|
38
|
+
static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
|
|
37
39
|
switch (type) {
|
|
40
|
+
case GGML_TYPE_Q1_0: return VDR_Q1_0_Q8_1_MMVQ;
|
|
38
41
|
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
|
|
39
42
|
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
|
|
40
43
|
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
|
|
41
44
|
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
|
|
42
45
|
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
|
|
43
46
|
case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
|
|
47
|
+
case GGML_TYPE_NVFP4: return VDR_NVFP4_Q8_1_MMVQ;
|
|
44
48
|
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
|
|
45
49
|
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
|
|
46
50
|
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
|
|
@@ -59,31 +63,290 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
|
|
59
63
|
|
|
60
64
|
enum mmvq_parameter_table_id {
|
|
61
65
|
MMVQ_PARAMETERS_GENERIC = 0,
|
|
66
|
+
MMVQ_PARAMETERS_TURING,
|
|
62
67
|
MMVQ_PARAMETERS_GCN,
|
|
63
|
-
MMVQ_PARAMETERS_RDNA2
|
|
68
|
+
MMVQ_PARAMETERS_RDNA2,
|
|
69
|
+
MMVQ_PARAMETERS_RDNA3_0,
|
|
70
|
+
MMVQ_PARAMETERS_RDNA4
|
|
64
71
|
};
|
|
65
72
|
|
|
66
73
|
static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
|
|
67
|
-
#if defined(
|
|
74
|
+
#if defined(RDNA4)
|
|
75
|
+
return MMVQ_PARAMETERS_RDNA4;
|
|
76
|
+
#elif defined(RDNA3_0)
|
|
77
|
+
return MMVQ_PARAMETERS_RDNA3_0;
|
|
78
|
+
#elif defined(RDNA2) || defined(RDNA3_5)
|
|
68
79
|
return MMVQ_PARAMETERS_RDNA2;
|
|
69
80
|
#elif defined(GCN) || defined(CDNA)
|
|
70
81
|
return MMVQ_PARAMETERS_GCN;
|
|
82
|
+
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING && __CUDA_ARCH__ < GGML_CUDA_CC_AMPERE
|
|
83
|
+
return MMVQ_PARAMETERS_TURING;
|
|
71
84
|
#else
|
|
72
85
|
return MMVQ_PARAMETERS_GENERIC;
|
|
73
86
|
#endif
|
|
74
87
|
}
|
|
75
88
|
|
|
76
89
|
static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
|
|
77
|
-
if (
|
|
90
|
+
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
|
|
91
|
+
return MMVQ_PARAMETERS_RDNA4;
|
|
92
|
+
}
|
|
93
|
+
if (GGML_CUDA_CC_IS_RDNA3_0(cc)) {
|
|
94
|
+
return MMVQ_PARAMETERS_RDNA3_0;
|
|
95
|
+
}
|
|
96
|
+
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc)) {
|
|
78
97
|
return MMVQ_PARAMETERS_RDNA2;
|
|
79
98
|
}
|
|
80
99
|
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
|
|
81
100
|
return MMVQ_PARAMETERS_GCN;
|
|
82
101
|
}
|
|
102
|
+
if (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING && ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_AMPERE) {
|
|
103
|
+
return MMVQ_PARAMETERS_TURING;
|
|
104
|
+
}
|
|
83
105
|
return MMVQ_PARAMETERS_GENERIC;
|
|
84
106
|
}
|
|
85
107
|
|
|
86
|
-
|
|
108
|
+
// Per-architecture maximum batch size for which MMVQ should be used for MUL_MAT_ID.
|
|
109
|
+
// Returns a value <= MMVQ_MAX_BATCH_SIZE. Default is MMVQ_MAX_BATCH_SIZE.
|
|
110
|
+
// Check https://github.com/ggml-org/llama.cpp/pull/20905#issuecomment-4145835627 for details
|
|
111
|
+
|
|
112
|
+
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(ggml_type type) {
|
|
113
|
+
switch (type) {
|
|
114
|
+
case GGML_TYPE_IQ1_S: return 6;
|
|
115
|
+
case GGML_TYPE_IQ1_M: return 6;
|
|
116
|
+
case GGML_TYPE_IQ2_S: return 4;
|
|
117
|
+
case GGML_TYPE_IQ2_XS: return 5;
|
|
118
|
+
case GGML_TYPE_IQ2_XXS: return 5;
|
|
119
|
+
case GGML_TYPE_IQ3_S: return 4;
|
|
120
|
+
case GGML_TYPE_IQ3_XXS: return 4;
|
|
121
|
+
case GGML_TYPE_IQ4_NL: return 6;
|
|
122
|
+
case GGML_TYPE_IQ4_XS: return 5;
|
|
123
|
+
case GGML_TYPE_MXFP4: return 4;
|
|
124
|
+
case GGML_TYPE_NVFP4: return 4;
|
|
125
|
+
case GGML_TYPE_Q2_K: return 4;
|
|
126
|
+
case GGML_TYPE_Q3_K: return 4;
|
|
127
|
+
case GGML_TYPE_Q4_0: return 6;
|
|
128
|
+
case GGML_TYPE_Q4_1: return 6;
|
|
129
|
+
case GGML_TYPE_Q4_K: return 5;
|
|
130
|
+
case GGML_TYPE_Q5_0: return 6;
|
|
131
|
+
case GGML_TYPE_Q5_1: return 6;
|
|
132
|
+
case GGML_TYPE_Q5_K: return 5;
|
|
133
|
+
case GGML_TYPE_Q6_K: return 4;
|
|
134
|
+
case GGML_TYPE_Q8_0: return 4;
|
|
135
|
+
default: return MMVQ_MAX_BATCH_SIZE;
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggml_type type) {
|
|
140
|
+
switch (type) {
|
|
141
|
+
case GGML_TYPE_IQ2_S: return 7;
|
|
142
|
+
case GGML_TYPE_IQ3_S: return 6;
|
|
143
|
+
case GGML_TYPE_IQ3_XXS: return 7;
|
|
144
|
+
case GGML_TYPE_MXFP4: return 7;
|
|
145
|
+
case GGML_TYPE_NVFP4: return 8;
|
|
146
|
+
case GGML_TYPE_Q2_K: return 7;
|
|
147
|
+
case GGML_TYPE_Q3_K: return 5;
|
|
148
|
+
default: return MMVQ_MAX_BATCH_SIZE;
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_gcn(ggml_type type) {
|
|
153
|
+
switch (type) {
|
|
154
|
+
case GGML_TYPE_IQ1_S: return 5;
|
|
155
|
+
case GGML_TYPE_IQ1_M: return 5;
|
|
156
|
+
case GGML_TYPE_IQ2_S: return 4;
|
|
157
|
+
case GGML_TYPE_IQ2_XS: return 4;
|
|
158
|
+
case GGML_TYPE_IQ2_XXS: return 4;
|
|
159
|
+
case GGML_TYPE_IQ3_S: return 4;
|
|
160
|
+
case GGML_TYPE_IQ3_XXS: return 4;
|
|
161
|
+
case GGML_TYPE_IQ4_NL: return 6;
|
|
162
|
+
case GGML_TYPE_IQ4_XS: return 4;
|
|
163
|
+
case GGML_TYPE_Q2_K: return 4;
|
|
164
|
+
case GGML_TYPE_Q3_K: return 4;
|
|
165
|
+
case GGML_TYPE_Q4_0: return 5;
|
|
166
|
+
case GGML_TYPE_Q4_1: return 5;
|
|
167
|
+
case GGML_TYPE_Q4_K: return 4;
|
|
168
|
+
case GGML_TYPE_Q5_K: return 4;
|
|
169
|
+
case GGML_TYPE_Q6_K: return 4;
|
|
170
|
+
case GGML_TYPE_Q8_0: return 4;
|
|
171
|
+
default: return MMVQ_MAX_BATCH_SIZE;
|
|
172
|
+
}
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_cdna(ggml_type type) {
|
|
176
|
+
switch (type) {
|
|
177
|
+
case GGML_TYPE_IQ2_S: return 5;
|
|
178
|
+
case GGML_TYPE_IQ2_XS: return 5;
|
|
179
|
+
case GGML_TYPE_IQ2_XXS: return 5;
|
|
180
|
+
case GGML_TYPE_IQ3_S: return 4;
|
|
181
|
+
case GGML_TYPE_IQ3_XXS: return 5;
|
|
182
|
+
default: return MMVQ_MAX_BATCH_SIZE;
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna1_rdna2(ggml_type type) {
|
|
187
|
+
switch (type) {
|
|
188
|
+
case GGML_TYPE_IQ2_S: return 4;
|
|
189
|
+
case GGML_TYPE_IQ2_XS: return 4;
|
|
190
|
+
case GGML_TYPE_IQ2_XXS: return 4;
|
|
191
|
+
case GGML_TYPE_IQ3_S: return 4;
|
|
192
|
+
case GGML_TYPE_IQ3_XXS: return 4;
|
|
193
|
+
case GGML_TYPE_Q2_K: return 7;
|
|
194
|
+
case GGML_TYPE_Q3_K: return 4;
|
|
195
|
+
case GGML_TYPE_Q4_K: return 5;
|
|
196
|
+
case GGML_TYPE_Q5_K: return 6;
|
|
197
|
+
case GGML_TYPE_Q6_K: return 5;
|
|
198
|
+
default: return MMVQ_MAX_BATCH_SIZE;
|
|
199
|
+
}
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna3(ggml_type type) {
|
|
203
|
+
switch (type) {
|
|
204
|
+
case GGML_TYPE_IQ1_S: return 6;
|
|
205
|
+
case GGML_TYPE_IQ1_M: return 6;
|
|
206
|
+
case GGML_TYPE_IQ2_S: return 4;
|
|
207
|
+
case GGML_TYPE_IQ2_XS: return 4;
|
|
208
|
+
case GGML_TYPE_IQ2_XXS: return 4;
|
|
209
|
+
case GGML_TYPE_IQ3_S: return 4;
|
|
210
|
+
case GGML_TYPE_IQ3_XXS: return 4;
|
|
211
|
+
case GGML_TYPE_IQ4_NL: return 6;
|
|
212
|
+
case GGML_TYPE_IQ4_XS: return 6;
|
|
213
|
+
case GGML_TYPE_Q4_K: return 4;
|
|
214
|
+
case GGML_TYPE_Q5_K: return 4;
|
|
215
|
+
case GGML_TYPE_Q6_K: return 4;
|
|
216
|
+
default: return MMVQ_MAX_BATCH_SIZE;
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type type) {
|
|
221
|
+
switch (type) {
|
|
222
|
+
case GGML_TYPE_IQ1_S: return 7;
|
|
223
|
+
case GGML_TYPE_IQ1_M: return 7;
|
|
224
|
+
case GGML_TYPE_IQ2_S: return 4;
|
|
225
|
+
case GGML_TYPE_IQ2_XS: return 4;
|
|
226
|
+
case GGML_TYPE_IQ2_XXS: return 4;
|
|
227
|
+
case GGML_TYPE_IQ3_S: return 4;
|
|
228
|
+
case GGML_TYPE_IQ3_XXS: return 4;
|
|
229
|
+
case GGML_TYPE_IQ4_NL: return 7;
|
|
230
|
+
case GGML_TYPE_IQ4_XS: return 5;
|
|
231
|
+
case GGML_TYPE_MXFP4: return 5;
|
|
232
|
+
case GGML_TYPE_NVFP4: return 5;
|
|
233
|
+
case GGML_TYPE_Q3_K: return 4;
|
|
234
|
+
case GGML_TYPE_Q4_0: return 7;
|
|
235
|
+
case GGML_TYPE_Q4_1: return 7;
|
|
236
|
+
case GGML_TYPE_Q4_K: return 4;
|
|
237
|
+
case GGML_TYPE_Q5_0: return 7;
|
|
238
|
+
case GGML_TYPE_Q5_1: return 7;
|
|
239
|
+
case GGML_TYPE_Q5_K: return 5;
|
|
240
|
+
case GGML_TYPE_Q6_K: return 5;
|
|
241
|
+
case GGML_TYPE_Q8_0: return 7;
|
|
242
|
+
default: return MMVQ_MAX_BATCH_SIZE;
|
|
243
|
+
}
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
// Host function: returns the max batch size for the current arch+type at runtime.
|
|
247
|
+
int get_mmvq_mmid_max_batch(ggml_type type, int cc) {
|
|
248
|
+
// NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID.
|
|
249
|
+
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
|
250
|
+
if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
|
251
|
+
return MMVQ_MAX_BATCH_SIZE;
|
|
252
|
+
}
|
|
253
|
+
if (cc >= GGML_CUDA_CC_TURING) {
|
|
254
|
+
return get_mmvq_mmid_max_batch_turing_plus(type);
|
|
255
|
+
}
|
|
256
|
+
return get_mmvq_mmid_max_batch_pascal_older(type);
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
// AMD
|
|
260
|
+
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
|
261
|
+
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
|
|
262
|
+
return get_mmvq_mmid_max_batch_rdna4(type);
|
|
263
|
+
}
|
|
264
|
+
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
|
|
265
|
+
return get_mmvq_mmid_max_batch_rdna3(type);
|
|
266
|
+
}
|
|
267
|
+
if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
|
|
268
|
+
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
|
|
269
|
+
}
|
|
270
|
+
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
|
271
|
+
return get_mmvq_mmid_max_batch_cdna(type);
|
|
272
|
+
}
|
|
273
|
+
if (GGML_CUDA_CC_IS_GCN(cc)) {
|
|
274
|
+
return get_mmvq_mmid_max_batch_gcn(type);
|
|
275
|
+
}
|
|
276
|
+
}
|
|
277
|
+
return MMVQ_MAX_BATCH_SIZE;
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
bool ggml_cuda_should_use_mmvq(enum ggml_type type, int cc, int64_t ne11) {
|
|
281
|
+
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
|
282
|
+
if (GGML_CUDA_CC_IS_CDNA1(cc)) {
|
|
283
|
+
switch (type) {
|
|
284
|
+
case GGML_TYPE_Q4_0:
|
|
285
|
+
case GGML_TYPE_Q4_1:
|
|
286
|
+
return ne11 <= 7;
|
|
287
|
+
case GGML_TYPE_Q5_1:
|
|
288
|
+
return ne11 <= 7;
|
|
289
|
+
case GGML_TYPE_Q8_0:
|
|
290
|
+
return ne11 <= 6;
|
|
291
|
+
case GGML_TYPE_Q2_K:
|
|
292
|
+
return ne11 <= 4;
|
|
293
|
+
case GGML_TYPE_Q3_K:
|
|
294
|
+
return ne11 <= 3;
|
|
295
|
+
case GGML_TYPE_Q4_K:
|
|
296
|
+
return ne11 <= 2;
|
|
297
|
+
case GGML_TYPE_Q5_K:
|
|
298
|
+
return ne11 <= 3;
|
|
299
|
+
case GGML_TYPE_Q6_K:
|
|
300
|
+
return ne11 <= 4;
|
|
301
|
+
case GGML_TYPE_IQ1_S:
|
|
302
|
+
return ne11 <= 5;
|
|
303
|
+
case GGML_TYPE_IQ2_XXS:
|
|
304
|
+
case GGML_TYPE_IQ3_S:
|
|
305
|
+
case GGML_TYPE_IQ4_XS:
|
|
306
|
+
return ne11 <= 6;
|
|
307
|
+
default:
|
|
308
|
+
return ne11 <= MMVQ_MAX_BATCH_SIZE;
|
|
309
|
+
}
|
|
310
|
+
}
|
|
311
|
+
switch (type) { // tuned for CDNA2
|
|
312
|
+
case GGML_TYPE_Q2_K:
|
|
313
|
+
return ne11 <= 5;
|
|
314
|
+
case GGML_TYPE_Q3_K:
|
|
315
|
+
case GGML_TYPE_Q4_K:
|
|
316
|
+
case GGML_TYPE_Q5_K:
|
|
317
|
+
return ne11 <= 3;
|
|
318
|
+
case GGML_TYPE_Q6_K:
|
|
319
|
+
return ne11 <= 5;
|
|
320
|
+
default:
|
|
321
|
+
return ne11 <= MMVQ_MAX_BATCH_SIZE;
|
|
322
|
+
}
|
|
323
|
+
}
|
|
324
|
+
return ne11 <= MMVQ_MAX_BATCH_SIZE;
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
// Device constexpr: returns the max batch size for the current arch+type at compile time.
|
|
328
|
+
template <ggml_type type>
|
|
329
|
+
static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() {
|
|
330
|
+
#if defined(RDNA4)
|
|
331
|
+
return get_mmvq_mmid_max_batch_rdna4(type);
|
|
332
|
+
#elif defined(RDNA3)
|
|
333
|
+
return get_mmvq_mmid_max_batch_rdna3(type);
|
|
334
|
+
#elif defined(RDNA2) || defined(RDNA1)
|
|
335
|
+
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
|
|
336
|
+
#elif defined(CDNA)
|
|
337
|
+
return get_mmvq_mmid_max_batch_cdna(type);
|
|
338
|
+
#elif defined(GCN)
|
|
339
|
+
return get_mmvq_mmid_max_batch_gcn(type);
|
|
340
|
+
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ >= GGML_CUDA_CC_ADA_LOVELACE)
|
|
341
|
+
return MMVQ_MAX_BATCH_SIZE;
|
|
342
|
+
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
|
343
|
+
return get_mmvq_mmid_max_batch_turing_plus(type);
|
|
344
|
+
#else
|
|
345
|
+
return get_mmvq_mmid_max_batch_pascal_older(type);
|
|
346
|
+
#endif
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
|
|
87
350
|
if (table_id == MMVQ_PARAMETERS_GENERIC) {
|
|
88
351
|
switch (ncols_dst) {
|
|
89
352
|
case 1:
|
|
@@ -114,14 +377,86 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_paramet
|
|
|
114
377
|
return 1;
|
|
115
378
|
}
|
|
116
379
|
}
|
|
380
|
+
if (table_id == MMVQ_PARAMETERS_RDNA4) {
|
|
381
|
+
// nwarps=8 benefits types with simple vec_dot on RDNA4 (ncols_dst=1).
|
|
382
|
+
// Types with complex vec_dot (Q3_K, IQ2_*, IQ3_*) regress due to register
|
|
383
|
+
// pressure and lookup table contention at higher thread counts.
|
|
384
|
+
if (ncols_dst == 1) {
|
|
385
|
+
switch (type) {
|
|
386
|
+
case GGML_TYPE_Q4_0:
|
|
387
|
+
case GGML_TYPE_Q4_1:
|
|
388
|
+
case GGML_TYPE_Q5_0:
|
|
389
|
+
case GGML_TYPE_Q5_1:
|
|
390
|
+
case GGML_TYPE_Q8_0:
|
|
391
|
+
case GGML_TYPE_Q2_K:
|
|
392
|
+
case GGML_TYPE_Q4_K:
|
|
393
|
+
case GGML_TYPE_Q5_K:
|
|
394
|
+
case GGML_TYPE_Q6_K:
|
|
395
|
+
case GGML_TYPE_IQ4_NL:
|
|
396
|
+
case GGML_TYPE_IQ4_XS:
|
|
397
|
+
return 8;
|
|
398
|
+
default:
|
|
399
|
+
return 1;
|
|
400
|
+
}
|
|
401
|
+
}
|
|
402
|
+
return 1;
|
|
403
|
+
}
|
|
404
|
+
if (table_id == MMVQ_PARAMETERS_RDNA3_0) {
|
|
405
|
+
// RDNA3 (W7900): stricter whitelist than RDNA4.
|
|
406
|
+
// Q2_K / Q5_K / IQ4_XS regress in full quant sweeps.
|
|
407
|
+
if (ncols_dst == 1) {
|
|
408
|
+
switch (type) {
|
|
409
|
+
case GGML_TYPE_Q4_0:
|
|
410
|
+
case GGML_TYPE_Q4_1:
|
|
411
|
+
case GGML_TYPE_Q5_0:
|
|
412
|
+
case GGML_TYPE_Q5_1:
|
|
413
|
+
case GGML_TYPE_Q8_0:
|
|
414
|
+
return 8;
|
|
415
|
+
case GGML_TYPE_Q6_K:
|
|
416
|
+
return 2;
|
|
417
|
+
case GGML_TYPE_IQ4_NL:
|
|
418
|
+
return 8;
|
|
419
|
+
default:
|
|
420
|
+
return 1;
|
|
421
|
+
}
|
|
422
|
+
}
|
|
423
|
+
return 1;
|
|
424
|
+
}
|
|
425
|
+
if (table_id == MMVQ_PARAMETERS_TURING) {
|
|
426
|
+
if (ncols_dst == 1) {
|
|
427
|
+
switch (type) {
|
|
428
|
+
case GGML_TYPE_Q2_K:
|
|
429
|
+
case GGML_TYPE_Q3_K:
|
|
430
|
+
case GGML_TYPE_Q4_K:
|
|
431
|
+
case GGML_TYPE_Q5_K:
|
|
432
|
+
case GGML_TYPE_Q6_K:
|
|
433
|
+
return 2;
|
|
434
|
+
default:
|
|
435
|
+
return 4;
|
|
436
|
+
}
|
|
437
|
+
}
|
|
438
|
+
switch (ncols_dst) {
|
|
439
|
+
case 2:
|
|
440
|
+
case 3:
|
|
441
|
+
case 4:
|
|
442
|
+
return 4;
|
|
443
|
+
case 5:
|
|
444
|
+
case 6:
|
|
445
|
+
case 7:
|
|
446
|
+
case 8:
|
|
447
|
+
return 2;
|
|
448
|
+
default:
|
|
449
|
+
return 1;
|
|
450
|
+
}
|
|
451
|
+
}
|
|
117
452
|
return 1;
|
|
118
453
|
}
|
|
119
454
|
|
|
120
|
-
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {
|
|
121
|
-
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
|
|
455
|
+
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) {
|
|
456
|
+
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN || table_id == MMVQ_PARAMETERS_TURING) {
|
|
122
457
|
switch (ncols_dst) {
|
|
123
458
|
case 1:
|
|
124
|
-
return 1;
|
|
459
|
+
return small_k ? nwarps : 1;
|
|
125
460
|
case 2:
|
|
126
461
|
case 3:
|
|
127
462
|
case 4:
|
|
@@ -137,22 +472,26 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
|
|
|
137
472
|
return 1;
|
|
138
473
|
}
|
|
139
474
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
|
|
475
|
+
template <ggml_type type, int ncols_dst, bool has_fusion, bool small_k = false>
|
|
476
|
+
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
|
|
143
477
|
static __global__ void mul_mat_vec_q(
|
|
144
|
-
const void *
|
|
478
|
+
const void * vx_ptr, const void * vy_ptr, const int32_t * ids_ptr, const ggml_cuda_mm_fusion_args_device fusion, float * dst_ptr,
|
|
145
479
|
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
|
|
146
480
|
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
|
|
147
481
|
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
|
|
148
|
-
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst
|
|
482
|
+
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
|
|
483
|
+
const uint32_t ids_stride) {
|
|
484
|
+
const void * GGML_CUDA_RESTRICT vx = vx_ptr;
|
|
485
|
+
const void * GGML_CUDA_RESTRICT vy = vy_ptr;
|
|
486
|
+
const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr;
|
|
487
|
+
float * GGML_CUDA_RESTRICT dst = dst_ptr;
|
|
149
488
|
|
|
150
489
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
151
490
|
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
|
152
491
|
constexpr int vdr = get_vdr_mmvq(type);
|
|
153
492
|
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
|
|
154
|
-
constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
|
|
155
|
-
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
|
|
493
|
+
constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
|
|
494
|
+
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
|
|
156
495
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
157
496
|
|
|
158
497
|
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
|
@@ -162,18 +501,24 @@ static __global__ void mul_mat_vec_q(
|
|
|
162
501
|
const int blocks_per_row_x = ncols_x / qk;
|
|
163
502
|
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
|
|
164
503
|
|
|
165
|
-
// The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
|
|
166
504
|
const uint32_t channel_dst = blockIdx.y;
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
505
|
+
|
|
506
|
+
uint32_t channel_x;
|
|
507
|
+
uint32_t channel_y;
|
|
508
|
+
uint32_t sample_dst;
|
|
509
|
+
|
|
510
|
+
ggml_cuda_pdl_sync();
|
|
511
|
+
channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
|
|
512
|
+
channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
|
|
513
|
+
sample_dst = blockIdx.z;
|
|
514
|
+
|
|
170
515
|
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
|
|
171
516
|
const uint32_t sample_y = sample_dst;
|
|
172
517
|
|
|
173
518
|
bool use_gate = false;
|
|
174
519
|
bool use_bias = false;
|
|
175
520
|
bool use_gate_bias = false;
|
|
176
|
-
const void * vgate = nullptr;
|
|
521
|
+
[[maybe_unused]] const void * vgate = nullptr;
|
|
177
522
|
const float * x_bias = nullptr;
|
|
178
523
|
const float * gate_bias = nullptr;
|
|
179
524
|
ggml_glu_op active_glu;
|
|
@@ -188,11 +533,11 @@ static __global__ void mul_mat_vec_q(
|
|
|
188
533
|
active_glu = fusion.glu_op;
|
|
189
534
|
}
|
|
190
535
|
|
|
191
|
-
const uint32_t channel_bias = ids ? channel_x : channel_dst;
|
|
192
536
|
|
|
193
|
-
float x_biases[ncols_dst] = { 0.0f };
|
|
194
|
-
float gate_biases[ncols_dst] = { 0.0f };
|
|
537
|
+
[[maybe_unused]] float x_biases[ncols_dst] = { 0.0f };
|
|
538
|
+
[[maybe_unused]] float gate_biases[ncols_dst] = { 0.0f };
|
|
195
539
|
if constexpr (has_fusion) {
|
|
540
|
+
const uint32_t channel_bias = ids ? channel_x : channel_dst;
|
|
196
541
|
if (use_bias) {
|
|
197
542
|
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
|
|
198
543
|
// 1. Hide latency by prefetching bias and gate here
|
|
@@ -247,12 +592,7 @@ static __global__ void mul_mat_vec_q(
|
|
|
247
592
|
}
|
|
248
593
|
|
|
249
594
|
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
|
|
250
|
-
__shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
|
|
251
|
-
if constexpr (!has_fusion) {
|
|
252
|
-
(void) tmp_shared_gate;
|
|
253
|
-
} else if (!use_gate) {
|
|
254
|
-
(void) tmp_shared_gate;
|
|
255
|
-
}
|
|
595
|
+
[[maybe_unused]] __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
|
|
256
596
|
|
|
257
597
|
if (threadIdx.y > 0) {
|
|
258
598
|
#pragma unroll
|
|
@@ -334,41 +674,139 @@ static __global__ void mul_mat_vec_q(
|
|
|
334
674
|
}
|
|
335
675
|
}
|
|
336
676
|
|
|
677
|
+
// Dedicated MoE multi-token kernel.
|
|
678
|
+
// Grid: (ceil(nrows_x / c_rows_per_block), nchannels_dst)
|
|
679
|
+
// Block: (warp_size, ncols_dst) - each warp handles one token independently.
|
|
680
|
+
// No shared memory reduction needed since each warp works alone.
|
|
681
|
+
template <ggml_type type, int c_rows_per_block>
|
|
682
|
+
__launch_bounds__(get_mmvq_mmid_max_batch_for_device<type>()*ggml_cuda_get_physical_warp_size(), 1)
|
|
683
|
+
static __global__ void mul_mat_vec_q_moe(
|
|
684
|
+
const void * vx_ptr, const void * vy_ptr, const int32_t * ids_ptr,
|
|
685
|
+
float * dst_ptr,
|
|
686
|
+
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x,
|
|
687
|
+
const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst,
|
|
688
|
+
const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst,
|
|
689
|
+
const uint32_t ncols_dst, const uint32_t ids_stride) {
|
|
690
|
+
const void * GGML_CUDA_RESTRICT vx = vx_ptr;
|
|
691
|
+
const void * GGML_CUDA_RESTRICT vy = vy_ptr;
|
|
692
|
+
const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr;
|
|
693
|
+
float * GGML_CUDA_RESTRICT dst = dst_ptr;
|
|
694
|
+
|
|
695
|
+
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
696
|
+
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
|
697
|
+
constexpr int vdr = get_vdr_mmvq(type);
|
|
698
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
699
|
+
|
|
700
|
+
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
|
701
|
+
|
|
702
|
+
const uint32_t token_idx = threadIdx.y;
|
|
703
|
+
const int row0 = c_rows_per_block*blockIdx.x;
|
|
704
|
+
const int blocks_per_row_x = ncols_x / qk;
|
|
705
|
+
constexpr int blocks_per_iter = vdr * warp_size / qi;
|
|
706
|
+
|
|
707
|
+
const uint32_t channel_dst = blockIdx.y;
|
|
708
|
+
|
|
709
|
+
if (token_idx >= ncols_dst) {
|
|
710
|
+
return;
|
|
711
|
+
}
|
|
712
|
+
|
|
713
|
+
ggml_cuda_pdl_sync();
|
|
714
|
+
const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride];
|
|
715
|
+
const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y);
|
|
716
|
+
|
|
717
|
+
const block_q8_1 * y = ((const block_q8_1 *) vy) + channel_y*stride_channel_y + token_idx*stride_col_y;
|
|
718
|
+
const int kbx_offset = channel_x*stride_channel_x + row0*stride_row_x;
|
|
719
|
+
|
|
720
|
+
// partial sum for each thread
|
|
721
|
+
float tmp[c_rows_per_block] = {0.0f};
|
|
722
|
+
|
|
723
|
+
for (int kbx = threadIdx.x / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
|
|
724
|
+
const int kby = kbx * (qk/QK8_1);
|
|
725
|
+
const int kqs = vdr * (threadIdx.x % (qi/vdr));
|
|
726
|
+
|
|
727
|
+
#pragma unroll
|
|
728
|
+
for (int i = 0; i < c_rows_per_block; ++i) {
|
|
729
|
+
tmp[i] += vec_dot_q_cuda(vx, &y[kby], kbx_offset + i*stride_row_x + kbx, kqs);
|
|
730
|
+
}
|
|
731
|
+
}
|
|
732
|
+
|
|
733
|
+
ggml_cuda_pdl_lc();
|
|
734
|
+
|
|
735
|
+
// Warp-level reduction only - no shared memory needed
|
|
736
|
+
#pragma unroll
|
|
737
|
+
for (int i = 0; i < c_rows_per_block; ++i) {
|
|
738
|
+
tmp[i] = warp_reduce_sum<warp_size>(tmp[i]);
|
|
739
|
+
}
|
|
740
|
+
|
|
741
|
+
// Write results
|
|
742
|
+
if (threadIdx.x < c_rows_per_block && (c_rows_per_block == 1 || uint32_t(row0 + threadIdx.x) < nrows_x)) {
|
|
743
|
+
dst[channel_dst*stride_channel_dst + token_idx*stride_col_dst + row0 + threadIdx.x] = tmp[threadIdx.x];
|
|
744
|
+
}
|
|
745
|
+
}
|
|
746
|
+
|
|
747
|
+
template<ggml_type type>
|
|
337
748
|
static std::pair<dim3, dim3> calc_launch_params(
|
|
338
|
-
const int ncols_dst, const int nrows_x, const int
|
|
339
|
-
const int warp_size, const mmvq_parameter_table_id table_id) {
|
|
340
|
-
const
|
|
341
|
-
const
|
|
342
|
-
const
|
|
749
|
+
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
|
|
750
|
+
const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) {
|
|
751
|
+
const int nwarps = calc_nwarps(type, ncols_dst, table_id);
|
|
752
|
+
const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
|
|
753
|
+
const int64_t nblocks = (nrows_x + rpb - 1) / rpb;
|
|
754
|
+
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
|
|
755
|
+
const dim3 block_dims(warp_size, nwarps, 1);
|
|
343
756
|
return {block_nums, block_dims};
|
|
344
757
|
}
|
|
345
758
|
|
|
346
|
-
template<ggml_type type, int c_ncols_dst>
|
|
759
|
+
template<ggml_type type, int c_ncols_dst, bool small_k = false>
|
|
347
760
|
static void mul_mat_vec_q_switch_fusion(
|
|
348
761
|
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
|
349
762
|
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
|
|
350
763
|
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
|
|
351
764
|
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
|
|
352
765
|
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
|
|
353
|
-
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared,
|
|
766
|
+
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared,
|
|
767
|
+
const uint32_t ids_stride, cudaStream_t stream) {
|
|
354
768
|
|
|
355
769
|
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
|
356
770
|
if constexpr (c_ncols_dst == 1) {
|
|
357
771
|
if (has_fusion) {
|
|
358
|
-
|
|
359
|
-
|
|
772
|
+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, nbytes_shared, stream);
|
|
773
|
+
ggml_cuda_kernel_launch(mul_mat_vec_q<type, c_ncols_dst, true, small_k>, launch_params,
|
|
774
|
+
vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
|
360
775
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
361
|
-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
776
|
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
|
|
362
777
|
return;
|
|
363
778
|
}
|
|
364
779
|
}
|
|
365
780
|
|
|
366
781
|
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
|
|
367
782
|
|
|
368
|
-
|
|
369
|
-
|
|
783
|
+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, nbytes_shared, stream);
|
|
784
|
+
ggml_cuda_kernel_launch(mul_mat_vec_q<type, c_ncols_dst, false, small_k>, launch_params,
|
|
785
|
+
vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
|
370
786
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
371
|
-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
787
|
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
|
|
788
|
+
}
|
|
789
|
+
|
|
790
|
+
template <ggml_type type>
|
|
791
|
+
static void mul_mat_vec_q_moe_launch(
|
|
792
|
+
const void * vx, const void * vy, const int32_t * ids, float * dst,
|
|
793
|
+
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x,
|
|
794
|
+
const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst,
|
|
795
|
+
const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst,
|
|
796
|
+
const uint32_t ncols_dst, const uint32_t ids_stride,
|
|
797
|
+
const int warp_size, const int nchannels_dst, cudaStream_t stream) {
|
|
798
|
+
|
|
799
|
+
constexpr int rows_per_block = 2; // 2 gives best perf based on tuning
|
|
800
|
+
const int64_t nblocks_rows = (nrows_x + rows_per_block - 1) / rows_per_block;
|
|
801
|
+
const dim3 block_nums(nblocks_rows, nchannels_dst);
|
|
802
|
+
const dim3 block_dims(warp_size, ncols_dst);
|
|
803
|
+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
|
|
804
|
+
|
|
805
|
+
ggml_cuda_kernel_launch(mul_mat_vec_q_moe<type, rows_per_block>, launch_params,
|
|
806
|
+
vx, vy, ids, dst, ncols_x, nchannels_y, nrows_x,
|
|
807
|
+
stride_row_x, stride_col_y, stride_col_dst,
|
|
808
|
+
stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
809
|
+
ncols_dst, ids_stride);
|
|
372
810
|
}
|
|
373
811
|
|
|
374
812
|
template <ggml_type type>
|
|
@@ -379,7 +817,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
|
|
|
379
817
|
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
|
|
380
818
|
const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
|
381
819
|
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
|
382
|
-
cudaStream_t stream) {
|
|
820
|
+
const int ids_stride, cudaStream_t stream) {
|
|
383
821
|
|
|
384
822
|
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
|
|
385
823
|
GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
|
|
@@ -389,76 +827,144 @@ static void mul_mat_vec_q_switch_ncols_dst(
|
|
|
389
827
|
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
|
|
390
828
|
|
|
391
829
|
const int device = ggml_cuda_get_device();
|
|
830
|
+
const int cc = ggml_cuda_info().devices[device].cc;
|
|
392
831
|
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
|
393
|
-
const mmvq_parameter_table_id table_id
|
|
832
|
+
const mmvq_parameter_table_id table_id = get_device_table_id(cc);
|
|
394
833
|
|
|
395
834
|
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
|
835
|
+
const bool has_ids = ids != nullptr;
|
|
836
|
+
|
|
837
|
+
const auto should_use_small_k = [&](int c_ncols_dst) {
|
|
838
|
+
// When K is small, increase rows_per_block to match nwarps so each warp has more work to do
|
|
839
|
+
// Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
|
|
840
|
+
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
841
|
+
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
|
842
|
+
constexpr int vdr = get_vdr_mmvq(type);
|
|
843
|
+
const int blocks_per_row_x = ncols_x / qk;
|
|
844
|
+
const int blocks_per_iter_1warp = vdr * warp_size / qi;
|
|
845
|
+
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
|
|
846
|
+
bool use = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
|
|
847
|
+
|
|
848
|
+
constexpr std::array<ggml_type, 2> iq_slow_turing = {
|
|
849
|
+
GGML_TYPE_IQ3_XXS,
|
|
850
|
+
GGML_TYPE_IQ3_S,
|
|
851
|
+
};
|
|
852
|
+
constexpr std::array<ggml_type, 8> iq_slow_other = {
|
|
853
|
+
GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS,
|
|
854
|
+
GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
|
|
855
|
+
};
|
|
856
|
+
constexpr std::array<ggml_type, 3> slow_pascal = {
|
|
857
|
+
GGML_TYPE_IQ3_S,
|
|
858
|
+
GGML_TYPE_Q2_K,
|
|
859
|
+
GGML_TYPE_Q3_K,
|
|
860
|
+
};
|
|
861
|
+
|
|
862
|
+
const bool is_nvidia_turing_plus = GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_TURING;
|
|
863
|
+
const bool is_nvidia_pascal_older = GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA;
|
|
864
|
+
|
|
865
|
+
if (is_nvidia_turing_plus) {
|
|
866
|
+
if (ncols_dst == 1 &&
|
|
867
|
+
std::find(iq_slow_turing.begin(), iq_slow_turing.end(), type) != iq_slow_turing.end()) {
|
|
868
|
+
use = false;
|
|
869
|
+
}
|
|
870
|
+
} else if ((ncols_dst == 1 && std::find(iq_slow_other.begin(), iq_slow_other.end(), type) != iq_slow_other.end()) ||
|
|
871
|
+
(is_nvidia_pascal_older && std::find(slow_pascal.begin(), slow_pascal.end(), type) != slow_pascal.end()) ||
|
|
872
|
+
GGML_CUDA_CC_IS_RDNA(cc)) {
|
|
873
|
+
use = false;
|
|
874
|
+
}
|
|
875
|
+
|
|
876
|
+
return use;
|
|
877
|
+
};
|
|
878
|
+
|
|
879
|
+
if (has_ids && ncols_dst > 1) {
|
|
880
|
+
// Multi-token MUL_MAT_ID path - dedicated MoE kernel
|
|
881
|
+
mul_mat_vec_q_moe_launch<type>(
|
|
882
|
+
vx, vy, ids, dst, ncols_x, nchannels_y_fd, nrows_x,
|
|
883
|
+
stride_row_x, stride_col_y, stride_col_dst,
|
|
884
|
+
stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
885
|
+
ncols_dst, ids_stride, warp_size, nchannels_dst, stream);
|
|
886
|
+
return;
|
|
887
|
+
}
|
|
396
888
|
|
|
397
|
-
GGML_ASSERT(!ids || ncols_dst == 1);
|
|
398
889
|
switch (ncols_dst) {
|
|
399
890
|
case 1: {
|
|
400
891
|
constexpr int c_ncols_dst = 1;
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
892
|
+
|
|
893
|
+
bool use_small_k = should_use_small_k(c_ncols_dst);
|
|
894
|
+
|
|
895
|
+
if (use_small_k) {
|
|
896
|
+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
|
|
897
|
+
nsamples_dst, warp_size, table_id, true);
|
|
898
|
+
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(
|
|
899
|
+
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
900
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
|
|
901
|
+
stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride,
|
|
902
|
+
stream);
|
|
903
|
+
} else {
|
|
904
|
+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
|
|
905
|
+
nsamples_dst, warp_size, table_id);
|
|
906
|
+
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
|
|
907
|
+
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
908
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
|
|
909
|
+
stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride,
|
|
910
|
+
stream);
|
|
911
|
+
}
|
|
406
912
|
} break;
|
|
407
913
|
case 2: {
|
|
408
914
|
constexpr int c_ncols_dst = 2;
|
|
409
|
-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
|
915
|
+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
|
410
916
|
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
411
917
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
412
918
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
|
413
|
-
dims.first, dims.second, 0, stream);
|
|
919
|
+
dims.first, dims.second, 0, ids_stride, stream);
|
|
414
920
|
} break;
|
|
415
921
|
case 3: {
|
|
416
922
|
constexpr int c_ncols_dst = 3;
|
|
417
|
-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
|
923
|
+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
|
418
924
|
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
419
925
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
420
926
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
|
421
|
-
dims.first, dims.second, 0, stream);
|
|
927
|
+
dims.first, dims.second, 0, ids_stride, stream);
|
|
422
928
|
} break;
|
|
423
929
|
case 4: {
|
|
424
930
|
constexpr int c_ncols_dst = 4;
|
|
425
|
-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
|
931
|
+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
|
426
932
|
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
427
933
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
428
934
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
|
429
|
-
dims.first, dims.second, 0, stream);
|
|
935
|
+
dims.first, dims.second, 0, ids_stride, stream);
|
|
430
936
|
} break;
|
|
431
937
|
case 5: {
|
|
432
938
|
constexpr int c_ncols_dst = 5;
|
|
433
|
-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
|
939
|
+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
|
434
940
|
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
435
941
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
436
942
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
|
437
|
-
dims.first, dims.second, 0, stream);
|
|
943
|
+
dims.first, dims.second, 0, ids_stride, stream);
|
|
438
944
|
} break;
|
|
439
945
|
case 6: {
|
|
440
946
|
constexpr int c_ncols_dst = 6;
|
|
441
|
-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
|
947
|
+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
|
442
948
|
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
443
949
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
444
950
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
|
445
|
-
dims.first, dims.second, 0, stream);
|
|
951
|
+
dims.first, dims.second, 0, ids_stride, stream);
|
|
446
952
|
} break;
|
|
447
953
|
case 7: {
|
|
448
954
|
constexpr int c_ncols_dst = 7;
|
|
449
|
-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
|
955
|
+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
|
450
956
|
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
451
957
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
452
958
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
|
453
|
-
dims.first, dims.second, 0, stream);
|
|
959
|
+
dims.first, dims.second, 0, ids_stride, stream);
|
|
454
960
|
} break;
|
|
455
961
|
case 8: {
|
|
456
962
|
constexpr int c_ncols_dst = 8;
|
|
457
|
-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
|
963
|
+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
|
458
964
|
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
459
965
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
460
966
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
|
461
|
-
dims.first, dims.second, 0, stream);
|
|
967
|
+
dims.first, dims.second, 0, ids_stride, stream);
|
|
462
968
|
} break;
|
|
463
969
|
default:
|
|
464
970
|
GGML_ABORT("fatal error");
|
|
@@ -474,127 +980,139 @@ static void mul_mat_vec_q_switch_type(
|
|
|
474
980
|
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
|
|
475
981
|
const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
|
476
982
|
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
|
477
|
-
cudaStream_t stream) {
|
|
983
|
+
const int ids_stride, cudaStream_t stream) {
|
|
478
984
|
switch (type_x) {
|
|
985
|
+
case GGML_TYPE_Q1_0:
|
|
986
|
+
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q1_0>
|
|
987
|
+
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
988
|
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
989
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
990
|
+
break;
|
|
479
991
|
case GGML_TYPE_Q4_0:
|
|
480
992
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
|
|
481
993
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
482
994
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
483
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
995
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
484
996
|
break;
|
|
485
997
|
case GGML_TYPE_Q4_1:
|
|
486
998
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
|
|
487
999
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
488
1000
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
489
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1001
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
490
1002
|
break;
|
|
491
1003
|
case GGML_TYPE_Q5_0:
|
|
492
1004
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
|
|
493
1005
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
494
1006
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
495
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1007
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
496
1008
|
break;
|
|
497
1009
|
case GGML_TYPE_Q5_1:
|
|
498
1010
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
|
|
499
1011
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
500
1012
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
501
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1013
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
502
1014
|
break;
|
|
503
1015
|
case GGML_TYPE_Q8_0:
|
|
504
1016
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
|
|
505
1017
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
506
1018
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
507
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1019
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
508
1020
|
break;
|
|
509
1021
|
case GGML_TYPE_MXFP4:
|
|
510
1022
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
|
|
511
1023
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
512
1024
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
513
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1025
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
1026
|
+
break;
|
|
1027
|
+
case GGML_TYPE_NVFP4:
|
|
1028
|
+
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_NVFP4>
|
|
1029
|
+
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
1030
|
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
1031
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
514
1032
|
break;
|
|
515
1033
|
case GGML_TYPE_Q2_K:
|
|
516
1034
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
|
|
517
1035
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
518
1036
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
519
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1037
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
520
1038
|
break;
|
|
521
1039
|
case GGML_TYPE_Q3_K:
|
|
522
1040
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
|
|
523
1041
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
524
1042
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
525
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1043
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
526
1044
|
break;
|
|
527
1045
|
case GGML_TYPE_Q4_K:
|
|
528
1046
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
|
|
529
1047
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
530
1048
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
531
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1049
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
532
1050
|
break;
|
|
533
1051
|
case GGML_TYPE_Q5_K:
|
|
534
1052
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
|
|
535
1053
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
536
1054
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
537
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1055
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
538
1056
|
break;
|
|
539
1057
|
case GGML_TYPE_Q6_K:
|
|
540
1058
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
|
|
541
1059
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
542
1060
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
543
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1061
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
544
1062
|
break;
|
|
545
1063
|
case GGML_TYPE_IQ2_XXS:
|
|
546
1064
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
|
|
547
1065
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
548
1066
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
549
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1067
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
550
1068
|
break;
|
|
551
1069
|
case GGML_TYPE_IQ2_XS:
|
|
552
1070
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
|
|
553
1071
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
554
1072
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
555
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1073
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
556
1074
|
break;
|
|
557
1075
|
case GGML_TYPE_IQ2_S:
|
|
558
1076
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
|
|
559
1077
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
560
1078
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
561
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1079
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
562
1080
|
break;
|
|
563
1081
|
case GGML_TYPE_IQ3_XXS:
|
|
564
1082
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
|
|
565
1083
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
566
1084
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
567
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1085
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
568
1086
|
break;
|
|
569
1087
|
case GGML_TYPE_IQ1_S:
|
|
570
1088
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
|
|
571
1089
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
572
1090
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
573
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1091
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
574
1092
|
break;
|
|
575
1093
|
case GGML_TYPE_IQ1_M:
|
|
576
1094
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
|
|
577
1095
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
578
1096
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
579
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1097
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
580
1098
|
break;
|
|
581
1099
|
case GGML_TYPE_IQ4_NL:
|
|
582
1100
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
|
|
583
1101
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
584
1102
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
585
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1103
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
586
1104
|
break;
|
|
587
1105
|
case GGML_TYPE_IQ4_XS:
|
|
588
1106
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
|
|
589
1107
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
590
1108
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
591
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1109
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
592
1110
|
break;
|
|
593
1111
|
case GGML_TYPE_IQ3_S:
|
|
594
1112
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
|
|
595
1113
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
|
596
1114
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
597
|
-
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
1115
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
|
598
1116
|
break;
|
|
599
1117
|
default:
|
|
600
1118
|
GGML_ABORT("fatal error");
|
|
@@ -622,7 +1140,7 @@ void ggml_cuda_mul_mat_vec_q(
|
|
|
622
1140
|
GGML_ASSERT( nb0 == ts_dst);
|
|
623
1141
|
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
|
|
624
1142
|
|
|
625
|
-
GGML_ASSERT(!ids || ne12
|
|
1143
|
+
GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE);
|
|
626
1144
|
|
|
627
1145
|
const float * src1_d = (const float *) src1->data;
|
|
628
1146
|
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
|
@@ -693,11 +1211,13 @@ void ggml_cuda_mul_mat_vec_q(
|
|
|
693
1211
|
const int64_t stride_channel_dst = ids ? s1 : s2;
|
|
694
1212
|
const int64_t stride_channel_y = ids ? s11 : s12;
|
|
695
1213
|
|
|
1214
|
+
const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
|
|
1215
|
+
|
|
696
1216
|
mul_mat_vec_q_switch_type(
|
|
697
1217
|
src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
|
|
698
1218
|
ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
|
|
699
1219
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
|
700
|
-
ne03, ne3, s03, s13, s3, stream);
|
|
1220
|
+
ne03, ne3, s03, s13, s3, ids_stride, stream);
|
|
701
1221
|
}
|
|
702
1222
|
|
|
703
1223
|
void ggml_cuda_op_mul_mat_vec_q(
|
|
@@ -726,7 +1246,7 @@ void ggml_cuda_op_mul_mat_vec_q(
|
|
|
726
1246
|
ggml_cuda_mm_fusion_args_device fusion_local{};
|
|
727
1247
|
mul_mat_vec_q_switch_type(
|
|
728
1248
|
src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
|
|
729
|
-
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
|
|
1249
|
+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream);
|
|
730
1250
|
|
|
731
1251
|
GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
|
|
732
1252
|
}
|