whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -98,6 +98,57 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
|
|
98
98
|
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
99
99
|
}
|
|
100
100
|
|
|
101
|
+
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) {
|
|
102
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true);
|
|
103
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
|
104
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
|
105
|
+
|
|
106
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
|
107
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
|
108
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
|
109
|
+
|
|
110
|
+
// TODO tune specifically for RDNA
|
|
111
|
+
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) {
|
|
115
|
+
// Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async).
|
|
116
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 1, true);
|
|
117
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true);
|
|
118
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true);
|
|
119
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 32, 32, 32, 1, true);
|
|
120
|
+
|
|
121
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 1, true);
|
|
122
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 1, true);
|
|
123
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true);
|
|
124
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true);
|
|
125
|
+
|
|
126
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 1, true);
|
|
127
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 1, true);
|
|
128
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true);
|
|
129
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true);
|
|
130
|
+
|
|
131
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 1, true);
|
|
132
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 1, true);
|
|
133
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true);
|
|
134
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true);
|
|
135
|
+
|
|
136
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 1, true);
|
|
137
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 1, true);
|
|
138
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true);
|
|
139
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true);
|
|
140
|
+
|
|
141
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 1, true);
|
|
142
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 1, true);
|
|
143
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true);
|
|
144
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true);
|
|
145
|
+
|
|
146
|
+
// Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy
|
|
147
|
+
// compile-time static_asserts even though the kernel guard prevents runtime execution.
|
|
148
|
+
// nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility.
|
|
149
|
+
return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false);
|
|
150
|
+
}
|
|
151
|
+
|
|
101
152
|
static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
102
153
|
if (ampere_mma_available(cc)) {
|
|
103
154
|
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
@@ -105,6 +156,12 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c
|
|
|
105
156
|
if (turing_mma_available(cc)) {
|
|
106
157
|
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
|
|
107
158
|
}
|
|
159
|
+
if (amd_mfma_available(cc)) {
|
|
160
|
+
return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
|
|
161
|
+
}
|
|
162
|
+
if (amd_wmma_available(cc)) {
|
|
163
|
+
return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
|
|
164
|
+
}
|
|
108
165
|
GGML_ASSERT(volta_mma_available(cc));
|
|
109
166
|
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
|
|
110
167
|
}
|
|
@@ -114,8 +171,12 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons
|
|
|
114
171
|
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
115
172
|
#elif defined(TURING_MMA_AVAILABLE)
|
|
116
173
|
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
|
|
174
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
175
|
+
return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
|
|
117
176
|
#elif defined(VOLTA_MMA_AVAILABLE)
|
|
118
177
|
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
|
|
178
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
179
|
+
return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
|
|
119
180
|
#else
|
|
120
181
|
GGML_UNUSED_VARS(DKQ, DV, ncols);
|
|
121
182
|
return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
|
|
@@ -186,6 +247,23 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ,
|
|
|
186
247
|
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
|
|
187
248
|
}
|
|
188
249
|
|
|
250
|
+
static constexpr __device__ int get_cols_per_thread() {
|
|
251
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
252
|
+
return 1; // AMD has a single column per thread.
|
|
253
|
+
#else
|
|
254
|
+
return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
|
255
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
static __host__ int get_cols_per_warp(const int cc) {
|
|
259
|
+
if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) {
|
|
260
|
+
return 16;
|
|
261
|
+
} else {
|
|
262
|
+
// Volta
|
|
263
|
+
return 32;
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
|
|
189
267
|
// ------------------------------------------------------------------------------------------------------------------
|
|
190
268
|
|
|
191
269
|
static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
|
|
@@ -206,6 +284,7 @@ static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, c
|
|
|
206
284
|
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
|
|
207
285
|
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
208
286
|
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
|
|
287
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
209
288
|
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
|
210
289
|
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
|
|
211
290
|
if constexpr (use_cp_async) {
|
|
@@ -217,10 +296,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
217
296
|
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
|
218
297
|
|
|
219
298
|
auto load = [&] __device__ (auto n) {
|
|
220
|
-
const int stride_k =
|
|
221
|
-
const int k0_start = stride_k ==
|
|
299
|
+
const int stride_k = warp_size >> n;
|
|
300
|
+
const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
|
222
301
|
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
|
223
|
-
const int stride_i =
|
|
302
|
+
const int stride_i = warp_size / stride_k;
|
|
224
303
|
|
|
225
304
|
if (k0_start == k0_stop) {
|
|
226
305
|
return;
|
|
@@ -228,7 +307,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
228
307
|
|
|
229
308
|
#pragma unroll
|
|
230
309
|
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
|
231
|
-
const int i = i0 + threadIdx.y*stride_i + (stride_k ==
|
|
310
|
+
const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
|
232
311
|
|
|
233
312
|
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
|
234
313
|
break;
|
|
@@ -236,7 +315,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
236
315
|
|
|
237
316
|
#pragma unroll
|
|
238
317
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
239
|
-
const int k = k0 + (stride_k ==
|
|
318
|
+
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
|
240
319
|
|
|
241
320
|
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
|
|
242
321
|
}
|
|
@@ -252,10 +331,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
252
331
|
} else {
|
|
253
332
|
// TODO use ggml_cuda_memcpy_1
|
|
254
333
|
auto load = [&] __device__ (const int n) {
|
|
255
|
-
const int stride_k =
|
|
256
|
-
const int k0_start = stride_k ==
|
|
334
|
+
const int stride_k = warp_size >> n;
|
|
335
|
+
const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k);
|
|
257
336
|
const int k0_stop = D2 - D2 % (1*stride_k);
|
|
258
|
-
const int stride_i =
|
|
337
|
+
const int stride_i = warp_size / stride_k;
|
|
259
338
|
|
|
260
339
|
if (k0_start == k0_stop) {
|
|
261
340
|
return;
|
|
@@ -263,7 +342,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
263
342
|
|
|
264
343
|
#pragma unroll
|
|
265
344
|
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
|
266
|
-
const int i = i0 + threadIdx.y*stride_i + (stride_k ==
|
|
345
|
+
const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
|
267
346
|
|
|
268
347
|
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
|
269
348
|
break;
|
|
@@ -271,7 +350,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
271
350
|
|
|
272
351
|
#pragma unroll
|
|
273
352
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
274
|
-
const int k = k0 + (stride_k ==
|
|
353
|
+
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
|
275
354
|
|
|
276
355
|
tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
|
|
277
356
|
}
|
|
@@ -289,18 +368,19 @@ template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_chec
|
|
|
289
368
|
static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|
290
369
|
const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
|
|
291
370
|
const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
|
|
371
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
292
372
|
if constexpr (use_cp_async) {
|
|
293
|
-
static_assert(nbatch_fa <= 8*
|
|
373
|
+
static_assert(nbatch_fa <= 8*warp_size && nbatch_fa % 8 == 0, "bad nbatch_fa");
|
|
294
374
|
static_assert(!oob_check, "OOB check incompatible with cp_async");
|
|
295
375
|
constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
|
|
296
|
-
constexpr int cols_per_warp = 8*
|
|
376
|
+
constexpr int cols_per_warp = 8*warp_size/nbatch_fa;
|
|
297
377
|
constexpr int stride_j = nwarps * cols_per_warp;
|
|
298
378
|
|
|
299
379
|
const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
|
|
300
380
|
|
|
301
381
|
#pragma unroll
|
|
302
382
|
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
|
|
303
|
-
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (
|
|
383
|
+
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
|
|
304
384
|
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
|
305
385
|
|
|
306
386
|
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
|
|
@@ -322,25 +402,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|
|
322
402
|
}
|
|
323
403
|
|
|
324
404
|
#pragma unroll
|
|
325
|
-
for (int i0 = 0; i0 < nbatch_fa; i0 +=
|
|
405
|
+
for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) {
|
|
326
406
|
const int i = i0 + threadIdx.x;
|
|
327
407
|
|
|
328
408
|
tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
|
|
329
409
|
}
|
|
330
410
|
}
|
|
331
|
-
} else if constexpr (nbatch_fa < 2*
|
|
332
|
-
constexpr int cols_per_warp = 2*
|
|
411
|
+
} else if constexpr (nbatch_fa < 2*warp_size) {
|
|
412
|
+
constexpr int cols_per_warp = 2*warp_size/nbatch_fa;
|
|
333
413
|
constexpr int stride_j = nwarps * cols_per_warp;
|
|
334
414
|
#pragma unroll
|
|
335
415
|
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
|
|
336
|
-
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (
|
|
416
|
+
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
|
|
337
417
|
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
|
338
418
|
|
|
339
419
|
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
|
|
340
420
|
break;
|
|
341
421
|
}
|
|
342
422
|
|
|
343
|
-
const int i = threadIdx.x % (
|
|
423
|
+
const int i = threadIdx.x % (warp_size/cols_per_warp);
|
|
344
424
|
|
|
345
425
|
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
|
|
346
426
|
}
|
|
@@ -355,7 +435,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|
|
355
435
|
}
|
|
356
436
|
|
|
357
437
|
#pragma unroll
|
|
358
|
-
for (int i0 = 0; i0 < nbatch_fa; i0 += 2*
|
|
438
|
+
for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) {
|
|
359
439
|
const int i = i0 + 2*threadIdx.x;
|
|
360
440
|
|
|
361
441
|
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
|
|
@@ -365,7 +445,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|
|
365
445
|
}
|
|
366
446
|
|
|
367
447
|
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
|
|
368
|
-
bool use_logit_softcap, bool
|
|
448
|
+
bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
|
|
369
449
|
typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
|
|
370
450
|
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
371
451
|
const float2 * const __restrict__ Q_f2,
|
|
@@ -393,11 +473,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
393
473
|
const int jt,
|
|
394
474
|
const int kb0,
|
|
395
475
|
const int k_VKQ_sup) {
|
|
396
|
-
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
476
|
+
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
|
477
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
397
478
|
constexpr int ncols = ncols1 * ncols2;
|
|
398
479
|
constexpr int cols_per_warp = T_B_KQ::I;
|
|
399
|
-
constexpr int cols_per_thread =
|
|
400
|
-
constexpr int np = nwarps *
|
|
480
|
+
constexpr int cols_per_thread = get_cols_per_thread();
|
|
481
|
+
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
|
401
482
|
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
|
402
483
|
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
|
403
484
|
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
|
@@ -407,19 +488,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
407
488
|
constexpr int stride_tile_Q = DKQ/2 + 4;
|
|
408
489
|
constexpr int stride_tile_K = nbatch_K2 + 4;
|
|
409
490
|
|
|
410
|
-
|
|
411
|
-
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
|
491
|
+
constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
|
|
412
492
|
|
|
413
493
|
const int k_VKQ_0 = kb0 * nbatch_fa;
|
|
414
494
|
#if defined(TURING_MMA_AVAILABLE)
|
|
415
495
|
T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
|
|
496
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
497
|
+
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
|
|
416
498
|
#else // Volta
|
|
417
499
|
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
|
|
418
500
|
#endif // defined(TURING_MMA_AVAILABLE)
|
|
419
501
|
|
|
420
502
|
if constexpr (nstages > 1) {
|
|
421
503
|
static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
|
|
422
|
-
static_assert(!
|
|
504
|
+
static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
|
|
423
505
|
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
|
424
506
|
constexpr bool use_cp_async = true;
|
|
425
507
|
cp_async_wait_all();
|
|
@@ -434,8 +516,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
434
516
|
}
|
|
435
517
|
}
|
|
436
518
|
|
|
519
|
+
// For MLA K and V have the same data.
|
|
520
|
+
// Therefore, iterate over K in reverse and later re-use the data if possible.
|
|
437
521
|
#pragma unroll
|
|
438
|
-
for (int k0_start =
|
|
522
|
+
for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) {
|
|
439
523
|
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
|
440
524
|
const int k0_diff = k0_stop - k0_start;
|
|
441
525
|
|
|
@@ -461,13 +545,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
461
545
|
if constexpr (cols_per_warp == 8) {
|
|
462
546
|
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
|
|
463
547
|
} else {
|
|
464
|
-
// Wide version of KQ_C is column-major
|
|
548
|
+
// Wide version of KQ_C is column-major
|
|
549
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
550
|
+
// AMD matrix C is column-major.
|
|
551
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
|
|
552
|
+
#else
|
|
553
|
+
// swap A and B for CUDA.
|
|
465
554
|
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
|
|
555
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
466
556
|
}
|
|
467
557
|
}
|
|
468
558
|
}
|
|
469
559
|
} else {
|
|
470
|
-
static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
|
|
471
560
|
#pragma unroll
|
|
472
561
|
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
|
473
562
|
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
|
@@ -479,8 +568,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
479
568
|
T_A_KQ K_A;
|
|
480
569
|
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
|
481
570
|
|
|
482
|
-
|
|
483
|
-
|
|
571
|
+
if constexpr (cols_per_warp == 8) {
|
|
572
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
|
573
|
+
} else {
|
|
574
|
+
// Wide version of KQ_C is column-major
|
|
575
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
576
|
+
// AMD matrix C is column-major.
|
|
577
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
|
578
|
+
#else
|
|
579
|
+
// swap A and B for CUDA.
|
|
580
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
|
581
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
582
|
+
}
|
|
484
583
|
}
|
|
485
584
|
}
|
|
486
585
|
}
|
|
@@ -532,7 +631,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
532
631
|
#pragma unroll
|
|
533
632
|
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
534
633
|
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
|
535
|
-
|
|
634
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
635
|
+
constexpr int KQ_idx = 0;
|
|
636
|
+
#else
|
|
637
|
+
// Turing + Volta:
|
|
638
|
+
const int KQ_idx = l % 2;
|
|
639
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
640
|
+
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
|
|
536
641
|
}
|
|
537
642
|
}
|
|
538
643
|
}
|
|
@@ -542,7 +647,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
542
647
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
543
648
|
#pragma unroll
|
|
544
649
|
for (int offset = 16; offset >= 4; offset >>= 1) {
|
|
545
|
-
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset,
|
|
650
|
+
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
|
|
546
651
|
}
|
|
547
652
|
}
|
|
548
653
|
|
|
@@ -552,8 +657,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
552
657
|
#pragma unroll
|
|
553
658
|
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
554
659
|
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
|
555
|
-
|
|
556
|
-
|
|
660
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
661
|
+
constexpr int KQ_idx = 0;
|
|
662
|
+
#else
|
|
663
|
+
// Turing + Volta:
|
|
664
|
+
const int KQ_idx = l % 2;
|
|
665
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
666
|
+
KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
|
|
667
|
+
KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
|
|
557
668
|
} else {
|
|
558
669
|
KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
|
|
559
670
|
}
|
|
@@ -584,8 +695,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
584
695
|
#pragma unroll
|
|
585
696
|
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
586
697
|
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
|
698
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
699
|
+
constexpr int KQ_idx = 0;
|
|
700
|
+
#else
|
|
587
701
|
// Turing + Volta:
|
|
588
|
-
|
|
702
|
+
const int KQ_idx = (l/2) % 2;
|
|
703
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
704
|
+
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
|
|
589
705
|
}
|
|
590
706
|
}
|
|
591
707
|
}
|
|
@@ -596,14 +712,22 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
596
712
|
// Values per KQ column are spread across 4 threads:
|
|
597
713
|
constexpr int offset_first = 2;
|
|
598
714
|
constexpr int offset_last = 1;
|
|
599
|
-
#
|
|
715
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
716
|
+
// MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16).
|
|
717
|
+
constexpr int offset_first = 32;
|
|
718
|
+
constexpr int offset_last = 16;
|
|
719
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
720
|
+
// Values per KQ column are spread across 2 threads:
|
|
721
|
+
constexpr int offset_first = 16;
|
|
722
|
+
constexpr int offset_last = 16;
|
|
723
|
+
#else // Volta
|
|
600
724
|
// Values per KQ column are spread across 2 threads:
|
|
601
725
|
constexpr int offset_first = 2;
|
|
602
726
|
constexpr int offset_last = 2;
|
|
603
727
|
#endif // defined(TURING_MMA_AVAILABLE)
|
|
604
728
|
#pragma unroll
|
|
605
729
|
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
|
606
|
-
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset,
|
|
730
|
+
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
|
|
607
731
|
}
|
|
608
732
|
}
|
|
609
733
|
|
|
@@ -612,10 +736,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
612
736
|
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
|
|
613
737
|
#pragma unroll
|
|
614
738
|
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
615
|
-
// Turing + Volta:
|
|
616
739
|
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
|
617
|
-
|
|
618
|
-
|
|
740
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
741
|
+
constexpr int KQ_idx = 0;
|
|
742
|
+
#else
|
|
743
|
+
// Turing + Volta:
|
|
744
|
+
const int KQ_idx = (l/2) % 2;
|
|
745
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
746
|
+
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
|
|
747
|
+
KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
|
|
619
748
|
} else {
|
|
620
749
|
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
|
|
621
750
|
}
|
|
@@ -639,7 +768,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
639
768
|
|
|
640
769
|
#if defined(TURING_MMA_AVAILABLE)
|
|
641
770
|
if constexpr (cols_per_warp == 8) {
|
|
642
|
-
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
|
771
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
|
|
643
772
|
#pragma unroll
|
|
644
773
|
for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
|
|
645
774
|
#pragma unroll
|
|
@@ -660,6 +789,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
660
789
|
}
|
|
661
790
|
}
|
|
662
791
|
}
|
|
792
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
793
|
+
const half2 KQ_max_scale_h2 = make_half2(
|
|
794
|
+
KQ_max_scale[0], KQ_max_scale[0]);
|
|
795
|
+
#pragma unroll
|
|
796
|
+
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
|
797
|
+
#pragma unroll
|
|
798
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
799
|
+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
800
|
+
}
|
|
801
|
+
}
|
|
663
802
|
#else // Volta
|
|
664
803
|
const half2 KQ_max_scale_h2 = make_half2(
|
|
665
804
|
KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
|
|
@@ -688,6 +827,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
688
827
|
}
|
|
689
828
|
|
|
690
829
|
if constexpr (nstages > 1) {
|
|
830
|
+
static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
|
|
691
831
|
// Preload K tile for next iteration:
|
|
692
832
|
constexpr bool use_cp_async = true;
|
|
693
833
|
cp_async_wait_all();
|
|
@@ -703,19 +843,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
703
843
|
}
|
|
704
844
|
|
|
705
845
|
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
846
|
+
#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
|
|
847
|
+
T_A_VKQ A_identity;
|
|
848
|
+
make_identity_mat(A_identity);
|
|
849
|
+
#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
|
|
710
850
|
|
|
711
851
|
// Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
|
|
712
852
|
#pragma unroll
|
|
713
|
-
for (int
|
|
714
|
-
|
|
715
|
-
const int
|
|
853
|
+
for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
|
|
854
|
+
static_assert(DV % (2*nbatch_V2) == 0, "bad loop size");
|
|
855
|
+
const int i0_stop = i0_start + 2*nbatch_V2;
|
|
856
|
+
const int i0_diff = i0_stop - i0_start;
|
|
716
857
|
|
|
717
858
|
if constexpr (nstages <= 1) {
|
|
718
|
-
if (
|
|
859
|
+
if (!V_is_K_view || i0_stop > 2*nbatch_K2) {
|
|
719
860
|
constexpr bool use_cp_async = nstages == 1;
|
|
720
861
|
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
721
862
|
(V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
|
|
@@ -725,9 +866,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
725
866
|
__syncthreads();
|
|
726
867
|
}
|
|
727
868
|
}
|
|
728
|
-
const half2 * tile_V_i =
|
|
869
|
+
const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
|
|
729
870
|
|
|
730
|
-
#if defined(TURING_MMA_AVAILABLE)
|
|
871
|
+
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
731
872
|
constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
|
|
732
873
|
#pragma unroll
|
|
733
874
|
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
|
|
@@ -737,12 +878,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
737
878
|
const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
|
|
738
879
|
|
|
739
880
|
T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
|
|
881
|
+
#if defined(LDMATRIX_TRANS_AVAILABLE)
|
|
740
882
|
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
|
883
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
884
|
+
// MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg].
|
|
885
|
+
// Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T.
|
|
886
|
+
// Load with transposed addressing: 4 strided half loads.
|
|
887
|
+
{
|
|
888
|
+
const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2;
|
|
889
|
+
const half * xs0_h = (const half *) xs0;
|
|
890
|
+
const int stride_h = stride_tile_V * 2; // stride in half units
|
|
891
|
+
half * A_h = (half *) A.x;
|
|
892
|
+
#pragma unroll
|
|
893
|
+
for (int l = 0; l < 4; ++l) {
|
|
894
|
+
A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16];
|
|
895
|
+
}
|
|
896
|
+
}
|
|
897
|
+
#else
|
|
898
|
+
// TODO: Try to transpose tile_V when loading gmem to smem.
|
|
899
|
+
// Use mma to transpose T_A_VKQ for RDNA.
|
|
900
|
+
T_A_VKQ A_trans;
|
|
901
|
+
load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
|
902
|
+
mma(A, A_trans, A_identity);
|
|
903
|
+
#endif // defined(LDMATRIX_TRANS_AVAILABLE)
|
|
741
904
|
if constexpr (T_B_KQ::I == 8) {
|
|
742
905
|
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
|
743
906
|
} else {
|
|
744
|
-
// Wide version of VKQ_C is column-major
|
|
907
|
+
// Wide version of VKQ_C is column-major.
|
|
908
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
909
|
+
// AMD matrix C is column-major.
|
|
910
|
+
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
|
911
|
+
#else
|
|
912
|
+
// swap A and B for CUDA.
|
|
745
913
|
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
|
|
914
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
746
915
|
}
|
|
747
916
|
}
|
|
748
917
|
}
|
|
@@ -761,7 +930,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
761
930
|
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
|
|
762
931
|
}
|
|
763
932
|
}
|
|
764
|
-
#endif // defined(TURING_MMA_AVAILABLE)
|
|
933
|
+
#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
765
934
|
|
|
766
935
|
if constexpr (nstages <= 1) {
|
|
767
936
|
__syncthreads(); // Only needed if tile_K == tile_V.
|
|
@@ -774,7 +943,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
774
943
|
tile_Q, tile_K, tile_V, tile_mask,
|
|
775
944
|
Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
|
776
945
|
NO_DEVICE_CODE;
|
|
777
|
-
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
946
|
+
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
|
778
947
|
}
|
|
779
948
|
|
|
780
949
|
#if defined(TURING_MMA_AVAILABLE)
|
|
@@ -794,6 +963,15 @@ template<> struct mma_tile_sizes<8> {
|
|
|
794
963
|
using T_B_VKQ = tile< 8, 8, half2>; // column-major
|
|
795
964
|
using T_C_VKQ = tile<16, 4, half2>; // row-major
|
|
796
965
|
};
|
|
966
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
967
|
+
template<int ncols> struct mma_tile_sizes {
|
|
968
|
+
using T_A_KQ = tile<16, 8, half2>; // row-major
|
|
969
|
+
using T_B_KQ = tile<16, 8, half2>; // column-major
|
|
970
|
+
using T_C_KQ = tile<16, 16, float>; // column-major
|
|
971
|
+
using T_A_VKQ = tile<16, 8, half2>; // row-major
|
|
972
|
+
using T_B_VKQ = tile<16, 8, half2>; // column-major
|
|
973
|
+
using T_C_VKQ = tile<16, 8, half2>; // column-major
|
|
974
|
+
};
|
|
797
975
|
#else // Volta
|
|
798
976
|
template<int ncols> struct mma_tile_sizes {
|
|
799
977
|
using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
|
@@ -805,7 +983,7 @@ template<int ncols> struct mma_tile_sizes {
|
|
|
805
983
|
};
|
|
806
984
|
#endif // defined(TURING_MMA_AVAILABLE)
|
|
807
985
|
|
|
808
|
-
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool
|
|
986
|
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup>
|
|
809
987
|
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
810
988
|
const float2 * const __restrict__ Q_f2,
|
|
811
989
|
const half2 * const __restrict__ K_h2,
|
|
@@ -819,6 +997,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
819
997
|
const float logit_softcap,
|
|
820
998
|
const uint3 ne01,
|
|
821
999
|
const int ne02,
|
|
1000
|
+
const int gqa_ratio,
|
|
822
1001
|
const int ne11,
|
|
823
1002
|
const int stride_Q1,
|
|
824
1003
|
const int stride_Q2,
|
|
@@ -826,11 +1005,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
826
1005
|
const int stride_V,
|
|
827
1006
|
const int stride_mask,
|
|
828
1007
|
const int jt,
|
|
1008
|
+
const int zt_gqa,
|
|
829
1009
|
const int kb0_start,
|
|
830
1010
|
const int kb0_stop) {
|
|
831
|
-
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1011
|
+
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
|
832
1012
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
|
833
1013
|
|
|
1014
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
834
1015
|
constexpr int ncols = ncols1 * ncols2;
|
|
835
1016
|
using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ;
|
|
836
1017
|
using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ;
|
|
@@ -840,8 +1021,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
840
1021
|
using T_C_VKQ = typename mma_tile_sizes<ncols>::T_C_VKQ;
|
|
841
1022
|
|
|
842
1023
|
constexpr int cols_per_warp = T_B_KQ::I;
|
|
843
|
-
constexpr int cols_per_thread =
|
|
844
|
-
constexpr int np = nwarps *
|
|
1024
|
+
constexpr int cols_per_thread = get_cols_per_thread();
|
|
1025
|
+
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
|
845
1026
|
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
|
846
1027
|
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
|
847
1028
|
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
|
@@ -859,8 +1040,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
859
1040
|
constexpr int stride_tile_Q = DKQ/2 + 4;
|
|
860
1041
|
constexpr int stride_tile_K = nbatch_K2 + 4;
|
|
861
1042
|
|
|
862
|
-
|
|
863
|
-
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
|
1043
|
+
constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
|
|
864
1044
|
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
|
865
1045
|
|
|
866
1046
|
extern __shared__ half2 tile_Q[];
|
|
@@ -871,6 +1051,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
871
1051
|
T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
|
|
872
1052
|
#if defined(TURING_MMA_AVAILABLE)
|
|
873
1053
|
T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
|
|
1054
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
1055
|
+
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
|
|
874
1056
|
#else // Volta
|
|
875
1057
|
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
|
|
876
1058
|
#endif // defined(TURING_MMA_AVAILABLE)
|
|
@@ -887,10 +1069,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
887
1069
|
// The loading is done with decreasing granularity for D for better memory bandwidth.
|
|
888
1070
|
const half2 scale_h2 = make_half2(scale, scale);
|
|
889
1071
|
#pragma unroll
|
|
890
|
-
for (int stride_k : {
|
|
891
|
-
const int k0_start = stride_k ==
|
|
1072
|
+
for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
|
|
1073
|
+
const int k0_start = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
|
|
892
1074
|
const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k);
|
|
893
|
-
const int stride_jc =
|
|
1075
|
+
const int stride_jc = warp_size / stride_k;
|
|
894
1076
|
|
|
895
1077
|
if (k0_start == k0_stop) {
|
|
896
1078
|
continue;
|
|
@@ -898,7 +1080,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
898
1080
|
|
|
899
1081
|
#pragma unroll
|
|
900
1082
|
for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
|
|
901
|
-
const int jc = jc0 + threadIdx.y*stride_jc + (stride_k ==
|
|
1083
|
+
const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
|
902
1084
|
|
|
903
1085
|
if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
|
|
904
1086
|
break;
|
|
@@ -907,10 +1089,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
907
1089
|
const int j = jc / ncols2;
|
|
908
1090
|
const int c = jc % ncols2;
|
|
909
1091
|
|
|
910
|
-
if (jt*ncols1 + j < int(ne01.z)) {
|
|
1092
|
+
if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
|
|
911
1093
|
#pragma unroll
|
|
912
1094
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
913
|
-
const int k = k0 + (stride_k ==
|
|
1095
|
+
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
|
914
1096
|
|
|
915
1097
|
const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
|
|
916
1098
|
tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
|
|
@@ -918,7 +1100,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
918
1100
|
} else {
|
|
919
1101
|
#pragma unroll
|
|
920
1102
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
921
|
-
const int k = k0 + (stride_k ==
|
|
1103
|
+
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
|
922
1104
|
|
|
923
1105
|
tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
|
|
924
1106
|
}
|
|
@@ -962,7 +1144,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
962
1144
|
constexpr bool last_iter = false;
|
|
963
1145
|
constexpr int k_VKQ_sup = nbatch_fa;
|
|
964
1146
|
flash_attn_ext_f16_iter
|
|
965
|
-
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap,
|
|
1147
|
+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
|
966
1148
|
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
967
1149
|
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
968
1150
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
@@ -971,7 +1153,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
971
1153
|
constexpr bool last_iter = true;
|
|
972
1154
|
const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
|
|
973
1155
|
flash_attn_ext_f16_iter
|
|
974
|
-
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap,
|
|
1156
|
+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
|
975
1157
|
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
976
1158
|
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
977
1159
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
@@ -982,7 +1164,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
982
1164
|
constexpr bool last_iter = false;
|
|
983
1165
|
constexpr int k_VKQ_sup = nbatch_fa;
|
|
984
1166
|
flash_attn_ext_f16_iter
|
|
985
|
-
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap,
|
|
1167
|
+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
|
986
1168
|
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
987
1169
|
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
988
1170
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
@@ -991,7 +1173,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
991
1173
|
constexpr bool last_iter = true;
|
|
992
1174
|
constexpr int k_VKQ_sup = nbatch_fa;
|
|
993
1175
|
flash_attn_ext_f16_iter
|
|
994
|
-
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap,
|
|
1176
|
+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
|
995
1177
|
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
996
1178
|
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
997
1179
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
@@ -1010,6 +1192,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1010
1192
|
// The partial sums are spread across 8/4 threads.
|
|
1011
1193
|
constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
|
|
1012
1194
|
constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
|
|
1195
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
1196
|
+
// The partial sums are spread across 4 threads (wavefront64, 16 cols).
|
|
1197
|
+
constexpr int offset_first = 32;
|
|
1198
|
+
constexpr int offset_last = 16;
|
|
1199
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
1200
|
+
// The partial sums are spread across 2 threads.
|
|
1201
|
+
constexpr int offset_first = 16;
|
|
1202
|
+
constexpr int offset_last = 16;
|
|
1013
1203
|
#else // Volta
|
|
1014
1204
|
// The partial sums are spread across 2 threads.
|
|
1015
1205
|
constexpr int offset_first = 2;
|
|
@@ -1019,13 +1209,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1019
1209
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
1020
1210
|
#pragma unroll
|
|
1021
1211
|
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
|
1022
|
-
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset,
|
|
1212
|
+
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size);
|
|
1023
1213
|
}
|
|
1024
1214
|
}
|
|
1025
1215
|
}
|
|
1026
1216
|
|
|
1027
1217
|
// If attention sinks are used, potentially re-scale if KQ_max is small.
|
|
1028
|
-
// Also add the sink as a value to KQ_rowsum, this is done after
|
|
1218
|
+
// Also add the sink as a value to KQ_rowsum, this is done after synchronization of KQ_rowsum
|
|
1029
1219
|
// so it's being done unconditionally for every thread.
|
|
1030
1220
|
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
|
|
1031
1221
|
float KQ_max_scale[cols_per_thread];
|
|
@@ -1047,7 +1237,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1047
1237
|
|
|
1048
1238
|
#if defined(TURING_MMA_AVAILABLE)
|
|
1049
1239
|
if constexpr (cols_per_warp == 8) {
|
|
1050
|
-
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
|
1240
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
|
|
1051
1241
|
#pragma unroll
|
|
1052
1242
|
for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
|
|
1053
1243
|
#pragma unroll
|
|
@@ -1068,6 +1258,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1068
1258
|
}
|
|
1069
1259
|
}
|
|
1070
1260
|
}
|
|
1261
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
1262
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
|
|
1263
|
+
#pragma unroll
|
|
1264
|
+
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
|
1265
|
+
#pragma unroll
|
|
1266
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
1267
|
+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
1268
|
+
}
|
|
1269
|
+
}
|
|
1071
1270
|
#else // Volta
|
|
1072
1271
|
const int col = (threadIdx.x / 2) % 2;
|
|
1073
1272
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
|
@@ -1119,6 +1318,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1119
1318
|
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
|
|
1120
1319
|
const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
|
|
1121
1320
|
const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
|
|
1321
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
1322
|
+
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
|
|
1323
|
+
const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
|
|
1324
|
+
const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
|
|
1122
1325
|
#else // Volta
|
|
1123
1326
|
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
|
|
1124
1327
|
const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
|
|
@@ -1149,14 +1352,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1149
1352
|
// Warps with threadIdx.y % np != 0 must NOT return early.
|
|
1150
1353
|
// All threads must return simultaneously to avoid race conditions with work on the next tile.
|
|
1151
1354
|
|
|
1152
|
-
constexpr int nmeta = np*cols_per_warp >=
|
|
1355
|
+
constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1;
|
|
1153
1356
|
|
|
1154
|
-
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp <
|
|
1357
|
+
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
|
|
1155
1358
|
float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
|
|
1156
1359
|
float2 meta[nmeta];
|
|
1157
1360
|
#pragma unroll
|
|
1158
1361
|
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
|
1159
|
-
meta[imeta] = meta_ptr[imeta *
|
|
1362
|
+
meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2];
|
|
1160
1363
|
}
|
|
1161
1364
|
|
|
1162
1365
|
float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
|
|
@@ -1166,8 +1369,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1166
1369
|
}
|
|
1167
1370
|
#pragma unroll
|
|
1168
1371
|
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
|
1169
|
-
if (offset <
|
|
1170
|
-
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset,
|
|
1372
|
+
if (offset < warp_size) {
|
|
1373
|
+
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size));
|
|
1171
1374
|
}
|
|
1172
1375
|
}
|
|
1173
1376
|
|
|
@@ -1184,8 +1387,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1184
1387
|
}
|
|
1185
1388
|
#pragma unroll
|
|
1186
1389
|
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
|
1187
|
-
if (offset <
|
|
1188
|
-
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset,
|
|
1390
|
+
if (offset < warp_size) {
|
|
1391
|
+
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size);
|
|
1189
1392
|
}
|
|
1190
1393
|
}
|
|
1191
1394
|
|
|
@@ -1194,19 +1397,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1194
1397
|
// Write back combined meta data:
|
|
1195
1398
|
#pragma unroll
|
|
1196
1399
|
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
|
1197
|
-
if (np*cols_per_warp >=
|
|
1400
|
+
if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) {
|
|
1198
1401
|
// Combined KQ max scale + rowsum.
|
|
1199
|
-
meta_ptr[imeta *
|
|
1402
|
+
meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
|
|
1200
1403
|
}
|
|
1201
1404
|
}
|
|
1202
1405
|
|
|
1203
1406
|
// Combined KQ max + rowsum.
|
|
1204
|
-
static_assert(cols_per_warp <=
|
|
1205
|
-
if (needs_fixup && (cols_per_warp ==
|
|
1407
|
+
static_assert(cols_per_warp <= warp_size);
|
|
1408
|
+
if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
|
|
1206
1409
|
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
|
1207
1410
|
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
|
1208
1411
|
}
|
|
1209
|
-
if (is_fixup && (cols_per_warp ==
|
|
1412
|
+
if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
|
|
1210
1413
|
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
|
1211
1414
|
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
|
1212
1415
|
}
|
|
@@ -1254,10 +1457,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1254
1457
|
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
|
|
1255
1458
|
|
|
1256
1459
|
#pragma unroll
|
|
1257
|
-
for (int stride_k : {
|
|
1258
|
-
const int k0_start = stride_k ==
|
|
1460
|
+
for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
|
|
1461
|
+
const int k0_start = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
|
|
1259
1462
|
const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k);
|
|
1260
|
-
const int stride_jc =
|
|
1463
|
+
const int stride_jc = warp_size / stride_k;
|
|
1261
1464
|
|
|
1262
1465
|
if (k0_start == k0_stop) {
|
|
1263
1466
|
continue;
|
|
@@ -1265,7 +1468,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1265
1468
|
|
|
1266
1469
|
#pragma unroll
|
|
1267
1470
|
for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
|
|
1268
|
-
const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k ==
|
|
1471
|
+
const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
|
1269
1472
|
|
|
1270
1473
|
if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
|
|
1271
1474
|
break;
|
|
@@ -1276,14 +1479,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1276
1479
|
const int j_dst = jc_dst / ncols2;
|
|
1277
1480
|
const int c_dst = jc_dst % ncols2;
|
|
1278
1481
|
|
|
1279
|
-
if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
|
|
1482
|
+
if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
|
|
1280
1483
|
continue;
|
|
1281
1484
|
}
|
|
1282
1485
|
|
|
1283
1486
|
const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
|
|
1284
1487
|
#pragma unroll
|
|
1285
1488
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
1286
|
-
const int k = k0 + (stride_k ==
|
|
1489
|
+
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
|
1287
1490
|
|
|
1288
1491
|
float2 dstk_val = make_float2(0.0f, 0.0f);
|
|
1289
1492
|
#pragma unroll
|
|
@@ -1315,14 +1518,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1315
1518
|
}
|
|
1316
1519
|
#else
|
|
1317
1520
|
GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
|
|
1318
|
-
scale, slope, logit_softcap, ne01, ne02,
|
|
1521
|
+
scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
|
|
1319
1522
|
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
|
|
1320
1523
|
jt, kb0_start, kb0_stop);
|
|
1321
1524
|
NO_DEVICE_CODE;
|
|
1322
|
-
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1525
|
+
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
|
1323
1526
|
}
|
|
1324
1527
|
|
|
1325
|
-
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool
|
|
1528
|
+
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
|
|
1326
1529
|
__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
|
|
1327
1530
|
static __global__ void flash_attn_ext_f16(
|
|
1328
1531
|
const char * __restrict__ Q,
|
|
@@ -1346,13 +1549,20 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1346
1549
|
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
1347
1550
|
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
1348
1551
|
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
1349
|
-
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
|
|
1552
|
+
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
|
|
1350
1553
|
|
|
1351
1554
|
// Skip unused kernel variants for faster compilation:
|
|
1352
1555
|
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
|
1353
1556
|
NO_DEVICE_CODE;
|
|
1354
1557
|
return;
|
|
1355
1558
|
}
|
|
1559
|
+
#ifdef VOLTA_MMA_AVAILABLE
|
|
1560
|
+
if (ncols1*ncols2 < 32) {
|
|
1561
|
+
NO_DEVICE_CODE;
|
|
1562
|
+
return;
|
|
1563
|
+
}
|
|
1564
|
+
#endif // VOLTA_MMA_AVAILABLE
|
|
1565
|
+
|
|
1356
1566
|
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
1357
1567
|
if (ncols1*ncols2 > 32) {
|
|
1358
1568
|
NO_DEVICE_CODE;
|
|
@@ -1360,12 +1570,25 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1360
1570
|
}
|
|
1361
1571
|
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
1362
1572
|
|
|
1363
|
-
|
|
1573
|
+
#if defined(AMD_WMMA_AVAILABLE)
|
|
1574
|
+
if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) {
|
|
1575
|
+
NO_DEVICE_CODE;
|
|
1576
|
+
return;
|
|
1577
|
+
}
|
|
1578
|
+
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
1579
|
+
|
|
1580
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
1581
|
+
if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) {
|
|
1582
|
+
NO_DEVICE_CODE;
|
|
1583
|
+
return;
|
|
1584
|
+
}
|
|
1585
|
+
#endif // defined(AMD_MFMA_AVAILABLE)
|
|
1364
1586
|
|
|
1587
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1365
1588
|
constexpr int ncols = ncols1 * ncols2;
|
|
1366
1589
|
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
|
1367
1590
|
constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
|
|
1368
|
-
constexpr int nwarps = nthreads /
|
|
1591
|
+
constexpr int nwarps = nthreads / warp_size;
|
|
1369
1592
|
|
|
1370
1593
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
1371
1594
|
|
|
@@ -1374,14 +1597,15 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1374
1597
|
const int stride_K = nb11 / sizeof(half2);
|
|
1375
1598
|
const int stride_mask = nb31 / sizeof(half);
|
|
1376
1599
|
|
|
1377
|
-
const int stride_V =
|
|
1600
|
+
const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
|
|
1378
1601
|
|
|
1379
|
-
const int iter_k
|
|
1380
|
-
const int iter_j
|
|
1602
|
+
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
|
1603
|
+
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
|
|
1604
|
+
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
|
|
1381
1605
|
|
|
1382
1606
|
// kbc == k block continuous, current index in continuous ijk space.
|
|
1383
|
-
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*
|
|
1384
|
-
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*
|
|
1607
|
+
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
|
1608
|
+
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
|
1385
1609
|
|
|
1386
1610
|
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
|
1387
1611
|
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
|
@@ -1392,22 +1616,24 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1392
1616
|
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
|
1393
1617
|
|
|
1394
1618
|
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
|
1395
|
-
|
|
1396
|
-
const int
|
|
1397
|
-
const int
|
|
1619
|
+
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
|
|
1620
|
+
const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
|
|
1621
|
+
const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
|
|
1622
|
+
const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
|
|
1623
|
+
const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
|
|
1398
1624
|
|
|
1399
|
-
const int
|
|
1625
|
+
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
|
1400
1626
|
|
|
1401
|
-
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*
|
|
1402
|
-
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*
|
|
1627
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
|
|
1628
|
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
|
|
1403
1629
|
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
|
|
1404
1630
|
(const half *) (mask + nb33*(sequence % ne33));
|
|
1405
|
-
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 +
|
|
1631
|
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
|
|
1406
1632
|
|
|
1407
|
-
const half2 * V_h2 =
|
|
1408
|
-
const float * sinks_f = sinks ? (const float *) sinks +
|
|
1633
|
+
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
|
|
1634
|
+
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
|
|
1409
1635
|
|
|
1410
|
-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias,
|
|
1636
|
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
|
|
1411
1637
|
|
|
1412
1638
|
if (KV_max) {
|
|
1413
1639
|
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
|
|
@@ -1415,14 +1641,14 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1415
1641
|
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
|
1416
1642
|
if (kb0_start == 0) {
|
|
1417
1643
|
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
|
1418
|
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap,
|
|
1644
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
|
1419
1645
|
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
1420
|
-
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
|
1646
|
+
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
|
|
1421
1647
|
} else {
|
|
1422
1648
|
constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
|
|
1423
|
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap,
|
|
1649
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
|
1424
1650
|
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
1425
|
-
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
|
1651
|
+
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
|
|
1426
1652
|
}
|
|
1427
1653
|
|
|
1428
1654
|
kbc += iter_k;
|
|
@@ -1436,22 +1662,24 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1436
1662
|
return;
|
|
1437
1663
|
}
|
|
1438
1664
|
|
|
1439
|
-
|
|
1440
|
-
const int
|
|
1441
|
-
const int
|
|
1665
|
+
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
|
|
1666
|
+
const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
|
|
1667
|
+
const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
|
|
1668
|
+
const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
|
|
1669
|
+
const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
|
|
1442
1670
|
|
|
1443
|
-
const int
|
|
1671
|
+
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
|
1444
1672
|
|
|
1445
|
-
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*
|
|
1446
|
-
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*
|
|
1673
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
|
|
1674
|
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
|
|
1447
1675
|
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
|
|
1448
1676
|
(const half *) (mask + nb33*(sequence % ne33));
|
|
1449
|
-
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 +
|
|
1677
|
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
|
|
1450
1678
|
|
|
1451
|
-
const half2 * V_h2 =
|
|
1452
|
-
const float * sinks_f = sinks ? (const float *) sinks +
|
|
1679
|
+
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
|
|
1680
|
+
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
|
|
1453
1681
|
|
|
1454
|
-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias,
|
|
1682
|
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
|
|
1455
1683
|
|
|
1456
1684
|
if (KV_max) {
|
|
1457
1685
|
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
|
|
@@ -1459,9 +1687,9 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1459
1687
|
|
|
1460
1688
|
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
|
1461
1689
|
constexpr bool needs_fixup = false;
|
|
1462
|
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap,
|
|
1690
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
|
1463
1691
|
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
1464
|
-
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
|
1692
|
+
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
|
|
1465
1693
|
#else
|
|
1466
1694
|
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
|
1467
1695
|
max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
@@ -1473,7 +1701,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1473
1701
|
ne31, ne32, ne33,
|
|
1474
1702
|
nb31, nb32, nb33);
|
|
1475
1703
|
NO_DEVICE_CODE;
|
|
1476
|
-
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
|
|
1704
|
+
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
|
|
1477
1705
|
}
|
|
1478
1706
|
|
|
1479
1707
|
template <int DKQ, int DV, int ncols1, int ncols2>
|
|
@@ -1492,10 +1720,11 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
1492
1720
|
const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc);
|
|
1493
1721
|
const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
|
|
1494
1722
|
|
|
1495
|
-
const int cols_per_warp = std::min(ncols,
|
|
1496
|
-
const int
|
|
1723
|
+
const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
|
|
1724
|
+
const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size;
|
|
1725
|
+
const int nwarps = nthreads / warp_size_host;
|
|
1497
1726
|
|
|
1498
|
-
constexpr bool
|
|
1727
|
+
constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
|
|
1499
1728
|
|
|
1500
1729
|
const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
|
1501
1730
|
const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
|
@@ -1512,33 +1741,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
1512
1741
|
float logit_softcap;
|
|
1513
1742
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
|
1514
1743
|
|
|
1744
|
+
#if defined(GGML_USE_HIP)
|
|
1745
|
+
using fattn_kernel_ptr_t = const void*;
|
|
1746
|
+
#else
|
|
1747
|
+
using fattn_kernel_ptr_t = fattn_kernel_t;
|
|
1748
|
+
#endif // defined(GGML_USE_HIP)
|
|
1515
1749
|
fattn_kernel_t fattn_kernel;
|
|
1516
1750
|
if (logit_softcap == 0.0f) {
|
|
1517
1751
|
constexpr bool use_logit_softcap = false;
|
|
1518
|
-
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap,
|
|
1752
|
+
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
|
|
1519
1753
|
|
|
1520
|
-
#if !defined(
|
|
1754
|
+
#if !defined(GGML_USE_MUSA)
|
|
1521
1755
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
1522
1756
|
if (!shared_memory_limit_raised[id]) {
|
|
1523
|
-
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
1757
|
+
CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
1524
1758
|
shared_memory_limit_raised[id] = true;
|
|
1525
1759
|
}
|
|
1526
|
-
#endif // !defined(
|
|
1760
|
+
#endif // !defined(GGML_USE_MUSA)
|
|
1527
1761
|
} else {
|
|
1528
1762
|
constexpr bool use_logit_softcap = true;
|
|
1529
|
-
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap,
|
|
1763
|
+
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
|
|
1530
1764
|
|
|
1531
|
-
#if !defined(
|
|
1765
|
+
#if !defined(GGML_USE_MUSA)
|
|
1532
1766
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
1533
1767
|
if (!shared_memory_limit_raised[id]) {
|
|
1534
|
-
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
1768
|
+
CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
1535
1769
|
shared_memory_limit_raised[id] = true;
|
|
1536
1770
|
}
|
|
1537
|
-
#endif // !defined(
|
|
1771
|
+
#endif // !defined(GGML_USE_MUSA)
|
|
1538
1772
|
}
|
|
1539
1773
|
|
|
1540
1774
|
launch_fattn<DV, ncols1, ncols2>
|
|
1541
|
-
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
|
|
1775
|
+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host);
|
|
1542
1776
|
}
|
|
1543
1777
|
|
|
1544
1778
|
|
|
@@ -1585,3 +1819,10 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
|
|
1585
1819
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
|
1586
1820
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
|
1587
1821
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
|
1822
|
+
|
|
1823
|
+
// For GLM 4.7 Flash
|
|
1824
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
|
1825
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
|
1826
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
|
1827
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
|
|
1828
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);
|