whispercpp 1.3.3 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/ruby_whisper_params.c +55 -25
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/build-xcframework.sh +24 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +4 -2
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/server/server.cpp +24 -13
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
- data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
- data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
- data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
- data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
- data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
- data/ext/sources/examples/talk-llama/llama-context.h +44 -29
- data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
- data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
- data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
- data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
- data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
- data/ext/sources/examples/talk-llama/llama-model.h +60 -9
- data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
- data/ext/sources/examples/talk-llama/llama.cpp +65 -10
- data/ext/sources/examples/talk-llama/llama.h +95 -177
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
- data/ext/sources/ggml/CMakeLists.txt +59 -31
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-backend.h +17 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -1
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml.h +221 -16
- data/ext/sources/ggml/src/CMakeLists.txt +17 -2
- data/ext/sources/ggml/src/ggml-alloc.c +265 -141
- data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
- data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
- data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
- data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
- data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- data/ext/sources/ggml/src/ggml-impl.h +119 -9
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml.c +478 -98
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/src/whisper.cpp +23 -46
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +3 -3
- data/ext/sources/tests/test-vad.cpp +2 -2
- data/lib/whisper/model/uri.rb +1 -1
- data/sig/whisper.rbs +7 -0
- data/test/test_params.rb +8 -0
- data/test/test_whisper.rb +1 -1
- data/whispercpp.gemspec +1 -1
- metadata +164 -157
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -392,7 +392,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|
392
392
|
}
|
393
393
|
}
|
394
394
|
|
395
|
-
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles,
|
395
|
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles,
|
396
|
+
bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
|
396
397
|
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
397
398
|
const float2 * const __restrict__ Q_f2,
|
398
399
|
const half2 * const __restrict__ K_h2,
|
@@ -408,7 +409,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
408
409
|
const int stride_K,
|
409
410
|
const int stride_V,
|
410
411
|
const int stride_mask,
|
411
|
-
const int jt,
|
412
412
|
half2 * const __restrict__ tile_Q,
|
413
413
|
half2 * const __restrict__ tile_K,
|
414
414
|
half2 * const __restrict__ tile_V,
|
@@ -418,7 +418,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
418
418
|
float * const __restrict__ KQ_max,
|
419
419
|
float * const __restrict__ KQ_rowsum,
|
420
420
|
const int kb0) {
|
421
|
-
#ifdef
|
421
|
+
#ifdef TURING_MMA_AVAILABLE
|
422
422
|
typedef fattn_mma_f16_config<DKQ, DV> c;
|
423
423
|
|
424
424
|
#ifdef CP_ASYNC_AVAILABLE
|
@@ -455,7 +455,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
455
455
|
cp_async_wait_all();
|
456
456
|
__syncthreads();
|
457
457
|
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
458
|
-
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
458
|
+
(V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
|
459
459
|
} else {
|
460
460
|
constexpr bool use_cp_async = nstages == 1;
|
461
461
|
if (ncols2 > 1 || mask_h2) {
|
@@ -471,7 +471,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
471
471
|
if (nstages <= 1) {
|
472
472
|
constexpr bool use_cp_async = nstages == 1;
|
473
473
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
474
|
-
(K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
474
|
+
(K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
475
475
|
if (use_cp_async) {
|
476
476
|
cp_async_wait_all();
|
477
477
|
}
|
@@ -715,7 +715,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
715
715
|
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
716
716
|
}
|
717
717
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
718
|
-
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
718
|
+
(K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
719
719
|
}
|
720
720
|
}
|
721
721
|
|
@@ -732,7 +732,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
732
732
|
if (nstages <= 1 && i0_start < reusable_cutoff) {
|
733
733
|
constexpr bool use_cp_async = nstages == 1;
|
734
734
|
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
735
|
-
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
735
|
+
(V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
736
736
|
if (use_cp_async) {
|
737
737
|
cp_async_wait_all();
|
738
738
|
}
|
@@ -767,17 +767,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
767
767
|
}
|
768
768
|
}
|
769
769
|
#else
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
776
|
-
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
|
777
|
-
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
|
778
|
-
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
|
770
|
+
GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup,
|
771
|
+
scale, slope, logit_softcap, ne01, ne02,
|
772
|
+
stride_K, stride_V, stride_mask,
|
773
|
+
tile_Q, tile_K, tile_V, tile_mask,
|
774
|
+
Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
779
775
|
NO_DEVICE_CODE;
|
780
|
-
#endif //
|
776
|
+
#endif // TURING_MMA_AVAILABLE
|
781
777
|
}
|
782
778
|
|
783
779
|
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
|
@@ -786,6 +782,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
786
782
|
const half2 * const __restrict__ K_h2,
|
787
783
|
const half2 * const __restrict__ V_h2,
|
788
784
|
const half2 * const __restrict__ mask_h2,
|
785
|
+
const float * const __restrict__ sinks_f,
|
789
786
|
float2 * const __restrict__ dstk,
|
790
787
|
float2 * const __restrict__ dstk_fixup,
|
791
788
|
const float scale,
|
@@ -801,7 +798,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
801
798
|
const int jt,
|
802
799
|
const int kb0_start,
|
803
800
|
const int kb0_stop) {
|
804
|
-
#ifdef
|
801
|
+
#ifdef TURING_MMA_AVAILABLE
|
805
802
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
806
803
|
|
807
804
|
typedef fattn_mma_f16_config<DKQ, DV> c;
|
@@ -920,21 +917,22 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
920
917
|
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
921
918
|
}
|
922
919
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
923
|
-
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
920
|
+
(K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
924
921
|
}
|
925
922
|
|
926
923
|
// Iterate over ne11 == previous tokens:
|
927
|
-
|
924
|
+
int kb0 = kb0_start;
|
925
|
+
for (; kb0 < kb0_stop-1; ++kb0) {
|
928
926
|
constexpr bool last_iter = false;
|
929
927
|
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
930
928
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
931
|
-
ne01, ne02, stride_K, stride_V, stride_mask,
|
929
|
+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
932
930
|
}
|
933
931
|
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
934
932
|
constexpr bool last_iter = true;
|
935
933
|
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
936
934
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
937
|
-
ne01, ne02, stride_K, stride_V, stride_mask,
|
935
|
+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
938
936
|
}
|
939
937
|
|
940
938
|
// With multi-stage loading there is no __syncthreads at the end of the iter,
|
@@ -957,6 +955,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
957
955
|
}
|
958
956
|
}
|
959
957
|
|
958
|
+
// If attention sinks are used, potentially re-scale if KQ_max is small.
|
959
|
+
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
|
960
|
+
// so it's being done unconditionally for every thread.
|
961
|
+
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
|
962
|
+
float KQ_max_scale[cols_per_thread];
|
963
|
+
#pragma unroll
|
964
|
+
for (int col = 0; col < cols_per_thread; ++col) {
|
965
|
+
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
|
966
|
+
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
|
967
|
+
const float sink = sinks_f[jc % ncols2];
|
968
|
+
|
969
|
+
const float KQ_max_new = fmaxf(KQ_max[col], sink);
|
970
|
+
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
|
971
|
+
KQ_max_scale[col] = expf(KQ_max_diff);
|
972
|
+
KQ_max[col] = KQ_max_new;
|
973
|
+
|
974
|
+
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
|
975
|
+
|
976
|
+
const float KQ_max_add = expf(sink - KQ_max_new);
|
977
|
+
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
|
978
|
+
}
|
979
|
+
|
980
|
+
if (ntiles == 1) {
|
981
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
982
|
+
#pragma unroll
|
983
|
+
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
|
984
|
+
#pragma unroll
|
985
|
+
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
|
986
|
+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
987
|
+
}
|
988
|
+
}
|
989
|
+
} else {
|
990
|
+
#pragma unroll
|
991
|
+
for (int col = 0; col < cols_per_thread; ++col) {
|
992
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
993
|
+
#pragma unroll
|
994
|
+
for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
|
995
|
+
#pragma unroll
|
996
|
+
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
|
997
|
+
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
|
998
|
+
}
|
999
|
+
}
|
1000
|
+
}
|
1001
|
+
}
|
1002
|
+
}
|
1003
|
+
|
960
1004
|
// Combine VKQ accumulator values if np > 1.
|
961
1005
|
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
962
1006
|
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
@@ -1189,14 +1233,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
1189
1233
|
}
|
1190
1234
|
}
|
1191
1235
|
#else
|
1192
|
-
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1196
|
-
GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
|
1197
|
-
GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
|
1236
|
+
GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dstk_fixup,
|
1237
|
+
scale, slope, logit_softcap, ne01, ne02,
|
1238
|
+
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
|
1239
|
+
jt, kb0_start, kb0_stop);
|
1198
1240
|
NO_DEVICE_CODE;
|
1199
|
-
#endif //
|
1241
|
+
#endif // TURING_MMA_AVAILABLE
|
1200
1242
|
}
|
1201
1243
|
|
1202
1244
|
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
|
@@ -1206,6 +1248,8 @@ static __global__ void flash_attn_ext_f16(
|
|
1206
1248
|
const char * __restrict__ K,
|
1207
1249
|
const char * __restrict__ V,
|
1208
1250
|
const char * __restrict__ mask,
|
1251
|
+
const char * __restrict__ sinks,
|
1252
|
+
const int * __restrict__ KV_max,
|
1209
1253
|
float * __restrict__ dst,
|
1210
1254
|
float2 * __restrict__ dst_meta,
|
1211
1255
|
const float scale,
|
@@ -1214,30 +1258,14 @@ static __global__ void flash_attn_ext_f16(
|
|
1214
1258
|
const float m1,
|
1215
1259
|
const uint32_t n_head_log2,
|
1216
1260
|
const float logit_softcap,
|
1217
|
-
const
|
1218
|
-
|
1219
|
-
const
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1225
|
-
const int ne31,
|
1226
|
-
const int nb31,
|
1227
|
-
const int nb01,
|
1228
|
-
const int nb02,
|
1229
|
-
const int nb03,
|
1230
|
-
const int nb11,
|
1231
|
-
const int nb12,
|
1232
|
-
const int nb13,
|
1233
|
-
const int nb21,
|
1234
|
-
const int nb22,
|
1235
|
-
const int nb23,
|
1236
|
-
const int ne0,
|
1237
|
-
const int ne1,
|
1238
|
-
const int ne2,
|
1239
|
-
const int ne3) {
|
1240
|
-
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
1261
|
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
1262
|
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
1263
|
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
1264
|
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
1265
|
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
1266
|
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
1267
|
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
1268
|
+
#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
|
1241
1269
|
|
1242
1270
|
// Skip unused kernel variants for faster compilation:
|
1243
1271
|
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
@@ -1272,8 +1300,8 @@ static __global__ void flash_attn_ext_f16(
|
|
1272
1300
|
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
|
1273
1301
|
|
1274
1302
|
// kbc == k block continuous, current index in continuous ijk space.
|
1275
|
-
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
1276
|
-
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
1303
|
+
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
1304
|
+
const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
1277
1305
|
|
1278
1306
|
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
1279
1307
|
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
@@ -1282,32 +1310,42 @@ static __global__ void flash_attn_ext_f16(
|
|
1282
1310
|
// kb0 == k start index when in the output tile.
|
1283
1311
|
int kb0_start = kbc % iter_k;
|
1284
1312
|
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
1313
|
+
|
1285
1314
|
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
1286
|
-
const int
|
1287
|
-
const int
|
1315
|
+
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
1316
|
+
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
|
1317
|
+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
|
1318
|
+
|
1319
|
+
const int head0 = zt * ncols2;
|
1288
1320
|
|
1289
|
-
const float2 * Q_f2 = (const float2 *) (Q + nb02*
|
1290
|
-
const half2 * K_h2 = (const half2 *) (K + nb12*(
|
1291
|
-
const half2 * mask_h2 = ncols2
|
1292
|
-
|
1321
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
|
1322
|
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
|
1323
|
+
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
1324
|
+
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
1325
|
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
|
1293
1326
|
|
1294
|
-
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(
|
1327
|
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
1328
|
+
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
|
1295
1329
|
|
1296
|
-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias,
|
1330
|
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
1297
1331
|
|
1298
1332
|
const int kb0_start_kernel = kb0_start * kb_niter;
|
1299
|
-
|
1333
|
+
int kb0_stop_kernel = kb0_stop * kb_niter;
|
1334
|
+
|
1335
|
+
if (KV_max) {
|
1336
|
+
kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
|
1337
|
+
}
|
1300
1338
|
|
1301
1339
|
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
1302
1340
|
if (kb0_start == 0) {
|
1303
1341
|
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
1304
1342
|
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
1305
|
-
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
1343
|
+
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
1306
1344
|
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
1307
1345
|
} else {
|
1308
1346
|
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
1309
1347
|
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
1310
|
-
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
1348
|
+
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
1311
1349
|
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
1312
1350
|
}
|
1313
1351
|
|
@@ -1322,39 +1360,47 @@ static __global__ void flash_attn_ext_f16(
|
|
1322
1360
|
return;
|
1323
1361
|
}
|
1324
1362
|
|
1325
|
-
const int
|
1326
|
-
const int
|
1363
|
+
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
1364
|
+
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
|
1365
|
+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
|
1366
|
+
|
1367
|
+
const int head0 = zt * ncols2;
|
1327
1368
|
|
1328
|
-
const float2 * Q_f2 = (const float2 *) (Q + nb02*
|
1329
|
-
const half2 * K_h2 = (const half2 *) (K + nb12*(
|
1330
|
-
const half2 * mask_h2 = ncols2
|
1331
|
-
|
1369
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
|
1370
|
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
|
1371
|
+
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
1372
|
+
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
1373
|
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
|
1332
1374
|
|
1333
|
-
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(
|
1375
|
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
1376
|
+
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
|
1334
1377
|
|
1335
|
-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias,
|
1378
|
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
1336
1379
|
|
1337
1380
|
const int kb0_start_kernel = kb0_start * kb_niter;
|
1338
|
-
|
1381
|
+
int kb0_stop_kernel = kb0_stop * kb_niter;
|
1382
|
+
|
1383
|
+
if (KV_max) {
|
1384
|
+
kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
|
1385
|
+
}
|
1339
1386
|
|
1340
1387
|
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
1341
1388
|
constexpr bool needs_fixup = false;
|
1342
1389
|
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
1343
|
-
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
1390
|
+
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
1344
1391
|
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
1345
1392
|
#else
|
1346
|
-
|
1347
|
-
|
1348
|
-
|
1349
|
-
|
1350
|
-
|
1351
|
-
|
1352
|
-
|
1353
|
-
|
1354
|
-
|
1355
|
-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
1393
|
+
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
1394
|
+
max_bias, m0, m1, n_head_log2, logit_softcap,
|
1395
|
+
ne00, ne01, ne02, ne03,
|
1396
|
+
nb01, nb02, nb03,
|
1397
|
+
ne10, ne11, ne12, ne13,
|
1398
|
+
nb11, nb12, nb13,
|
1399
|
+
nb21, nb22, nb23,
|
1400
|
+
ne31, ne32, ne33,
|
1401
|
+
nb31, nb32, nb33);
|
1356
1402
|
NO_DEVICE_CODE;
|
1357
|
-
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(
|
1403
|
+
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
|
1358
1404
|
}
|
1359
1405
|
|
1360
1406
|
template <int DKQ, int DV, int ncols1, int ncols2>
|
@@ -1404,24 +1450,24 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
1404
1450
|
constexpr bool use_logit_softcap = false;
|
1405
1451
|
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
1406
1452
|
|
1407
|
-
#if !
|
1453
|
+
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
1408
1454
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
1409
1455
|
if (!shared_memory_limit_raised[id]) {
|
1410
1456
|
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
1411
1457
|
shared_memory_limit_raised[id] = true;
|
1412
1458
|
}
|
1413
|
-
#endif // !
|
1459
|
+
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
1414
1460
|
} else {
|
1415
1461
|
constexpr bool use_logit_softcap = true;
|
1416
1462
|
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
1417
1463
|
|
1418
|
-
#if !
|
1464
|
+
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
1419
1465
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
1420
1466
|
if (!shared_memory_limit_raised[id]) {
|
1421
1467
|
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
1422
1468
|
shared_memory_limit_raised[id] = true;
|
1423
1469
|
}
|
1424
|
-
#endif // !
|
1470
|
+
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
1425
1471
|
}
|
1426
1472
|
|
1427
1473
|
launch_fattn<DV, ncols1, ncols2>
|