whispercpp 1.3.5 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/LICENSE +1 -1
- data/README.md +133 -3
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -7
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +56 -46
- data/ext/ruby_whisper.h +165 -2
- data/ext/ruby_whisper_context.c +297 -126
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -66
- data/ext/ruby_whisper_segment.c +6 -7
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +46 -16
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +24 -19
- data/ext/sources/examples/cli/cli.cpp +51 -9
- data/ext/sources/examples/common-ggml.cpp +4 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +213 -163
- data/ext/sources/ggml/CMakeLists.txt +29 -15
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +73 -11
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +8 -3
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +155 -16
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +25 -5
- data/ext/sources/ggml/src/ggml-alloc.c +9 -10
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
- data/ext/sources/ggml/src/ggml-common.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
- data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
- data/ext/sources/ggml/src/ggml-impl.h +68 -1
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +385 -119
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
- data/ext/sources/ggml/src/ggml.c +268 -52
- data/ext/sources/ggml/src/gguf.cpp +377 -47
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +62 -40
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +445 -55
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_context_params.rb +82 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +44 -6
- data/whispercpp.gemspec +2 -2
- metadata +426 -280
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
- data/ext/sources/examples/talk-llama/llama-context.h +0 -360
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
- data/ext/sources/examples/talk-llama/llama-model.h +0 -544
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
- data/ext/sources/examples/talk-llama/llama.h +0 -1540
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -569
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
|
@@ -10,9 +10,9 @@
|
|
|
10
10
|
using namespace ggml_cuda_mma;
|
|
11
11
|
|
|
12
12
|
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
|
13
|
-
#define MMQ_ITER_K
|
|
14
|
-
#define
|
|
15
|
-
#define MMQ_NWARPS
|
|
13
|
+
#define MMQ_ITER_K 256
|
|
14
|
+
#define MMQ_ITER_K_FP4 512
|
|
15
|
+
#define MMQ_NWARPS 8
|
|
16
16
|
|
|
17
17
|
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
|
|
18
18
|
typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
|
|
@@ -46,9 +46,12 @@ struct block_q8_1_mmq {
|
|
|
46
46
|
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
|
|
47
47
|
};
|
|
48
48
|
|
|
49
|
+
// this struct is used for fp4 data types (currently only used for Blackwell)
|
|
50
|
+
// mxfp4 has block size 32, each int32 of d4 contains 2 e8m0 scales in the lower 16 bits
|
|
51
|
+
// nvfp4 has block size 16, each int32 of d4 contains 4 ue4m3 scales
|
|
49
52
|
struct block_fp4_mmq {
|
|
50
|
-
uint32_t d4[4];
|
|
51
|
-
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte)
|
|
53
|
+
uint32_t d4[4];
|
|
54
|
+
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte)
|
|
52
55
|
};
|
|
53
56
|
|
|
54
57
|
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
|
|
@@ -57,6 +60,8 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b
|
|
|
57
60
|
|
|
58
61
|
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
|
59
62
|
switch (type_x) {
|
|
63
|
+
case GGML_TYPE_Q1_0:
|
|
64
|
+
return MMQ_Q8_1_DS_LAYOUT_D4;
|
|
60
65
|
case GGML_TYPE_Q4_0:
|
|
61
66
|
case GGML_TYPE_Q4_1:
|
|
62
67
|
return MMQ_Q8_1_DS_LAYOUT_DS4;
|
|
@@ -68,6 +73,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
|
|
68
73
|
return MMQ_Q8_1_DS_LAYOUT_D4;
|
|
69
74
|
case GGML_TYPE_MXFP4:
|
|
70
75
|
return MMQ_Q8_1_DS_LAYOUT_D4;
|
|
76
|
+
case GGML_TYPE_NVFP4:
|
|
77
|
+
return MMQ_Q8_1_DS_LAYOUT_D4;
|
|
71
78
|
case GGML_TYPE_Q2_K:
|
|
72
79
|
return MMQ_Q8_1_DS_LAYOUT_D2S6;
|
|
73
80
|
case GGML_TYPE_Q3_K:
|
|
@@ -100,7 +107,7 @@ struct tile_x_sizes {
|
|
|
100
107
|
};
|
|
101
108
|
|
|
102
109
|
static int get_mmq_x_max_host(const int cc) {
|
|
103
|
-
return (
|
|
110
|
+
return (turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
|
|
104
111
|
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
|
|
105
112
|
#ifdef GGML_CUDA_FORCE_MMQ
|
|
106
113
|
128 : 64;
|
|
@@ -110,9 +117,9 @@ static int get_mmq_x_max_host(const int cc) {
|
|
|
110
117
|
}
|
|
111
118
|
|
|
112
119
|
static constexpr __device__ int get_mmq_x_max_device() {
|
|
113
|
-
#if defined(
|
|
120
|
+
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
114
121
|
return 128;
|
|
115
|
-
#else // defined(
|
|
122
|
+
#else // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
116
123
|
|
|
117
124
|
#if defined(GGML_USE_HIP)
|
|
118
125
|
return 64;
|
|
@@ -139,10 +146,11 @@ static int get_mmq_y_host(const int cc) {
|
|
|
139
146
|
|
|
140
147
|
static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
|
|
141
148
|
#if defined(BLACKWELL_MMA_AVAILABLE)
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
149
|
+
if (type == GGML_TYPE_NVFP4 || type == GGML_TYPE_MXFP4) {
|
|
150
|
+
return MMQ_ITER_K_FP4;
|
|
151
|
+
}
|
|
145
152
|
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
|
153
|
+
return MMQ_ITER_K;
|
|
146
154
|
}
|
|
147
155
|
|
|
148
156
|
static constexpr __device__ int get_mmq_y_device() {
|
|
@@ -183,12 +191,14 @@ static constexpr __device__ int get_mmq_y_device() {
|
|
|
183
191
|
|
|
184
192
|
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
|
|
185
193
|
switch (type) {
|
|
194
|
+
case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0;
|
|
186
195
|
case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
|
|
187
196
|
case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
|
|
188
197
|
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
|
|
189
198
|
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
|
|
190
199
|
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
|
|
191
200
|
case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
|
|
201
|
+
case GGML_TYPE_NVFP4: return MMQ_DP4A_TXS_Q8_0_16;
|
|
192
202
|
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
|
|
193
203
|
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
|
|
194
204
|
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
|
|
@@ -206,12 +216,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
|
|
206
216
|
}
|
|
207
217
|
}
|
|
208
218
|
|
|
209
|
-
#define MMQ_MMA_TILE_X_K_Q8_0
|
|
210
|
-
#define MMQ_MMA_TILE_X_K_FP4
|
|
211
|
-
#define
|
|
212
|
-
#define
|
|
213
|
-
#define
|
|
214
|
-
#define
|
|
219
|
+
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
|
220
|
+
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 and NVFP4 Blackwell
|
|
221
|
+
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 Generic
|
|
222
|
+
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
|
223
|
+
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
|
|
224
|
+
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
|
|
225
|
+
#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
|
|
215
226
|
|
|
216
227
|
static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
|
|
217
228
|
static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
|
|
@@ -220,9 +231,12 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
|
|
|
220
231
|
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
|
221
232
|
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
|
|
222
233
|
static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
|
|
234
|
+
static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
|
|
235
|
+
|
|
223
236
|
|
|
224
237
|
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
225
238
|
switch (type) {
|
|
239
|
+
case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
226
240
|
case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
227
241
|
case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
228
242
|
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
@@ -230,6 +244,11 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
|
230
244
|
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
231
245
|
// tile sizes are the same for Q8_1 and FP4 for blackwell
|
|
232
246
|
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
247
|
+
#if defined(BLACKWELL_MMA_AVAILABLE)
|
|
248
|
+
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_FP4;
|
|
249
|
+
#else
|
|
250
|
+
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4;
|
|
251
|
+
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
|
233
252
|
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
|
234
253
|
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
|
235
254
|
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
@@ -295,6 +314,87 @@ static constexpr __device__ int mmq_get_nwarps_device() {
|
|
|
295
314
|
|
|
296
315
|
// ------------------------------------------------------------
|
|
297
316
|
|
|
317
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q1_0(
|
|
318
|
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
319
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
320
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
321
|
+
|
|
322
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
323
|
+
int * x_qs = (int *) x_tile;
|
|
324
|
+
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
325
|
+
#else
|
|
326
|
+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
|
327
|
+
int * x_qs = (int *) x_tile;
|
|
328
|
+
float * x_df = (float *) (x_qs + txs.qs);
|
|
329
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
330
|
+
|
|
331
|
+
constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0;
|
|
332
|
+
constexpr int threads_per_row = blocks_per_iter * QI1_0;
|
|
333
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
334
|
+
constexpr int scale_entries_per_block = QK1_0 / QK8_1;
|
|
335
|
+
constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block;
|
|
336
|
+
|
|
337
|
+
const int txi = threadIdx.x % threads_per_row;
|
|
338
|
+
const int kbx = txi / QI1_0;
|
|
339
|
+
const int kqsx = txi % QI1_0;
|
|
340
|
+
|
|
341
|
+
#pragma unroll
|
|
342
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
343
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
344
|
+
|
|
345
|
+
if (need_check) {
|
|
346
|
+
i = min(i, i_max);
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + kbx;
|
|
350
|
+
const int qs_offset = 4*kqsx;
|
|
351
|
+
const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) |
|
|
352
|
+
(bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24);
|
|
353
|
+
|
|
354
|
+
int unpacked_bytes[8];
|
|
355
|
+
#pragma unroll
|
|
356
|
+
for (int j = 0; j < 8; ++j) {
|
|
357
|
+
const int shift = j * 4;
|
|
358
|
+
const int bits4 = (qs0 >> shift) & 0x0F;
|
|
359
|
+
const int b0 = (bits4 & 0x01) ? 1 : -1;
|
|
360
|
+
const int b1 = (bits4 & 0x02) ? 1 : -1;
|
|
361
|
+
const int b2 = (bits4 & 0x04) ? 1 : -1;
|
|
362
|
+
const int b3 = (bits4 & 0x08) ? 1 : -1;
|
|
363
|
+
unpacked_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24);
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0;
|
|
367
|
+
#pragma unroll
|
|
368
|
+
for (int j = 0; j < 8; ++j) {
|
|
369
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
370
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j];
|
|
371
|
+
#else
|
|
372
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j];
|
|
373
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
374
|
+
}
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
const int ksx = threadIdx.x % scale_entries_per_row;
|
|
378
|
+
const int scale_block = ksx / scale_entries_per_block;
|
|
379
|
+
|
|
380
|
+
#pragma unroll
|
|
381
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
382
|
+
int i = i0 + threadIdx.y;
|
|
383
|
+
|
|
384
|
+
if (need_check) {
|
|
385
|
+
i = min(i, i_max);
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block;
|
|
389
|
+
|
|
390
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
391
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d;
|
|
392
|
+
#else
|
|
393
|
+
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d;
|
|
394
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
395
|
+
}
|
|
396
|
+
}
|
|
397
|
+
|
|
298
398
|
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
|
|
299
399
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
300
400
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
@@ -379,17 +479,25 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
|
379
479
|
#pragma unroll
|
|
380
480
|
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
381
481
|
const int i = i0 + threadIdx.x;
|
|
382
|
-
|
|
383
482
|
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
|
|
384
483
|
|
|
385
484
|
int u[2*VDR_Q4_0_Q8_1_MMQ];
|
|
386
485
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
486
|
+
constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
|
|
487
|
+
constexpr int mcpy_int = max_cpy / sizeof(int);
|
|
488
|
+
static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ");
|
|
489
|
+
|
|
490
|
+
int tmp0[4], tmp1[4];
|
|
491
|
+
|
|
492
|
+
#pragma unroll
|
|
493
|
+
for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) {
|
|
494
|
+
ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] );
|
|
495
|
+
ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0 + l0 * mcpy_int]);
|
|
391
496
|
}
|
|
392
497
|
|
|
498
|
+
u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3];
|
|
499
|
+
u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3];
|
|
500
|
+
|
|
393
501
|
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
|
|
394
502
|
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
|
|
395
503
|
x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
@@ -482,17 +590,25 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
|
482
590
|
#pragma unroll
|
|
483
591
|
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
484
592
|
const int i = i0 + threadIdx.x;
|
|
485
|
-
|
|
486
593
|
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
|
|
487
594
|
|
|
488
595
|
int u[2*VDR_Q4_1_Q8_1_MMQ];
|
|
489
596
|
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
597
|
+
constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
|
|
598
|
+
constexpr int mcpy_int = max_cpy / sizeof(int);
|
|
599
|
+
static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ");
|
|
600
|
+
|
|
601
|
+
int tmp0[4], tmp1[4];
|
|
602
|
+
|
|
603
|
+
#pragma unroll
|
|
604
|
+
for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) {
|
|
605
|
+
ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] );
|
|
606
|
+
ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_1 + l0 * mcpy_int]);
|
|
494
607
|
}
|
|
495
608
|
|
|
609
|
+
u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3];
|
|
610
|
+
u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3];
|
|
611
|
+
|
|
496
612
|
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
|
|
497
613
|
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
|
|
498
614
|
x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
@@ -826,6 +942,187 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
|
|
|
826
942
|
}
|
|
827
943
|
}
|
|
828
944
|
|
|
945
|
+
#ifdef BLACKWELL_MMA_AVAILABLE
|
|
946
|
+
template <int mmq_y, bool need_check>
|
|
947
|
+
static __device__ __forceinline__ void load_tiles_nvfp4_nvfp4(const char * __restrict__ x,
|
|
948
|
+
int * __restrict__ x_tile,
|
|
949
|
+
const int kbx0,
|
|
950
|
+
const int i_max,
|
|
951
|
+
const int stride) {
|
|
952
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
953
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
954
|
+
constexpr int iter_k = get_iter_k(GGML_TYPE_NVFP4);
|
|
955
|
+
constexpr int threads_per_row = iter_k / QK_NVFP4; // each thread processes 1 block
|
|
956
|
+
constexpr int rows_per_warp = warp_size / threads_per_row;
|
|
957
|
+
|
|
958
|
+
uint32_t * x_u32 = (uint32_t *) x_tile;
|
|
959
|
+
|
|
960
|
+
const int txi = threadIdx.x;
|
|
961
|
+
const int kbx = txi % threads_per_row;
|
|
962
|
+
const int row_in_warp = txi / threads_per_row;
|
|
963
|
+
|
|
964
|
+
const block_nvfp4 * bxi_base = (const block_nvfp4 *) x + kbx0 + kbx;
|
|
965
|
+
uint32_t * x_u32_scale = x_u32 + 64 + kbx;
|
|
966
|
+
|
|
967
|
+
#pragma unroll
|
|
968
|
+
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
|
|
969
|
+
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
|
|
970
|
+
|
|
971
|
+
if constexpr (need_check) {
|
|
972
|
+
i = min(i, i_max);
|
|
973
|
+
}
|
|
974
|
+
|
|
975
|
+
const block_nvfp4 * bxi = bxi_base + i * stride;
|
|
976
|
+
const int row_base = i * MMQ_MMA_TILE_X_K_FP4;
|
|
977
|
+
const int q_base = row_base + 8 * kbx;
|
|
978
|
+
|
|
979
|
+
const uint32_t * src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
|
|
980
|
+
|
|
981
|
+
#pragma unroll
|
|
982
|
+
for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
|
|
983
|
+
x_u32[q_base + 2 * sub + 0] = src_qs[2 * sub + 0];
|
|
984
|
+
x_u32[q_base + 2 * sub + 1] = src_qs[2 * sub + 1];
|
|
985
|
+
}
|
|
986
|
+
|
|
987
|
+
x_u32_scale[row_base] = get_int_b4(bxi->d, 0);
|
|
988
|
+
}
|
|
989
|
+
}
|
|
990
|
+
|
|
991
|
+
// Shared MMA kernel for MXFP4 and NVFP4 on Blackwell.
|
|
992
|
+
// Both quantizations encode values as e2m1 (FP4) and produce one uint32 scale per
|
|
993
|
+
// m16n8k64 MMA call; only the PTX kind (scale_vec::2X ue8m0 vs scale_vec::4X ue4m3)
|
|
994
|
+
// and the per-type stride constant differ.
|
|
995
|
+
template <int mmq_x, int mmq_y, ggml_type type>
|
|
996
|
+
static __device__ __forceinline__ void vec_dot_fp4_fp4_mma(const int * __restrict__ x,
|
|
997
|
+
const int * __restrict__ y,
|
|
998
|
+
float * __restrict__ sum,
|
|
999
|
+
const int k00) {
|
|
1000
|
+
static_assert(type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4,
|
|
1001
|
+
"vec_dot_fp4_fp4_mma: type must be MXFP4 or NVFP4");
|
|
1002
|
+
|
|
1003
|
+
typedef tile<16, 8, int> tile_A;
|
|
1004
|
+
typedef tile<8, 8, int> tile_B;
|
|
1005
|
+
typedef tile<16, 8, float> tile_C;
|
|
1006
|
+
|
|
1007
|
+
constexpr int stride = MMQ_MMA_TILE_X_K_FP4;
|
|
1008
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1009
|
+
constexpr int rows_per_warp = 2 * granularity;
|
|
1010
|
+
constexpr int ntx = rows_per_warp / tile_C::I;
|
|
1011
|
+
constexpr int nfrags = MMQ_TILE_NE_K / tile_A::J;
|
|
1012
|
+
|
|
1013
|
+
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_K);
|
|
1014
|
+
|
|
1015
|
+
const int * x_qs = (const int *) x;
|
|
1016
|
+
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
|
1017
|
+
const int * y_qs = (const int *) y + 4;
|
|
1018
|
+
const uint32_t * y_sc = (const uint32_t *) y;
|
|
1019
|
+
|
|
1020
|
+
// 2 threads per quad supply the packed scale register to the block_scale MMA,
|
|
1021
|
+
// see https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
|
|
1022
|
+
const int tidx_A = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
|
|
1023
|
+
const int tidx_B = threadIdx.x / 4;
|
|
1024
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1025
|
+
|
|
1026
|
+
tile_A A[ntx][nfrags];
|
|
1027
|
+
uint32_t scaleA[ntx][nfrags];
|
|
1028
|
+
|
|
1029
|
+
#pragma unroll
|
|
1030
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1031
|
+
#pragma unroll
|
|
1032
|
+
for (int frag = 0; frag < nfrags; ++frag) {
|
|
1033
|
+
const int k0 = k00 + frag * tile_A::J;
|
|
1034
|
+
load_ldmatrix(A[n][frag], x_qs + (i0 + n * tile_A::I) * stride + k0, stride);
|
|
1035
|
+
scaleA[n][frag] = x_sc[(i0 + n * tile_A::I + tidx_A) * stride + k0 / tile_A::J];
|
|
1036
|
+
}
|
|
1037
|
+
}
|
|
1038
|
+
|
|
1039
|
+
#pragma unroll
|
|
1040
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
|
|
1041
|
+
tile_B B[nfrags];
|
|
1042
|
+
uint32_t scaleB[nfrags];
|
|
1043
|
+
|
|
1044
|
+
#pragma unroll
|
|
1045
|
+
for (int frag = 0; frag < nfrags; ++frag) {
|
|
1046
|
+
const int k0 = frag * tile_B::J;
|
|
1047
|
+
load_generic(B[frag], y_qs + j0 * MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K);
|
|
1048
|
+
scaleB[frag] = y_sc[(j0 + tidx_B) * MMQ_TILE_Y_K + frag];
|
|
1049
|
+
}
|
|
1050
|
+
|
|
1051
|
+
#pragma unroll
|
|
1052
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1053
|
+
#pragma unroll
|
|
1054
|
+
for (int frag = 0; frag < nfrags; ++frag) {
|
|
1055
|
+
tile_C C = {};
|
|
1056
|
+
mma_block_scaled_fp4<type>(C, A[n][frag], B[frag], scaleA[n][frag], scaleB[frag]);
|
|
1057
|
+
#pragma unroll
|
|
1058
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1059
|
+
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
|
|
1060
|
+
}
|
|
1061
|
+
}
|
|
1062
|
+
}
|
|
1063
|
+
}
|
|
1064
|
+
}
|
|
1065
|
+
#endif // BLACKWELL_MMA_AVAILABLE
|
|
1066
|
+
|
|
1067
|
+
|
|
1068
|
+
template <int mmq_y, bool need_check>
|
|
1069
|
+
static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
|
|
1070
|
+
int * __restrict__ x_tile,
|
|
1071
|
+
const int kb0,
|
|
1072
|
+
const int i_max,
|
|
1073
|
+
const int stride) {
|
|
1074
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1075
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1076
|
+
|
|
1077
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1078
|
+
int * x_qs = (int *) x_tile;
|
|
1079
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
1080
|
+
#else
|
|
1081
|
+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_NVFP4, mmq_y);
|
|
1082
|
+
int * x_qs = (int *) x_tile;
|
|
1083
|
+
float * x_df = (float *) (x_qs + txs.qs);
|
|
1084
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1085
|
+
|
|
1086
|
+
constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4;
|
|
1087
|
+
constexpr int rows_per_warp = warp_size / threads_per_row;
|
|
1088
|
+
const int kbx = threadIdx.x % threads_per_row;
|
|
1089
|
+
const int row_in_warp = threadIdx.x / threads_per_row;
|
|
1090
|
+
|
|
1091
|
+
#pragma unroll
|
|
1092
|
+
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
|
|
1093
|
+
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
|
|
1094
|
+
|
|
1095
|
+
if constexpr (need_check) {
|
|
1096
|
+
i = min(i, i_max);
|
|
1097
|
+
}
|
|
1098
|
+
|
|
1099
|
+
const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx;
|
|
1100
|
+
const uint32_t * __restrict__ src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
|
|
1101
|
+
const int kqs = 16 * kbx;
|
|
1102
|
+
const int ksc = 4 * kbx;
|
|
1103
|
+
|
|
1104
|
+
#pragma unroll
|
|
1105
|
+
for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
|
|
1106
|
+
const int2 q0 = get_int_from_table_16(src_qs[2 * sub + 0], kvalues_mxfp4);
|
|
1107
|
+
const int2 q1 = get_int_from_table_16(src_qs[2 * sub + 1], kvalues_mxfp4);
|
|
1108
|
+
|
|
1109
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1110
|
+
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0] = q0.x;
|
|
1111
|
+
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1] = q1.x;
|
|
1112
|
+
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2] = q0.y;
|
|
1113
|
+
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3] = q1.y;
|
|
1114
|
+
x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
|
|
1115
|
+
#else
|
|
1116
|
+
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 0] = q0.x;
|
|
1117
|
+
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 1] = q1.x;
|
|
1118
|
+
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 2] = q0.y;
|
|
1119
|
+
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 3] = q1.y;
|
|
1120
|
+
x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
|
|
1121
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1122
|
+
}
|
|
1123
|
+
}
|
|
1124
|
+
}
|
|
1125
|
+
|
|
829
1126
|
template <int mmq_x, int mmq_y>
|
|
830
1127
|
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
|
831
1128
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
@@ -887,13 +1184,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
|
887
1184
|
tile_A A[ntx];
|
|
888
1185
|
#pragma unroll
|
|
889
1186
|
for (int n = 0; n < ntx; ++n) {
|
|
890
|
-
|
|
1187
|
+
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
|
891
1188
|
}
|
|
892
1189
|
|
|
893
1190
|
#pragma unroll
|
|
894
1191
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
895
1192
|
tile_B B;
|
|
896
|
-
|
|
1193
|
+
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
897
1194
|
|
|
898
1195
|
float dB;
|
|
899
1196
|
const int j = j0 + tile_C::get_j(0);
|
|
@@ -996,77 +1293,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
|
996
1293
|
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
997
1294
|
}
|
|
998
1295
|
|
|
999
|
-
template <int mmq_x, int mmq_y>
|
|
1000
|
-
static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
|
|
1001
|
-
const int * __restrict__ y,
|
|
1002
|
-
float * __restrict__ sum,
|
|
1003
|
-
const int k00) {
|
|
1004
|
-
typedef tile<16, 8, int> tile_A;
|
|
1005
|
-
typedef tile<8, 8, int> tile_B;
|
|
1006
|
-
typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
|
|
1007
|
-
|
|
1008
|
-
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1009
|
-
constexpr int rows_per_warp = 2 * granularity;
|
|
1010
|
-
constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
|
|
1011
|
-
|
|
1012
|
-
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
|
|
1013
|
-
|
|
1014
|
-
// Match layout from load_tiles_mxfp4_fp4
|
|
1015
|
-
const int * x_qs = (const int *) x;
|
|
1016
|
-
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
|
1017
|
-
const int * y_qs = (const int *) y + 4;
|
|
1018
|
-
const uint32_t * y_sc = (const uint32_t *) y;
|
|
1019
|
-
|
|
1020
|
-
// tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
|
|
1021
|
-
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
|
1022
|
-
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
|
1023
|
-
|
|
1024
|
-
// Block scale
|
|
1025
|
-
// Each thread has to point to a 4 byte scale value
|
|
1026
|
-
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
|
|
1027
|
-
|
|
1028
|
-
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1029
|
-
|
|
1030
|
-
#pragma unroll
|
|
1031
|
-
for (int n = 0; n < ntx; ++n) {
|
|
1032
|
-
#pragma unroll
|
|
1033
|
-
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
|
1034
|
-
const int k0 = k00 + k01;
|
|
1035
|
-
|
|
1036
|
-
load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
|
|
1037
|
-
MMQ_MMA_TILE_X_K_FP4);
|
|
1038
|
-
|
|
1039
|
-
// based on block-scaling document, 2 threads in each quad need to supply to the scale value
|
|
1040
|
-
const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
|
|
1041
|
-
scaleA[n][k01 / (2 * QI_MXFP4)] =
|
|
1042
|
-
*(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
|
|
1043
|
-
}
|
|
1044
|
-
}
|
|
1045
|
-
|
|
1046
|
-
#pragma unroll
|
|
1047
|
-
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
|
|
1048
|
-
#pragma unroll
|
|
1049
|
-
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
|
1050
|
-
tile_B B;
|
|
1051
|
-
uint32_t scaleB; // 2xN scales
|
|
1052
|
-
|
|
1053
|
-
load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
|
|
1054
|
-
|
|
1055
|
-
scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
|
|
1056
|
-
|
|
1057
|
-
#pragma unroll
|
|
1058
|
-
for (int n = 0; n < ntx; ++n) {
|
|
1059
|
-
tile_C C;
|
|
1060
|
-
|
|
1061
|
-
mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
|
|
1062
|
-
#pragma unroll
|
|
1063
|
-
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1064
|
-
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
|
|
1065
|
-
}
|
|
1066
|
-
}
|
|
1067
|
-
}
|
|
1068
|
-
}
|
|
1069
|
-
}
|
|
1070
1296
|
|
|
1071
1297
|
template <int mmq_x, int mmq_y>
|
|
1072
1298
|
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
|
@@ -1128,13 +1354,13 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
|
1128
1354
|
tile_A A[ntx];
|
|
1129
1355
|
#pragma unroll
|
|
1130
1356
|
for (int n = 0; n < ntx; ++n) {
|
|
1131
|
-
|
|
1357
|
+
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
|
1132
1358
|
}
|
|
1133
1359
|
|
|
1134
1360
|
#pragma unroll
|
|
1135
1361
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1136
1362
|
tile_B B;
|
|
1137
|
-
|
|
1363
|
+
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1138
1364
|
|
|
1139
1365
|
const int j = j0 + tile_C::get_j(0);
|
|
1140
1366
|
const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
@@ -1229,7 +1455,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
|
1229
1455
|
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1230
1456
|
}
|
|
1231
1457
|
|
|
1232
|
-
// Used for Q3_K, IQ2_S, and IQ2_XS
|
|
1458
|
+
// Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS
|
|
1233
1459
|
template <int mmq_x, int mmq_y>
|
|
1234
1460
|
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
|
1235
1461
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
@@ -1268,57 +1494,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
|
|
1268
1494
|
template <int mmq_x, int mmq_y>
|
|
1269
1495
|
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
1270
1496
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1271
|
-
#if defined(AMD_MFMA_AVAILABLE)
|
|
1272
|
-
constexpr data_layout input_layout = get_input_data_layout();
|
|
1273
|
-
typedef tile<16, 8, int, input_layout> tile_A;
|
|
1274
|
-
typedef tile<16, 8, int, input_layout> tile_B;
|
|
1275
|
-
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
1276
|
-
typedef tile<64, 2, int, input_layout> tile_load;
|
|
1277
|
-
|
|
1278
|
-
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1279
|
-
constexpr int rows_per_warp = granularity;
|
|
1280
|
-
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1281
|
-
|
|
1282
|
-
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1283
|
-
|
|
1284
|
-
const int * x_qs = (const int *) x;
|
|
1285
|
-
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
1286
|
-
const int * y_qs = (const int *) y + 4;
|
|
1287
|
-
const float * y_df = (const float *) y;
|
|
1288
|
-
|
|
1289
|
-
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1290
|
-
|
|
1291
|
-
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
1292
|
-
const int k0 = k00 + k01;
|
|
1293
|
-
|
|
1294
|
-
tile_A A[ntx];
|
|
1295
|
-
#pragma unroll
|
|
1296
|
-
for (int n = 0; n < ntx; ++n) {
|
|
1297
|
-
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
|
1298
|
-
}
|
|
1299
|
-
|
|
1300
|
-
#pragma unroll
|
|
1301
|
-
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1302
|
-
tile_B B[1];
|
|
1303
|
-
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1304
|
-
|
|
1305
|
-
const int j = j0 + tile_C::get_j(0);
|
|
1306
|
-
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
|
|
1307
|
-
|
|
1308
|
-
#pragma unroll
|
|
1309
|
-
for (int n = 0; n < ntx; ++n) {
|
|
1310
|
-
tile_C C;
|
|
1311
|
-
mma(C, A[n], B[0]);
|
|
1312
|
-
|
|
1313
|
-
#pragma unroll
|
|
1314
|
-
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1315
|
-
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
1316
|
-
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
|
|
1317
|
-
}
|
|
1318
|
-
}
|
|
1319
|
-
}
|
|
1320
|
-
}
|
|
1321
|
-
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
|
1497
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1322
1498
|
constexpr data_layout input_layout = get_input_data_layout();
|
|
1323
1499
|
typedef tile<16, 4, int, input_layout> tile_A;
|
|
1324
1500
|
typedef tile<16, 4, int, input_layout> tile_B;
|
|
@@ -1343,13 +1519,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
|
1343
1519
|
tile_A A[ntx];
|
|
1344
1520
|
#pragma unroll
|
|
1345
1521
|
for (int n = 0; n < ntx; ++n) {
|
|
1346
|
-
|
|
1522
|
+
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
|
1347
1523
|
}
|
|
1348
1524
|
|
|
1349
1525
|
#pragma unroll
|
|
1350
1526
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1351
1527
|
tile_B B;
|
|
1352
|
-
|
|
1528
|
+
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1353
1529
|
|
|
1354
1530
|
const int j = j0 + tile_C::get_j(0);
|
|
1355
1531
|
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
|
@@ -1575,74 +1751,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
|
1575
1751
|
template <int mmq_x, int mmq_y>
|
|
1576
1752
|
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
1577
1753
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1578
|
-
#if defined(AMD_MFMA_AVAILABLE)
|
|
1579
|
-
constexpr data_layout input_layout = get_input_data_layout();
|
|
1580
|
-
typedef tile<16, 8, int, input_layout> tile_A;
|
|
1581
|
-
typedef tile<16, 8, int, input_layout> tile_B;
|
|
1582
|
-
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
1583
|
-
typedef tile<64, 2, int, input_layout> tile_load;
|
|
1584
|
-
|
|
1585
|
-
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1586
|
-
constexpr int rows_per_warp = granularity;
|
|
1587
|
-
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1588
|
-
|
|
1589
|
-
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1590
|
-
|
|
1591
|
-
const int * x_qs = (const int *) x;
|
|
1592
|
-
const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
|
|
1593
|
-
const int * y_qs = (const int *) y + 4;
|
|
1594
|
-
const half2 * y_ds = (const half2 *) y;
|
|
1595
|
-
|
|
1596
|
-
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1597
|
-
|
|
1598
|
-
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
1599
|
-
const int k0 = k00 + k01;
|
|
1600
|
-
|
|
1601
|
-
tile_A A[ntx];
|
|
1602
|
-
#pragma unroll
|
|
1603
|
-
for (int n = 0; n < ntx; ++n) {
|
|
1604
|
-
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
|
1605
|
-
}
|
|
1606
|
-
|
|
1607
|
-
#pragma unroll
|
|
1608
|
-
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1609
|
-
tile_B B[1];
|
|
1610
|
-
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1611
|
-
|
|
1612
|
-
const int j = j0 + tile_C::get_j(0);
|
|
1613
|
-
const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
|
|
1614
|
-
const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
|
|
1615
|
-
: (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
|
|
1616
|
-
: __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
|
|
1617
|
-
|
|
1618
|
-
tile_C Cm;
|
|
1619
|
-
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
1620
|
-
tile_A A1;
|
|
1621
|
-
A1.x[0] = 0x01010101;
|
|
1622
|
-
A1.x[1] = 0x01010101;
|
|
1623
|
-
mma(Cm, A1, B[0]);
|
|
1624
|
-
}
|
|
1625
|
-
|
|
1626
|
-
#pragma unroll
|
|
1627
|
-
for (int n = 0; n < ntx; ++n) {
|
|
1628
|
-
tile_C Cd;
|
|
1629
|
-
mma(Cd, A[n], B[0]);
|
|
1630
|
-
|
|
1631
|
-
#pragma unroll
|
|
1632
|
-
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1633
|
-
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
1634
|
-
const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
|
|
1635
|
-
float tmp = Cd.x[l]*dm.x;
|
|
1636
|
-
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
1637
|
-
tmp -= Cm.x[l]*dm.y;
|
|
1638
|
-
}
|
|
1639
|
-
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
|
|
1640
|
-
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
|
|
1641
|
-
}
|
|
1642
|
-
}
|
|
1643
|
-
}
|
|
1644
|
-
}
|
|
1645
|
-
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
|
1754
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1646
1755
|
constexpr data_layout input_layout = get_input_data_layout();
|
|
1647
1756
|
typedef tile<16, 4, int, input_layout> tile_A;
|
|
1648
1757
|
typedef tile<16, 4, int, input_layout> tile_B;
|
|
@@ -1667,13 +1776,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1667
1776
|
tile_A A[ntx];
|
|
1668
1777
|
#pragma unroll
|
|
1669
1778
|
for (int n = 0; n < ntx; ++n) {
|
|
1670
|
-
|
|
1779
|
+
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
|
1671
1780
|
}
|
|
1672
1781
|
|
|
1673
1782
|
#pragma unroll
|
|
1674
1783
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1675
1784
|
tile_B B;
|
|
1676
|
-
|
|
1785
|
+
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1677
1786
|
|
|
1678
1787
|
const int j = j0 + tile_C::get_j(0);
|
|
1679
1788
|
const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
|
|
@@ -2406,59 +2515,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
|
|
2406
2515
|
template <int mmq_x, int mmq_y>
|
|
2407
2516
|
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
2408
2517
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
2409
|
-
#if defined(AMD_MFMA_AVAILABLE)
|
|
2410
|
-
constexpr data_layout input_layout = get_input_data_layout();
|
|
2411
|
-
typedef tile<16, 8, int, input_layout> tile_A;
|
|
2412
|
-
typedef tile<16, 8, int, input_layout> tile_B;
|
|
2413
|
-
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
2414
|
-
typedef tile<64, 2, int, input_layout> tile_load;
|
|
2415
|
-
|
|
2416
|
-
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
2417
|
-
constexpr int rows_per_warp = granularity;
|
|
2418
|
-
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
2419
|
-
|
|
2420
|
-
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
2421
|
-
|
|
2422
|
-
const int * x_qs = (const int *) x;
|
|
2423
|
-
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
2424
|
-
const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
|
|
2425
|
-
const int * y_qs = (const int *) y + 4;
|
|
2426
|
-
const float * y_df = (const float *) y;
|
|
2427
|
-
|
|
2428
|
-
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
2429
|
-
|
|
2430
|
-
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
2431
|
-
const int k0 = k00 + k01;
|
|
2432
|
-
|
|
2433
|
-
tile_A A[ntx];
|
|
2434
|
-
#pragma unroll
|
|
2435
|
-
for (int n = 0; n < ntx; ++n) {
|
|
2436
|
-
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
|
|
2437
|
-
}
|
|
2438
|
-
|
|
2439
|
-
#pragma unroll
|
|
2440
|
-
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
2441
|
-
tile_B B[1];
|
|
2442
|
-
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
2443
|
-
|
|
2444
|
-
const int j = j0 + tile_C::get_j(0);
|
|
2445
|
-
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
|
|
2446
|
-
|
|
2447
|
-
#pragma unroll
|
|
2448
|
-
for (int n = 0; n < ntx; ++n) {
|
|
2449
|
-
tile_C C;
|
|
2450
|
-
mma(C, A[n], B[0]);
|
|
2451
|
-
|
|
2452
|
-
#pragma unroll
|
|
2453
|
-
for (int l = 0; l < tile_C::ne; ++l) {
|
|
2454
|
-
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
2455
|
-
const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
|
|
2456
|
-
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
|
|
2457
|
-
}
|
|
2458
|
-
}
|
|
2459
|
-
}
|
|
2460
|
-
}
|
|
2461
|
-
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
|
2518
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2462
2519
|
constexpr data_layout input_layout = get_input_data_layout();
|
|
2463
2520
|
typedef tile<16, 4, int, input_layout> tile_A;
|
|
2464
2521
|
typedef tile<16, 4, int, input_layout> tile_B;
|
|
@@ -2484,13 +2541,13 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
|
2484
2541
|
tile_A A[ntx];
|
|
2485
2542
|
#pragma unroll
|
|
2486
2543
|
for (int n = 0; n < ntx; ++n) {
|
|
2487
|
-
|
|
2544
|
+
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
|
|
2488
2545
|
}
|
|
2489
2546
|
|
|
2490
2547
|
#pragma unroll
|
|
2491
2548
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
2492
2549
|
tile_B B;
|
|
2493
|
-
|
|
2550
|
+
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
2494
2551
|
|
|
2495
2552
|
const int j = j0 + tile_C::get_j(0);
|
|
2496
2553
|
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
|
@@ -2715,14 +2772,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2715
2772
|
|
|
2716
2773
|
#pragma unroll
|
|
2717
2774
|
for (int l = 0; l < QR2_XXS; ++l) {
|
|
2718
|
-
const
|
|
2719
|
-
const
|
|
2775
|
+
const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]];
|
|
2776
|
+
const uint32_t signs = unpack_ksigns(aux32 >> (7 * l));
|
|
2720
2777
|
|
|
2721
|
-
const int signs0 = __vcmpne4(
|
|
2722
|
-
const int grid0 = __vsub4(grid_pos
|
|
2778
|
+
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
|
2779
|
+
const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
|
|
2723
2780
|
|
|
2724
|
-
const int signs1 = __vcmpne4(
|
|
2725
|
-
const int grid1 = __vsub4(grid_pos
|
|
2781
|
+
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
|
2782
|
+
const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
|
|
2726
2783
|
|
|
2727
2784
|
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2728
2785
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
|
|
@@ -2733,12 +2790,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2733
2790
|
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2734
2791
|
}
|
|
2735
2792
|
|
|
2736
|
-
const int ls = aux32 >>
|
|
2793
|
+
const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
|
|
2737
2794
|
const float d = bxi->d;
|
|
2738
2795
|
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2739
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (
|
|
2796
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
|
|
2740
2797
|
#else
|
|
2741
|
-
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (
|
|
2798
|
+
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
|
|
2742
2799
|
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2743
2800
|
}
|
|
2744
2801
|
}
|
|
@@ -2776,11 +2833,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2776
2833
|
|
|
2777
2834
|
#pragma unroll
|
|
2778
2835
|
for (int l = 0; l < QR2_XS; ++l) {
|
|
2779
|
-
const
|
|
2780
|
-
const uint32_t
|
|
2836
|
+
const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF];
|
|
2837
|
+
const uint32_t signs = unpack_ksigns(q2[l] >> 9);
|
|
2781
2838
|
|
|
2782
|
-
const int
|
|
2783
|
-
const int
|
|
2839
|
+
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
|
2840
|
+
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
|
2841
|
+
|
|
2842
|
+
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
|
2843
|
+
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
|
2784
2844
|
|
|
2785
2845
|
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2786
2846
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
|
@@ -2904,11 +2964,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2904
2964
|
#pragma unroll
|
|
2905
2965
|
for (int l = 0; l < QR3_XXS; ++l) {
|
|
2906
2966
|
const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
|
|
2967
|
+
const uint32_t signs = unpack_ksigns(aux32 >> (7*l));
|
|
2907
2968
|
|
|
2908
|
-
const int
|
|
2969
|
+
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
|
2970
|
+
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
|
2909
2971
|
|
|
2910
|
-
const int
|
|
2911
|
-
const int grid_h = __vsub4(grid_pos.y ^
|
|
2972
|
+
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
|
2973
|
+
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
|
2912
2974
|
|
|
2913
2975
|
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2914
2976
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
|
|
@@ -3203,6 +3265,14 @@ static __device__ __forceinline__ void mmq_write_back_mma(
|
|
|
3203
3265
|
template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
|
|
3204
3266
|
struct mmq_type_traits;
|
|
3205
3267
|
|
|
3268
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3269
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q1_0> {
|
|
3270
|
+
static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ;
|
|
3271
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0<mmq_y, need_check>;
|
|
3272
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
3273
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
3274
|
+
};
|
|
3275
|
+
|
|
3206
3276
|
template <int mmq_x, int mmq_y, bool need_check>
|
|
3207
3277
|
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
|
|
3208
3278
|
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
|
|
@@ -3248,7 +3318,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
|
|
3248
3318
|
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
|
|
3249
3319
|
#ifdef BLACKWELL_MMA_AVAILABLE
|
|
3250
3320
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
|
|
3251
|
-
static constexpr vec_dot_mmq_t vec_dot_mma =
|
|
3321
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_MXFP4>;
|
|
3252
3322
|
#else
|
|
3253
3323
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
|
|
3254
3324
|
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
@@ -3256,6 +3326,19 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
|
|
3256
3326
|
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
3257
3327
|
};
|
|
3258
3328
|
|
|
3329
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3330
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> {
|
|
3331
|
+
static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ;
|
|
3332
|
+
#ifdef BLACKWELL_MMA_AVAILABLE
|
|
3333
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4_nvfp4<mmq_y, need_check>;
|
|
3334
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_NVFP4>;
|
|
3335
|
+
#else
|
|
3336
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4<mmq_y, need_check>;
|
|
3337
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
|
3338
|
+
#endif // BLACKWELL_MMA_AVAILABLE
|
|
3339
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
|
|
3340
|
+
};
|
|
3341
|
+
|
|
3259
3342
|
template <int mmq_x, int mmq_y, bool need_check>
|
|
3260
3343
|
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
|
|
3261
3344
|
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
|
@@ -3387,7 +3470,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
|
3387
3470
|
|
|
3388
3471
|
#if defined(BLACKWELL_MMA_AVAILABLE)
|
|
3389
3472
|
// FP4 tile stores 8 blocks
|
|
3390
|
-
constexpr int ne_block = (type == GGML_TYPE_MXFP4) ?
|
|
3473
|
+
constexpr int ne_block = (type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4) ? QK_K : 4 * QK8_1;
|
|
3391
3474
|
#else
|
|
3392
3475
|
constexpr int ne_block = 4 * QK8_1;
|
|
3393
3476
|
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
|
@@ -3459,10 +3542,10 @@ template <ggml_type type, int mmq_x, bool need_check>
|
|
|
3459
3542
|
static __global__ void mul_mat_q(
|
|
3460
3543
|
const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
|
|
3461
3544
|
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
|
3462
|
-
const
|
|
3463
|
-
const
|
|
3464
|
-
const
|
|
3465
|
-
const
|
|
3545
|
+
const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
|
|
3546
|
+
const uint3 channel_ratio, const uint3 nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
|
3547
|
+
const uint3 sample_ratio, const uint3 nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
|
3548
|
+
const uint3 ntx) {
|
|
3466
3549
|
|
|
3467
3550
|
// Skip unused template specializations for faster compilation:
|
|
3468
3551
|
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
|
|
@@ -3476,8 +3559,7 @@ static __global__ void mul_mat_q(
|
|
|
3476
3559
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
3477
3560
|
constexpr int mmq_y = get_mmq_y_device();
|
|
3478
3561
|
|
|
3479
|
-
const
|
|
3480
|
-
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
|
|
3562
|
+
const uint32_t nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
|
|
3481
3563
|
|
|
3482
3564
|
// Initialize the ids for writing back data with just the index.
|
|
3483
3565
|
// For regular matrix multiplications this is never changed.
|
|
@@ -3498,8 +3580,9 @@ static __global__ void mul_mat_q(
|
|
|
3498
3580
|
// On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
|
3499
3581
|
#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
|
3500
3582
|
{
|
|
3501
|
-
const
|
|
3502
|
-
const int
|
|
3583
|
+
const uint2 tmp2 = fast_div_modulo(blockIdx.z, nchannels_y);
|
|
3584
|
+
const int wt = tmp2.x;
|
|
3585
|
+
const int zt = tmp2.y;
|
|
3503
3586
|
const int jt = blockIdx.y;
|
|
3504
3587
|
const int it = blockIdx.x;
|
|
3505
3588
|
|
|
@@ -3542,40 +3625,40 @@ static __global__ void mul_mat_q(
|
|
|
3542
3625
|
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
|
3543
3626
|
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
|
|
3544
3627
|
|
|
3545
|
-
const int offset_x = (wt
|
|
3628
|
+
const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
|
3546
3629
|
|
|
3547
3630
|
constexpr bool fixup = false;
|
|
3548
3631
|
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
|
3549
3632
|
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
|
3550
|
-
tile_x_max_i, tile_y_max_j, 0,
|
|
3633
|
+
tile_x_max_i, tile_y_max_j, 0, blocks_per_ne00.z);
|
|
3551
3634
|
return;
|
|
3552
3635
|
}
|
|
3553
|
-
#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
|
3636
|
+
#endif // (defined(GGML_USE_HIP) && !defined(CDNA4) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
|
3554
3637
|
|
|
3555
|
-
constexpr int ITER_K
|
|
3556
|
-
|
|
3557
|
-
const int64_t blocks_per_ne00 = ncols_x / qk;
|
|
3558
|
-
constexpr int blocks_per_iter = ITER_K / qk;
|
|
3638
|
+
constexpr int ITER_K = get_iter_k(type);
|
|
3639
|
+
constexpr int blocks_per_iter = ITER_K / qk;
|
|
3559
3640
|
|
|
3560
3641
|
// kbc == k block continuous, current index in continuous ijk space.
|
|
3561
|
-
|
|
3562
|
-
|
|
3642
|
+
int kbc = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
|
|
3643
|
+
int kbc_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
|
|
3563
3644
|
|
|
3564
|
-
kbc -= (kbc
|
|
3565
|
-
kbc_stop -= (kbc_stop
|
|
3645
|
+
kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter;
|
|
3646
|
+
kbc_stop -= fastmodulo(kbc_stop, blocks_per_ne00) % blocks_per_iter;
|
|
3566
3647
|
|
|
3567
3648
|
// kb0 == k index when doing the matrix multiplication for an output tile.
|
|
3568
|
-
int kb0_start = kbc
|
|
3569
|
-
int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
|
|
3570
|
-
while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
|
|
3571
|
-
int tmp = kbc;
|
|
3572
|
-
|
|
3573
|
-
|
|
3574
|
-
|
|
3575
|
-
|
|
3576
|
-
const int zt =
|
|
3577
|
-
tmp
|
|
3578
|
-
|
|
3649
|
+
int kb0_start = fastmodulo(kbc, blocks_per_ne00);
|
|
3650
|
+
int kb0_stop = min(blocks_per_ne00.z, uint32_t(kb0_start + kbc_stop - kbc));
|
|
3651
|
+
while (kbc < kbc_stop && kb0_stop == int(blocks_per_ne00.z)) {
|
|
3652
|
+
int tmp = fastdiv(kbc, blocks_per_ne00);
|
|
3653
|
+
uint2 tmp2 = fast_div_modulo(tmp, ntx);
|
|
3654
|
+
const int jt = tmp2.y;
|
|
3655
|
+
tmp = tmp2.x;
|
|
3656
|
+
tmp2 = fast_div_modulo(tmp, nchannels_y);
|
|
3657
|
+
const int zt = tmp2.y;
|
|
3658
|
+
tmp = tmp2.x;
|
|
3659
|
+
tmp2 = fast_div_modulo(tmp, nsamples_y);
|
|
3660
|
+
const int wt = tmp2.y;
|
|
3661
|
+
const int it = tmp2.x;
|
|
3579
3662
|
|
|
3580
3663
|
// Defaults for regular matrix multiplication:
|
|
3581
3664
|
int col_low = 0;
|
|
@@ -3593,11 +3676,11 @@ static __global__ void mul_mat_q(
|
|
|
3593
3676
|
offset_dst = 0;
|
|
3594
3677
|
|
|
3595
3678
|
if (jt*mmq_x >= col_diff) {
|
|
3596
|
-
kbc += blocks_per_ne00;
|
|
3597
|
-
kbc -= kbc
|
|
3679
|
+
kbc += blocks_per_ne00.z;
|
|
3680
|
+
kbc -= fastmodulo(kbc, blocks_per_ne00);
|
|
3598
3681
|
|
|
3599
3682
|
kb0_start = 0;
|
|
3600
|
-
kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
|
|
3683
|
+
kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc));
|
|
3601
3684
|
|
|
3602
3685
|
continue;
|
|
3603
3686
|
}
|
|
@@ -3622,32 +3705,34 @@ static __global__ void mul_mat_q(
|
|
|
3622
3705
|
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
|
3623
3706
|
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
|
|
3624
3707
|
|
|
3625
|
-
const int offset_x = (wt
|
|
3708
|
+
const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
|
3626
3709
|
|
|
3627
3710
|
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
|
3628
3711
|
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
|
3629
3712
|
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
|
3630
3713
|
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
|
3631
3714
|
|
|
3632
|
-
kbc += blocks_per_ne00;
|
|
3633
|
-
kbc -= kbc
|
|
3715
|
+
kbc += blocks_per_ne00.z;
|
|
3716
|
+
kbc -= fastmodulo(kbc, blocks_per_ne00);
|
|
3634
3717
|
|
|
3635
3718
|
kb0_start = 0;
|
|
3636
|
-
kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
|
|
3719
|
+
kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc));
|
|
3637
3720
|
}
|
|
3638
3721
|
|
|
3639
3722
|
if (kbc >= kbc_stop) {
|
|
3640
3723
|
return;
|
|
3641
3724
|
}
|
|
3642
3725
|
|
|
3643
|
-
int tmp = kbc;
|
|
3644
|
-
|
|
3645
|
-
|
|
3646
|
-
|
|
3647
|
-
|
|
3648
|
-
const int zt =
|
|
3649
|
-
tmp
|
|
3650
|
-
|
|
3726
|
+
int tmp = fastdiv(kbc, blocks_per_ne00);
|
|
3727
|
+
uint2 tmp2 = fast_div_modulo(tmp, ntx);
|
|
3728
|
+
const int jt = tmp2.y;
|
|
3729
|
+
tmp = tmp2.x;
|
|
3730
|
+
tmp2 = fast_div_modulo(tmp, nchannels_y);
|
|
3731
|
+
const int zt = tmp2.y;
|
|
3732
|
+
tmp = tmp2.x;
|
|
3733
|
+
tmp2 = fast_div_modulo(tmp, nsamples_y);
|
|
3734
|
+
const int wt = tmp2.y;
|
|
3735
|
+
const int it = tmp2.x;
|
|
3651
3736
|
|
|
3652
3737
|
// Defaults for regular matrix multiplication:
|
|
3653
3738
|
int col_low = 0;
|
|
@@ -3689,7 +3774,7 @@ static __global__ void mul_mat_q(
|
|
|
3689
3774
|
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
|
3690
3775
|
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
|
|
3691
3776
|
|
|
3692
|
-
const int offset_x = (wt
|
|
3777
|
+
const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
|
3693
3778
|
|
|
3694
3779
|
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
|
3695
3780
|
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
|
@@ -3697,40 +3782,38 @@ static __global__ void mul_mat_q(
|
|
|
3697
3782
|
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
|
3698
3783
|
}
|
|
3699
3784
|
|
|
3700
|
-
|
|
3701
3785
|
template <ggml_type type, int mmq_x, bool need_check>
|
|
3786
|
+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device()/2, 1)
|
|
3702
3787
|
static __global__ void mul_mat_q_stream_k_fixup(
|
|
3703
|
-
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst,
|
|
3704
|
-
|
|
3705
|
-
const int
|
|
3706
|
-
const int
|
|
3707
|
-
constexpr int
|
|
3708
|
-
constexpr int
|
|
3709
|
-
constexpr int
|
|
3710
|
-
|
|
3711
|
-
constexpr int blocks_per_iter = ITER_K / qk;
|
|
3712
|
-
const int64_t blocks_per_ne00 = ncols_x / qk;
|
|
3788
|
+
const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
|
|
3789
|
+
float * __restrict__ tmp_last_tile, const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst,
|
|
3790
|
+
const int stride_col_dst, const uint3 nchannels_y, const int stride_channel_dst, const uint3 nsamples_y,
|
|
3791
|
+
const int stride_sample_dst, const uint3 ntx) {
|
|
3792
|
+
constexpr int mmq_y = get_mmq_y_device();
|
|
3793
|
+
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
3794
|
+
constexpr int ITER_K = get_iter_k(type);
|
|
3795
|
+
constexpr int blocks_per_iter = ITER_K / qk;
|
|
3713
3796
|
|
|
3714
|
-
constexpr int nwarps = mmq_get_nwarps_device();
|
|
3797
|
+
constexpr int nwarps = mmq_get_nwarps_device()/2;
|
|
3715
3798
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
3716
3799
|
|
|
3717
|
-
float sum[mmq_x
|
|
3800
|
+
float sum[mmq_x / nwarps] = {0.0f};
|
|
3801
|
+
const int i = blockIdx.y*warp_size + threadIdx.x;
|
|
3718
3802
|
|
|
3719
|
-
const int
|
|
3720
|
-
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
|
|
3803
|
+
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
|
|
3721
3804
|
|
|
3722
3805
|
const int bidx0 = blockIdx.x;
|
|
3723
3806
|
|
|
3724
3807
|
// kbc == k block continuous, current index in continuous ijk space.
|
|
3725
|
-
|
|
3726
|
-
|
|
3808
|
+
int kbc0 = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
|
|
3809
|
+
int kbc0_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
|
|
3727
3810
|
|
|
3728
|
-
kbc0 -= (kbc0
|
|
3729
|
-
kbc0_stop -= (kbc0_stop
|
|
3811
|
+
kbc0 -= fastmodulo(kbc0, blocks_per_ne00) % blocks_per_iter;
|
|
3812
|
+
kbc0_stop -= fastmodulo(kbc0_stop, blocks_per_ne00) % blocks_per_iter;
|
|
3730
3813
|
|
|
3731
3814
|
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
|
3732
|
-
const bool wrote_beginning_of_tile = kbc0
|
|
3733
|
-
const bool did_not_write_last = kbc0
|
|
3815
|
+
const bool wrote_beginning_of_tile = fastmodulo(kbc0, blocks_per_ne00) == 0;
|
|
3816
|
+
const bool did_not_write_last = fastdiv(kbc0, blocks_per_ne00) == fastdiv(kbc0_stop, blocks_per_ne00) && fastmodulo(kbc0_stop, blocks_per_ne00) != 0;
|
|
3734
3817
|
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
|
|
3735
3818
|
return;
|
|
3736
3819
|
}
|
|
@@ -3739,11 +3822,11 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
3739
3822
|
|
|
3740
3823
|
// Iterate over previous blocks and sum up partial sums written to fixup buffer.
|
|
3741
3824
|
// All CUDA blocks that get here must have a previous block that needs a fixup.
|
|
3742
|
-
|
|
3743
|
-
|
|
3825
|
+
int bidx = bidx0 - 1;
|
|
3826
|
+
int kbc_stop = kbc0;
|
|
3744
3827
|
while(true) {
|
|
3745
|
-
|
|
3746
|
-
kbc -= (kbc
|
|
3828
|
+
int kbc = int64_t(bidx)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
|
|
3829
|
+
kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter;
|
|
3747
3830
|
|
|
3748
3831
|
if (kbc == kbc_stop) { // Did not have any data.
|
|
3749
3832
|
bidx--;
|
|
@@ -3753,20 +3836,16 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
3753
3836
|
|
|
3754
3837
|
any_fixup = true;
|
|
3755
3838
|
|
|
3839
|
+
|
|
3756
3840
|
#pragma unroll
|
|
3757
3841
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
3758
3842
|
const int j = j0 + threadIdx.y;
|
|
3759
3843
|
|
|
3760
|
-
|
|
3761
|
-
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
3762
|
-
const int i = i0 + threadIdx.x;
|
|
3763
|
-
|
|
3764
|
-
sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
|
|
3765
|
-
}
|
|
3844
|
+
sum[j0/nwarps] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
|
|
3766
3845
|
}
|
|
3767
3846
|
|
|
3768
3847
|
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
|
3769
|
-
if (kbc
|
|
3848
|
+
if (fastmodulo(kbc, blocks_per_ne00) == 0 || fastdiv(kbc, blocks_per_ne00) < fastdiv(kbc0, blocks_per_ne00)) {
|
|
3770
3849
|
break;
|
|
3771
3850
|
}
|
|
3772
3851
|
bidx--;
|
|
@@ -3777,14 +3856,16 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
3777
3856
|
return;
|
|
3778
3857
|
}
|
|
3779
3858
|
|
|
3780
|
-
int tmp = kbc0;
|
|
3781
|
-
|
|
3782
|
-
|
|
3783
|
-
|
|
3784
|
-
|
|
3785
|
-
const int zt =
|
|
3786
|
-
tmp
|
|
3787
|
-
|
|
3859
|
+
int tmp = fastdiv(kbc0, blocks_per_ne00);
|
|
3860
|
+
uint2 tmp2 = fast_div_modulo(tmp, ntx);
|
|
3861
|
+
const int jt = tmp2.y;
|
|
3862
|
+
tmp = tmp2.x;
|
|
3863
|
+
tmp2 = fast_div_modulo(tmp, nchannels_y);
|
|
3864
|
+
const int zt = tmp2.y;
|
|
3865
|
+
tmp = tmp2.x;
|
|
3866
|
+
tmp2 = fast_div_modulo(tmp, nsamples_y);
|
|
3867
|
+
const int wt = tmp2.y;
|
|
3868
|
+
const int it = tmp2.x;
|
|
3788
3869
|
|
|
3789
3870
|
if (!ids_dst) {
|
|
3790
3871
|
const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
|
|
@@ -3792,6 +3873,9 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
3792
3873
|
|
|
3793
3874
|
const int i_max = nrows_x - it*mmq_y - 1;
|
|
3794
3875
|
const int j_max = ncols_dst - jt*mmq_x - 1;
|
|
3876
|
+
if (need_check && i > i_max) {
|
|
3877
|
+
return;
|
|
3878
|
+
}
|
|
3795
3879
|
|
|
3796
3880
|
#pragma unroll
|
|
3797
3881
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
@@ -3801,16 +3885,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
3801
3885
|
return;
|
|
3802
3886
|
}
|
|
3803
3887
|
|
|
3804
|
-
|
|
3805
|
-
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
3806
|
-
const int i = i0 + threadIdx.x;
|
|
3807
|
-
|
|
3808
|
-
if (need_check && i > i_max) {
|
|
3809
|
-
continue;
|
|
3810
|
-
}
|
|
3811
|
-
|
|
3812
|
-
dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
|
3813
|
-
}
|
|
3888
|
+
dst[j*stride_col_dst + i] += sum[j0/nwarps];
|
|
3814
3889
|
}
|
|
3815
3890
|
return;
|
|
3816
3891
|
}
|
|
@@ -3830,6 +3905,9 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
3830
3905
|
|
|
3831
3906
|
const int i_max = nrows_x - it*mmq_y - 1;
|
|
3832
3907
|
const int j_max = col_diff - jt*mmq_x - 1;
|
|
3908
|
+
if (need_check && i > i_max) {
|
|
3909
|
+
return;
|
|
3910
|
+
}
|
|
3833
3911
|
|
|
3834
3912
|
#pragma unroll
|
|
3835
3913
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
@@ -3839,16 +3917,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
3839
3917
|
return;
|
|
3840
3918
|
}
|
|
3841
3919
|
|
|
3842
|
-
|
|
3843
|
-
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
3844
|
-
const int i = i0 + threadIdx.x;
|
|
3845
|
-
|
|
3846
|
-
if (need_check && i > i_max) {
|
|
3847
|
-
continue;
|
|
3848
|
-
}
|
|
3849
|
-
|
|
3850
|
-
dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
|
3851
|
-
}
|
|
3920
|
+
dst[ids_dst_shared[j]*stride_col_dst + i] += sum[j0/nwarps];
|
|
3852
3921
|
}
|
|
3853
3922
|
}
|
|
3854
3923
|
|
|
@@ -3896,29 +3965,44 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
|
3896
3965
|
const int channel_ratio = args.nchannels_y / args.nchannels_x;
|
|
3897
3966
|
const int sample_ratio = args.nsamples_y / args.nsamples_x;
|
|
3898
3967
|
|
|
3968
|
+
const uint3 blocks_per_ne00_fd = init_fastdiv_values(args.ncols_x / ggml_cuda_type_traits<type>::qk);
|
|
3969
|
+
const uint3 ntx_fd = init_fastdiv_values(ntx);
|
|
3970
|
+
const uint3 nchannels_y_fd = init_fastdiv_values(args.nchannels_y);
|
|
3971
|
+
const uint3 nsamples_y_fd = init_fastdiv_values(args.nsamples_y);
|
|
3972
|
+
const uint3 channel_ratio_fd = init_fastdiv_values(channel_ratio);
|
|
3973
|
+
const uint3 sample_ratio_fd = init_fastdiv_values(sample_ratio);
|
|
3974
|
+
|
|
3899
3975
|
if (!args.use_stream_k) {
|
|
3900
3976
|
if (args.nrows_x % mmq_y == 0) {
|
|
3901
3977
|
constexpr bool need_check = false;
|
|
3902
3978
|
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
|
3903
3979
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
|
3904
|
-
|
|
3905
|
-
|
|
3906
|
-
|
|
3907
|
-
|
|
3980
|
+
blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
|
3981
|
+
channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
|
3982
|
+
sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
|
3983
|
+
ntx_fd);
|
|
3908
3984
|
} else {
|
|
3909
3985
|
constexpr bool need_check = true;
|
|
3910
3986
|
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
|
3911
3987
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
|
3912
|
-
|
|
3913
|
-
|
|
3914
|
-
|
|
3915
|
-
|
|
3988
|
+
blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
|
3989
|
+
channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
|
3990
|
+
sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
|
3991
|
+
ntx_fd);
|
|
3916
3992
|
}
|
|
3917
3993
|
return;
|
|
3918
3994
|
}
|
|
3919
3995
|
|
|
3920
|
-
|
|
3921
|
-
|
|
3996
|
+
// For the stream-k kernel it is possible to run it with tiling by setting the number of CUDA blocks equal to the number of tiles.
|
|
3997
|
+
// This is worthwhile if the efficiency of tiling is high and skipping the fixup kernel is more important.
|
|
3998
|
+
const int ntiles_dst = ntx * nty * ntzw;
|
|
3999
|
+
const int tiles_nwaves = (ntiles_dst + nsm - 1) / nsm;
|
|
4000
|
+
const int tiles_efficiency_percent = 100 * ntiles_dst / (nsm*tiles_nwaves);
|
|
4001
|
+
const dim3 block_nums_stream_k(GGML_CUDA_CC_IS_NVIDIA(cc) && tiles_efficiency_percent >= 90 ? ntiles_dst : nsm, 1, 1);
|
|
4002
|
+
|
|
4003
|
+
GGML_ASSERT(ntiles_dst * blocks_per_ne00_fd.z < (1 << 30)); // Assert that variable kbc will not overflow.
|
|
4004
|
+
|
|
4005
|
+
const bool fixup_needed = ntiles_dst % block_nums_stream_k.x != 0;
|
|
3922
4006
|
|
|
3923
4007
|
ggml_cuda_pool & pool = ctx.pool(id);
|
|
3924
4008
|
ggml_cuda_pool_alloc<float> tmp_fixup(pool);
|
|
@@ -3926,40 +4010,45 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
|
3926
4010
|
tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y);
|
|
3927
4011
|
}
|
|
3928
4012
|
|
|
4013
|
+
const dim3 block_nums_fixup(block_nums_stream_k.x, mmq_y/warp_size, 1);
|
|
4014
|
+
const dim3 block_dims_fixup(block_dims.x, block_dims.y/2, block_dims.z);
|
|
4015
|
+
|
|
3929
4016
|
if (args.nrows_x % mmq_y == 0) {
|
|
3930
4017
|
constexpr bool need_check = false;
|
|
3931
4018
|
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
|
3932
4019
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
|
3933
|
-
|
|
3934
|
-
|
|
3935
|
-
|
|
3936
|
-
|
|
4020
|
+
blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
|
4021
|
+
channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
|
4022
|
+
sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
|
4023
|
+
ntx_fd);
|
|
3937
4024
|
|
|
3938
4025
|
if (!fixup_needed) {
|
|
3939
4026
|
return;
|
|
3940
4027
|
}
|
|
3941
4028
|
|
|
3942
|
-
|
|
3943
|
-
|
|
3944
|
-
|
|
3945
|
-
args.
|
|
4029
|
+
CUDA_CHECK(cudaGetLastError());
|
|
4030
|
+
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_fixup, block_dims_fixup, 0, stream>>>
|
|
4031
|
+
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst,
|
|
4032
|
+
args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst,
|
|
4033
|
+
ntx_fd);
|
|
3946
4034
|
} else {
|
|
3947
4035
|
constexpr bool need_check = true;
|
|
3948
4036
|
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
|
3949
4037
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
|
3950
|
-
|
|
3951
|
-
|
|
3952
|
-
|
|
3953
|
-
|
|
4038
|
+
blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
|
4039
|
+
channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
|
4040
|
+
sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
|
4041
|
+
ntx_fd);
|
|
3954
4042
|
|
|
3955
4043
|
if (!fixup_needed) {
|
|
3956
4044
|
return;
|
|
3957
4045
|
}
|
|
3958
4046
|
|
|
3959
|
-
|
|
3960
|
-
|
|
3961
|
-
|
|
3962
|
-
args.
|
|
4047
|
+
CUDA_CHECK(cudaGetLastError());
|
|
4048
|
+
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_fixup, block_dims_fixup, 0, stream>>>
|
|
4049
|
+
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst,
|
|
4050
|
+
args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst,
|
|
4051
|
+
ntx_fd);
|
|
3963
4052
|
}
|
|
3964
4053
|
}
|
|
3965
4054
|
|
|
@@ -4057,6 +4146,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
|
|
|
4057
4146
|
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
|
|
4058
4147
|
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
|
|
4059
4148
|
extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
|
|
4149
|
+
extern DECL_MMQ_CASE(GGML_TYPE_NVFP4);
|
|
4060
4150
|
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
|
|
4061
4151
|
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
|
|
4062
4152
|
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
|
|
@@ -4083,3 +4173,4 @@ void ggml_cuda_op_mul_mat_q(
|
|
|
4083
4173
|
const int64_t src1_padded_row_size, cudaStream_t stream);
|
|
4084
4174
|
|
|
4085
4175
|
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);
|
|
4176
|
+
|