whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -33,7 +33,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
33
33
|
const float m1,
|
|
34
34
|
const uint32_t n_head_log2,
|
|
35
35
|
const float logit_softcap,
|
|
36
|
-
const int32_t ne00, const
|
|
36
|
+
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
|
|
37
37
|
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
38
38
|
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
39
39
|
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
@@ -86,11 +86,11 @@ static __global__ void flash_attn_ext_vec(
|
|
|
86
86
|
|
|
87
87
|
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
|
|
88
88
|
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
|
89
|
-
#ifdef
|
|
89
|
+
#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
90
90
|
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
|
|
91
91
|
#else
|
|
92
92
|
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
|
|
93
|
-
#endif //
|
|
93
|
+
#endif // V_DOT2_F32_F16_AVAILABLE
|
|
94
94
|
|
|
95
95
|
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
|
96
96
|
|
|
@@ -112,13 +112,13 @@ static __global__ void flash_attn_ext_vec(
|
|
|
112
112
|
|
|
113
113
|
constexpr int ne_KQ = ncols*D;
|
|
114
114
|
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
|
|
115
|
-
#ifdef
|
|
115
|
+
#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
116
116
|
half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
|
117
117
|
__shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
|
118
118
|
#else
|
|
119
119
|
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
|
120
120
|
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
|
121
|
-
#endif //
|
|
121
|
+
#endif // V_DOT2_F32_F16_AVAILABLE
|
|
122
122
|
|
|
123
123
|
float KQ_max[ncols];
|
|
124
124
|
float KQ_sum[ncols];
|
|
@@ -129,11 +129,11 @@ static __global__ void flash_attn_ext_vec(
|
|
|
129
129
|
}
|
|
130
130
|
|
|
131
131
|
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
|
132
|
-
#ifdef
|
|
132
|
+
#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
133
133
|
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
|
|
134
134
|
#else
|
|
135
135
|
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
|
|
136
|
-
#endif //
|
|
136
|
+
#endif // V_DOT2_F32_F16_AVAILABLE
|
|
137
137
|
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
|
138
138
|
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
|
139
139
|
if constexpr (Q_q8_1) {
|
|
@@ -150,12 +150,12 @@ static __global__ void flash_attn_ext_vec(
|
|
|
150
150
|
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
|
151
151
|
|
|
152
152
|
// Set memory to zero if out of bounds:
|
|
153
|
-
if (ncols > 1 && ic0 + j >= ne01) {
|
|
153
|
+
if (ncols > 1 && ic0 + j >= int(ne01.z)) {
|
|
154
154
|
#pragma unroll
|
|
155
155
|
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
|
|
156
156
|
const int i = i0 + threadIdx.x;
|
|
157
157
|
|
|
158
|
-
if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) {
|
|
158
|
+
if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
|
|
159
159
|
tmp_q_i32[i] = 0;
|
|
160
160
|
}
|
|
161
161
|
}
|
|
@@ -191,7 +191,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
191
191
|
|
|
192
192
|
__syncthreads();
|
|
193
193
|
} else {
|
|
194
|
-
#ifdef
|
|
194
|
+
#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
195
195
|
const half2 scale_h2 = make_half2(scale, scale);
|
|
196
196
|
#pragma unroll
|
|
197
197
|
for (int j = 0; j < ncols; ++j) {
|
|
@@ -201,7 +201,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
201
201
|
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
|
|
202
202
|
|
|
203
203
|
float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
|
|
204
|
-
if (ncols == 1 || ic0 + j < ne01) {
|
|
204
|
+
if (ncols == 1 || ic0 + j < int(ne01.z)) {
|
|
205
205
|
ggml_cuda_memcpy_1<cpy_nb>(tmp, &Q_j[i]);
|
|
206
206
|
ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
|
|
207
207
|
}
|
|
@@ -222,7 +222,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
222
222
|
#pragma unroll
|
|
223
223
|
for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
|
|
224
224
|
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
|
|
225
|
-
if (ncols == 1 || ic0 + j < ne01) {
|
|
225
|
+
if (ncols == 1 || ic0 + j < int(ne01.z)) {
|
|
226
226
|
ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]);
|
|
227
227
|
ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
|
|
228
228
|
}
|
|
@@ -233,7 +233,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
233
233
|
Q_reg[j][k].y *= scale;
|
|
234
234
|
}
|
|
235
235
|
}
|
|
236
|
-
#endif //
|
|
236
|
+
#endif // V_DOT2_F32_F16_AVAILABLE
|
|
237
237
|
}
|
|
238
238
|
|
|
239
239
|
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
|
@@ -266,13 +266,13 @@ static __global__ void flash_attn_ext_vec(
|
|
|
266
266
|
sum = logit_softcap*tanhf(sum);
|
|
267
267
|
}
|
|
268
268
|
|
|
269
|
-
if (mask) {
|
|
269
|
+
if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) {
|
|
270
270
|
sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
|
|
271
271
|
}
|
|
272
272
|
|
|
273
|
-
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum);
|
|
273
|
+
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum + FATTN_KQ_MAX_OFFSET);
|
|
274
274
|
|
|
275
|
-
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) {
|
|
275
|
+
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
|
|
276
276
|
KQ_reg[j] = sum;
|
|
277
277
|
}
|
|
278
278
|
}
|
|
@@ -291,7 +291,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
291
291
|
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
|
|
292
292
|
KQ[j*nthreads + tid] = KQ_reg[j];
|
|
293
293
|
|
|
294
|
-
#ifdef
|
|
294
|
+
#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
295
295
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
|
296
296
|
#pragma unroll
|
|
297
297
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
|
@@ -303,7 +303,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
303
303
|
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
|
304
304
|
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
|
305
305
|
}
|
|
306
|
-
#endif //
|
|
306
|
+
#endif // V_DOT2_F32_F16_AVAILABLE
|
|
307
307
|
}
|
|
308
308
|
|
|
309
309
|
#ifndef GGML_USE_HIP
|
|
@@ -314,7 +314,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
314
314
|
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
|
|
315
315
|
const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
|
|
316
316
|
|
|
317
|
-
#ifdef
|
|
317
|
+
#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
318
318
|
half2 KQ_k[ncols];
|
|
319
319
|
#pragma unroll
|
|
320
320
|
for (int j = 0; j < ncols; ++j) {
|
|
@@ -353,7 +353,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
353
353
|
}
|
|
354
354
|
}
|
|
355
355
|
}
|
|
356
|
-
#endif //
|
|
356
|
+
#endif // V_DOT2_F32_F16_AVAILABLE
|
|
357
357
|
}
|
|
358
358
|
}
|
|
359
359
|
|
|
@@ -374,7 +374,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
374
374
|
|
|
375
375
|
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
|
|
376
376
|
|
|
377
|
-
#ifdef
|
|
377
|
+
#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
378
378
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
|
379
379
|
#pragma unroll
|
|
380
380
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
|
@@ -386,7 +386,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
386
386
|
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
|
387
387
|
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
|
388
388
|
}
|
|
389
|
-
#endif //
|
|
389
|
+
#endif // V_DOT2_F32_F16_AVAILABLE
|
|
390
390
|
}
|
|
391
391
|
}
|
|
392
392
|
|
|
@@ -412,7 +412,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
412
412
|
|
|
413
413
|
#pragma unroll
|
|
414
414
|
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
|
415
|
-
if (ncols > 1 && ic0 + j_VKQ >= ne01) {
|
|
415
|
+
if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) {
|
|
416
416
|
break;
|
|
417
417
|
}
|
|
418
418
|
|
|
@@ -421,7 +421,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
421
421
|
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
|
|
422
422
|
KQ_max[j_VKQ] = kqmax_new;
|
|
423
423
|
|
|
424
|
-
#ifdef
|
|
424
|
+
#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
425
425
|
half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
|
|
426
426
|
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
|
|
427
427
|
|
|
@@ -452,7 +452,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
452
452
|
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
|
|
453
453
|
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
|
|
454
454
|
}
|
|
455
|
-
#endif //
|
|
455
|
+
#endif // V_DOT2_F32_F16_AVAILABLE
|
|
456
456
|
|
|
457
457
|
KQ_sum[j_VKQ] *= kqmax_scale;
|
|
458
458
|
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
|
|
@@ -479,7 +479,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
479
479
|
if (gridDim.y == 1) {
|
|
480
480
|
dst_val /= KQ_sum[j_VKQ];
|
|
481
481
|
}
|
|
482
|
-
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
|
|
482
|
+
dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
|
|
483
483
|
}
|
|
484
484
|
}
|
|
485
485
|
|
|
@@ -489,8 +489,8 @@ static __global__ void flash_attn_ext_vec(
|
|
|
489
489
|
|
|
490
490
|
}
|
|
491
491
|
|
|
492
|
-
if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) {
|
|
493
|
-
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
|
|
492
|
+
if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) {
|
|
493
|
+
dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
|
|
494
494
|
}
|
|
495
495
|
#else
|
|
496
496
|
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
|
@@ -516,8 +516,8 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm
|
|
|
516
516
|
const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
|
|
517
517
|
const int nwarps = nthreads / WARP_SIZE;
|
|
518
518
|
fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
|
519
|
-
|
|
520
|
-
|
|
519
|
+
const bool need_f16_K = type_K == GGML_TYPE_F16;
|
|
520
|
+
const bool need_f16_V = type_V == GGML_TYPE_F16;
|
|
521
521
|
constexpr size_t nbytes_shared = 0;
|
|
522
522
|
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
|
523
523
|
}
|
|
@@ -526,17 +526,10 @@ template <int D, ggml_type type_K, ggml_type type_V>
|
|
|
526
526
|
void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
527
527
|
const ggml_tensor * KQV = dst;
|
|
528
528
|
const ggml_tensor * Q = dst->src[0];
|
|
529
|
-
const ggml_tensor * K = dst->src[1];
|
|
530
|
-
const ggml_tensor * V = dst->src[2];
|
|
531
|
-
|
|
532
|
-
GGML_ASSERT(K->type == type_K);
|
|
533
|
-
GGML_ASSERT(V->type == type_V);
|
|
534
529
|
|
|
535
530
|
float logit_softcap;
|
|
536
531
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
|
537
532
|
|
|
538
|
-
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
|
539
|
-
|
|
540
533
|
if (Q->ne[1] == 1) {
|
|
541
534
|
constexpr int cols_per_block = 1;
|
|
542
535
|
if (logit_softcap == 0.0f) {
|
|
@@ -6,19 +6,19 @@
|
|
|
6
6
|
#include "fattn-common.cuh"
|
|
7
7
|
#include "fattn-wmma-f16.cuh"
|
|
8
8
|
|
|
9
|
-
#ifdef
|
|
9
|
+
#ifdef GGML_USE_WMMA_FATTN
|
|
10
10
|
#if !defined(GGML_USE_HIP)
|
|
11
11
|
#include <mma.h>
|
|
12
|
-
#
|
|
12
|
+
#if defined(GGML_USE_MUSA)
|
|
13
13
|
namespace wmma = mtmusa::wmma;
|
|
14
14
|
#else // GGML_USE_MUSA
|
|
15
15
|
namespace wmma = nvcuda::wmma;
|
|
16
16
|
#endif // GGML_USE_MUSA
|
|
17
|
-
#elif defined(
|
|
17
|
+
#elif defined(GGML_USE_HIP)
|
|
18
18
|
#include <rocwmma/rocwmma.hpp>
|
|
19
19
|
namespace wmma = rocwmma;
|
|
20
20
|
#endif // !defined(GGML_USE_HIP)
|
|
21
|
-
#endif //
|
|
21
|
+
#endif // GGML_USE_WMMA_FATTN
|
|
22
22
|
|
|
23
23
|
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
|
24
24
|
template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
|
|
@@ -38,14 +38,14 @@ static __global__ void flash_attn_ext_f16(
|
|
|
38
38
|
const float m1,
|
|
39
39
|
const uint32_t n_head_log2,
|
|
40
40
|
const float logit_softcap,
|
|
41
|
-
const int32_t ne00, const
|
|
41
|
+
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
|
|
42
42
|
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
43
43
|
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
44
44
|
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
45
45
|
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
46
46
|
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
47
47
|
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
48
|
-
#if defined(FLASH_ATTN_AVAILABLE) && (
|
|
48
|
+
#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
|
|
49
49
|
// Skip unused kernel variants for faster compilation:
|
|
50
50
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
|
51
51
|
NO_DEVICE_CODE;
|
|
@@ -149,7 +149,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
149
149
|
if (i0 + warp_size > D && i >= D) {
|
|
150
150
|
break;
|
|
151
151
|
}
|
|
152
|
-
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
|
152
|
+
KQ[j*D_padded + i] = ic0 + j < int(ne01.z) ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
|
153
153
|
}
|
|
154
154
|
}
|
|
155
155
|
|
|
@@ -218,8 +218,9 @@ static __global__ void flash_attn_ext_f16(
|
|
|
218
218
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
|
|
219
219
|
const int k = k0 + threadIdx.x;
|
|
220
220
|
|
|
221
|
-
KQ_f_tmp[k0/warp_size] += mask
|
|
222
|
-
|
|
221
|
+
KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ?
|
|
222
|
+
__half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
|
223
|
+
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size] + FATTN_KQ_MAX_OFFSET);
|
|
223
224
|
}
|
|
224
225
|
KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
|
|
225
226
|
|
|
@@ -270,7 +271,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
270
271
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
|
|
271
272
|
const int k = k0 + threadIdx.x;
|
|
272
273
|
|
|
273
|
-
KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
|
274
|
+
KQ2_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
|
274
275
|
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
|
|
275
276
|
}
|
|
276
277
|
KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
|
@@ -431,7 +432,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
431
432
|
#pragma unroll
|
|
432
433
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
|
433
434
|
const int j_VKQ = j0 + threadIdx.y;
|
|
434
|
-
if (ic0 + j_VKQ >= ne01) {
|
|
435
|
+
if (ic0 + j_VKQ >= int(ne01.z)) {
|
|
435
436
|
return;
|
|
436
437
|
}
|
|
437
438
|
|
|
@@ -442,7 +443,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
442
443
|
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
|
443
444
|
}
|
|
444
445
|
|
|
445
|
-
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
|
446
|
+
const int j_dst_unrolled = ((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
|
446
447
|
|
|
447
448
|
#pragma unroll
|
|
448
449
|
for (int i0 = 0; i0 < D; i0 += warp_size) {
|
|
@@ -481,7 +482,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
481
482
|
ne31, ne32, ne33,
|
|
482
483
|
nb31, nb32, nb33);
|
|
483
484
|
NO_DEVICE_CODE;
|
|
484
|
-
#endif // defined(FLASH_ATTN_AVAILABLE) && (
|
|
485
|
+
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
|
|
485
486
|
}
|
|
486
487
|
|
|
487
488
|
constexpr int get_max_power_of_2(int x) {
|
|
@@ -1,3 +1,51 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
1
3
|
#include "common.cuh"
|
|
2
4
|
|
|
5
|
+
#if defined(GGML_USE_MUSA)
|
|
6
|
+
#define GGML_USE_WMMA_FATTN
|
|
7
|
+
#endif // defined(GGML_USE_MUSA)
|
|
8
|
+
|
|
9
|
+
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
|
10
|
+
#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
|
|
11
|
+
#define GGML_USE_WMMA_FATTN
|
|
12
|
+
#elif defined(CDNA)
|
|
13
|
+
#warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance"
|
|
14
|
+
#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
|
|
15
|
+
#if defined(RDNA3)
|
|
16
|
+
#define GGML_USE_WMMA_FATTN
|
|
17
|
+
#endif // defined(RDNA3)
|
|
18
|
+
#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
|
|
19
|
+
#define GGML_USE_WMMA_FATTN
|
|
20
|
+
#elif defined(RDNA4)
|
|
21
|
+
#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance"
|
|
22
|
+
#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
|
|
23
|
+
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
|
|
24
|
+
|
|
25
|
+
// WMMA flash attention requires FP16 matrix instructions to be available for ggml code.
|
|
26
|
+
static bool ggml_cuda_should_use_wmma_fattn(const int cc) {
|
|
27
|
+
#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
|
28
|
+
return false;
|
|
29
|
+
#else
|
|
30
|
+
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA) ||
|
|
31
|
+
GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
|
|
32
|
+
return true;
|
|
33
|
+
} else if (GGML_CUDA_CC_IS_CDNA(cc)){
|
|
34
|
+
#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
|
|
35
|
+
return true;
|
|
36
|
+
#else
|
|
37
|
+
return false;
|
|
38
|
+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
|
|
39
|
+
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
|
|
40
|
+
#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
|
|
41
|
+
return true;
|
|
42
|
+
#else
|
|
43
|
+
return false;
|
|
44
|
+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
|
|
45
|
+
} else {
|
|
46
|
+
return false;
|
|
47
|
+
}
|
|
48
|
+
#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
|
49
|
+
}
|
|
50
|
+
|
|
3
51
|
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@@ -12,13 +12,13 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
|
|
|
12
12
|
const ggml_tensor * Q = dst->src[0];
|
|
13
13
|
|
|
14
14
|
if constexpr (ncols2 <= 8) {
|
|
15
|
-
if (Q->ne[1] <= 8/ncols2) {
|
|
15
|
+
if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) {
|
|
16
16
|
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
|
|
17
17
|
return;
|
|
18
18
|
}
|
|
19
19
|
}
|
|
20
20
|
|
|
21
|
-
if (Q->ne[1] <= 16/ncols2) {
|
|
21
|
+
if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) {
|
|
22
22
|
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
|
|
23
23
|
return;
|
|
24
24
|
}
|
|
@@ -36,12 +36,26 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
|
|
36
36
|
const ggml_tensor * KQV = dst;
|
|
37
37
|
const ggml_tensor * Q = dst->src[0];
|
|
38
38
|
const ggml_tensor * K = dst->src[1];
|
|
39
|
+
const ggml_tensor * V = dst->src[2];
|
|
39
40
|
const ggml_tensor * mask = dst->src[3];
|
|
40
41
|
|
|
41
42
|
float max_bias = 0.0f;
|
|
42
43
|
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
|
43
44
|
|
|
44
|
-
|
|
45
|
+
// Edge cases like no mask, ALiBi, unpadded K/V, or misaligned addresses for large data transfers
|
|
46
|
+
// are put into the template specialization without GQA optimizations.
|
|
47
|
+
bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
|
48
|
+
for (const ggml_tensor * t : {Q, K, V, mask}) {
|
|
49
|
+
if (t == nullptr) {
|
|
50
|
+
continue;
|
|
51
|
+
}
|
|
52
|
+
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
|
|
53
|
+
if (t->nb[i] % 16 != 0) {
|
|
54
|
+
use_gqa_opt = false;
|
|
55
|
+
break;
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
}
|
|
45
59
|
|
|
46
60
|
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
|
47
61
|
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
|
@@ -116,11 +130,15 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
|
|
116
130
|
}
|
|
117
131
|
}
|
|
118
132
|
|
|
119
|
-
#define FATTN_VEC_CASE(D, type_K, type_V)
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
133
|
+
#define FATTN_VEC_CASE(D, type_K, type_V) \
|
|
134
|
+
{ \
|
|
135
|
+
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
|
|
136
|
+
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
|
|
137
|
+
if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
|
|
138
|
+
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
|
|
139
|
+
return; \
|
|
140
|
+
} \
|
|
141
|
+
} \
|
|
124
142
|
|
|
125
143
|
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
|
|
126
144
|
FATTN_VEC_CASE( 64, type_K, type_V) \
|
|
@@ -198,6 +216,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|
|
198
216
|
return BEST_FATTN_KERNEL_NONE;
|
|
199
217
|
#endif// FLASH_ATTN_AVAILABLE
|
|
200
218
|
|
|
219
|
+
const ggml_tensor * KQV = dst;
|
|
201
220
|
const ggml_tensor * Q = dst->src[0];
|
|
202
221
|
const ggml_tensor * K = dst->src[1];
|
|
203
222
|
const ggml_tensor * V = dst->src[2];
|
|
@@ -206,31 +225,33 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|
|
206
225
|
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
|
207
226
|
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
|
208
227
|
|
|
228
|
+
float max_bias = 0.0f;
|
|
229
|
+
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
|
230
|
+
|
|
231
|
+
// The effective batch size for the kernel can be increased by gqa_ratio.
|
|
232
|
+
// The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
|
|
233
|
+
const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
|
234
|
+
|
|
209
235
|
const int cc = ggml_cuda_info().devices[device].cc;
|
|
210
236
|
|
|
211
237
|
switch (K->ne[0]) {
|
|
238
|
+
case 40:
|
|
212
239
|
case 64:
|
|
213
|
-
case
|
|
214
|
-
case 256:
|
|
215
|
-
if (V->ne[0] != K->ne[0]) {
|
|
216
|
-
return BEST_FATTN_KERNEL_NONE;
|
|
217
|
-
}
|
|
218
|
-
break;
|
|
240
|
+
case 72:
|
|
219
241
|
case 80:
|
|
220
242
|
case 96:
|
|
243
|
+
case 128:
|
|
221
244
|
case 112:
|
|
245
|
+
case 256:
|
|
222
246
|
if (V->ne[0] != K->ne[0]) {
|
|
223
247
|
return BEST_FATTN_KERNEL_NONE;
|
|
224
248
|
}
|
|
225
|
-
if (!fp16_mma_available(cc) && !turing_mma_available(cc)) {
|
|
226
|
-
return BEST_FATTN_KERNEL_NONE;
|
|
227
|
-
}
|
|
228
249
|
break;
|
|
229
250
|
case 576:
|
|
230
251
|
if (V->ne[0] != 512) {
|
|
231
252
|
return BEST_FATTN_KERNEL_NONE;
|
|
232
253
|
}
|
|
233
|
-
if (!
|
|
254
|
+
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
|
234
255
|
return BEST_FATTN_KERNEL_NONE;
|
|
235
256
|
}
|
|
236
257
|
break;
|
|
@@ -245,6 +266,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|
|
245
266
|
#endif // GGML_CUDA_FA_ALL_QUANTS
|
|
246
267
|
|
|
247
268
|
switch (K->type) {
|
|
269
|
+
case GGML_TYPE_F32:
|
|
248
270
|
case GGML_TYPE_F16:
|
|
249
271
|
break;
|
|
250
272
|
case GGML_TYPE_Q4_1:
|
|
@@ -264,47 +286,71 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|
|
264
286
|
return BEST_FATTN_KERNEL_NONE;
|
|
265
287
|
}
|
|
266
288
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
// If Turing tensor cores available, use them except for some cases with batch size 1:
|
|
270
|
-
if (turing_mma_available(cc)) {
|
|
271
|
-
best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16;
|
|
289
|
+
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
|
|
290
|
+
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
|
272
291
|
|
|
292
|
+
// If Turing tensor cores are available, use them:
|
|
293
|
+
if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
|
273
294
|
if (can_use_vector_kernel) {
|
|
274
|
-
if (K->type
|
|
295
|
+
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
|
275
296
|
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
|
|
276
|
-
|
|
297
|
+
return BEST_FATTN_KERNEL_VEC;
|
|
277
298
|
}
|
|
278
299
|
} else {
|
|
279
300
|
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
|
280
301
|
if (Q->ne[1] <= 2) {
|
|
281
|
-
|
|
302
|
+
return BEST_FATTN_KERNEL_VEC;
|
|
282
303
|
}
|
|
283
304
|
} else {
|
|
284
305
|
if (Q->ne[1] == 1) {
|
|
285
|
-
|
|
306
|
+
return BEST_FATTN_KERNEL_VEC;
|
|
286
307
|
}
|
|
287
308
|
}
|
|
288
309
|
}
|
|
289
|
-
if (
|
|
290
|
-
|
|
310
|
+
if (!gqa_opt_applies && Q->ne[1] == 1) {
|
|
311
|
+
return BEST_FATTN_KERNEL_VEC;
|
|
291
312
|
}
|
|
292
313
|
}
|
|
293
|
-
|
|
294
|
-
return best;
|
|
314
|
+
return BEST_FATTN_KERNEL_MMA_F16;
|
|
295
315
|
}
|
|
296
316
|
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
317
|
+
if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
|
318
|
+
int gqa_ratio_eff = 1;
|
|
319
|
+
const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
|
|
320
|
+
while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
|
|
321
|
+
gqa_ratio_eff *= 2;
|
|
322
|
+
}
|
|
323
|
+
if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) {
|
|
324
|
+
return BEST_FATTN_KERNEL_VEC;
|
|
325
|
+
}
|
|
326
|
+
if (Q->ne[1] * gqa_ratio_eff <= 16) {
|
|
327
|
+
return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices.
|
|
328
|
+
}
|
|
329
|
+
return BEST_FATTN_KERNEL_MMA_F16;
|
|
300
330
|
}
|
|
301
331
|
|
|
302
|
-
//
|
|
303
|
-
if (
|
|
332
|
+
// Use the WMMA kernel if possible:
|
|
333
|
+
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
|
|
334
|
+
if (can_use_vector_kernel && Q->ne[1] <= 2) {
|
|
335
|
+
return BEST_FATTN_KERNEL_VEC;
|
|
336
|
+
}
|
|
304
337
|
return BEST_FATTN_KERNEL_WMMA_F16;
|
|
305
338
|
}
|
|
306
339
|
|
|
307
|
-
// If there
|
|
340
|
+
// If there are no tensor cores available, use the generic tile kernel:
|
|
341
|
+
if (can_use_vector_kernel) {
|
|
342
|
+
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
|
343
|
+
if (Q->ne[1] == 1) {
|
|
344
|
+
if (!gqa_opt_applies) {
|
|
345
|
+
return BEST_FATTN_KERNEL_VEC;
|
|
346
|
+
}
|
|
347
|
+
}
|
|
348
|
+
} else {
|
|
349
|
+
if (Q->ne[1] <= 2) {
|
|
350
|
+
return BEST_FATTN_KERNEL_VEC;
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
}
|
|
308
354
|
return BEST_FATTN_KERNEL_TILE;
|
|
309
355
|
}
|
|
310
356
|
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
#include "fill.cuh"
|
|
2
|
+
#include "convert.cuh"
|
|
3
|
+
|
|
4
|
+
#define CUDA_FILL_BLOCK_SIZE 256
|
|
5
|
+
|
|
6
|
+
template <typename T>
|
|
7
|
+
static __global__ void fill_kernel(T * dst, const int64_t k, const T value) {
|
|
8
|
+
const int64_t i = (int64_t)blockDim.x * blockIdx.x + threadIdx.x;
|
|
9
|
+
if (i >= k) {
|
|
10
|
+
return;
|
|
11
|
+
}
|
|
12
|
+
dst[i] = value;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
16
|
+
void * dst_d = dst->data;
|
|
17
|
+
cudaStream_t stream = ctx.stream();
|
|
18
|
+
|
|
19
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
20
|
+
|
|
21
|
+
float value;
|
|
22
|
+
memcpy(&value, dst->op_params, sizeof(float));
|
|
23
|
+
|
|
24
|
+
const int64_t k = ggml_nelements(dst);
|
|
25
|
+
const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE;
|
|
26
|
+
|
|
27
|
+
switch (dst->type) {
|
|
28
|
+
case GGML_TYPE_F32:
|
|
29
|
+
fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((float *)dst_d, k, value);
|
|
30
|
+
break;
|
|
31
|
+
case GGML_TYPE_F16:
|
|
32
|
+
fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((half *)dst_d, k, ggml_cuda_cast<half>(value));
|
|
33
|
+
break;
|
|
34
|
+
default:
|
|
35
|
+
GGML_ABORT("unsupported type");
|
|
36
|
+
}
|
|
37
|
+
}
|