whispercpp 1.3.3 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +79 -25
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +122 -111
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +34 -24
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
- data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
- data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
- data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
- data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
- data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
- data/ext/sources/examples/talk-llama/llama-context.h +99 -36
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
- data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
- data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
- data/ext/sources/examples/talk-llama/llama-model.h +104 -12
- data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
- data/ext/sources/examples/talk-llama/llama.cpp +794 -12
- data/ext/sources/examples/talk-llama/llama.h +246 -190
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
- data/ext/sources/ggml/CMakeLists.txt +135 -79
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +21 -2
- data/ext/sources/ggml/include/ggml-cpu.h +2 -1
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +406 -23
- data/ext/sources/ggml/src/CMakeLists.txt +99 -13
- data/ext/sources/ggml/src/ggml-alloc.c +368 -161
- data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
- data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
- data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
- data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
- data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
- data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
- data/ext/sources/ggml/src/ggml-impl.h +186 -15
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +901 -129
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +124 -81
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +7 -5
- data/ext/sources/tests/test-vad.cpp +3 -3
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +126 -2
- data/test/test_params.rb +24 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +8 -1
- data/whispercpp.gemspec +1 -1
- metadata +439 -179
- data/ext/sources/build-xcframework.sh +0 -547
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
|
@@ -5,284 +5,211 @@
|
|
|
5
5
|
|
|
6
6
|
using namespace ggml_cuda_mma;
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
typedef tile< 8, 8, half2> tile_B;
|
|
10
|
-
typedef tile<16, 8, half2> tile_B_16;
|
|
11
|
-
typedef tile<16, 8, float> tile_C_KQ;
|
|
12
|
-
typedef tile<16, 16, float> tile_C_KQ_16;
|
|
13
|
-
typedef tile<16, 4, half2> tile_C_VKQ;
|
|
14
|
-
typedef tile<16, 8, half2> tile_C_VKQ_16;
|
|
15
|
-
|
|
16
|
-
// Config options for specific head sizes.
|
|
8
|
+
// Config options for the MMA kernel.
|
|
17
9
|
// Should not affect results, only speed/register pressure/shared memory use.
|
|
18
|
-
|
|
19
|
-
//
|
|
20
|
-
//
|
|
21
|
-
//
|
|
22
|
-
//
|
|
23
|
-
//
|
|
24
|
-
//
|
|
25
|
-
//
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
static constexpr int nbatch_fa = 64;
|
|
33
|
-
static constexpr int nwarps_max = 4;
|
|
34
|
-
static constexpr bool Q_in_reg = true;
|
|
35
|
-
static constexpr int nstages_target = 2;
|
|
36
|
-
|
|
37
|
-
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
|
38
|
-
return 32;
|
|
39
|
-
}
|
|
40
|
-
|
|
41
|
-
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
|
42
|
-
return 32;
|
|
43
|
-
}
|
|
44
|
-
|
|
45
|
-
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
|
46
|
-
return 32;
|
|
47
|
-
}
|
|
48
|
-
|
|
49
|
-
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
|
50
|
-
return 32;
|
|
51
|
-
}
|
|
52
|
-
|
|
53
|
-
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
|
54
|
-
return 32;
|
|
55
|
-
}
|
|
56
|
-
|
|
57
|
-
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
|
58
|
-
return 32;
|
|
59
|
-
}
|
|
10
|
+
struct fattn_mma_config {
|
|
11
|
+
int nthreads; // Number of threads per CUDA block.
|
|
12
|
+
int occupancy; // Targeted occupancy for the MMA kernel.
|
|
13
|
+
int nbatch_fa; // Number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
|
|
14
|
+
int nbatch_K2; // Number of K half2 values in direction of DKQ to load in parallel.
|
|
15
|
+
int nbatch_V2; // Number of V half2 values in direction of DV to load in parallel.
|
|
16
|
+
int nbatch_combine; // Number of VKQ half2 values in direction of DV to combine in parallel.
|
|
17
|
+
int nstages_target; // Number of pipeline stages to use ideally, 1 == always load data synchronously, 2 == preload data if there is hardware support.
|
|
18
|
+
bool Q_in_reg; // Whether the Q values should be kept permanently in registers.
|
|
19
|
+
|
|
20
|
+
constexpr __host__ __device__ fattn_mma_config(
|
|
21
|
+
int nthreads, int occupancy, int nbatch_fa, int nbatch_K2, int nbatch_V2, int nbatch_combine, int nstages_target, bool Q_in_reg) :
|
|
22
|
+
nthreads(nthreads), occupancy(occupancy), nbatch_fa(nbatch_fa), nbatch_K2(nbatch_K2), nbatch_V2(nbatch_V2), nbatch_combine(nbatch_combine),
|
|
23
|
+
nstages_target(nstages_target), Q_in_reg(Q_in_reg) {}
|
|
60
24
|
};
|
|
61
25
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
|
114
|
-
return 48;
|
|
115
|
-
}
|
|
116
|
-
|
|
117
|
-
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
|
118
|
-
return 48;
|
|
119
|
-
}
|
|
26
|
+
#define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads_, occupancy_, nbatch_fa_, nbatch_K2_, nbatch_V2_, nbatch_combine_, nstages_target_, Q_in_reg_) \
|
|
27
|
+
if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \
|
|
28
|
+
static_assert((nthreads_) % 32 == 0 && (nthreads_) <= 512, "bad nthreads"); \
|
|
29
|
+
static_assert( (occupancy_) <= 8, "bad occupancy"); \
|
|
30
|
+
static_assert((nbatch_fa_) % 32 == 0 && (nbatch_fa_) <= 256, "bad nbatch_fa"); \
|
|
31
|
+
static_assert((nbatch_K2_) % 4 == 0 && (nbatch_K2_) <= 512, "bad nbatch_K2"); \
|
|
32
|
+
static_assert((nbatch_V2_) % 4 == 0 && (nbatch_V2_) <= 256, "bad nbatch_V2"); \
|
|
33
|
+
static_assert((nbatch_combine_) % 4 == 0 && (nbatch_combine_) <= 128, "bad nbatch_combine"); \
|
|
34
|
+
static_assert((nstages_target_) >= 1 && (nstages_target_) <= 2, "bad nstages_target"); \
|
|
35
|
+
return fattn_mma_config{(nthreads_), (occupancy_), (nbatch_fa_), (nbatch_K2_), (nbatch_V2_), (nbatch_combine_), (nstages_target_), (Q_in_reg_)}; \
|
|
36
|
+
} \
|
|
37
|
+
|
|
38
|
+
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) {
|
|
39
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 2, true);
|
|
40
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 2, true);
|
|
41
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 2, true);
|
|
42
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 2, true);
|
|
43
|
+
|
|
44
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 2, true);
|
|
45
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 2, true);
|
|
46
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 2, true);
|
|
47
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 2, true);
|
|
48
|
+
|
|
49
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 2, true);
|
|
50
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 2, true);
|
|
51
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 2, true);
|
|
52
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 2, true);
|
|
53
|
+
|
|
54
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 2, true);
|
|
55
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 2, true);
|
|
56
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 2, true);
|
|
57
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 2, true);
|
|
58
|
+
|
|
59
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 2, true);
|
|
60
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 2, true);
|
|
61
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true);
|
|
62
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true);
|
|
63
|
+
|
|
64
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true);
|
|
65
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true);
|
|
66
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
|
67
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
|
68
|
+
|
|
69
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
|
|
70
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
|
|
71
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
|
72
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
|
73
|
+
|
|
74
|
+
return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
|
|
75
|
+
}
|
|
120
76
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
77
|
+
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) {
|
|
78
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 128, 2, 64, 128, 128, 128, 2, true);
|
|
79
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true);
|
|
80
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
|
81
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
|
125
82
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
static constexpr bool Q_in_reg = true;
|
|
131
|
-
static constexpr int nstages_target = 2;
|
|
83
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
|
84
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
|
85
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
|
86
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
|
132
87
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
}
|
|
88
|
+
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
89
|
+
}
|
|
136
90
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
91
|
+
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
|
|
92
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
|
|
93
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
|
|
94
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
|
|
95
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
|
|
140
96
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
97
|
+
// TODO tune specifically for Volta
|
|
98
|
+
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
99
|
+
}
|
|
144
100
|
|
|
145
|
-
|
|
146
|
-
|
|
101
|
+
static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
102
|
+
if (ampere_mma_available(cc)) {
|
|
103
|
+
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
147
104
|
}
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
return 56;
|
|
105
|
+
if (turing_mma_available(cc)) {
|
|
106
|
+
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
|
|
151
107
|
}
|
|
108
|
+
GGML_ASSERT(volta_mma_available(cc));
|
|
109
|
+
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
|
|
110
|
+
}
|
|
152
111
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
112
|
+
static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) {
|
|
113
|
+
#if defined(AMPERE_MMA_AVAILABLE)
|
|
114
|
+
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
115
|
+
#elif defined(TURING_MMA_AVAILABLE)
|
|
116
|
+
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
|
|
117
|
+
#elif defined(VOLTA_MMA_AVAILABLE)
|
|
118
|
+
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
|
|
119
|
+
#else
|
|
120
|
+
GGML_UNUSED_VARS(DKQ, DV, ncols);
|
|
121
|
+
return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
|
|
122
|
+
#endif // defined(AMPERE_MMA_AVAILABLE)
|
|
123
|
+
}
|
|
157
124
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
static constexpr int nwarps_max = 4;
|
|
162
|
-
static constexpr bool Q_in_reg = true;
|
|
163
|
-
static constexpr int nstages_target = 2;
|
|
125
|
+
static __host__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
126
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nthreads;
|
|
127
|
+
}
|
|
164
128
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
129
|
+
static constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols) {
|
|
130
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nthreads;
|
|
131
|
+
}
|
|
168
132
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
133
|
+
static __host__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
134
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).occupancy;
|
|
135
|
+
}
|
|
172
136
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
137
|
+
static constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols) {
|
|
138
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).occupancy;
|
|
139
|
+
}
|
|
176
140
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
141
|
+
static __host__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
142
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_fa;
|
|
143
|
+
}
|
|
180
144
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
145
|
+
static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
|
|
146
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_fa;
|
|
147
|
+
}
|
|
184
148
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
};
|
|
149
|
+
static __host__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
150
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_K2;
|
|
151
|
+
}
|
|
189
152
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
static constexpr int nwarps_max = 4;
|
|
194
|
-
static constexpr bool Q_in_reg = true;
|
|
195
|
-
static constexpr int nstages_target = 2;
|
|
153
|
+
static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols) {
|
|
154
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_K2;
|
|
155
|
+
}
|
|
196
156
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
157
|
+
static __host__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
158
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_V2;
|
|
159
|
+
}
|
|
200
160
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
161
|
+
static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols) {
|
|
162
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_V2;
|
|
163
|
+
}
|
|
204
164
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
165
|
+
static __host__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
166
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_combine;
|
|
167
|
+
}
|
|
208
168
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
169
|
+
static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols) {
|
|
170
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_combine;
|
|
171
|
+
}
|
|
212
172
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
}
|
|
217
|
-
return 64;
|
|
218
|
-
}
|
|
173
|
+
static __host__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
174
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nstages_target;
|
|
175
|
+
}
|
|
219
176
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
#else
|
|
224
|
-
GGML_UNUSED(ncols);
|
|
225
|
-
return 128;
|
|
226
|
-
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
227
|
-
}
|
|
228
|
-
};
|
|
177
|
+
static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols) {
|
|
178
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nstages_target;
|
|
179
|
+
}
|
|
229
180
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
static constexpr int nwarps_max = 8;
|
|
234
|
-
static constexpr bool Q_in_reg = false;
|
|
235
|
-
static constexpr int nstages_target = 1;
|
|
181
|
+
static __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
182
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).Q_in_reg;
|
|
183
|
+
}
|
|
236
184
|
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
}
|
|
241
|
-
return ncols <= 16 ? 288 : 160;
|
|
242
|
-
}
|
|
185
|
+
static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) {
|
|
186
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
|
|
187
|
+
}
|
|
243
188
|
|
|
244
|
-
|
|
245
|
-
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
246
|
-
return ncols <= 16 ? 96 : 160;
|
|
247
|
-
#else
|
|
248
|
-
return ncols <= 16 ? 288 : 160;
|
|
249
|
-
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
250
|
-
}
|
|
189
|
+
// ------------------------------------------------------------------------------------------------------------------
|
|
251
190
|
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
}
|
|
256
|
-
return ncols <= 16 ? 256 : 128;
|
|
257
|
-
}
|
|
191
|
+
static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
|
|
192
|
+
return cp_async_available(cc) && ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2, cc) : 0;
|
|
193
|
+
}
|
|
258
194
|
|
|
259
|
-
|
|
260
|
-
#
|
|
261
|
-
|
|
195
|
+
static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2) {
|
|
196
|
+
#ifdef CP_ASYNC_AVAILABLE
|
|
197
|
+
return ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2) : 0;
|
|
262
198
|
#else
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
|
268
|
-
return 128;
|
|
269
|
-
}
|
|
270
|
-
|
|
271
|
-
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
|
272
|
-
return 128;
|
|
273
|
-
}
|
|
274
|
-
};
|
|
199
|
+
GGML_UNUSED_VARS(DKQ, DV, ncols1, ncols2);
|
|
200
|
+
return 0;
|
|
201
|
+
#endif // CP_ASYNC_AVAILABLE
|
|
202
|
+
}
|
|
275
203
|
|
|
276
204
|
// ------------------------------------------------------------------------------------------------------------------
|
|
277
205
|
|
|
278
|
-
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
|
|
206
|
+
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
|
|
279
207
|
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
280
|
-
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
|
|
281
|
-
|
|
208
|
+
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
|
|
282
209
|
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
|
283
210
|
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
|
|
284
|
-
|
|
285
|
-
|
|
211
|
+
if constexpr (use_cp_async) {
|
|
212
|
+
static_assert(!oob_check, "OOB check not compatible with cp_async");
|
|
286
213
|
constexpr int preload = 64;
|
|
287
214
|
constexpr int h2_per_chunk = 16/sizeof(half2);
|
|
288
215
|
const int chunks_per_row = D2 / h2_per_chunk;
|
|
@@ -315,9 +242,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
315
242
|
}
|
|
316
243
|
}
|
|
317
244
|
};
|
|
318
|
-
|
|
245
|
+
// 1: max 32*16=512 bytes, 256 half
|
|
246
|
+
// 2: max 16*16=256 bytes, 128 half
|
|
247
|
+
// 3: max 8*16=128 bytes, 64 half
|
|
248
|
+
// 4: max 4*16= 64 bytes, 32 half
|
|
249
|
+
// 5: max 2*16= 32 bytes, 16 half
|
|
250
|
+
// 6: max 1*16= 16 bytes, 8 half
|
|
251
|
+
ggml_cuda_unroll<6>{}(load);
|
|
319
252
|
} else {
|
|
320
|
-
|
|
253
|
+
// TODO use ggml_cuda_memcpy_1
|
|
321
254
|
auto load = [&] __device__ (const int n) {
|
|
322
255
|
const int stride_k = WARP_SIZE >> n;
|
|
323
256
|
const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
|
|
@@ -340,20 +273,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
340
273
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
341
274
|
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
|
342
275
|
|
|
343
|
-
tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
|
|
276
|
+
tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
|
|
344
277
|
}
|
|
345
278
|
}
|
|
346
279
|
};
|
|
347
|
-
|
|
280
|
+
// 1: max 32* 4=128 bytes, 64 half
|
|
281
|
+
// 2: max 16* 4= 64 bytes, 32 half
|
|
282
|
+
// 3: max 8* 4= 32 bytes, 16 half
|
|
283
|
+
// 4: max 4* 4= 16 bytes, 8 half
|
|
284
|
+
ggml_cuda_unroll<4>{}(load);
|
|
348
285
|
}
|
|
349
286
|
}
|
|
350
287
|
|
|
351
|
-
template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async>
|
|
288
|
+
template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
|
|
352
289
|
static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|
353
|
-
const
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
290
|
+
const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
|
|
291
|
+
const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
|
|
292
|
+
if constexpr (use_cp_async) {
|
|
293
|
+
static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa");
|
|
294
|
+
static_assert(!oob_check, "OOB check incompatible with cp_async");
|
|
357
295
|
constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
|
|
358
296
|
constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
|
|
359
297
|
constexpr int stride_j = nwarps * cols_per_warp;
|
|
@@ -361,78 +299,110 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|
|
361
299
|
const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
|
|
362
300
|
|
|
363
301
|
#pragma unroll
|
|
364
|
-
for (int
|
|
365
|
-
const int
|
|
366
|
-
|
|
302
|
+
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
|
|
303
|
+
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
|
|
304
|
+
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
|
367
305
|
|
|
368
|
-
if (
|
|
306
|
+
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
|
|
369
307
|
break;
|
|
370
308
|
}
|
|
371
309
|
|
|
372
|
-
const int i =
|
|
310
|
+
const int i = 8 * (threadIdx.x % (nbatch_fa/8));
|
|
373
311
|
|
|
374
|
-
cp_async_cg_16<preload>(tile_mask_32 +
|
|
312
|
+
cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i);
|
|
375
313
|
}
|
|
376
|
-
|
|
377
|
-
|
|
314
|
+
} else if constexpr (oob_check) {
|
|
315
|
+
#pragma unroll
|
|
316
|
+
for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
|
|
317
|
+
const int j_sram = j1 + threadIdx.y;
|
|
318
|
+
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
|
319
|
+
|
|
320
|
+
if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
|
|
321
|
+
break;
|
|
322
|
+
}
|
|
378
323
|
|
|
379
|
-
constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
|
|
380
|
-
constexpr int stride_j = nwarps * cols_per_warp;
|
|
381
324
|
#pragma unroll
|
|
382
|
-
|
|
383
|
-
|
|
325
|
+
for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
|
|
326
|
+
const int i = i0 + threadIdx.x;
|
|
384
327
|
|
|
385
|
-
|
|
386
|
-
|
|
328
|
+
tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
|
|
329
|
+
}
|
|
387
330
|
}
|
|
331
|
+
} else if constexpr (nbatch_fa < 2*WARP_SIZE) {
|
|
332
|
+
constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
|
|
333
|
+
constexpr int stride_j = nwarps * cols_per_warp;
|
|
334
|
+
#pragma unroll
|
|
335
|
+
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
|
|
336
|
+
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
|
|
337
|
+
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
|
388
338
|
|
|
389
|
-
|
|
339
|
+
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
|
|
340
|
+
break;
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
|
|
344
|
+
|
|
345
|
+
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
|
|
346
|
+
}
|
|
347
|
+
} else {
|
|
348
|
+
#pragma unroll
|
|
349
|
+
for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
|
|
350
|
+
const int j_sram = j1 + threadIdx.y;
|
|
351
|
+
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
|
390
352
|
|
|
391
|
-
|
|
353
|
+
if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
|
|
354
|
+
break;
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
#pragma unroll
|
|
358
|
+
for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
|
|
359
|
+
const int i = i0 + 2*threadIdx.x;
|
|
360
|
+
|
|
361
|
+
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
|
|
362
|
+
}
|
|
363
|
+
}
|
|
392
364
|
}
|
|
393
365
|
}
|
|
394
366
|
|
|
395
|
-
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
|
|
367
|
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
|
|
368
|
+
bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
|
|
369
|
+
typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
|
|
396
370
|
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
397
371
|
const float2 * const __restrict__ Q_f2,
|
|
398
372
|
const half2 * const __restrict__ K_h2,
|
|
399
373
|
const half2 * const __restrict__ V_h2,
|
|
400
|
-
const
|
|
374
|
+
const half * const __restrict__ mask_h,
|
|
401
375
|
float2 * const __restrict__ dstk,
|
|
402
376
|
float2 * const __restrict__ dstk_fixup,
|
|
403
377
|
const float scale,
|
|
404
378
|
const float slope,
|
|
405
379
|
const float logit_softcap,
|
|
406
|
-
const
|
|
380
|
+
const uint3 ne01,
|
|
407
381
|
const int ne02,
|
|
408
382
|
const int stride_K,
|
|
409
383
|
const int stride_V,
|
|
410
384
|
const int stride_mask,
|
|
411
|
-
const int jt,
|
|
412
385
|
half2 * const __restrict__ tile_Q,
|
|
413
386
|
half2 * const __restrict__ tile_K,
|
|
414
387
|
half2 * const __restrict__ tile_V,
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
388
|
+
half * const __restrict__ tile_mask,
|
|
389
|
+
T_B_KQ * const __restrict__ Q_B,
|
|
390
|
+
T_C_VKQ * const __restrict__ VKQ_C,
|
|
418
391
|
float * const __restrict__ KQ_max,
|
|
419
392
|
float * const __restrict__ KQ_rowsum,
|
|
420
|
-
const int
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
constexpr int
|
|
426
|
-
|
|
427
|
-
constexpr int
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
constexpr int
|
|
431
|
-
constexpr
|
|
432
|
-
constexpr int
|
|
433
|
-
constexpr int ncols = ncols1 * ncols2;
|
|
434
|
-
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
|
435
|
-
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
|
393
|
+
const int jt,
|
|
394
|
+
const int kb0,
|
|
395
|
+
const int k_VKQ_sup) {
|
|
396
|
+
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
397
|
+
constexpr int ncols = ncols1 * ncols2;
|
|
398
|
+
constexpr int cols_per_warp = T_B_KQ::I;
|
|
399
|
+
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
|
400
|
+
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
|
401
|
+
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
|
402
|
+
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
|
403
|
+
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
|
404
|
+
constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);
|
|
405
|
+
constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2);
|
|
436
406
|
|
|
437
407
|
constexpr int stride_tile_Q = DKQ/2 + 4;
|
|
438
408
|
constexpr int stride_tile_K = nbatch_K2 + 4;
|
|
@@ -440,26 +410,27 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
440
410
|
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
|
441
411
|
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
|
442
412
|
|
|
443
|
-
const int k_VKQ_0 = kb0 *
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
|
|
413
|
+
const int k_VKQ_0 = kb0 * nbatch_fa;
|
|
414
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
415
|
+
T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
|
|
416
|
+
#else // Volta
|
|
417
|
+
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
|
|
418
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
450
419
|
|
|
451
420
|
if constexpr (nstages > 1) {
|
|
421
|
+
static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
|
|
452
422
|
static_assert(!mla, "multi-stage loading not implemented for MLA");
|
|
453
423
|
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
|
454
424
|
constexpr bool use_cp_async = true;
|
|
455
425
|
cp_async_wait_all();
|
|
456
426
|
__syncthreads();
|
|
457
|
-
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps,
|
|
458
|
-
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
|
427
|
+
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
428
|
+
(V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V, k_VKQ_sup);
|
|
459
429
|
} else {
|
|
460
430
|
constexpr bool use_cp_async = nstages == 1;
|
|
461
|
-
if (ncols2 > 1 ||
|
|
462
|
-
flash_attn_ext_f16_load_mask<ncols1, nwarps,
|
|
431
|
+
if (ncols2 > 1 || mask_h) {
|
|
432
|
+
flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
433
|
+
(mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
|
|
463
434
|
}
|
|
464
435
|
}
|
|
465
436
|
|
|
@@ -468,10 +439,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
468
439
|
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
|
469
440
|
const int k0_diff = k0_stop - k0_start;
|
|
470
441
|
|
|
471
|
-
if (nstages <= 1) {
|
|
442
|
+
if constexpr (nstages <= 1) {
|
|
472
443
|
constexpr bool use_cp_async = nstages == 1;
|
|
473
|
-
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps,
|
|
474
|
-
(K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
|
444
|
+
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
445
|
+
(K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup);
|
|
475
446
|
if (use_cp_async) {
|
|
476
447
|
cp_async_wait_all();
|
|
477
448
|
}
|
|
@@ -479,55 +450,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
479
450
|
}
|
|
480
451
|
|
|
481
452
|
// Calculate tile of KQ:
|
|
482
|
-
if constexpr (
|
|
453
|
+
if constexpr (Q_in_reg) {
|
|
483
454
|
#pragma unroll
|
|
484
|
-
for (int i_KQ_00 = 0; i_KQ_00 <
|
|
485
|
-
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*
|
|
455
|
+
for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {
|
|
456
|
+
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;
|
|
486
457
|
#pragma unroll
|
|
487
|
-
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 +=
|
|
488
|
-
|
|
458
|
+
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
|
459
|
+
T_A_KQ K_A;
|
|
489
460
|
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
|
490
|
-
if (
|
|
491
|
-
mma(KQ_C[i_KQ_00/(np*
|
|
461
|
+
if constexpr (cols_per_warp == 8) {
|
|
462
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
|
|
492
463
|
} else {
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
// Wide version of KQ_C is column-major => swap A and B.
|
|
496
|
-
mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
|
|
497
|
-
}
|
|
464
|
+
// Wide version of KQ_C is column-major => swap A and B.
|
|
465
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
|
|
498
466
|
}
|
|
499
467
|
}
|
|
500
468
|
}
|
|
501
469
|
} else {
|
|
502
|
-
static_assert(
|
|
470
|
+
static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
|
|
503
471
|
#pragma unroll
|
|
504
|
-
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 +=
|
|
505
|
-
load_ldmatrix(
|
|
472
|
+
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
|
473
|
+
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
|
506
474
|
|
|
507
475
|
#pragma unroll
|
|
508
|
-
for (int i_KQ_00 = 0; i_KQ_00 <
|
|
509
|
-
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*
|
|
476
|
+
for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {
|
|
477
|
+
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;
|
|
510
478
|
|
|
511
|
-
|
|
479
|
+
T_A_KQ K_A;
|
|
512
480
|
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
|
513
481
|
|
|
514
482
|
// Wide version of KQ_C is column-major => swap A and B.
|
|
515
|
-
mma(
|
|
483
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
|
516
484
|
}
|
|
517
485
|
}
|
|
518
486
|
}
|
|
519
487
|
|
|
520
|
-
if (nstages <= 1) {
|
|
488
|
+
if constexpr (nstages <= 1) {
|
|
521
489
|
__syncthreads(); // Only needed if tile_K == tile_V.
|
|
522
490
|
}
|
|
523
491
|
}
|
|
524
492
|
|
|
525
493
|
if (use_logit_softcap) {
|
|
526
|
-
|
|
494
|
+
constexpr int stride = cols_per_warp == 8 ? np*T_C_KQ::I : np*T_C_KQ::J;
|
|
495
|
+
static_assert(nbatch_fa % stride == 0, "bad loop size");
|
|
527
496
|
#pragma unroll
|
|
528
|
-
for (int i = 0; i <
|
|
497
|
+
for (int i = 0; i < nbatch_fa/stride; ++i) {
|
|
529
498
|
#pragma unroll
|
|
530
|
-
for (int l = 0; l <
|
|
499
|
+
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
531
500
|
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
|
|
532
501
|
}
|
|
533
502
|
}
|
|
@@ -540,34 +509,35 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
540
509
|
}
|
|
541
510
|
float KQ_rowsum_add[cols_per_thread] = {0.0f};
|
|
542
511
|
|
|
543
|
-
if (
|
|
544
|
-
if (ncols2 > 1 ||
|
|
512
|
+
if constexpr (cols_per_warp == 8) {
|
|
513
|
+
if (ncols2 > 1 || mask_h) {
|
|
545
514
|
#pragma unroll
|
|
546
|
-
for (int i00 = 0; i00 <
|
|
547
|
-
const int i0 = i00 + (threadIdx.y % np)*
|
|
515
|
+
for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::I) {
|
|
516
|
+
const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::I;
|
|
548
517
|
#pragma unroll
|
|
549
|
-
for (int l = 0; l <
|
|
550
|
-
const int i = i0 +
|
|
551
|
-
const int j = ((threadIdx.y / np)*
|
|
518
|
+
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
519
|
+
const int i = i0 + T_C_KQ::get_i(l);
|
|
520
|
+
const int j = ((threadIdx.y / np)*T_C_KQ::J + T_C_KQ::get_j(l)) / ncols2;
|
|
552
521
|
|
|
553
|
-
KQ_C[i00/(np*
|
|
554
|
-
__half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]);
|
|
522
|
+
KQ_C[i00/(np*T_C_KQ::I)].x[l] += slope * __half2float(tile_mask[j*(nbatch_fa + 8) + i]);
|
|
555
523
|
}
|
|
556
524
|
}
|
|
557
525
|
}
|
|
558
526
|
|
|
559
527
|
// Calculate softmax for each KQ column using the current max. value.
|
|
560
528
|
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
|
561
|
-
static_assert(
|
|
529
|
+
static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size");
|
|
562
530
|
#pragma unroll
|
|
563
|
-
for (int
|
|
531
|
+
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
|
|
564
532
|
#pragma unroll
|
|
565
|
-
for (int l = 0; l <
|
|
566
|
-
|
|
533
|
+
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
534
|
+
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
|
535
|
+
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
|
|
536
|
+
}
|
|
567
537
|
}
|
|
568
538
|
}
|
|
569
539
|
|
|
570
|
-
// Values per KQ column are spread across 8 threads
|
|
540
|
+
// Values per KQ column are spread across 8 threads:
|
|
571
541
|
#pragma unroll
|
|
572
542
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
573
543
|
#pragma unroll
|
|
@@ -576,73 +546,78 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
576
546
|
}
|
|
577
547
|
}
|
|
578
548
|
|
|
579
|
-
static_assert(
|
|
549
|
+
static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size");
|
|
580
550
|
#pragma unroll
|
|
581
|
-
for (int
|
|
551
|
+
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
|
|
582
552
|
#pragma unroll
|
|
583
|
-
for (int l = 0; l <
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
553
|
+
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
554
|
+
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
|
555
|
+
KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]);
|
|
556
|
+
KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
|
|
557
|
+
} else {
|
|
558
|
+
KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
|
|
559
|
+
}
|
|
587
560
|
}
|
|
588
561
|
}
|
|
589
|
-
} else { //
|
|
590
|
-
if (ncols2 > 1 ||
|
|
591
|
-
#pragma unroll
|
|
592
|
-
for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) {
|
|
593
|
-
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
|
|
562
|
+
} else { // not Turing mma or T_B_KQ::I > 8
|
|
563
|
+
if (ncols2 > 1 || mask_h) {
|
|
594
564
|
#pragma unroll
|
|
595
|
-
|
|
565
|
+
for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) {
|
|
566
|
+
const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J;
|
|
596
567
|
#pragma unroll
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
568
|
+
for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) {
|
|
569
|
+
const int i = (i0 + T_C_KQ::get_j(l0)) / 2;
|
|
570
|
+
const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l0)) / ncols2;
|
|
600
571
|
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
|
|
605
|
-
}
|
|
572
|
+
const float2 tmp = __half22float2(((const half2 *)tile_mask)[j*(nbatch_fa/2 + 4) + i]);
|
|
573
|
+
KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x;
|
|
574
|
+
KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y;
|
|
606
575
|
}
|
|
607
576
|
}
|
|
608
577
|
}
|
|
609
578
|
|
|
610
579
|
// Calculate softmax for each KQ column using the current max. value.
|
|
611
580
|
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
|
612
|
-
static_assert(
|
|
613
|
-
#pragma unroll
|
|
614
|
-
for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
|
|
581
|
+
static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size");
|
|
615
582
|
#pragma unroll
|
|
616
|
-
|
|
583
|
+
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
|
|
617
584
|
#pragma unroll
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
585
|
+
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
586
|
+
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
|
587
|
+
// Turing + Volta:
|
|
588
|
+
KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
|
|
621
589
|
}
|
|
622
590
|
}
|
|
623
591
|
}
|
|
624
592
|
|
|
625
|
-
// Values per KQ column are spread across 4 threads, does not need full warp reduce:
|
|
626
593
|
#pragma unroll
|
|
627
594
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
595
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
596
|
+
// Values per KQ column are spread across 4 threads:
|
|
597
|
+
constexpr int offset_first = 2;
|
|
598
|
+
constexpr int offset_last = 1;
|
|
599
|
+
#else
|
|
600
|
+
// Values per KQ column are spread across 2 threads:
|
|
601
|
+
constexpr int offset_first = 2;
|
|
602
|
+
constexpr int offset_last = 2;
|
|
603
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
628
604
|
#pragma unroll
|
|
629
|
-
for (int offset =
|
|
605
|
+
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
|
630
606
|
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
|
|
631
607
|
}
|
|
632
608
|
}
|
|
633
609
|
|
|
634
|
-
static_assert(
|
|
610
|
+
static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size");
|
|
635
611
|
#pragma unroll
|
|
636
|
-
for (int
|
|
612
|
+
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
|
|
637
613
|
#pragma unroll
|
|
638
|
-
for (int
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l];
|
|
614
|
+
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
615
|
+
// Turing + Volta:
|
|
616
|
+
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
|
617
|
+
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]);
|
|
618
|
+
KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
|
|
619
|
+
} else {
|
|
620
|
+
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
|
|
646
621
|
}
|
|
647
622
|
}
|
|
648
623
|
}
|
|
@@ -662,12 +637,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
662
637
|
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
|
|
663
638
|
}
|
|
664
639
|
|
|
665
|
-
|
|
640
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
641
|
+
if constexpr (cols_per_warp == 8) {
|
|
666
642
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
|
667
643
|
#pragma unroll
|
|
668
|
-
for (int i = 0; i < DV/
|
|
644
|
+
for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
|
|
669
645
|
#pragma unroll
|
|
670
|
-
for (int l = 0; l <
|
|
646
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
671
647
|
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
672
648
|
}
|
|
673
649
|
}
|
|
@@ -676,46 +652,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
676
652
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
677
653
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
|
678
654
|
#pragma unroll
|
|
679
|
-
for (int i = 0; i < DV/
|
|
655
|
+
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
|
680
656
|
#pragma unroll
|
|
681
|
-
for (int l0 = 0; l0 <
|
|
682
|
-
|
|
657
|
+
for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {
|
|
658
|
+
VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;
|
|
683
659
|
}
|
|
684
660
|
}
|
|
685
661
|
}
|
|
686
662
|
}
|
|
663
|
+
#else // Volta
|
|
664
|
+
const half2 KQ_max_scale_h2 = make_half2(
|
|
665
|
+
KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
|
|
666
|
+
#pragma unroll
|
|
667
|
+
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
|
668
|
+
#pragma unroll
|
|
669
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
670
|
+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
671
|
+
}
|
|
672
|
+
}
|
|
673
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
687
674
|
}
|
|
688
675
|
|
|
689
676
|
// Convert KQ C tiles into B tiles for VKQ calculation:
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
if (ntiles == 1) {
|
|
677
|
+
T_B_VKQ B[nbatch_fa/(np*2*T_B_VKQ::J)];
|
|
678
|
+
static_assert(nbatch_fa % (np*2*T_B_VKQ::J) == 0, "bad loop size");
|
|
679
|
+
if constexpr (cols_per_warp == 8) {
|
|
694
680
|
#pragma unroll
|
|
695
|
-
for (int k = 0; k <
|
|
681
|
+
for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {
|
|
696
682
|
B[k] = get_transposed(get_half2(KQ_C[k]));
|
|
697
683
|
}
|
|
698
684
|
} else {
|
|
699
|
-
for (int k = 0; k <
|
|
700
|
-
|
|
701
|
-
for (int t = 0; t < ntiles/2; ++t) {
|
|
702
|
-
B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
|
|
703
|
-
}
|
|
685
|
+
for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {
|
|
686
|
+
B[k] = get_half2(KQ_C[k]);
|
|
704
687
|
}
|
|
705
688
|
}
|
|
706
689
|
|
|
707
|
-
if (nstages > 1) {
|
|
690
|
+
if constexpr (nstages > 1) {
|
|
708
691
|
// Preload K tile for next iteration:
|
|
709
692
|
constexpr bool use_cp_async = true;
|
|
710
693
|
cp_async_wait_all();
|
|
711
694
|
__syncthreads();
|
|
712
695
|
if (!last_iter) {
|
|
713
|
-
if (ncols2 > 1 ||
|
|
714
|
-
flash_attn_ext_f16_load_mask<ncols1, nwarps,
|
|
715
|
-
(
|
|
696
|
+
if (ncols2 > 1 || mask_h) {
|
|
697
|
+
flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
698
|
+
(mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
|
|
716
699
|
}
|
|
717
|
-
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps,
|
|
718
|
-
(K_h2 + (k_VKQ_0 +
|
|
700
|
+
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
701
|
+
(K_h2 + int64_t(k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
|
|
719
702
|
}
|
|
720
703
|
}
|
|
721
704
|
|
|
@@ -724,75 +707,119 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
724
707
|
// Therefore, iterate over V in reverse and re-use the data if possible.
|
|
725
708
|
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
|
|
726
709
|
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
|
|
710
|
+
|
|
711
|
+
// Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
|
|
727
712
|
#pragma unroll
|
|
728
713
|
for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
|
|
729
714
|
const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
|
|
730
715
|
const int i0_diff = i0_stop - i0_start;
|
|
731
716
|
|
|
732
|
-
if (nstages <= 1
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
717
|
+
if constexpr (nstages <= 1) {
|
|
718
|
+
if (i0_start < reusable_cutoff) {
|
|
719
|
+
constexpr bool use_cp_async = nstages == 1;
|
|
720
|
+
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
721
|
+
(V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
|
|
722
|
+
if (use_cp_async) {
|
|
723
|
+
cp_async_wait_all();
|
|
724
|
+
}
|
|
725
|
+
__syncthreads();
|
|
738
726
|
}
|
|
739
|
-
__syncthreads();
|
|
740
727
|
}
|
|
741
728
|
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
|
|
742
729
|
|
|
743
|
-
|
|
730
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
731
|
+
constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
|
|
744
732
|
#pragma unroll
|
|
745
|
-
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 +=
|
|
746
|
-
static_assert((
|
|
733
|
+
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
|
|
734
|
+
static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size");
|
|
747
735
|
#pragma unroll
|
|
748
|
-
for (int k00 = 0; k00 <
|
|
749
|
-
const int k0 = k00 + (threadIdx.y % np)*
|
|
736
|
+
for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) {
|
|
737
|
+
const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
|
|
750
738
|
|
|
751
|
-
|
|
739
|
+
T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
|
|
752
740
|
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
|
753
|
-
if (
|
|
754
|
-
mma(VKQ_C[i_VKQ_0/
|
|
741
|
+
if constexpr (T_B_KQ::I == 8) {
|
|
742
|
+
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
|
755
743
|
} else {
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
// Wide version of VKQ_C is column-major => swap A and B.
|
|
759
|
-
mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
|
|
760
|
-
}
|
|
744
|
+
// Wide version of VKQ_C is column-major => swap A and B.
|
|
745
|
+
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
|
|
761
746
|
}
|
|
762
747
|
}
|
|
763
748
|
}
|
|
749
|
+
#else // Volta
|
|
750
|
+
constexpr int i0_stride = 2*T_C_VKQ::J;
|
|
751
|
+
#pragma unroll
|
|
752
|
+
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
|
|
753
|
+
static_assert(nbatch_fa % (np*T_A_VKQ::I) == 0, "bad loop size");
|
|
754
|
+
static_assert(2*T_B_VKQ::J == T_A_VKQ::I, "bad tile sizes");
|
|
755
|
+
#pragma unroll
|
|
756
|
+
for (int k00 = 0; k00 < nbatch_fa; k00 += np*T_A_VKQ::I) {
|
|
757
|
+
const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::I;
|
|
758
|
+
|
|
759
|
+
T_A_VKQ A; // Transposed in both SRAM and registers, load normally.
|
|
760
|
+
load_ldmatrix(A, tile_V_i + k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
|
761
|
+
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
|
|
762
|
+
}
|
|
763
|
+
}
|
|
764
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
764
765
|
|
|
765
|
-
if (nstages <= 1) {
|
|
766
|
+
if constexpr (nstages <= 1) {
|
|
766
767
|
__syncthreads(); // Only needed if tile_K == tile_V.
|
|
767
768
|
}
|
|
768
769
|
}
|
|
769
770
|
#else
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
|
776
|
-
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
|
|
777
|
-
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
|
|
778
|
-
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
|
|
771
|
+
GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup,
|
|
772
|
+
scale, slope, logit_softcap, ne01, ne02,
|
|
773
|
+
stride_K, stride_V, stride_mask,
|
|
774
|
+
tile_Q, tile_K, tile_V, tile_mask,
|
|
775
|
+
Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
|
779
776
|
NO_DEVICE_CODE;
|
|
780
|
-
#endif //
|
|
777
|
+
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
781
778
|
}
|
|
782
779
|
|
|
783
|
-
|
|
780
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
781
|
+
template<int ncols> struct mma_tile_sizes {
|
|
782
|
+
using T_A_KQ = tile<16, 8, half2>; // row-major
|
|
783
|
+
using T_B_KQ = tile<16, 8, half2>; // column-major
|
|
784
|
+
using T_C_KQ = tile<16, 16, float>; // column-major
|
|
785
|
+
using T_A_VKQ = tile<16, 8, half2>; // row-major
|
|
786
|
+
using T_B_VKQ = tile<16, 8, half2>; // column-major
|
|
787
|
+
using T_C_VKQ = tile<16, 8, half2>; // column-major
|
|
788
|
+
};
|
|
789
|
+
template<> struct mma_tile_sizes<8> {
|
|
790
|
+
using T_A_KQ = tile<16, 8, half2>; // row-major
|
|
791
|
+
using T_B_KQ = tile< 8, 8, half2>; // column-major
|
|
792
|
+
using T_C_KQ = tile<16, 8, float>; // row-major
|
|
793
|
+
using T_A_VKQ = tile<16, 8, half2>; // row-major
|
|
794
|
+
using T_B_VKQ = tile< 8, 8, half2>; // column-major
|
|
795
|
+
using T_C_VKQ = tile<16, 4, half2>; // row-major
|
|
796
|
+
};
|
|
797
|
+
#else // Volta
|
|
798
|
+
template<int ncols> struct mma_tile_sizes {
|
|
799
|
+
using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
|
800
|
+
using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
801
|
+
using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
802
|
+
using T_A_VKQ = tile< 8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED>; // column-major
|
|
803
|
+
using T_B_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
804
|
+
using T_C_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
805
|
+
};
|
|
806
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
807
|
+
|
|
808
|
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
|
|
784
809
|
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
785
810
|
const float2 * const __restrict__ Q_f2,
|
|
786
811
|
const half2 * const __restrict__ K_h2,
|
|
787
812
|
const half2 * const __restrict__ V_h2,
|
|
788
|
-
const
|
|
813
|
+
const half * const __restrict__ mask_h,
|
|
814
|
+
const float * const __restrict__ sinks_f,
|
|
789
815
|
float2 * const __restrict__ dstk,
|
|
790
816
|
float2 * const __restrict__ dstk_fixup,
|
|
791
817
|
const float scale,
|
|
792
818
|
const float slope,
|
|
793
819
|
const float logit_softcap,
|
|
794
|
-
const
|
|
820
|
+
const uint3 ne01,
|
|
795
821
|
const int ne02,
|
|
822
|
+
const int ne11,
|
|
796
823
|
const int stride_Q1,
|
|
797
824
|
const int stride_Q2,
|
|
798
825
|
const int stride_K,
|
|
@@ -801,23 +828,31 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
801
828
|
const int jt,
|
|
802
829
|
const int kb0_start,
|
|
803
830
|
const int kb0_stop) {
|
|
804
|
-
#
|
|
831
|
+
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
805
832
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
|
806
833
|
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
constexpr int
|
|
816
|
-
constexpr int
|
|
817
|
-
constexpr int
|
|
818
|
-
constexpr int
|
|
819
|
-
constexpr int
|
|
820
|
-
constexpr int
|
|
834
|
+
constexpr int ncols = ncols1 * ncols2;
|
|
835
|
+
using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ;
|
|
836
|
+
using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ;
|
|
837
|
+
using T_C_KQ = typename mma_tile_sizes<ncols>::T_C_KQ;
|
|
838
|
+
using T_A_VKQ = typename mma_tile_sizes<ncols>::T_A_VKQ;
|
|
839
|
+
using T_B_VKQ = typename mma_tile_sizes<ncols>::T_B_VKQ;
|
|
840
|
+
using T_C_VKQ = typename mma_tile_sizes<ncols>::T_C_VKQ;
|
|
841
|
+
|
|
842
|
+
constexpr int cols_per_warp = T_B_KQ::I;
|
|
843
|
+
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
|
844
|
+
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
|
845
|
+
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
|
846
|
+
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
|
847
|
+
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
|
848
|
+
constexpr int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols);
|
|
849
|
+
constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);
|
|
850
|
+
constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2);
|
|
851
|
+
|
|
852
|
+
if (cols_per_warp > ncols) {
|
|
853
|
+
NO_DEVICE_CODE;
|
|
854
|
+
return;
|
|
855
|
+
}
|
|
821
856
|
|
|
822
857
|
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
|
|
823
858
|
|
|
@@ -829,15 +864,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
829
864
|
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
|
830
865
|
|
|
831
866
|
extern __shared__ half2 tile_Q[];
|
|
832
|
-
half2 * tile_K =
|
|
833
|
-
half2 * tile_V =
|
|
834
|
-
|
|
867
|
+
half2 * tile_K = Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q;
|
|
868
|
+
half2 * tile_V = nstages > 1 ? tile_K + nbatch_fa * stride_tile_K : tile_K;
|
|
869
|
+
half * tile_mask = (half *) (nstages > 1 ? tile_V + nbatch_fa * stride_tile_V : tile_V + nbatch_fa * stride_tile_KV_max);
|
|
835
870
|
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
871
|
+
T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
|
|
872
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
873
|
+
T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
|
|
874
|
+
#else // Volta
|
|
875
|
+
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
|
|
876
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
841
877
|
|
|
842
878
|
float KQ_rowsum[cols_per_thread] = {0.0f};
|
|
843
879
|
float KQ_max[cols_per_thread];
|
|
@@ -871,7 +907,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
871
907
|
const int j = jc / ncols2;
|
|
872
908
|
const int c = jc % ncols2;
|
|
873
909
|
|
|
874
|
-
if (jt*ncols1 + j < ne01) {
|
|
910
|
+
if (jt*ncols1 + j < int(ne01.z)) {
|
|
875
911
|
#pragma unroll
|
|
876
912
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
877
913
|
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
|
@@ -892,62 +928,93 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
892
928
|
|
|
893
929
|
__syncthreads();
|
|
894
930
|
|
|
895
|
-
if (
|
|
931
|
+
if (Q_in_reg) {
|
|
896
932
|
const int j0 = (threadIdx.y / np) * cols_per_warp;
|
|
897
933
|
|
|
898
934
|
#pragma unroll
|
|
899
|
-
for (int k0 = 0; k0 < DKQ/2; k0 +=
|
|
900
|
-
|
|
901
|
-
load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
|
|
902
|
-
} else {
|
|
903
|
-
#pragma unroll
|
|
904
|
-
for (int t = 0; t < ntiles/2; ++t) {
|
|
905
|
-
load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
|
|
906
|
-
tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q);
|
|
907
|
-
}
|
|
908
|
-
}
|
|
935
|
+
for (int k0 = 0; k0 < DKQ/2; k0 += T_B_KQ::J) {
|
|
936
|
+
load_ldmatrix(Q_B[k0/T_B_KQ::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
|
|
909
937
|
}
|
|
910
938
|
}
|
|
911
939
|
|
|
912
940
|
__syncthreads();
|
|
913
941
|
|
|
942
|
+
int kb0 = kb0_start;
|
|
943
|
+
|
|
914
944
|
// Preload mask and K data for first iteration when using cp_async with multiple stages:
|
|
915
945
|
if constexpr (nstages > 1) {
|
|
916
946
|
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
|
|
917
947
|
constexpr bool use_cp_async = true;
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
948
|
+
constexpr bool oob_check = false;
|
|
949
|
+
constexpr int k_VKQ_sup = nbatch_fa;
|
|
950
|
+
if (ncols2 > 1 || mask_h) {
|
|
951
|
+
flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
952
|
+
(mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
|
|
953
|
+
}
|
|
954
|
+
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
955
|
+
(K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
|
|
956
|
+
}
|
|
957
|
+
|
|
958
|
+
// kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
|
959
|
+
if constexpr (ncols2 == 1) {
|
|
960
|
+
constexpr bool oob_check = true;
|
|
961
|
+
for (; kb0 < kb0_stop-1; ++kb0) {
|
|
962
|
+
constexpr bool last_iter = false;
|
|
963
|
+
constexpr int k_VKQ_sup = nbatch_fa;
|
|
964
|
+
flash_attn_ext_f16_iter
|
|
965
|
+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
|
966
|
+
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
967
|
+
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
968
|
+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
969
|
+
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
|
921
970
|
}
|
|
922
|
-
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
|
923
|
-
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
|
924
|
-
}
|
|
925
|
-
|
|
926
|
-
// Iterate over ne11 == previous tokens:
|
|
927
|
-
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
|
928
|
-
constexpr bool last_iter = false;
|
|
929
|
-
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
|
930
|
-
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
931
|
-
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
|
932
|
-
}
|
|
933
|
-
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
|
934
971
|
constexpr bool last_iter = true;
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
972
|
+
const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
|
|
973
|
+
flash_attn_ext_f16_iter
|
|
974
|
+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
|
975
|
+
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
976
|
+
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
977
|
+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
978
|
+
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
|
979
|
+
} else {
|
|
980
|
+
constexpr bool oob_check = false;
|
|
981
|
+
for (; kb0 < kb0_stop-1; ++kb0) {
|
|
982
|
+
constexpr bool last_iter = false;
|
|
983
|
+
constexpr int k_VKQ_sup = nbatch_fa;
|
|
984
|
+
flash_attn_ext_f16_iter
|
|
985
|
+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
|
986
|
+
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
987
|
+
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
988
|
+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
989
|
+
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
|
990
|
+
}
|
|
991
|
+
constexpr bool last_iter = true;
|
|
992
|
+
constexpr int k_VKQ_sup = nbatch_fa;
|
|
993
|
+
flash_attn_ext_f16_iter
|
|
994
|
+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
|
995
|
+
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
996
|
+
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
997
|
+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
998
|
+
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
|
938
999
|
}
|
|
939
1000
|
|
|
940
1001
|
// With multi-stage loading there is no __syncthreads at the end of the iter,
|
|
941
1002
|
// there can be a race condition on shared memory access for combining/writing back results.
|
|
942
|
-
if (nstages > 1 && nwarps*cols_per_warp >
|
|
1003
|
+
if constexpr (nstages > 1 && nwarps*cols_per_warp > nbatch_fa) {
|
|
943
1004
|
__syncthreads();
|
|
944
1005
|
}
|
|
945
1006
|
|
|
946
1007
|
// Finally, sum up partial KQ rowsums.
|
|
947
|
-
// The partial sums are spread across 8/4 threads each, does not need full reduce.
|
|
948
1008
|
{
|
|
949
|
-
|
|
950
|
-
|
|
1009
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
1010
|
+
// The partial sums are spread across 8/4 threads.
|
|
1011
|
+
constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
|
|
1012
|
+
constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
|
|
1013
|
+
#else // Volta
|
|
1014
|
+
// The partial sums are spread across 2 threads.
|
|
1015
|
+
constexpr int offset_first = 2;
|
|
1016
|
+
constexpr int offset_last = 2;
|
|
1017
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
951
1018
|
#pragma unroll
|
|
952
1019
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
953
1020
|
#pragma unroll
|
|
@@ -957,20 +1024,76 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
957
1024
|
}
|
|
958
1025
|
}
|
|
959
1026
|
|
|
1027
|
+
// If attention sinks are used, potentially re-scale if KQ_max is small.
|
|
1028
|
+
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
|
|
1029
|
+
// so it's being done unconditionally for every thread.
|
|
1030
|
+
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
|
|
1031
|
+
float KQ_max_scale[cols_per_thread];
|
|
1032
|
+
#pragma unroll
|
|
1033
|
+
for (int col = 0; col < cols_per_thread; ++col) {
|
|
1034
|
+
const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);
|
|
1035
|
+
const float sink = sinks_f[jc % ncols2];
|
|
1036
|
+
|
|
1037
|
+
const float KQ_max_new = fmaxf(KQ_max[col], sink);
|
|
1038
|
+
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
|
|
1039
|
+
KQ_max_scale[col] = expf(KQ_max_diff);
|
|
1040
|
+
KQ_max[col] = KQ_max_new;
|
|
1041
|
+
|
|
1042
|
+
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
|
|
1043
|
+
|
|
1044
|
+
const float KQ_max_add = expf(sink - KQ_max_new);
|
|
1045
|
+
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
|
|
1046
|
+
}
|
|
1047
|
+
|
|
1048
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
1049
|
+
if constexpr (cols_per_warp == 8) {
|
|
1050
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
|
1051
|
+
#pragma unroll
|
|
1052
|
+
for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
|
|
1053
|
+
#pragma unroll
|
|
1054
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
1055
|
+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
1056
|
+
}
|
|
1057
|
+
}
|
|
1058
|
+
} else {
|
|
1059
|
+
#pragma unroll
|
|
1060
|
+
for (int col = 0; col < cols_per_thread; ++col) {
|
|
1061
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
|
1062
|
+
#pragma unroll
|
|
1063
|
+
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
|
1064
|
+
#pragma unroll
|
|
1065
|
+
for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {
|
|
1066
|
+
VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;
|
|
1067
|
+
}
|
|
1068
|
+
}
|
|
1069
|
+
}
|
|
1070
|
+
}
|
|
1071
|
+
#else // Volta
|
|
1072
|
+
const int col = (threadIdx.x / 2) % 2;
|
|
1073
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
|
1074
|
+
#pragma unroll
|
|
1075
|
+
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
|
1076
|
+
#pragma unroll
|
|
1077
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
1078
|
+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
1079
|
+
}
|
|
1080
|
+
}
|
|
1081
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
1082
|
+
}
|
|
1083
|
+
|
|
960
1084
|
// Combine VKQ accumulator values if np > 1.
|
|
961
1085
|
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
|
962
1086
|
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
|
963
1087
|
|
|
964
|
-
constexpr int
|
|
965
|
-
constexpr int tile_stride = nbatch_combine + 4;
|
|
1088
|
+
constexpr int tile_stride = nbatch_combine + 4;
|
|
966
1089
|
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
|
|
967
1090
|
|
|
968
|
-
if constexpr (
|
|
969
|
-
const int jc_cwmo = (threadIdx.x % (2*
|
|
970
|
-
const int jc_cwm = threadIdx.y*(2*
|
|
1091
|
+
if constexpr (cols_per_warp == 8) {
|
|
1092
|
+
const int jc_cwmo = (threadIdx.x % (2*T_C_VKQ::J)) / T_C_VKQ::J; // jc combine write meta offset
|
|
1093
|
+
const int jc_cwm = threadIdx.y*(2*T_C_VKQ::J) + 2*T_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
|
|
971
1094
|
const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum
|
|
972
1095
|
|
|
973
|
-
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*
|
|
1096
|
+
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*T_C_VKQ::J) {
|
|
974
1097
|
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
|
975
1098
|
((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
|
|
976
1099
|
}
|
|
@@ -979,24 +1102,30 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
979
1102
|
|
|
980
1103
|
if (np == 1) {
|
|
981
1104
|
// No combination is needed, the meta data can be directly written from registers to VRAM.
|
|
982
|
-
if (needs_fixup && threadIdx.x <
|
|
1105
|
+
if (needs_fixup && threadIdx.x < T_B_KQ::I) {
|
|
983
1106
|
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
|
984
1107
|
dstk_fixup_meta[jc_cwm] = KQ_cmr;
|
|
985
1108
|
}
|
|
986
|
-
if (is_fixup && threadIdx.x <
|
|
1109
|
+
if (is_fixup && threadIdx.x < T_B_KQ::I) {
|
|
987
1110
|
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
|
988
1111
|
dstk_fixup_meta[jc_cwm] = KQ_cmr;
|
|
989
1112
|
}
|
|
990
1113
|
}
|
|
991
1114
|
} else {
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
const
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1115
|
+
// jc_cwm = jc combine write meta
|
|
1116
|
+
// KQ_cmr = KQ combine max rowsum
|
|
1117
|
+
// Use the 16 bytes of padding in each Q column to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
|
1118
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
1119
|
+
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
|
|
1120
|
+
const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
|
|
1121
|
+
const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
|
|
1122
|
+
#else // Volta
|
|
1123
|
+
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
|
|
1124
|
+
const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
|
|
1125
|
+
const bool thread_should_write = T_C_KQ::J == 8 || T_C_KQ::get_j(threadIdx.x & 2) < 8;
|
|
1126
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
1127
|
+
|
|
1128
|
+
if (((!needs_fixup && !is_fixup) || np > 1) && thread_should_write) {
|
|
1000
1129
|
((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
|
|
1001
1130
|
}
|
|
1002
1131
|
|
|
@@ -1004,18 +1133,17 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1004
1133
|
|
|
1005
1134
|
if (np == 1) {
|
|
1006
1135
|
// No combination is needed, the meta data can be directly written from registers to VRAM.
|
|
1007
|
-
if (needs_fixup &&
|
|
1136
|
+
if (needs_fixup && thread_should_write) {
|
|
1008
1137
|
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
|
1009
1138
|
dstk_fixup_meta[jc_cwm] = KQ_cmr;
|
|
1010
1139
|
}
|
|
1011
|
-
if (is_fixup &&
|
|
1140
|
+
if (is_fixup && thread_should_write) {
|
|
1012
1141
|
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
|
1013
1142
|
dstk_fixup_meta[jc_cwm] = KQ_cmr;
|
|
1014
1143
|
}
|
|
1015
1144
|
}
|
|
1016
1145
|
}
|
|
1017
1146
|
|
|
1018
|
-
static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles");
|
|
1019
1147
|
if (np > 1 && threadIdx.y % np == 0) {
|
|
1020
1148
|
// Combine the meta data for parallel warps via shared memory.
|
|
1021
1149
|
// Warps with threadIdx.y % np != 0 must NOT return early.
|
|
@@ -1091,32 +1219,29 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1091
1219
|
|
|
1092
1220
|
#pragma unroll
|
|
1093
1221
|
for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
|
|
1094
|
-
if (
|
|
1095
|
-
const int jc_cwd = threadIdx.y*
|
|
1222
|
+
if constexpr (cols_per_warp == 8) {
|
|
1223
|
+
const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data
|
|
1096
1224
|
#pragma unroll
|
|
1097
|
-
for (int
|
|
1098
|
-
const
|
|
1225
|
+
for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) {
|
|
1226
|
+
const T_B_KQ B = get_transposed(VKQ_C[(k00 + k1)/T_B_KQ::J]); // Conversion of C to B matrix puts it in column-major format.
|
|
1099
1227
|
|
|
1100
1228
|
#pragma unroll
|
|
1101
|
-
for (int l = 0; l <
|
|
1102
|
-
const int k =
|
|
1229
|
+
for (int l = 0; l < T_B_KQ::ne; ++l) {
|
|
1230
|
+
const int k = k1 + T_B_KQ::get_j(l);
|
|
1103
1231
|
|
|
1104
1232
|
tile_Q[jc_cwd*tile_stride + k] = B.x[l];
|
|
1105
1233
|
}
|
|
1106
1234
|
}
|
|
1107
1235
|
} else {
|
|
1236
|
+
const int j0 = threadIdx.y*cols_per_warp;
|
|
1108
1237
|
#pragma unroll
|
|
1109
|
-
for (int
|
|
1110
|
-
const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
|
|
1111
|
-
#pragma unroll
|
|
1112
|
-
for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) {
|
|
1238
|
+
for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {
|
|
1113
1239
|
#pragma unroll
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1240
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
1241
|
+
const int j = j0 + T_C_VKQ::get_i(l);
|
|
1242
|
+
const int k = k1 + T_C_VKQ::get_j(l);
|
|
1117
1243
|
|
|
1118
|
-
|
|
1119
|
-
}
|
|
1244
|
+
tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l];
|
|
1120
1245
|
}
|
|
1121
1246
|
}
|
|
1122
1247
|
}
|
|
@@ -1151,7 +1276,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1151
1276
|
const int j_dst = jc_dst / ncols2;
|
|
1152
1277
|
const int c_dst = jc_dst % ncols2;
|
|
1153
1278
|
|
|
1154
|
-
if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
|
|
1279
|
+
if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
|
|
1155
1280
|
continue;
|
|
1156
1281
|
}
|
|
1157
1282
|
|
|
@@ -1189,23 +1314,23 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1189
1314
|
}
|
|
1190
1315
|
}
|
|
1191
1316
|
#else
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
|
|
1197
|
-
GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
|
|
1317
|
+
GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
|
|
1318
|
+
scale, slope, logit_softcap, ne01, ne02,
|
|
1319
|
+
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
|
|
1320
|
+
jt, kb0_start, kb0_stop);
|
|
1198
1321
|
NO_DEVICE_CODE;
|
|
1199
|
-
#endif //
|
|
1322
|
+
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1200
1323
|
}
|
|
1201
1324
|
|
|
1202
|
-
template<int DKQ, int DV, int ncols1, int ncols2,
|
|
1203
|
-
__launch_bounds__(
|
|
1325
|
+
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool mla>
|
|
1326
|
+
__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
|
|
1204
1327
|
static __global__ void flash_attn_ext_f16(
|
|
1205
1328
|
const char * __restrict__ Q,
|
|
1206
1329
|
const char * __restrict__ K,
|
|
1207
1330
|
const char * __restrict__ V,
|
|
1208
1331
|
const char * __restrict__ mask,
|
|
1332
|
+
const char * __restrict__ sinks,
|
|
1333
|
+
const int * __restrict__ KV_max,
|
|
1209
1334
|
float * __restrict__ dst,
|
|
1210
1335
|
float2 * __restrict__ dst_meta,
|
|
1211
1336
|
const float scale,
|
|
@@ -1214,30 +1339,14 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1214
1339
|
const float m1,
|
|
1215
1340
|
const uint32_t n_head_log2,
|
|
1216
1341
|
const float logit_softcap,
|
|
1217
|
-
const
|
|
1218
|
-
|
|
1219
|
-
const
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
const int ne31,
|
|
1226
|
-
const int nb31,
|
|
1227
|
-
const int nb01,
|
|
1228
|
-
const int nb02,
|
|
1229
|
-
const int nb03,
|
|
1230
|
-
const int nb11,
|
|
1231
|
-
const int nb12,
|
|
1232
|
-
const int nb13,
|
|
1233
|
-
const int nb21,
|
|
1234
|
-
const int nb22,
|
|
1235
|
-
const int nb23,
|
|
1236
|
-
const int ne0,
|
|
1237
|
-
const int ne1,
|
|
1238
|
-
const int ne2,
|
|
1239
|
-
const int ne3) {
|
|
1240
|
-
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
|
1342
|
+
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
|
|
1343
|
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
1344
|
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
1345
|
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
1346
|
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
1347
|
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
1348
|
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
1349
|
+
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
|
|
1241
1350
|
|
|
1242
1351
|
// Skip unused kernel variants for faster compilation:
|
|
1243
1352
|
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
|
@@ -1253,27 +1362,26 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1253
1362
|
|
|
1254
1363
|
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
|
|
1255
1364
|
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1365
|
+
constexpr int ncols = ncols1 * ncols2;
|
|
1366
|
+
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
|
1367
|
+
constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
|
|
1368
|
+
constexpr int nwarps = nthreads / WARP_SIZE;
|
|
1259
1369
|
|
|
1260
1370
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
1261
1371
|
|
|
1262
1372
|
const int stride_Q1 = nb01 / sizeof(float2);
|
|
1263
1373
|
const int stride_Q2 = nb02 / sizeof(float2);
|
|
1264
1374
|
const int stride_K = nb11 / sizeof(half2);
|
|
1265
|
-
const int stride_mask = nb31 / sizeof(
|
|
1375
|
+
const int stride_mask = nb31 / sizeof(half);
|
|
1266
1376
|
|
|
1267
1377
|
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
|
|
1268
1378
|
|
|
1269
|
-
const int iter_k = ne11 /
|
|
1270
|
-
const int iter_j = (ne01 + (ncols1
|
|
1271
|
-
|
|
1272
|
-
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
|
|
1379
|
+
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
|
1380
|
+
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
|
|
1273
1381
|
|
|
1274
1382
|
// kbc == k block continuous, current index in continuous ijk space.
|
|
1275
|
-
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
1276
|
-
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
1383
|
+
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
1384
|
+
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
1277
1385
|
|
|
1278
1386
|
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
|
1279
1387
|
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
|
@@ -1282,33 +1390,39 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1282
1390
|
// kb0 == k start index when in the output tile.
|
|
1283
1391
|
int kb0_start = kbc % iter_k;
|
|
1284
1392
|
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
|
1393
|
+
|
|
1285
1394
|
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
|
1286
|
-
const int
|
|
1287
|
-
const int
|
|
1395
|
+
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
|
1396
|
+
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
|
|
1397
|
+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
|
|
1288
1398
|
|
|
1289
|
-
const
|
|
1290
|
-
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
|
1291
|
-
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
|
1292
|
-
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
|
1399
|
+
const int head0 = zt * ncols2;
|
|
1293
1400
|
|
|
1294
|
-
const
|
|
1401
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
|
|
1402
|
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
|
|
1403
|
+
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
|
|
1404
|
+
(const half *) (mask + nb33*(sequence % ne33));
|
|
1405
|
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
|
|
1295
1406
|
|
|
1296
|
-
const
|
|
1407
|
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
|
1408
|
+
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
|
|
1297
1409
|
|
|
1298
|
-
const
|
|
1299
|
-
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
1410
|
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
|
1300
1411
|
|
|
1412
|
+
if (KV_max) {
|
|
1413
|
+
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
|
|
1414
|
+
}
|
|
1301
1415
|
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
|
1302
1416
|
if (kb0_start == 0) {
|
|
1303
1417
|
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
|
1304
|
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps,
|
|
1305
|
-
(Q_f2, K_h2, V_h2,
|
|
1306
|
-
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt,
|
|
1418
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
|
|
1419
|
+
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
1420
|
+
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
|
1307
1421
|
} else {
|
|
1308
|
-
constexpr bool needs_fixup = true; // CUDA block is
|
|
1309
|
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps,
|
|
1310
|
-
(Q_f2, K_h2, V_h2,
|
|
1311
|
-
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt,
|
|
1422
|
+
constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
|
|
1423
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
|
|
1424
|
+
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
1425
|
+
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
|
1312
1426
|
}
|
|
1313
1427
|
|
|
1314
1428
|
kbc += iter_k;
|
|
@@ -1322,39 +1436,44 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1322
1436
|
return;
|
|
1323
1437
|
}
|
|
1324
1438
|
|
|
1325
|
-
const int
|
|
1326
|
-
const int
|
|
1439
|
+
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
|
1440
|
+
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
|
|
1441
|
+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
|
|
1327
1442
|
|
|
1328
|
-
const
|
|
1329
|
-
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
|
1330
|
-
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
|
1331
|
-
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
|
1443
|
+
const int head0 = zt * ncols2;
|
|
1332
1444
|
|
|
1333
|
-
const
|
|
1445
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
|
|
1446
|
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
|
|
1447
|
+
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
|
|
1448
|
+
(const half *) (mask + nb33*(sequence % ne33));
|
|
1449
|
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
|
|
1334
1450
|
|
|
1335
|
-
const
|
|
1451
|
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
|
1452
|
+
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
|
|
1336
1453
|
|
|
1337
|
-
const
|
|
1338
|
-
|
|
1454
|
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
|
1455
|
+
|
|
1456
|
+
if (KV_max) {
|
|
1457
|
+
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
|
|
1458
|
+
}
|
|
1339
1459
|
|
|
1340
1460
|
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
|
1341
1461
|
constexpr bool needs_fixup = false;
|
|
1342
|
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps,
|
|
1343
|
-
(Q_f2, K_h2, V_h2,
|
|
1344
|
-
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt,
|
|
1462
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
|
|
1463
|
+
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
1464
|
+
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
|
1345
1465
|
#else
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
1466
|
+
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
|
1467
|
+
max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
1468
|
+
ne00, ne01, ne02, ne03,
|
|
1469
|
+
nb01, nb02, nb03,
|
|
1470
|
+
ne10, ne11, ne12, ne13,
|
|
1471
|
+
nb11, nb12, nb13,
|
|
1472
|
+
nb21, nb22, nb23,
|
|
1473
|
+
ne31, ne32, ne33,
|
|
1474
|
+
nb31, nb32, nb33);
|
|
1356
1475
|
NO_DEVICE_CODE;
|
|
1357
|
-
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(
|
|
1476
|
+
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
|
|
1358
1477
|
}
|
|
1359
1478
|
|
|
1360
1479
|
template <int DKQ, int DV, int ncols1, int ncols2>
|
|
@@ -1363,36 +1482,30 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
1363
1482
|
const int id = ggml_cuda_get_device();
|
|
1364
1483
|
const int cc = ggml_cuda_info().devices[id].cc;
|
|
1365
1484
|
|
|
1366
|
-
|
|
1485
|
+
constexpr int ncols = ncols1 * ncols2;
|
|
1367
1486
|
|
|
1368
|
-
const int
|
|
1487
|
+
const int nthreads = ggml_cuda_fattn_mma_get_nthreads (DKQ, DV, ncols, cc);
|
|
1488
|
+
const int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols, cc);
|
|
1489
|
+
const int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols, cc);
|
|
1490
|
+
const int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols, cc);
|
|
1491
|
+
const int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols, cc);
|
|
1492
|
+
const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc);
|
|
1493
|
+
const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
|
|
1369
1494
|
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
constexpr int cols_per_warp = ntiles * tile_B::I;
|
|
1373
|
-
constexpr int nwarps_max_x = ncols / cols_per_warp;
|
|
1374
|
-
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
|
|
1375
|
-
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
|
|
1495
|
+
const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32);
|
|
1496
|
+
const int nwarps = nthreads / WARP_SIZE;
|
|
1376
1497
|
|
|
1377
1498
|
constexpr bool mla = DKQ == 576;
|
|
1378
1499
|
|
|
1379
|
-
const
|
|
1380
|
-
const
|
|
1381
|
-
const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
|
|
1382
|
-
|
|
1383
|
-
static_assert(DKQ % tile_B::J == 0, "bad DKQ");
|
|
1384
|
-
static_assert(DV % tile_A::J == 0, "bad DV");
|
|
1385
|
-
static_assert(ncols % cols_per_warp == 0, "bad ncols");
|
|
1386
|
-
|
|
1387
|
-
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
|
1388
|
-
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
|
1500
|
+
const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
|
1501
|
+
const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
|
1389
1502
|
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
|
|
1390
|
-
const size_t nbytes_shared_mask = ncols1 * (
|
|
1503
|
+
const size_t nbytes_shared_mask = ncols1 * (nbatch_fa/2 + 4) * sizeof(half2);
|
|
1391
1504
|
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
|
|
1392
1505
|
|
|
1393
1506
|
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
|
|
1394
1507
|
|
|
1395
|
-
const size_t nbytes_shared_total = std::max(nbytes_shared_combine,
|
|
1508
|
+
const size_t nbytes_shared_total = std::max(nbytes_shared_combine, Q_in_reg ?
|
|
1396
1509
|
std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) :
|
|
1397
1510
|
nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask);
|
|
1398
1511
|
|
|
@@ -1402,30 +1515,30 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
1402
1515
|
fattn_kernel_t fattn_kernel;
|
|
1403
1516
|
if (logit_softcap == 0.0f) {
|
|
1404
1517
|
constexpr bool use_logit_softcap = false;
|
|
1405
|
-
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2,
|
|
1518
|
+
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
|
|
1406
1519
|
|
|
1407
|
-
#if !
|
|
1520
|
+
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
|
1408
1521
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
1409
1522
|
if (!shared_memory_limit_raised[id]) {
|
|
1410
1523
|
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
1411
1524
|
shared_memory_limit_raised[id] = true;
|
|
1412
1525
|
}
|
|
1413
|
-
#endif // !
|
|
1526
|
+
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
|
1414
1527
|
} else {
|
|
1415
1528
|
constexpr bool use_logit_softcap = true;
|
|
1416
|
-
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2,
|
|
1529
|
+
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
|
|
1417
1530
|
|
|
1418
|
-
#if !
|
|
1531
|
+
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
|
1419
1532
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
1420
1533
|
if (!shared_memory_limit_raised[id]) {
|
|
1421
1534
|
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
1422
1535
|
shared_memory_limit_raised[id] = true;
|
|
1423
1536
|
}
|
|
1424
|
-
#endif // !
|
|
1537
|
+
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
|
1425
1538
|
}
|
|
1426
1539
|
|
|
1427
1540
|
launch_fattn<DV, ncols1, ncols2>
|
|
1428
|
-
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total,
|
|
1541
|
+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
|
|
1429
1542
|
}
|
|
1430
1543
|
|
|
1431
1544
|
|