whispercpp 1.3.3 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/ruby_whisper_params.c +55 -25
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/build-xcframework.sh +24 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +4 -2
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/server/server.cpp +24 -13
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
- data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
- data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
- data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
- data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
- data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
- data/ext/sources/examples/talk-llama/llama-context.h +44 -29
- data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
- data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
- data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
- data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
- data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
- data/ext/sources/examples/talk-llama/llama-model.h +60 -9
- data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
- data/ext/sources/examples/talk-llama/llama.cpp +65 -10
- data/ext/sources/examples/talk-llama/llama.h +95 -177
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
- data/ext/sources/ggml/CMakeLists.txt +59 -31
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-backend.h +17 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -1
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml.h +221 -16
- data/ext/sources/ggml/src/CMakeLists.txt +17 -2
- data/ext/sources/ggml/src/ggml-alloc.c +265 -141
- data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
- data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
- data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
- data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
- data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- data/ext/sources/ggml/src/ggml-impl.h +119 -9
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml.c +478 -98
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/src/whisper.cpp +23 -46
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +3 -3
- data/ext/sources/tests/test-vad.cpp +2 -2
- data/lib/whisper/model/uri.rb +1 -1
- data/sig/whisper.rbs +7 -0
- data/test/test_params.rb +8 -0
- data/test/test_whisper.rb +1 -1
- data/whispercpp.gemspec +1 -1
- metadata +164 -157
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -15,6 +15,8 @@ typedef void (* fattn_kernel_t)(
|
|
15
15
|
const char * __restrict__ K,
|
16
16
|
const char * __restrict__ V,
|
17
17
|
const char * __restrict__ mask,
|
18
|
+
const char * __restrict__ sinks,
|
19
|
+
const int * __restrict__ KV_max,
|
18
20
|
float * __restrict__ dst,
|
19
21
|
float2 * __restrict__ dst_meta,
|
20
22
|
const float scale,
|
@@ -23,300 +25,238 @@ typedef void (* fattn_kernel_t)(
|
|
23
25
|
const float m1,
|
24
26
|
const uint32_t n_head_log2,
|
25
27
|
const float logit_softcap,
|
26
|
-
const
|
27
|
-
|
28
|
-
const
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
const int nb31,
|
36
|
-
const int nb01,
|
37
|
-
const int nb02,
|
38
|
-
const int nb03,
|
39
|
-
const int nb11,
|
40
|
-
const int nb12,
|
41
|
-
const int nb13,
|
42
|
-
const int nb21,
|
43
|
-
const int nb22,
|
44
|
-
const int nb23,
|
45
|
-
const int ne0,
|
46
|
-
const int ne1,
|
47
|
-
const int ne2,
|
48
|
-
const int ne3);
|
49
|
-
|
50
|
-
typedef half (*vec_dot_KQ_f16_t)(
|
51
|
-
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
52
|
-
typedef float (*vec_dot_KQ_f32_t)(
|
28
|
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
29
|
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
30
|
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
31
|
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
32
|
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
33
|
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
34
|
+
const int32_t nb31, const int32_t nb32, const int64_t nb33);
|
35
|
+
|
36
|
+
typedef float (*vec_dot_KQ_t)(
|
53
37
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
54
38
|
|
55
|
-
template<
|
56
|
-
static __device__ __forceinline__
|
39
|
+
template <int D, int nthreads>
|
40
|
+
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
|
41
|
+
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
42
|
+
|
43
|
+
const half2 * K_h2 = (const half2 *) K_c;
|
44
|
+
GGML_UNUSED(Q_q8);
|
45
|
+
GGML_UNUSED(Q_ds_v);
|
46
|
+
|
47
|
+
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
48
|
+
constexpr int cpy_ne = cpy_nb / 4;
|
49
|
+
|
50
|
+
float sum = 0.0f;
|
51
|
+
|
52
|
+
#pragma unroll
|
53
|
+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
|
54
|
+
half2 tmp[cpy_ne];
|
55
|
+
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
|
56
|
+
#pragma unroll
|
57
|
+
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
|
58
|
+
#ifdef FAST_FP16_AVAILABLE
|
59
|
+
ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
60
|
+
#else
|
61
|
+
ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
62
|
+
#endif // FP16_AVAILABLE
|
63
|
+
}
|
64
|
+
}
|
65
|
+
|
66
|
+
return sum;
|
67
|
+
}
|
68
|
+
|
69
|
+
template<int D, int nthreads>
|
70
|
+
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
|
57
71
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
58
72
|
|
59
73
|
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
60
74
|
GGML_UNUSED(Q_v);
|
61
75
|
|
62
|
-
|
76
|
+
float sum = 0.0f;
|
63
77
|
|
64
78
|
#pragma unroll
|
65
|
-
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 +=
|
66
|
-
const int k_KQ = k_KQ_0 + threadIdx.x;
|
79
|
+
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
80
|
+
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
|
67
81
|
|
68
82
|
const int ib = k_KQ / QI8_1;
|
69
83
|
const int iqs4 = k_KQ % QI4_0;
|
70
84
|
const int shift = k_KQ & (QI8_1/2);
|
71
85
|
|
72
|
-
|
73
|
-
|
86
|
+
int v;
|
87
|
+
ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4);
|
88
|
+
v = (v >> shift) & 0x0F0F0F0F;
|
89
|
+
const int u = Q_q8[k_KQ_0/nthreads];
|
74
90
|
|
75
91
|
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
76
92
|
|
77
|
-
|
78
|
-
|
79
|
-
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
80
|
-
|
81
|
-
const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size];
|
82
|
-
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
|
83
|
-
} else
|
84
|
-
#endif // FP16_AVAILABLE
|
85
|
-
{
|
86
|
-
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
87
|
-
|
88
|
-
sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
|
89
|
-
}
|
93
|
+
const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
|
94
|
+
sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x - (8/QI8_1)*Q_ds.y);
|
90
95
|
}
|
91
96
|
|
92
97
|
return sum;
|
93
98
|
}
|
94
99
|
|
95
|
-
template<
|
96
|
-
static __device__ __forceinline__
|
100
|
+
template<int D, int nthreads>
|
101
|
+
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_1(
|
97
102
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
98
103
|
|
99
104
|
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
100
105
|
GGML_UNUSED(Q_v);
|
101
106
|
|
102
|
-
|
107
|
+
float sum = 0.0f;
|
103
108
|
|
104
109
|
#pragma unroll
|
105
|
-
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 +=
|
106
|
-
const int k_KQ = k_KQ_0 + threadIdx.x;
|
110
|
+
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
111
|
+
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
|
107
112
|
|
108
113
|
const int ib = k_KQ / QI8_1;
|
109
114
|
const int iqs4 = k_KQ % QI4_1;
|
110
115
|
const int shift = k_KQ & (QI8_1/2);
|
111
116
|
|
112
|
-
|
113
|
-
|
117
|
+
int v;
|
118
|
+
ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4);
|
119
|
+
v = (v >> shift) & 0x0F0F0F0F;
|
120
|
+
const int u = Q_q8[k_KQ_0/nthreads];
|
114
121
|
|
115
122
|
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
116
123
|
|
117
|
-
|
118
|
-
|
119
|
-
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
120
|
-
|
121
|
-
const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size];
|
122
|
-
const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
|
123
|
-
sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
|
124
|
-
} else
|
125
|
-
#endif // FP16_AVAILABLE
|
126
|
-
{
|
127
|
-
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
128
|
-
|
129
|
-
const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
|
130
|
-
const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
|
124
|
+
const float2 K_dm = __half22float2(K_q4_1[ib].dm);
|
125
|
+
const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
|
131
126
|
|
132
|
-
|
133
|
-
}
|
127
|
+
sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
|
134
128
|
}
|
135
129
|
|
136
130
|
return sum;
|
137
131
|
}
|
138
132
|
|
139
|
-
template<
|
140
|
-
static __device__ __forceinline__
|
133
|
+
template<int D, int nthreads>
|
134
|
+
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_0(
|
141
135
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
142
136
|
|
143
137
|
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
144
138
|
GGML_UNUSED(Q_v);
|
145
139
|
|
146
|
-
|
140
|
+
float sum = 0.0f;
|
147
141
|
|
148
142
|
#pragma unroll
|
149
|
-
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 +=
|
150
|
-
const int k_KQ = k_KQ_0 + threadIdx.x;
|
143
|
+
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
144
|
+
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
|
151
145
|
|
152
146
|
const int ib = k_KQ / QI8_1;
|
153
147
|
const int iqs4 = k_KQ % QI5_0;
|
154
148
|
const int iqs8 = k_KQ % QI8_1;
|
155
149
|
const int shift = k_KQ & (QI8_1/2);
|
156
150
|
|
157
|
-
int v
|
158
|
-
|
159
|
-
v
|
160
|
-
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
161
|
-
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
162
|
-
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
151
|
+
int v;
|
152
|
+
ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4);
|
153
|
+
v = (v >> shift) & 0x0F0F0F0F;
|
163
154
|
|
164
|
-
|
155
|
+
{
|
156
|
+
int vh;
|
157
|
+
ggml_cuda_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh);
|
158
|
+
vh >>= iqs8 * QI5_0;
|
159
|
+
|
160
|
+
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
161
|
+
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
162
|
+
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
163
|
+
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
164
|
+
}
|
165
165
|
|
166
|
-
const int
|
166
|
+
const int u = Q_q8[k_KQ_0/nthreads];
|
167
167
|
|
168
|
-
|
169
|
-
if (std::is_same<T, half>::value) {
|
170
|
-
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
168
|
+
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
171
169
|
|
172
|
-
|
173
|
-
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
|
174
|
-
} else
|
175
|
-
#endif // FP16_AVAILABLE
|
176
|
-
{
|
177
|
-
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
170
|
+
const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
|
178
171
|
|
179
|
-
|
180
|
-
}
|
172
|
+
sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x - (16/QI8_1)*Q_ds.y);
|
181
173
|
}
|
182
174
|
|
183
175
|
return sum;
|
184
176
|
}
|
185
177
|
|
186
|
-
template<
|
187
|
-
static __device__ __forceinline__
|
178
|
+
template<int D, int nthreads>
|
179
|
+
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_1(
|
188
180
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
189
181
|
|
190
182
|
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
191
183
|
GGML_UNUSED(Q_v);
|
192
184
|
|
193
|
-
|
185
|
+
float sum = 0.0f;
|
194
186
|
|
195
187
|
#pragma unroll
|
196
|
-
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 +=
|
197
|
-
const int k_KQ = k_KQ_0 + threadIdx.x;
|
188
|
+
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
189
|
+
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
|
198
190
|
|
199
191
|
const int ib = k_KQ / QI8_1;
|
200
192
|
const int iqs4 = k_KQ % QI5_1;
|
201
193
|
const int iqs8 = k_KQ % QI8_1;
|
202
194
|
const int shift = k_KQ & (QI8_1/2);
|
203
195
|
|
204
|
-
int v
|
205
|
-
|
206
|
-
v
|
207
|
-
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
208
|
-
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
209
|
-
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
196
|
+
int v;
|
197
|
+
ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4);
|
198
|
+
v = (v >> shift) & 0x0F0F0F0F;
|
210
199
|
|
211
|
-
|
212
|
-
|
213
|
-
|
200
|
+
{
|
201
|
+
int vh;
|
202
|
+
ggml_cuda_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh);
|
203
|
+
vh >>= iqs8 * QI5_0;
|
204
|
+
|
205
|
+
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
206
|
+
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
207
|
+
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
208
|
+
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
209
|
+
}
|
214
210
|
|
215
|
-
|
216
|
-
if (std::is_same<T, half>::value) {
|
217
|
-
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
211
|
+
const int u = Q_q8[k_KQ_0/nthreads];
|
218
212
|
|
219
|
-
|
220
|
-
const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
|
221
|
-
sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
|
222
|
-
} else
|
223
|
-
#endif // FP16_AVAILABLE
|
224
|
-
{
|
225
|
-
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
213
|
+
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
226
214
|
|
227
|
-
|
228
|
-
|
215
|
+
const float2 K_dm = __half22float2(K_q5_1[ib].dm);
|
216
|
+
const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
|
229
217
|
|
230
|
-
|
231
|
-
}
|
218
|
+
sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
|
232
219
|
}
|
233
220
|
|
234
221
|
return sum;
|
235
222
|
}
|
236
223
|
|
237
|
-
template <
|
238
|
-
static __device__ __forceinline__
|
224
|
+
template <int D, int nthreads>
|
225
|
+
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q8_0(
|
239
226
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
240
227
|
|
241
228
|
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
242
229
|
GGML_UNUSED(Q_v);
|
243
230
|
|
244
|
-
|
231
|
+
float sum = 0.0f;
|
245
232
|
|
246
233
|
#pragma unroll
|
247
|
-
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 +=
|
248
|
-
const int k_KQ = k_KQ_0 + threadIdx.x;
|
234
|
+
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
235
|
+
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
|
249
236
|
|
250
237
|
const int ib = k_KQ / QI8_0;
|
251
238
|
const int iqs = k_KQ % QI8_0;
|
252
239
|
|
253
|
-
|
240
|
+
int v;
|
241
|
+
ggml_cuda_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs);
|
254
242
|
|
255
|
-
|
256
|
-
|
257
|
-
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
258
|
-
Q_d = __low2half(Q_ds[k_KQ_0/warp_size]);
|
259
|
-
} else {
|
260
|
-
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
261
|
-
Q_d = Q_ds[k_KQ_0/warp_size].x;
|
262
|
-
}
|
243
|
+
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
244
|
+
const float Q_d = Q_ds[k_KQ_0/nthreads].x;
|
263
245
|
|
264
|
-
sum += vec_dot_q8_0_q8_1_impl<
|
246
|
+
sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d);
|
265
247
|
}
|
266
248
|
|
267
249
|
return sum;
|
268
250
|
}
|
269
251
|
|
270
|
-
template <typename
|
271
|
-
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
272
|
-
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
273
|
-
|
274
|
-
const half2 * K_h2 = (const half2 *) K_c;
|
275
|
-
GGML_UNUSED(Q_q8);
|
276
|
-
GGML_UNUSED(Q_ds_v);
|
277
|
-
|
278
|
-
#ifdef FP16_AVAILABLE
|
279
|
-
if (std::is_same<T, half>::value) {
|
280
|
-
const half2 * Q_h2 = (const half2 *) Q_v;
|
281
|
-
|
282
|
-
half2 sum2 = make_half2(0.0f, 0.0f);
|
283
|
-
|
284
|
-
#pragma unroll
|
285
|
-
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
|
286
|
-
const int k_KQ = k_KQ_0 + threadIdx.x;
|
287
|
-
|
288
|
-
const half2 K_ik = K_h2[k_KQ];
|
289
|
-
sum2 += K_ik * Q_h2[k_KQ_0/warp_size];
|
290
|
-
}
|
291
|
-
|
292
|
-
return __low2half(sum2) + __high2half(sum2);
|
293
|
-
}
|
294
|
-
#endif // FP16_AVAILABLE
|
295
|
-
|
296
|
-
const float2 * Q_f2 = (const float2 *) Q_v;
|
297
|
-
|
298
|
-
float sum = 0.0f;
|
299
|
-
|
300
|
-
#pragma unroll
|
301
|
-
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
|
302
|
-
const int k_KQ = k_KQ_0 + threadIdx.x;
|
303
|
-
|
304
|
-
const half2 K_ik = K_h2[k_KQ];
|
305
|
-
sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x;
|
306
|
-
sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y;
|
307
|
-
}
|
308
|
-
|
309
|
-
return sum;
|
310
|
-
}
|
311
|
-
|
312
|
-
template <typename Tds>
|
252
|
+
template <typename Tds, int ni>
|
313
253
|
static __device__ __forceinline__ void quantize_q8_1_to_shared(
|
314
254
|
const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
|
315
255
|
|
316
256
|
float vals[sizeof(int)] = {0.0f};
|
317
257
|
#pragma unroll
|
318
258
|
for (int l = 0; l < int(sizeof(int)); ++l) {
|
319
|
-
vals[l] = scale * x[4*threadIdx.x + l];
|
259
|
+
vals[l] = (ni == WARP_SIZE || threadIdx.x < ni) ? scale * x[4*threadIdx.x + l] : 0.0f;
|
320
260
|
}
|
321
261
|
|
322
262
|
float amax = fabsf(vals[0]);
|
@@ -344,7 +284,7 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
|
|
344
284
|
}
|
345
285
|
|
346
286
|
yq32[threadIdx.x] = q32;
|
347
|
-
if (threadIdx.x % QI8_1 == 0) {
|
287
|
+
if (threadIdx.x % QI8_1 == 0 && (ni == WARP_SIZE || threadIdx.x < ni)) {
|
348
288
|
if (std::is_same<Tds, half2>::value) {
|
349
289
|
((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum);
|
350
290
|
} else {
|
@@ -353,173 +293,335 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
|
|
353
293
|
}
|
354
294
|
}
|
355
295
|
|
356
|
-
typedef
|
357
|
-
|
296
|
+
typedef void (*dequantize_V_t)(const void *, void *, const int64_t);
|
297
|
+
|
298
|
+
template <typename T, int ne>
|
299
|
+
static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
300
|
+
if constexpr (std::is_same_v<T, half>) {
|
301
|
+
ggml_cuda_memcpy_1<ne*sizeof(half)>(dst, (const half *) vx + i0);
|
302
|
+
} else if constexpr (std::is_same_v<T, float>) {
|
303
|
+
static_assert(ne % 2 == 0, "bad ne");
|
304
|
+
half2 tmp[ne/2];
|
305
|
+
ggml_cuda_memcpy_1<ne*sizeof(half)>(tmp, (const half *) vx + i0);
|
306
|
+
float2 * dst_f2 = (float2 *) dst;
|
307
|
+
#pragma unroll
|
308
|
+
for (int l = 0; l < ne/2; ++l) {
|
309
|
+
dst_f2[l] = __half22float2(tmp[l]);
|
310
|
+
}
|
311
|
+
} else {
|
312
|
+
static_assert(std::is_same_v<T, void>, "unsupported type");
|
313
|
+
}
|
314
|
+
}
|
358
315
|
|
359
|
-
template <typename T>
|
360
|
-
static __device__ __forceinline__
|
316
|
+
template <typename T, int ne>
|
317
|
+
static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
361
318
|
const block_q4_0 * x = (const block_q4_0 *) vx;
|
362
319
|
|
363
|
-
const int64_t ib =
|
364
|
-
const int iqs =
|
365
|
-
const int shift = (
|
320
|
+
const int64_t ib = i0 / QK4_0;
|
321
|
+
const int iqs = i0 % (QK4_0/2);
|
322
|
+
const int shift = (i0 % QK4_0) / (QK4_0/2);
|
366
323
|
|
367
|
-
|
368
|
-
|
369
|
-
|
324
|
+
int q;
|
325
|
+
static_assert(ne == 2 || ne == 4, "bad ne");
|
326
|
+
ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
|
327
|
+
q >>= 4*shift;
|
328
|
+
q &= 0x0F0F0F0F;
|
329
|
+
q = __vsubss4(q, 0x08080808);
|
330
|
+
|
331
|
+
const int8_t * q8 = (const int8_t *) &q;
|
370
332
|
|
371
333
|
#ifdef FP16_AVAILABLE
|
372
|
-
if (std::
|
373
|
-
|
374
|
-
|
334
|
+
if constexpr (std::is_same_v<T, half>) {
|
335
|
+
const half2 d = __half2half2(x[ib].d);
|
336
|
+
|
337
|
+
#pragma unroll
|
338
|
+
for (int l0 = 0; l0 < ne; l0 += 2) {
|
339
|
+
((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
|
340
|
+
}
|
341
|
+
} else
|
375
342
|
#endif // FP16_AVAILABLE
|
343
|
+
if constexpr (std::is_same_v<T, float>) {
|
344
|
+
const float d = x[ib].d;
|
376
345
|
|
377
|
-
|
346
|
+
#pragma unroll
|
347
|
+
for (int l = 0; l < ne; ++l) {
|
348
|
+
((float *) dst)[l] = d * q8[l];
|
349
|
+
}
|
350
|
+
} else {
|
351
|
+
static_assert(std::is_same_v<T, void>, "bad type");
|
352
|
+
}
|
378
353
|
}
|
379
354
|
|
380
|
-
template <typename T>
|
381
|
-
static __device__ __forceinline__
|
355
|
+
template <typename T, int ne>
|
356
|
+
static __device__ __forceinline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
382
357
|
const block_q4_1 * x = (const block_q4_1 *) vx;
|
383
358
|
|
384
|
-
const int64_t ib =
|
385
|
-
const int iqs =
|
386
|
-
const int shift = (
|
359
|
+
const int64_t ib = i0 / QK4_1;
|
360
|
+
const int iqs = i0 % (QK4_1/2);
|
361
|
+
const int shift = (i0 % QK4_1) / (QK4_1/2);
|
387
362
|
|
388
|
-
|
389
|
-
|
390
|
-
|
363
|
+
int q;
|
364
|
+
static_assert(ne == 2 || ne == 4, "bad ne");
|
365
|
+
ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
|
366
|
+
q >>= 4*shift;
|
367
|
+
q &= 0x0F0F0F0F;
|
368
|
+
|
369
|
+
const int8_t * q8 = (const int8_t *) &q;
|
391
370
|
|
392
371
|
#ifdef FP16_AVAILABLE
|
393
|
-
if (std::
|
394
|
-
|
395
|
-
|
372
|
+
if constexpr (std::is_same_v<T, half>) {
|
373
|
+
const half2 dm = x[ib].dm;
|
374
|
+
const half2 d = __half2half2( __low2half(dm));
|
375
|
+
const half2 m = __half2half2(__high2half(dm));
|
376
|
+
|
377
|
+
#pragma unroll
|
378
|
+
for (int l0 = 0; l0 < ne; l0 += 2) {
|
379
|
+
((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
|
380
|
+
}
|
381
|
+
} else
|
396
382
|
#endif // FP16_AVAILABLE
|
383
|
+
if constexpr (std::is_same_v<T, float>) {
|
384
|
+
const float2 dm = __half22float2(x[ib].dm);
|
397
385
|
|
398
|
-
|
386
|
+
#pragma unroll
|
387
|
+
for (int l = 0; l < ne; ++l) {
|
388
|
+
((float *) dst)[l] = dm.x * q8[l] + dm.y;
|
389
|
+
}
|
390
|
+
} else {
|
391
|
+
static_assert(std::is_same_v<T, void>, "bad type");
|
392
|
+
}
|
399
393
|
}
|
400
394
|
|
401
|
-
template <typename T>
|
402
|
-
static __device__ __forceinline__
|
395
|
+
template <typename T, int ne>
|
396
|
+
static __device__ __forceinline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
403
397
|
const block_q5_0 * x = (const block_q5_0 *) vx;
|
404
398
|
|
405
|
-
const int64_t ib =
|
406
|
-
const int idq =
|
407
|
-
const int iqs =
|
408
|
-
const int shift = (
|
399
|
+
const int64_t ib = i0 / QK5_0;
|
400
|
+
const int idq = i0 % QK5_0;
|
401
|
+
const int iqs = i0 % (QK5_0/2);
|
402
|
+
const int shift = (i0 % QK5_0) / (QK5_0/2);
|
409
403
|
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
const int q = (ql | qh) - 16;
|
404
|
+
int q;
|
405
|
+
static_assert(ne == 2 || ne == 4, "bad ne");
|
406
|
+
ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
|
407
|
+
q >>= 4*shift;
|
408
|
+
q &= 0x0F0F0F0F;
|
416
409
|
|
417
|
-
|
418
|
-
|
419
|
-
|
410
|
+
{
|
411
|
+
int qh;
|
412
|
+
ggml_cuda_memcpy_1<ne, 2>(&qh, x[ib].qh);
|
413
|
+
#pragma unroll
|
414
|
+
for (int l = 0; l < ne; ++l) {
|
415
|
+
q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
|
416
|
+
}
|
420
417
|
}
|
418
|
+
|
419
|
+
q = __vsubss4(q, 0x10101010);
|
420
|
+
|
421
|
+
const int8_t * q8 = (const int8_t *) &q;
|
422
|
+
|
423
|
+
#ifdef FP16_AVAILABLE
|
424
|
+
if constexpr (std::is_same_v<T, half>) {
|
425
|
+
const half2 d = __half2half2(x[ib].d);
|
426
|
+
|
427
|
+
#pragma unroll
|
428
|
+
for (int l0 = 0; l0 < ne; l0 += 2) {
|
429
|
+
((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
|
430
|
+
}
|
431
|
+
} else
|
421
432
|
#endif // FP16_AVAILABLE
|
433
|
+
if constexpr (std::is_same_v<T, float>) {
|
434
|
+
const float d = x[ib].d;
|
422
435
|
|
423
|
-
|
436
|
+
#pragma unroll
|
437
|
+
for (int l = 0; l < ne; ++l) {
|
438
|
+
((float *) dst)[l] = d * q8[l];
|
439
|
+
}
|
440
|
+
} else {
|
441
|
+
static_assert(std::is_same_v<T, void>, "bad type");
|
442
|
+
}
|
424
443
|
}
|
425
444
|
|
426
|
-
template <typename T>
|
427
|
-
static __device__ __forceinline__
|
445
|
+
template <typename T, int ne>
|
446
|
+
static __device__ __forceinline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
428
447
|
const block_q5_1 * x = (const block_q5_1 *) vx;
|
429
448
|
|
430
|
-
const int64_t ib =
|
431
|
-
const int idq =
|
432
|
-
const int iqs =
|
433
|
-
const int shift = (
|
449
|
+
const int64_t ib = i0 / QK5_1;
|
450
|
+
const int idq = i0 % QK5_1;
|
451
|
+
const int iqs = i0 % (QK5_1/2);
|
452
|
+
const int shift = (i0 % QK5_1) / (QK5_1/2);
|
434
453
|
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
const int q = (ql | qh);
|
454
|
+
int q;
|
455
|
+
static_assert(ne == 2 || ne == 4, "bad ne");
|
456
|
+
ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
|
457
|
+
q >>= 4*shift;
|
458
|
+
q &= 0x0F0F0F0F;
|
441
459
|
|
442
|
-
|
443
|
-
|
444
|
-
|
460
|
+
{
|
461
|
+
int qh;
|
462
|
+
ggml_cuda_memcpy_1<ne>(&qh, x[ib].qh);
|
463
|
+
#pragma unroll
|
464
|
+
for (int l = 0; l < ne; ++l) {
|
465
|
+
q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
|
466
|
+
}
|
445
467
|
}
|
468
|
+
|
469
|
+
const int8_t * q8 = (const int8_t *) &q;
|
470
|
+
|
471
|
+
#ifdef FP16_AVAILABLE
|
472
|
+
if constexpr (std::is_same_v<T, half>) {
|
473
|
+
const half2 dm = x[ib].dm;
|
474
|
+
const half2 d = __half2half2( __low2half(dm));
|
475
|
+
const half2 m = __half2half2(__high2half(dm));
|
476
|
+
|
477
|
+
#pragma unroll
|
478
|
+
for (int l0 = 0; l0 < ne; l0 += 2) {
|
479
|
+
((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
|
480
|
+
}
|
481
|
+
} else
|
446
482
|
#endif // FP16_AVAILABLE
|
483
|
+
if constexpr (std::is_same_v<T, float>) {
|
484
|
+
const float2 dm = __half22float2(x[ib].dm);
|
447
485
|
|
448
|
-
|
486
|
+
#pragma unroll
|
487
|
+
for (int l = 0; l < ne; ++l) {
|
488
|
+
((float *) dst)[l] = dm.x * q8[l] + dm.y;
|
489
|
+
}
|
490
|
+
} else {
|
491
|
+
static_assert(std::is_same_v<T, void>, "bad type");
|
492
|
+
}
|
449
493
|
}
|
450
494
|
|
451
|
-
template <typename T>
|
452
|
-
static __device__ __forceinline__
|
495
|
+
template <typename T, int ne>
|
496
|
+
static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
453
497
|
const block_q8_0 * x = (const block_q8_0 *) vx;
|
454
498
|
|
455
|
-
const int64_t ib =
|
456
|
-
const int iqs =
|
499
|
+
const int64_t ib = i0 / QK8_0;
|
500
|
+
const int iqs = i0 % QK8_0;
|
457
501
|
|
458
|
-
|
459
|
-
|
502
|
+
static_assert(ne % 2 == 0, "bad ne");
|
503
|
+
int8_t qs[ne];
|
504
|
+
ggml_cuda_memcpy_1<ne, 2>(qs, x[ib].qs + iqs);
|
460
505
|
|
461
506
|
#ifdef FP16_AVAILABLE
|
462
|
-
if (std::is_same<T, half>::value) {
|
463
|
-
|
464
|
-
|
507
|
+
if constexpr (std::is_same<T, half>::value) {
|
508
|
+
const half2 d = __half2half2(x[ib].d);
|
509
|
+
|
510
|
+
#pragma unroll
|
511
|
+
for (int l0 = 0; l0 < ne; l0 += 2) {
|
512
|
+
((half2 *) dst)[l0/2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]);
|
513
|
+
}
|
514
|
+
} else
|
465
515
|
#endif // FP16_AVAILABLE
|
516
|
+
if constexpr (std::is_same<T, float>::value) {
|
517
|
+
const float d = x[ib].d;
|
466
518
|
|
467
|
-
|
519
|
+
#pragma unroll
|
520
|
+
for (int l = 0; l < ne; ++l) {
|
521
|
+
((float *) dst)[l] = d * qs[l];
|
522
|
+
}
|
523
|
+
} else {
|
524
|
+
static_assert(std::is_same_v<T, void>, "unsupported type");
|
525
|
+
}
|
468
526
|
}
|
469
527
|
|
470
|
-
template <
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
528
|
+
template <ggml_type type_K, int D, int nthreads>
|
529
|
+
constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
|
530
|
+
if constexpr (type_K == GGML_TYPE_F16) {
|
531
|
+
return vec_dot_fattn_vec_KQ_f16<D, nthreads>;
|
532
|
+
} else if constexpr (type_K == GGML_TYPE_Q4_0) {
|
533
|
+
return vec_dot_fattn_vec_KQ_q4_0<D, nthreads>;
|
534
|
+
} else if constexpr (type_K == GGML_TYPE_Q4_1) {
|
535
|
+
return vec_dot_fattn_vec_KQ_q4_1<D, nthreads>;
|
536
|
+
} else if constexpr (type_K == GGML_TYPE_Q5_0) {
|
537
|
+
return vec_dot_fattn_vec_KQ_q5_0<D, nthreads>;
|
538
|
+
} else if constexpr (type_K == GGML_TYPE_Q5_1) {
|
539
|
+
return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
|
540
|
+
} else if constexpr (type_K == GGML_TYPE_Q8_0) {
|
541
|
+
return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
|
542
|
+
} else {
|
543
|
+
static_assert(type_K == -1, "bad type");
|
544
|
+
return nullptr;
|
545
|
+
}
|
475
546
|
}
|
476
547
|
|
477
|
-
template <
|
478
|
-
constexpr __device__
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
548
|
+
template <ggml_type type_V, typename T, int ne>
|
549
|
+
constexpr __device__ dequantize_V_t get_dequantize_V() {
|
550
|
+
if constexpr (type_V == GGML_TYPE_F16) {
|
551
|
+
return dequantize_V_f16<T, ne>;
|
552
|
+
} else if constexpr (type_V == GGML_TYPE_Q4_0) {
|
553
|
+
return dequantize_V_q4_0<T, ne>;
|
554
|
+
} else if constexpr (type_V == GGML_TYPE_Q4_1) {
|
555
|
+
return dequantize_V_q4_1<T, ne>;
|
556
|
+
} else if constexpr (type_V == GGML_TYPE_Q5_0) {
|
557
|
+
return dequantize_V_q5_0<T, ne>;
|
558
|
+
} else if constexpr (type_V == GGML_TYPE_Q5_1) {
|
559
|
+
return dequantize_V_q5_1<T, ne>;
|
560
|
+
} else if constexpr (type_V == GGML_TYPE_Q8_0) {
|
561
|
+
return dequantize_V_q8_0<T, ne>;
|
562
|
+
} else {
|
563
|
+
static_assert(type_V == -1, "bad type");
|
564
|
+
return nullptr;
|
565
|
+
}
|
486
566
|
}
|
487
567
|
|
488
|
-
template <int
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
nullptr;
|
497
|
-
}
|
568
|
+
template <int ncols1>
|
569
|
+
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
|
570
|
+
static __global__ void flash_attn_mask_to_KV_max(
|
571
|
+
const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
|
572
|
+
const int ne31 = gridDim.x;
|
573
|
+
const int tid = threadIdx.x;
|
574
|
+
const int sequence = blockIdx.y;
|
575
|
+
const int jt = blockIdx.x;
|
498
576
|
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
nullptr;
|
507
|
-
}
|
577
|
+
mask += sequence*s33 + jt*ncols1*s31;
|
578
|
+
|
579
|
+
__shared__ int buf_iw[WARP_SIZE];
|
580
|
+
if (tid < WARP_SIZE) {
|
581
|
+
buf_iw[tid] = 1;
|
582
|
+
}
|
583
|
+
__syncthreads();
|
508
584
|
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
585
|
+
int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
|
586
|
+
for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
|
587
|
+
int all_inf = 1;
|
588
|
+
|
589
|
+
#pragma unroll
|
590
|
+
for (int j = 0; j < ncols1; ++j) {
|
591
|
+
const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
|
592
|
+
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
|
593
|
+
}
|
594
|
+
|
595
|
+
all_inf = warp_reduce_all(all_inf);
|
596
|
+
if (tid % WARP_SIZE == 0) {
|
597
|
+
buf_iw[tid / WARP_SIZE] = all_inf;
|
598
|
+
}
|
599
|
+
__syncthreads();
|
600
|
+
all_inf = buf_iw[tid % WARP_SIZE];
|
601
|
+
__syncthreads();
|
602
|
+
all_inf = warp_reduce_all(all_inf);
|
603
|
+
|
604
|
+
if (!all_inf) {
|
605
|
+
break;
|
606
|
+
}
|
607
|
+
}
|
608
|
+
|
609
|
+
// If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
|
610
|
+
// If the break was triggered it's the lower edge of the tile with the first non-masked values.
|
611
|
+
// In either case, walk back the decrementation by FATTN_KQ_STRIDE.
|
612
|
+
KV_max_sj += FATTN_KQ_STRIDE;
|
613
|
+
|
614
|
+
if (threadIdx.x != 0) {
|
615
|
+
return;
|
616
|
+
}
|
617
|
+
|
618
|
+
KV_max[sequence*ne31 + jt] = KV_max_sj;
|
517
619
|
}
|
518
620
|
|
519
621
|
template<int D, int ncols1, int ncols2> // D == head size
|
520
622
|
__launch_bounds__(D, 1)
|
521
623
|
static __global__ void flash_attn_stream_k_fixup(
|
522
|
-
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
624
|
+
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
|
523
625
|
constexpr int ncols = ncols1*ncols2;
|
524
626
|
|
525
627
|
const int bidx0 = blockIdx.x;
|
@@ -533,8 +635,8 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
533
635
|
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
534
636
|
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
535
637
|
|
536
|
-
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
537
|
-
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
638
|
+
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
639
|
+
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
538
640
|
|
539
641
|
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
540
642
|
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
@@ -543,14 +645,15 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
543
645
|
return;
|
544
646
|
}
|
545
647
|
|
546
|
-
const int
|
547
|
-
const int
|
648
|
+
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
|
649
|
+
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
650
|
+
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
548
651
|
|
549
652
|
if (jt*ncols1 + j >= ne01) {
|
550
653
|
return;
|
551
654
|
}
|
552
655
|
|
553
|
-
dst += jt*ne02*(ncols1*D) +
|
656
|
+
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
|
554
657
|
|
555
658
|
// Load the partial result that needs a fixup:
|
556
659
|
float dst_val = 0.0f;
|
@@ -569,7 +672,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
569
672
|
int bidx = bidx0 - 1;
|
570
673
|
int kbc_stop = kbc0;
|
571
674
|
while(true) {
|
572
|
-
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
675
|
+
const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
573
676
|
if (kbc == kbc_stop) { // Did not have any data.
|
574
677
|
bidx--;
|
575
678
|
kbc_stop = kbc;
|
@@ -607,24 +710,37 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
607
710
|
}
|
608
711
|
|
609
712
|
template<int D> // D == head size
|
610
|
-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
611
713
|
__launch_bounds__(D, 1)
|
612
|
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
613
714
|
static __global__ void flash_attn_combine_results(
|
614
715
|
const float * __restrict__ VKQ_parts,
|
615
716
|
const float2 * __restrict__ VKQ_meta,
|
616
717
|
float * __restrict__ dst,
|
617
718
|
const int parallel_blocks) {
|
618
|
-
|
619
|
-
|
620
|
-
|
719
|
+
// Dimension 0: threadIdx.x
|
720
|
+
// Dimension 1: blockIdx.x
|
721
|
+
// Dimension 2: blockIdx.y
|
722
|
+
// Dimension 3: blockIdx.z
|
723
|
+
// Memory layout is permuted with [0, 2, 1, 3]
|
724
|
+
|
725
|
+
const int ne01 = gridDim.x;
|
726
|
+
const int ne02 = gridDim.y;
|
727
|
+
|
728
|
+
const int col = blockIdx.x;
|
729
|
+
const int head = blockIdx.y;
|
730
|
+
const int sequence = blockIdx.z;
|
731
|
+
|
732
|
+
const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
|
733
|
+
|
734
|
+
VKQ_parts += j_dst_unrolled * parallel_blocks*D;
|
735
|
+
VKQ_meta += j_dst_unrolled * parallel_blocks;
|
736
|
+
dst += j_dst_unrolled * D;
|
621
737
|
|
622
738
|
const int tid = threadIdx.x;
|
623
739
|
__builtin_assume(tid < D);
|
624
740
|
|
625
741
|
extern __shared__ float2 meta[];
|
626
742
|
for (int i = tid; i < 2*parallel_blocks; i += D) {
|
627
|
-
((float *) meta)[i] = ((const float *)VKQ_meta) [
|
743
|
+
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
|
628
744
|
}
|
629
745
|
|
630
746
|
__syncthreads();
|
@@ -637,38 +753,13 @@ static __global__ void flash_attn_combine_results(
|
|
637
753
|
float VKQ_numerator = 0.0f;
|
638
754
|
float VKQ_denominator = 0.0f;
|
639
755
|
for (int l = 0; l < parallel_blocks; ++l) {
|
640
|
-
const float
|
641
|
-
float KQ_max_scale = expf(diff);
|
642
|
-
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
643
|
-
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
756
|
+
const float KQ_max_scale = expf(meta[l].x - kqmax);
|
644
757
|
|
645
|
-
VKQ_numerator += KQ_max_scale * VKQ_parts[l*
|
758
|
+
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
|
646
759
|
VKQ_denominator += KQ_max_scale * meta[l].y;
|
647
760
|
}
|
648
761
|
|
649
|
-
dst[
|
650
|
-
}
|
651
|
-
|
652
|
-
[[noreturn]]
|
653
|
-
static void on_no_fattn_vec_case(const int D) {
|
654
|
-
if (D == 64) {
|
655
|
-
fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
|
656
|
-
fprintf(stderr, "By default only f16 KV cache is supported.\n");
|
657
|
-
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
|
658
|
-
GGML_ABORT("fatal error");
|
659
|
-
} else if (D == 128) {
|
660
|
-
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
|
661
|
-
fprintf(stderr, "Supported combinations:\n");
|
662
|
-
fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
|
663
|
-
fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
|
664
|
-
fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
|
665
|
-
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
|
666
|
-
GGML_ABORT("fatal error");
|
667
|
-
} else {
|
668
|
-
fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
|
669
|
-
fprintf(stderr, "Only f16 is supported.\n");
|
670
|
-
GGML_ABORT("fatal error");
|
671
|
-
}
|
762
|
+
dst[tid] = VKQ_numerator / VKQ_denominator;
|
672
763
|
}
|
673
764
|
|
674
765
|
template <int DV, int ncols1, int ncols2>
|
@@ -686,7 +777,8 @@ void launch_fattn(
|
|
686
777
|
|
687
778
|
GGML_ASSERT(V || is_mla);
|
688
779
|
|
689
|
-
const ggml_tensor * mask
|
780
|
+
const ggml_tensor * mask = dst->src[3];
|
781
|
+
const ggml_tensor * sinks = dst->src[4];
|
690
782
|
|
691
783
|
ggml_tensor * KQV = dst;
|
692
784
|
|
@@ -703,8 +795,6 @@ void launch_fattn(
|
|
703
795
|
|
704
796
|
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
705
797
|
|
706
|
-
GGML_ASSERT(Q->ne[3] == 1);
|
707
|
-
|
708
798
|
ggml_cuda_pool & pool = ctx.pool();
|
709
799
|
cudaStream_t main_stream = ctx.stream();
|
710
800
|
const int id = ggml_cuda_get_device();
|
@@ -713,6 +803,7 @@ void launch_fattn(
|
|
713
803
|
|
714
804
|
ggml_cuda_pool_alloc<half> K_f16(pool);
|
715
805
|
ggml_cuda_pool_alloc<half> V_f16(pool);
|
806
|
+
ggml_cuda_pool_alloc<int> KV_max(pool);
|
716
807
|
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
717
808
|
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
718
809
|
|
@@ -727,43 +818,86 @@ void launch_fattn(
|
|
727
818
|
size_t nb23 = V ? V->nb[3] : nb13;
|
728
819
|
|
729
820
|
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
730
|
-
GGML_ASSERT(ggml_is_contiguously_allocated(K));
|
731
|
-
K_f16.alloc(ggml_nelements(K));
|
732
|
-
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
|
733
|
-
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
|
734
|
-
K_data = (char *) K_f16.ptr;
|
735
|
-
|
736
821
|
const size_t bs = ggml_blck_size(K->type);
|
737
822
|
const size_t ts = ggml_type_size(K->type);
|
738
823
|
|
739
|
-
|
740
|
-
|
741
|
-
|
824
|
+
K_f16.alloc(ggml_nelements(K));
|
825
|
+
if (ggml_is_contiguously_allocated(K)) {
|
826
|
+
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
|
827
|
+
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
|
828
|
+
|
829
|
+
nb11 = nb11*bs*sizeof(half)/ts;
|
830
|
+
nb12 = nb12*bs*sizeof(half)/ts;
|
831
|
+
nb13 = nb13*bs*sizeof(half)/ts;
|
832
|
+
} else {
|
833
|
+
GGML_ASSERT(K->nb[0] == ts);
|
834
|
+
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
|
835
|
+
const int64_t s01 = nb11 / ts;
|
836
|
+
const int64_t s02 = nb12 / ts;
|
837
|
+
const int64_t s03 = nb13 / ts;
|
838
|
+
to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
|
839
|
+
|
840
|
+
nb11 = K->ne[0] * sizeof(half);
|
841
|
+
nb12 = K->ne[1] * nb11;
|
842
|
+
nb13 = K->ne[2] * nb12;
|
843
|
+
}
|
844
|
+
K_data = (char *) K_f16.ptr;
|
742
845
|
}
|
743
846
|
|
744
847
|
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
|
745
|
-
GGML_ASSERT(ggml_is_contiguously_allocated(V));
|
746
|
-
V_f16.alloc(ggml_nelements(V));
|
747
|
-
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
748
|
-
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
749
|
-
V_data = (char *) V_f16.ptr;
|
750
|
-
|
751
848
|
const size_t bs = ggml_blck_size(V->type);
|
752
849
|
const size_t ts = ggml_type_size(V->type);
|
753
850
|
|
754
|
-
|
755
|
-
|
756
|
-
|
851
|
+
V_f16.alloc(ggml_nelements(V));
|
852
|
+
if (ggml_is_contiguously_allocated(V)) {
|
853
|
+
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
854
|
+
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
855
|
+
V_data = (char *) V_f16.ptr;
|
856
|
+
|
857
|
+
nb21 = nb21*bs*sizeof(half)/ts;
|
858
|
+
nb22 = nb22*bs*sizeof(half)/ts;
|
859
|
+
nb23 = nb23*bs*sizeof(half)/ts;
|
860
|
+
} else {
|
861
|
+
GGML_ASSERT(V->nb[0] == ts);
|
862
|
+
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
|
863
|
+
const int64_t s01 = nb21 / ts;
|
864
|
+
const int64_t s02 = nb22 / ts;
|
865
|
+
const int64_t s03 = nb23 / ts;
|
866
|
+
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
867
|
+
|
868
|
+
nb21 = V->ne[0] * sizeof(half);
|
869
|
+
nb22 = V->ne[1] * nb21;
|
870
|
+
nb23 = V->ne[2] * nb22;
|
871
|
+
}
|
872
|
+
V_data = (char *) V_f16.ptr;
|
757
873
|
}
|
758
874
|
|
759
|
-
int parallel_blocks = 1;
|
760
|
-
|
761
875
|
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
762
876
|
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
763
877
|
|
878
|
+
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
|
879
|
+
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
|
880
|
+
// multiple sequences of possibly different lengths.
|
881
|
+
if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
|
882
|
+
const int s31 = mask->nb[1] / sizeof(half2);
|
883
|
+
const int s33 = mask->nb[3] / sizeof(half2);
|
884
|
+
|
885
|
+
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
|
886
|
+
const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
|
887
|
+
|
888
|
+
const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
|
889
|
+
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
|
890
|
+
|
891
|
+
KV_max.alloc(ne_KV_max);
|
892
|
+
flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
|
893
|
+
((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
|
894
|
+
CUDA_CHECK(cudaGetLastError());
|
895
|
+
}
|
896
|
+
|
764
897
|
const dim3 block_dim(warp_size, nwarps, 1);
|
765
898
|
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
766
899
|
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
900
|
+
int parallel_blocks = max_blocks_per_sm;
|
767
901
|
|
768
902
|
dim3 blocks_num;
|
769
903
|
if (stream_k) {
|
@@ -785,9 +919,6 @@ void launch_fattn(
|
|
785
919
|
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
786
920
|
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
787
921
|
|
788
|
-
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
|
789
|
-
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
|
790
|
-
|
791
922
|
// parallel_blocks must not be larger than what the tensor size allows:
|
792
923
|
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
793
924
|
|
@@ -802,7 +933,7 @@ void launch_fattn(
|
|
802
933
|
const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
|
803
934
|
|
804
935
|
// Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
|
805
|
-
if (efficiency_percent_best >=
|
936
|
+
if (efficiency_percent_best >= 95 && nwaves > nwaves_best) {
|
806
937
|
break;
|
807
938
|
}
|
808
939
|
|
@@ -847,15 +978,15 @@ void launch_fattn(
|
|
847
978
|
K_data,
|
848
979
|
V_data,
|
849
980
|
mask ? ((const char *) mask->data) : nullptr,
|
981
|
+
sinks ? ((const char *) sinks->data) : nullptr,
|
982
|
+
KV_max.ptr,
|
850
983
|
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
851
984
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
852
|
-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
853
|
-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
854
|
-
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
855
|
-
Q->nb[1], Q->nb[2], Q->nb[3],
|
856
|
-
nb11, nb12, nb13,
|
985
|
+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
|
986
|
+
K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
|
857
987
|
nb21, nb22, nb23,
|
858
|
-
|
988
|
+
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
|
989
|
+
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
|
859
990
|
);
|
860
991
|
CUDA_CHECK(cudaGetLastError());
|
861
992
|
|
@@ -866,11 +997,11 @@ void launch_fattn(
|
|
866
997
|
|
867
998
|
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
868
999
|
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
869
|
-
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
1000
|
+
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
|
870
1001
|
}
|
871
1002
|
} else if (parallel_blocks > 1) {
|
872
1003
|
const dim3 block_dim_combine(DV, 1, 1);
|
873
|
-
const dim3 blocks_num_combine(Q->ne[1],
|
1004
|
+
const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
|
874
1005
|
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
875
1006
|
|
876
1007
|
flash_attn_combine_results<DV>
|