whispercpp 1.3.1 → 1.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +4 -3
- data/README.md +92 -31
- data/Rakefile +26 -7
- data/ext/.gitignore +5 -7
- data/ext/dependencies.rb +61 -0
- data/ext/extconf.rb +21 -198
- data/ext/options.rb +221 -0
- data/ext/ruby_whisper.c +159 -0
- data/ext/ruby_whisper.h +17 -2
- data/ext/ruby_whisper_context.c +641 -0
- data/ext/ruby_whisper_error.c +52 -0
- data/ext/ruby_whisper_model.c +232 -0
- data/ext/ruby_whisper_params.c +1301 -0
- data/ext/ruby_whisper_segment.c +143 -0
- data/ext/ruby_whisper_transcribe.cpp +87 -0
- data/ext/ruby_whisper_vad_params.c +288 -0
- data/ext/sources/.dockerignore +3 -0
- data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
- data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
- data/ext/sources/CMakeLists.txt +251 -0
- data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
- data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
- data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
- data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
- data/ext/sources/bindings/javascript/package.json +26 -0
- data/ext/sources/bindings/javascript/whisper.js +19 -0
- data/ext/sources/build-xcframework.sh +547 -0
- data/ext/sources/ci/run.sh +336 -0
- data/ext/sources/close-issue.yml +28 -0
- data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
- data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
- data/ext/sources/cmake/build-info.cmake +60 -0
- data/ext/sources/cmake/git-vars.cmake +22 -0
- data/ext/sources/cmake/whisper-config.cmake.in +65 -0
- data/ext/sources/cmake/whisper.pc.in +10 -0
- data/ext/sources/examples/CMakeLists.txt +124 -0
- data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
- data/ext/sources/examples/addon.node/addon.cpp +438 -0
- data/ext/sources/examples/addon.node/index.js +54 -0
- data/ext/sources/examples/addon.node/package.json +16 -0
- data/ext/sources/examples/bench/CMakeLists.txt +8 -0
- data/ext/sources/examples/bench/bench.cpp +175 -0
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
- data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
- data/ext/sources/examples/cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/cli/cli.cpp +1294 -0
- data/ext/sources/examples/coi-serviceworker.js +146 -0
- data/ext/sources/examples/command/CMakeLists.txt +10 -0
- data/ext/sources/examples/command/command.cpp +776 -0
- data/ext/sources/examples/command/commands.txt +9 -0
- data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
- data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/common-ggml.cpp +238 -0
- data/ext/sources/examples/common-ggml.h +18 -0
- data/ext/sources/examples/common-sdl.cpp +227 -0
- data/ext/sources/examples/common-sdl.h +49 -0
- data/ext/sources/examples/common-whisper.cpp +168 -0
- data/ext/sources/examples/common-whisper.h +24 -0
- data/ext/sources/examples/common.cpp +675 -0
- data/ext/sources/examples/common.h +322 -0
- data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
- data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
- data/ext/sources/examples/generate-karaoke.sh +57 -0
- data/ext/sources/examples/grammar-parser.cpp +423 -0
- data/ext/sources/examples/grammar-parser.h +29 -0
- data/ext/sources/examples/helpers.js +191 -0
- data/ext/sources/examples/json.hpp +24596 -0
- data/ext/sources/examples/livestream.sh +112 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
- data/ext/sources/examples/lsp/lsp.cpp +467 -0
- data/ext/sources/examples/lsp/whisper.vim +362 -0
- data/ext/sources/examples/miniaudio.h +93468 -0
- data/ext/sources/examples/python/test_whisper_processor.py +7 -0
- data/ext/sources/examples/python/whisper_processor.py +54 -0
- data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
- data/ext/sources/examples/quantize/quantize.cpp +223 -0
- data/ext/sources/examples/server/CMakeLists.txt +12 -0
- data/ext/sources/examples/server/bench.js +29 -0
- data/ext/sources/examples/server/httplib.h +10497 -0
- data/ext/sources/examples/server/server.cpp +1091 -0
- data/ext/sources/examples/server.py +115 -0
- data/ext/sources/examples/stb_vorbis.c +5584 -0
- data/ext/sources/examples/stream/CMakeLists.txt +10 -0
- data/ext/sources/examples/stream/stream.cpp +429 -0
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
- data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
- data/ext/sources/examples/sycl/build.sh +22 -0
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
- data/ext/sources/examples/sycl/run-whisper.sh +17 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
- data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
- data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
- data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
- data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
- data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
- data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
- data/ext/sources/examples/talk-llama/llama-context.h +276 -0
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
- data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
- data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
- data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
- data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
- data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
- data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
- data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
- data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
- data/ext/sources/examples/talk-llama/llama-io.h +35 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
- data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
- data/ext/sources/examples/talk-llama/llama-model.h +425 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
- data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
- data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
- data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
- data/ext/sources/examples/talk-llama/llama.cpp +354 -0
- data/ext/sources/examples/talk-llama/llama.h +1377 -0
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
- data/ext/sources/examples/talk-llama/speak +40 -0
- data/ext/sources/examples/talk-llama/speak.bat +1 -0
- data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
- data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
- data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
- data/ext/sources/examples/talk-llama/unicode.h +66 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
- data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
- data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
- data/ext/sources/ggml/CMakeLists.txt +390 -0
- data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
- data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
- data/ext/sources/ggml/cmake/common.cmake +26 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
- data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
- data/ext/{ggml → sources/ggml}/include/ggml-backend.h +9 -7
- data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
- data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +9 -1
- data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
- data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
- data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
- data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
- data/ext/{ggml → sources/ggml}/include/ggml.h +182 -265
- data/ext/sources/ggml/include/gguf.h +202 -0
- data/ext/sources/ggml/src/CMakeLists.txt +346 -0
- data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
- data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +87 -53
- data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +26 -14
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
- data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
- data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +135 -1
- data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +564 -146
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
- data/ext/{ggml → sources/ggml}/src/ggml-common.h +12 -8
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +2 -1
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/cpu-feats-x86.cpp +5 -1
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +163 -41
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.c +4029 -1117
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +67 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
- data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
- data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
- data/ext/{ggml → sources/ggml}/src/ggml-impl.h +64 -19
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
- data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2178 -1064
- data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +1575 -1218
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
- data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
- data/ext/{ggml → sources/ggml}/src/ggml-quants.c +114 -120
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +480 -73
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +32 -33
- data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +4 -2
- data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/convert.cpp +104 -28
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +156 -17
- data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1004 -1240
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
- data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +0 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmvq.cpp +261 -166
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/norm.cpp +204 -81
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
- data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +35 -25
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
- data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
- data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3130 -1087
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
- data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
- data/ext/{ggml → sources/ggml}/src/ggml.c +676 -1820
- data/ext/sources/ggml/src/gguf.cpp +1330 -0
- data/ext/{include → sources/include}/whisper.h +68 -2
- data/ext/sources/src/CMakeLists.txt +143 -0
- data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
- data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +35 -10
- data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
- data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +28 -3
- data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
- data/ext/sources/src/whisper-arch.h +197 -0
- data/ext/{src → sources/src}/whisper.cpp +1905 -374
- data/ext/sources/tests/CMakeLists.txt +105 -0
- data/ext/sources/tests/earnings21/eval.mk +58 -0
- data/ext/sources/tests/earnings21/eval.py +68 -0
- data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
- data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
- data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
- data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
- data/ext/sources/tests/earnings21/requirements.txt +6 -0
- data/ext/sources/tests/en-0-ref.txt +1 -0
- data/ext/sources/tests/en-1-ref.txt +1 -0
- data/ext/sources/tests/en-2-ref.txt +1 -0
- data/ext/sources/tests/es-0-ref.txt +1 -0
- data/ext/sources/tests/librispeech/eval.mk +39 -0
- data/ext/sources/tests/librispeech/eval.py +47 -0
- data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
- data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
- data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
- data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
- data/ext/sources/tests/librispeech/requirements.txt +6 -0
- data/ext/sources/tests/run-tests.sh +130 -0
- data/ext/sources/tests/test-c.c +3 -0
- data/ext/sources/tests/test-vad-full.cpp +54 -0
- data/ext/sources/tests/test-vad.cpp +83 -0
- data/ext/sources/tests/test-whisper.js +58 -0
- data/extsources.rb +33 -5
- data/lib/whisper/model/uri.rb +149 -128
- data/sig/whisper.rbs +480 -0
- data/tests/helper.rb +28 -0
- data/tests/test_callback.rb +45 -3
- data/tests/test_error.rb +2 -2
- data/tests/test_model.rb +38 -0
- data/tests/test_package.rb +18 -3
- data/tests/test_params.rb +145 -8
- data/tests/test_segment.rb +10 -19
- data/tests/test_vad.rb +19 -0
- data/tests/test_vad_params.rb +103 -0
- data/tests/test_whisper.rb +37 -37
- data/whispercpp.gemspec +5 -4
- metadata +766 -111
- data/ext/cpu.mk +0 -9
- data/ext/examples/dr_wav.h +0 -8815
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
- data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
- data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
- data/ext/metal-embed.mk +0 -17
- data/ext/metal.mk +0 -6
- data/ext/ruby_whisper.cpp +0 -1909
- data/ext/scripts/get-flags.mk +0 -38
- data/lib/whisper.rb +0 -2
- /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-aarch64.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
- /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
- /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
- /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
@@ -0,0 +1,1471 @@
|
|
1
|
+
#include "common.cuh"
|
2
|
+
#include "cp-async.cuh"
|
3
|
+
#include "mma.cuh"
|
4
|
+
#include "fattn-common.cuh"
|
5
|
+
|
6
|
+
using namespace ggml_cuda_mma;
|
7
|
+
|
8
|
+
typedef tile<16, 8, half2> tile_A;
|
9
|
+
typedef tile< 8, 8, half2> tile_B;
|
10
|
+
typedef tile<16, 8, half2> tile_B_16;
|
11
|
+
typedef tile<16, 8, float> tile_C_KQ;
|
12
|
+
typedef tile<16, 16, float> tile_C_KQ_16;
|
13
|
+
typedef tile<16, 4, half2> tile_C_VKQ;
|
14
|
+
typedef tile<16, 8, half2> tile_C_VKQ_16;
|
15
|
+
|
16
|
+
// Config options for specific head sizes.
|
17
|
+
// Should not affect results, only speed/register pressure/shared memory use.
|
18
|
+
//
|
19
|
+
// nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
|
20
|
+
// nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory).
|
21
|
+
// Q_in_reg: whether the Q values should be kept permanently in registers.
|
22
|
+
// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading.
|
23
|
+
// nbatch_K2: number of K half2 values in direction of DKQ to load in parallel.
|
24
|
+
// nbatch_V2: number of V half2 values in direction of DV to load in parallel.
|
25
|
+
// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel.
|
26
|
+
|
27
|
+
template <int DKQ, int DV>
|
28
|
+
struct fattn_mma_f16_config;
|
29
|
+
|
30
|
+
template <>
|
31
|
+
struct fattn_mma_f16_config< 64, 64> {
|
32
|
+
static constexpr int nbatch_fa = 64;
|
33
|
+
static constexpr int nwarps_max = 4;
|
34
|
+
static constexpr bool Q_in_reg = true;
|
35
|
+
static constexpr int nstages_target = 2;
|
36
|
+
|
37
|
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
38
|
+
return 32;
|
39
|
+
}
|
40
|
+
|
41
|
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
42
|
+
return 32;
|
43
|
+
}
|
44
|
+
|
45
|
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
46
|
+
return 32;
|
47
|
+
}
|
48
|
+
|
49
|
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
50
|
+
return 32;
|
51
|
+
}
|
52
|
+
|
53
|
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
54
|
+
return 32;
|
55
|
+
}
|
56
|
+
|
57
|
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
58
|
+
return 32;
|
59
|
+
}
|
60
|
+
};
|
61
|
+
|
62
|
+
template <>
|
63
|
+
struct fattn_mma_f16_config< 80, 80> {
|
64
|
+
static constexpr int nbatch_fa = 64;
|
65
|
+
static constexpr int nwarps_max = 4;
|
66
|
+
static constexpr bool Q_in_reg = true;
|
67
|
+
static constexpr int nstages_target = 2;
|
68
|
+
|
69
|
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
70
|
+
return 40;
|
71
|
+
}
|
72
|
+
|
73
|
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
74
|
+
return 40;
|
75
|
+
}
|
76
|
+
|
77
|
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
78
|
+
return 40;
|
79
|
+
}
|
80
|
+
|
81
|
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
82
|
+
return 40;
|
83
|
+
}
|
84
|
+
|
85
|
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
86
|
+
return 40;
|
87
|
+
}
|
88
|
+
|
89
|
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
90
|
+
return 40;
|
91
|
+
}
|
92
|
+
};
|
93
|
+
|
94
|
+
template <>
|
95
|
+
struct fattn_mma_f16_config< 96, 96> {
|
96
|
+
static constexpr int nbatch_fa = 64;
|
97
|
+
static constexpr int nwarps_max = 4;
|
98
|
+
static constexpr bool Q_in_reg = true;
|
99
|
+
static constexpr int nstages_target = 2;
|
100
|
+
|
101
|
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
102
|
+
return 48;
|
103
|
+
}
|
104
|
+
|
105
|
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
106
|
+
return 48;
|
107
|
+
}
|
108
|
+
|
109
|
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
110
|
+
return 48;
|
111
|
+
}
|
112
|
+
|
113
|
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
114
|
+
return 48;
|
115
|
+
}
|
116
|
+
|
117
|
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
118
|
+
return 48;
|
119
|
+
}
|
120
|
+
|
121
|
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
122
|
+
return 48;
|
123
|
+
}
|
124
|
+
};
|
125
|
+
|
126
|
+
template <>
|
127
|
+
struct fattn_mma_f16_config<112, 112> {
|
128
|
+
static constexpr int nbatch_fa = 64;
|
129
|
+
static constexpr int nwarps_max = 4;
|
130
|
+
static constexpr bool Q_in_reg = true;
|
131
|
+
static constexpr int nstages_target = 2;
|
132
|
+
|
133
|
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
134
|
+
return 56;
|
135
|
+
}
|
136
|
+
|
137
|
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
138
|
+
return 56;
|
139
|
+
}
|
140
|
+
|
141
|
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
142
|
+
return 56;
|
143
|
+
}
|
144
|
+
|
145
|
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
146
|
+
return 56;
|
147
|
+
}
|
148
|
+
|
149
|
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
150
|
+
return 56;
|
151
|
+
}
|
152
|
+
|
153
|
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
154
|
+
return 56;
|
155
|
+
}
|
156
|
+
};
|
157
|
+
|
158
|
+
template <>
|
159
|
+
struct fattn_mma_f16_config<128, 128> {
|
160
|
+
static constexpr int nbatch_fa = 64;
|
161
|
+
static constexpr int nwarps_max = 4;
|
162
|
+
static constexpr bool Q_in_reg = true;
|
163
|
+
static constexpr int nstages_target = 2;
|
164
|
+
|
165
|
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
166
|
+
return 64;
|
167
|
+
}
|
168
|
+
|
169
|
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
170
|
+
return 64;
|
171
|
+
}
|
172
|
+
|
173
|
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
174
|
+
return 64;
|
175
|
+
}
|
176
|
+
|
177
|
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
178
|
+
return 64;
|
179
|
+
}
|
180
|
+
|
181
|
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
182
|
+
return 64;
|
183
|
+
}
|
184
|
+
|
185
|
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
186
|
+
return 64;
|
187
|
+
}
|
188
|
+
};
|
189
|
+
|
190
|
+
template <>
|
191
|
+
struct fattn_mma_f16_config<256, 256> {
|
192
|
+
static constexpr int nbatch_fa = 32;
|
193
|
+
static constexpr int nwarps_max = 4;
|
194
|
+
static constexpr bool Q_in_reg = true;
|
195
|
+
static constexpr int nstages_target = 2;
|
196
|
+
|
197
|
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
198
|
+
return 128;
|
199
|
+
}
|
200
|
+
|
201
|
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
202
|
+
return 128;
|
203
|
+
}
|
204
|
+
|
205
|
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
206
|
+
return 128;
|
207
|
+
}
|
208
|
+
|
209
|
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
210
|
+
return 128;
|
211
|
+
}
|
212
|
+
|
213
|
+
static int get_nbatch_combine_host(const int cc, const int ncols) {
|
214
|
+
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
215
|
+
return ncols <= 16 ? 128 : 64;
|
216
|
+
}
|
217
|
+
return 64;
|
218
|
+
}
|
219
|
+
|
220
|
+
static constexpr __device__ int get_nbatch_combine_device(int ncols) {
|
221
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
222
|
+
return ncols <= 16 ? 128 : 64;
|
223
|
+
#else
|
224
|
+
GGML_UNUSED(ncols);
|
225
|
+
return 128;
|
226
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
227
|
+
}
|
228
|
+
};
|
229
|
+
|
230
|
+
template <>
|
231
|
+
struct fattn_mma_f16_config<576, 512> {
|
232
|
+
static constexpr int nbatch_fa = 32;
|
233
|
+
static constexpr int nwarps_max = 8;
|
234
|
+
static constexpr bool Q_in_reg = false;
|
235
|
+
static constexpr int nstages_target = 1;
|
236
|
+
|
237
|
+
static int get_nbatch_K2_host(const int cc, const int ncols) {
|
238
|
+
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
239
|
+
return ncols <= 16 ? 96 : 160;
|
240
|
+
}
|
241
|
+
return ncols <= 16 ? 288 : 160;
|
242
|
+
}
|
243
|
+
|
244
|
+
static constexpr __device__ int get_nbatch_K2_device(int ncols) {
|
245
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
246
|
+
return ncols <= 16 ? 96 : 160;
|
247
|
+
#else
|
248
|
+
return ncols <= 16 ? 288 : 160;
|
249
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
250
|
+
}
|
251
|
+
|
252
|
+
static int get_nbatch_V2_host(const int cc, const int ncols) {
|
253
|
+
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
254
|
+
return ncols <= 16 ? 64 : 128;
|
255
|
+
}
|
256
|
+
return ncols <= 16 ? 256 : 128;
|
257
|
+
}
|
258
|
+
|
259
|
+
static constexpr __device__ int get_nbatch_V2_device(int ncols) {
|
260
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
261
|
+
return ncols <= 16 ? 64 : 128;
|
262
|
+
#else
|
263
|
+
return ncols <= 16 ? 256 : 128;
|
264
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
265
|
+
}
|
266
|
+
|
267
|
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
268
|
+
return 128;
|
269
|
+
}
|
270
|
+
|
271
|
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
272
|
+
return 128;
|
273
|
+
}
|
274
|
+
};
|
275
|
+
|
276
|
+
// ------------------------------------------------------------------------------------------------------------------
|
277
|
+
|
278
|
+
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
|
279
|
+
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
280
|
+
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
|
281
|
+
|
282
|
+
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
283
|
+
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
|
284
|
+
|
285
|
+
if (use_cp_async) {
|
286
|
+
constexpr int preload = 64;
|
287
|
+
constexpr int h2_per_chunk = 16/sizeof(half2);
|
288
|
+
const int chunks_per_row = D2 / h2_per_chunk;
|
289
|
+
|
290
|
+
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
291
|
+
|
292
|
+
auto load = [&] __device__ (auto n) {
|
293
|
+
const int stride_k = WARP_SIZE >> n;
|
294
|
+
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
295
|
+
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
296
|
+
const int stride_i = WARP_SIZE / stride_k;
|
297
|
+
|
298
|
+
if (k0_start == k0_stop) {
|
299
|
+
return;
|
300
|
+
}
|
301
|
+
|
302
|
+
#pragma unroll
|
303
|
+
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
304
|
+
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
305
|
+
|
306
|
+
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
307
|
+
break;
|
308
|
+
}
|
309
|
+
|
310
|
+
#pragma unroll
|
311
|
+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
312
|
+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
313
|
+
|
314
|
+
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
|
315
|
+
}
|
316
|
+
}
|
317
|
+
};
|
318
|
+
ggml_cuda_unroll<5>{}(load);
|
319
|
+
} else {
|
320
|
+
static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
|
321
|
+
auto load = [&] __device__ (const int n) {
|
322
|
+
const int stride_k = WARP_SIZE >> n;
|
323
|
+
const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
|
324
|
+
const int k0_stop = D2 - D2 % (1*stride_k);
|
325
|
+
const int stride_i = WARP_SIZE / stride_k;
|
326
|
+
|
327
|
+
if (k0_start == k0_stop) {
|
328
|
+
return;
|
329
|
+
}
|
330
|
+
|
331
|
+
#pragma unroll
|
332
|
+
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
333
|
+
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
334
|
+
|
335
|
+
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
336
|
+
break;
|
337
|
+
}
|
338
|
+
|
339
|
+
#pragma unroll
|
340
|
+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
341
|
+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
342
|
+
|
343
|
+
tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
|
344
|
+
}
|
345
|
+
}
|
346
|
+
};
|
347
|
+
ggml_cuda_unroll<3>{}(load);
|
348
|
+
}
|
349
|
+
}
|
350
|
+
|
351
|
+
template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async>
|
352
|
+
static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
353
|
+
const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
|
354
|
+
static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter");
|
355
|
+
|
356
|
+
if (use_cp_async) {
|
357
|
+
constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
|
358
|
+
constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
|
359
|
+
constexpr int stride_j = nwarps * cols_per_warp;
|
360
|
+
|
361
|
+
const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
|
362
|
+
|
363
|
+
#pragma unroll
|
364
|
+
for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
|
365
|
+
const int j = j0 + threadIdx.y*cols_per_warp +
|
366
|
+
(nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp));
|
367
|
+
|
368
|
+
if (j0 + stride_j > ncols1 && j >= ncols1) {
|
369
|
+
break;
|
370
|
+
}
|
371
|
+
|
372
|
+
const int i = 4 * (threadIdx.x % (nbatch_fa/8));
|
373
|
+
|
374
|
+
cp_async_cg_16<preload>(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i);
|
375
|
+
}
|
376
|
+
return;
|
377
|
+
}
|
378
|
+
|
379
|
+
constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
|
380
|
+
constexpr int stride_j = nwarps * cols_per_warp;
|
381
|
+
#pragma unroll
|
382
|
+
for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
|
383
|
+
const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp));
|
384
|
+
|
385
|
+
if (j0 + stride_j > ncols1 && j >= ncols1) {
|
386
|
+
break;
|
387
|
+
}
|
388
|
+
|
389
|
+
const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp);
|
390
|
+
|
391
|
+
tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i];
|
392
|
+
}
|
393
|
+
}
|
394
|
+
|
395
|
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
|
396
|
+
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
397
|
+
const float2 * const __restrict__ Q_f2,
|
398
|
+
const half2 * const __restrict__ K_h2,
|
399
|
+
const half2 * const __restrict__ V_h2,
|
400
|
+
const half2 * const __restrict__ mask_h2,
|
401
|
+
float2 * const __restrict__ dstk,
|
402
|
+
float2 * const __restrict__ dstk_fixup,
|
403
|
+
const float scale,
|
404
|
+
const float slope,
|
405
|
+
const float logit_softcap,
|
406
|
+
const int ne01,
|
407
|
+
const int ne02,
|
408
|
+
const int stride_K,
|
409
|
+
const int stride_V,
|
410
|
+
const int stride_mask,
|
411
|
+
const int jt,
|
412
|
+
half2 * const __restrict__ tile_Q,
|
413
|
+
half2 * const __restrict__ tile_K,
|
414
|
+
half2 * const __restrict__ tile_V,
|
415
|
+
half2 * const __restrict__ tile_mask,
|
416
|
+
const tile_B * const __restrict__ Q_B,
|
417
|
+
tile_C_VKQ * const __restrict__ VKQ_C,
|
418
|
+
float * const __restrict__ KQ_max,
|
419
|
+
float * const __restrict__ KQ_rowsum,
|
420
|
+
const int kb0) {
|
421
|
+
#ifdef NEW_MMA_AVAILABLE
|
422
|
+
typedef fattn_mma_f16_config<DKQ, DV> c;
|
423
|
+
|
424
|
+
#ifdef CP_ASYNC_AVAILABLE
|
425
|
+
constexpr int nstages = c::nstages_target;
|
426
|
+
#else
|
427
|
+
constexpr int nstages = 0;
|
428
|
+
#endif // CP_ASYNC_AVAILABLE
|
429
|
+
|
430
|
+
constexpr int cols_per_warp = ntiles * tile_B::I;
|
431
|
+
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
432
|
+
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
433
|
+
constexpr int ncols = ncols1 * ncols2;
|
434
|
+
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
435
|
+
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
436
|
+
|
437
|
+
constexpr int stride_tile_Q = DKQ/2 + 4;
|
438
|
+
constexpr int stride_tile_K = nbatch_K2 + 4;
|
439
|
+
|
440
|
+
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
441
|
+
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
442
|
+
|
443
|
+
const int k_VKQ_0 = kb0 * c::nbatch_fa;
|
444
|
+
tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
|
445
|
+
|
446
|
+
// Use wide variants of tiles if ntiles >= 2.
|
447
|
+
tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
|
448
|
+
tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
|
449
|
+
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
|
450
|
+
|
451
|
+
if constexpr (nstages > 1) {
|
452
|
+
static_assert(!mla, "multi-stage loading not implemented for MLA");
|
453
|
+
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
454
|
+
constexpr bool use_cp_async = true;
|
455
|
+
cp_async_wait_all();
|
456
|
+
__syncthreads();
|
457
|
+
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
458
|
+
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
459
|
+
} else {
|
460
|
+
constexpr bool use_cp_async = nstages == 1;
|
461
|
+
if (ncols2 > 1 || mask_h2) {
|
462
|
+
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
|
463
|
+
}
|
464
|
+
}
|
465
|
+
|
466
|
+
#pragma unroll
|
467
|
+
for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
|
468
|
+
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
469
|
+
const int k0_diff = k0_stop - k0_start;
|
470
|
+
|
471
|
+
if (nstages <= 1) {
|
472
|
+
constexpr bool use_cp_async = nstages == 1;
|
473
|
+
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
474
|
+
(K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
475
|
+
if (use_cp_async) {
|
476
|
+
cp_async_wait_all();
|
477
|
+
}
|
478
|
+
__syncthreads();
|
479
|
+
}
|
480
|
+
|
481
|
+
// Calculate tile of KQ:
|
482
|
+
if constexpr (c::Q_in_reg) {
|
483
|
+
#pragma unroll
|
484
|
+
for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
|
485
|
+
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
|
486
|
+
#pragma unroll
|
487
|
+
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
|
488
|
+
tile_A K_A;
|
489
|
+
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
490
|
+
if (ntiles == 1) {
|
491
|
+
mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
|
492
|
+
} else {
|
493
|
+
#pragma unroll
|
494
|
+
for (int t = 0; t < ntiles/2; ++t) {
|
495
|
+
// Wide version of KQ_C is column-major => swap A and B.
|
496
|
+
mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
|
497
|
+
}
|
498
|
+
}
|
499
|
+
}
|
500
|
+
}
|
501
|
+
} else {
|
502
|
+
static_assert(ntiles == 2, "ntiles != 2 not implemented");
|
503
|
+
#pragma unroll
|
504
|
+
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
|
505
|
+
load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
506
|
+
|
507
|
+
#pragma unroll
|
508
|
+
for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
|
509
|
+
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
|
510
|
+
|
511
|
+
tile_A K_A;
|
512
|
+
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
513
|
+
|
514
|
+
// Wide version of KQ_C is column-major => swap A and B.
|
515
|
+
mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A);
|
516
|
+
}
|
517
|
+
}
|
518
|
+
}
|
519
|
+
|
520
|
+
if (nstages <= 1) {
|
521
|
+
__syncthreads(); // Only needed if tile_K == tile_V.
|
522
|
+
}
|
523
|
+
}
|
524
|
+
|
525
|
+
if (use_logit_softcap) {
|
526
|
+
static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
|
527
|
+
#pragma unroll
|
528
|
+
for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) {
|
529
|
+
#pragma unroll
|
530
|
+
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
531
|
+
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
|
532
|
+
}
|
533
|
+
}
|
534
|
+
}
|
535
|
+
|
536
|
+
float KQ_max_new[cols_per_thread];
|
537
|
+
#pragma unroll
|
538
|
+
for (int col = 0; col < cols_per_thread; ++col) {
|
539
|
+
KQ_max_new[col] = KQ_max[col];
|
540
|
+
}
|
541
|
+
float KQ_rowsum_add[cols_per_thread] = {0.0f};
|
542
|
+
|
543
|
+
if (ntiles == 1) {
|
544
|
+
if (ncols2 > 1 || mask_h2) {
|
545
|
+
#pragma unroll
|
546
|
+
for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
|
547
|
+
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
|
548
|
+
#pragma unroll
|
549
|
+
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
550
|
+
const int i = i0 + tile_C_KQ::get_i(l);
|
551
|
+
const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2;
|
552
|
+
|
553
|
+
KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope *
|
554
|
+
__half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]);
|
555
|
+
}
|
556
|
+
}
|
557
|
+
}
|
558
|
+
|
559
|
+
// Calculate softmax for each KQ column using the current max. value.
|
560
|
+
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
561
|
+
static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
|
562
|
+
#pragma unroll
|
563
|
+
for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
|
564
|
+
#pragma unroll
|
565
|
+
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
566
|
+
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
|
567
|
+
}
|
568
|
+
}
|
569
|
+
|
570
|
+
// Values per KQ column are spread across 8 threads, does not need full warp reduce:
|
571
|
+
#pragma unroll
|
572
|
+
for (int col = 0; col < cols_per_thread; ++col) {
|
573
|
+
#pragma unroll
|
574
|
+
for (int offset = 16; offset >= 4; offset >>= 1) {
|
575
|
+
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
|
576
|
+
}
|
577
|
+
}
|
578
|
+
|
579
|
+
static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
|
580
|
+
#pragma unroll
|
581
|
+
for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
|
582
|
+
#pragma unroll
|
583
|
+
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
584
|
+
KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]);
|
585
|
+
|
586
|
+
KQ_rowsum_add[l % 2] += KQ_C[k].x[l];
|
587
|
+
}
|
588
|
+
}
|
589
|
+
} else { // ntiles > 1
|
590
|
+
if (ncols2 > 1 || mask_h2) {
|
591
|
+
#pragma unroll
|
592
|
+
for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) {
|
593
|
+
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
|
594
|
+
#pragma unroll
|
595
|
+
for (int t = 0; t < ntiles/2; ++t) {
|
596
|
+
#pragma unroll
|
597
|
+
for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) {
|
598
|
+
const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
|
599
|
+
const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
|
600
|
+
|
601
|
+
const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]);
|
602
|
+
const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t;
|
603
|
+
KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x;
|
604
|
+
KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
|
605
|
+
}
|
606
|
+
}
|
607
|
+
}
|
608
|
+
}
|
609
|
+
|
610
|
+
// Calculate softmax for each KQ column using the current max. value.
|
611
|
+
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
612
|
+
static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
|
613
|
+
#pragma unroll
|
614
|
+
for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
|
615
|
+
#pragma unroll
|
616
|
+
for (int t = 0; t < ntiles/2; ++t) {
|
617
|
+
#pragma unroll
|
618
|
+
for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
|
619
|
+
const int KQ_index = 2*t + (l/2) % 2;
|
620
|
+
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]);
|
621
|
+
}
|
622
|
+
}
|
623
|
+
}
|
624
|
+
|
625
|
+
// Values per KQ column are spread across 4 threads, does not need full warp reduce:
|
626
|
+
#pragma unroll
|
627
|
+
for (int col = 0; col < cols_per_thread; ++col) {
|
628
|
+
#pragma unroll
|
629
|
+
for (int offset = 2; offset >= 1; offset >>= 1) {
|
630
|
+
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
|
631
|
+
}
|
632
|
+
}
|
633
|
+
|
634
|
+
static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size");
|
635
|
+
#pragma unroll
|
636
|
+
for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
|
637
|
+
#pragma unroll
|
638
|
+
for (int t = 0; t < ntiles/2; ++t) {
|
639
|
+
#pragma unroll
|
640
|
+
for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
|
641
|
+
const int KQ_index = 2*t + (l/2) % 2;
|
642
|
+
|
643
|
+
KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]);
|
644
|
+
|
645
|
+
KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l];
|
646
|
+
}
|
647
|
+
}
|
648
|
+
}
|
649
|
+
}
|
650
|
+
|
651
|
+
{
|
652
|
+
float KQ_max_scale[cols_per_thread];
|
653
|
+
#pragma unroll
|
654
|
+
for (int col = 0; col < cols_per_thread; ++col) {
|
655
|
+
KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]);
|
656
|
+
KQ_max[col] = KQ_max_new[col];
|
657
|
+
|
658
|
+
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
659
|
+
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
|
660
|
+
}
|
661
|
+
|
662
|
+
if (ntiles == 1) {
|
663
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
664
|
+
#pragma unroll
|
665
|
+
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
|
666
|
+
#pragma unroll
|
667
|
+
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
|
668
|
+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
669
|
+
}
|
670
|
+
}
|
671
|
+
} else {
|
672
|
+
#pragma unroll
|
673
|
+
for (int col = 0; col < cols_per_thread; ++col) {
|
674
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
675
|
+
#pragma unroll
|
676
|
+
for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
|
677
|
+
#pragma unroll
|
678
|
+
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
|
679
|
+
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
|
680
|
+
}
|
681
|
+
}
|
682
|
+
}
|
683
|
+
}
|
684
|
+
}
|
685
|
+
|
686
|
+
// Convert KQ C tiles into B tiles for VKQ calculation:
|
687
|
+
tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles];
|
688
|
+
tile_B_16 * B_16 = (tile_B_16 *) B;
|
689
|
+
static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size");
|
690
|
+
if (ntiles == 1) {
|
691
|
+
#pragma unroll
|
692
|
+
for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) {
|
693
|
+
B[k] = get_transposed(get_half2(KQ_C[k]));
|
694
|
+
}
|
695
|
+
} else {
|
696
|
+
for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) {
|
697
|
+
#pragma unroll
|
698
|
+
for (int t = 0; t < ntiles/2; ++t) {
|
699
|
+
B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
|
700
|
+
}
|
701
|
+
}
|
702
|
+
}
|
703
|
+
|
704
|
+
if (nstages > 1) {
|
705
|
+
// Preload K tile for next iteration:
|
706
|
+
constexpr bool use_cp_async = true;
|
707
|
+
cp_async_wait_all();
|
708
|
+
__syncthreads();
|
709
|
+
if (!last_iter) {
|
710
|
+
if (ncols2 > 1 || mask_h2) {
|
711
|
+
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
|
712
|
+
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
713
|
+
}
|
714
|
+
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
715
|
+
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
716
|
+
}
|
717
|
+
}
|
718
|
+
|
719
|
+
|
720
|
+
// For MLA K and V have the same data.
|
721
|
+
// Therefore, iterate over V in reverse and re-use the data if possible.
|
722
|
+
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
|
723
|
+
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
|
724
|
+
#pragma unroll
|
725
|
+
for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
|
726
|
+
const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
|
727
|
+
const int i0_diff = i0_stop - i0_start;
|
728
|
+
|
729
|
+
if (nstages <= 1 && i0_start < reusable_cutoff) {
|
730
|
+
constexpr bool use_cp_async = nstages == 1;
|
731
|
+
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
732
|
+
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
733
|
+
if (use_cp_async) {
|
734
|
+
cp_async_wait_all();
|
735
|
+
}
|
736
|
+
__syncthreads();
|
737
|
+
}
|
738
|
+
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
|
739
|
+
|
740
|
+
// Calculate VKQ tile:
|
741
|
+
#pragma unroll
|
742
|
+
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) {
|
743
|
+
static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size");
|
744
|
+
#pragma unroll
|
745
|
+
for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) {
|
746
|
+
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
|
747
|
+
|
748
|
+
tile_A A;
|
749
|
+
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
750
|
+
if (ntiles == 1) {
|
751
|
+
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
|
752
|
+
} else {
|
753
|
+
#pragma unroll
|
754
|
+
for (int t = 0; t < ntiles/2; ++t) {
|
755
|
+
// Wide version of VKQ_C is column-major => swap A and B.
|
756
|
+
mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
|
757
|
+
}
|
758
|
+
}
|
759
|
+
}
|
760
|
+
}
|
761
|
+
|
762
|
+
if (nstages <= 1) {
|
763
|
+
__syncthreads(); // Only needed if tile_K == tile_V.
|
764
|
+
}
|
765
|
+
}
|
766
|
+
#else
|
767
|
+
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
|
768
|
+
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
769
|
+
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
770
|
+
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
|
771
|
+
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
772
|
+
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
773
|
+
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
|
774
|
+
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
|
775
|
+
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
|
776
|
+
NO_DEVICE_CODE;
|
777
|
+
#endif // NEW_MMA_AVAILABLE
|
778
|
+
}
|
779
|
+
|
780
|
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
|
781
|
+
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
782
|
+
const float2 * const __restrict__ Q_f2,
|
783
|
+
const half2 * const __restrict__ K_h2,
|
784
|
+
const half2 * const __restrict__ V_h2,
|
785
|
+
const half2 * const __restrict__ mask_h2,
|
786
|
+
float2 * const __restrict__ dstk,
|
787
|
+
float2 * const __restrict__ dstk_fixup,
|
788
|
+
const float scale,
|
789
|
+
const float slope,
|
790
|
+
const float logit_softcap,
|
791
|
+
const int ne01,
|
792
|
+
const int ne02,
|
793
|
+
const int stride_Q1,
|
794
|
+
const int stride_Q2,
|
795
|
+
const int stride_K,
|
796
|
+
const int stride_V,
|
797
|
+
const int stride_mask,
|
798
|
+
const int jt,
|
799
|
+
const int kb0_start,
|
800
|
+
const int kb0_stop) {
|
801
|
+
#ifdef NEW_MMA_AVAILABLE
|
802
|
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
803
|
+
|
804
|
+
typedef fattn_mma_f16_config<DKQ, DV> c;
|
805
|
+
|
806
|
+
#ifdef CP_ASYNC_AVAILABLE
|
807
|
+
constexpr int nstages = c::nstages_target;
|
808
|
+
#else
|
809
|
+
constexpr int nstages = 0;
|
810
|
+
#endif // CP_ASYNC_AVAILABLE
|
811
|
+
|
812
|
+
constexpr int ncols = ncols1 * ncols2;
|
813
|
+
constexpr int cols_per_warp = ntiles * tile_B::I;
|
814
|
+
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
815
|
+
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
816
|
+
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
817
|
+
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
818
|
+
|
819
|
+
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
|
820
|
+
|
821
|
+
constexpr int stride_tile_Q = DKQ/2 + 4;
|
822
|
+
constexpr int stride_tile_K = nbatch_K2 + 4;
|
823
|
+
|
824
|
+
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
825
|
+
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
826
|
+
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
827
|
+
|
828
|
+
extern __shared__ half2 tile_Q[];
|
829
|
+
half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q;
|
830
|
+
half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K;
|
831
|
+
half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max;
|
832
|
+
|
833
|
+
tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles];
|
834
|
+
tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles];
|
835
|
+
|
836
|
+
tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
|
837
|
+
tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
|
838
|
+
|
839
|
+
float KQ_rowsum[cols_per_thread] = {0.0f};
|
840
|
+
float KQ_max[cols_per_thread];
|
841
|
+
#pragma unroll
|
842
|
+
for (int col = 0; col < cols_per_thread; ++col) {
|
843
|
+
KQ_max[col] = -FLT_MAX/2.0f;
|
844
|
+
}
|
845
|
+
|
846
|
+
// Load Q data into tile_Q, either temporarily or permanently.
|
847
|
+
// Q in registers is faster, but register pressure is the biggest bottleneck.
|
848
|
+
// The loading is done with decreasing granularity for D for better memory bandwidth.
|
849
|
+
const half2 scale_h2 = make_half2(scale, scale);
|
850
|
+
#pragma unroll
|
851
|
+
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
852
|
+
const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
|
853
|
+
const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k);
|
854
|
+
const int stride_jc = WARP_SIZE / stride_k;
|
855
|
+
|
856
|
+
if (k0_start == k0_stop) {
|
857
|
+
continue;
|
858
|
+
}
|
859
|
+
|
860
|
+
#pragma unroll
|
861
|
+
for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
|
862
|
+
const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
863
|
+
|
864
|
+
if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
|
865
|
+
break;
|
866
|
+
}
|
867
|
+
|
868
|
+
const int j = jc / ncols2;
|
869
|
+
const int c = jc % ncols2;
|
870
|
+
|
871
|
+
if (jt*ncols1 + j < ne01) {
|
872
|
+
#pragma unroll
|
873
|
+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
874
|
+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
875
|
+
|
876
|
+
const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
|
877
|
+
tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
|
878
|
+
}
|
879
|
+
} else {
|
880
|
+
#pragma unroll
|
881
|
+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
882
|
+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
883
|
+
|
884
|
+
tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
|
885
|
+
}
|
886
|
+
}
|
887
|
+
}
|
888
|
+
}
|
889
|
+
|
890
|
+
__syncthreads();
|
891
|
+
|
892
|
+
if (c::Q_in_reg) {
|
893
|
+
const int j0 = (threadIdx.y / np) * cols_per_warp;
|
894
|
+
|
895
|
+
#pragma unroll
|
896
|
+
for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) {
|
897
|
+
if (ntiles == 1) {
|
898
|
+
load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
|
899
|
+
} else {
|
900
|
+
#pragma unroll
|
901
|
+
for (int t = 0; t < ntiles/2; ++t) {
|
902
|
+
load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
|
903
|
+
tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q);
|
904
|
+
}
|
905
|
+
}
|
906
|
+
}
|
907
|
+
}
|
908
|
+
|
909
|
+
__syncthreads();
|
910
|
+
|
911
|
+
// Preload mask and K data for first iteration when using cp_async with multiple stages:
|
912
|
+
if constexpr (nstages > 1) {
|
913
|
+
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
|
914
|
+
constexpr bool use_cp_async = true;
|
915
|
+
if (ncols2 > 1 || mask_h2) {
|
916
|
+
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
|
917
|
+
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
918
|
+
}
|
919
|
+
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
920
|
+
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
921
|
+
}
|
922
|
+
|
923
|
+
// Iterate over ne11 == previous tokens:
|
924
|
+
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
925
|
+
constexpr bool last_iter = false;
|
926
|
+
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
927
|
+
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
928
|
+
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
929
|
+
}
|
930
|
+
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
931
|
+
constexpr bool last_iter = true;
|
932
|
+
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
933
|
+
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
934
|
+
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
935
|
+
}
|
936
|
+
|
937
|
+
// With multi-stage loading there is no __syncthreads at the end of the iter,
|
938
|
+
// there can be a race condition on shared memory access for combining/writing back results.
|
939
|
+
if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) {
|
940
|
+
__syncthreads();
|
941
|
+
}
|
942
|
+
|
943
|
+
// Finally, sum up partial KQ rowsums.
|
944
|
+
// The partial sums are spread across 8/4 threads each, does not need full reduce.
|
945
|
+
{
|
946
|
+
constexpr int offset_first = ntiles == 1 ? 16 : 2;
|
947
|
+
constexpr int offset_last = ntiles == 1 ? 4 : 1;
|
948
|
+
#pragma unroll
|
949
|
+
for (int col = 0; col < cols_per_thread; ++col) {
|
950
|
+
#pragma unroll
|
951
|
+
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
952
|
+
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
|
953
|
+
}
|
954
|
+
}
|
955
|
+
}
|
956
|
+
|
957
|
+
// Combine VKQ accumulator values if np > 1.
|
958
|
+
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
959
|
+
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
960
|
+
|
961
|
+
constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
|
962
|
+
constexpr int tile_stride = nbatch_combine + 4;
|
963
|
+
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
|
964
|
+
|
965
|
+
if constexpr (ntiles == 1) {
|
966
|
+
const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset
|
967
|
+
const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
|
968
|
+
const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum
|
969
|
+
|
970
|
+
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
|
971
|
+
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
972
|
+
((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
|
973
|
+
}
|
974
|
+
|
975
|
+
__syncthreads();
|
976
|
+
|
977
|
+
if (np == 1) {
|
978
|
+
// No combination is needed, the meta data can be directly written from registers to VRAM.
|
979
|
+
if (needs_fixup && threadIdx.x < tile_B::I) {
|
980
|
+
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
981
|
+
dstk_fixup_meta[jc_cwm] = KQ_cmr;
|
982
|
+
}
|
983
|
+
if (is_fixup && threadIdx.x < tile_B::I) {
|
984
|
+
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
985
|
+
dstk_fixup_meta[jc_cwm] = KQ_cmr;
|
986
|
+
}
|
987
|
+
}
|
988
|
+
} else {
|
989
|
+
static_assert(ntiles == 2 || ntiles == 4, "bad ntiles");
|
990
|
+
const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta
|
991
|
+
+ (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0)
|
992
|
+
+ tile_C_VKQ_16::get_i(threadIdx.x % 4);
|
993
|
+
const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum
|
994
|
+
|
995
|
+
if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) {
|
996
|
+
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
997
|
+
((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
|
998
|
+
}
|
999
|
+
|
1000
|
+
__syncthreads();
|
1001
|
+
|
1002
|
+
if (np == 1) {
|
1003
|
+
// No combination is needed, the meta data can be directly written from registers to VRAM.
|
1004
|
+
if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
|
1005
|
+
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
1006
|
+
dstk_fixup_meta[jc_cwm] = KQ_cmr;
|
1007
|
+
}
|
1008
|
+
if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
|
1009
|
+
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
1010
|
+
dstk_fixup_meta[jc_cwm] = KQ_cmr;
|
1011
|
+
}
|
1012
|
+
}
|
1013
|
+
}
|
1014
|
+
|
1015
|
+
static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles");
|
1016
|
+
if (np > 1 && threadIdx.y % np == 0) {
|
1017
|
+
// Combine the meta data for parallel warps via shared memory.
|
1018
|
+
// Warps with threadIdx.y % np != 0 must NOT return early.
|
1019
|
+
// All threads must return simultaneously to avoid race conditions with work on the next tile.
|
1020
|
+
|
1021
|
+
constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
|
1022
|
+
|
1023
|
+
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
|
1024
|
+
float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
|
1025
|
+
float2 meta[nmeta];
|
1026
|
+
#pragma unroll
|
1027
|
+
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
1028
|
+
meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
|
1029
|
+
}
|
1030
|
+
|
1031
|
+
float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
|
1032
|
+
#pragma unroll
|
1033
|
+
for (int imeta = 1; imeta < nmeta; ++imeta) {
|
1034
|
+
KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x);
|
1035
|
+
}
|
1036
|
+
#pragma unroll
|
1037
|
+
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
1038
|
+
if (offset < WARP_SIZE) {
|
1039
|
+
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
|
1040
|
+
}
|
1041
|
+
}
|
1042
|
+
|
1043
|
+
float KQ_cms[nmeta]; // KQ combine max scale per warp.
|
1044
|
+
#pragma unroll
|
1045
|
+
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
1046
|
+
KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn);
|
1047
|
+
}
|
1048
|
+
|
1049
|
+
float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps.
|
1050
|
+
#pragma unroll
|
1051
|
+
for (int imeta = 1; imeta < nmeta; ++imeta) {
|
1052
|
+
KQ_crs += KQ_cms[imeta]*meta[imeta].y;
|
1053
|
+
}
|
1054
|
+
#pragma unroll
|
1055
|
+
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
1056
|
+
if (offset < WARP_SIZE) {
|
1057
|
+
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
|
1058
|
+
}
|
1059
|
+
}
|
1060
|
+
|
1061
|
+
__syncthreads();
|
1062
|
+
|
1063
|
+
// Write back combined meta data:
|
1064
|
+
#pragma unroll
|
1065
|
+
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
1066
|
+
if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
|
1067
|
+
// Combined KQ max scale + rowsum.
|
1068
|
+
meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
|
1069
|
+
}
|
1070
|
+
}
|
1071
|
+
|
1072
|
+
// Combined KQ max + rowsum.
|
1073
|
+
static_assert(cols_per_warp <= WARP_SIZE);
|
1074
|
+
if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
|
1075
|
+
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
1076
|
+
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
1077
|
+
}
|
1078
|
+
if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
|
1079
|
+
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
1080
|
+
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
1081
|
+
}
|
1082
|
+
} else if (np > 1) {
|
1083
|
+
// Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.
|
1084
|
+
// Therefore, all other warps also need to execute a __syncthreads().
|
1085
|
+
// Otherwise the points at which warps synchronize with each other would become misaligned.
|
1086
|
+
__syncthreads();
|
1087
|
+
}
|
1088
|
+
|
1089
|
+
#pragma unroll
|
1090
|
+
for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
|
1091
|
+
if (ntiles == 1) {
|
1092
|
+
const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
|
1093
|
+
#pragma unroll
|
1094
|
+
for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) {
|
1095
|
+
const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
|
1096
|
+
|
1097
|
+
#pragma unroll
|
1098
|
+
for (int l = 0; l < tile_B::ne; ++l) {
|
1099
|
+
const int k = k0 + tile_B::get_j(l);
|
1100
|
+
|
1101
|
+
tile_Q[jc_cwd*tile_stride + k] = B.x[l];
|
1102
|
+
}
|
1103
|
+
}
|
1104
|
+
} else {
|
1105
|
+
#pragma unroll
|
1106
|
+
for (int t = 0; t < ntiles/2; ++t) {
|
1107
|
+
const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
|
1108
|
+
#pragma unroll
|
1109
|
+
for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) {
|
1110
|
+
#pragma unroll
|
1111
|
+
for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
|
1112
|
+
const int j = j0 + tile_C_VKQ_16::get_i(l);
|
1113
|
+
const int k = k0 + tile_C_VKQ_16::get_j(l);
|
1114
|
+
|
1115
|
+
tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
|
1116
|
+
}
|
1117
|
+
}
|
1118
|
+
}
|
1119
|
+
}
|
1120
|
+
|
1121
|
+
__syncthreads();
|
1122
|
+
|
1123
|
+
if (np == 1 || threadIdx.y % np == 0) {
|
1124
|
+
// The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
|
1125
|
+
// The values after that are for the partial results of the individual blocks.
|
1126
|
+
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
|
1127
|
+
|
1128
|
+
#pragma unroll
|
1129
|
+
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
1130
|
+
const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
|
1131
|
+
const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k);
|
1132
|
+
const int stride_jc = WARP_SIZE / stride_k;
|
1133
|
+
|
1134
|
+
if (k0_start == k0_stop) {
|
1135
|
+
continue;
|
1136
|
+
}
|
1137
|
+
|
1138
|
+
#pragma unroll
|
1139
|
+
for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
|
1140
|
+
const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
1141
|
+
|
1142
|
+
if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
|
1143
|
+
break;
|
1144
|
+
}
|
1145
|
+
|
1146
|
+
const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;
|
1147
|
+
|
1148
|
+
const int j_dst = jc_dst / ncols2;
|
1149
|
+
const int c_dst = jc_dst % ncols2;
|
1150
|
+
|
1151
|
+
if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
|
1152
|
+
continue;
|
1153
|
+
}
|
1154
|
+
|
1155
|
+
const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
|
1156
|
+
#pragma unroll
|
1157
|
+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
1158
|
+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
1159
|
+
|
1160
|
+
float2 dstk_val = make_float2(0.0f, 0.0f);
|
1161
|
+
#pragma unroll
|
1162
|
+
for (int ip = 0; ip < np; ++ip) {
|
1163
|
+
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0];
|
1164
|
+
const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]);
|
1165
|
+
dstk_val.x += dstk_val_add.x*KQ_crs;
|
1166
|
+
dstk_val.y += dstk_val_add.y*KQ_crs;
|
1167
|
+
}
|
1168
|
+
|
1169
|
+
if (!needs_fixup && !is_fixup) {
|
1170
|
+
const float KQ_rowsum_j = meta_j[1];
|
1171
|
+
dstk_val.x /= KQ_rowsum_j;
|
1172
|
+
dstk_val.y /= KQ_rowsum_j;
|
1173
|
+
}
|
1174
|
+
|
1175
|
+
if (is_fixup) {
|
1176
|
+
dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val;
|
1177
|
+
} else {
|
1178
|
+
dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val;
|
1179
|
+
}
|
1180
|
+
}
|
1181
|
+
}
|
1182
|
+
}
|
1183
|
+
}
|
1184
|
+
if (np > 1) {
|
1185
|
+
__syncthreads();
|
1186
|
+
}
|
1187
|
+
}
|
1188
|
+
#else
|
1189
|
+
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
|
1190
|
+
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
1191
|
+
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
1192
|
+
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
|
1193
|
+
GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
|
1194
|
+
GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
|
1195
|
+
NO_DEVICE_CODE;
|
1196
|
+
#endif // NEW_MMA_AVAILABLE
|
1197
|
+
}
|
1198
|
+
|
1199
|
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
|
1200
|
+
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
1201
|
+
static __global__ void flash_attn_ext_f16(
|
1202
|
+
const char * __restrict__ Q,
|
1203
|
+
const char * __restrict__ K,
|
1204
|
+
const char * __restrict__ V,
|
1205
|
+
const char * __restrict__ mask,
|
1206
|
+
float * __restrict__ dst,
|
1207
|
+
float2 * __restrict__ dst_meta,
|
1208
|
+
const float scale,
|
1209
|
+
const float max_bias,
|
1210
|
+
const float m0,
|
1211
|
+
const float m1,
|
1212
|
+
const uint32_t n_head_log2,
|
1213
|
+
const float logit_softcap,
|
1214
|
+
const int ne00,
|
1215
|
+
const int ne01,
|
1216
|
+
const int ne02,
|
1217
|
+
const int ne03,
|
1218
|
+
const int ne10,
|
1219
|
+
const int ne11,
|
1220
|
+
const int ne12,
|
1221
|
+
const int ne13,
|
1222
|
+
const int ne31,
|
1223
|
+
const int nb31,
|
1224
|
+
const int nb01,
|
1225
|
+
const int nb02,
|
1226
|
+
const int nb03,
|
1227
|
+
const int nb11,
|
1228
|
+
const int nb12,
|
1229
|
+
const int nb13,
|
1230
|
+
const int nb21,
|
1231
|
+
const int nb22,
|
1232
|
+
const int nb23,
|
1233
|
+
const int ne0,
|
1234
|
+
const int ne1,
|
1235
|
+
const int ne2,
|
1236
|
+
const int ne3) {
|
1237
|
+
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
1238
|
+
|
1239
|
+
// Skip unused kernel variants for faster compilation:
|
1240
|
+
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
1241
|
+
NO_DEVICE_CODE;
|
1242
|
+
return;
|
1243
|
+
}
|
1244
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
1245
|
+
if (ncols1*ncols2 > 32) {
|
1246
|
+
NO_DEVICE_CODE;
|
1247
|
+
return;
|
1248
|
+
}
|
1249
|
+
#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
1250
|
+
|
1251
|
+
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
|
1252
|
+
|
1253
|
+
typedef fattn_mma_f16_config<DKQ, DV> c;
|
1254
|
+
|
1255
|
+
static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config<DKQ, DV>::nbatch_fa == 0, "bad nbatch_fa");
|
1256
|
+
|
1257
|
+
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
1258
|
+
|
1259
|
+
const int stride_Q1 = nb01 / sizeof(float2);
|
1260
|
+
const int stride_Q2 = nb02 / sizeof(float2);
|
1261
|
+
const int stride_K = nb11 / sizeof(half2);
|
1262
|
+
const int stride_mask = nb31 / sizeof(half2);
|
1263
|
+
|
1264
|
+
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
|
1265
|
+
|
1266
|
+
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
1267
|
+
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
1268
|
+
|
1269
|
+
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
|
1270
|
+
|
1271
|
+
// kbc == k block continuous, current index in continuous ijk space.
|
1272
|
+
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
1273
|
+
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
1274
|
+
|
1275
|
+
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
1276
|
+
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
1277
|
+
// In the most general case >2 seams can fall into the same tile.
|
1278
|
+
|
1279
|
+
// kb0 == k start index when in the output tile.
|
1280
|
+
int kb0_start = kbc % iter_k;
|
1281
|
+
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
1282
|
+
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
1283
|
+
const int channel = kbc / (iter_k*iter_j);
|
1284
|
+
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
1285
|
+
|
1286
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
1287
|
+
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
1288
|
+
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
1289
|
+
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
1290
|
+
|
1291
|
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
1292
|
+
|
1293
|
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
1294
|
+
|
1295
|
+
const int kb0_start_kernel = kb0_start * kb_niter;
|
1296
|
+
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
1297
|
+
|
1298
|
+
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
1299
|
+
if (kb0_start == 0) {
|
1300
|
+
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
1301
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
1302
|
+
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
1303
|
+
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
1304
|
+
} else {
|
1305
|
+
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
1306
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
1307
|
+
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
1308
|
+
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
1309
|
+
}
|
1310
|
+
|
1311
|
+
kbc += iter_k;
|
1312
|
+
kbc -= kbc % iter_k;
|
1313
|
+
|
1314
|
+
kb0_start = 0;
|
1315
|
+
kb0_stop = min(iter_k, kbc_stop - kbc);
|
1316
|
+
}
|
1317
|
+
|
1318
|
+
if (kbc >= kbc_stop) {
|
1319
|
+
return;
|
1320
|
+
}
|
1321
|
+
|
1322
|
+
const int channel = kbc / (iter_k*iter_j);
|
1323
|
+
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
1324
|
+
|
1325
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
1326
|
+
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
1327
|
+
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
1328
|
+
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
1329
|
+
|
1330
|
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
1331
|
+
|
1332
|
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
1333
|
+
|
1334
|
+
const int kb0_start_kernel = kb0_start * kb_niter;
|
1335
|
+
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
1336
|
+
|
1337
|
+
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
1338
|
+
constexpr bool needs_fixup = false;
|
1339
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
1340
|
+
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
1341
|
+
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
1342
|
+
#else
|
1343
|
+
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
1344
|
+
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
1345
|
+
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
1346
|
+
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
1347
|
+
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
1348
|
+
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
1349
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
1350
|
+
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
1351
|
+
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
1352
|
+
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
1353
|
+
NO_DEVICE_CODE;
|
1354
|
+
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
1355
|
+
}
|
1356
|
+
|
1357
|
+
template <int DKQ, int DV, int ncols1, int ncols2>
|
1358
|
+
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
1359
|
+
const ggml_tensor * KQV = dst;
|
1360
|
+
const int id = ggml_cuda_get_device();
|
1361
|
+
const int cc = ggml_cuda_info().devices[id].cc;
|
1362
|
+
|
1363
|
+
typedef fattn_mma_f16_config<DKQ, DV> c;
|
1364
|
+
|
1365
|
+
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
|
1366
|
+
|
1367
|
+
constexpr int ncols = ncols1 * ncols2;
|
1368
|
+
constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp.
|
1369
|
+
constexpr int cols_per_warp = ntiles * tile_B::I;
|
1370
|
+
constexpr int nwarps_max_x = ncols / cols_per_warp;
|
1371
|
+
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
|
1372
|
+
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
|
1373
|
+
|
1374
|
+
constexpr bool mla = DKQ == 576;
|
1375
|
+
|
1376
|
+
const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
|
1377
|
+
const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
|
1378
|
+
const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
|
1379
|
+
|
1380
|
+
static_assert(DKQ % tile_B::J == 0, "bad DKQ");
|
1381
|
+
static_assert(DV % tile_A::J == 0, "bad DV");
|
1382
|
+
static_assert(ncols % cols_per_warp == 0, "bad ncols");
|
1383
|
+
|
1384
|
+
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
1385
|
+
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
1386
|
+
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
|
1387
|
+
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
|
1388
|
+
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
|
1389
|
+
|
1390
|
+
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
|
1391
|
+
|
1392
|
+
const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ?
|
1393
|
+
std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) :
|
1394
|
+
nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask);
|
1395
|
+
|
1396
|
+
float logit_softcap;
|
1397
|
+
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
1398
|
+
|
1399
|
+
fattn_kernel_t fattn_kernel;
|
1400
|
+
if (logit_softcap == 0.0f) {
|
1401
|
+
constexpr bool use_logit_softcap = false;
|
1402
|
+
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
1403
|
+
|
1404
|
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
1405
|
+
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
1406
|
+
if (!shared_memory_limit_raised[id]) {
|
1407
|
+
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
1408
|
+
shared_memory_limit_raised[id] = true;
|
1409
|
+
}
|
1410
|
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
1411
|
+
} else {
|
1412
|
+
constexpr bool use_logit_softcap = true;
|
1413
|
+
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
1414
|
+
|
1415
|
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
1416
|
+
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
1417
|
+
if (!shared_memory_limit_raised[id]) {
|
1418
|
+
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
1419
|
+
shared_memory_limit_raised[id] = true;
|
1420
|
+
}
|
1421
|
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
1422
|
+
}
|
1423
|
+
|
1424
|
+
launch_fattn<DV, ncols1, ncols2>
|
1425
|
+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
|
1426
|
+
}
|
1427
|
+
|
1428
|
+
|
1429
|
+
#define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2) \
|
1430
|
+
template void ggml_cuda_flash_attn_ext_mma_f16_case \
|
1431
|
+
<DKQ, DV, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
1432
|
+
|
1433
|
+
#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols) \
|
1434
|
+
extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 1, 1); \
|
1435
|
+
extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 2, 2); \
|
1436
|
+
extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 4, 4); \
|
1437
|
+
extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 8, 8); \
|
1438
|
+
extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/16, 16); \
|
1439
|
+
|
1440
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 8)
|
1441
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 8)
|
1442
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 8)
|
1443
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 8)
|
1444
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 8)
|
1445
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 8)
|
1446
|
+
|
1447
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 16)
|
1448
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 16)
|
1449
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 16)
|
1450
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 16)
|
1451
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 16)
|
1452
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 16)
|
1453
|
+
|
1454
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 32)
|
1455
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 32)
|
1456
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 32)
|
1457
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 32)
|
1458
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 32)
|
1459
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 32)
|
1460
|
+
|
1461
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 64)
|
1462
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 64)
|
1463
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 64)
|
1464
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64)
|
1465
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64)
|
1466
|
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
1467
|
+
|
1468
|
+
// The number of viable configurations for Deepseek is very limited:
|
1469
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
1470
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
1471
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|