whispercpp 1.3.3 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/ruby_whisper_params.c +55 -25
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/build-xcframework.sh +24 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +4 -2
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/server/server.cpp +24 -13
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
- data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
- data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
- data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
- data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
- data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
- data/ext/sources/examples/talk-llama/llama-context.h +44 -29
- data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
- data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
- data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
- data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
- data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
- data/ext/sources/examples/talk-llama/llama-model.h +60 -9
- data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
- data/ext/sources/examples/talk-llama/llama.cpp +65 -10
- data/ext/sources/examples/talk-llama/llama.h +95 -177
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
- data/ext/sources/ggml/CMakeLists.txt +59 -31
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-backend.h +17 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -1
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml.h +221 -16
- data/ext/sources/ggml/src/CMakeLists.txt +17 -2
- data/ext/sources/ggml/src/ggml-alloc.c +265 -141
- data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
- data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
- data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
- data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
- data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- data/ext/sources/ggml/src/ggml-impl.h +119 -9
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml.c +478 -98
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/src/whisper.cpp +23 -46
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +3 -3
- data/ext/sources/tests/test-vad.cpp +2 -2
- data/lib/whisper/model/uri.rb +1 -1
- data/sig/whisper.rbs +7 -0
- data/test/test_params.rb +8 -0
- data/test/test_whisper.rb +1 -1
- data/whispercpp.gemspec +1 -1
- metadata +164 -157
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -1,9 +1,10 @@
|
|
1
1
|
#include "ggml.h"
|
2
2
|
#include "common.cuh"
|
3
|
-
#include "
|
3
|
+
#include "convert.cuh"
|
4
|
+
#include "mmvf.cuh"
|
4
5
|
|
5
6
|
template <typename T, typename type_acc, int ncols_dst, int block_size>
|
6
|
-
static __global__ void
|
7
|
+
static __global__ void mul_mat_vec_f(
|
7
8
|
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
8
9
|
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
9
10
|
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
@@ -37,7 +38,7 @@ static __global__ void mul_mat_vec(
|
|
37
38
|
|
38
39
|
float sumf[ncols_dst] = {0.0f};
|
39
40
|
|
40
|
-
if constexpr (std::
|
41
|
+
if constexpr (std::is_same_v<T, float>) {
|
41
42
|
const float2 * x2 = (const float2 *) x;
|
42
43
|
|
43
44
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
@@ -50,10 +51,10 @@ static __global__ void mul_mat_vec(
|
|
50
51
|
sumf[j] += tmpx.y*tmpy.y;
|
51
52
|
}
|
52
53
|
}
|
53
|
-
} else if constexpr (std::
|
54
|
+
} else if constexpr (std::is_same_v<T, half>) {
|
54
55
|
const half2 * x2 = (const half2 *) x;
|
55
56
|
|
56
|
-
if (std::
|
57
|
+
if (std::is_same_v<type_acc, float>) {
|
57
58
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
58
59
|
const float2 tmpx = __half22float2(x2[col2]);
|
59
60
|
|
@@ -86,19 +87,19 @@ static __global__ void mul_mat_vec(
|
|
86
87
|
NO_DEVICE_CODE;
|
87
88
|
#endif // FP16_AVAILABLE
|
88
89
|
}
|
89
|
-
} else if constexpr (std::
|
90
|
+
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
|
90
91
|
const int * x2 = (const int *) x;
|
91
92
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
92
93
|
const int tmpx = x2[col2];
|
93
94
|
#pragma unroll
|
94
95
|
for (int j = 0; j < ncols_dst; ++j) {
|
95
96
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
96
|
-
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
|
97
|
-
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
|
97
|
+
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
|
98
|
+
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
|
98
99
|
}
|
99
100
|
}
|
100
101
|
} else {
|
101
|
-
static_assert(std::
|
102
|
+
static_assert(std::is_same_v<T, void>, "unsupported type");
|
102
103
|
}
|
103
104
|
|
104
105
|
#pragma unroll
|
@@ -126,7 +127,7 @@ static __global__ void mul_mat_vec(
|
|
126
127
|
}
|
127
128
|
|
128
129
|
template <typename T, typename type_acc, int ncols_dst>
|
129
|
-
static void
|
130
|
+
static void launch_mul_mat_vec_f_cuda(
|
130
131
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
131
132
|
const int64_t ncols, const int64_t nrows,
|
132
133
|
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
@@ -141,11 +142,9 @@ static void launch_mul_mat_vec_cuda(
|
|
141
142
|
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
142
143
|
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
143
144
|
const int64_t sample_ratio = nsamples_dst / nsamples_x;
|
144
|
-
int device;
|
145
|
-
int warp_size;
|
146
145
|
|
147
|
-
|
148
|
-
warp_size = ggml_cuda_info().devices[device].warp_size;
|
146
|
+
const int device = ggml_cuda_get_device();
|
147
|
+
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
149
148
|
|
150
149
|
int64_t block_size_best = warp_size;
|
151
150
|
int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
|
@@ -161,54 +160,54 @@ static void launch_mul_mat_vec_cuda(
|
|
161
160
|
}
|
162
161
|
}
|
163
162
|
|
164
|
-
const int
|
163
|
+
const int nbytes_shared = warp_size*sizeof(float);
|
165
164
|
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
|
166
165
|
const dim3 block_dims(block_size_best, 1, 1);
|
167
166
|
switch (block_size_best) {
|
168
167
|
case 32: {
|
169
|
-
|
168
|
+
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
170
169
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
171
170
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
172
171
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
173
172
|
} break;
|
174
173
|
case 64: {
|
175
|
-
|
174
|
+
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
176
175
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
177
176
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
178
177
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
179
178
|
} break;
|
180
179
|
case 96: {
|
181
|
-
|
180
|
+
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
|
182
181
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
183
182
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
184
183
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
185
184
|
} break;
|
186
185
|
case 128: {
|
187
|
-
|
186
|
+
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
188
187
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
189
188
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
190
189
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
191
190
|
} break;
|
192
191
|
case 160: {
|
193
|
-
|
192
|
+
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
|
194
193
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
195
194
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
196
195
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
197
196
|
} break;
|
198
197
|
case 192: {
|
199
|
-
|
198
|
+
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
|
200
199
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
201
200
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
202
201
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
203
202
|
} break;
|
204
203
|
case 224: {
|
205
|
-
|
204
|
+
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
|
206
205
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
207
206
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
208
207
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
209
208
|
} break;
|
210
209
|
case 256: {
|
211
|
-
|
210
|
+
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
212
211
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
213
212
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
214
213
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
@@ -220,7 +219,7 @@ static void launch_mul_mat_vec_cuda(
|
|
220
219
|
}
|
221
220
|
|
222
221
|
template <typename T, typename type_acc>
|
223
|
-
static void
|
222
|
+
static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
224
223
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
225
224
|
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
226
225
|
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
@@ -230,49 +229,49 @@ static void mul_mat_vec_cuda_switch_ncols_dst(
|
|
230
229
|
cudaStream_t stream) {
|
231
230
|
switch (ncols_dst) {
|
232
231
|
case 1:
|
233
|
-
|
232
|
+
launch_mul_mat_vec_f_cuda<T, type_acc, 1>
|
234
233
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
235
234
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
236
235
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
237
236
|
break;
|
238
237
|
case 2:
|
239
|
-
|
238
|
+
launch_mul_mat_vec_f_cuda<T, type_acc, 2>
|
240
239
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
241
240
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
242
241
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
243
242
|
break;
|
244
243
|
case 3:
|
245
|
-
|
244
|
+
launch_mul_mat_vec_f_cuda<T, type_acc, 3>
|
246
245
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
247
246
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
248
247
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
249
248
|
break;
|
250
249
|
case 4:
|
251
|
-
|
250
|
+
launch_mul_mat_vec_f_cuda<T, type_acc, 4>
|
252
251
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
253
252
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
254
253
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
255
254
|
break;
|
256
255
|
case 5:
|
257
|
-
|
256
|
+
launch_mul_mat_vec_f_cuda<T, type_acc, 5>
|
258
257
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
259
258
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
260
259
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
261
260
|
break;
|
262
261
|
case 6:
|
263
|
-
|
262
|
+
launch_mul_mat_vec_f_cuda<T, type_acc, 6>
|
264
263
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
265
264
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
266
265
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
267
266
|
break;
|
268
267
|
case 7:
|
269
|
-
|
268
|
+
launch_mul_mat_vec_f_cuda<T, type_acc, 7>
|
270
269
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
271
270
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
272
271
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
273
272
|
break;
|
274
273
|
case 8:
|
275
|
-
|
274
|
+
launch_mul_mat_vec_f_cuda<T, type_acc, 8>
|
276
275
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
277
276
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
278
277
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
@@ -284,7 +283,7 @@ static void mul_mat_vec_cuda_switch_ncols_dst(
|
|
284
283
|
}
|
285
284
|
|
286
285
|
template<typename T>
|
287
|
-
static void
|
286
|
+
static void mul_mat_vec_f_cuda(
|
288
287
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
289
288
|
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
290
289
|
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
|
@@ -292,22 +291,22 @@ static void mul_mat_vec_cuda(
|
|
292
291
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
293
292
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
294
293
|
enum ggml_prec prec, cudaStream_t stream) {
|
295
|
-
if constexpr(std::
|
294
|
+
if constexpr(std::is_same_v<T, half>) {
|
296
295
|
if (prec == GGML_PREC_DEFAULT) {
|
297
|
-
|
296
|
+
mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
|
298
297
|
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
299
298
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
300
299
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
301
300
|
return;
|
302
301
|
}
|
303
302
|
}
|
304
|
-
|
303
|
+
mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
|
305
304
|
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
306
305
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
307
306
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
308
307
|
}
|
309
308
|
|
310
|
-
void
|
309
|
+
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
311
310
|
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
312
311
|
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
|
313
312
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
@@ -355,19 +354,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
|
|
355
354
|
switch (src0->type) {
|
356
355
|
case GGML_TYPE_F32: {
|
357
356
|
const float * src0_d = (const float *) src0->data;
|
358
|
-
|
357
|
+
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
359
358
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
360
359
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
361
360
|
} break;
|
362
361
|
case GGML_TYPE_F16: {
|
363
362
|
const half * src0_d = (const half *) src0->data;
|
364
|
-
|
363
|
+
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
365
364
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
366
365
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
367
366
|
} break;
|
368
367
|
case GGML_TYPE_BF16: {
|
369
368
|
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
370
|
-
|
369
|
+
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
371
370
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
372
371
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
373
372
|
} break;
|
@@ -376,7 +375,7 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
|
|
376
375
|
}
|
377
376
|
}
|
378
377
|
|
379
|
-
void
|
378
|
+
void ggml_cuda_op_mul_mat_vec_f(
|
380
379
|
ggml_backend_cuda_context & ctx,
|
381
380
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
382
381
|
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
@@ -414,19 +413,19 @@ void ggml_cuda_op_mul_mat_vec(
|
|
414
413
|
switch (src0->type) {
|
415
414
|
case GGML_TYPE_F32: {
|
416
415
|
const float * src0_d = (const float *) src0_dd_i;
|
417
|
-
|
416
|
+
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
418
417
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
419
418
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
420
419
|
} break;
|
421
420
|
case GGML_TYPE_F16: {
|
422
421
|
const half * src0_d = (const half *) src0_dd_i;
|
423
|
-
|
422
|
+
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
424
423
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
425
424
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
426
425
|
} break;
|
427
426
|
case GGML_TYPE_BF16: {
|
428
427
|
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
429
|
-
|
428
|
+
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
430
429
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
431
430
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
432
431
|
} break;
|
@@ -434,23 +433,18 @@ void ggml_cuda_op_mul_mat_vec(
|
|
434
433
|
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
435
434
|
}
|
436
435
|
|
437
|
-
|
438
|
-
GGML_UNUSED(src1);
|
439
|
-
GGML_UNUSED(dst);
|
440
|
-
GGML_UNUSED(src1_ddq_i);
|
441
|
-
GGML_UNUSED(src1_ncols);
|
442
|
-
GGML_UNUSED(src1_padded_row_size);
|
436
|
+
GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size);
|
443
437
|
}
|
444
438
|
|
445
|
-
bool
|
439
|
+
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
|
446
440
|
if (src0_ne[0] % 2 != 0) {
|
447
441
|
return false;
|
448
442
|
}
|
449
443
|
switch (type) {
|
450
444
|
case GGML_TYPE_F32:
|
451
445
|
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
452
|
-
if (cc
|
453
|
-
return ne11 <=
|
446
|
+
if (ampere_mma_available(cc)) {
|
447
|
+
return ne11 <= 3;
|
454
448
|
}
|
455
449
|
if (cc >= GGML_CUDA_CC_TURING) {
|
456
450
|
return ne11 <= 4;
|
@@ -466,6 +460,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
|
|
466
460
|
case GGML_TYPE_F16:
|
467
461
|
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
468
462
|
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
|
463
|
+
if (ampere_mma_available(cc)) {
|
464
|
+
return src0_small && ne11 == 1;
|
465
|
+
}
|
469
466
|
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
470
467
|
return src0_small && ne11 <= 4;
|
471
468
|
}
|
@@ -486,6 +483,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
|
|
486
483
|
case GGML_TYPE_BF16:
|
487
484
|
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
488
485
|
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
|
486
|
+
if (ampere_mma_available(cc)) {
|
487
|
+
return src0_small && ne11 == 1;
|
488
|
+
}
|
489
489
|
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
490
490
|
return src0_small && ne11 <= 4;
|
491
491
|
}
|
@@ -1,11 +1,11 @@
|
|
1
1
|
#include "common.cuh"
|
2
2
|
|
3
|
-
void
|
3
|
+
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
4
4
|
|
5
|
-
void
|
5
|
+
void ggml_cuda_op_mul_mat_vec_f(
|
6
6
|
ggml_backend_cuda_context & ctx,
|
7
7
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
8
8
|
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
9
9
|
const int64_t src1_padded_row_size, cudaStream_t stream);
|
10
10
|
|
11
|
-
bool
|
11
|
+
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
|
@@ -13,6 +13,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
|
|
13
13
|
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
|
14
14
|
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
|
15
15
|
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
|
16
|
+
case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
|
16
17
|
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
|
17
18
|
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
|
18
19
|
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
|
@@ -38,6 +39,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
|
38
39
|
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
|
39
40
|
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
|
40
41
|
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
|
42
|
+
case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
|
41
43
|
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
|
42
44
|
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
|
43
45
|
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
|
@@ -139,9 +141,10 @@ template <ggml_type type, int ncols_dst>
|
|
139
141
|
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
|
140
142
|
static __global__ void mul_mat_vec_q(
|
141
143
|
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
142
|
-
const
|
143
|
-
const
|
144
|
-
const
|
144
|
+
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
|
145
|
+
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
|
146
|
+
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
|
147
|
+
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
|
145
148
|
|
146
149
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
147
150
|
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
@@ -159,12 +162,12 @@ static __global__ void mul_mat_vec_q(
|
|
159
162
|
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
|
160
163
|
|
161
164
|
// The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
|
162
|
-
const
|
163
|
-
const
|
164
|
-
const
|
165
|
-
const
|
166
|
-
const
|
167
|
-
const
|
165
|
+
const uint32_t channel_dst = blockIdx.y;
|
166
|
+
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
|
167
|
+
const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
|
168
|
+
const uint32_t sample_dst = blockIdx.z;
|
169
|
+
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
|
170
|
+
const uint32_t sample_y = sample_dst;
|
168
171
|
|
169
172
|
// partial sum for each thread
|
170
173
|
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
|
@@ -217,7 +220,7 @@ static __global__ void mul_mat_vec_q(
|
|
217
220
|
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
|
218
221
|
}
|
219
222
|
|
220
|
-
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 +
|
223
|
+
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
|
221
224
|
dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x];
|
222
225
|
}
|
223
226
|
}
|
@@ -245,8 +248,9 @@ static void mul_mat_vec_q_switch_ncols_dst(
|
|
245
248
|
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
|
246
249
|
GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
|
247
250
|
|
248
|
-
const
|
249
|
-
const
|
251
|
+
const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
|
252
|
+
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
|
253
|
+
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
|
250
254
|
|
251
255
|
const int device = ggml_cuda_get_device();
|
252
256
|
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
@@ -254,86 +258,70 @@ static void mul_mat_vec_q_switch_ncols_dst(
|
|
254
258
|
|
255
259
|
GGML_ASSERT(!ids || ncols_dst == 1);
|
256
260
|
switch (ncols_dst) {
|
257
|
-
case 1:
|
258
|
-
{
|
261
|
+
case 1: {
|
259
262
|
constexpr int c_ncols_dst = 1;
|
260
263
|
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
261
264
|
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
262
|
-
(vx, vy, ids, dst, ncols_x,
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
case 2:
|
268
|
-
{
|
265
|
+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
266
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
267
|
+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
268
|
+
} break;
|
269
|
+
case 2: {
|
269
270
|
constexpr int c_ncols_dst = 2;
|
270
271
|
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
271
272
|
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
272
|
-
(vx, vy, ids, dst, ncols_x,
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
case 3:
|
278
|
-
{
|
273
|
+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
274
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
275
|
+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
276
|
+
} break;
|
277
|
+
case 3: {
|
279
278
|
constexpr int c_ncols_dst = 3;
|
280
279
|
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
281
280
|
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
282
|
-
(vx, vy, ids, dst, ncols_x,
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
case 4:
|
288
|
-
{
|
281
|
+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
282
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
283
|
+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
284
|
+
} break;
|
285
|
+
case 4: {
|
289
286
|
constexpr int c_ncols_dst = 4;
|
290
287
|
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
291
288
|
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
292
|
-
(vx, vy, ids, dst, ncols_x,
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
case 5:
|
298
|
-
{
|
289
|
+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
290
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
291
|
+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
292
|
+
} break;
|
293
|
+
case 5: {
|
299
294
|
constexpr int c_ncols_dst = 5;
|
300
295
|
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
301
296
|
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
302
|
-
(vx, vy, ids, dst, ncols_x,
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
case 6:
|
308
|
-
{
|
297
|
+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
298
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
299
|
+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
300
|
+
} break;
|
301
|
+
case 6: {
|
309
302
|
constexpr int c_ncols_dst = 6;
|
310
303
|
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
311
304
|
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
312
|
-
(vx, vy, ids, dst, ncols_x,
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
case 7:
|
318
|
-
{
|
305
|
+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
306
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
307
|
+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
308
|
+
} break;
|
309
|
+
case 7: {
|
319
310
|
constexpr int c_ncols_dst = 7;
|
320
311
|
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
321
312
|
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
322
|
-
(vx, vy, ids, dst, ncols_x,
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
case 8:
|
328
|
-
{
|
313
|
+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
314
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
315
|
+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
316
|
+
} break;
|
317
|
+
case 8: {
|
329
318
|
constexpr int c_ncols_dst = 8;
|
330
319
|
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
331
320
|
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
332
|
-
(vx, vy, ids, dst, ncols_x,
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
}
|
321
|
+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
322
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
323
|
+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
324
|
+
} break;
|
337
325
|
default:
|
338
326
|
GGML_ABORT("fatal error");
|
339
327
|
break;
|
@@ -384,6 +372,13 @@ static void mul_mat_vec_q_switch_type(
|
|
384
372
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
385
373
|
stream);
|
386
374
|
break;
|
375
|
+
case GGML_TYPE_MXFP4:
|
376
|
+
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
|
377
|
+
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
378
|
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
379
|
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
380
|
+
stream);
|
381
|
+
break;
|
387
382
|
case GGML_TYPE_Q2_K:
|
388
383
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
|
389
384
|
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
@@ -587,9 +582,5 @@ void ggml_cuda_op_mul_mat_vec_q(
|
|
587
582
|
src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
|
588
583
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
|
589
584
|
|
590
|
-
|
591
|
-
GGML_UNUSED(dst);
|
592
|
-
GGML_UNUSED(src1_ddf_i);
|
593
|
-
GGML_UNUSED(src1_ncols);
|
594
|
-
GGML_UNUSED(src1_padded_row_size);
|
585
|
+
GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
|
595
586
|
}
|