whispercpp 1.3.5 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/LICENSE +1 -1
- data/README.md +133 -3
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -7
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +56 -46
- data/ext/ruby_whisper.h +165 -2
- data/ext/ruby_whisper_context.c +297 -126
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -66
- data/ext/ruby_whisper_segment.c +6 -7
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +46 -16
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +24 -19
- data/ext/sources/examples/cli/cli.cpp +51 -9
- data/ext/sources/examples/common-ggml.cpp +4 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +213 -163
- data/ext/sources/ggml/CMakeLists.txt +29 -15
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +73 -11
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +8 -3
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +155 -16
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +25 -5
- data/ext/sources/ggml/src/ggml-alloc.c +9 -10
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
- data/ext/sources/ggml/src/ggml-common.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
- data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
- data/ext/sources/ggml/src/ggml-impl.h +68 -1
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +385 -119
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
- data/ext/sources/ggml/src/ggml.c +268 -52
- data/ext/sources/ggml/src/gguf.cpp +377 -47
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +62 -40
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +445 -55
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_context_params.rb +82 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +44 -6
- data/whispercpp.gemspec +2 -2
- metadata +426 -280
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
- data/ext/sources/examples/talk-llama/llama-context.h +0 -360
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
- data/ext/sources/examples/talk-llama/llama-model.h +0 -544
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
- data/ext/sources/examples/talk-llama/llama.h +0 -1540
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -569
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
|
@@ -56,6 +56,65 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
|
|
|
56
56
|
}
|
|
57
57
|
}
|
|
58
58
|
|
|
59
|
+
template <typename reorder_vec_dot_q_sycl, int ncols_dst>
|
|
60
|
+
static void mul_mat_vec_q_reorder_ncols(const void * __restrict__ vx, const void * __restrict__ vy,
|
|
61
|
+
float * __restrict__ dst, const int ncols, const int nrows,
|
|
62
|
+
const int stride_col_y_bytes, const int stride_col_dst,
|
|
63
|
+
const sycl::nd_item<3> & nd_item) {
|
|
64
|
+
using block_type = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>;
|
|
65
|
+
using block_traits = typename block_type::traits;
|
|
66
|
+
|
|
67
|
+
const auto sg = nd_item.get_sub_group();
|
|
68
|
+
const int sg_range = sg.get_group_linear_range();
|
|
69
|
+
const int workgroup_id = nd_item.get_group_linear_id();
|
|
70
|
+
const int sg_id = sg.get_group_linear_id();
|
|
71
|
+
const int row = workgroup_id * sg_range + sg_id;
|
|
72
|
+
|
|
73
|
+
if (row >= nrows) {
|
|
74
|
+
return;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
const int blocks_per_row = ncols / block_traits::qk;
|
|
78
|
+
constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
|
|
79
|
+
constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
|
|
80
|
+
const int nblocks = nrows * (ncols / block_traits::qk);
|
|
81
|
+
|
|
82
|
+
static_assert(blocks_per_subgroup > 0);
|
|
83
|
+
static_assert(block_elements_per_subgroup > 0);
|
|
84
|
+
|
|
85
|
+
float partial_sum[ncols_dst] = {0.0f};
|
|
86
|
+
for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
|
|
87
|
+
const int ibx = row * blocks_per_row + i;
|
|
88
|
+
|
|
89
|
+
const auto bx_offset = block_type::get_block_offset(ibx, nblocks);
|
|
90
|
+
const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx);
|
|
91
|
+
const int iby = i * block_type::block_to_q8_1_ratio();
|
|
92
|
+
|
|
93
|
+
#pragma unroll
|
|
94
|
+
for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
|
|
95
|
+
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
|
|
96
|
+
|
|
97
|
+
#pragma unroll
|
|
98
|
+
for (int j = 0; j < ncols_dst; ++j) {
|
|
99
|
+
const char * vy_j = (const char *)vy + j * stride_col_y_bytes;
|
|
100
|
+
const int8_t * q8_1_quant_ptr = (const int8_t *)vy_j + iby * QK8_1;
|
|
101
|
+
const sycl::half2* q8_1_ds_ptr = (const sycl::half2 *)(vy_j + ncols + iby * sizeof(sycl::half2));
|
|
102
|
+
|
|
103
|
+
partial_sum[j] += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs);
|
|
104
|
+
}
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
#pragma unroll
|
|
109
|
+
for (int j = 0; j < ncols_dst; ++j) {
|
|
110
|
+
float sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum[j], std::plus<>());
|
|
111
|
+
|
|
112
|
+
if (sg.leader()) {
|
|
113
|
+
dst[j * stride_col_dst + row] = sum;
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
|
|
59
118
|
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
|
|
60
119
|
static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
|
61
120
|
const int ncols, const int nrows, const sycl::nd_item<3> & item_ct1) {
|
|
@@ -100,6 +159,70 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
|
|
|
100
159
|
}
|
|
101
160
|
}
|
|
102
161
|
|
|
162
|
+
template <int qk, int qi, typename block_q_t, int vdr,
|
|
163
|
+
vec_dot_q_sycl_t vec_dot_q_sycl, int ncols_dst>
|
|
164
|
+
static void mul_mat_vec_q_ncols(
|
|
165
|
+
const void * __restrict__ vx,
|
|
166
|
+
const void * __restrict__ vy,
|
|
167
|
+
float * __restrict__ dst,
|
|
168
|
+
const int ncols,
|
|
169
|
+
const int nrows,
|
|
170
|
+
const int stride_col_y,
|
|
171
|
+
const int stride_col_dst,
|
|
172
|
+
const sycl::nd_item<3> & item_ct1) {
|
|
173
|
+
|
|
174
|
+
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1)
|
|
175
|
+
+ item_ct1.get_local_id(1);
|
|
176
|
+
|
|
177
|
+
if (row >= nrows) {
|
|
178
|
+
return;
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
const int blocks_per_row = ncols / qk;
|
|
182
|
+
constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi;
|
|
183
|
+
|
|
184
|
+
// partial sums: one per output column
|
|
185
|
+
float tmp[ncols_dst] = {0.0f};
|
|
186
|
+
|
|
187
|
+
const block_q_t * x = (const block_q_t *) vx;
|
|
188
|
+
const block_q8_1 * y = (const block_q8_1 *) vy;
|
|
189
|
+
|
|
190
|
+
for (int i = item_ct1.get_local_id(2) / (qi / vdr);
|
|
191
|
+
i < blocks_per_row;
|
|
192
|
+
i += blocks_per_warp) {
|
|
193
|
+
|
|
194
|
+
const int ibx = row * blocks_per_row + i;
|
|
195
|
+
const int iby = i * (qk / QK8_1);
|
|
196
|
+
|
|
197
|
+
// read weight block once, dot against all columns
|
|
198
|
+
for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) {
|
|
199
|
+
const int iqs = elem + vdr * (item_ct1.get_local_id(2) % (qi / vdr));
|
|
200
|
+
|
|
201
|
+
#pragma unroll
|
|
202
|
+
for (int j = 0; j < ncols_dst; ++j) {
|
|
203
|
+
tmp[j] += vec_dot_q_sycl(&x[ibx], &y[j * stride_col_y + iby], iqs);
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
// reduce within subgroup
|
|
209
|
+
#pragma unroll
|
|
210
|
+
for (int j = 0; j < ncols_dst; ++j) {
|
|
211
|
+
#pragma unroll
|
|
212
|
+
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
|
213
|
+
tmp[j] += dpct::permute_sub_group_by_xor(
|
|
214
|
+
item_ct1.get_sub_group(), tmp[j], mask);
|
|
215
|
+
}
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
if (item_ct1.get_local_id(2) == 0) {
|
|
219
|
+
#pragma unroll
|
|
220
|
+
for (int j = 0; j < ncols_dst; ++j) {
|
|
221
|
+
dst[j * stride_col_dst + row] = tmp[j];
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
|
|
103
226
|
template <int qk, int qi, typename block_q_t, int vdr>
|
|
104
227
|
static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
|
|
105
228
|
const void *__restrict__ vy,
|
|
@@ -537,9 +660,9 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
|
|
|
537
660
|
static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
|
538
661
|
const int nrows, dpct::queue_ptr stream) {
|
|
539
662
|
GGML_ASSERT(ncols % QK4_0 == 0);
|
|
540
|
-
|
|
541
|
-
constexpr size_t num_subgroups =
|
|
542
|
-
|
|
663
|
+
// Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
|
|
664
|
+
constexpr size_t num_subgroups = WARP_SIZE;
|
|
665
|
+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
|
|
543
666
|
|
|
544
667
|
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
|
|
545
668
|
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
|
@@ -553,6 +676,45 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy,
|
|
|
553
676
|
});
|
|
554
677
|
}
|
|
555
678
|
|
|
679
|
+
template <int ncols_dst>
|
|
680
|
+
static void reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols(
|
|
681
|
+
const void * vx, const void * vy, float * dst,
|
|
682
|
+
const int ncols, const int nrows,
|
|
683
|
+
const int stride_col_y_bytes, const int stride_col_dst,
|
|
684
|
+
dpct::queue_ptr stream) {
|
|
685
|
+
GGML_ASSERT(ncols % QK4_0 == 0);
|
|
686
|
+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
|
687
|
+
constexpr size_t num_subgroups = 16;
|
|
688
|
+
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
|
689
|
+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
|
690
|
+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
|
691
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
692
|
+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
|
693
|
+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
694
|
+
mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>, ncols_dst>(
|
|
695
|
+
vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
|
|
696
|
+
});
|
|
697
|
+
});
|
|
698
|
+
}
|
|
699
|
+
|
|
700
|
+
static void reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols(
|
|
701
|
+
const void * vx, const void * vy, float * dst,
|
|
702
|
+
const int ncols, const int nrows, const int ncols_dst,
|
|
703
|
+
const int stride_col_y_bytes, const int stride_col_dst,
|
|
704
|
+
dpct::queue_ptr stream) {
|
|
705
|
+
switch (ncols_dst) {
|
|
706
|
+
case 1: reorder_mul_mat_vec_q4_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
707
|
+
case 2: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
708
|
+
case 3: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
709
|
+
case 4: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
710
|
+
case 5: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
711
|
+
case 6: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
712
|
+
case 7: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
713
|
+
case 8: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
714
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for Q4_0 reorder multi-col MMVQ", ncols_dst);
|
|
715
|
+
}
|
|
716
|
+
}
|
|
717
|
+
|
|
556
718
|
static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
|
|
557
719
|
dpct::queue_ptr stream) {
|
|
558
720
|
GGML_ASSERT(ncols % QK4_0 == 0);
|
|
@@ -571,6 +733,45 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float *
|
|
|
571
733
|
}
|
|
572
734
|
}
|
|
573
735
|
|
|
736
|
+
template <int ncols_dst>
|
|
737
|
+
static void mul_mat_vec_q4_0_q8_1_sycl_ncols(
|
|
738
|
+
const void * vx, const void * vy, float * dst,
|
|
739
|
+
const int ncols, const int nrows,
|
|
740
|
+
const int stride_col_y, const int stride_col_dst,
|
|
741
|
+
dpct::queue_ptr stream) {
|
|
742
|
+
GGML_ASSERT(ncols % QK4_0 == 0);
|
|
743
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
744
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
745
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
746
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
747
|
+
cgh.parallel_for(
|
|
748
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
749
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
750
|
+
mul_mat_vec_q_ncols<QK4_0, QI4_0, block_q4_0,
|
|
751
|
+
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1, ncols_dst>(
|
|
752
|
+
vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
|
|
753
|
+
});
|
|
754
|
+
});
|
|
755
|
+
}
|
|
756
|
+
|
|
757
|
+
static void mul_mat_vec_q4_0_q8_1_sycl_switch_ncols(
|
|
758
|
+
const void * vx, const void * vy, float * dst,
|
|
759
|
+
const int ncols, const int nrows, const int ncols_dst,
|
|
760
|
+
const int stride_col_y, const int stride_col_dst,
|
|
761
|
+
dpct::queue_ptr stream) {
|
|
762
|
+
switch (ncols_dst) {
|
|
763
|
+
case 1: mul_mat_vec_q4_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
764
|
+
case 2: mul_mat_vec_q4_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
765
|
+
case 3: mul_mat_vec_q4_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
766
|
+
case 4: mul_mat_vec_q4_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
767
|
+
case 5: mul_mat_vec_q4_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
768
|
+
case 6: mul_mat_vec_q4_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
769
|
+
case 7: mul_mat_vec_q4_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
770
|
+
case 8: mul_mat_vec_q4_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
771
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for Q4_0 multi-col MMVQ", ncols_dst);
|
|
772
|
+
}
|
|
773
|
+
}
|
|
774
|
+
|
|
574
775
|
static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
|
575
776
|
float *dst, const int ncols,
|
|
576
777
|
const int nrows,
|
|
@@ -595,6 +796,45 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
|
|
595
796
|
}
|
|
596
797
|
}
|
|
597
798
|
|
|
799
|
+
template <int ncols_dst>
|
|
800
|
+
static void mul_mat_vec_q4_1_q8_1_sycl_ncols(
|
|
801
|
+
const void * vx, const void * vy, float * dst,
|
|
802
|
+
const int ncols, const int nrows,
|
|
803
|
+
const int stride_col_y, const int stride_col_dst,
|
|
804
|
+
dpct::queue_ptr stream) {
|
|
805
|
+
GGML_ASSERT(ncols % QK4_1 == 0);
|
|
806
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
807
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
808
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
809
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
810
|
+
cgh.parallel_for(
|
|
811
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
812
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
813
|
+
mul_mat_vec_q_ncols<QK4_0, QI4_1, block_q4_1,
|
|
814
|
+
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1, ncols_dst>(
|
|
815
|
+
vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
|
|
816
|
+
});
|
|
817
|
+
});
|
|
818
|
+
}
|
|
819
|
+
|
|
820
|
+
static void mul_mat_vec_q4_1_q8_1_sycl_switch_ncols(
|
|
821
|
+
const void * vx, const void * vy, float * dst,
|
|
822
|
+
const int ncols, const int nrows, const int ncols_dst,
|
|
823
|
+
const int stride_col_y, const int stride_col_dst,
|
|
824
|
+
dpct::queue_ptr stream) {
|
|
825
|
+
switch (ncols_dst) {
|
|
826
|
+
case 1: mul_mat_vec_q4_1_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
827
|
+
case 2: mul_mat_vec_q4_1_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
828
|
+
case 3: mul_mat_vec_q4_1_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
829
|
+
case 4: mul_mat_vec_q4_1_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
830
|
+
case 5: mul_mat_vec_q4_1_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
831
|
+
case 6: mul_mat_vec_q4_1_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
832
|
+
case 7: mul_mat_vec_q4_1_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
833
|
+
case 8: mul_mat_vec_q4_1_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
834
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for Q4_1 multi-col MMVQ", ncols_dst);
|
|
835
|
+
}
|
|
836
|
+
}
|
|
837
|
+
|
|
598
838
|
static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
|
|
599
839
|
dpct::queue_ptr stream) {
|
|
600
840
|
GGML_ASSERT(ncols % QK_MXFP4 == 0);
|
|
@@ -613,6 +853,101 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float
|
|
|
613
853
|
}
|
|
614
854
|
}
|
|
615
855
|
|
|
856
|
+
template <int ncols_dst>
|
|
857
|
+
static void mul_mat_vec_mxfp4_q8_1_sycl_ncols(
|
|
858
|
+
const void * vx, const void * vy, float * dst,
|
|
859
|
+
const int ncols, const int nrows,
|
|
860
|
+
const int stride_col_y, const int stride_col_dst,
|
|
861
|
+
dpct::queue_ptr stream) {
|
|
862
|
+
GGML_ASSERT(ncols % QK_MXFP4 == 0);
|
|
863
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
864
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
865
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
866
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
867
|
+
cgh.parallel_for(
|
|
868
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
869
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
870
|
+
mul_mat_vec_q_ncols<QK_MXFP4, QI_MXFP4, block_mxfp4,
|
|
871
|
+
VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1, ncols_dst>(
|
|
872
|
+
vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
|
|
873
|
+
});
|
|
874
|
+
});
|
|
875
|
+
}
|
|
876
|
+
|
|
877
|
+
static void mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols(
|
|
878
|
+
const void * vx, const void * vy, float * dst,
|
|
879
|
+
const int ncols, const int nrows, const int ncols_dst,
|
|
880
|
+
const int stride_col_y, const int stride_col_dst,
|
|
881
|
+
dpct::queue_ptr stream) {
|
|
882
|
+
switch (ncols_dst) {
|
|
883
|
+
case 1: mul_mat_vec_mxfp4_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
884
|
+
case 2: mul_mat_vec_mxfp4_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
885
|
+
case 3: mul_mat_vec_mxfp4_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
886
|
+
case 4: mul_mat_vec_mxfp4_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
887
|
+
case 5: mul_mat_vec_mxfp4_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
888
|
+
case 6: mul_mat_vec_mxfp4_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
889
|
+
case 7: mul_mat_vec_mxfp4_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
890
|
+
case 8: mul_mat_vec_mxfp4_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
891
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for MXFP4 multi-col MMVQ", ncols_dst);
|
|
892
|
+
}
|
|
893
|
+
}
|
|
894
|
+
|
|
895
|
+
static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
|
|
896
|
+
dpct::queue_ptr stream) {
|
|
897
|
+
GGML_ASSERT(ncols % QK_NVFP4 == 0);
|
|
898
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
899
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
900
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
901
|
+
|
|
902
|
+
{
|
|
903
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
904
|
+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
905
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
906
|
+
mul_mat_vec_q<QK_NVFP4, QI_NVFP4, block_nvfp4, VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1>(
|
|
907
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
|
908
|
+
});
|
|
909
|
+
});
|
|
910
|
+
}
|
|
911
|
+
}
|
|
912
|
+
|
|
913
|
+
template <int ncols_dst>
|
|
914
|
+
static void mul_mat_vec_nvfp4_q8_1_sycl_ncols(
|
|
915
|
+
const void * vx, const void * vy, float * dst,
|
|
916
|
+
const int ncols, const int nrows,
|
|
917
|
+
const int stride_col_y, const int stride_col_dst,
|
|
918
|
+
dpct::queue_ptr stream) {
|
|
919
|
+
GGML_ASSERT(ncols % QK_NVFP4 == 0);
|
|
920
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
921
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
922
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
923
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
924
|
+
cgh.parallel_for(
|
|
925
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
926
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
927
|
+
mul_mat_vec_q_ncols<QK_NVFP4, QI_NVFP4, block_nvfp4,
|
|
928
|
+
VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1, ncols_dst>(
|
|
929
|
+
vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
|
|
930
|
+
});
|
|
931
|
+
});
|
|
932
|
+
}
|
|
933
|
+
|
|
934
|
+
static void mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols(
|
|
935
|
+
const void * vx, const void * vy, float * dst,
|
|
936
|
+
const int ncols, const int nrows, const int ncols_dst,
|
|
937
|
+
const int stride_col_y, const int stride_col_dst,
|
|
938
|
+
dpct::queue_ptr stream) {
|
|
939
|
+
switch (ncols_dst) {
|
|
940
|
+
case 1: mul_mat_vec_nvfp4_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
941
|
+
case 2: mul_mat_vec_nvfp4_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
942
|
+
case 3: mul_mat_vec_nvfp4_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
943
|
+
case 4: mul_mat_vec_nvfp4_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
944
|
+
case 5: mul_mat_vec_nvfp4_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
945
|
+
case 6: mul_mat_vec_nvfp4_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
946
|
+
case 7: mul_mat_vec_nvfp4_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
947
|
+
case 8: mul_mat_vec_nvfp4_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
948
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for NVFP4 multi-col MMVQ", ncols_dst);
|
|
949
|
+
}
|
|
950
|
+
}
|
|
616
951
|
|
|
617
952
|
static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
|
618
953
|
float *dst, const int ncols,
|
|
@@ -638,6 +973,45 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
|
|
638
973
|
}
|
|
639
974
|
}
|
|
640
975
|
|
|
976
|
+
template <int ncols_dst>
|
|
977
|
+
static void mul_mat_vec_q5_0_q8_1_sycl_ncols(
|
|
978
|
+
const void * vx, const void * vy, float * dst,
|
|
979
|
+
const int ncols, const int nrows,
|
|
980
|
+
const int stride_col_y, const int stride_col_dst,
|
|
981
|
+
dpct::queue_ptr stream) {
|
|
982
|
+
GGML_ASSERT(ncols % QK5_0 == 0);
|
|
983
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
984
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
985
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
986
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
987
|
+
cgh.parallel_for(
|
|
988
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
989
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
990
|
+
mul_mat_vec_q_ncols<QK5_0, QI5_0, block_q5_0,
|
|
991
|
+
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1, ncols_dst>(
|
|
992
|
+
vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
|
|
993
|
+
});
|
|
994
|
+
});
|
|
995
|
+
}
|
|
996
|
+
|
|
997
|
+
static void mul_mat_vec_q5_0_q8_1_sycl_switch_ncols(
|
|
998
|
+
const void * vx, const void * vy, float * dst,
|
|
999
|
+
const int ncols, const int nrows, const int ncols_dst,
|
|
1000
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1001
|
+
dpct::queue_ptr stream) {
|
|
1002
|
+
switch (ncols_dst) {
|
|
1003
|
+
case 1: mul_mat_vec_q5_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
1004
|
+
case 2: mul_mat_vec_q5_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1005
|
+
case 3: mul_mat_vec_q5_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1006
|
+
case 4: mul_mat_vec_q5_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1007
|
+
case 5: mul_mat_vec_q5_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1008
|
+
case 6: mul_mat_vec_q5_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1009
|
+
case 7: mul_mat_vec_q5_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1010
|
+
case 8: mul_mat_vec_q5_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1011
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for Q5_0 multi-col MMVQ", ncols_dst);
|
|
1012
|
+
}
|
|
1013
|
+
}
|
|
1014
|
+
|
|
641
1015
|
static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
|
642
1016
|
float *dst, const int ncols,
|
|
643
1017
|
const int nrows,
|
|
@@ -662,6 +1036,103 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
|
|
662
1036
|
}
|
|
663
1037
|
}
|
|
664
1038
|
|
|
1039
|
+
template <int ncols_dst>
|
|
1040
|
+
static void mul_mat_vec_q5_1_q8_1_sycl_ncols(
|
|
1041
|
+
const void * vx, const void * vy, float * dst,
|
|
1042
|
+
const int ncols, const int nrows,
|
|
1043
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1044
|
+
dpct::queue_ptr stream) {
|
|
1045
|
+
GGML_ASSERT(ncols % QK5_1 == 0);
|
|
1046
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
1047
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
1048
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
1049
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
1050
|
+
cgh.parallel_for(
|
|
1051
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
1052
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1053
|
+
mul_mat_vec_q_ncols<QK5_1, QI5_1, block_q5_1,
|
|
1054
|
+
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1, ncols_dst>(
|
|
1055
|
+
vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
|
|
1056
|
+
});
|
|
1057
|
+
});
|
|
1058
|
+
}
|
|
1059
|
+
|
|
1060
|
+
static void mul_mat_vec_q5_1_q8_1_sycl_switch_ncols(
|
|
1061
|
+
const void * vx, const void * vy, float * dst,
|
|
1062
|
+
const int ncols, const int nrows, const int ncols_dst,
|
|
1063
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1064
|
+
dpct::queue_ptr stream) {
|
|
1065
|
+
switch (ncols_dst) {
|
|
1066
|
+
case 1: mul_mat_vec_q5_1_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
1067
|
+
case 2: mul_mat_vec_q5_1_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1068
|
+
case 3: mul_mat_vec_q5_1_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1069
|
+
case 4: mul_mat_vec_q5_1_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1070
|
+
case 5: mul_mat_vec_q5_1_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1071
|
+
case 6: mul_mat_vec_q5_1_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1072
|
+
case 7: mul_mat_vec_q5_1_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1073
|
+
case 8: mul_mat_vec_q5_1_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1074
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for Q5_1 multi-col MMVQ", ncols_dst);
|
|
1075
|
+
}
|
|
1076
|
+
}
|
|
1077
|
+
|
|
1078
|
+
static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
|
1079
|
+
const int nrows, dpct::queue_ptr stream) {
|
|
1080
|
+
GGML_ASSERT(ncols % QK8_0 == 0);
|
|
1081
|
+
// Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
|
|
1082
|
+
constexpr size_t num_subgroups = WARP_SIZE;
|
|
1083
|
+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
|
|
1084
|
+
|
|
1085
|
+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
|
|
1086
|
+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
|
1087
|
+
|
|
1088
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
1089
|
+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
|
1090
|
+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1091
|
+
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0>>(vx, vy, dst, ncols, nrows,
|
|
1092
|
+
nd_item);
|
|
1093
|
+
});
|
|
1094
|
+
});
|
|
1095
|
+
}
|
|
1096
|
+
|
|
1097
|
+
template <int ncols_dst>
|
|
1098
|
+
static void reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols(
|
|
1099
|
+
const void * vx, const void * vy, float * dst,
|
|
1100
|
+
const int ncols, const int nrows,
|
|
1101
|
+
const int stride_col_y_bytes, const int stride_col_dst,
|
|
1102
|
+
dpct::queue_ptr stream) {
|
|
1103
|
+
GGML_ASSERT(ncols % QK8_0 == 0);
|
|
1104
|
+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
|
1105
|
+
constexpr size_t num_subgroups = 16;
|
|
1106
|
+
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
|
1107
|
+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
|
1108
|
+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
|
1109
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
1110
|
+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
|
1111
|
+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1112
|
+
mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0>, ncols_dst>(
|
|
1113
|
+
vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
|
|
1114
|
+
});
|
|
1115
|
+
});
|
|
1116
|
+
}
|
|
1117
|
+
|
|
1118
|
+
static void reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols(
|
|
1119
|
+
const void * vx, const void * vy, float * dst,
|
|
1120
|
+
const int ncols, const int nrows, const int ncols_dst,
|
|
1121
|
+
const int stride_col_y_bytes, const int stride_col_dst,
|
|
1122
|
+
dpct::queue_ptr stream) {
|
|
1123
|
+
switch (ncols_dst) {
|
|
1124
|
+
case 1: reorder_mul_mat_vec_q8_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
1125
|
+
case 2: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1126
|
+
case 3: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1127
|
+
case 4: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1128
|
+
case 5: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1129
|
+
case 6: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1130
|
+
case 7: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1131
|
+
case 8: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1132
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for Q8_0 reorder multi-col MMVQ", ncols_dst);
|
|
1133
|
+
}
|
|
1134
|
+
}
|
|
1135
|
+
|
|
665
1136
|
static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
|
666
1137
|
float *dst, const int ncols,
|
|
667
1138
|
const int nrows,
|
|
@@ -686,6 +1157,45 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
|
|
686
1157
|
}
|
|
687
1158
|
}
|
|
688
1159
|
|
|
1160
|
+
template <int ncols_dst>
|
|
1161
|
+
static void mul_mat_vec_q8_0_q8_1_sycl_ncols(
|
|
1162
|
+
const void * vx, const void * vy, float * dst,
|
|
1163
|
+
const int ncols, const int nrows,
|
|
1164
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1165
|
+
dpct::queue_ptr stream) {
|
|
1166
|
+
GGML_ASSERT(ncols % QK8_0 == 0);
|
|
1167
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
1168
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
1169
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
1170
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
1171
|
+
cgh.parallel_for(
|
|
1172
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
1173
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1174
|
+
mul_mat_vec_q_ncols<QK8_0, QI8_0, block_q8_0,
|
|
1175
|
+
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1, ncols_dst>(
|
|
1176
|
+
vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
|
|
1177
|
+
});
|
|
1178
|
+
});
|
|
1179
|
+
}
|
|
1180
|
+
|
|
1181
|
+
static void mul_mat_vec_q8_0_q8_1_sycl_switch_ncols(
|
|
1182
|
+
const void * vx, const void * vy, float * dst,
|
|
1183
|
+
const int ncols, const int nrows, const int ncols_dst,
|
|
1184
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1185
|
+
dpct::queue_ptr stream) {
|
|
1186
|
+
switch (ncols_dst) {
|
|
1187
|
+
case 1: mul_mat_vec_q8_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
1188
|
+
case 2: mul_mat_vec_q8_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1189
|
+
case 3: mul_mat_vec_q8_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1190
|
+
case 4: mul_mat_vec_q8_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1191
|
+
case 5: mul_mat_vec_q8_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1192
|
+
case 6: mul_mat_vec_q8_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1193
|
+
case 7: mul_mat_vec_q8_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1194
|
+
case 8: mul_mat_vec_q8_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1195
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for Q8_0 multi-col MMVQ", ncols_dst);
|
|
1196
|
+
}
|
|
1197
|
+
}
|
|
1198
|
+
|
|
689
1199
|
static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
|
690
1200
|
float *dst, const int ncols,
|
|
691
1201
|
const int nrows,
|
|
@@ -710,6 +1220,45 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
|
|
710
1220
|
}
|
|
711
1221
|
}
|
|
712
1222
|
|
|
1223
|
+
template <int ncols_dst>
|
|
1224
|
+
static void mul_mat_vec_q2_K_q8_1_sycl_ncols(
|
|
1225
|
+
const void * vx, const void * vy, float * dst,
|
|
1226
|
+
const int ncols, const int nrows,
|
|
1227
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1228
|
+
dpct::queue_ptr stream) {
|
|
1229
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
|
1230
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
1231
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
1232
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
1233
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
1234
|
+
cgh.parallel_for(
|
|
1235
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
1236
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1237
|
+
mul_mat_vec_q_ncols<QK_K, QI2_K, block_q2_K,
|
|
1238
|
+
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1, ncols_dst>(
|
|
1239
|
+
vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
|
|
1240
|
+
});
|
|
1241
|
+
});
|
|
1242
|
+
}
|
|
1243
|
+
|
|
1244
|
+
static void mul_mat_vec_q2_K_q8_1_sycl_switch_ncols(
|
|
1245
|
+
const void * vx, const void * vy, float * dst,
|
|
1246
|
+
const int ncols, const int nrows, const int ncols_dst,
|
|
1247
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1248
|
+
dpct::queue_ptr stream) {
|
|
1249
|
+
switch (ncols_dst) {
|
|
1250
|
+
case 1: mul_mat_vec_q2_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
1251
|
+
case 2: mul_mat_vec_q2_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1252
|
+
case 3: mul_mat_vec_q2_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1253
|
+
case 4: mul_mat_vec_q2_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1254
|
+
case 5: mul_mat_vec_q2_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1255
|
+
case 6: mul_mat_vec_q2_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1256
|
+
case 7: mul_mat_vec_q2_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1257
|
+
case 8: mul_mat_vec_q2_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1258
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for Q2_K multi-col MMVQ", ncols_dst);
|
|
1259
|
+
}
|
|
1260
|
+
}
|
|
1261
|
+
|
|
713
1262
|
static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
|
714
1263
|
float *dst, const int ncols,
|
|
715
1264
|
const int nrows,
|
|
@@ -734,6 +1283,105 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
|
|
734
1283
|
}
|
|
735
1284
|
}
|
|
736
1285
|
|
|
1286
|
+
static void reorder_mul_mat_vec_q3_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
|
1287
|
+
const int nrows, dpct::queue_ptr stream) {
|
|
1288
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
|
1289
|
+
|
|
1290
|
+
// Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
|
|
1291
|
+
constexpr size_t num_subgroups = WARP_SIZE;
|
|
1292
|
+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
|
|
1293
|
+
|
|
1294
|
+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
|
1295
|
+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
|
1296
|
+
|
|
1297
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
1298
|
+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
|
1299
|
+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1300
|
+
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>>(vx, vy, dst, ncols, nrows,
|
|
1301
|
+
nd_item);
|
|
1302
|
+
});
|
|
1303
|
+
});
|
|
1304
|
+
}
|
|
1305
|
+
|
|
1306
|
+
template <int ncols_dst>
|
|
1307
|
+
static void reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols(
|
|
1308
|
+
const void * vx, const void * vy, float * dst,
|
|
1309
|
+
const int ncols, const int nrows,
|
|
1310
|
+
const int stride_col_y_bytes, const int stride_col_dst,
|
|
1311
|
+
dpct::queue_ptr stream) {
|
|
1312
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
|
1313
|
+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
|
1314
|
+
constexpr size_t num_subgroups = 16;
|
|
1315
|
+
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
|
1316
|
+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
|
1317
|
+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
|
1318
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
1319
|
+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
|
1320
|
+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1321
|
+
mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>, ncols_dst>(
|
|
1322
|
+
vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
|
|
1323
|
+
});
|
|
1324
|
+
});
|
|
1325
|
+
}
|
|
1326
|
+
|
|
1327
|
+
static void reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols(
|
|
1328
|
+
const void * vx, const void * vy, float * dst,
|
|
1329
|
+
const int ncols, const int nrows, const int ncols_dst,
|
|
1330
|
+
const int stride_col_y_bytes, const int stride_col_dst,
|
|
1331
|
+
dpct::queue_ptr stream) {
|
|
1332
|
+
switch (ncols_dst) {
|
|
1333
|
+
case 1: reorder_mul_mat_vec_q3_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
1334
|
+
case 2: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1335
|
+
case 3: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1336
|
+
case 4: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1337
|
+
case 5: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1338
|
+
case 6: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1339
|
+
case 7: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1340
|
+
case 8: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1341
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for Q3_K reorder multi-col MMVQ", ncols_dst);
|
|
1342
|
+
}
|
|
1343
|
+
}
|
|
1344
|
+
|
|
1345
|
+
template <int ncols_dst>
|
|
1346
|
+
static void mul_mat_vec_q3_K_q8_1_sycl_ncols(
|
|
1347
|
+
const void * vx, const void * vy, float * dst,
|
|
1348
|
+
const int ncols, const int nrows,
|
|
1349
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1350
|
+
dpct::queue_ptr stream) {
|
|
1351
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
|
1352
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
1353
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
1354
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
1355
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
1356
|
+
cgh.parallel_for(
|
|
1357
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
1358
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1359
|
+
mul_mat_vec_q_ncols<QK_K, QI3_K, block_q3_K,
|
|
1360
|
+
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1, ncols_dst>(
|
|
1361
|
+
vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
|
|
1362
|
+
});
|
|
1363
|
+
});
|
|
1364
|
+
}
|
|
1365
|
+
|
|
1366
|
+
static void mul_mat_vec_q3_K_q8_1_sycl_switch_ncols(
|
|
1367
|
+
const void * vx, const void * vy, float * dst,
|
|
1368
|
+
const int ncols, const int nrows, const int ncols_dst,
|
|
1369
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1370
|
+
dpct::queue_ptr stream) {
|
|
1371
|
+
switch (ncols_dst) {
|
|
1372
|
+
case 1: mul_mat_vec_q3_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
1373
|
+
case 2: mul_mat_vec_q3_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1374
|
+
case 3: mul_mat_vec_q3_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1375
|
+
case 4: mul_mat_vec_q3_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1376
|
+
case 5: mul_mat_vec_q3_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1377
|
+
case 6: mul_mat_vec_q3_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1378
|
+
case 7: mul_mat_vec_q3_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1379
|
+
case 8: mul_mat_vec_q3_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1380
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for Q3_K multi-col MMVQ", ncols_dst);
|
|
1381
|
+
}
|
|
1382
|
+
}
|
|
1383
|
+
|
|
1384
|
+
|
|
737
1385
|
static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
|
738
1386
|
float *dst, const int ncols,
|
|
739
1387
|
const int nrows,
|
|
@@ -758,13 +1406,58 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
|
|
758
1406
|
}
|
|
759
1407
|
}
|
|
760
1408
|
|
|
1409
|
+
template <int ncols_dst>
|
|
1410
|
+
static void mul_mat_vec_q4_K_q8_1_sycl_ncols(
|
|
1411
|
+
const void * vx, const void * vy, float * dst,
|
|
1412
|
+
const int ncols, const int nrows,
|
|
1413
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1414
|
+
dpct::queue_ptr stream) {
|
|
1415
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
|
1416
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
1417
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
1418
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
1419
|
+
|
|
1420
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
1421
|
+
cgh.parallel_for(
|
|
1422
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
1423
|
+
[=](sycl::nd_item<3> item_ct1)
|
|
1424
|
+
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1425
|
+
mul_mat_vec_q_ncols<QK_K, QI4_K, block_q4_K,
|
|
1426
|
+
VDR_Q4_K_Q8_1_MMVQ,
|
|
1427
|
+
vec_dot_q4_K_q8_1,
|
|
1428
|
+
ncols_dst>(
|
|
1429
|
+
vx, vy, dst, ncols, nrows,
|
|
1430
|
+
stride_col_y, stride_col_dst, item_ct1);
|
|
1431
|
+
});
|
|
1432
|
+
});
|
|
1433
|
+
}
|
|
1434
|
+
|
|
1435
|
+
static void mul_mat_vec_q4_K_q8_1_sycl_switch_ncols(
|
|
1436
|
+
const void * vx, const void * vy, float * dst,
|
|
1437
|
+
const int ncols, const int nrows,
|
|
1438
|
+
const int ncols_dst,
|
|
1439
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1440
|
+
dpct::queue_ptr stream) {
|
|
1441
|
+
switch (ncols_dst) {
|
|
1442
|
+
case 1: mul_mat_vec_q4_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
1443
|
+
case 2: mul_mat_vec_q4_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1444
|
+
case 3: mul_mat_vec_q4_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1445
|
+
case 4: mul_mat_vec_q4_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1446
|
+
case 5: mul_mat_vec_q4_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1447
|
+
case 6: mul_mat_vec_q4_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1448
|
+
case 7: mul_mat_vec_q4_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1449
|
+
case 8: mul_mat_vec_q4_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1450
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for Q4_K multi-col MMVQ", ncols_dst);
|
|
1451
|
+
}
|
|
1452
|
+
}
|
|
1453
|
+
|
|
761
1454
|
static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
|
762
1455
|
const int nrows, dpct::queue_ptr stream) {
|
|
763
1456
|
GGML_ASSERT(ncols % QK_K == 0);
|
|
764
1457
|
|
|
765
|
-
|
|
766
|
-
constexpr size_t num_subgroups =
|
|
767
|
-
|
|
1458
|
+
// Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
|
|
1459
|
+
constexpr size_t num_subgroups = WARP_SIZE;
|
|
1460
|
+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
|
|
768
1461
|
|
|
769
1462
|
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
|
770
1463
|
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
|
@@ -778,6 +1471,44 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy,
|
|
|
778
1471
|
});
|
|
779
1472
|
}
|
|
780
1473
|
|
|
1474
|
+
template <int ncols_dst>
|
|
1475
|
+
static void reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols(
|
|
1476
|
+
const void * vx, const void * vy, float * dst,
|
|
1477
|
+
const int ncols, const int nrows,
|
|
1478
|
+
const int stride_col_y_bytes, const int stride_col_dst,
|
|
1479
|
+
dpct::queue_ptr stream) {
|
|
1480
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
|
1481
|
+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
|
1482
|
+
constexpr size_t num_subgroups = 16;
|
|
1483
|
+
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
|
1484
|
+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
|
1485
|
+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
|
1486
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
1487
|
+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
|
1488
|
+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1489
|
+
mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>, ncols_dst>(
|
|
1490
|
+
vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
|
|
1491
|
+
});
|
|
1492
|
+
});
|
|
1493
|
+
}
|
|
1494
|
+
|
|
1495
|
+
static void reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols(
|
|
1496
|
+
const void * vx, const void * vy, float * dst,
|
|
1497
|
+
const int ncols, const int nrows, const int ncols_dst,
|
|
1498
|
+
const int stride_col_y_bytes, const int stride_col_dst,
|
|
1499
|
+
dpct::queue_ptr stream) {
|
|
1500
|
+
switch (ncols_dst) {
|
|
1501
|
+
case 1: reorder_mul_mat_vec_q4_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
1502
|
+
case 2: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1503
|
+
case 3: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1504
|
+
case 4: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1505
|
+
case 5: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1506
|
+
case 6: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1507
|
+
case 7: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1508
|
+
case 8: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1509
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for Q4_K reorder multi-col MMVQ", ncols_dst);
|
|
1510
|
+
}
|
|
1511
|
+
}
|
|
781
1512
|
|
|
782
1513
|
static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
|
783
1514
|
float *dst, const int ncols,
|
|
@@ -803,9 +1534,55 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
|
|
803
1534
|
}
|
|
804
1535
|
}
|
|
805
1536
|
|
|
806
|
-
|
|
1537
|
+
template <int ncols_dst>
|
|
1538
|
+
static void mul_mat_vec_q5_K_q8_1_sycl_ncols(
|
|
1539
|
+
const void * vx, const void * vy, float * dst,
|
|
1540
|
+
const int ncols, const int nrows,
|
|
1541
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1542
|
+
dpct::queue_ptr stream) {
|
|
1543
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
|
1544
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
1545
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
1546
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
1547
|
+
|
|
1548
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
1549
|
+
cgh.parallel_for(
|
|
1550
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
1551
|
+
[=](sycl::nd_item<3> item_ct1)
|
|
1552
|
+
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1553
|
+
mul_mat_vec_q_ncols<QK_K, QI5_K, block_q5_K,
|
|
1554
|
+
VDR_Q5_K_Q8_1_MMVQ,
|
|
1555
|
+
vec_dot_q5_K_q8_1,
|
|
1556
|
+
ncols_dst>(
|
|
1557
|
+
vx, vy, dst, ncols, nrows,
|
|
1558
|
+
stride_col_y, stride_col_dst, item_ct1);
|
|
1559
|
+
});
|
|
1560
|
+
});
|
|
1561
|
+
}
|
|
1562
|
+
|
|
1563
|
+
static void mul_mat_vec_q5_K_q8_1_sycl_switch_ncols(
|
|
1564
|
+
const void * vx, const void * vy, float * dst,
|
|
1565
|
+
const int ncols, const int nrows,
|
|
1566
|
+
const int ncols_dst,
|
|
1567
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1568
|
+
dpct::queue_ptr stream) {
|
|
1569
|
+
switch (ncols_dst) {
|
|
1570
|
+
case 1: mul_mat_vec_q5_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
1571
|
+
case 2: mul_mat_vec_q5_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1572
|
+
case 3: mul_mat_vec_q5_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1573
|
+
case 4: mul_mat_vec_q5_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1574
|
+
case 5: mul_mat_vec_q5_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1575
|
+
case 6: mul_mat_vec_q5_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1576
|
+
case 7: mul_mat_vec_q5_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1577
|
+
case 8: mul_mat_vec_q5_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1578
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for Q5_K multi-col MMVQ", ncols_dst);
|
|
1579
|
+
}
|
|
1580
|
+
}
|
|
1581
|
+
|
|
1582
|
+
static void reorder_mul_mat_vec_q5_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
|
807
1583
|
const int nrows, dpct::queue_ptr stream) {
|
|
808
1584
|
GGML_ASSERT(ncols % QK_K == 0);
|
|
1585
|
+
|
|
809
1586
|
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
|
810
1587
|
constexpr size_t num_subgroups = 16;
|
|
811
1588
|
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
|
@@ -813,6 +1590,64 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy,
|
|
|
813
1590
|
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
|
814
1591
|
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
|
815
1592
|
|
|
1593
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
1594
|
+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
|
1595
|
+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1596
|
+
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q5_K>>(vx, vy, dst, ncols,
|
|
1597
|
+
nrows, nd_item);
|
|
1598
|
+
});
|
|
1599
|
+
});
|
|
1600
|
+
}
|
|
1601
|
+
|
|
1602
|
+
template <int ncols_dst>
|
|
1603
|
+
static void reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols(
|
|
1604
|
+
const void * vx, const void * vy, float * dst,
|
|
1605
|
+
const int ncols, const int nrows,
|
|
1606
|
+
const int stride_col_y_bytes, const int stride_col_dst,
|
|
1607
|
+
dpct::queue_ptr stream) {
|
|
1608
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
|
1609
|
+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
|
1610
|
+
constexpr size_t num_subgroups = 16;
|
|
1611
|
+
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
|
1612
|
+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
|
1613
|
+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
|
1614
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
1615
|
+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
|
1616
|
+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1617
|
+
mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q5_K>, ncols_dst>(
|
|
1618
|
+
vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
|
|
1619
|
+
});
|
|
1620
|
+
});
|
|
1621
|
+
}
|
|
1622
|
+
|
|
1623
|
+
static void reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols(
|
|
1624
|
+
const void * vx, const void * vy, float * dst,
|
|
1625
|
+
const int ncols, const int nrows, const int ncols_dst,
|
|
1626
|
+
const int stride_col_y_bytes, const int stride_col_dst,
|
|
1627
|
+
dpct::queue_ptr stream) {
|
|
1628
|
+
switch (ncols_dst) {
|
|
1629
|
+
case 1: reorder_mul_mat_vec_q5_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
1630
|
+
case 2: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1631
|
+
case 3: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1632
|
+
case 4: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1633
|
+
case 5: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1634
|
+
case 6: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1635
|
+
case 7: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1636
|
+
case 8: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1637
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for Q5_K reorder multi-col MMVQ", ncols_dst);
|
|
1638
|
+
}
|
|
1639
|
+
}
|
|
1640
|
+
|
|
1641
|
+
static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
|
1642
|
+
const int nrows, dpct::queue_ptr stream) {
|
|
1643
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
|
1644
|
+
// Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
|
|
1645
|
+
constexpr size_t num_subgroups = WARP_SIZE;
|
|
1646
|
+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
|
|
1647
|
+
|
|
1648
|
+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
|
1649
|
+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
|
1650
|
+
|
|
816
1651
|
stream->submit([&](sycl::handler & cgh) {
|
|
817
1652
|
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
|
818
1653
|
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
@@ -821,6 +1656,46 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy,
|
|
|
821
1656
|
});
|
|
822
1657
|
});
|
|
823
1658
|
}
|
|
1659
|
+
|
|
1660
|
+
template <int ncols_dst>
|
|
1661
|
+
static void reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols(
|
|
1662
|
+
const void * vx, const void * vy, float * dst,
|
|
1663
|
+
const int ncols, const int nrows,
|
|
1664
|
+
const int stride_col_y_bytes, const int stride_col_dst,
|
|
1665
|
+
dpct::queue_ptr stream) {
|
|
1666
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
|
1667
|
+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
|
1668
|
+
constexpr size_t num_subgroups = 16;
|
|
1669
|
+
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
|
1670
|
+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
|
1671
|
+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
|
1672
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
1673
|
+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
|
1674
|
+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1675
|
+
mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>, ncols_dst>(
|
|
1676
|
+
vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
|
|
1677
|
+
});
|
|
1678
|
+
});
|
|
1679
|
+
}
|
|
1680
|
+
|
|
1681
|
+
static void reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols(
|
|
1682
|
+
const void * vx, const void * vy, float * dst,
|
|
1683
|
+
const int ncols, const int nrows, const int ncols_dst,
|
|
1684
|
+
const int stride_col_y_bytes, const int stride_col_dst,
|
|
1685
|
+
dpct::queue_ptr stream) {
|
|
1686
|
+
switch (ncols_dst) {
|
|
1687
|
+
case 1: reorder_mul_mat_vec_q6_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
1688
|
+
case 2: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1689
|
+
case 3: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1690
|
+
case 4: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1691
|
+
case 5: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1692
|
+
case 6: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1693
|
+
case 7: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1694
|
+
case 8: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
|
|
1695
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for Q6_K reorder multi-col MMVQ", ncols_dst);
|
|
1696
|
+
}
|
|
1697
|
+
}
|
|
1698
|
+
|
|
824
1699
|
static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
|
825
1700
|
float *dst, const int ncols,
|
|
826
1701
|
const int nrows,
|
|
@@ -845,6 +1720,51 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
|
|
845
1720
|
}
|
|
846
1721
|
}
|
|
847
1722
|
|
|
1723
|
+
template <int ncols_dst>
|
|
1724
|
+
static void mul_mat_vec_q6_K_q8_1_sycl_ncols(
|
|
1725
|
+
const void * vx, const void * vy, float * dst,
|
|
1726
|
+
const int ncols, const int nrows,
|
|
1727
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1728
|
+
dpct::queue_ptr stream) {
|
|
1729
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
|
1730
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
1731
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
1732
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
1733
|
+
|
|
1734
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
1735
|
+
cgh.parallel_for(
|
|
1736
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
1737
|
+
[=](sycl::nd_item<3> item_ct1)
|
|
1738
|
+
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1739
|
+
mul_mat_vec_q_ncols<QK_K, QI6_K, block_q6_K,
|
|
1740
|
+
VDR_Q6_K_Q8_1_MMVQ,
|
|
1741
|
+
vec_dot_q6_K_q8_1,
|
|
1742
|
+
ncols_dst>(
|
|
1743
|
+
vx, vy, dst, ncols, nrows,
|
|
1744
|
+
stride_col_y, stride_col_dst, item_ct1);
|
|
1745
|
+
});
|
|
1746
|
+
});
|
|
1747
|
+
}
|
|
1748
|
+
|
|
1749
|
+
static void mul_mat_vec_q6_K_q8_1_sycl_switch_ncols(
|
|
1750
|
+
const void * vx, const void * vy, float * dst,
|
|
1751
|
+
const int ncols, const int nrows,
|
|
1752
|
+
const int ncols_dst,
|
|
1753
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1754
|
+
dpct::queue_ptr stream) {
|
|
1755
|
+
switch (ncols_dst) {
|
|
1756
|
+
case 1: mul_mat_vec_q6_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
1757
|
+
case 2: mul_mat_vec_q6_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1758
|
+
case 3: mul_mat_vec_q6_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1759
|
+
case 4: mul_mat_vec_q6_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1760
|
+
case 5: mul_mat_vec_q6_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1761
|
+
case 6: mul_mat_vec_q6_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1762
|
+
case 7: mul_mat_vec_q6_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1763
|
+
case 8: mul_mat_vec_q6_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1764
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for Q6_K multi-col MMVQ", ncols_dst);
|
|
1765
|
+
}
|
|
1766
|
+
}
|
|
1767
|
+
|
|
848
1768
|
|
|
849
1769
|
static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
|
|
850
1770
|
float *dst, const int ncols,
|
|
@@ -1041,6 +1961,51 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
|
|
|
1041
1961
|
}
|
|
1042
1962
|
}
|
|
1043
1963
|
|
|
1964
|
+
template <int ncols_dst>
|
|
1965
|
+
static void mul_mat_vec_iq4_xs_q8_1_sycl_ncols(
|
|
1966
|
+
const void * vx, const void * vy, float * dst,
|
|
1967
|
+
const int ncols, const int nrows,
|
|
1968
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1969
|
+
dpct::queue_ptr stream) {
|
|
1970
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
|
1971
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
1972
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
1973
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
1974
|
+
|
|
1975
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
1976
|
+
cgh.parallel_for(
|
|
1977
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
1978
|
+
[=](sycl::nd_item<3> item_ct1)
|
|
1979
|
+
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1980
|
+
mul_mat_vec_q_ncols<QK_K, QI4_XS/4, block_iq4_xs,
|
|
1981
|
+
1,
|
|
1982
|
+
vec_dot_iq4_xs_q8_1,
|
|
1983
|
+
ncols_dst>(
|
|
1984
|
+
vx, vy, dst, ncols, nrows,
|
|
1985
|
+
stride_col_y, stride_col_dst, item_ct1);
|
|
1986
|
+
});
|
|
1987
|
+
});
|
|
1988
|
+
}
|
|
1989
|
+
|
|
1990
|
+
static void mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols(
|
|
1991
|
+
const void * vx, const void * vy, float * dst,
|
|
1992
|
+
const int ncols, const int nrows,
|
|
1993
|
+
const int ncols_dst,
|
|
1994
|
+
const int stride_col_y, const int stride_col_dst,
|
|
1995
|
+
dpct::queue_ptr stream) {
|
|
1996
|
+
switch (ncols_dst) {
|
|
1997
|
+
case 1: mul_mat_vec_iq4_xs_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
|
|
1998
|
+
case 2: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
1999
|
+
case 3: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
2000
|
+
case 4: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
2001
|
+
case 5: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
2002
|
+
case 6: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
2003
|
+
case 7: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
2004
|
+
case 8: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
|
|
2005
|
+
default: GGML_ABORT("unsupported ncols_dst=%d for IQ4_XS multi-col MMVQ", ncols_dst);
|
|
2006
|
+
}
|
|
2007
|
+
}
|
|
2008
|
+
|
|
1044
2009
|
void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
|
|
1045
2010
|
ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
|
1046
2011
|
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low,
|
|
@@ -1067,50 +2032,219 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
|
|
|
1067
2032
|
case GGML_TYPE_Q4_0:
|
|
1068
2033
|
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
|
1069
2034
|
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
2035
|
+
if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2036
|
+
const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
|
|
2037
|
+
const int stride_col_dst = dst->ne[0];
|
|
2038
|
+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2039
|
+
reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols(
|
|
2040
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2041
|
+
src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
|
|
2042
|
+
return;
|
|
2043
|
+
} else {
|
|
2044
|
+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n");
|
|
2045
|
+
reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
2046
|
+
}
|
|
2047
|
+
} else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2048
|
+
const int stride_col_y = src1_padded_col_size / QK8_1;
|
|
2049
|
+
const int stride_col_dst = dst->ne[0];
|
|
2050
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2051
|
+
mul_mat_vec_q4_0_q8_1_sycl_switch_ncols(
|
|
2052
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2053
|
+
src1_ncols, stride_col_y, stride_col_dst, stream);
|
|
2054
|
+
return;
|
|
2055
|
+
} else if (i == 0 || src1_ncols == 1) {
|
|
1073
2056
|
GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n");
|
|
1074
2057
|
mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1075
2058
|
}
|
|
1076
2059
|
break;
|
|
1077
2060
|
case GGML_TYPE_Q4_1:
|
|
1078
|
-
|
|
2061
|
+
if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2062
|
+
const int stride_col_y = src1_padded_col_size / QK8_1;
|
|
2063
|
+
const int stride_col_dst = dst->ne[0];
|
|
2064
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_1_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2065
|
+
mul_mat_vec_q4_1_q8_1_sycl_switch_ncols(
|
|
2066
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2067
|
+
src1_ncols, stride_col_y, stride_col_dst, stream);
|
|
2068
|
+
return;
|
|
2069
|
+
} else if (i == 0 || src1_ncols == 1) {
|
|
2070
|
+
mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
2071
|
+
}
|
|
1079
2072
|
break;
|
|
1080
2073
|
case GGML_TYPE_Q5_0:
|
|
1081
|
-
|
|
2074
|
+
if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2075
|
+
const int stride_col_y = src1_padded_col_size / QK8_1;
|
|
2076
|
+
const int stride_col_dst = dst->ne[0];
|
|
2077
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2078
|
+
mul_mat_vec_q5_0_q8_1_sycl_switch_ncols(
|
|
2079
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2080
|
+
src1_ncols, stride_col_y, stride_col_dst, stream);
|
|
2081
|
+
return;
|
|
2082
|
+
} else if (i == 0 || src1_ncols == 1) {
|
|
2083
|
+
mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
2084
|
+
}
|
|
1082
2085
|
break;
|
|
1083
2086
|
case GGML_TYPE_Q5_1:
|
|
1084
|
-
|
|
2087
|
+
if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2088
|
+
const int stride_col_y = src1_padded_col_size / QK8_1;
|
|
2089
|
+
const int stride_col_dst = dst->ne[0];
|
|
2090
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_1_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2091
|
+
mul_mat_vec_q5_1_q8_1_sycl_switch_ncols(
|
|
2092
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2093
|
+
src1_ncols, stride_col_y, stride_col_dst, stream);
|
|
2094
|
+
return;
|
|
2095
|
+
} else if (i == 0 || src1_ncols == 1) {
|
|
2096
|
+
mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
2097
|
+
}
|
|
1085
2098
|
break;
|
|
1086
2099
|
case GGML_TYPE_Q8_0:
|
|
1087
|
-
|
|
2100
|
+
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
|
2101
|
+
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
|
2102
|
+
if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2103
|
+
const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
|
|
2104
|
+
const int stride_col_dst = dst->ne[0];
|
|
2105
|
+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2106
|
+
reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols(
|
|
2107
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2108
|
+
src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
|
|
2109
|
+
return;
|
|
2110
|
+
} else {
|
|
2111
|
+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n");
|
|
2112
|
+
reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
2113
|
+
}
|
|
2114
|
+
} else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2115
|
+
const int stride_col_y = src1_padded_col_size / QK8_1;
|
|
2116
|
+
const int stride_col_dst = dst->ne[0];
|
|
2117
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2118
|
+
mul_mat_vec_q8_0_q8_1_sycl_switch_ncols(
|
|
2119
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2120
|
+
src1_ncols, stride_col_y, stride_col_dst, stream);
|
|
2121
|
+
return;
|
|
2122
|
+
} else if (i == 0 || src1_ncols == 1) {
|
|
2123
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl\n");
|
|
2124
|
+
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
2125
|
+
}
|
|
1088
2126
|
break;
|
|
1089
2127
|
case GGML_TYPE_Q2_K:
|
|
1090
|
-
|
|
2128
|
+
if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2129
|
+
const int stride_col_y = src1_padded_col_size / QK8_1;
|
|
2130
|
+
const int stride_col_dst = dst->ne[0];
|
|
2131
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q2_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2132
|
+
mul_mat_vec_q2_K_q8_1_sycl_switch_ncols(
|
|
2133
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2134
|
+
src1_ncols, stride_col_y, stride_col_dst, stream);
|
|
2135
|
+
return;
|
|
2136
|
+
} else if (i == 0 || src1_ncols == 1) {
|
|
2137
|
+
mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
2138
|
+
}
|
|
1091
2139
|
break;
|
|
1092
2140
|
case GGML_TYPE_Q3_K:
|
|
1093
|
-
|
|
2141
|
+
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
|
2142
|
+
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
|
2143
|
+
if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2144
|
+
const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
|
|
2145
|
+
const int stride_col_dst = dst->ne[0];
|
|
2146
|
+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2147
|
+
reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols(
|
|
2148
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2149
|
+
src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
|
|
2150
|
+
return;
|
|
2151
|
+
} else {
|
|
2152
|
+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl\n");
|
|
2153
|
+
reorder_mul_mat_vec_q3_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
2154
|
+
}
|
|
2155
|
+
} else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2156
|
+
const int stride_col_y = src1_padded_col_size / QK8_1;
|
|
2157
|
+
const int stride_col_dst = dst->ne[0];
|
|
2158
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2159
|
+
mul_mat_vec_q3_K_q8_1_sycl_switch_ncols(
|
|
2160
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2161
|
+
src1_ncols, stride_col_y, stride_col_dst, stream);
|
|
2162
|
+
return;
|
|
2163
|
+
} else if (i == 0 || src1_ncols == 1) {
|
|
2164
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl\n");
|
|
2165
|
+
mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
2166
|
+
}
|
|
1094
2167
|
break;
|
|
1095
2168
|
case GGML_TYPE_Q4_K:
|
|
1096
2169
|
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
|
1097
2170
|
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
2171
|
+
if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2172
|
+
const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
|
|
2173
|
+
const int stride_col_dst = dst->ne[0];
|
|
2174
|
+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2175
|
+
reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols(
|
|
2176
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2177
|
+
src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
|
|
2178
|
+
return;
|
|
2179
|
+
} else {
|
|
2180
|
+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n");
|
|
2181
|
+
reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
2182
|
+
}
|
|
2183
|
+
} else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2184
|
+
const int stride_col_y = src1_padded_col_size / QK8_1;
|
|
2185
|
+
const int stride_col_dst = dst->ne[0];
|
|
2186
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2187
|
+
mul_mat_vec_q4_K_q8_1_sycl_switch_ncols(
|
|
2188
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2189
|
+
src1_ncols, stride_col_y, stride_col_dst, stream);
|
|
2190
|
+
return;
|
|
2191
|
+
} else if (i == 0 || src1_ncols == 1) {
|
|
1101
2192
|
GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl\n");
|
|
1102
2193
|
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1103
2194
|
}
|
|
1104
2195
|
break;
|
|
1105
2196
|
case GGML_TYPE_Q5_K:
|
|
1106
|
-
|
|
2197
|
+
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
|
2198
|
+
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
|
2199
|
+
if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2200
|
+
const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
|
|
2201
|
+
const int stride_col_dst = dst->ne[0];
|
|
2202
|
+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2203
|
+
reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols(
|
|
2204
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2205
|
+
src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
|
|
2206
|
+
return;
|
|
2207
|
+
} else {
|
|
2208
|
+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl\n");
|
|
2209
|
+
reorder_mul_mat_vec_q5_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
2210
|
+
}
|
|
2211
|
+
} else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2212
|
+
const int stride_col_y = src1_padded_col_size / QK8_1;
|
|
2213
|
+
const int stride_col_dst = dst->ne[0];
|
|
2214
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2215
|
+
mul_mat_vec_q5_K_q8_1_sycl_switch_ncols(
|
|
2216
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2217
|
+
src1_ncols, stride_col_y, stride_col_dst, stream);
|
|
2218
|
+
return;
|
|
2219
|
+
} else if (i == 0 || src1_ncols == 1) {
|
|
2220
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl\n");
|
|
2221
|
+
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
2222
|
+
}
|
|
1107
2223
|
break;
|
|
1108
2224
|
case GGML_TYPE_Q6_K:
|
|
1109
2225
|
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
|
1110
2226
|
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
2227
|
+
if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2228
|
+
const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
|
|
2229
|
+
const int stride_col_dst = dst->ne[0];
|
|
2230
|
+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2231
|
+
reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols(
|
|
2232
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2233
|
+
src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
|
|
2234
|
+
return;
|
|
2235
|
+
} else {
|
|
2236
|
+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
|
|
2237
|
+
reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
2238
|
+
}
|
|
2239
|
+
} else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2240
|
+
const int stride_col_y = src1_padded_col_size / QK8_1;
|
|
2241
|
+
const int stride_col_dst = dst->ne[0];
|
|
2242
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2243
|
+
mul_mat_vec_q6_K_q8_1_sycl_switch_ncols(
|
|
2244
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2245
|
+
src1_ncols, stride_col_y, stride_col_dst, stream);
|
|
2246
|
+
return;
|
|
2247
|
+
} else if (i == 0 || src1_ncols == 1) {
|
|
1114
2248
|
GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n");
|
|
1115
2249
|
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1116
2250
|
}
|
|
@@ -1140,13 +2274,46 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
|
|
|
1140
2274
|
mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1141
2275
|
break;
|
|
1142
2276
|
case GGML_TYPE_IQ4_XS:
|
|
1143
|
-
|
|
2277
|
+
if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2278
|
+
const int stride_col_y = src1_padded_col_size / QK8_1;
|
|
2279
|
+
const int stride_col_dst = dst->ne[0];
|
|
2280
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2281
|
+
mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols(
|
|
2282
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2283
|
+
src1_ncols, stride_col_y, stride_col_dst, stream);
|
|
2284
|
+
return;
|
|
2285
|
+
} else if (i == 0 || src1_ncols == 1) {
|
|
2286
|
+
mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
2287
|
+
}
|
|
1144
2288
|
break;
|
|
1145
2289
|
case GGML_TYPE_MXFP4:
|
|
1146
|
-
|
|
2290
|
+
if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2291
|
+
const int stride_col_y = src1_padded_col_size / QK8_1;
|
|
2292
|
+
const int stride_col_dst = dst->ne[0];
|
|
2293
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2294
|
+
mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols(
|
|
2295
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2296
|
+
src1_ncols, stride_col_y, stride_col_dst, stream);
|
|
2297
|
+
return;
|
|
2298
|
+
} else if (i == 0 || src1_ncols == 1) {
|
|
2299
|
+
mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
2300
|
+
}
|
|
2301
|
+
break;
|
|
2302
|
+
case GGML_TYPE_NVFP4:
|
|
2303
|
+
if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
|
|
2304
|
+
const int stride_col_y = src1_padded_col_size / QK8_1;
|
|
2305
|
+
const int stride_col_dst = dst->ne[0];
|
|
2306
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
|
|
2307
|
+
mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols(
|
|
2308
|
+
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
|
|
2309
|
+
src1_ncols, stride_col_y, stride_col_dst, stream);
|
|
2310
|
+
return;
|
|
2311
|
+
} else if (i == 0 || src1_ncols == 1) {
|
|
2312
|
+
mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
2313
|
+
}
|
|
1147
2314
|
break;
|
|
1148
2315
|
default:
|
|
1149
|
-
GGML_ABORT("fatal error");
|
|
2316
|
+
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type));
|
|
1150
2317
|
}
|
|
1151
2318
|
}
|
|
1152
2319
|
GGML_UNUSED(src1);
|
|
@@ -1154,3 +2321,154 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
|
|
|
1154
2321
|
GGML_UNUSED(src1_ddf_i);
|
|
1155
2322
|
GGML_UNUSED(ctx);
|
|
1156
2323
|
}
|
|
2324
|
+
|
|
2325
|
+
// src1_row_stride: 0 for shared src1 (gate/up proj), else per-expert stride (down proj).
|
|
2326
|
+
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
|
|
2327
|
+
static void mul_mat_vec_q_moe(
|
|
2328
|
+
const void * __restrict__ vx_base, const void * __restrict__ vy_base,
|
|
2329
|
+
float * __restrict__ dst_base, const int32_t * __restrict__ ids_dev,
|
|
2330
|
+
const int ncols, const int nrows,
|
|
2331
|
+
const size_t expert_weight_stride, const size_t dst_row_stride,
|
|
2332
|
+
const size_t src1_row_stride,
|
|
2333
|
+
const sycl::nd_item<3> & item_ct1) {
|
|
2334
|
+
|
|
2335
|
+
const int expert_idx = item_ct1.get_group(1);
|
|
2336
|
+
const int i02 = ids_dev[expert_idx];
|
|
2337
|
+
|
|
2338
|
+
const char * vx = (const char *) vx_base + (size_t) i02 * expert_weight_stride;
|
|
2339
|
+
const char * vy = (const char *) vy_base + (size_t) expert_idx * src1_row_stride;
|
|
2340
|
+
float * dst = (float *) ((char *) dst_base + (size_t) expert_idx * dst_row_stride);
|
|
2341
|
+
|
|
2342
|
+
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
|
|
2343
|
+
|
|
2344
|
+
if (row >= nrows) {
|
|
2345
|
+
return;
|
|
2346
|
+
}
|
|
2347
|
+
|
|
2348
|
+
const int blocks_per_row = ncols / qk;
|
|
2349
|
+
constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi;
|
|
2350
|
+
|
|
2351
|
+
float tmp = 0.0f;
|
|
2352
|
+
|
|
2353
|
+
const block_q_t * x = (const block_q_t *) vx;
|
|
2354
|
+
const block_q8_1 * y = (const block_q8_1 *) vy;
|
|
2355
|
+
|
|
2356
|
+
for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) {
|
|
2357
|
+
const int ibx = row * blocks_per_row + i;
|
|
2358
|
+
const int iby = i * (qk / QK8_1);
|
|
2359
|
+
|
|
2360
|
+
for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) {
|
|
2361
|
+
const int iqs = elem + vdr * (item_ct1.get_local_id(2) % (qi / vdr));
|
|
2362
|
+
tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
|
|
2363
|
+
}
|
|
2364
|
+
}
|
|
2365
|
+
|
|
2366
|
+
#pragma unroll
|
|
2367
|
+
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
|
2368
|
+
tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
|
2369
|
+
}
|
|
2370
|
+
|
|
2371
|
+
if (item_ct1.get_local_id(2) == 0) {
|
|
2372
|
+
dst[row] = tmp;
|
|
2373
|
+
}
|
|
2374
|
+
}
|
|
2375
|
+
|
|
2376
|
+
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
|
|
2377
|
+
static void launch_mul_mat_vec_q_moe(
|
|
2378
|
+
const void * vx_base, const void * vy, const int32_t * ids_dev,
|
|
2379
|
+
float * dst_base, const int ncols, const int nrows, const int n_experts_used,
|
|
2380
|
+
const size_t expert_weight_stride, const size_t dst_row_stride,
|
|
2381
|
+
const size_t src1_row_stride,
|
|
2382
|
+
dpct::queue_ptr stream) {
|
|
2383
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
2384
|
+
const sycl::range<3> block_nums(1, (unsigned) n_experts_used, (unsigned) block_num_y);
|
|
2385
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
2386
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
2387
|
+
cgh.parallel_for(
|
|
2388
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
2389
|
+
[=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
2390
|
+
mul_mat_vec_q_moe<qk, qi, block_q_t, vdr, vec_dot_q_sycl>(
|
|
2391
|
+
vx_base, vy, dst_base, ids_dev, ncols, nrows,
|
|
2392
|
+
expert_weight_stride, dst_row_stride, src1_row_stride, item);
|
|
2393
|
+
});
|
|
2394
|
+
});
|
|
2395
|
+
}
|
|
2396
|
+
|
|
2397
|
+
bool ggml_sycl_mul_mat_vec_q_id(
|
|
2398
|
+
enum ggml_type src0_type,
|
|
2399
|
+
const void * vx_base,
|
|
2400
|
+
const void * vy,
|
|
2401
|
+
const int32_t * ids_dev,
|
|
2402
|
+
float * dst_base,
|
|
2403
|
+
int ncols,
|
|
2404
|
+
int nrows,
|
|
2405
|
+
int n_experts_used,
|
|
2406
|
+
size_t expert_weight_stride,
|
|
2407
|
+
size_t dst_row_stride,
|
|
2408
|
+
size_t src1_row_stride,
|
|
2409
|
+
dpct::queue_ptr stream) {
|
|
2410
|
+
switch (src0_type) {
|
|
2411
|
+
case GGML_TYPE_Q4_0:
|
|
2412
|
+
launch_mul_mat_vec_q_moe<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
|
2413
|
+
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
|
|
2414
|
+
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
|
|
2415
|
+
return true;
|
|
2416
|
+
case GGML_TYPE_Q4_1:
|
|
2417
|
+
launch_mul_mat_vec_q_moe<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
|
|
2418
|
+
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
|
|
2419
|
+
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
|
|
2420
|
+
return true;
|
|
2421
|
+
case GGML_TYPE_Q5_0:
|
|
2422
|
+
launch_mul_mat_vec_q_moe<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
|
|
2423
|
+
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
|
|
2424
|
+
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
|
|
2425
|
+
return true;
|
|
2426
|
+
case GGML_TYPE_Q5_1:
|
|
2427
|
+
launch_mul_mat_vec_q_moe<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
|
|
2428
|
+
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
|
|
2429
|
+
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
|
|
2430
|
+
return true;
|
|
2431
|
+
case GGML_TYPE_Q8_0:
|
|
2432
|
+
launch_mul_mat_vec_q_moe<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
|
|
2433
|
+
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
|
|
2434
|
+
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
|
|
2435
|
+
return true;
|
|
2436
|
+
case GGML_TYPE_Q2_K:
|
|
2437
|
+
launch_mul_mat_vec_q_moe<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
|
|
2438
|
+
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
|
|
2439
|
+
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
|
|
2440
|
+
return true;
|
|
2441
|
+
case GGML_TYPE_Q3_K:
|
|
2442
|
+
launch_mul_mat_vec_q_moe<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
|
|
2443
|
+
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
|
|
2444
|
+
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
|
|
2445
|
+
return true;
|
|
2446
|
+
case GGML_TYPE_Q4_K:
|
|
2447
|
+
launch_mul_mat_vec_q_moe<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
|
|
2448
|
+
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
|
|
2449
|
+
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
|
|
2450
|
+
return true;
|
|
2451
|
+
case GGML_TYPE_Q5_K:
|
|
2452
|
+
launch_mul_mat_vec_q_moe<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
|
|
2453
|
+
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
|
|
2454
|
+
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
|
|
2455
|
+
return true;
|
|
2456
|
+
case GGML_TYPE_Q6_K:
|
|
2457
|
+
launch_mul_mat_vec_q_moe<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
|
|
2458
|
+
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
|
|
2459
|
+
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
|
|
2460
|
+
return true;
|
|
2461
|
+
case GGML_TYPE_MXFP4:
|
|
2462
|
+
launch_mul_mat_vec_q_moe<QK_MXFP4, QI_MXFP4, block_mxfp4, VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1>(
|
|
2463
|
+
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
|
|
2464
|
+
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
|
|
2465
|
+
return true;
|
|
2466
|
+
case GGML_TYPE_NVFP4:
|
|
2467
|
+
launch_mul_mat_vec_q_moe<QK_NVFP4, QI_NVFP4, block_nvfp4, VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1>(
|
|
2468
|
+
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
|
|
2469
|
+
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
|
|
2470
|
+
return true;
|
|
2471
|
+
default:
|
|
2472
|
+
return false;
|
|
2473
|
+
}
|
|
2474
|
+
}
|