whispercpp 1.3.4 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +158 -44
- data/ext/extconf.rb +3 -2
- data/ext/ruby_whisper.c +34 -6
- data/ext/ruby_whisper.h +67 -0
- data/ext/ruby_whisper_context.c +236 -144
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +12 -13
- data/ext/ruby_whisper_params.c +47 -24
- data/ext/ruby_whisper_segment.c +84 -20
- data/ext/ruby_whisper_token.c +371 -0
- data/ext/ruby_whisper_transcribe.cpp +5 -2
- data/ext/ruby_whisper_vad_context.c +122 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +138 -0
- data/ext/ruby_whisper_vad_segments.c +105 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +129 -112
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +28 -15
- data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
- data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
- data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
- data/ext/sources/examples/talk-llama/llama-context.h +70 -23
- data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
- data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
- data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
- data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
- data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
- data/ext/sources/examples/talk-llama/llama-model.h +112 -18
- data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
- data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
- data/ext/sources/examples/talk-llama/llama.cpp +802 -21
- data/ext/sources/examples/talk-llama/llama.h +210 -39
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +704 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +90 -56
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +5 -2
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +6 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +14 -12
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +246 -21
- data/ext/sources/ggml/src/CMakeLists.txt +85 -11
- data/ext/sources/ggml/src/ggml-alloc.c +128 -50
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
- data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
- data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
- data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
- data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
- data/ext/sources/ggml/src/ggml-impl.h +129 -6
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- data/ext/sources/ggml/src/ggml.c +590 -64
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +106 -62
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +162 -4
- data/test/test_context_params.rb +82 -0
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +81 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +100 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +27 -0
- data/whispercpp.gemspec +1 -1
- metadata +502 -37
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -5,284 +5,290 @@
|
|
|
5
5
|
|
|
6
6
|
using namespace ggml_cuda_mma;
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
typedef tile< 8, 8, half2> tile_B;
|
|
10
|
-
typedef tile<16, 8, half2> tile_B_16;
|
|
11
|
-
typedef tile<16, 8, float> tile_C_KQ;
|
|
12
|
-
typedef tile<16, 16, float> tile_C_KQ_16;
|
|
13
|
-
typedef tile<16, 4, half2> tile_C_VKQ;
|
|
14
|
-
typedef tile<16, 8, half2> tile_C_VKQ_16;
|
|
15
|
-
|
|
16
|
-
// Config options for specific head sizes.
|
|
8
|
+
// Config options for the MMA kernel.
|
|
17
9
|
// Should not affect results, only speed/register pressure/shared memory use.
|
|
18
|
-
|
|
19
|
-
//
|
|
20
|
-
//
|
|
21
|
-
//
|
|
22
|
-
//
|
|
23
|
-
//
|
|
24
|
-
//
|
|
25
|
-
//
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
static constexpr int nbatch_fa = 64;
|
|
33
|
-
static constexpr int nwarps_max = 4;
|
|
34
|
-
static constexpr bool Q_in_reg = true;
|
|
35
|
-
static constexpr int nstages_target = 2;
|
|
36
|
-
|
|
37
|
-
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
|
38
|
-
return 32;
|
|
39
|
-
}
|
|
40
|
-
|
|
41
|
-
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
|
42
|
-
return 32;
|
|
43
|
-
}
|
|
44
|
-
|
|
45
|
-
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
|
46
|
-
return 32;
|
|
47
|
-
}
|
|
48
|
-
|
|
49
|
-
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
|
50
|
-
return 32;
|
|
51
|
-
}
|
|
52
|
-
|
|
53
|
-
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
|
54
|
-
return 32;
|
|
55
|
-
}
|
|
56
|
-
|
|
57
|
-
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
|
58
|
-
return 32;
|
|
59
|
-
}
|
|
10
|
+
struct fattn_mma_config {
|
|
11
|
+
int nthreads; // Number of threads per CUDA block.
|
|
12
|
+
int occupancy; // Targeted occupancy for the MMA kernel.
|
|
13
|
+
int nbatch_fa; // Number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
|
|
14
|
+
int nbatch_K2; // Number of K half2 values in direction of DKQ to load in parallel.
|
|
15
|
+
int nbatch_V2; // Number of V half2 values in direction of DV to load in parallel.
|
|
16
|
+
int nbatch_combine; // Number of VKQ half2 values in direction of DV to combine in parallel.
|
|
17
|
+
int nstages_target; // Number of pipeline stages to use ideally, 1 == always load data synchronously, 2 == preload data if there is hardware support.
|
|
18
|
+
bool Q_in_reg; // Whether the Q values should be kept permanently in registers.
|
|
19
|
+
|
|
20
|
+
constexpr __host__ __device__ fattn_mma_config(
|
|
21
|
+
int nthreads, int occupancy, int nbatch_fa, int nbatch_K2, int nbatch_V2, int nbatch_combine, int nstages_target, bool Q_in_reg) :
|
|
22
|
+
nthreads(nthreads), occupancy(occupancy), nbatch_fa(nbatch_fa), nbatch_K2(nbatch_K2), nbatch_V2(nbatch_V2), nbatch_combine(nbatch_combine),
|
|
23
|
+
nstages_target(nstages_target), Q_in_reg(Q_in_reg) {}
|
|
60
24
|
};
|
|
61
25
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
26
|
+
#define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads_, occupancy_, nbatch_fa_, nbatch_K2_, nbatch_V2_, nbatch_combine_, nstages_target_, Q_in_reg_) \
|
|
27
|
+
if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \
|
|
28
|
+
static_assert((nthreads_) % 32 == 0 && (nthreads_) <= 512, "bad nthreads"); \
|
|
29
|
+
static_assert( (occupancy_) <= 8, "bad occupancy"); \
|
|
30
|
+
static_assert((nbatch_fa_) % 32 == 0 && (nbatch_fa_) <= 256, "bad nbatch_fa"); \
|
|
31
|
+
static_assert((nbatch_K2_) % 4 == 0 && (nbatch_K2_) <= 512, "bad nbatch_K2"); \
|
|
32
|
+
static_assert((nbatch_V2_) % 4 == 0 && (nbatch_V2_) <= 256, "bad nbatch_V2"); \
|
|
33
|
+
static_assert((nbatch_combine_) % 4 == 0 && (nbatch_combine_) <= 128, "bad nbatch_combine"); \
|
|
34
|
+
static_assert((nstages_target_) >= 1 && (nstages_target_) <= 2, "bad nstages_target"); \
|
|
35
|
+
return fattn_mma_config{(nthreads_), (occupancy_), (nbatch_fa_), (nbatch_K2_), (nbatch_V2_), (nbatch_combine_), (nstages_target_), (Q_in_reg_)}; \
|
|
36
|
+
} \
|
|
37
|
+
|
|
38
|
+
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) {
|
|
39
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 2, true);
|
|
40
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 2, true);
|
|
41
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 2, true);
|
|
42
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 2, true);
|
|
43
|
+
|
|
44
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 2, true);
|
|
45
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 2, true);
|
|
46
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 2, true);
|
|
47
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 2, true);
|
|
48
|
+
|
|
49
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 2, true);
|
|
50
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 2, true);
|
|
51
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 2, true);
|
|
52
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 2, true);
|
|
53
|
+
|
|
54
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 2, true);
|
|
55
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 2, true);
|
|
56
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 2, true);
|
|
57
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 2, true);
|
|
58
|
+
|
|
59
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 2, true);
|
|
60
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 2, true);
|
|
61
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true);
|
|
62
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true);
|
|
63
|
+
|
|
64
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true);
|
|
65
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true);
|
|
66
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
|
67
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
|
68
|
+
|
|
69
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
|
|
70
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
|
|
71
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
|
72
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
|
73
|
+
|
|
74
|
+
return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
|
|
75
|
+
}
|
|
76
76
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
77
|
+
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) {
|
|
78
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 128, 2, 64, 128, 128, 128, 2, true);
|
|
79
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true);
|
|
80
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
|
81
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
|
80
82
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
83
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
|
84
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
|
85
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
|
86
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
|
84
87
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
}
|
|
88
|
+
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
89
|
+
}
|
|
88
90
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
91
|
+
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
|
|
92
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
|
|
93
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
|
|
94
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
|
|
95
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
|
|
93
96
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
static constexpr int nwarps_max = 4;
|
|
98
|
-
static constexpr bool Q_in_reg = true;
|
|
99
|
-
static constexpr int nstages_target = 2;
|
|
97
|
+
// TODO tune specifically for Volta
|
|
98
|
+
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
99
|
+
}
|
|
100
100
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
101
|
+
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) {
|
|
102
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true);
|
|
103
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
|
104
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
|
104
105
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
106
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
|
107
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
|
108
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
|
108
109
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
110
|
+
// TODO tune specifically for RDNA
|
|
111
|
+
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
112
|
+
}
|
|
112
113
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
114
|
+
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) {
|
|
115
|
+
// Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async).
|
|
116
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 1, true);
|
|
117
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true);
|
|
118
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true);
|
|
119
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 32, 32, 32, 1, true);
|
|
120
|
+
|
|
121
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 1, true);
|
|
122
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 1, true);
|
|
123
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true);
|
|
124
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true);
|
|
125
|
+
|
|
126
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 1, true);
|
|
127
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 1, true);
|
|
128
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true);
|
|
129
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true);
|
|
130
|
+
|
|
131
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 1, true);
|
|
132
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 1, true);
|
|
133
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true);
|
|
134
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true);
|
|
135
|
+
|
|
136
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 1, true);
|
|
137
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 1, true);
|
|
138
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true);
|
|
139
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true);
|
|
140
|
+
|
|
141
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 1, true);
|
|
142
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 1, true);
|
|
143
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true);
|
|
144
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true);
|
|
145
|
+
|
|
146
|
+
// Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy
|
|
147
|
+
// compile-time static_asserts even though the kernel guard prevents runtime execution.
|
|
148
|
+
// nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility.
|
|
149
|
+
return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false);
|
|
150
|
+
}
|
|
116
151
|
|
|
117
|
-
|
|
118
|
-
|
|
152
|
+
static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
153
|
+
if (ampere_mma_available(cc)) {
|
|
154
|
+
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
119
155
|
}
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
return 48;
|
|
156
|
+
if (turing_mma_available(cc)) {
|
|
157
|
+
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
|
|
123
158
|
}
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
template <>
|
|
127
|
-
struct fattn_mma_f16_config<112, 112> {
|
|
128
|
-
static constexpr int nbatch_fa = 64;
|
|
129
|
-
static constexpr int nwarps_max = 4;
|
|
130
|
-
static constexpr bool Q_in_reg = true;
|
|
131
|
-
static constexpr int nstages_target = 2;
|
|
132
|
-
|
|
133
|
-
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
|
134
|
-
return 56;
|
|
159
|
+
if (amd_mfma_available(cc)) {
|
|
160
|
+
return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
|
|
135
161
|
}
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
return 56;
|
|
162
|
+
if (amd_wmma_available(cc)) {
|
|
163
|
+
return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
|
|
139
164
|
}
|
|
165
|
+
GGML_ASSERT(volta_mma_available(cc));
|
|
166
|
+
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
|
|
167
|
+
}
|
|
140
168
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
169
|
+
static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) {
|
|
170
|
+
#if defined(AMPERE_MMA_AVAILABLE)
|
|
171
|
+
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
172
|
+
#elif defined(TURING_MMA_AVAILABLE)
|
|
173
|
+
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
|
|
174
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
175
|
+
return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
|
|
176
|
+
#elif defined(VOLTA_MMA_AVAILABLE)
|
|
177
|
+
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
|
|
178
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
179
|
+
return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
|
|
180
|
+
#else
|
|
181
|
+
GGML_UNUSED_VARS(DKQ, DV, ncols);
|
|
182
|
+
return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
|
|
183
|
+
#endif // defined(AMPERE_MMA_AVAILABLE)
|
|
184
|
+
}
|
|
144
185
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
186
|
+
static __host__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
187
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nthreads;
|
|
188
|
+
}
|
|
148
189
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
190
|
+
static constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols) {
|
|
191
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nthreads;
|
|
192
|
+
}
|
|
152
193
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
};
|
|
194
|
+
static __host__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
195
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).occupancy;
|
|
196
|
+
}
|
|
157
197
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
static constexpr int nwarps_max = 4;
|
|
162
|
-
static constexpr bool Q_in_reg = true;
|
|
163
|
-
static constexpr int nstages_target = 2;
|
|
198
|
+
static constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols) {
|
|
199
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).occupancy;
|
|
200
|
+
}
|
|
164
201
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
202
|
+
static __host__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
203
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_fa;
|
|
204
|
+
}
|
|
168
205
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
206
|
+
static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
|
|
207
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_fa;
|
|
208
|
+
}
|
|
172
209
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
210
|
+
static __host__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
211
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_K2;
|
|
212
|
+
}
|
|
176
213
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
214
|
+
static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols) {
|
|
215
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_K2;
|
|
216
|
+
}
|
|
180
217
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
218
|
+
static __host__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
219
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_V2;
|
|
220
|
+
}
|
|
184
221
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
};
|
|
222
|
+
static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols) {
|
|
223
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_V2;
|
|
224
|
+
}
|
|
189
225
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
static constexpr int nwarps_max = 4;
|
|
194
|
-
static constexpr bool Q_in_reg = true;
|
|
195
|
-
static constexpr int nstages_target = 2;
|
|
226
|
+
static __host__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
227
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_combine;
|
|
228
|
+
}
|
|
196
229
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
230
|
+
static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols) {
|
|
231
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_combine;
|
|
232
|
+
}
|
|
200
233
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
234
|
+
static __host__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
235
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nstages_target;
|
|
236
|
+
}
|
|
204
237
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
238
|
+
static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols) {
|
|
239
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nstages_target;
|
|
240
|
+
}
|
|
208
241
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
242
|
+
static __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
243
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).Q_in_reg;
|
|
244
|
+
}
|
|
212
245
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
}
|
|
217
|
-
return 64;
|
|
218
|
-
}
|
|
246
|
+
static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) {
|
|
247
|
+
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
|
|
248
|
+
}
|
|
219
249
|
|
|
220
|
-
|
|
221
|
-
#if
|
|
222
|
-
|
|
250
|
+
static constexpr __device__ int get_cols_per_thread() {
|
|
251
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
252
|
+
return 1; // AMD has a single column per thread.
|
|
223
253
|
#else
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
}
|
|
228
|
-
};
|
|
229
|
-
|
|
230
|
-
template <>
|
|
231
|
-
struct fattn_mma_f16_config<576, 512> {
|
|
232
|
-
static constexpr int nbatch_fa = 32;
|
|
233
|
-
static constexpr int nwarps_max = 8;
|
|
234
|
-
static constexpr bool Q_in_reg = false;
|
|
235
|
-
static constexpr int nstages_target = 1;
|
|
254
|
+
return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
|
255
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
256
|
+
}
|
|
236
257
|
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
258
|
+
static __host__ int get_cols_per_warp(const int cc) {
|
|
259
|
+
if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) {
|
|
260
|
+
return 16;
|
|
261
|
+
} else {
|
|
262
|
+
// Volta
|
|
263
|
+
return 32;
|
|
242
264
|
}
|
|
265
|
+
}
|
|
243
266
|
|
|
244
|
-
|
|
245
|
-
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
246
|
-
return ncols <= 16 ? 96 : 160;
|
|
247
|
-
#else
|
|
248
|
-
return ncols <= 16 ? 288 : 160;
|
|
249
|
-
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
250
|
-
}
|
|
267
|
+
// ------------------------------------------------------------------------------------------------------------------
|
|
251
268
|
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
}
|
|
256
|
-
return ncols <= 16 ? 256 : 128;
|
|
257
|
-
}
|
|
269
|
+
static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
|
|
270
|
+
return cp_async_available(cc) && ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2, cc) : 0;
|
|
271
|
+
}
|
|
258
272
|
|
|
259
|
-
|
|
260
|
-
#
|
|
261
|
-
|
|
273
|
+
static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2) {
|
|
274
|
+
#ifdef CP_ASYNC_AVAILABLE
|
|
275
|
+
return ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2) : 0;
|
|
262
276
|
#else
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
|
268
|
-
return 128;
|
|
269
|
-
}
|
|
270
|
-
|
|
271
|
-
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
|
272
|
-
return 128;
|
|
273
|
-
}
|
|
274
|
-
};
|
|
277
|
+
GGML_UNUSED_VARS(DKQ, DV, ncols1, ncols2);
|
|
278
|
+
return 0;
|
|
279
|
+
#endif // CP_ASYNC_AVAILABLE
|
|
280
|
+
}
|
|
275
281
|
|
|
276
282
|
// ------------------------------------------------------------------------------------------------------------------
|
|
277
283
|
|
|
278
|
-
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
|
|
284
|
+
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
|
|
279
285
|
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
280
|
-
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
|
|
281
|
-
|
|
286
|
+
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
|
|
287
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
282
288
|
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
|
283
289
|
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
|
|
284
|
-
|
|
285
|
-
|
|
290
|
+
if constexpr (use_cp_async) {
|
|
291
|
+
static_assert(!oob_check, "OOB check not compatible with cp_async");
|
|
286
292
|
constexpr int preload = 64;
|
|
287
293
|
constexpr int h2_per_chunk = 16/sizeof(half2);
|
|
288
294
|
const int chunks_per_row = D2 / h2_per_chunk;
|
|
@@ -290,10 +296,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
290
296
|
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
|
291
297
|
|
|
292
298
|
auto load = [&] __device__ (auto n) {
|
|
293
|
-
const int stride_k =
|
|
294
|
-
const int k0_start = stride_k ==
|
|
299
|
+
const int stride_k = warp_size >> n;
|
|
300
|
+
const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
|
295
301
|
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
|
296
|
-
const int stride_i =
|
|
302
|
+
const int stride_i = warp_size / stride_k;
|
|
297
303
|
|
|
298
304
|
if (k0_start == k0_stop) {
|
|
299
305
|
return;
|
|
@@ -301,7 +307,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
301
307
|
|
|
302
308
|
#pragma unroll
|
|
303
309
|
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
|
304
|
-
const int i = i0 + threadIdx.y*stride_i + (stride_k ==
|
|
310
|
+
const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
|
305
311
|
|
|
306
312
|
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
|
307
313
|
break;
|
|
@@ -309,20 +315,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
309
315
|
|
|
310
316
|
#pragma unroll
|
|
311
317
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
312
|
-
const int k = k0 + (stride_k ==
|
|
318
|
+
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
|
313
319
|
|
|
314
320
|
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
|
|
315
321
|
}
|
|
316
322
|
}
|
|
317
323
|
};
|
|
318
|
-
|
|
324
|
+
// 1: max 32*16=512 bytes, 256 half
|
|
325
|
+
// 2: max 16*16=256 bytes, 128 half
|
|
326
|
+
// 3: max 8*16=128 bytes, 64 half
|
|
327
|
+
// 4: max 4*16= 64 bytes, 32 half
|
|
328
|
+
// 5: max 2*16= 32 bytes, 16 half
|
|
329
|
+
// 6: max 1*16= 16 bytes, 8 half
|
|
330
|
+
ggml_cuda_unroll<6>{}(load);
|
|
319
331
|
} else {
|
|
320
|
-
|
|
332
|
+
// TODO use ggml_cuda_memcpy_1
|
|
321
333
|
auto load = [&] __device__ (const int n) {
|
|
322
|
-
const int stride_k =
|
|
323
|
-
const int k0_start = stride_k ==
|
|
334
|
+
const int stride_k = warp_size >> n;
|
|
335
|
+
const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k);
|
|
324
336
|
const int k0_stop = D2 - D2 % (1*stride_k);
|
|
325
|
-
const int stride_i =
|
|
337
|
+
const int stride_i = warp_size / stride_k;
|
|
326
338
|
|
|
327
339
|
if (k0_start == k0_stop) {
|
|
328
340
|
return;
|
|
@@ -330,7 +342,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
330
342
|
|
|
331
343
|
#pragma unroll
|
|
332
344
|
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
|
333
|
-
const int i = i0 + threadIdx.y*stride_i + (stride_k ==
|
|
345
|
+
const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
|
334
346
|
|
|
335
347
|
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
|
336
348
|
break;
|
|
@@ -338,73 +350,114 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
338
350
|
|
|
339
351
|
#pragma unroll
|
|
340
352
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
341
|
-
const int k = k0 + (stride_k ==
|
|
353
|
+
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
|
342
354
|
|
|
343
|
-
tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
|
|
355
|
+
tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
|
|
344
356
|
}
|
|
345
357
|
}
|
|
346
358
|
};
|
|
347
|
-
|
|
359
|
+
// 1: max 32* 4=128 bytes, 64 half
|
|
360
|
+
// 2: max 16* 4= 64 bytes, 32 half
|
|
361
|
+
// 3: max 8* 4= 32 bytes, 16 half
|
|
362
|
+
// 4: max 4* 4= 16 bytes, 8 half
|
|
363
|
+
ggml_cuda_unroll<4>{}(load);
|
|
348
364
|
}
|
|
349
365
|
}
|
|
350
366
|
|
|
351
|
-
template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async>
|
|
367
|
+
template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
|
|
352
368
|
static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|
353
|
-
const
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
if (use_cp_async) {
|
|
369
|
+
const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
|
|
370
|
+
const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
|
|
371
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
372
|
+
if constexpr (use_cp_async) {
|
|
373
|
+
static_assert(nbatch_fa <= 8*warp_size && nbatch_fa % 8 == 0, "bad nbatch_fa");
|
|
374
|
+
static_assert(!oob_check, "OOB check incompatible with cp_async");
|
|
357
375
|
constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
|
|
358
|
-
constexpr int cols_per_warp = 8*
|
|
376
|
+
constexpr int cols_per_warp = 8*warp_size/nbatch_fa;
|
|
359
377
|
constexpr int stride_j = nwarps * cols_per_warp;
|
|
360
378
|
|
|
361
379
|
const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
|
|
362
380
|
|
|
363
381
|
#pragma unroll
|
|
364
|
-
for (int
|
|
365
|
-
const int
|
|
366
|
-
|
|
382
|
+
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
|
|
383
|
+
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
|
|
384
|
+
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
|
367
385
|
|
|
368
|
-
if (
|
|
386
|
+
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
|
|
369
387
|
break;
|
|
370
388
|
}
|
|
371
389
|
|
|
372
|
-
const int i =
|
|
390
|
+
const int i = 8 * (threadIdx.x % (nbatch_fa/8));
|
|
373
391
|
|
|
374
|
-
cp_async_cg_16<preload>(tile_mask_32 +
|
|
392
|
+
cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i);
|
|
375
393
|
}
|
|
376
|
-
|
|
377
|
-
|
|
394
|
+
} else if constexpr (oob_check) {
|
|
395
|
+
#pragma unroll
|
|
396
|
+
for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
|
|
397
|
+
const int j_sram = j1 + threadIdx.y;
|
|
398
|
+
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
|
399
|
+
|
|
400
|
+
if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
|
|
401
|
+
break;
|
|
402
|
+
}
|
|
378
403
|
|
|
379
|
-
constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
|
|
380
|
-
constexpr int stride_j = nwarps * cols_per_warp;
|
|
381
404
|
#pragma unroll
|
|
382
|
-
|
|
383
|
-
|
|
405
|
+
for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) {
|
|
406
|
+
const int i = i0 + threadIdx.x;
|
|
384
407
|
|
|
385
|
-
|
|
386
|
-
|
|
408
|
+
tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
|
|
409
|
+
}
|
|
387
410
|
}
|
|
411
|
+
} else if constexpr (nbatch_fa < 2*warp_size) {
|
|
412
|
+
constexpr int cols_per_warp = 2*warp_size/nbatch_fa;
|
|
413
|
+
constexpr int stride_j = nwarps * cols_per_warp;
|
|
414
|
+
#pragma unroll
|
|
415
|
+
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
|
|
416
|
+
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
|
|
417
|
+
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
|
388
418
|
|
|
389
|
-
|
|
419
|
+
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
|
|
420
|
+
break;
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
const int i = threadIdx.x % (warp_size/cols_per_warp);
|
|
390
424
|
|
|
391
|
-
|
|
425
|
+
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
|
|
426
|
+
}
|
|
427
|
+
} else {
|
|
428
|
+
#pragma unroll
|
|
429
|
+
for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
|
|
430
|
+
const int j_sram = j1 + threadIdx.y;
|
|
431
|
+
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
|
432
|
+
|
|
433
|
+
if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
|
|
434
|
+
break;
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
#pragma unroll
|
|
438
|
+
for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) {
|
|
439
|
+
const int i = i0 + 2*threadIdx.x;
|
|
440
|
+
|
|
441
|
+
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
|
|
442
|
+
}
|
|
443
|
+
}
|
|
392
444
|
}
|
|
393
445
|
}
|
|
394
446
|
|
|
395
|
-
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
|
|
396
|
-
bool use_logit_softcap, bool
|
|
447
|
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
|
|
448
|
+
bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
|
|
449
|
+
typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
|
|
397
450
|
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
398
451
|
const float2 * const __restrict__ Q_f2,
|
|
399
452
|
const half2 * const __restrict__ K_h2,
|
|
400
453
|
const half2 * const __restrict__ V_h2,
|
|
401
|
-
const
|
|
454
|
+
const half * const __restrict__ mask_h,
|
|
402
455
|
float2 * const __restrict__ dstk,
|
|
403
456
|
float2 * const __restrict__ dstk_fixup,
|
|
404
457
|
const float scale,
|
|
405
458
|
const float slope,
|
|
406
459
|
const float logit_softcap,
|
|
407
|
-
const
|
|
460
|
+
const uint3 ne01,
|
|
408
461
|
const int ne02,
|
|
409
462
|
const int stride_K,
|
|
410
463
|
const int stride_V,
|
|
@@ -412,66 +465,68 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
412
465
|
half2 * const __restrict__ tile_Q,
|
|
413
466
|
half2 * const __restrict__ tile_K,
|
|
414
467
|
half2 * const __restrict__ tile_V,
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
468
|
+
half * const __restrict__ tile_mask,
|
|
469
|
+
T_B_KQ * const __restrict__ Q_B,
|
|
470
|
+
T_C_VKQ * const __restrict__ VKQ_C,
|
|
418
471
|
float * const __restrict__ KQ_max,
|
|
419
472
|
float * const __restrict__ KQ_rowsum,
|
|
420
|
-
const int
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
constexpr int
|
|
426
|
-
|
|
427
|
-
constexpr int
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
constexpr int
|
|
431
|
-
constexpr int
|
|
432
|
-
constexpr
|
|
433
|
-
constexpr int
|
|
434
|
-
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
|
435
|
-
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
|
473
|
+
const int jt,
|
|
474
|
+
const int kb0,
|
|
475
|
+
const int k_VKQ_sup) {
|
|
476
|
+
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
|
477
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
478
|
+
constexpr int ncols = ncols1 * ncols2;
|
|
479
|
+
constexpr int cols_per_warp = T_B_KQ::I;
|
|
480
|
+
constexpr int cols_per_thread = get_cols_per_thread();
|
|
481
|
+
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
|
482
|
+
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
|
483
|
+
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
|
484
|
+
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
|
485
|
+
constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);
|
|
486
|
+
constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2);
|
|
436
487
|
|
|
437
488
|
constexpr int stride_tile_Q = DKQ/2 + 4;
|
|
438
489
|
constexpr int stride_tile_K = nbatch_K2 + 4;
|
|
439
490
|
|
|
440
|
-
|
|
441
|
-
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
|
491
|
+
constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
|
|
442
492
|
|
|
443
|
-
const int k_VKQ_0 = kb0 *
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
493
|
+
const int k_VKQ_0 = kb0 * nbatch_fa;
|
|
494
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
495
|
+
T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
|
|
496
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
497
|
+
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
|
|
498
|
+
#else // Volta
|
|
499
|
+
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
|
|
500
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
450
501
|
|
|
451
502
|
if constexpr (nstages > 1) {
|
|
452
|
-
static_assert(!
|
|
503
|
+
static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
|
|
504
|
+
static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
|
|
453
505
|
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
|
454
506
|
constexpr bool use_cp_async = true;
|
|
455
507
|
cp_async_wait_all();
|
|
456
508
|
__syncthreads();
|
|
457
|
-
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps,
|
|
458
|
-
(V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
|
|
509
|
+
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
510
|
+
(V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V, k_VKQ_sup);
|
|
459
511
|
} else {
|
|
460
512
|
constexpr bool use_cp_async = nstages == 1;
|
|
461
|
-
if (ncols2 > 1 ||
|
|
462
|
-
flash_attn_ext_f16_load_mask<ncols1, nwarps,
|
|
513
|
+
if (ncols2 > 1 || mask_h) {
|
|
514
|
+
flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
515
|
+
(mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
|
|
463
516
|
}
|
|
464
517
|
}
|
|
465
518
|
|
|
519
|
+
// For MLA K and V have the same data.
|
|
520
|
+
// Therefore, iterate over K in reverse and later re-use the data if possible.
|
|
466
521
|
#pragma unroll
|
|
467
|
-
for (int k0_start =
|
|
522
|
+
for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) {
|
|
468
523
|
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
|
469
524
|
const int k0_diff = k0_stop - k0_start;
|
|
470
525
|
|
|
471
|
-
if (nstages <= 1) {
|
|
526
|
+
if constexpr (nstages <= 1) {
|
|
472
527
|
constexpr bool use_cp_async = nstages == 1;
|
|
473
|
-
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps,
|
|
474
|
-
(K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
|
528
|
+
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
529
|
+
(K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup);
|
|
475
530
|
if (use_cp_async) {
|
|
476
531
|
cp_async_wait_all();
|
|
477
532
|
}
|
|
@@ -479,55 +534,68 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
479
534
|
}
|
|
480
535
|
|
|
481
536
|
// Calculate tile of KQ:
|
|
482
|
-
if constexpr (
|
|
537
|
+
if constexpr (Q_in_reg) {
|
|
483
538
|
#pragma unroll
|
|
484
|
-
for (int i_KQ_00 = 0; i_KQ_00 <
|
|
485
|
-
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*
|
|
539
|
+
for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {
|
|
540
|
+
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;
|
|
486
541
|
#pragma unroll
|
|
487
|
-
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 +=
|
|
488
|
-
|
|
542
|
+
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
|
543
|
+
T_A_KQ K_A;
|
|
489
544
|
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
|
490
|
-
if (
|
|
491
|
-
mma(KQ_C[i_KQ_00/(np*
|
|
545
|
+
if constexpr (cols_per_warp == 8) {
|
|
546
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
|
|
492
547
|
} else {
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
548
|
+
// Wide version of KQ_C is column-major
|
|
549
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
550
|
+
// AMD matrix C is column-major.
|
|
551
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
|
|
552
|
+
#else
|
|
553
|
+
// swap A and B for CUDA.
|
|
554
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
|
|
555
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
498
556
|
}
|
|
499
557
|
}
|
|
500
558
|
}
|
|
501
559
|
} else {
|
|
502
|
-
static_assert(ntiles == 2, "ntiles != 2 not implemented");
|
|
503
560
|
#pragma unroll
|
|
504
|
-
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 +=
|
|
505
|
-
load_ldmatrix(
|
|
561
|
+
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
|
562
|
+
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
|
506
563
|
|
|
507
564
|
#pragma unroll
|
|
508
|
-
for (int i_KQ_00 = 0; i_KQ_00 <
|
|
509
|
-
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*
|
|
565
|
+
for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {
|
|
566
|
+
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;
|
|
510
567
|
|
|
511
|
-
|
|
568
|
+
T_A_KQ K_A;
|
|
512
569
|
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
|
513
570
|
|
|
514
|
-
|
|
515
|
-
|
|
571
|
+
if constexpr (cols_per_warp == 8) {
|
|
572
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
|
573
|
+
} else {
|
|
574
|
+
// Wide version of KQ_C is column-major
|
|
575
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
576
|
+
// AMD matrix C is column-major.
|
|
577
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
|
578
|
+
#else
|
|
579
|
+
// swap A and B for CUDA.
|
|
580
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
|
581
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
582
|
+
}
|
|
516
583
|
}
|
|
517
584
|
}
|
|
518
585
|
}
|
|
519
586
|
|
|
520
|
-
if (nstages <= 1) {
|
|
587
|
+
if constexpr (nstages <= 1) {
|
|
521
588
|
__syncthreads(); // Only needed if tile_K == tile_V.
|
|
522
589
|
}
|
|
523
590
|
}
|
|
524
591
|
|
|
525
592
|
if (use_logit_softcap) {
|
|
526
|
-
|
|
593
|
+
constexpr int stride = cols_per_warp == 8 ? np*T_C_KQ::I : np*T_C_KQ::J;
|
|
594
|
+
static_assert(nbatch_fa % stride == 0, "bad loop size");
|
|
527
595
|
#pragma unroll
|
|
528
|
-
for (int i = 0; i <
|
|
596
|
+
for (int i = 0; i < nbatch_fa/stride; ++i) {
|
|
529
597
|
#pragma unroll
|
|
530
|
-
for (int l = 0; l <
|
|
598
|
+
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
531
599
|
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
|
|
532
600
|
}
|
|
533
601
|
}
|
|
@@ -540,109 +608,145 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
540
608
|
}
|
|
541
609
|
float KQ_rowsum_add[cols_per_thread] = {0.0f};
|
|
542
610
|
|
|
543
|
-
if (
|
|
544
|
-
if (ncols2 > 1 ||
|
|
611
|
+
if constexpr (cols_per_warp == 8) {
|
|
612
|
+
if (ncols2 > 1 || mask_h) {
|
|
545
613
|
#pragma unroll
|
|
546
|
-
for (int i00 = 0; i00 <
|
|
547
|
-
const int i0 = i00 + (threadIdx.y % np)*
|
|
614
|
+
for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::I) {
|
|
615
|
+
const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::I;
|
|
548
616
|
#pragma unroll
|
|
549
|
-
for (int l = 0; l <
|
|
550
|
-
const int i = i0 +
|
|
551
|
-
const int j = ((threadIdx.y / np)*
|
|
617
|
+
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
618
|
+
const int i = i0 + T_C_KQ::get_i(l);
|
|
619
|
+
const int j = ((threadIdx.y / np)*T_C_KQ::J + T_C_KQ::get_j(l)) / ncols2;
|
|
552
620
|
|
|
553
|
-
KQ_C[i00/(np*
|
|
554
|
-
__half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]);
|
|
621
|
+
KQ_C[i00/(np*T_C_KQ::I)].x[l] += slope * __half2float(tile_mask[j*(nbatch_fa + 8) + i]);
|
|
555
622
|
}
|
|
556
623
|
}
|
|
557
624
|
}
|
|
558
625
|
|
|
559
626
|
// Calculate softmax for each KQ column using the current max. value.
|
|
560
627
|
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
|
561
|
-
static_assert(
|
|
628
|
+
static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size");
|
|
562
629
|
#pragma unroll
|
|
563
|
-
for (int
|
|
630
|
+
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
|
|
564
631
|
#pragma unroll
|
|
565
|
-
for (int l = 0; l <
|
|
566
|
-
|
|
632
|
+
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
633
|
+
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
|
634
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
635
|
+
constexpr int KQ_idx = 0;
|
|
636
|
+
#else
|
|
637
|
+
// Turing + Volta:
|
|
638
|
+
const int KQ_idx = l % 2;
|
|
639
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
640
|
+
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
|
|
641
|
+
}
|
|
567
642
|
}
|
|
568
643
|
}
|
|
569
644
|
|
|
570
|
-
// Values per KQ column are spread across 8 threads
|
|
645
|
+
// Values per KQ column are spread across 8 threads:
|
|
571
646
|
#pragma unroll
|
|
572
647
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
573
648
|
#pragma unroll
|
|
574
649
|
for (int offset = 16; offset >= 4; offset >>= 1) {
|
|
575
|
-
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset,
|
|
650
|
+
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
|
|
576
651
|
}
|
|
577
652
|
}
|
|
578
653
|
|
|
579
|
-
static_assert(
|
|
654
|
+
static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size");
|
|
580
655
|
#pragma unroll
|
|
581
|
-
for (int
|
|
656
|
+
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
|
|
582
657
|
#pragma unroll
|
|
583
|
-
for (int l = 0; l <
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
658
|
+
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
659
|
+
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
|
660
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
661
|
+
constexpr int KQ_idx = 0;
|
|
662
|
+
#else
|
|
663
|
+
// Turing + Volta:
|
|
664
|
+
const int KQ_idx = l % 2;
|
|
665
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
666
|
+
KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
|
|
667
|
+
KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
|
|
668
|
+
} else {
|
|
669
|
+
KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
|
|
670
|
+
}
|
|
587
671
|
}
|
|
588
672
|
}
|
|
589
|
-
} else { //
|
|
590
|
-
if (ncols2 > 1 ||
|
|
673
|
+
} else { // not Turing mma or T_B_KQ::I > 8
|
|
674
|
+
if (ncols2 > 1 || mask_h) {
|
|
591
675
|
#pragma unroll
|
|
592
|
-
for (int i00 = 0; i00 <
|
|
593
|
-
const int i0 = i00 + (threadIdx.y % np)*
|
|
676
|
+
for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) {
|
|
677
|
+
const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J;
|
|
594
678
|
#pragma unroll
|
|
595
|
-
for (int
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
|
|
599
|
-
const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
|
|
679
|
+
for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) {
|
|
680
|
+
const int i = (i0 + T_C_KQ::get_j(l0)) / 2;
|
|
681
|
+
const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l0)) / ncols2;
|
|
600
682
|
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
|
|
605
|
-
}
|
|
683
|
+
const float2 tmp = __half22float2(((const half2 *)tile_mask)[j*(nbatch_fa/2 + 4) + i]);
|
|
684
|
+
KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x;
|
|
685
|
+
KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y;
|
|
606
686
|
}
|
|
607
687
|
}
|
|
608
688
|
}
|
|
609
689
|
|
|
610
690
|
// Calculate softmax for each KQ column using the current max. value.
|
|
611
691
|
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
|
612
|
-
static_assert(
|
|
692
|
+
static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size");
|
|
613
693
|
#pragma unroll
|
|
614
|
-
for (int
|
|
694
|
+
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
|
|
615
695
|
#pragma unroll
|
|
616
|
-
for (int
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
696
|
+
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
697
|
+
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
|
698
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
699
|
+
constexpr int KQ_idx = 0;
|
|
700
|
+
#else
|
|
701
|
+
// Turing + Volta:
|
|
702
|
+
const int KQ_idx = (l/2) % 2;
|
|
703
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
704
|
+
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
|
|
621
705
|
}
|
|
622
706
|
}
|
|
623
707
|
}
|
|
624
708
|
|
|
625
|
-
// Values per KQ column are spread across 4 threads, does not need full warp reduce:
|
|
626
709
|
#pragma unroll
|
|
627
710
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
711
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
712
|
+
// Values per KQ column are spread across 4 threads:
|
|
713
|
+
constexpr int offset_first = 2;
|
|
714
|
+
constexpr int offset_last = 1;
|
|
715
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
716
|
+
// MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16).
|
|
717
|
+
constexpr int offset_first = 32;
|
|
718
|
+
constexpr int offset_last = 16;
|
|
719
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
720
|
+
// Values per KQ column are spread across 2 threads:
|
|
721
|
+
constexpr int offset_first = 16;
|
|
722
|
+
constexpr int offset_last = 16;
|
|
723
|
+
#else // Volta
|
|
724
|
+
// Values per KQ column are spread across 2 threads:
|
|
725
|
+
constexpr int offset_first = 2;
|
|
726
|
+
constexpr int offset_last = 2;
|
|
727
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
628
728
|
#pragma unroll
|
|
629
|
-
for (int offset =
|
|
630
|
-
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset,
|
|
729
|
+
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
|
730
|
+
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
|
|
631
731
|
}
|
|
632
732
|
}
|
|
633
733
|
|
|
634
|
-
static_assert(
|
|
635
|
-
#pragma unroll
|
|
636
|
-
for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
|
|
734
|
+
static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size");
|
|
637
735
|
#pragma unroll
|
|
638
|
-
|
|
736
|
+
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
|
|
639
737
|
#pragma unroll
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
738
|
+
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
739
|
+
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
|
740
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
741
|
+
constexpr int KQ_idx = 0;
|
|
742
|
+
#else
|
|
743
|
+
// Turing + Volta:
|
|
744
|
+
const int KQ_idx = (l/2) % 2;
|
|
745
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
746
|
+
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
|
|
747
|
+
KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
|
|
748
|
+
} else {
|
|
749
|
+
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
|
|
646
750
|
}
|
|
647
751
|
}
|
|
648
752
|
}
|
|
@@ -662,12 +766,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
662
766
|
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
|
|
663
767
|
}
|
|
664
768
|
|
|
665
|
-
|
|
666
|
-
|
|
769
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
770
|
+
if constexpr (cols_per_warp == 8) {
|
|
771
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
|
|
667
772
|
#pragma unroll
|
|
668
|
-
for (int i = 0; i < DV/
|
|
773
|
+
for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
|
|
669
774
|
#pragma unroll
|
|
670
|
-
for (int l = 0; l <
|
|
775
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
671
776
|
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
672
777
|
}
|
|
673
778
|
}
|
|
@@ -676,165 +781,281 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
676
781
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
677
782
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
|
678
783
|
#pragma unroll
|
|
679
|
-
for (int i = 0; i < DV/
|
|
784
|
+
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
|
680
785
|
#pragma unroll
|
|
681
|
-
for (int l0 = 0; l0 <
|
|
682
|
-
|
|
786
|
+
for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {
|
|
787
|
+
VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;
|
|
683
788
|
}
|
|
684
789
|
}
|
|
685
790
|
}
|
|
686
791
|
}
|
|
792
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
793
|
+
const half2 KQ_max_scale_h2 = make_half2(
|
|
794
|
+
KQ_max_scale[0], KQ_max_scale[0]);
|
|
795
|
+
#pragma unroll
|
|
796
|
+
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
|
797
|
+
#pragma unroll
|
|
798
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
799
|
+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
800
|
+
}
|
|
801
|
+
}
|
|
802
|
+
#else // Volta
|
|
803
|
+
const half2 KQ_max_scale_h2 = make_half2(
|
|
804
|
+
KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
|
|
805
|
+
#pragma unroll
|
|
806
|
+
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
|
807
|
+
#pragma unroll
|
|
808
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
809
|
+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
810
|
+
}
|
|
811
|
+
}
|
|
812
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
687
813
|
}
|
|
688
814
|
|
|
689
815
|
// Convert KQ C tiles into B tiles for VKQ calculation:
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
if (ntiles == 1) {
|
|
816
|
+
T_B_VKQ B[nbatch_fa/(np*2*T_B_VKQ::J)];
|
|
817
|
+
static_assert(nbatch_fa % (np*2*T_B_VKQ::J) == 0, "bad loop size");
|
|
818
|
+
if constexpr (cols_per_warp == 8) {
|
|
694
819
|
#pragma unroll
|
|
695
|
-
for (int k = 0; k <
|
|
820
|
+
for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {
|
|
696
821
|
B[k] = get_transposed(get_half2(KQ_C[k]));
|
|
697
822
|
}
|
|
698
823
|
} else {
|
|
699
|
-
for (int k = 0; k <
|
|
700
|
-
|
|
701
|
-
for (int t = 0; t < ntiles/2; ++t) {
|
|
702
|
-
B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
|
|
703
|
-
}
|
|
824
|
+
for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {
|
|
825
|
+
B[k] = get_half2(KQ_C[k]);
|
|
704
826
|
}
|
|
705
827
|
}
|
|
706
828
|
|
|
707
|
-
if (nstages > 1) {
|
|
829
|
+
if constexpr (nstages > 1) {
|
|
830
|
+
static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
|
|
708
831
|
// Preload K tile for next iteration:
|
|
709
832
|
constexpr bool use_cp_async = true;
|
|
710
833
|
cp_async_wait_all();
|
|
711
834
|
__syncthreads();
|
|
712
835
|
if (!last_iter) {
|
|
713
|
-
if (ncols2 > 1 ||
|
|
714
|
-
flash_attn_ext_f16_load_mask<ncols1, nwarps,
|
|
715
|
-
(
|
|
836
|
+
if (ncols2 > 1 || mask_h) {
|
|
837
|
+
flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
838
|
+
(mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
|
|
716
839
|
}
|
|
717
|
-
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps,
|
|
718
|
-
(K_h2 + int64_t(k_VKQ_0 +
|
|
840
|
+
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
841
|
+
(K_h2 + int64_t(k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
|
|
719
842
|
}
|
|
720
843
|
}
|
|
721
844
|
|
|
722
845
|
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
846
|
+
#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
|
|
847
|
+
T_A_VKQ A_identity;
|
|
848
|
+
make_identity_mat(A_identity);
|
|
849
|
+
#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
|
|
850
|
+
|
|
851
|
+
// Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
|
|
727
852
|
#pragma unroll
|
|
728
|
-
for (int
|
|
729
|
-
|
|
730
|
-
const int
|
|
853
|
+
for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
|
|
854
|
+
static_assert(DV % (2*nbatch_V2) == 0, "bad loop size");
|
|
855
|
+
const int i0_stop = i0_start + 2*nbatch_V2;
|
|
856
|
+
const int i0_diff = i0_stop - i0_start;
|
|
731
857
|
|
|
732
|
-
if (nstages <= 1
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
858
|
+
if constexpr (nstages <= 1) {
|
|
859
|
+
if (!V_is_K_view || i0_stop > 2*nbatch_K2) {
|
|
860
|
+
constexpr bool use_cp_async = nstages == 1;
|
|
861
|
+
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
862
|
+
(V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
|
|
863
|
+
if (use_cp_async) {
|
|
864
|
+
cp_async_wait_all();
|
|
865
|
+
}
|
|
866
|
+
__syncthreads();
|
|
738
867
|
}
|
|
739
|
-
__syncthreads();
|
|
740
868
|
}
|
|
741
|
-
const half2 * tile_V_i =
|
|
869
|
+
const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
|
|
742
870
|
|
|
743
|
-
|
|
871
|
+
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
872
|
+
constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
|
|
744
873
|
#pragma unroll
|
|
745
|
-
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 +=
|
|
746
|
-
static_assert((
|
|
874
|
+
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
|
|
875
|
+
static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size");
|
|
747
876
|
#pragma unroll
|
|
748
|
-
for (int k00 = 0; k00 <
|
|
749
|
-
const int k0 = k00 + (threadIdx.y % np)*
|
|
877
|
+
for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) {
|
|
878
|
+
const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
|
|
750
879
|
|
|
751
|
-
|
|
880
|
+
T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
|
|
881
|
+
#if defined(LDMATRIX_TRANS_AVAILABLE)
|
|
752
882
|
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
883
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
884
|
+
// MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg].
|
|
885
|
+
// Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T.
|
|
886
|
+
// Load with transposed addressing: 4 strided half loads.
|
|
887
|
+
{
|
|
888
|
+
const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2;
|
|
889
|
+
const half * xs0_h = (const half *) xs0;
|
|
890
|
+
const int stride_h = stride_tile_V * 2; // stride in half units
|
|
891
|
+
half * A_h = (half *) A.x;
|
|
892
|
+
#pragma unroll
|
|
893
|
+
for (int l = 0; l < 4; ++l) {
|
|
894
|
+
A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16];
|
|
760
895
|
}
|
|
761
896
|
}
|
|
897
|
+
#else
|
|
898
|
+
// TODO: Try to transpose tile_V when loading gmem to smem.
|
|
899
|
+
// Use mma to transpose T_A_VKQ for RDNA.
|
|
900
|
+
T_A_VKQ A_trans;
|
|
901
|
+
load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
|
902
|
+
mma(A, A_trans, A_identity);
|
|
903
|
+
#endif // defined(LDMATRIX_TRANS_AVAILABLE)
|
|
904
|
+
if constexpr (T_B_KQ::I == 8) {
|
|
905
|
+
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
|
906
|
+
} else {
|
|
907
|
+
// Wide version of VKQ_C is column-major.
|
|
908
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
909
|
+
// AMD matrix C is column-major.
|
|
910
|
+
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
|
911
|
+
#else
|
|
912
|
+
// swap A and B for CUDA.
|
|
913
|
+
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
|
|
914
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
915
|
+
}
|
|
762
916
|
}
|
|
763
917
|
}
|
|
918
|
+
#else // Volta
|
|
919
|
+
constexpr int i0_stride = 2*T_C_VKQ::J;
|
|
920
|
+
#pragma unroll
|
|
921
|
+
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
|
|
922
|
+
static_assert(nbatch_fa % (np*T_A_VKQ::I) == 0, "bad loop size");
|
|
923
|
+
static_assert(2*T_B_VKQ::J == T_A_VKQ::I, "bad tile sizes");
|
|
924
|
+
#pragma unroll
|
|
925
|
+
for (int k00 = 0; k00 < nbatch_fa; k00 += np*T_A_VKQ::I) {
|
|
926
|
+
const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::I;
|
|
764
927
|
|
|
765
|
-
|
|
928
|
+
T_A_VKQ A; // Transposed in both SRAM and registers, load normally.
|
|
929
|
+
load_ldmatrix(A, tile_V_i + k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
|
930
|
+
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
|
|
931
|
+
}
|
|
932
|
+
}
|
|
933
|
+
#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
934
|
+
|
|
935
|
+
if constexpr (nstages <= 1) {
|
|
766
936
|
__syncthreads(); // Only needed if tile_K == tile_V.
|
|
767
937
|
}
|
|
768
938
|
}
|
|
769
939
|
#else
|
|
770
|
-
GGML_UNUSED_VARS(Q_f2, K_h2, V_h2,
|
|
940
|
+
GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup,
|
|
771
941
|
scale, slope, logit_softcap, ne01, ne02,
|
|
772
942
|
stride_K, stride_V, stride_mask,
|
|
773
943
|
tile_Q, tile_K, tile_V, tile_mask,
|
|
774
944
|
Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
|
775
945
|
NO_DEVICE_CODE;
|
|
776
|
-
#endif // TURING_MMA_AVAILABLE
|
|
946
|
+
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
|
777
947
|
}
|
|
778
948
|
|
|
779
|
-
|
|
949
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
950
|
+
template<int ncols> struct mma_tile_sizes {
|
|
951
|
+
using T_A_KQ = tile<16, 8, half2>; // row-major
|
|
952
|
+
using T_B_KQ = tile<16, 8, half2>; // column-major
|
|
953
|
+
using T_C_KQ = tile<16, 16, float>; // column-major
|
|
954
|
+
using T_A_VKQ = tile<16, 8, half2>; // row-major
|
|
955
|
+
using T_B_VKQ = tile<16, 8, half2>; // column-major
|
|
956
|
+
using T_C_VKQ = tile<16, 8, half2>; // column-major
|
|
957
|
+
};
|
|
958
|
+
template<> struct mma_tile_sizes<8> {
|
|
959
|
+
using T_A_KQ = tile<16, 8, half2>; // row-major
|
|
960
|
+
using T_B_KQ = tile< 8, 8, half2>; // column-major
|
|
961
|
+
using T_C_KQ = tile<16, 8, float>; // row-major
|
|
962
|
+
using T_A_VKQ = tile<16, 8, half2>; // row-major
|
|
963
|
+
using T_B_VKQ = tile< 8, 8, half2>; // column-major
|
|
964
|
+
using T_C_VKQ = tile<16, 4, half2>; // row-major
|
|
965
|
+
};
|
|
966
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
967
|
+
template<int ncols> struct mma_tile_sizes {
|
|
968
|
+
using T_A_KQ = tile<16, 8, half2>; // row-major
|
|
969
|
+
using T_B_KQ = tile<16, 8, half2>; // column-major
|
|
970
|
+
using T_C_KQ = tile<16, 16, float>; // column-major
|
|
971
|
+
using T_A_VKQ = tile<16, 8, half2>; // row-major
|
|
972
|
+
using T_B_VKQ = tile<16, 8, half2>; // column-major
|
|
973
|
+
using T_C_VKQ = tile<16, 8, half2>; // column-major
|
|
974
|
+
};
|
|
975
|
+
#else // Volta
|
|
976
|
+
template<int ncols> struct mma_tile_sizes {
|
|
977
|
+
using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
|
978
|
+
using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
979
|
+
using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
980
|
+
using T_A_VKQ = tile< 8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED>; // column-major
|
|
981
|
+
using T_B_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
982
|
+
using T_C_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
983
|
+
};
|
|
984
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
985
|
+
|
|
986
|
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup>
|
|
780
987
|
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
781
988
|
const float2 * const __restrict__ Q_f2,
|
|
782
989
|
const half2 * const __restrict__ K_h2,
|
|
783
990
|
const half2 * const __restrict__ V_h2,
|
|
784
|
-
const
|
|
991
|
+
const half * const __restrict__ mask_h,
|
|
785
992
|
const float * const __restrict__ sinks_f,
|
|
786
993
|
float2 * const __restrict__ dstk,
|
|
787
994
|
float2 * const __restrict__ dstk_fixup,
|
|
788
995
|
const float scale,
|
|
789
996
|
const float slope,
|
|
790
997
|
const float logit_softcap,
|
|
791
|
-
const
|
|
998
|
+
const uint3 ne01,
|
|
792
999
|
const int ne02,
|
|
1000
|
+
const int gqa_ratio,
|
|
1001
|
+
const int ne11,
|
|
793
1002
|
const int stride_Q1,
|
|
794
1003
|
const int stride_Q2,
|
|
795
1004
|
const int stride_K,
|
|
796
1005
|
const int stride_V,
|
|
797
1006
|
const int stride_mask,
|
|
798
1007
|
const int jt,
|
|
1008
|
+
const int zt_gqa,
|
|
799
1009
|
const int kb0_start,
|
|
800
1010
|
const int kb0_stop) {
|
|
801
|
-
#
|
|
1011
|
+
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
|
802
1012
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
|
803
1013
|
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
constexpr int
|
|
814
|
-
constexpr int
|
|
815
|
-
constexpr int
|
|
816
|
-
constexpr int
|
|
817
|
-
constexpr int
|
|
1014
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1015
|
+
constexpr int ncols = ncols1 * ncols2;
|
|
1016
|
+
using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ;
|
|
1017
|
+
using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ;
|
|
1018
|
+
using T_C_KQ = typename mma_tile_sizes<ncols>::T_C_KQ;
|
|
1019
|
+
using T_A_VKQ = typename mma_tile_sizes<ncols>::T_A_VKQ;
|
|
1020
|
+
using T_B_VKQ = typename mma_tile_sizes<ncols>::T_B_VKQ;
|
|
1021
|
+
using T_C_VKQ = typename mma_tile_sizes<ncols>::T_C_VKQ;
|
|
1022
|
+
|
|
1023
|
+
constexpr int cols_per_warp = T_B_KQ::I;
|
|
1024
|
+
constexpr int cols_per_thread = get_cols_per_thread();
|
|
1025
|
+
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
|
1026
|
+
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
|
1027
|
+
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
|
1028
|
+
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
|
1029
|
+
constexpr int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols);
|
|
1030
|
+
constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);
|
|
1031
|
+
constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2);
|
|
1032
|
+
|
|
1033
|
+
if (cols_per_warp > ncols) {
|
|
1034
|
+
NO_DEVICE_CODE;
|
|
1035
|
+
return;
|
|
1036
|
+
}
|
|
818
1037
|
|
|
819
1038
|
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
|
|
820
1039
|
|
|
821
1040
|
constexpr int stride_tile_Q = DKQ/2 + 4;
|
|
822
1041
|
constexpr int stride_tile_K = nbatch_K2 + 4;
|
|
823
1042
|
|
|
824
|
-
|
|
825
|
-
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
|
1043
|
+
constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
|
|
826
1044
|
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
|
827
1045
|
|
|
828
1046
|
extern __shared__ half2 tile_Q[];
|
|
829
|
-
half2 * tile_K =
|
|
830
|
-
half2 * tile_V =
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
1047
|
+
half2 * tile_K = Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q;
|
|
1048
|
+
half2 * tile_V = nstages > 1 ? tile_K + nbatch_fa * stride_tile_K : tile_K;
|
|
1049
|
+
half * tile_mask = (half *) (nstages > 1 ? tile_V + nbatch_fa * stride_tile_V : tile_V + nbatch_fa * stride_tile_KV_max);
|
|
1050
|
+
|
|
1051
|
+
T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
|
|
1052
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
1053
|
+
T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
|
|
1054
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
1055
|
+
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
|
|
1056
|
+
#else // Volta
|
|
1057
|
+
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
|
|
1058
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
838
1059
|
|
|
839
1060
|
float KQ_rowsum[cols_per_thread] = {0.0f};
|
|
840
1061
|
float KQ_max[cols_per_thread];
|
|
@@ -848,10 +1069,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
848
1069
|
// The loading is done with decreasing granularity for D for better memory bandwidth.
|
|
849
1070
|
const half2 scale_h2 = make_half2(scale, scale);
|
|
850
1071
|
#pragma unroll
|
|
851
|
-
for (int stride_k : {
|
|
852
|
-
const int k0_start = stride_k ==
|
|
1072
|
+
for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
|
|
1073
|
+
const int k0_start = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
|
|
853
1074
|
const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k);
|
|
854
|
-
const int stride_jc =
|
|
1075
|
+
const int stride_jc = warp_size / stride_k;
|
|
855
1076
|
|
|
856
1077
|
if (k0_start == k0_stop) {
|
|
857
1078
|
continue;
|
|
@@ -859,7 +1080,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
859
1080
|
|
|
860
1081
|
#pragma unroll
|
|
861
1082
|
for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
|
|
862
|
-
const int jc = jc0 + threadIdx.y*stride_jc + (stride_k ==
|
|
1083
|
+
const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
|
863
1084
|
|
|
864
1085
|
if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
|
|
865
1086
|
break;
|
|
@@ -868,10 +1089,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
868
1089
|
const int j = jc / ncols2;
|
|
869
1090
|
const int c = jc % ncols2;
|
|
870
1091
|
|
|
871
|
-
if (jt*ncols1 + j < ne01) {
|
|
1092
|
+
if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
|
|
872
1093
|
#pragma unroll
|
|
873
1094
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
874
|
-
const int k = k0 + (stride_k ==
|
|
1095
|
+
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
|
875
1096
|
|
|
876
1097
|
const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
|
|
877
1098
|
tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
|
|
@@ -879,7 +1100,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
879
1100
|
} else {
|
|
880
1101
|
#pragma unroll
|
|
881
1102
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
882
|
-
const int k = k0 + (stride_k ==
|
|
1103
|
+
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
|
883
1104
|
|
|
884
1105
|
tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
|
|
885
1106
|
}
|
|
@@ -889,81 +1110,118 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
889
1110
|
|
|
890
1111
|
__syncthreads();
|
|
891
1112
|
|
|
892
|
-
if (
|
|
1113
|
+
if (Q_in_reg) {
|
|
893
1114
|
const int j0 = (threadIdx.y / np) * cols_per_warp;
|
|
894
1115
|
|
|
895
1116
|
#pragma unroll
|
|
896
|
-
for (int k0 = 0; k0 < DKQ/2; k0 +=
|
|
897
|
-
|
|
898
|
-
load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
|
|
899
|
-
} else {
|
|
900
|
-
#pragma unroll
|
|
901
|
-
for (int t = 0; t < ntiles/2; ++t) {
|
|
902
|
-
load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
|
|
903
|
-
tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q);
|
|
904
|
-
}
|
|
905
|
-
}
|
|
1117
|
+
for (int k0 = 0; k0 < DKQ/2; k0 += T_B_KQ::J) {
|
|
1118
|
+
load_ldmatrix(Q_B[k0/T_B_KQ::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
|
|
906
1119
|
}
|
|
907
1120
|
}
|
|
908
1121
|
|
|
909
1122
|
__syncthreads();
|
|
910
1123
|
|
|
1124
|
+
int kb0 = kb0_start;
|
|
1125
|
+
|
|
911
1126
|
// Preload mask and K data for first iteration when using cp_async with multiple stages:
|
|
912
1127
|
if constexpr (nstages > 1) {
|
|
913
1128
|
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
|
|
914
1129
|
constexpr bool use_cp_async = true;
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
1130
|
+
constexpr bool oob_check = false;
|
|
1131
|
+
constexpr int k_VKQ_sup = nbatch_fa;
|
|
1132
|
+
if (ncols2 > 1 || mask_h) {
|
|
1133
|
+
flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
1134
|
+
(mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
|
|
1135
|
+
}
|
|
1136
|
+
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
1137
|
+
(K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
|
|
1138
|
+
}
|
|
1139
|
+
|
|
1140
|
+
// kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
|
1141
|
+
if constexpr (ncols2 == 1) {
|
|
1142
|
+
constexpr bool oob_check = true;
|
|
1143
|
+
for (; kb0 < kb0_stop-1; ++kb0) {
|
|
1144
|
+
constexpr bool last_iter = false;
|
|
1145
|
+
constexpr int k_VKQ_sup = nbatch_fa;
|
|
1146
|
+
flash_attn_ext_f16_iter
|
|
1147
|
+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
|
1148
|
+
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
1149
|
+
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
1150
|
+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
1151
|
+
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
|
1152
|
+
}
|
|
1153
|
+
constexpr bool last_iter = true;
|
|
1154
|
+
const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
|
|
1155
|
+
flash_attn_ext_f16_iter
|
|
1156
|
+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
|
1157
|
+
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
1158
|
+
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
1159
|
+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
1160
|
+
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
|
1161
|
+
} else {
|
|
1162
|
+
constexpr bool oob_check = false;
|
|
1163
|
+
for (; kb0 < kb0_stop-1; ++kb0) {
|
|
1164
|
+
constexpr bool last_iter = false;
|
|
1165
|
+
constexpr int k_VKQ_sup = nbatch_fa;
|
|
1166
|
+
flash_attn_ext_f16_iter
|
|
1167
|
+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
|
1168
|
+
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
1169
|
+
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
1170
|
+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
1171
|
+
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
|
918
1172
|
}
|
|
919
|
-
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
|
920
|
-
(K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
|
921
|
-
}
|
|
922
|
-
|
|
923
|
-
// Iterate over ne11 == previous tokens:
|
|
924
|
-
int kb0 = kb0_start;
|
|
925
|
-
for (; kb0 < kb0_stop-1; ++kb0) {
|
|
926
|
-
constexpr bool last_iter = false;
|
|
927
|
-
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
|
928
|
-
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
929
|
-
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
|
930
|
-
}
|
|
931
|
-
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
|
932
1173
|
constexpr bool last_iter = true;
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
1174
|
+
constexpr int k_VKQ_sup = nbatch_fa;
|
|
1175
|
+
flash_attn_ext_f16_iter
|
|
1176
|
+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
|
1177
|
+
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
1178
|
+
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
1179
|
+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
1180
|
+
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
|
936
1181
|
}
|
|
937
1182
|
|
|
938
1183
|
// With multi-stage loading there is no __syncthreads at the end of the iter,
|
|
939
1184
|
// there can be a race condition on shared memory access for combining/writing back results.
|
|
940
|
-
if (nstages > 1 && nwarps*cols_per_warp >
|
|
1185
|
+
if constexpr (nstages > 1 && nwarps*cols_per_warp > nbatch_fa) {
|
|
941
1186
|
__syncthreads();
|
|
942
1187
|
}
|
|
943
1188
|
|
|
944
1189
|
// Finally, sum up partial KQ rowsums.
|
|
945
|
-
// The partial sums are spread across 8/4 threads each, does not need full reduce.
|
|
946
1190
|
{
|
|
947
|
-
|
|
948
|
-
|
|
1191
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
1192
|
+
// The partial sums are spread across 8/4 threads.
|
|
1193
|
+
constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
|
|
1194
|
+
constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
|
|
1195
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
1196
|
+
// The partial sums are spread across 4 threads (wavefront64, 16 cols).
|
|
1197
|
+
constexpr int offset_first = 32;
|
|
1198
|
+
constexpr int offset_last = 16;
|
|
1199
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
1200
|
+
// The partial sums are spread across 2 threads.
|
|
1201
|
+
constexpr int offset_first = 16;
|
|
1202
|
+
constexpr int offset_last = 16;
|
|
1203
|
+
#else // Volta
|
|
1204
|
+
// The partial sums are spread across 2 threads.
|
|
1205
|
+
constexpr int offset_first = 2;
|
|
1206
|
+
constexpr int offset_last = 2;
|
|
1207
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
949
1208
|
#pragma unroll
|
|
950
1209
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
951
1210
|
#pragma unroll
|
|
952
1211
|
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
|
953
|
-
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset,
|
|
1212
|
+
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size);
|
|
954
1213
|
}
|
|
955
1214
|
}
|
|
956
1215
|
}
|
|
957
1216
|
|
|
958
1217
|
// If attention sinks are used, potentially re-scale if KQ_max is small.
|
|
959
|
-
// Also add the sink as a value to KQ_rowsum, this is done after
|
|
1218
|
+
// Also add the sink as a value to KQ_rowsum, this is done after synchronization of KQ_rowsum
|
|
960
1219
|
// so it's being done unconditionally for every thread.
|
|
961
1220
|
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
|
|
962
1221
|
float KQ_max_scale[cols_per_thread];
|
|
963
1222
|
#pragma unroll
|
|
964
1223
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
965
|
-
|
|
966
|
-
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
|
|
1224
|
+
const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);
|
|
967
1225
|
const float sink = sinks_f[jc % ncols2];
|
|
968
1226
|
|
|
969
1227
|
const float KQ_max_new = fmaxf(KQ_max[col], sink);
|
|
@@ -977,12 +1235,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
977
1235
|
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
|
|
978
1236
|
}
|
|
979
1237
|
|
|
980
|
-
|
|
981
|
-
|
|
1238
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
1239
|
+
if constexpr (cols_per_warp == 8) {
|
|
1240
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
|
|
982
1241
|
#pragma unroll
|
|
983
|
-
for (int i = 0; i < DV/
|
|
1242
|
+
for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
|
|
984
1243
|
#pragma unroll
|
|
985
|
-
for (int l = 0; l <
|
|
1244
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
986
1245
|
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
987
1246
|
}
|
|
988
1247
|
}
|
|
@@ -991,30 +1250,49 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
991
1250
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
992
1251
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
|
993
1252
|
#pragma unroll
|
|
994
|
-
for (int i = 0; i < DV/
|
|
1253
|
+
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
|
995
1254
|
#pragma unroll
|
|
996
|
-
for (int l0 = 0; l0 <
|
|
997
|
-
|
|
1255
|
+
for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {
|
|
1256
|
+
VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;
|
|
998
1257
|
}
|
|
999
1258
|
}
|
|
1000
1259
|
}
|
|
1001
1260
|
}
|
|
1261
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
1262
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
|
|
1263
|
+
#pragma unroll
|
|
1264
|
+
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
|
1265
|
+
#pragma unroll
|
|
1266
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
1267
|
+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
1268
|
+
}
|
|
1269
|
+
}
|
|
1270
|
+
#else // Volta
|
|
1271
|
+
const int col = (threadIdx.x / 2) % 2;
|
|
1272
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
|
1273
|
+
#pragma unroll
|
|
1274
|
+
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
|
1275
|
+
#pragma unroll
|
|
1276
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
1277
|
+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
1278
|
+
}
|
|
1279
|
+
}
|
|
1280
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
1002
1281
|
}
|
|
1003
1282
|
|
|
1004
1283
|
// Combine VKQ accumulator values if np > 1.
|
|
1005
1284
|
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
|
1006
1285
|
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
|
1007
1286
|
|
|
1008
|
-
constexpr int
|
|
1009
|
-
constexpr int tile_stride = nbatch_combine + 4;
|
|
1287
|
+
constexpr int tile_stride = nbatch_combine + 4;
|
|
1010
1288
|
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
|
|
1011
1289
|
|
|
1012
|
-
if constexpr (
|
|
1013
|
-
const int jc_cwmo = (threadIdx.x % (2*
|
|
1014
|
-
const int jc_cwm = threadIdx.y*(2*
|
|
1290
|
+
if constexpr (cols_per_warp == 8) {
|
|
1291
|
+
const int jc_cwmo = (threadIdx.x % (2*T_C_VKQ::J)) / T_C_VKQ::J; // jc combine write meta offset
|
|
1292
|
+
const int jc_cwm = threadIdx.y*(2*T_C_VKQ::J) + 2*T_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
|
|
1015
1293
|
const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum
|
|
1016
1294
|
|
|
1017
|
-
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*
|
|
1295
|
+
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*T_C_VKQ::J) {
|
|
1018
1296
|
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
|
1019
1297
|
((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
|
|
1020
1298
|
}
|
|
@@ -1023,24 +1301,34 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1023
1301
|
|
|
1024
1302
|
if (np == 1) {
|
|
1025
1303
|
// No combination is needed, the meta data can be directly written from registers to VRAM.
|
|
1026
|
-
if (needs_fixup && threadIdx.x <
|
|
1304
|
+
if (needs_fixup && threadIdx.x < T_B_KQ::I) {
|
|
1027
1305
|
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
|
1028
1306
|
dstk_fixup_meta[jc_cwm] = KQ_cmr;
|
|
1029
1307
|
}
|
|
1030
|
-
if (is_fixup && threadIdx.x <
|
|
1308
|
+
if (is_fixup && threadIdx.x < T_B_KQ::I) {
|
|
1031
1309
|
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
|
1032
1310
|
dstk_fixup_meta[jc_cwm] = KQ_cmr;
|
|
1033
1311
|
}
|
|
1034
1312
|
}
|
|
1035
1313
|
} else {
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
const
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1314
|
+
// jc_cwm = jc combine write meta
|
|
1315
|
+
// KQ_cmr = KQ combine max rowsum
|
|
1316
|
+
// Use the 16 bytes of padding in each Q column to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
|
1317
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
1318
|
+
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
|
|
1319
|
+
const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
|
|
1320
|
+
const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
|
|
1321
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
1322
|
+
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
|
|
1323
|
+
const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
|
|
1324
|
+
const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
|
|
1325
|
+
#else // Volta
|
|
1326
|
+
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
|
|
1327
|
+
const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
|
|
1328
|
+
const bool thread_should_write = T_C_KQ::J == 8 || T_C_KQ::get_j(threadIdx.x & 2) < 8;
|
|
1329
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
1330
|
+
|
|
1331
|
+
if (((!needs_fixup && !is_fixup) || np > 1) && thread_should_write) {
|
|
1044
1332
|
((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
|
|
1045
1333
|
}
|
|
1046
1334
|
|
|
@@ -1048,31 +1336,30 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1048
1336
|
|
|
1049
1337
|
if (np == 1) {
|
|
1050
1338
|
// No combination is needed, the meta data can be directly written from registers to VRAM.
|
|
1051
|
-
if (needs_fixup &&
|
|
1339
|
+
if (needs_fixup && thread_should_write) {
|
|
1052
1340
|
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
|
1053
1341
|
dstk_fixup_meta[jc_cwm] = KQ_cmr;
|
|
1054
1342
|
}
|
|
1055
|
-
if (is_fixup &&
|
|
1343
|
+
if (is_fixup && thread_should_write) {
|
|
1056
1344
|
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
|
1057
1345
|
dstk_fixup_meta[jc_cwm] = KQ_cmr;
|
|
1058
1346
|
}
|
|
1059
1347
|
}
|
|
1060
1348
|
}
|
|
1061
1349
|
|
|
1062
|
-
static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles");
|
|
1063
1350
|
if (np > 1 && threadIdx.y % np == 0) {
|
|
1064
1351
|
// Combine the meta data for parallel warps via shared memory.
|
|
1065
1352
|
// Warps with threadIdx.y % np != 0 must NOT return early.
|
|
1066
1353
|
// All threads must return simultaneously to avoid race conditions with work on the next tile.
|
|
1067
1354
|
|
|
1068
|
-
constexpr int nmeta = np*cols_per_warp >=
|
|
1355
|
+
constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1;
|
|
1069
1356
|
|
|
1070
|
-
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp <
|
|
1357
|
+
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
|
|
1071
1358
|
float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
|
|
1072
1359
|
float2 meta[nmeta];
|
|
1073
1360
|
#pragma unroll
|
|
1074
1361
|
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
|
1075
|
-
meta[imeta] = meta_ptr[imeta *
|
|
1362
|
+
meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2];
|
|
1076
1363
|
}
|
|
1077
1364
|
|
|
1078
1365
|
float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
|
|
@@ -1082,8 +1369,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1082
1369
|
}
|
|
1083
1370
|
#pragma unroll
|
|
1084
1371
|
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
|
1085
|
-
if (offset <
|
|
1086
|
-
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset,
|
|
1372
|
+
if (offset < warp_size) {
|
|
1373
|
+
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size));
|
|
1087
1374
|
}
|
|
1088
1375
|
}
|
|
1089
1376
|
|
|
@@ -1100,8 +1387,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1100
1387
|
}
|
|
1101
1388
|
#pragma unroll
|
|
1102
1389
|
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
|
1103
|
-
if (offset <
|
|
1104
|
-
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset,
|
|
1390
|
+
if (offset < warp_size) {
|
|
1391
|
+
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size);
|
|
1105
1392
|
}
|
|
1106
1393
|
}
|
|
1107
1394
|
|
|
@@ -1110,19 +1397,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1110
1397
|
// Write back combined meta data:
|
|
1111
1398
|
#pragma unroll
|
|
1112
1399
|
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
|
1113
|
-
if (np*cols_per_warp >=
|
|
1400
|
+
if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) {
|
|
1114
1401
|
// Combined KQ max scale + rowsum.
|
|
1115
|
-
meta_ptr[imeta *
|
|
1402
|
+
meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
|
|
1116
1403
|
}
|
|
1117
1404
|
}
|
|
1118
1405
|
|
|
1119
1406
|
// Combined KQ max + rowsum.
|
|
1120
|
-
static_assert(cols_per_warp <=
|
|
1121
|
-
if (needs_fixup && (cols_per_warp ==
|
|
1407
|
+
static_assert(cols_per_warp <= warp_size);
|
|
1408
|
+
if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
|
|
1122
1409
|
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
|
1123
1410
|
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
|
1124
1411
|
}
|
|
1125
|
-
if (is_fixup && (cols_per_warp ==
|
|
1412
|
+
if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
|
|
1126
1413
|
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
|
1127
1414
|
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
|
1128
1415
|
}
|
|
@@ -1135,32 +1422,29 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1135
1422
|
|
|
1136
1423
|
#pragma unroll
|
|
1137
1424
|
for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
|
|
1138
|
-
if (
|
|
1139
|
-
const int jc_cwd = threadIdx.y*
|
|
1425
|
+
if constexpr (cols_per_warp == 8) {
|
|
1426
|
+
const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data
|
|
1140
1427
|
#pragma unroll
|
|
1141
|
-
for (int
|
|
1142
|
-
const
|
|
1428
|
+
for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) {
|
|
1429
|
+
const T_B_KQ B = get_transposed(VKQ_C[(k00 + k1)/T_B_KQ::J]); // Conversion of C to B matrix puts it in column-major format.
|
|
1143
1430
|
|
|
1144
1431
|
#pragma unroll
|
|
1145
|
-
for (int l = 0; l <
|
|
1146
|
-
const int k =
|
|
1432
|
+
for (int l = 0; l < T_B_KQ::ne; ++l) {
|
|
1433
|
+
const int k = k1 + T_B_KQ::get_j(l);
|
|
1147
1434
|
|
|
1148
1435
|
tile_Q[jc_cwd*tile_stride + k] = B.x[l];
|
|
1149
1436
|
}
|
|
1150
1437
|
}
|
|
1151
1438
|
} else {
|
|
1439
|
+
const int j0 = threadIdx.y*cols_per_warp;
|
|
1152
1440
|
#pragma unroll
|
|
1153
|
-
for (int
|
|
1154
|
-
const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
|
|
1441
|
+
for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {
|
|
1155
1442
|
#pragma unroll
|
|
1156
|
-
for (int
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
const int j = j0 + tile_C_VKQ_16::get_i(l);
|
|
1160
|
-
const int k = k0 + tile_C_VKQ_16::get_j(l);
|
|
1443
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
1444
|
+
const int j = j0 + T_C_VKQ::get_i(l);
|
|
1445
|
+
const int k = k1 + T_C_VKQ::get_j(l);
|
|
1161
1446
|
|
|
1162
|
-
|
|
1163
|
-
}
|
|
1447
|
+
tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l];
|
|
1164
1448
|
}
|
|
1165
1449
|
}
|
|
1166
1450
|
}
|
|
@@ -1173,10 +1457,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1173
1457
|
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
|
|
1174
1458
|
|
|
1175
1459
|
#pragma unroll
|
|
1176
|
-
for (int stride_k : {
|
|
1177
|
-
const int k0_start = stride_k ==
|
|
1460
|
+
for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
|
|
1461
|
+
const int k0_start = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
|
|
1178
1462
|
const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k);
|
|
1179
|
-
const int stride_jc =
|
|
1463
|
+
const int stride_jc = warp_size / stride_k;
|
|
1180
1464
|
|
|
1181
1465
|
if (k0_start == k0_stop) {
|
|
1182
1466
|
continue;
|
|
@@ -1184,7 +1468,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1184
1468
|
|
|
1185
1469
|
#pragma unroll
|
|
1186
1470
|
for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
|
|
1187
|
-
const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k ==
|
|
1471
|
+
const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
|
1188
1472
|
|
|
1189
1473
|
if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
|
|
1190
1474
|
break;
|
|
@@ -1195,14 +1479,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1195
1479
|
const int j_dst = jc_dst / ncols2;
|
|
1196
1480
|
const int c_dst = jc_dst % ncols2;
|
|
1197
1481
|
|
|
1198
|
-
if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
|
|
1482
|
+
if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
|
|
1199
1483
|
continue;
|
|
1200
1484
|
}
|
|
1201
1485
|
|
|
1202
1486
|
const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
|
|
1203
1487
|
#pragma unroll
|
|
1204
1488
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
1205
|
-
const int k = k0 + (stride_k ==
|
|
1489
|
+
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
|
1206
1490
|
|
|
1207
1491
|
float2 dstk_val = make_float2(0.0f, 0.0f);
|
|
1208
1492
|
#pragma unroll
|
|
@@ -1233,16 +1517,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1233
1517
|
}
|
|
1234
1518
|
}
|
|
1235
1519
|
#else
|
|
1236
|
-
GGML_UNUSED_VARS(Q_f2, K_h2, V_h2,
|
|
1237
|
-
scale, slope, logit_softcap, ne01, ne02,
|
|
1520
|
+
GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
|
|
1521
|
+
scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
|
|
1238
1522
|
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
|
|
1239
1523
|
jt, kb0_start, kb0_stop);
|
|
1240
1524
|
NO_DEVICE_CODE;
|
|
1241
|
-
#endif // TURING_MMA_AVAILABLE
|
|
1525
|
+
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
|
1242
1526
|
}
|
|
1243
1527
|
|
|
1244
|
-
template<int DKQ, int DV, int ncols1, int ncols2,
|
|
1245
|
-
__launch_bounds__(
|
|
1528
|
+
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
|
|
1529
|
+
__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
|
|
1246
1530
|
static __global__ void flash_attn_ext_f16(
|
|
1247
1531
|
const char * __restrict__ Q,
|
|
1248
1532
|
const char * __restrict__ K,
|
|
@@ -1258,20 +1542,27 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1258
1542
|
const float m1,
|
|
1259
1543
|
const uint32_t n_head_log2,
|
|
1260
1544
|
const float logit_softcap,
|
|
1261
|
-
const int32_t ne00, const
|
|
1545
|
+
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
|
|
1262
1546
|
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
1263
1547
|
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
1264
1548
|
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
1265
1549
|
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
1266
1550
|
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
1267
1551
|
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
1268
|
-
#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
|
|
1552
|
+
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
|
|
1269
1553
|
|
|
1270
1554
|
// Skip unused kernel variants for faster compilation:
|
|
1271
1555
|
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
|
1272
1556
|
NO_DEVICE_CODE;
|
|
1273
1557
|
return;
|
|
1274
1558
|
}
|
|
1559
|
+
#ifdef VOLTA_MMA_AVAILABLE
|
|
1560
|
+
if (ncols1*ncols2 < 32) {
|
|
1561
|
+
NO_DEVICE_CODE;
|
|
1562
|
+
return;
|
|
1563
|
+
}
|
|
1564
|
+
#endif // VOLTA_MMA_AVAILABLE
|
|
1565
|
+
|
|
1275
1566
|
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
1276
1567
|
if (ncols1*ncols2 > 32) {
|
|
1277
1568
|
NO_DEVICE_CODE;
|
|
@@ -1279,29 +1570,42 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1279
1570
|
}
|
|
1280
1571
|
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
1281
1572
|
|
|
1282
|
-
|
|
1573
|
+
#if defined(AMD_WMMA_AVAILABLE)
|
|
1574
|
+
if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) {
|
|
1575
|
+
NO_DEVICE_CODE;
|
|
1576
|
+
return;
|
|
1577
|
+
}
|
|
1578
|
+
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
1283
1579
|
|
|
1284
|
-
|
|
1580
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
1581
|
+
if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) {
|
|
1582
|
+
NO_DEVICE_CODE;
|
|
1583
|
+
return;
|
|
1584
|
+
}
|
|
1585
|
+
#endif // defined(AMD_MFMA_AVAILABLE)
|
|
1285
1586
|
|
|
1286
|
-
|
|
1587
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1588
|
+
constexpr int ncols = ncols1 * ncols2;
|
|
1589
|
+
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
|
1590
|
+
constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
|
|
1591
|
+
constexpr int nwarps = nthreads / warp_size;
|
|
1287
1592
|
|
|
1288
1593
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
1289
1594
|
|
|
1290
1595
|
const int stride_Q1 = nb01 / sizeof(float2);
|
|
1291
1596
|
const int stride_Q2 = nb02 / sizeof(float2);
|
|
1292
1597
|
const int stride_K = nb11 / sizeof(half2);
|
|
1293
|
-
const int stride_mask = nb31 / sizeof(
|
|
1294
|
-
|
|
1295
|
-
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
|
|
1598
|
+
const int stride_mask = nb31 / sizeof(half);
|
|
1296
1599
|
|
|
1297
|
-
const int
|
|
1298
|
-
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
|
1600
|
+
const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
|
|
1299
1601
|
|
|
1300
|
-
|
|
1602
|
+
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
|
1603
|
+
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
|
|
1604
|
+
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
|
|
1301
1605
|
|
|
1302
1606
|
// kbc == k block continuous, current index in continuous ijk space.
|
|
1303
|
-
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*
|
|
1304
|
-
const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*
|
|
1607
|
+
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
|
1608
|
+
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
|
1305
1609
|
|
|
1306
1610
|
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
|
1307
1611
|
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
|
@@ -1312,41 +1616,39 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1312
1616
|
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
|
1313
1617
|
|
|
1314
1618
|
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
|
1315
|
-
|
|
1316
|
-
const int
|
|
1317
|
-
const int
|
|
1619
|
+
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
|
|
1620
|
+
const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
|
|
1621
|
+
const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
|
|
1622
|
+
const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
|
|
1623
|
+
const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
|
|
1318
1624
|
|
|
1319
|
-
const int
|
|
1625
|
+
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
|
1320
1626
|
|
|
1321
|
-
const float2 * Q_f2
|
|
1322
|
-
const half2 * K_h2
|
|
1323
|
-
const
|
|
1324
|
-
(const
|
|
1325
|
-
float2 * dstk
|
|
1627
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
|
|
1628
|
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
|
|
1629
|
+
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
|
|
1630
|
+
(const half *) (mask + nb33*(sequence % ne33));
|
|
1631
|
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
|
|
1326
1632
|
|
|
1327
|
-
const half2 * V_h2 =
|
|
1328
|
-
const float * sinks_f = sinks ? (const float *) sinks +
|
|
1633
|
+
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
|
|
1634
|
+
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
|
|
1329
1635
|
|
|
1330
|
-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias,
|
|
1331
|
-
|
|
1332
|
-
const int kb0_start_kernel = kb0_start * kb_niter;
|
|
1333
|
-
int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
1636
|
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
|
|
1334
1637
|
|
|
1335
1638
|
if (KV_max) {
|
|
1336
|
-
|
|
1639
|
+
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
|
|
1337
1640
|
}
|
|
1338
|
-
|
|
1339
1641
|
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
|
1340
1642
|
if (kb0_start == 0) {
|
|
1341
1643
|
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
|
1342
|
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps,
|
|
1343
|
-
(Q_f2, K_h2, V_h2,
|
|
1344
|
-
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt,
|
|
1644
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
|
1645
|
+
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
1646
|
+
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
|
|
1345
1647
|
} else {
|
|
1346
|
-
constexpr bool needs_fixup = true; // CUDA block is
|
|
1347
|
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps,
|
|
1348
|
-
(Q_f2, K_h2, V_h2,
|
|
1349
|
-
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt,
|
|
1648
|
+
constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
|
|
1649
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
|
1650
|
+
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
1651
|
+
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
|
|
1350
1652
|
}
|
|
1351
1653
|
|
|
1352
1654
|
kbc += iter_k;
|
|
@@ -1360,35 +1662,34 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1360
1662
|
return;
|
|
1361
1663
|
}
|
|
1362
1664
|
|
|
1363
|
-
|
|
1364
|
-
const int
|
|
1365
|
-
const int
|
|
1665
|
+
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
|
|
1666
|
+
const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
|
|
1667
|
+
const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
|
|
1668
|
+
const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
|
|
1669
|
+
const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
|
|
1366
1670
|
|
|
1367
|
-
const int
|
|
1671
|
+
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
|
1368
1672
|
|
|
1369
|
-
const float2 * Q_f2
|
|
1370
|
-
const half2 * K_h2
|
|
1371
|
-
const
|
|
1372
|
-
(const
|
|
1373
|
-
float2 * dstk
|
|
1673
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
|
|
1674
|
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
|
|
1675
|
+
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
|
|
1676
|
+
(const half *) (mask + nb33*(sequence % ne33));
|
|
1677
|
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
|
|
1374
1678
|
|
|
1375
|
-
const half2 * V_h2 =
|
|
1376
|
-
const float * sinks_f = sinks ? (const float *) sinks +
|
|
1679
|
+
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
|
|
1680
|
+
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
|
|
1377
1681
|
|
|
1378
|
-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias,
|
|
1379
|
-
|
|
1380
|
-
const int kb0_start_kernel = kb0_start * kb_niter;
|
|
1381
|
-
int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
1682
|
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
|
|
1382
1683
|
|
|
1383
1684
|
if (KV_max) {
|
|
1384
|
-
|
|
1685
|
+
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
|
|
1385
1686
|
}
|
|
1386
1687
|
|
|
1387
1688
|
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
|
1388
1689
|
constexpr bool needs_fixup = false;
|
|
1389
|
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps,
|
|
1390
|
-
(Q_f2, K_h2, V_h2,
|
|
1391
|
-
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt,
|
|
1690
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
|
1691
|
+
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
1692
|
+
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
|
|
1392
1693
|
#else
|
|
1393
1694
|
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
|
1394
1695
|
max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
@@ -1400,7 +1701,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1400
1701
|
ne31, ne32, ne33,
|
|
1401
1702
|
nb31, nb32, nb33);
|
|
1402
1703
|
NO_DEVICE_CODE;
|
|
1403
|
-
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
|
|
1704
|
+
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
|
|
1404
1705
|
}
|
|
1405
1706
|
|
|
1406
1707
|
template <int DKQ, int DV, int ncols1, int ncols2>
|
|
@@ -1409,69 +1710,69 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
1409
1710
|
const int id = ggml_cuda_get_device();
|
|
1410
1711
|
const int cc = ggml_cuda_info().devices[id].cc;
|
|
1411
1712
|
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
|
|
1713
|
+
constexpr int ncols = ncols1 * ncols2;
|
|
1415
1714
|
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1715
|
+
const int nthreads = ggml_cuda_fattn_mma_get_nthreads (DKQ, DV, ncols, cc);
|
|
1716
|
+
const int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols, cc);
|
|
1717
|
+
const int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols, cc);
|
|
1718
|
+
const int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols, cc);
|
|
1719
|
+
const int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols, cc);
|
|
1720
|
+
const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc);
|
|
1721
|
+
const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
|
|
1422
1722
|
|
|
1423
|
-
|
|
1723
|
+
const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
|
|
1724
|
+
const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size;
|
|
1725
|
+
const int nwarps = nthreads / warp_size_host;
|
|
1424
1726
|
|
|
1425
|
-
|
|
1426
|
-
const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
|
|
1427
|
-
const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
|
|
1727
|
+
constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
|
|
1428
1728
|
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
static_assert(ncols % cols_per_warp == 0, "bad ncols");
|
|
1432
|
-
|
|
1433
|
-
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
|
1434
|
-
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
|
1729
|
+
const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
|
1730
|
+
const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
|
1435
1731
|
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
|
|
1436
|
-
const size_t nbytes_shared_mask = ncols1 * (
|
|
1732
|
+
const size_t nbytes_shared_mask = ncols1 * (nbatch_fa/2 + 4) * sizeof(half2);
|
|
1437
1733
|
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
|
|
1438
1734
|
|
|
1439
1735
|
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
|
|
1440
1736
|
|
|
1441
|
-
const size_t nbytes_shared_total = std::max(nbytes_shared_combine,
|
|
1737
|
+
const size_t nbytes_shared_total = std::max(nbytes_shared_combine, Q_in_reg ?
|
|
1442
1738
|
std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) :
|
|
1443
1739
|
nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask);
|
|
1444
1740
|
|
|
1445
1741
|
float logit_softcap;
|
|
1446
1742
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
|
1447
1743
|
|
|
1744
|
+
#if defined(GGML_USE_HIP)
|
|
1745
|
+
using fattn_kernel_ptr_t = const void*;
|
|
1746
|
+
#else
|
|
1747
|
+
using fattn_kernel_ptr_t = fattn_kernel_t;
|
|
1748
|
+
#endif // defined(GGML_USE_HIP)
|
|
1448
1749
|
fattn_kernel_t fattn_kernel;
|
|
1449
1750
|
if (logit_softcap == 0.0f) {
|
|
1450
1751
|
constexpr bool use_logit_softcap = false;
|
|
1451
|
-
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2,
|
|
1752
|
+
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
|
|
1452
1753
|
|
|
1453
|
-
#if !defined(
|
|
1754
|
+
#if !defined(GGML_USE_MUSA)
|
|
1454
1755
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
1455
1756
|
if (!shared_memory_limit_raised[id]) {
|
|
1456
|
-
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
1757
|
+
CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
1457
1758
|
shared_memory_limit_raised[id] = true;
|
|
1458
1759
|
}
|
|
1459
|
-
#endif // !defined(
|
|
1760
|
+
#endif // !defined(GGML_USE_MUSA)
|
|
1460
1761
|
} else {
|
|
1461
1762
|
constexpr bool use_logit_softcap = true;
|
|
1462
|
-
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2,
|
|
1763
|
+
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
|
|
1463
1764
|
|
|
1464
|
-
#if !defined(
|
|
1765
|
+
#if !defined(GGML_USE_MUSA)
|
|
1465
1766
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
1466
1767
|
if (!shared_memory_limit_raised[id]) {
|
|
1467
|
-
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
1768
|
+
CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
1468
1769
|
shared_memory_limit_raised[id] = true;
|
|
1469
1770
|
}
|
|
1470
|
-
#endif // !defined(
|
|
1771
|
+
#endif // !defined(GGML_USE_MUSA)
|
|
1471
1772
|
}
|
|
1472
1773
|
|
|
1473
1774
|
launch_fattn<DV, ncols1, ncols2>
|
|
1474
|
-
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total,
|
|
1775
|
+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host);
|
|
1475
1776
|
}
|
|
1476
1777
|
|
|
1477
1778
|
|
|
@@ -1518,3 +1819,10 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
|
|
1518
1819
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
|
1519
1820
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
|
1520
1821
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
|
1822
|
+
|
|
1823
|
+
// For GLM 4.7 Flash
|
|
1824
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
|
1825
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
|
1826
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
|
1827
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
|
|
1828
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);
|