whispercpp 1.3.6 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/README.md +38 -5
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -8
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +36 -42
- data/ext/ruby_whisper.h +135 -0
- data/ext/ruby_whisper_context.c +107 -28
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -65
- data/ext/ruby_whisper_segment.c +6 -6
- data/ext/ruby_whisper_transcribe.cpp +42 -15
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +1 -1
- data/ext/sources/examples/cli/cli.cpp +43 -9
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +199 -163
- data/ext/sources/ggml/CMakeLists.txt +21 -13
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +72 -10
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-rpc.h +3 -3
- data/ext/sources/ggml/include/ggml.h +101 -9
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +22 -5
- data/ext/sources/ggml/src/ggml-alloc.c +5 -1
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
- data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
- data/ext/sources/ggml/src/ggml-impl.h +6 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
- data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +289 -114
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
- data/ext/sources/ggml/src/ggml.c +110 -28
- data/ext/sources/ggml/src/gguf.cpp +173 -28
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +56 -12
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +411 -62
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +24 -6
- data/whispercpp.gemspec +2 -2
- metadata +215 -281
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
- data/ext/sources/examples/talk-llama/llama-context.h +0 -359
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
- data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
- data/ext/sources/examples/talk-llama/llama-model.h +0 -597
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
- data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
- data/ext/sources/examples/talk-llama/llama.h +0 -1573
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -704
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
- /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
|
@@ -10,9 +10,9 @@
|
|
|
10
10
|
using namespace ggml_cuda_mma;
|
|
11
11
|
|
|
12
12
|
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
|
13
|
-
#define MMQ_ITER_K
|
|
14
|
-
#define
|
|
15
|
-
#define MMQ_NWARPS
|
|
13
|
+
#define MMQ_ITER_K 256
|
|
14
|
+
#define MMQ_ITER_K_FP4 512
|
|
15
|
+
#define MMQ_NWARPS 8
|
|
16
16
|
|
|
17
17
|
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
|
|
18
18
|
typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
|
|
@@ -46,9 +46,12 @@ struct block_q8_1_mmq {
|
|
|
46
46
|
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
|
|
47
47
|
};
|
|
48
48
|
|
|
49
|
+
// this struct is used for fp4 data types (currently only used for Blackwell)
|
|
50
|
+
// mxfp4 has block size 32, each int32 of d4 contains 2 e8m0 scales in the lower 16 bits
|
|
51
|
+
// nvfp4 has block size 16, each int32 of d4 contains 4 ue4m3 scales
|
|
49
52
|
struct block_fp4_mmq {
|
|
50
|
-
uint32_t d4[4];
|
|
51
|
-
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte)
|
|
53
|
+
uint32_t d4[4];
|
|
54
|
+
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte)
|
|
52
55
|
};
|
|
53
56
|
|
|
54
57
|
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
|
|
@@ -57,6 +60,8 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b
|
|
|
57
60
|
|
|
58
61
|
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
|
59
62
|
switch (type_x) {
|
|
63
|
+
case GGML_TYPE_Q1_0:
|
|
64
|
+
return MMQ_Q8_1_DS_LAYOUT_D4;
|
|
60
65
|
case GGML_TYPE_Q4_0:
|
|
61
66
|
case GGML_TYPE_Q4_1:
|
|
62
67
|
return MMQ_Q8_1_DS_LAYOUT_DS4;
|
|
@@ -68,6 +73,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
|
|
68
73
|
return MMQ_Q8_1_DS_LAYOUT_D4;
|
|
69
74
|
case GGML_TYPE_MXFP4:
|
|
70
75
|
return MMQ_Q8_1_DS_LAYOUT_D4;
|
|
76
|
+
case GGML_TYPE_NVFP4:
|
|
77
|
+
return MMQ_Q8_1_DS_LAYOUT_D4;
|
|
71
78
|
case GGML_TYPE_Q2_K:
|
|
72
79
|
return MMQ_Q8_1_DS_LAYOUT_D2S6;
|
|
73
80
|
case GGML_TYPE_Q3_K:
|
|
@@ -100,7 +107,7 @@ struct tile_x_sizes {
|
|
|
100
107
|
};
|
|
101
108
|
|
|
102
109
|
static int get_mmq_x_max_host(const int cc) {
|
|
103
|
-
return (
|
|
110
|
+
return (turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
|
|
104
111
|
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
|
|
105
112
|
#ifdef GGML_CUDA_FORCE_MMQ
|
|
106
113
|
128 : 64;
|
|
@@ -110,9 +117,9 @@ static int get_mmq_x_max_host(const int cc) {
|
|
|
110
117
|
}
|
|
111
118
|
|
|
112
119
|
static constexpr __device__ int get_mmq_x_max_device() {
|
|
113
|
-
#if defined(
|
|
120
|
+
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
114
121
|
return 128;
|
|
115
|
-
#else // defined(
|
|
122
|
+
#else // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
116
123
|
|
|
117
124
|
#if defined(GGML_USE_HIP)
|
|
118
125
|
return 64;
|
|
@@ -139,10 +146,11 @@ static int get_mmq_y_host(const int cc) {
|
|
|
139
146
|
|
|
140
147
|
static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
|
|
141
148
|
#if defined(BLACKWELL_MMA_AVAILABLE)
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
149
|
+
if (type == GGML_TYPE_NVFP4 || type == GGML_TYPE_MXFP4) {
|
|
150
|
+
return MMQ_ITER_K_FP4;
|
|
151
|
+
}
|
|
145
152
|
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
|
153
|
+
return MMQ_ITER_K;
|
|
146
154
|
}
|
|
147
155
|
|
|
148
156
|
static constexpr __device__ int get_mmq_y_device() {
|
|
@@ -183,12 +191,14 @@ static constexpr __device__ int get_mmq_y_device() {
|
|
|
183
191
|
|
|
184
192
|
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
|
|
185
193
|
switch (type) {
|
|
194
|
+
case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0;
|
|
186
195
|
case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
|
|
187
196
|
case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
|
|
188
197
|
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
|
|
189
198
|
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
|
|
190
199
|
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
|
|
191
200
|
case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
|
|
201
|
+
case GGML_TYPE_NVFP4: return MMQ_DP4A_TXS_Q8_0_16;
|
|
192
202
|
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
|
|
193
203
|
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
|
|
194
204
|
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
|
|
@@ -206,12 +216,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
|
|
206
216
|
}
|
|
207
217
|
}
|
|
208
218
|
|
|
209
|
-
#define MMQ_MMA_TILE_X_K_Q8_0
|
|
210
|
-
#define MMQ_MMA_TILE_X_K_FP4
|
|
211
|
-
#define
|
|
212
|
-
#define
|
|
213
|
-
#define
|
|
214
|
-
#define
|
|
219
|
+
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
|
220
|
+
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 and NVFP4 Blackwell
|
|
221
|
+
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 Generic
|
|
222
|
+
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
|
223
|
+
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
|
|
224
|
+
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
|
|
225
|
+
#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
|
|
215
226
|
|
|
216
227
|
static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
|
|
217
228
|
static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
|
|
@@ -220,9 +231,12 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
|
|
|
220
231
|
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
|
221
232
|
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
|
|
222
233
|
static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
|
|
234
|
+
static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
|
|
235
|
+
|
|
223
236
|
|
|
224
237
|
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
225
238
|
switch (type) {
|
|
239
|
+
case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
226
240
|
case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
227
241
|
case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
228
242
|
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
@@ -230,6 +244,11 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
|
230
244
|
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
231
245
|
// tile sizes are the same for Q8_1 and FP4 for blackwell
|
|
232
246
|
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
247
|
+
#if defined(BLACKWELL_MMA_AVAILABLE)
|
|
248
|
+
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_FP4;
|
|
249
|
+
#else
|
|
250
|
+
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4;
|
|
251
|
+
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
|
233
252
|
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
|
234
253
|
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
|
235
254
|
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
@@ -295,6 +314,87 @@ static constexpr __device__ int mmq_get_nwarps_device() {
|
|
|
295
314
|
|
|
296
315
|
// ------------------------------------------------------------
|
|
297
316
|
|
|
317
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q1_0(
|
|
318
|
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
319
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
320
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
321
|
+
|
|
322
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
323
|
+
int * x_qs = (int *) x_tile;
|
|
324
|
+
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
325
|
+
#else
|
|
326
|
+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
|
327
|
+
int * x_qs = (int *) x_tile;
|
|
328
|
+
float * x_df = (float *) (x_qs + txs.qs);
|
|
329
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
330
|
+
|
|
331
|
+
constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0;
|
|
332
|
+
constexpr int threads_per_row = blocks_per_iter * QI1_0;
|
|
333
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
334
|
+
constexpr int scale_entries_per_block = QK1_0 / QK8_1;
|
|
335
|
+
constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block;
|
|
336
|
+
|
|
337
|
+
const int txi = threadIdx.x % threads_per_row;
|
|
338
|
+
const int kbx = txi / QI1_0;
|
|
339
|
+
const int kqsx = txi % QI1_0;
|
|
340
|
+
|
|
341
|
+
#pragma unroll
|
|
342
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
343
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
344
|
+
|
|
345
|
+
if (need_check) {
|
|
346
|
+
i = min(i, i_max);
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + kbx;
|
|
350
|
+
const int qs_offset = 4*kqsx;
|
|
351
|
+
const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) |
|
|
352
|
+
(bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24);
|
|
353
|
+
|
|
354
|
+
int unpacked_bytes[8];
|
|
355
|
+
#pragma unroll
|
|
356
|
+
for (int j = 0; j < 8; ++j) {
|
|
357
|
+
const int shift = j * 4;
|
|
358
|
+
const int bits4 = (qs0 >> shift) & 0x0F;
|
|
359
|
+
const int b0 = (bits4 & 0x01) ? 1 : -1;
|
|
360
|
+
const int b1 = (bits4 & 0x02) ? 1 : -1;
|
|
361
|
+
const int b2 = (bits4 & 0x04) ? 1 : -1;
|
|
362
|
+
const int b3 = (bits4 & 0x08) ? 1 : -1;
|
|
363
|
+
unpacked_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24);
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0;
|
|
367
|
+
#pragma unroll
|
|
368
|
+
for (int j = 0; j < 8; ++j) {
|
|
369
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
370
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j];
|
|
371
|
+
#else
|
|
372
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j];
|
|
373
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
374
|
+
}
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
const int ksx = threadIdx.x % scale_entries_per_row;
|
|
378
|
+
const int scale_block = ksx / scale_entries_per_block;
|
|
379
|
+
|
|
380
|
+
#pragma unroll
|
|
381
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
382
|
+
int i = i0 + threadIdx.y;
|
|
383
|
+
|
|
384
|
+
if (need_check) {
|
|
385
|
+
i = min(i, i_max);
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block;
|
|
389
|
+
|
|
390
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
391
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d;
|
|
392
|
+
#else
|
|
393
|
+
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d;
|
|
394
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
395
|
+
}
|
|
396
|
+
}
|
|
397
|
+
|
|
298
398
|
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
|
|
299
399
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
300
400
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
@@ -379,17 +479,25 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
|
379
479
|
#pragma unroll
|
|
380
480
|
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
381
481
|
const int i = i0 + threadIdx.x;
|
|
382
|
-
|
|
383
482
|
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
|
|
384
483
|
|
|
385
484
|
int u[2*VDR_Q4_0_Q8_1_MMQ];
|
|
386
485
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
486
|
+
constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
|
|
487
|
+
constexpr int mcpy_int = max_cpy / sizeof(int);
|
|
488
|
+
static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ");
|
|
489
|
+
|
|
490
|
+
int tmp0[4], tmp1[4];
|
|
491
|
+
|
|
492
|
+
#pragma unroll
|
|
493
|
+
for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) {
|
|
494
|
+
ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] );
|
|
495
|
+
ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0 + l0 * mcpy_int]);
|
|
391
496
|
}
|
|
392
497
|
|
|
498
|
+
u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3];
|
|
499
|
+
u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3];
|
|
500
|
+
|
|
393
501
|
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
|
|
394
502
|
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
|
|
395
503
|
x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
@@ -482,17 +590,25 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
|
482
590
|
#pragma unroll
|
|
483
591
|
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
484
592
|
const int i = i0 + threadIdx.x;
|
|
485
|
-
|
|
486
593
|
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
|
|
487
594
|
|
|
488
595
|
int u[2*VDR_Q4_1_Q8_1_MMQ];
|
|
489
596
|
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
597
|
+
constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
|
|
598
|
+
constexpr int mcpy_int = max_cpy / sizeof(int);
|
|
599
|
+
static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ");
|
|
600
|
+
|
|
601
|
+
int tmp0[4], tmp1[4];
|
|
602
|
+
|
|
603
|
+
#pragma unroll
|
|
604
|
+
for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) {
|
|
605
|
+
ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] );
|
|
606
|
+
ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_1 + l0 * mcpy_int]);
|
|
494
607
|
}
|
|
495
608
|
|
|
609
|
+
u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3];
|
|
610
|
+
u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3];
|
|
611
|
+
|
|
496
612
|
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
|
|
497
613
|
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
|
|
498
614
|
x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
@@ -826,6 +942,187 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
|
|
|
826
942
|
}
|
|
827
943
|
}
|
|
828
944
|
|
|
945
|
+
#ifdef BLACKWELL_MMA_AVAILABLE
|
|
946
|
+
template <int mmq_y, bool need_check>
|
|
947
|
+
static __device__ __forceinline__ void load_tiles_nvfp4_nvfp4(const char * __restrict__ x,
|
|
948
|
+
int * __restrict__ x_tile,
|
|
949
|
+
const int kbx0,
|
|
950
|
+
const int i_max,
|
|
951
|
+
const int stride) {
|
|
952
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
953
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
954
|
+
constexpr int iter_k = get_iter_k(GGML_TYPE_NVFP4);
|
|
955
|
+
constexpr int threads_per_row = iter_k / QK_NVFP4; // each thread processes 1 block
|
|
956
|
+
constexpr int rows_per_warp = warp_size / threads_per_row;
|
|
957
|
+
|
|
958
|
+
uint32_t * x_u32 = (uint32_t *) x_tile;
|
|
959
|
+
|
|
960
|
+
const int txi = threadIdx.x;
|
|
961
|
+
const int kbx = txi % threads_per_row;
|
|
962
|
+
const int row_in_warp = txi / threads_per_row;
|
|
963
|
+
|
|
964
|
+
const block_nvfp4 * bxi_base = (const block_nvfp4 *) x + kbx0 + kbx;
|
|
965
|
+
uint32_t * x_u32_scale = x_u32 + 64 + kbx;
|
|
966
|
+
|
|
967
|
+
#pragma unroll
|
|
968
|
+
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
|
|
969
|
+
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
|
|
970
|
+
|
|
971
|
+
if constexpr (need_check) {
|
|
972
|
+
i = min(i, i_max);
|
|
973
|
+
}
|
|
974
|
+
|
|
975
|
+
const block_nvfp4 * bxi = bxi_base + i * stride;
|
|
976
|
+
const int row_base = i * MMQ_MMA_TILE_X_K_FP4;
|
|
977
|
+
const int q_base = row_base + 8 * kbx;
|
|
978
|
+
|
|
979
|
+
const uint32_t * src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
|
|
980
|
+
|
|
981
|
+
#pragma unroll
|
|
982
|
+
for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
|
|
983
|
+
x_u32[q_base + 2 * sub + 0] = src_qs[2 * sub + 0];
|
|
984
|
+
x_u32[q_base + 2 * sub + 1] = src_qs[2 * sub + 1];
|
|
985
|
+
}
|
|
986
|
+
|
|
987
|
+
x_u32_scale[row_base] = get_int_b4(bxi->d, 0);
|
|
988
|
+
}
|
|
989
|
+
}
|
|
990
|
+
|
|
991
|
+
// Shared MMA kernel for MXFP4 and NVFP4 on Blackwell.
|
|
992
|
+
// Both quantizations encode values as e2m1 (FP4) and produce one uint32 scale per
|
|
993
|
+
// m16n8k64 MMA call; only the PTX kind (scale_vec::2X ue8m0 vs scale_vec::4X ue4m3)
|
|
994
|
+
// and the per-type stride constant differ.
|
|
995
|
+
template <int mmq_x, int mmq_y, ggml_type type>
|
|
996
|
+
static __device__ __forceinline__ void vec_dot_fp4_fp4_mma(const int * __restrict__ x,
|
|
997
|
+
const int * __restrict__ y,
|
|
998
|
+
float * __restrict__ sum,
|
|
999
|
+
const int k00) {
|
|
1000
|
+
static_assert(type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4,
|
|
1001
|
+
"vec_dot_fp4_fp4_mma: type must be MXFP4 or NVFP4");
|
|
1002
|
+
|
|
1003
|
+
typedef tile<16, 8, int> tile_A;
|
|
1004
|
+
typedef tile<8, 8, int> tile_B;
|
|
1005
|
+
typedef tile<16, 8, float> tile_C;
|
|
1006
|
+
|
|
1007
|
+
constexpr int stride = MMQ_MMA_TILE_X_K_FP4;
|
|
1008
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1009
|
+
constexpr int rows_per_warp = 2 * granularity;
|
|
1010
|
+
constexpr int ntx = rows_per_warp / tile_C::I;
|
|
1011
|
+
constexpr int nfrags = MMQ_TILE_NE_K / tile_A::J;
|
|
1012
|
+
|
|
1013
|
+
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_K);
|
|
1014
|
+
|
|
1015
|
+
const int * x_qs = (const int *) x;
|
|
1016
|
+
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
|
1017
|
+
const int * y_qs = (const int *) y + 4;
|
|
1018
|
+
const uint32_t * y_sc = (const uint32_t *) y;
|
|
1019
|
+
|
|
1020
|
+
// 2 threads per quad supply the packed scale register to the block_scale MMA,
|
|
1021
|
+
// see https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
|
|
1022
|
+
const int tidx_A = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
|
|
1023
|
+
const int tidx_B = threadIdx.x / 4;
|
|
1024
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1025
|
+
|
|
1026
|
+
tile_A A[ntx][nfrags];
|
|
1027
|
+
uint32_t scaleA[ntx][nfrags];
|
|
1028
|
+
|
|
1029
|
+
#pragma unroll
|
|
1030
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1031
|
+
#pragma unroll
|
|
1032
|
+
for (int frag = 0; frag < nfrags; ++frag) {
|
|
1033
|
+
const int k0 = k00 + frag * tile_A::J;
|
|
1034
|
+
load_ldmatrix(A[n][frag], x_qs + (i0 + n * tile_A::I) * stride + k0, stride);
|
|
1035
|
+
scaleA[n][frag] = x_sc[(i0 + n * tile_A::I + tidx_A) * stride + k0 / tile_A::J];
|
|
1036
|
+
}
|
|
1037
|
+
}
|
|
1038
|
+
|
|
1039
|
+
#pragma unroll
|
|
1040
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
|
|
1041
|
+
tile_B B[nfrags];
|
|
1042
|
+
uint32_t scaleB[nfrags];
|
|
1043
|
+
|
|
1044
|
+
#pragma unroll
|
|
1045
|
+
for (int frag = 0; frag < nfrags; ++frag) {
|
|
1046
|
+
const int k0 = frag * tile_B::J;
|
|
1047
|
+
load_generic(B[frag], y_qs + j0 * MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K);
|
|
1048
|
+
scaleB[frag] = y_sc[(j0 + tidx_B) * MMQ_TILE_Y_K + frag];
|
|
1049
|
+
}
|
|
1050
|
+
|
|
1051
|
+
#pragma unroll
|
|
1052
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1053
|
+
#pragma unroll
|
|
1054
|
+
for (int frag = 0; frag < nfrags; ++frag) {
|
|
1055
|
+
tile_C C = {};
|
|
1056
|
+
mma_block_scaled_fp4<type>(C, A[n][frag], B[frag], scaleA[n][frag], scaleB[frag]);
|
|
1057
|
+
#pragma unroll
|
|
1058
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1059
|
+
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
|
|
1060
|
+
}
|
|
1061
|
+
}
|
|
1062
|
+
}
|
|
1063
|
+
}
|
|
1064
|
+
}
|
|
1065
|
+
#endif // BLACKWELL_MMA_AVAILABLE
|
|
1066
|
+
|
|
1067
|
+
|
|
1068
|
+
template <int mmq_y, bool need_check>
|
|
1069
|
+
static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
|
|
1070
|
+
int * __restrict__ x_tile,
|
|
1071
|
+
const int kb0,
|
|
1072
|
+
const int i_max,
|
|
1073
|
+
const int stride) {
|
|
1074
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1075
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1076
|
+
|
|
1077
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1078
|
+
int * x_qs = (int *) x_tile;
|
|
1079
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
1080
|
+
#else
|
|
1081
|
+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_NVFP4, mmq_y);
|
|
1082
|
+
int * x_qs = (int *) x_tile;
|
|
1083
|
+
float * x_df = (float *) (x_qs + txs.qs);
|
|
1084
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1085
|
+
|
|
1086
|
+
constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4;
|
|
1087
|
+
constexpr int rows_per_warp = warp_size / threads_per_row;
|
|
1088
|
+
const int kbx = threadIdx.x % threads_per_row;
|
|
1089
|
+
const int row_in_warp = threadIdx.x / threads_per_row;
|
|
1090
|
+
|
|
1091
|
+
#pragma unroll
|
|
1092
|
+
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
|
|
1093
|
+
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
|
|
1094
|
+
|
|
1095
|
+
if constexpr (need_check) {
|
|
1096
|
+
i = min(i, i_max);
|
|
1097
|
+
}
|
|
1098
|
+
|
|
1099
|
+
const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx;
|
|
1100
|
+
const uint32_t * __restrict__ src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
|
|
1101
|
+
const int kqs = 16 * kbx;
|
|
1102
|
+
const int ksc = 4 * kbx;
|
|
1103
|
+
|
|
1104
|
+
#pragma unroll
|
|
1105
|
+
for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
|
|
1106
|
+
const int2 q0 = get_int_from_table_16(src_qs[2 * sub + 0], kvalues_mxfp4);
|
|
1107
|
+
const int2 q1 = get_int_from_table_16(src_qs[2 * sub + 1], kvalues_mxfp4);
|
|
1108
|
+
|
|
1109
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1110
|
+
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0] = q0.x;
|
|
1111
|
+
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1] = q1.x;
|
|
1112
|
+
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2] = q0.y;
|
|
1113
|
+
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3] = q1.y;
|
|
1114
|
+
x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
|
|
1115
|
+
#else
|
|
1116
|
+
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 0] = q0.x;
|
|
1117
|
+
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 1] = q1.x;
|
|
1118
|
+
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 2] = q0.y;
|
|
1119
|
+
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 3] = q1.y;
|
|
1120
|
+
x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
|
|
1121
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1122
|
+
}
|
|
1123
|
+
}
|
|
1124
|
+
}
|
|
1125
|
+
|
|
829
1126
|
template <int mmq_x, int mmq_y>
|
|
830
1127
|
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
|
831
1128
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
@@ -887,13 +1184,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
|
887
1184
|
tile_A A[ntx];
|
|
888
1185
|
#pragma unroll
|
|
889
1186
|
for (int n = 0; n < ntx; ++n) {
|
|
890
|
-
|
|
1187
|
+
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
|
891
1188
|
}
|
|
892
1189
|
|
|
893
1190
|
#pragma unroll
|
|
894
1191
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
895
1192
|
tile_B B;
|
|
896
|
-
|
|
1193
|
+
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
897
1194
|
|
|
898
1195
|
float dB;
|
|
899
1196
|
const int j = j0 + tile_C::get_j(0);
|
|
@@ -996,77 +1293,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
|
996
1293
|
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
997
1294
|
}
|
|
998
1295
|
|
|
999
|
-
template <int mmq_x, int mmq_y>
|
|
1000
|
-
static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
|
|
1001
|
-
const int * __restrict__ y,
|
|
1002
|
-
float * __restrict__ sum,
|
|
1003
|
-
const int k00) {
|
|
1004
|
-
typedef tile<16, 8, int> tile_A;
|
|
1005
|
-
typedef tile<8, 8, int> tile_B;
|
|
1006
|
-
typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
|
|
1007
|
-
|
|
1008
|
-
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1009
|
-
constexpr int rows_per_warp = 2 * granularity;
|
|
1010
|
-
constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
|
|
1011
|
-
|
|
1012
|
-
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
|
|
1013
|
-
|
|
1014
|
-
// Match layout from load_tiles_mxfp4_fp4
|
|
1015
|
-
const int * x_qs = (const int *) x;
|
|
1016
|
-
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
|
1017
|
-
const int * y_qs = (const int *) y + 4;
|
|
1018
|
-
const uint32_t * y_sc = (const uint32_t *) y;
|
|
1019
|
-
|
|
1020
|
-
// tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
|
|
1021
|
-
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
|
1022
|
-
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
|
1023
|
-
|
|
1024
|
-
// Block scale
|
|
1025
|
-
// Each thread has to point to a 4 byte scale value
|
|
1026
|
-
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
|
|
1027
|
-
|
|
1028
|
-
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1029
|
-
|
|
1030
|
-
#pragma unroll
|
|
1031
|
-
for (int n = 0; n < ntx; ++n) {
|
|
1032
|
-
#pragma unroll
|
|
1033
|
-
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
|
1034
|
-
const int k0 = k00 + k01;
|
|
1035
|
-
|
|
1036
|
-
load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
|
|
1037
|
-
MMQ_MMA_TILE_X_K_FP4);
|
|
1038
|
-
|
|
1039
|
-
// based on block-scaling document, 2 threads in each quad need to supply to the scale value
|
|
1040
|
-
const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
|
|
1041
|
-
scaleA[n][k01 / (2 * QI_MXFP4)] =
|
|
1042
|
-
*(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
|
|
1043
|
-
}
|
|
1044
|
-
}
|
|
1045
|
-
|
|
1046
|
-
#pragma unroll
|
|
1047
|
-
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
|
|
1048
|
-
#pragma unroll
|
|
1049
|
-
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
|
1050
|
-
tile_B B;
|
|
1051
|
-
uint32_t scaleB; // 2xN scales
|
|
1052
|
-
|
|
1053
|
-
load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
|
|
1054
|
-
|
|
1055
|
-
scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
|
|
1056
|
-
|
|
1057
|
-
#pragma unroll
|
|
1058
|
-
for (int n = 0; n < ntx; ++n) {
|
|
1059
|
-
tile_C C;
|
|
1060
|
-
|
|
1061
|
-
mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
|
|
1062
|
-
#pragma unroll
|
|
1063
|
-
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1064
|
-
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
|
|
1065
|
-
}
|
|
1066
|
-
}
|
|
1067
|
-
}
|
|
1068
|
-
}
|
|
1069
|
-
}
|
|
1070
1296
|
|
|
1071
1297
|
template <int mmq_x, int mmq_y>
|
|
1072
1298
|
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
|
@@ -1128,13 +1354,13 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
|
1128
1354
|
tile_A A[ntx];
|
|
1129
1355
|
#pragma unroll
|
|
1130
1356
|
for (int n = 0; n < ntx; ++n) {
|
|
1131
|
-
|
|
1357
|
+
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
|
1132
1358
|
}
|
|
1133
1359
|
|
|
1134
1360
|
#pragma unroll
|
|
1135
1361
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1136
1362
|
tile_B B;
|
|
1137
|
-
|
|
1363
|
+
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1138
1364
|
|
|
1139
1365
|
const int j = j0 + tile_C::get_j(0);
|
|
1140
1366
|
const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
@@ -1229,7 +1455,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
|
1229
1455
|
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1230
1456
|
}
|
|
1231
1457
|
|
|
1232
|
-
// Used for Q3_K, IQ2_S, and IQ2_XS
|
|
1458
|
+
// Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS
|
|
1233
1459
|
template <int mmq_x, int mmq_y>
|
|
1234
1460
|
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
|
1235
1461
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
@@ -1268,57 +1494,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
|
|
1268
1494
|
template <int mmq_x, int mmq_y>
|
|
1269
1495
|
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
1270
1496
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1271
|
-
#if defined(AMD_MFMA_AVAILABLE)
|
|
1272
|
-
constexpr data_layout input_layout = get_input_data_layout();
|
|
1273
|
-
typedef tile<16, 8, int, input_layout> tile_A;
|
|
1274
|
-
typedef tile<16, 8, int, input_layout> tile_B;
|
|
1275
|
-
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
1276
|
-
typedef tile<64, 2, int, input_layout> tile_load;
|
|
1277
|
-
|
|
1278
|
-
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1279
|
-
constexpr int rows_per_warp = granularity;
|
|
1280
|
-
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1281
|
-
|
|
1282
|
-
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1283
|
-
|
|
1284
|
-
const int * x_qs = (const int *) x;
|
|
1285
|
-
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
1286
|
-
const int * y_qs = (const int *) y + 4;
|
|
1287
|
-
const float * y_df = (const float *) y;
|
|
1288
|
-
|
|
1289
|
-
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1290
|
-
|
|
1291
|
-
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
1292
|
-
const int k0 = k00 + k01;
|
|
1293
|
-
|
|
1294
|
-
tile_A A[ntx];
|
|
1295
|
-
#pragma unroll
|
|
1296
|
-
for (int n = 0; n < ntx; ++n) {
|
|
1297
|
-
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
|
1298
|
-
}
|
|
1299
|
-
|
|
1300
|
-
#pragma unroll
|
|
1301
|
-
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1302
|
-
tile_B B[1];
|
|
1303
|
-
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1304
|
-
|
|
1305
|
-
const int j = j0 + tile_C::get_j(0);
|
|
1306
|
-
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
|
|
1307
|
-
|
|
1308
|
-
#pragma unroll
|
|
1309
|
-
for (int n = 0; n < ntx; ++n) {
|
|
1310
|
-
tile_C C;
|
|
1311
|
-
mma(C, A[n], B[0]);
|
|
1312
|
-
|
|
1313
|
-
#pragma unroll
|
|
1314
|
-
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1315
|
-
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
1316
|
-
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
|
|
1317
|
-
}
|
|
1318
|
-
}
|
|
1319
|
-
}
|
|
1320
|
-
}
|
|
1321
|
-
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
|
1497
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1322
1498
|
constexpr data_layout input_layout = get_input_data_layout();
|
|
1323
1499
|
typedef tile<16, 4, int, input_layout> tile_A;
|
|
1324
1500
|
typedef tile<16, 4, int, input_layout> tile_B;
|
|
@@ -1343,13 +1519,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
|
1343
1519
|
tile_A A[ntx];
|
|
1344
1520
|
#pragma unroll
|
|
1345
1521
|
for (int n = 0; n < ntx; ++n) {
|
|
1346
|
-
|
|
1522
|
+
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
|
1347
1523
|
}
|
|
1348
1524
|
|
|
1349
1525
|
#pragma unroll
|
|
1350
1526
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1351
1527
|
tile_B B;
|
|
1352
|
-
|
|
1528
|
+
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1353
1529
|
|
|
1354
1530
|
const int j = j0 + tile_C::get_j(0);
|
|
1355
1531
|
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
|
@@ -1575,74 +1751,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
|
1575
1751
|
template <int mmq_x, int mmq_y>
|
|
1576
1752
|
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
1577
1753
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1578
|
-
#if defined(AMD_MFMA_AVAILABLE)
|
|
1579
|
-
constexpr data_layout input_layout = get_input_data_layout();
|
|
1580
|
-
typedef tile<16, 8, int, input_layout> tile_A;
|
|
1581
|
-
typedef tile<16, 8, int, input_layout> tile_B;
|
|
1582
|
-
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
1583
|
-
typedef tile<64, 2, int, input_layout> tile_load;
|
|
1584
|
-
|
|
1585
|
-
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1586
|
-
constexpr int rows_per_warp = granularity;
|
|
1587
|
-
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1588
|
-
|
|
1589
|
-
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1590
|
-
|
|
1591
|
-
const int * x_qs = (const int *) x;
|
|
1592
|
-
const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
|
|
1593
|
-
const int * y_qs = (const int *) y + 4;
|
|
1594
|
-
const half2 * y_ds = (const half2 *) y;
|
|
1595
|
-
|
|
1596
|
-
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1597
|
-
|
|
1598
|
-
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
1599
|
-
const int k0 = k00 + k01;
|
|
1600
|
-
|
|
1601
|
-
tile_A A[ntx];
|
|
1602
|
-
#pragma unroll
|
|
1603
|
-
for (int n = 0; n < ntx; ++n) {
|
|
1604
|
-
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
|
1605
|
-
}
|
|
1606
|
-
|
|
1607
|
-
#pragma unroll
|
|
1608
|
-
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1609
|
-
tile_B B[1];
|
|
1610
|
-
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1611
|
-
|
|
1612
|
-
const int j = j0 + tile_C::get_j(0);
|
|
1613
|
-
const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
|
|
1614
|
-
const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
|
|
1615
|
-
: (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
|
|
1616
|
-
: __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
|
|
1617
|
-
|
|
1618
|
-
tile_C Cm;
|
|
1619
|
-
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
1620
|
-
tile_A A1;
|
|
1621
|
-
A1.x[0] = 0x01010101;
|
|
1622
|
-
A1.x[1] = 0x01010101;
|
|
1623
|
-
mma(Cm, A1, B[0]);
|
|
1624
|
-
}
|
|
1625
|
-
|
|
1626
|
-
#pragma unroll
|
|
1627
|
-
for (int n = 0; n < ntx; ++n) {
|
|
1628
|
-
tile_C Cd;
|
|
1629
|
-
mma(Cd, A[n], B[0]);
|
|
1630
|
-
|
|
1631
|
-
#pragma unroll
|
|
1632
|
-
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1633
|
-
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
1634
|
-
const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
|
|
1635
|
-
float tmp = Cd.x[l]*dm.x;
|
|
1636
|
-
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
1637
|
-
tmp -= Cm.x[l]*dm.y;
|
|
1638
|
-
}
|
|
1639
|
-
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
|
|
1640
|
-
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
|
|
1641
|
-
}
|
|
1642
|
-
}
|
|
1643
|
-
}
|
|
1644
|
-
}
|
|
1645
|
-
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
|
1754
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1646
1755
|
constexpr data_layout input_layout = get_input_data_layout();
|
|
1647
1756
|
typedef tile<16, 4, int, input_layout> tile_A;
|
|
1648
1757
|
typedef tile<16, 4, int, input_layout> tile_B;
|
|
@@ -1667,13 +1776,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1667
1776
|
tile_A A[ntx];
|
|
1668
1777
|
#pragma unroll
|
|
1669
1778
|
for (int n = 0; n < ntx; ++n) {
|
|
1670
|
-
|
|
1779
|
+
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
|
1671
1780
|
}
|
|
1672
1781
|
|
|
1673
1782
|
#pragma unroll
|
|
1674
1783
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1675
1784
|
tile_B B;
|
|
1676
|
-
|
|
1785
|
+
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1677
1786
|
|
|
1678
1787
|
const int j = j0 + tile_C::get_j(0);
|
|
1679
1788
|
const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
|
|
@@ -2406,59 +2515,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
|
|
2406
2515
|
template <int mmq_x, int mmq_y>
|
|
2407
2516
|
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
2408
2517
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
2409
|
-
#if defined(AMD_MFMA_AVAILABLE)
|
|
2410
|
-
constexpr data_layout input_layout = get_input_data_layout();
|
|
2411
|
-
typedef tile<16, 8, int, input_layout> tile_A;
|
|
2412
|
-
typedef tile<16, 8, int, input_layout> tile_B;
|
|
2413
|
-
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
2414
|
-
typedef tile<64, 2, int, input_layout> tile_load;
|
|
2415
|
-
|
|
2416
|
-
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
2417
|
-
constexpr int rows_per_warp = granularity;
|
|
2418
|
-
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
2419
|
-
|
|
2420
|
-
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
2421
|
-
|
|
2422
|
-
const int * x_qs = (const int *) x;
|
|
2423
|
-
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
2424
|
-
const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
|
|
2425
|
-
const int * y_qs = (const int *) y + 4;
|
|
2426
|
-
const float * y_df = (const float *) y;
|
|
2427
|
-
|
|
2428
|
-
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
2429
|
-
|
|
2430
|
-
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
2431
|
-
const int k0 = k00 + k01;
|
|
2432
|
-
|
|
2433
|
-
tile_A A[ntx];
|
|
2434
|
-
#pragma unroll
|
|
2435
|
-
for (int n = 0; n < ntx; ++n) {
|
|
2436
|
-
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
|
|
2437
|
-
}
|
|
2438
|
-
|
|
2439
|
-
#pragma unroll
|
|
2440
|
-
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
2441
|
-
tile_B B[1];
|
|
2442
|
-
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
2443
|
-
|
|
2444
|
-
const int j = j0 + tile_C::get_j(0);
|
|
2445
|
-
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
|
|
2446
|
-
|
|
2447
|
-
#pragma unroll
|
|
2448
|
-
for (int n = 0; n < ntx; ++n) {
|
|
2449
|
-
tile_C C;
|
|
2450
|
-
mma(C, A[n], B[0]);
|
|
2451
|
-
|
|
2452
|
-
#pragma unroll
|
|
2453
|
-
for (int l = 0; l < tile_C::ne; ++l) {
|
|
2454
|
-
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
2455
|
-
const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
|
|
2456
|
-
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
|
|
2457
|
-
}
|
|
2458
|
-
}
|
|
2459
|
-
}
|
|
2460
|
-
}
|
|
2461
|
-
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
|
2518
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2462
2519
|
constexpr data_layout input_layout = get_input_data_layout();
|
|
2463
2520
|
typedef tile<16, 4, int, input_layout> tile_A;
|
|
2464
2521
|
typedef tile<16, 4, int, input_layout> tile_B;
|
|
@@ -2484,13 +2541,13 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
|
2484
2541
|
tile_A A[ntx];
|
|
2485
2542
|
#pragma unroll
|
|
2486
2543
|
for (int n = 0; n < ntx; ++n) {
|
|
2487
|
-
|
|
2544
|
+
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
|
|
2488
2545
|
}
|
|
2489
2546
|
|
|
2490
2547
|
#pragma unroll
|
|
2491
2548
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
2492
2549
|
tile_B B;
|
|
2493
|
-
|
|
2550
|
+
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
2494
2551
|
|
|
2495
2552
|
const int j = j0 + tile_C::get_j(0);
|
|
2496
2553
|
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
|
@@ -3208,6 +3265,14 @@ static __device__ __forceinline__ void mmq_write_back_mma(
|
|
|
3208
3265
|
template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
|
|
3209
3266
|
struct mmq_type_traits;
|
|
3210
3267
|
|
|
3268
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3269
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q1_0> {
|
|
3270
|
+
static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ;
|
|
3271
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0<mmq_y, need_check>;
|
|
3272
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
3273
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
3274
|
+
};
|
|
3275
|
+
|
|
3211
3276
|
template <int mmq_x, int mmq_y, bool need_check>
|
|
3212
3277
|
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
|
|
3213
3278
|
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
|
|
@@ -3253,7 +3318,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
|
|
3253
3318
|
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
|
|
3254
3319
|
#ifdef BLACKWELL_MMA_AVAILABLE
|
|
3255
3320
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
|
|
3256
|
-
static constexpr vec_dot_mmq_t vec_dot_mma =
|
|
3321
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_MXFP4>;
|
|
3257
3322
|
#else
|
|
3258
3323
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
|
|
3259
3324
|
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
@@ -3261,6 +3326,19 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
|
|
3261
3326
|
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
3262
3327
|
};
|
|
3263
3328
|
|
|
3329
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3330
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> {
|
|
3331
|
+
static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ;
|
|
3332
|
+
#ifdef BLACKWELL_MMA_AVAILABLE
|
|
3333
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4_nvfp4<mmq_y, need_check>;
|
|
3334
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_NVFP4>;
|
|
3335
|
+
#else
|
|
3336
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4<mmq_y, need_check>;
|
|
3337
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
|
3338
|
+
#endif // BLACKWELL_MMA_AVAILABLE
|
|
3339
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
|
|
3340
|
+
};
|
|
3341
|
+
|
|
3264
3342
|
template <int mmq_x, int mmq_y, bool need_check>
|
|
3265
3343
|
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
|
|
3266
3344
|
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
|
@@ -3392,7 +3470,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
|
3392
3470
|
|
|
3393
3471
|
#if defined(BLACKWELL_MMA_AVAILABLE)
|
|
3394
3472
|
// FP4 tile stores 8 blocks
|
|
3395
|
-
constexpr int ne_block = (type == GGML_TYPE_MXFP4) ?
|
|
3473
|
+
constexpr int ne_block = (type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4) ? QK_K : 4 * QK8_1;
|
|
3396
3474
|
#else
|
|
3397
3475
|
constexpr int ne_block = 4 * QK8_1;
|
|
3398
3476
|
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
|
@@ -3464,10 +3542,10 @@ template <ggml_type type, int mmq_x, bool need_check>
|
|
|
3464
3542
|
static __global__ void mul_mat_q(
|
|
3465
3543
|
const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
|
|
3466
3544
|
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
|
3467
|
-
const
|
|
3468
|
-
const
|
|
3469
|
-
const
|
|
3470
|
-
const
|
|
3545
|
+
const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
|
|
3546
|
+
const uint3 channel_ratio, const uint3 nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
|
3547
|
+
const uint3 sample_ratio, const uint3 nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
|
3548
|
+
const uint3 ntx) {
|
|
3471
3549
|
|
|
3472
3550
|
// Skip unused template specializations for faster compilation:
|
|
3473
3551
|
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
|
|
@@ -3481,8 +3559,7 @@ static __global__ void mul_mat_q(
|
|
|
3481
3559
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
3482
3560
|
constexpr int mmq_y = get_mmq_y_device();
|
|
3483
3561
|
|
|
3484
|
-
const
|
|
3485
|
-
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
|
|
3562
|
+
const uint32_t nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
|
|
3486
3563
|
|
|
3487
3564
|
// Initialize the ids for writing back data with just the index.
|
|
3488
3565
|
// For regular matrix multiplications this is never changed.
|
|
@@ -3503,8 +3580,9 @@ static __global__ void mul_mat_q(
|
|
|
3503
3580
|
// On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
|
3504
3581
|
#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
|
3505
3582
|
{
|
|
3506
|
-
const
|
|
3507
|
-
const int
|
|
3583
|
+
const uint2 tmp2 = fast_div_modulo(blockIdx.z, nchannels_y);
|
|
3584
|
+
const int wt = tmp2.x;
|
|
3585
|
+
const int zt = tmp2.y;
|
|
3508
3586
|
const int jt = blockIdx.y;
|
|
3509
3587
|
const int it = blockIdx.x;
|
|
3510
3588
|
|
|
@@ -3547,40 +3625,40 @@ static __global__ void mul_mat_q(
|
|
|
3547
3625
|
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
|
3548
3626
|
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
|
|
3549
3627
|
|
|
3550
|
-
const int offset_x = (wt
|
|
3628
|
+
const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
|
3551
3629
|
|
|
3552
3630
|
constexpr bool fixup = false;
|
|
3553
3631
|
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
|
3554
3632
|
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
|
3555
|
-
tile_x_max_i, tile_y_max_j, 0,
|
|
3633
|
+
tile_x_max_i, tile_y_max_j, 0, blocks_per_ne00.z);
|
|
3556
3634
|
return;
|
|
3557
3635
|
}
|
|
3558
|
-
#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
|
3559
|
-
|
|
3560
|
-
constexpr int ITER_K = get_iter_k(type);
|
|
3636
|
+
#endif // (defined(GGML_USE_HIP) && !defined(CDNA4) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
|
3561
3637
|
|
|
3562
|
-
|
|
3563
|
-
constexpr int
|
|
3638
|
+
constexpr int ITER_K = get_iter_k(type);
|
|
3639
|
+
constexpr int blocks_per_iter = ITER_K / qk;
|
|
3564
3640
|
|
|
3565
3641
|
// kbc == k block continuous, current index in continuous ijk space.
|
|
3566
|
-
|
|
3567
|
-
|
|
3642
|
+
int kbc = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
|
|
3643
|
+
int kbc_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
|
|
3568
3644
|
|
|
3569
|
-
kbc -= (kbc
|
|
3570
|
-
kbc_stop -= (kbc_stop
|
|
3645
|
+
kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter;
|
|
3646
|
+
kbc_stop -= fastmodulo(kbc_stop, blocks_per_ne00) % blocks_per_iter;
|
|
3571
3647
|
|
|
3572
3648
|
// kb0 == k index when doing the matrix multiplication for an output tile.
|
|
3573
|
-
int kb0_start = kbc
|
|
3574
|
-
int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
|
|
3575
|
-
while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
|
|
3576
|
-
int tmp = kbc;
|
|
3577
|
-
|
|
3578
|
-
|
|
3579
|
-
|
|
3580
|
-
|
|
3581
|
-
const int zt =
|
|
3582
|
-
tmp
|
|
3583
|
-
|
|
3649
|
+
int kb0_start = fastmodulo(kbc, blocks_per_ne00);
|
|
3650
|
+
int kb0_stop = min(blocks_per_ne00.z, uint32_t(kb0_start + kbc_stop - kbc));
|
|
3651
|
+
while (kbc < kbc_stop && kb0_stop == int(blocks_per_ne00.z)) {
|
|
3652
|
+
int tmp = fastdiv(kbc, blocks_per_ne00);
|
|
3653
|
+
uint2 tmp2 = fast_div_modulo(tmp, ntx);
|
|
3654
|
+
const int jt = tmp2.y;
|
|
3655
|
+
tmp = tmp2.x;
|
|
3656
|
+
tmp2 = fast_div_modulo(tmp, nchannels_y);
|
|
3657
|
+
const int zt = tmp2.y;
|
|
3658
|
+
tmp = tmp2.x;
|
|
3659
|
+
tmp2 = fast_div_modulo(tmp, nsamples_y);
|
|
3660
|
+
const int wt = tmp2.y;
|
|
3661
|
+
const int it = tmp2.x;
|
|
3584
3662
|
|
|
3585
3663
|
// Defaults for regular matrix multiplication:
|
|
3586
3664
|
int col_low = 0;
|
|
@@ -3598,11 +3676,11 @@ static __global__ void mul_mat_q(
|
|
|
3598
3676
|
offset_dst = 0;
|
|
3599
3677
|
|
|
3600
3678
|
if (jt*mmq_x >= col_diff) {
|
|
3601
|
-
kbc += blocks_per_ne00;
|
|
3602
|
-
kbc -= kbc
|
|
3679
|
+
kbc += blocks_per_ne00.z;
|
|
3680
|
+
kbc -= fastmodulo(kbc, blocks_per_ne00);
|
|
3603
3681
|
|
|
3604
3682
|
kb0_start = 0;
|
|
3605
|
-
kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
|
|
3683
|
+
kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc));
|
|
3606
3684
|
|
|
3607
3685
|
continue;
|
|
3608
3686
|
}
|
|
@@ -3627,32 +3705,34 @@ static __global__ void mul_mat_q(
|
|
|
3627
3705
|
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
|
3628
3706
|
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
|
|
3629
3707
|
|
|
3630
|
-
const int offset_x = (wt
|
|
3708
|
+
const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
|
3631
3709
|
|
|
3632
3710
|
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
|
3633
3711
|
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
|
3634
3712
|
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
|
3635
3713
|
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
|
3636
3714
|
|
|
3637
|
-
kbc += blocks_per_ne00;
|
|
3638
|
-
kbc -= kbc
|
|
3715
|
+
kbc += blocks_per_ne00.z;
|
|
3716
|
+
kbc -= fastmodulo(kbc, blocks_per_ne00);
|
|
3639
3717
|
|
|
3640
3718
|
kb0_start = 0;
|
|
3641
|
-
kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
|
|
3719
|
+
kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc));
|
|
3642
3720
|
}
|
|
3643
3721
|
|
|
3644
3722
|
if (kbc >= kbc_stop) {
|
|
3645
3723
|
return;
|
|
3646
3724
|
}
|
|
3647
3725
|
|
|
3648
|
-
int tmp = kbc;
|
|
3649
|
-
|
|
3650
|
-
|
|
3651
|
-
|
|
3652
|
-
|
|
3653
|
-
const int zt =
|
|
3654
|
-
tmp
|
|
3655
|
-
|
|
3726
|
+
int tmp = fastdiv(kbc, blocks_per_ne00);
|
|
3727
|
+
uint2 tmp2 = fast_div_modulo(tmp, ntx);
|
|
3728
|
+
const int jt = tmp2.y;
|
|
3729
|
+
tmp = tmp2.x;
|
|
3730
|
+
tmp2 = fast_div_modulo(tmp, nchannels_y);
|
|
3731
|
+
const int zt = tmp2.y;
|
|
3732
|
+
tmp = tmp2.x;
|
|
3733
|
+
tmp2 = fast_div_modulo(tmp, nsamples_y);
|
|
3734
|
+
const int wt = tmp2.y;
|
|
3735
|
+
const int it = tmp2.x;
|
|
3656
3736
|
|
|
3657
3737
|
// Defaults for regular matrix multiplication:
|
|
3658
3738
|
int col_low = 0;
|
|
@@ -3694,7 +3774,7 @@ static __global__ void mul_mat_q(
|
|
|
3694
3774
|
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
|
3695
3775
|
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
|
|
3696
3776
|
|
|
3697
|
-
const int offset_x = (wt
|
|
3777
|
+
const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
|
3698
3778
|
|
|
3699
3779
|
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
|
3700
3780
|
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
|
@@ -3703,46 +3783,37 @@ static __global__ void mul_mat_q(
|
|
|
3703
3783
|
}
|
|
3704
3784
|
|
|
3705
3785
|
template <ggml_type type, int mmq_x, bool need_check>
|
|
3706
|
-
|
|
3707
|
-
|
|
3708
|
-
|
|
3709
|
-
|
|
3710
|
-
|
|
3711
|
-
|
|
3712
|
-
|
|
3713
|
-
|
|
3714
|
-
|
|
3715
|
-
|
|
3716
|
-
const int nsamples_y,
|
|
3717
|
-
const size_t stride_sample_dst,
|
|
3718
|
-
const int ncols_max) {
|
|
3719
|
-
constexpr int mmq_y = get_mmq_y_device();
|
|
3720
|
-
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
3721
|
-
constexpr int ITER_K = get_iter_k(type);
|
|
3722
|
-
|
|
3723
|
-
constexpr int blocks_per_iter = ITER_K / qk;
|
|
3724
|
-
const int64_t blocks_per_ne00 = ncols_x / qk;
|
|
3786
|
+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device()/2, 1)
|
|
3787
|
+
static __global__ void mul_mat_q_stream_k_fixup(
|
|
3788
|
+
const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
|
|
3789
|
+
float * __restrict__ tmp_last_tile, const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst,
|
|
3790
|
+
const int stride_col_dst, const uint3 nchannels_y, const int stride_channel_dst, const uint3 nsamples_y,
|
|
3791
|
+
const int stride_sample_dst, const uint3 ntx) {
|
|
3792
|
+
constexpr int mmq_y = get_mmq_y_device();
|
|
3793
|
+
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
3794
|
+
constexpr int ITER_K = get_iter_k(type);
|
|
3795
|
+
constexpr int blocks_per_iter = ITER_K / qk;
|
|
3725
3796
|
|
|
3726
|
-
constexpr int nwarps = mmq_get_nwarps_device();
|
|
3797
|
+
constexpr int nwarps = mmq_get_nwarps_device()/2;
|
|
3727
3798
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
3728
3799
|
|
|
3729
|
-
float sum[mmq_x
|
|
3800
|
+
float sum[mmq_x / nwarps] = {0.0f};
|
|
3801
|
+
const int i = blockIdx.y*warp_size + threadIdx.x;
|
|
3730
3802
|
|
|
3731
|
-
const int
|
|
3732
|
-
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
|
|
3803
|
+
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
|
|
3733
3804
|
|
|
3734
3805
|
const int bidx0 = blockIdx.x;
|
|
3735
3806
|
|
|
3736
3807
|
// kbc == k block continuous, current index in continuous ijk space.
|
|
3737
|
-
|
|
3738
|
-
|
|
3808
|
+
int kbc0 = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
|
|
3809
|
+
int kbc0_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
|
|
3739
3810
|
|
|
3740
|
-
kbc0 -= (kbc0
|
|
3741
|
-
kbc0_stop -= (kbc0_stop
|
|
3811
|
+
kbc0 -= fastmodulo(kbc0, blocks_per_ne00) % blocks_per_iter;
|
|
3812
|
+
kbc0_stop -= fastmodulo(kbc0_stop, blocks_per_ne00) % blocks_per_iter;
|
|
3742
3813
|
|
|
3743
3814
|
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
|
3744
|
-
const bool wrote_beginning_of_tile = kbc0
|
|
3745
|
-
const bool did_not_write_last = kbc0
|
|
3815
|
+
const bool wrote_beginning_of_tile = fastmodulo(kbc0, blocks_per_ne00) == 0;
|
|
3816
|
+
const bool did_not_write_last = fastdiv(kbc0, blocks_per_ne00) == fastdiv(kbc0_stop, blocks_per_ne00) && fastmodulo(kbc0_stop, blocks_per_ne00) != 0;
|
|
3746
3817
|
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
|
|
3747
3818
|
return;
|
|
3748
3819
|
}
|
|
@@ -3751,11 +3822,11 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
|
|
|
3751
3822
|
|
|
3752
3823
|
// Iterate over previous blocks and sum up partial sums written to fixup buffer.
|
|
3753
3824
|
// All CUDA blocks that get here must have a previous block that needs a fixup.
|
|
3754
|
-
|
|
3755
|
-
|
|
3825
|
+
int bidx = bidx0 - 1;
|
|
3826
|
+
int kbc_stop = kbc0;
|
|
3756
3827
|
while(true) {
|
|
3757
|
-
|
|
3758
|
-
kbc -= (kbc
|
|
3828
|
+
int kbc = int64_t(bidx)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
|
|
3829
|
+
kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter;
|
|
3759
3830
|
|
|
3760
3831
|
if (kbc == kbc_stop) { // Did not have any data.
|
|
3761
3832
|
bidx--;
|
|
@@ -3765,20 +3836,16 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
|
|
|
3765
3836
|
|
|
3766
3837
|
any_fixup = true;
|
|
3767
3838
|
|
|
3839
|
+
|
|
3768
3840
|
#pragma unroll
|
|
3769
3841
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
3770
3842
|
const int j = j0 + threadIdx.y;
|
|
3771
3843
|
|
|
3772
|
-
|
|
3773
|
-
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
3774
|
-
const int i = i0 + threadIdx.x;
|
|
3775
|
-
|
|
3776
|
-
sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
|
|
3777
|
-
}
|
|
3844
|
+
sum[j0/nwarps] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
|
|
3778
3845
|
}
|
|
3779
3846
|
|
|
3780
3847
|
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
|
3781
|
-
if (kbc
|
|
3848
|
+
if (fastmodulo(kbc, blocks_per_ne00) == 0 || fastdiv(kbc, blocks_per_ne00) < fastdiv(kbc0, blocks_per_ne00)) {
|
|
3782
3849
|
break;
|
|
3783
3850
|
}
|
|
3784
3851
|
bidx--;
|
|
@@ -3789,14 +3856,16 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
|
|
|
3789
3856
|
return;
|
|
3790
3857
|
}
|
|
3791
3858
|
|
|
3792
|
-
int tmp = kbc0;
|
|
3793
|
-
|
|
3794
|
-
|
|
3795
|
-
|
|
3796
|
-
|
|
3797
|
-
const int zt =
|
|
3798
|
-
tmp
|
|
3799
|
-
|
|
3859
|
+
int tmp = fastdiv(kbc0, blocks_per_ne00);
|
|
3860
|
+
uint2 tmp2 = fast_div_modulo(tmp, ntx);
|
|
3861
|
+
const int jt = tmp2.y;
|
|
3862
|
+
tmp = tmp2.x;
|
|
3863
|
+
tmp2 = fast_div_modulo(tmp, nchannels_y);
|
|
3864
|
+
const int zt = tmp2.y;
|
|
3865
|
+
tmp = tmp2.x;
|
|
3866
|
+
tmp2 = fast_div_modulo(tmp, nsamples_y);
|
|
3867
|
+
const int wt = tmp2.y;
|
|
3868
|
+
const int it = tmp2.x;
|
|
3800
3869
|
|
|
3801
3870
|
if (!ids_dst) {
|
|
3802
3871
|
const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
|
|
@@ -3804,6 +3873,9 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
|
|
|
3804
3873
|
|
|
3805
3874
|
const int i_max = nrows_x - it*mmq_y - 1;
|
|
3806
3875
|
const int j_max = ncols_dst - jt*mmq_x - 1;
|
|
3876
|
+
if (need_check && i > i_max) {
|
|
3877
|
+
return;
|
|
3878
|
+
}
|
|
3807
3879
|
|
|
3808
3880
|
#pragma unroll
|
|
3809
3881
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
@@ -3813,16 +3885,7 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
|
|
|
3813
3885
|
return;
|
|
3814
3886
|
}
|
|
3815
3887
|
|
|
3816
|
-
|
|
3817
|
-
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
3818
|
-
const int i = i0 + threadIdx.x;
|
|
3819
|
-
|
|
3820
|
-
if (need_check && i > i_max) {
|
|
3821
|
-
continue;
|
|
3822
|
-
}
|
|
3823
|
-
|
|
3824
|
-
dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
|
3825
|
-
}
|
|
3888
|
+
dst[j*stride_col_dst + i] += sum[j0/nwarps];
|
|
3826
3889
|
}
|
|
3827
3890
|
return;
|
|
3828
3891
|
}
|
|
@@ -3842,6 +3905,9 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
|
|
|
3842
3905
|
|
|
3843
3906
|
const int i_max = nrows_x - it*mmq_y - 1;
|
|
3844
3907
|
const int j_max = col_diff - jt*mmq_x - 1;
|
|
3908
|
+
if (need_check && i > i_max) {
|
|
3909
|
+
return;
|
|
3910
|
+
}
|
|
3845
3911
|
|
|
3846
3912
|
#pragma unroll
|
|
3847
3913
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
@@ -3851,16 +3917,7 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
|
|
|
3851
3917
|
return;
|
|
3852
3918
|
}
|
|
3853
3919
|
|
|
3854
|
-
|
|
3855
|
-
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
3856
|
-
const int i = i0 + threadIdx.x;
|
|
3857
|
-
|
|
3858
|
-
if (need_check && i > i_max) {
|
|
3859
|
-
continue;
|
|
3860
|
-
}
|
|
3861
|
-
|
|
3862
|
-
dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
|
3863
|
-
}
|
|
3920
|
+
dst[ids_dst_shared[j]*stride_col_dst + i] += sum[j0/nwarps];
|
|
3864
3921
|
}
|
|
3865
3922
|
}
|
|
3866
3923
|
|
|
@@ -3908,29 +3965,44 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
|
3908
3965
|
const int channel_ratio = args.nchannels_y / args.nchannels_x;
|
|
3909
3966
|
const int sample_ratio = args.nsamples_y / args.nsamples_x;
|
|
3910
3967
|
|
|
3968
|
+
const uint3 blocks_per_ne00_fd = init_fastdiv_values(args.ncols_x / ggml_cuda_type_traits<type>::qk);
|
|
3969
|
+
const uint3 ntx_fd = init_fastdiv_values(ntx);
|
|
3970
|
+
const uint3 nchannels_y_fd = init_fastdiv_values(args.nchannels_y);
|
|
3971
|
+
const uint3 nsamples_y_fd = init_fastdiv_values(args.nsamples_y);
|
|
3972
|
+
const uint3 channel_ratio_fd = init_fastdiv_values(channel_ratio);
|
|
3973
|
+
const uint3 sample_ratio_fd = init_fastdiv_values(sample_ratio);
|
|
3974
|
+
|
|
3911
3975
|
if (!args.use_stream_k) {
|
|
3912
3976
|
if (args.nrows_x % mmq_y == 0) {
|
|
3913
3977
|
constexpr bool need_check = false;
|
|
3914
3978
|
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
|
3915
3979
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
|
3916
|
-
|
|
3917
|
-
|
|
3918
|
-
|
|
3919
|
-
|
|
3980
|
+
blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
|
3981
|
+
channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
|
3982
|
+
sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
|
3983
|
+
ntx_fd);
|
|
3920
3984
|
} else {
|
|
3921
3985
|
constexpr bool need_check = true;
|
|
3922
3986
|
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
|
3923
3987
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
|
3924
|
-
|
|
3925
|
-
|
|
3926
|
-
|
|
3927
|
-
|
|
3988
|
+
blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
|
3989
|
+
channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
|
3990
|
+
sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
|
3991
|
+
ntx_fd);
|
|
3928
3992
|
}
|
|
3929
3993
|
return;
|
|
3930
3994
|
}
|
|
3931
3995
|
|
|
3932
|
-
|
|
3933
|
-
|
|
3996
|
+
// For the stream-k kernel it is possible to run it with tiling by setting the number of CUDA blocks equal to the number of tiles.
|
|
3997
|
+
// This is worthwhile if the efficiency of tiling is high and skipping the fixup kernel is more important.
|
|
3998
|
+
const int ntiles_dst = ntx * nty * ntzw;
|
|
3999
|
+
const int tiles_nwaves = (ntiles_dst + nsm - 1) / nsm;
|
|
4000
|
+
const int tiles_efficiency_percent = 100 * ntiles_dst / (nsm*tiles_nwaves);
|
|
4001
|
+
const dim3 block_nums_stream_k(GGML_CUDA_CC_IS_NVIDIA(cc) && tiles_efficiency_percent >= 90 ? ntiles_dst : nsm, 1, 1);
|
|
4002
|
+
|
|
4003
|
+
GGML_ASSERT(ntiles_dst * blocks_per_ne00_fd.z < (1 << 30)); // Assert that variable kbc will not overflow.
|
|
4004
|
+
|
|
4005
|
+
const bool fixup_needed = ntiles_dst % block_nums_stream_k.x != 0;
|
|
3934
4006
|
|
|
3935
4007
|
ggml_cuda_pool & pool = ctx.pool(id);
|
|
3936
4008
|
ggml_cuda_pool_alloc<float> tmp_fixup(pool);
|
|
@@ -3938,40 +4010,45 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
|
3938
4010
|
tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y);
|
|
3939
4011
|
}
|
|
3940
4012
|
|
|
4013
|
+
const dim3 block_nums_fixup(block_nums_stream_k.x, mmq_y/warp_size, 1);
|
|
4014
|
+
const dim3 block_dims_fixup(block_dims.x, block_dims.y/2, block_dims.z);
|
|
4015
|
+
|
|
3941
4016
|
if (args.nrows_x % mmq_y == 0) {
|
|
3942
4017
|
constexpr bool need_check = false;
|
|
3943
4018
|
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
|
3944
4019
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
|
3945
|
-
|
|
3946
|
-
|
|
3947
|
-
|
|
3948
|
-
|
|
4020
|
+
blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
|
4021
|
+
channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
|
4022
|
+
sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
|
4023
|
+
ntx_fd);
|
|
3949
4024
|
|
|
3950
4025
|
if (!fixup_needed) {
|
|
3951
4026
|
return;
|
|
3952
4027
|
}
|
|
3953
4028
|
|
|
3954
|
-
|
|
3955
|
-
|
|
3956
|
-
|
|
3957
|
-
args.
|
|
4029
|
+
CUDA_CHECK(cudaGetLastError());
|
|
4030
|
+
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_fixup, block_dims_fixup, 0, stream>>>
|
|
4031
|
+
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst,
|
|
4032
|
+
args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst,
|
|
4033
|
+
ntx_fd);
|
|
3958
4034
|
} else {
|
|
3959
4035
|
constexpr bool need_check = true;
|
|
3960
4036
|
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
|
3961
4037
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
|
3962
|
-
|
|
3963
|
-
|
|
3964
|
-
|
|
3965
|
-
|
|
4038
|
+
blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
|
4039
|
+
channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
|
4040
|
+
sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
|
4041
|
+
ntx_fd);
|
|
3966
4042
|
|
|
3967
4043
|
if (!fixup_needed) {
|
|
3968
4044
|
return;
|
|
3969
4045
|
}
|
|
3970
4046
|
|
|
3971
|
-
|
|
3972
|
-
|
|
3973
|
-
|
|
3974
|
-
args.
|
|
4047
|
+
CUDA_CHECK(cudaGetLastError());
|
|
4048
|
+
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_fixup, block_dims_fixup, 0, stream>>>
|
|
4049
|
+
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst,
|
|
4050
|
+
args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst,
|
|
4051
|
+
ntx_fd);
|
|
3975
4052
|
}
|
|
3976
4053
|
}
|
|
3977
4054
|
|
|
@@ -4069,6 +4146,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
|
|
|
4069
4146
|
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
|
|
4070
4147
|
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
|
|
4071
4148
|
extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
|
|
4149
|
+
extern DECL_MMQ_CASE(GGML_TYPE_NVFP4);
|
|
4072
4150
|
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
|
|
4073
4151
|
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
|
|
4074
4152
|
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
|
|
@@ -4095,3 +4173,4 @@ void ggml_cuda_op_mul_mat_q(
|
|
|
4095
4173
|
const int64_t src1_padded_row_size, cudaStream_t stream);
|
|
4096
4174
|
|
|
4097
4175
|
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);
|
|
4176
|
+
|