whispercpp 1.3.3 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/ruby_whisper_params.c +55 -25
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/build-xcframework.sh +24 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +4 -2
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/server/server.cpp +24 -13
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
- data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
- data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
- data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
- data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
- data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
- data/ext/sources/examples/talk-llama/llama-context.h +44 -29
- data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
- data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
- data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
- data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
- data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
- data/ext/sources/examples/talk-llama/llama-model.h +60 -9
- data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
- data/ext/sources/examples/talk-llama/llama.cpp +65 -10
- data/ext/sources/examples/talk-llama/llama.h +95 -177
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
- data/ext/sources/ggml/CMakeLists.txt +59 -31
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-backend.h +17 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -1
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml.h +221 -16
- data/ext/sources/ggml/src/CMakeLists.txt +17 -2
- data/ext/sources/ggml/src/ggml-alloc.c +265 -141
- data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
- data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
- data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
- data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
- data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- data/ext/sources/ggml/src/ggml-impl.h +119 -9
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml.c +478 -98
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/src/whisper.cpp +23 -46
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +3 -3
- data/ext/sources/tests/test-vad.cpp +2 -2
- data/lib/whisper/model/uri.rb +1 -1
- data/sig/whisper.rbs +7 -0
- data/test/test_params.rb +8 -0
- data/test/test_whisper.rb +1 -1
- data/whispercpp.gemspec +1 -1
- metadata +164 -157
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -0,0 +1,755 @@
|
|
1
|
+
#include "common.cuh"
|
2
|
+
#include "fattn-common.cuh"
|
3
|
+
#include "fattn-tile.cuh"
|
4
|
+
|
5
|
+
// kq_stride == number of KQ rows to process per iteration
|
6
|
+
// kq_nbatch == number of K columns to load in parallel for KQ calculation
|
7
|
+
|
8
|
+
static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) {
|
9
|
+
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
10
|
+
if (GGML_CUDA_CC_IS_RDNA(cc)) {
|
11
|
+
switch (D) {
|
12
|
+
case 64:
|
13
|
+
return 128;
|
14
|
+
case 128:
|
15
|
+
case 256:
|
16
|
+
return ncols <= 16 ? 128 : 64;
|
17
|
+
default:
|
18
|
+
GGML_ABORT("fatal error");
|
19
|
+
return -1;
|
20
|
+
}
|
21
|
+
}
|
22
|
+
switch (D) {
|
23
|
+
case 64:
|
24
|
+
return ncols == 32 ? 128 : 64;
|
25
|
+
case 128:
|
26
|
+
return ncols == 32 ? 64 : 32;
|
27
|
+
case 256:
|
28
|
+
return 32;
|
29
|
+
default:
|
30
|
+
GGML_ABORT("fatal error");
|
31
|
+
return -1;
|
32
|
+
}
|
33
|
+
}
|
34
|
+
if (fast_fp16_available(cc)) {
|
35
|
+
switch (D) {
|
36
|
+
case 64:
|
37
|
+
case 128:
|
38
|
+
case 256:
|
39
|
+
return ncols <= 16 ? 128 : 64;
|
40
|
+
default:
|
41
|
+
GGML_ABORT("fatal error");
|
42
|
+
return -1;
|
43
|
+
}
|
44
|
+
}
|
45
|
+
switch (D) {
|
46
|
+
case 64:
|
47
|
+
return ncols <= 16 ? 128 : 64;
|
48
|
+
case 128:
|
49
|
+
return ncols <= 16 ? 64 : 32;
|
50
|
+
case 256:
|
51
|
+
return 32;
|
52
|
+
default:
|
53
|
+
GGML_ABORT("fatal error");
|
54
|
+
return -1;
|
55
|
+
}
|
56
|
+
GGML_UNUSED(warp_size);
|
57
|
+
}
|
58
|
+
|
59
|
+
static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
|
60
|
+
#ifdef GGML_USE_HIP
|
61
|
+
#ifdef RDNA
|
62
|
+
switch (D) {
|
63
|
+
case 64:
|
64
|
+
return 128;
|
65
|
+
case 128:
|
66
|
+
case 256:
|
67
|
+
return ncols <= 16 ? 128 : 64;
|
68
|
+
default:
|
69
|
+
return -1;
|
70
|
+
}
|
71
|
+
#else
|
72
|
+
switch (D) {
|
73
|
+
case 64:
|
74
|
+
return ncols == 32 ? 128 : 64;
|
75
|
+
case 128:
|
76
|
+
return ncols == 32 ? 64 : 32;
|
77
|
+
case 256:
|
78
|
+
return 32;
|
79
|
+
default:
|
80
|
+
return -1;
|
81
|
+
}
|
82
|
+
#endif // RDNA
|
83
|
+
#else
|
84
|
+
#ifdef FAST_FP16_AVAILABLE
|
85
|
+
switch (D) {
|
86
|
+
case 64:
|
87
|
+
case 128:
|
88
|
+
case 256:
|
89
|
+
return ncols <= 16 ? 128 : 64;
|
90
|
+
default:
|
91
|
+
return -1;
|
92
|
+
}
|
93
|
+
#else
|
94
|
+
switch (D) {
|
95
|
+
case 64:
|
96
|
+
return ncols <= 16 ? 128 : 64;
|
97
|
+
case 128:
|
98
|
+
return ncols <= 16 ? 64 : 32;
|
99
|
+
case 256:
|
100
|
+
return 32;
|
101
|
+
default:
|
102
|
+
return -1;
|
103
|
+
}
|
104
|
+
#endif // FAST_FP16_AVAILABLE
|
105
|
+
#endif // GGML_USE_HIP
|
106
|
+
GGML_UNUSED_VARS(ncols, warp_size);
|
107
|
+
}
|
108
|
+
|
109
|
+
static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) {
|
110
|
+
#ifdef GGML_USE_HIP
|
111
|
+
switch (D) {
|
112
|
+
case 64:
|
113
|
+
return 64;
|
114
|
+
case 128:
|
115
|
+
case 256:
|
116
|
+
return 128;
|
117
|
+
default:
|
118
|
+
return -1;
|
119
|
+
}
|
120
|
+
#else
|
121
|
+
#ifdef FAST_FP16_AVAILABLE
|
122
|
+
switch (D) {
|
123
|
+
case 64:
|
124
|
+
return 64;
|
125
|
+
case 128:
|
126
|
+
case 256:
|
127
|
+
return 128;
|
128
|
+
default:
|
129
|
+
return -1;
|
130
|
+
}
|
131
|
+
#else
|
132
|
+
switch (D) {
|
133
|
+
case 64:
|
134
|
+
return 64;
|
135
|
+
case 128:
|
136
|
+
return 128;
|
137
|
+
case 256:
|
138
|
+
return ncols <= 16 ? 128 : 64;
|
139
|
+
default:
|
140
|
+
return -1;
|
141
|
+
}
|
142
|
+
#endif // FAST_FP16_AVAILABLE
|
143
|
+
#endif // GGML_USE_HIP
|
144
|
+
GGML_UNUSED_VARS(ncols, warp_size);
|
145
|
+
}
|
146
|
+
|
147
|
+
static int fattn_tile_get_nthreads_host(const int cc, const int ncols) {
|
148
|
+
return 256;
|
149
|
+
GGML_UNUSED_VARS(cc, ncols);
|
150
|
+
}
|
151
|
+
|
152
|
+
static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) {
|
153
|
+
return 256;
|
154
|
+
GGML_UNUSED(ncols);
|
155
|
+
}
|
156
|
+
|
157
|
+
static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) {
|
158
|
+
#ifdef RDNA
|
159
|
+
return 3;
|
160
|
+
#else
|
161
|
+
return ncols <= 16 ? 3 : 2;
|
162
|
+
#endif // RDNA
|
163
|
+
GGML_UNUSED(ncols);
|
164
|
+
}
|
165
|
+
|
166
|
+
template<int D, int ncols, bool use_logit_softcap> // D == head size
|
167
|
+
__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols))
|
168
|
+
static __global__ void flash_attn_tile(
|
169
|
+
const char * __restrict__ Q,
|
170
|
+
const char * __restrict__ K,
|
171
|
+
const char * __restrict__ V,
|
172
|
+
const char * __restrict__ mask,
|
173
|
+
const char * __restrict__ sinks,
|
174
|
+
const int * __restrict__ KV_max,
|
175
|
+
float * __restrict__ dst,
|
176
|
+
float2 * __restrict__ dst_meta,
|
177
|
+
const float scale,
|
178
|
+
const float max_bias,
|
179
|
+
const float m0,
|
180
|
+
const float m1,
|
181
|
+
const uint32_t n_head_log2,
|
182
|
+
const float logit_softcap,
|
183
|
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
184
|
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
185
|
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
186
|
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
187
|
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
188
|
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
189
|
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
190
|
+
#ifdef FLASH_ATTN_AVAILABLE
|
191
|
+
|
192
|
+
// Skip unused kernel variants for faster compilation:
|
193
|
+
#ifdef FP16_MMA_AVAILABLE
|
194
|
+
NO_DEVICE_CODE;
|
195
|
+
return;
|
196
|
+
#endif // FP16_MMA_AVAILABLE
|
197
|
+
|
198
|
+
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
199
|
+
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
200
|
+
max_bias, m0, m1, n_head_log2, logit_softcap,
|
201
|
+
ne00, ne01, ne02, ne03,
|
202
|
+
nb01, nb02, nb03,
|
203
|
+
ne10, ne11, ne12, ne13,
|
204
|
+
nb11, nb12, nb13,
|
205
|
+
nb21, nb22, nb23,
|
206
|
+
ne31, ne32, ne33,
|
207
|
+
nb31, nb32, nb33);
|
208
|
+
NO_DEVICE_CODE;
|
209
|
+
return;
|
210
|
+
}
|
211
|
+
|
212
|
+
constexpr int warp_size = 32;
|
213
|
+
constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size;
|
214
|
+
constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
|
215
|
+
static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
|
216
|
+
constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size);
|
217
|
+
static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch");
|
218
|
+
|
219
|
+
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
220
|
+
|
221
|
+
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
222
|
+
|
223
|
+
const int sequence = blockIdx.z / ne02;
|
224
|
+
const int head = blockIdx.z - sequence*ne02;
|
225
|
+
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
226
|
+
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
227
|
+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
228
|
+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
229
|
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
230
|
+
const float * sinksf = (const float *) (sinks);
|
231
|
+
|
232
|
+
const int stride_KV2 = nb11 / sizeof(half2);
|
233
|
+
|
234
|
+
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
235
|
+
|
236
|
+
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
237
|
+
constexpr int cpy_ne = cpy_nb / 4;
|
238
|
+
|
239
|
+
constexpr int cpw = ncols/nwarps; // cols per warp
|
240
|
+
|
241
|
+
// softmax_iter_j == number of KQ columns for which to calculate softmax in parallel.
|
242
|
+
// KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes.
|
243
|
+
#ifdef FAST_FP16_AVAILABLE
|
244
|
+
constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
|
245
|
+
|
246
|
+
__shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
247
|
+
__shared__ half2 Q_tmp[ncols][D/2];
|
248
|
+
__shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
249
|
+
half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
250
|
+
#else
|
251
|
+
constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
|
252
|
+
|
253
|
+
__shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
254
|
+
__shared__ float Q_tmp[ncols][D];
|
255
|
+
__shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
256
|
+
float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
257
|
+
#endif // FAST_FP16_AVAILABLE
|
258
|
+
static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j");
|
259
|
+
|
260
|
+
float KQ_max[cpw];
|
261
|
+
#pragma unroll
|
262
|
+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
263
|
+
KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
|
264
|
+
}
|
265
|
+
float KQ_sum[cpw] = {0.0f};
|
266
|
+
|
267
|
+
// Load Q data, convert to FP16 if fast.
|
268
|
+
#pragma unroll
|
269
|
+
for (int j0 = 0; j0 < cpw; ++j0) {
|
270
|
+
const int j = j0 + threadIdx.y*cpw;
|
271
|
+
|
272
|
+
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
273
|
+
|
274
|
+
#pragma unroll
|
275
|
+
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
276
|
+
float tmp_f[cpy_ne_D] = {0.0f};
|
277
|
+
if (ic0 + j < ne01) {
|
278
|
+
ggml_cuda_memcpy_1<sizeof(tmp_f)>(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]);
|
279
|
+
}
|
280
|
+
|
281
|
+
#pragma unroll
|
282
|
+
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
283
|
+
tmp_f[i1] *= scale;
|
284
|
+
}
|
285
|
+
|
286
|
+
#ifdef FAST_FP16_AVAILABLE
|
287
|
+
half2 tmp_h2[cpy_ne_D/2];
|
288
|
+
#pragma unroll
|
289
|
+
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
|
290
|
+
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
|
291
|
+
}
|
292
|
+
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2);
|
293
|
+
#else
|
294
|
+
ggml_cuda_memcpy_1<sizeof(tmp_f)> (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f);
|
295
|
+
#endif // FAST_FP16_AVAILABLE
|
296
|
+
}
|
297
|
+
}
|
298
|
+
|
299
|
+
__syncthreads();
|
300
|
+
|
301
|
+
// Main loop over KV cache:
|
302
|
+
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
303
|
+
for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) {
|
304
|
+
// Calculate KQ tile and keep track of new maximum KQ values:
|
305
|
+
|
306
|
+
float KQ_max_new[cpw];
|
307
|
+
#pragma unroll
|
308
|
+
for (int j = 0; j < cpw; ++j) {
|
309
|
+
KQ_max_new[j] = KQ_max[j];
|
310
|
+
}
|
311
|
+
|
312
|
+
float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication.
|
313
|
+
|
314
|
+
// KQ = K @ Q matrix multiplication:
|
315
|
+
#pragma unroll
|
316
|
+
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) {
|
317
|
+
#pragma unroll
|
318
|
+
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) {
|
319
|
+
const int i_KQ = i_KQ_0 + threadIdx.y;
|
320
|
+
|
321
|
+
#ifdef FAST_FP16_AVAILABLE
|
322
|
+
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size);
|
323
|
+
#pragma unroll
|
324
|
+
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
325
|
+
ggml_cuda_memcpy_1<cpy_ne_kqnb*4>(
|
326
|
+
&KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb],
|
327
|
+
&K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]);
|
328
|
+
}
|
329
|
+
#else
|
330
|
+
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size;
|
331
|
+
#pragma unroll
|
332
|
+
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
333
|
+
half2 tmp_h2[cpy_ne_kqnb/2];
|
334
|
+
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
335
|
+
tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]);
|
336
|
+
|
337
|
+
float2 tmp_f2[cpy_ne_kqnb/2];
|
338
|
+
#pragma unroll
|
339
|
+
for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) {
|
340
|
+
tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]);
|
341
|
+
}
|
342
|
+
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
343
|
+
&KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2);
|
344
|
+
}
|
345
|
+
#endif // FAST_FP16_AVAILABLE
|
346
|
+
}
|
347
|
+
|
348
|
+
__syncthreads();
|
349
|
+
|
350
|
+
#ifdef FAST_FP16_AVAILABLE
|
351
|
+
#pragma unroll
|
352
|
+
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
|
353
|
+
half2 K_k[kq_stride/warp_size][cpy_ne];
|
354
|
+
half2 Q_k[cpw][cpy_ne];
|
355
|
+
#else
|
356
|
+
#pragma unroll
|
357
|
+
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
|
358
|
+
float K_k[kq_stride/warp_size][cpy_ne];
|
359
|
+
float Q_k[cpw][cpy_ne];
|
360
|
+
#endif // FAST_FP16_AVAILABLE
|
361
|
+
|
362
|
+
#pragma unroll
|
363
|
+
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
364
|
+
const int i_KQ = i_KQ_0 + threadIdx.x;
|
365
|
+
|
366
|
+
#ifdef FAST_FP16_AVAILABLE
|
367
|
+
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
|
368
|
+
#else
|
369
|
+
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
|
370
|
+
#endif // FAST_FP16_AVAILABLE
|
371
|
+
}
|
372
|
+
#pragma unroll
|
373
|
+
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
374
|
+
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
375
|
+
|
376
|
+
#ifdef FAST_FP16_AVAILABLE
|
377
|
+
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
|
378
|
+
#else
|
379
|
+
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
|
380
|
+
#endif // FAST_FP16_AVAILABLE
|
381
|
+
}
|
382
|
+
|
383
|
+
#pragma unroll
|
384
|
+
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
385
|
+
#pragma unroll
|
386
|
+
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
387
|
+
#pragma unroll
|
388
|
+
for (int k = 0; k < cpy_ne; ++k) {
|
389
|
+
ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]);
|
390
|
+
}
|
391
|
+
}
|
392
|
+
}
|
393
|
+
}
|
394
|
+
|
395
|
+
if (k_KQ_0 + kq_nbatch < D) {
|
396
|
+
__syncthreads(); // Sync not needed on last iteration.
|
397
|
+
}
|
398
|
+
}
|
399
|
+
|
400
|
+
// Apply logit softcap, mask, update KQ_max:
|
401
|
+
#pragma unroll
|
402
|
+
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
403
|
+
const int i_KQ = i_KQ_0 + threadIdx.x;
|
404
|
+
|
405
|
+
#pragma unroll
|
406
|
+
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
407
|
+
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
408
|
+
|
409
|
+
if (use_logit_softcap) {
|
410
|
+
KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
411
|
+
}
|
412
|
+
|
413
|
+
KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
414
|
+
|
415
|
+
KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
416
|
+
}
|
417
|
+
}
|
418
|
+
|
419
|
+
__syncthreads();
|
420
|
+
|
421
|
+
// Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
|
422
|
+
#pragma unroll
|
423
|
+
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
424
|
+
#ifdef FAST_FP16_AVAILABLE
|
425
|
+
half tmp[kq_stride/warp_size][softmax_iter_j];
|
426
|
+
#else
|
427
|
+
float tmp[kq_stride/warp_size][softmax_iter_j];
|
428
|
+
#endif // FAST_FP16_AVAILABLE
|
429
|
+
|
430
|
+
#pragma unroll
|
431
|
+
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
432
|
+
KQ_max_new[j0+j1] = warp_reduce_max<warp_size>(KQ_max_new[j0+j1]);
|
433
|
+
const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]);
|
434
|
+
KQ_max[j0+j1] = KQ_max_new[j0+j1];
|
435
|
+
|
436
|
+
float KQ_sum_add = 0.0f;
|
437
|
+
#pragma unroll
|
438
|
+
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
439
|
+
const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]);
|
440
|
+
KQ_sum_add += val;
|
441
|
+
tmp[i0/warp_size][j1] = val;
|
442
|
+
}
|
443
|
+
KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add;
|
444
|
+
|
445
|
+
#ifdef FAST_FP16_AVAILABLE
|
446
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
447
|
+
#pragma unroll
|
448
|
+
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
449
|
+
VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2;
|
450
|
+
}
|
451
|
+
#else
|
452
|
+
#pragma unroll
|
453
|
+
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
454
|
+
VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale;
|
455
|
+
VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale;
|
456
|
+
}
|
457
|
+
#endif // FAST_FP16_AVAILABLE
|
458
|
+
}
|
459
|
+
|
460
|
+
#pragma unroll
|
461
|
+
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
462
|
+
const int i = i0 + threadIdx.x;
|
463
|
+
|
464
|
+
ggml_cuda_memcpy_1<sizeof(tmp[0])>(
|
465
|
+
KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]);
|
466
|
+
}
|
467
|
+
}
|
468
|
+
|
469
|
+
// VKQ = V @ KQ matrix multiplication:
|
470
|
+
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K.
|
471
|
+
static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter");
|
472
|
+
#pragma unroll
|
473
|
+
for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) {
|
474
|
+
#pragma unroll
|
475
|
+
for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) {
|
476
|
+
const int k_tile = k1 + threadIdx.y;
|
477
|
+
|
478
|
+
#ifdef FAST_FP16_AVAILABLE
|
479
|
+
constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size);
|
480
|
+
#pragma unroll
|
481
|
+
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
482
|
+
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
483
|
+
&KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D],
|
484
|
+
&V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]);
|
485
|
+
}
|
486
|
+
#else
|
487
|
+
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
488
|
+
#pragma unroll
|
489
|
+
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
490
|
+
half2 tmp_h2[cpy_ne_D/2];
|
491
|
+
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
492
|
+
tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]);
|
493
|
+
|
494
|
+
float2 tmp_f2[cpy_ne_D/2];
|
495
|
+
#pragma unroll
|
496
|
+
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
|
497
|
+
tmp_f2[i1] = __half22float2(tmp_h2[i1]);
|
498
|
+
}
|
499
|
+
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
500
|
+
&KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2);
|
501
|
+
}
|
502
|
+
#endif // FAST_FP16_AVAILABLE
|
503
|
+
}
|
504
|
+
|
505
|
+
__syncthreads();
|
506
|
+
|
507
|
+
#ifdef FAST_FP16_AVAILABLE
|
508
|
+
#pragma unroll
|
509
|
+
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
510
|
+
half2 V_k[(D/2)/warp_size];
|
511
|
+
half2 KQ_k[cpw];
|
512
|
+
|
513
|
+
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
514
|
+
#pragma unroll
|
515
|
+
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
516
|
+
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]);
|
517
|
+
}
|
518
|
+
#pragma unroll
|
519
|
+
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
520
|
+
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
521
|
+
|
522
|
+
half tmp[softmax_iter_j];
|
523
|
+
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(half)>(
|
524
|
+
&tmp, KQ[j][k0 + k1]);
|
525
|
+
#pragma unroll
|
526
|
+
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
527
|
+
KQ_k[j0+j1] = __half2half2(tmp[j1]);
|
528
|
+
}
|
529
|
+
}
|
530
|
+
|
531
|
+
#pragma unroll
|
532
|
+
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
533
|
+
#pragma unroll
|
534
|
+
for (int j0 = 0; j0 < cpw; ++j0) {
|
535
|
+
VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0];
|
536
|
+
}
|
537
|
+
}
|
538
|
+
}
|
539
|
+
#else
|
540
|
+
#pragma unroll
|
541
|
+
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
542
|
+
float2 V_k[(D/2)/warp_size];
|
543
|
+
float KQ_k[cpw];
|
544
|
+
|
545
|
+
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
546
|
+
#pragma unroll
|
547
|
+
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
548
|
+
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]);
|
549
|
+
}
|
550
|
+
#pragma unroll
|
551
|
+
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
552
|
+
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
553
|
+
|
554
|
+
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(float)>(
|
555
|
+
&KQ_k[j0], KQ[j][k0 + k1]);
|
556
|
+
}
|
557
|
+
|
558
|
+
#pragma unroll
|
559
|
+
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
560
|
+
#pragma unroll
|
561
|
+
for (int j0 = 0; j0 < cpw; ++j0) {
|
562
|
+
VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0];
|
563
|
+
VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0];
|
564
|
+
}
|
565
|
+
}
|
566
|
+
}
|
567
|
+
#endif // FAST_FP16_AVAILABLE
|
568
|
+
|
569
|
+
__syncthreads();
|
570
|
+
}
|
571
|
+
}
|
572
|
+
|
573
|
+
|
574
|
+
// Attention sink: adjust running max and sum once per head
|
575
|
+
if (sinksf && blockIdx.y == 0) {
|
576
|
+
const float sink = sinksf[head];
|
577
|
+
|
578
|
+
#pragma unroll
|
579
|
+
for (int j0 = 0; j0 < cpw; ++j0) {
|
580
|
+
float KQ_max_new_j = fmaxf(KQ_max[j0], sink);
|
581
|
+
KQ_max_new_j = warp_reduce_max<warp_size>(KQ_max_new_j);
|
582
|
+
|
583
|
+
const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j);
|
584
|
+
KQ_max[j0] = KQ_max_new_j;
|
585
|
+
|
586
|
+
const float val = expf(sink - KQ_max[j0]);
|
587
|
+
KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale;
|
588
|
+
if (threadIdx.x == 0) {
|
589
|
+
KQ_sum[j0] += val;
|
590
|
+
}
|
591
|
+
|
592
|
+
#ifdef FAST_FP16_AVAILABLE
|
593
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
594
|
+
#pragma unroll
|
595
|
+
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
596
|
+
VKQ[j0][i0/warp_size] *= KQ_max_scale_h2;
|
597
|
+
}
|
598
|
+
#else
|
599
|
+
#pragma unroll
|
600
|
+
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
601
|
+
VKQ[j0][i0/warp_size].x *= KQ_max_scale;
|
602
|
+
VKQ[j0][i0/warp_size].y *= KQ_max_scale;
|
603
|
+
}
|
604
|
+
#endif // FAST_FP16_AVAILABLE
|
605
|
+
}
|
606
|
+
}
|
607
|
+
|
608
|
+
#pragma unroll
|
609
|
+
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
610
|
+
KQ_sum[j_VKQ_0] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ_0]);
|
611
|
+
}
|
612
|
+
if (gridDim.y == 1) {
|
613
|
+
#pragma unroll
|
614
|
+
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
615
|
+
#ifdef FAST_FP16_AVAILABLE
|
616
|
+
const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]);
|
617
|
+
#pragma unroll
|
618
|
+
for (int i = 0; i < (D/2)/warp_size; ++i) {
|
619
|
+
VKQ[j_VKQ_0][i] *= KQ_sum_j_inv;
|
620
|
+
}
|
621
|
+
#else
|
622
|
+
const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0];
|
623
|
+
#pragma unroll
|
624
|
+
for (int i = 0; i < (D/2)/warp_size; ++i) {
|
625
|
+
VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv;
|
626
|
+
VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv;
|
627
|
+
}
|
628
|
+
#endif // FAST_FP16_AVAILABLE
|
629
|
+
}
|
630
|
+
}
|
631
|
+
|
632
|
+
// Write back results:
|
633
|
+
#pragma unroll
|
634
|
+
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
635
|
+
const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw;
|
636
|
+
|
637
|
+
if (ic0 + j_VKQ >= ne01) {
|
638
|
+
return;
|
639
|
+
}
|
640
|
+
|
641
|
+
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
642
|
+
|
643
|
+
#ifdef FAST_FP16_AVAILABLE
|
644
|
+
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
645
|
+
#pragma unroll
|
646
|
+
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
647
|
+
float2 tmp[cpy_ne_D];
|
648
|
+
#pragma unroll
|
649
|
+
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
650
|
+
tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]);
|
651
|
+
}
|
652
|
+
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
|
653
|
+
}
|
654
|
+
#else
|
655
|
+
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
656
|
+
#pragma unroll
|
657
|
+
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
658
|
+
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
659
|
+
&dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]);
|
660
|
+
}
|
661
|
+
#endif // FAST_FP16_AVAILABLE
|
662
|
+
|
663
|
+
if (gridDim.y != 1 && threadIdx.x == 0) {
|
664
|
+
dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]);
|
665
|
+
}
|
666
|
+
}
|
667
|
+
#else
|
668
|
+
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
669
|
+
max_bias, m0, m1, n_head_log2, logit_softcap,
|
670
|
+
ne00, ne01, ne02, ne03,
|
671
|
+
nb01, nb02, nb03,
|
672
|
+
ne10, ne11, ne12, ne13,
|
673
|
+
nb11, nb12, nb13,
|
674
|
+
nb21, nb22, nb23,
|
675
|
+
ne31, ne32, ne33,
|
676
|
+
nb31, nb32, nb33);
|
677
|
+
NO_DEVICE_CODE;
|
678
|
+
#endif // FLASH_ATTN_AVAILABLE
|
679
|
+
}
|
680
|
+
|
681
|
+
template <int D, bool use_logit_softcap>
|
682
|
+
static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
683
|
+
const ggml_tensor * Q = dst->src[0];
|
684
|
+
|
685
|
+
const int id = ggml_cuda_get_device();
|
686
|
+
const int cc = ggml_cuda_info().devices[id].cc;
|
687
|
+
const int warp_size = 32;
|
688
|
+
|
689
|
+
constexpr size_t nbytes_shared = 0;
|
690
|
+
|
691
|
+
#ifdef GGML_USE_HIP
|
692
|
+
if constexpr (D <= 128) {
|
693
|
+
if (Q->ne[1] > 32) {
|
694
|
+
constexpr int cols_per_block = 64;
|
695
|
+
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
696
|
+
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
697
|
+
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
698
|
+
launch_fattn<D, cols_per_block, 1>
|
699
|
+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
700
|
+
return;
|
701
|
+
}
|
702
|
+
}
|
703
|
+
#endif // GGML_USE_HIP
|
704
|
+
|
705
|
+
if (Q->ne[1] > 16) {
|
706
|
+
constexpr int cols_per_block = 32;
|
707
|
+
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
708
|
+
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
709
|
+
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
710
|
+
launch_fattn<D, cols_per_block, 1>
|
711
|
+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
712
|
+
return;
|
713
|
+
}
|
714
|
+
|
715
|
+
constexpr int cols_per_block = 16;
|
716
|
+
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
717
|
+
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
718
|
+
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
719
|
+
launch_fattn<D, cols_per_block, 1>
|
720
|
+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
721
|
+
}
|
722
|
+
|
723
|
+
template <bool use_logit_softcap>
|
724
|
+
static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
725
|
+
const ggml_tensor * Q = dst->src[0];
|
726
|
+
switch (Q->ne[0]) {
|
727
|
+
case 64: {
|
728
|
+
launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst);
|
729
|
+
} break;
|
730
|
+
case 128: {
|
731
|
+
launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst);
|
732
|
+
} break;
|
733
|
+
case 256: {
|
734
|
+
launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst);
|
735
|
+
} break;
|
736
|
+
default: {
|
737
|
+
GGML_ABORT("Unsupported head size");
|
738
|
+
} break;
|
739
|
+
}
|
740
|
+
}
|
741
|
+
|
742
|
+
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
743
|
+
const ggml_tensor * KQV = dst;
|
744
|
+
|
745
|
+
float logit_softcap;
|
746
|
+
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
747
|
+
|
748
|
+
if (logit_softcap == 0.0f) {
|
749
|
+
constexpr bool use_logit_softcap = false;
|
750
|
+
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
|
751
|
+
} else {
|
752
|
+
constexpr bool use_logit_softcap = true;
|
753
|
+
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
|
754
|
+
}
|
755
|
+
}
|