whispercpp 1.3.3 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +79 -25
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +122 -111
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +34 -24
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
- data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
- data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
- data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
- data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
- data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
- data/ext/sources/examples/talk-llama/llama-context.h +99 -36
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
- data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
- data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
- data/ext/sources/examples/talk-llama/llama-model.h +104 -12
- data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
- data/ext/sources/examples/talk-llama/llama.cpp +794 -12
- data/ext/sources/examples/talk-llama/llama.h +246 -190
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
- data/ext/sources/ggml/CMakeLists.txt +135 -79
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +21 -2
- data/ext/sources/ggml/include/ggml-cpu.h +2 -1
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +406 -23
- data/ext/sources/ggml/src/CMakeLists.txt +99 -13
- data/ext/sources/ggml/src/ggml-alloc.c +368 -161
- data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
- data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
- data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
- data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
- data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
- data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
- data/ext/sources/ggml/src/ggml-impl.h +186 -15
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +901 -129
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +124 -81
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +7 -5
- data/ext/sources/tests/test-vad.cpp +3 -3
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +126 -2
- data/test/test_params.rb +24 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +8 -1
- data/whispercpp.gemspec +1 -1
- metadata +439 -179
- data/ext/sources/build-xcframework.sh +0 -547
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
|
@@ -10,11 +10,21 @@
|
|
|
10
10
|
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
|
11
11
|
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
|
12
12
|
|
|
13
|
+
// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
|
|
14
|
+
// by the VKQ accumulators is effectively being shifted up by a factor of 2.
|
|
15
|
+
// This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
|
|
16
|
+
// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
|
|
17
|
+
// Still, the value range should be shifted as much as necessary but as little as possible.
|
|
18
|
+
// The macro on the following line shifts it by a factor of 2**3=8, as was needed to fix https://github.com/ggml-org/llama.cpp/issues/18606 .
|
|
19
|
+
#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)
|
|
20
|
+
|
|
13
21
|
typedef void (* fattn_kernel_t)(
|
|
14
22
|
const char * __restrict__ Q,
|
|
15
23
|
const char * __restrict__ K,
|
|
16
24
|
const char * __restrict__ V,
|
|
17
25
|
const char * __restrict__ mask,
|
|
26
|
+
const char * __restrict__ sinks,
|
|
27
|
+
const int * __restrict__ KV_max,
|
|
18
28
|
float * __restrict__ dst,
|
|
19
29
|
float2 * __restrict__ dst_meta,
|
|
20
30
|
const float scale,
|
|
@@ -23,300 +33,238 @@ typedef void (* fattn_kernel_t)(
|
|
|
23
33
|
const float m1,
|
|
24
34
|
const uint32_t n_head_log2,
|
|
25
35
|
const float logit_softcap,
|
|
26
|
-
const
|
|
27
|
-
|
|
28
|
-
const
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
const int nb31,
|
|
36
|
-
const int nb01,
|
|
37
|
-
const int nb02,
|
|
38
|
-
const int nb03,
|
|
39
|
-
const int nb11,
|
|
40
|
-
const int nb12,
|
|
41
|
-
const int nb13,
|
|
42
|
-
const int nb21,
|
|
43
|
-
const int nb22,
|
|
44
|
-
const int nb23,
|
|
45
|
-
const int ne0,
|
|
46
|
-
const int ne1,
|
|
47
|
-
const int ne2,
|
|
48
|
-
const int ne3);
|
|
49
|
-
|
|
50
|
-
typedef half (*vec_dot_KQ_f16_t)(
|
|
51
|
-
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
|
52
|
-
typedef float (*vec_dot_KQ_f32_t)(
|
|
36
|
+
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
|
|
37
|
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
38
|
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
39
|
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
40
|
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
41
|
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
42
|
+
const int32_t nb31, const int32_t nb32, const int64_t nb33);
|
|
43
|
+
|
|
44
|
+
typedef float (*vec_dot_KQ_t)(
|
|
53
45
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
|
54
46
|
|
|
55
|
-
template<
|
|
56
|
-
static __device__ __forceinline__
|
|
47
|
+
template <int D, int nthreads>
|
|
48
|
+
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
|
|
49
|
+
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
|
50
|
+
|
|
51
|
+
const half2 * K_h2 = (const half2 *) K_c;
|
|
52
|
+
GGML_UNUSED(Q_q8);
|
|
53
|
+
GGML_UNUSED(Q_ds_v);
|
|
54
|
+
|
|
55
|
+
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
|
56
|
+
constexpr int cpy_ne = cpy_nb / 4;
|
|
57
|
+
|
|
58
|
+
float sum = 0.0f;
|
|
59
|
+
|
|
60
|
+
#pragma unroll
|
|
61
|
+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
|
|
62
|
+
half2 tmp[cpy_ne];
|
|
63
|
+
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
|
|
64
|
+
#pragma unroll
|
|
65
|
+
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
|
|
66
|
+
#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
67
|
+
ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
|
68
|
+
#else
|
|
69
|
+
ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
|
70
|
+
#endif // V_DOT2_F32_F16_AVAILABLE
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
return sum;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
template<int D, int nthreads>
|
|
78
|
+
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
|
|
57
79
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
|
58
80
|
|
|
59
81
|
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
|
60
82
|
GGML_UNUSED(Q_v);
|
|
61
83
|
|
|
62
|
-
|
|
84
|
+
float sum = 0.0f;
|
|
63
85
|
|
|
64
86
|
#pragma unroll
|
|
65
|
-
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 +=
|
|
66
|
-
const int k_KQ = k_KQ_0 + threadIdx.x;
|
|
87
|
+
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
|
88
|
+
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
|
|
67
89
|
|
|
68
90
|
const int ib = k_KQ / QI8_1;
|
|
69
91
|
const int iqs4 = k_KQ % QI4_0;
|
|
70
92
|
const int shift = k_KQ & (QI8_1/2);
|
|
71
93
|
|
|
72
|
-
|
|
73
|
-
|
|
94
|
+
int v;
|
|
95
|
+
ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4);
|
|
96
|
+
v = (v >> shift) & 0x0F0F0F0F;
|
|
97
|
+
const int u = Q_q8[k_KQ_0/nthreads];
|
|
74
98
|
|
|
75
99
|
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
|
76
100
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
|
80
|
-
|
|
81
|
-
const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size];
|
|
82
|
-
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
|
|
83
|
-
} else
|
|
84
|
-
#endif // FP16_AVAILABLE
|
|
85
|
-
{
|
|
86
|
-
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
|
87
|
-
|
|
88
|
-
sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
|
|
89
|
-
}
|
|
101
|
+
const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
|
|
102
|
+
sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x - (8/QI8_1)*Q_ds.y);
|
|
90
103
|
}
|
|
91
104
|
|
|
92
105
|
return sum;
|
|
93
106
|
}
|
|
94
107
|
|
|
95
|
-
template<
|
|
96
|
-
static __device__ __forceinline__
|
|
108
|
+
template<int D, int nthreads>
|
|
109
|
+
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_1(
|
|
97
110
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
|
98
111
|
|
|
99
112
|
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
|
100
113
|
GGML_UNUSED(Q_v);
|
|
101
114
|
|
|
102
|
-
|
|
115
|
+
float sum = 0.0f;
|
|
103
116
|
|
|
104
117
|
#pragma unroll
|
|
105
|
-
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 +=
|
|
106
|
-
const int k_KQ = k_KQ_0 + threadIdx.x;
|
|
118
|
+
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
|
119
|
+
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
|
|
107
120
|
|
|
108
121
|
const int ib = k_KQ / QI8_1;
|
|
109
122
|
const int iqs4 = k_KQ % QI4_1;
|
|
110
123
|
const int shift = k_KQ & (QI8_1/2);
|
|
111
124
|
|
|
112
|
-
|
|
113
|
-
|
|
125
|
+
int v;
|
|
126
|
+
ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4);
|
|
127
|
+
v = (v >> shift) & 0x0F0F0F0F;
|
|
128
|
+
const int u = Q_q8[k_KQ_0/nthreads];
|
|
114
129
|
|
|
115
130
|
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
|
116
131
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
|
120
|
-
|
|
121
|
-
const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size];
|
|
122
|
-
const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
|
|
123
|
-
sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
|
|
124
|
-
} else
|
|
125
|
-
#endif // FP16_AVAILABLE
|
|
126
|
-
{
|
|
127
|
-
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
|
128
|
-
|
|
129
|
-
const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
|
|
130
|
-
const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
|
|
132
|
+
const float2 K_dm = __half22float2(K_q4_1[ib].dm);
|
|
133
|
+
const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
|
|
131
134
|
|
|
132
|
-
|
|
133
|
-
}
|
|
135
|
+
sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
|
|
134
136
|
}
|
|
135
137
|
|
|
136
138
|
return sum;
|
|
137
139
|
}
|
|
138
140
|
|
|
139
|
-
template<
|
|
140
|
-
static __device__ __forceinline__
|
|
141
|
+
template<int D, int nthreads>
|
|
142
|
+
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_0(
|
|
141
143
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
|
142
144
|
|
|
143
145
|
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
|
144
146
|
GGML_UNUSED(Q_v);
|
|
145
147
|
|
|
146
|
-
|
|
148
|
+
float sum = 0.0f;
|
|
147
149
|
|
|
148
150
|
#pragma unroll
|
|
149
|
-
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 +=
|
|
150
|
-
const int k_KQ = k_KQ_0 + threadIdx.x;
|
|
151
|
+
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
|
152
|
+
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
|
|
151
153
|
|
|
152
154
|
const int ib = k_KQ / QI8_1;
|
|
153
155
|
const int iqs4 = k_KQ % QI5_0;
|
|
154
156
|
const int iqs8 = k_KQ % QI8_1;
|
|
155
157
|
const int shift = k_KQ & (QI8_1/2);
|
|
156
158
|
|
|
157
|
-
int v
|
|
158
|
-
|
|
159
|
-
v
|
|
160
|
-
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
|
161
|
-
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
|
162
|
-
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
|
159
|
+
int v;
|
|
160
|
+
ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4);
|
|
161
|
+
v = (v >> shift) & 0x0F0F0F0F;
|
|
163
162
|
|
|
164
|
-
|
|
163
|
+
{
|
|
164
|
+
int vh;
|
|
165
|
+
ggml_cuda_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh);
|
|
166
|
+
vh >>= iqs8 * QI5_0;
|
|
167
|
+
|
|
168
|
+
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
|
169
|
+
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
|
170
|
+
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
|
171
|
+
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
|
172
|
+
}
|
|
165
173
|
|
|
166
|
-
const int
|
|
174
|
+
const int u = Q_q8[k_KQ_0/nthreads];
|
|
167
175
|
|
|
168
|
-
|
|
169
|
-
if (std::is_same<T, half>::value) {
|
|
170
|
-
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
|
176
|
+
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
|
171
177
|
|
|
172
|
-
|
|
173
|
-
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
|
|
174
|
-
} else
|
|
175
|
-
#endif // FP16_AVAILABLE
|
|
176
|
-
{
|
|
177
|
-
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
|
178
|
+
const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
|
|
178
179
|
|
|
179
|
-
|
|
180
|
-
}
|
|
180
|
+
sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x - (16/QI8_1)*Q_ds.y);
|
|
181
181
|
}
|
|
182
182
|
|
|
183
183
|
return sum;
|
|
184
184
|
}
|
|
185
185
|
|
|
186
|
-
template<
|
|
187
|
-
static __device__ __forceinline__
|
|
186
|
+
template<int D, int nthreads>
|
|
187
|
+
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_1(
|
|
188
188
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
|
189
189
|
|
|
190
190
|
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
|
191
191
|
GGML_UNUSED(Q_v);
|
|
192
192
|
|
|
193
|
-
|
|
193
|
+
float sum = 0.0f;
|
|
194
194
|
|
|
195
195
|
#pragma unroll
|
|
196
|
-
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 +=
|
|
197
|
-
const int k_KQ = k_KQ_0 + threadIdx.x;
|
|
196
|
+
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
|
197
|
+
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
|
|
198
198
|
|
|
199
199
|
const int ib = k_KQ / QI8_1;
|
|
200
200
|
const int iqs4 = k_KQ % QI5_1;
|
|
201
201
|
const int iqs8 = k_KQ % QI8_1;
|
|
202
202
|
const int shift = k_KQ & (QI8_1/2);
|
|
203
203
|
|
|
204
|
-
int v
|
|
205
|
-
|
|
206
|
-
v
|
|
207
|
-
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
|
208
|
-
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
|
209
|
-
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
|
210
|
-
|
|
211
|
-
const int u = Q_q8[k_KQ_0/warp_size];
|
|
204
|
+
int v;
|
|
205
|
+
ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4);
|
|
206
|
+
v = (v >> shift) & 0x0F0F0F0F;
|
|
212
207
|
|
|
213
|
-
|
|
208
|
+
{
|
|
209
|
+
int vh;
|
|
210
|
+
ggml_cuda_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh);
|
|
211
|
+
vh >>= iqs8 * QI5_0;
|
|
212
|
+
|
|
213
|
+
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
|
214
|
+
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
|
215
|
+
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
|
216
|
+
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
|
217
|
+
}
|
|
214
218
|
|
|
215
|
-
|
|
216
|
-
if (std::is_same<T, half>::value) {
|
|
217
|
-
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
|
219
|
+
const int u = Q_q8[k_KQ_0/nthreads];
|
|
218
220
|
|
|
219
|
-
|
|
220
|
-
const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
|
|
221
|
-
sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
|
|
222
|
-
} else
|
|
223
|
-
#endif // FP16_AVAILABLE
|
|
224
|
-
{
|
|
225
|
-
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
|
221
|
+
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
|
226
222
|
|
|
227
|
-
|
|
228
|
-
|
|
223
|
+
const float2 K_dm = __half22float2(K_q5_1[ib].dm);
|
|
224
|
+
const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
|
|
229
225
|
|
|
230
|
-
|
|
231
|
-
}
|
|
226
|
+
sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
|
|
232
227
|
}
|
|
233
228
|
|
|
234
229
|
return sum;
|
|
235
230
|
}
|
|
236
231
|
|
|
237
|
-
template <
|
|
238
|
-
static __device__ __forceinline__
|
|
232
|
+
template <int D, int nthreads>
|
|
233
|
+
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q8_0(
|
|
239
234
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
|
240
235
|
|
|
241
236
|
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
|
242
237
|
GGML_UNUSED(Q_v);
|
|
243
238
|
|
|
244
|
-
|
|
239
|
+
float sum = 0.0f;
|
|
245
240
|
|
|
246
241
|
#pragma unroll
|
|
247
|
-
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 +=
|
|
248
|
-
const int k_KQ = k_KQ_0 + threadIdx.x;
|
|
242
|
+
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
|
243
|
+
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
|
|
249
244
|
|
|
250
245
|
const int ib = k_KQ / QI8_0;
|
|
251
246
|
const int iqs = k_KQ % QI8_0;
|
|
252
247
|
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
T Q_d;
|
|
256
|
-
if (std::is_same<T, half>::value) {
|
|
257
|
-
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
|
258
|
-
Q_d = __low2half(Q_ds[k_KQ_0/warp_size]);
|
|
259
|
-
} else {
|
|
260
|
-
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
|
261
|
-
Q_d = Q_ds[k_KQ_0/warp_size].x;
|
|
262
|
-
}
|
|
263
|
-
|
|
264
|
-
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d);
|
|
265
|
-
}
|
|
266
|
-
|
|
267
|
-
return sum;
|
|
268
|
-
}
|
|
269
|
-
|
|
270
|
-
template <typename T, int D, int warp_size>
|
|
271
|
-
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
|
272
|
-
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
|
273
|
-
|
|
274
|
-
const half2 * K_h2 = (const half2 *) K_c;
|
|
275
|
-
GGML_UNUSED(Q_q8);
|
|
276
|
-
GGML_UNUSED(Q_ds_v);
|
|
277
|
-
|
|
278
|
-
#ifdef FP16_AVAILABLE
|
|
279
|
-
if (std::is_same<T, half>::value) {
|
|
280
|
-
const half2 * Q_h2 = (const half2 *) Q_v;
|
|
281
|
-
|
|
282
|
-
half2 sum2 = make_half2(0.0f, 0.0f);
|
|
283
|
-
|
|
284
|
-
#pragma unroll
|
|
285
|
-
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
|
|
286
|
-
const int k_KQ = k_KQ_0 + threadIdx.x;
|
|
287
|
-
|
|
288
|
-
const half2 K_ik = K_h2[k_KQ];
|
|
289
|
-
sum2 += K_ik * Q_h2[k_KQ_0/warp_size];
|
|
290
|
-
}
|
|
291
|
-
|
|
292
|
-
return __low2half(sum2) + __high2half(sum2);
|
|
293
|
-
}
|
|
294
|
-
#endif // FP16_AVAILABLE
|
|
248
|
+
int v;
|
|
249
|
+
ggml_cuda_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs);
|
|
295
250
|
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
float sum = 0.0f;
|
|
251
|
+
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
|
252
|
+
const float Q_d = Q_ds[k_KQ_0/nthreads].x;
|
|
299
253
|
|
|
300
|
-
|
|
301
|
-
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
|
|
302
|
-
const int k_KQ = k_KQ_0 + threadIdx.x;
|
|
303
|
-
|
|
304
|
-
const half2 K_ik = K_h2[k_KQ];
|
|
305
|
-
sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x;
|
|
306
|
-
sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y;
|
|
254
|
+
sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d);
|
|
307
255
|
}
|
|
308
256
|
|
|
309
257
|
return sum;
|
|
310
258
|
}
|
|
311
259
|
|
|
312
|
-
template <typename Tds>
|
|
260
|
+
template <typename Tds, int ni>
|
|
313
261
|
static __device__ __forceinline__ void quantize_q8_1_to_shared(
|
|
314
262
|
const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
|
|
315
263
|
|
|
316
264
|
float vals[sizeof(int)] = {0.0f};
|
|
317
265
|
#pragma unroll
|
|
318
266
|
for (int l = 0; l < int(sizeof(int)); ++l) {
|
|
319
|
-
vals[l] = scale * x[4*threadIdx.x + l];
|
|
267
|
+
vals[l] = (ni == WARP_SIZE || threadIdx.x < ni) ? scale * x[4*threadIdx.x + l] : 0.0f;
|
|
320
268
|
}
|
|
321
269
|
|
|
322
270
|
float amax = fabsf(vals[0]);
|
|
@@ -344,7 +292,7 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
|
|
|
344
292
|
}
|
|
345
293
|
|
|
346
294
|
yq32[threadIdx.x] = q32;
|
|
347
|
-
if (threadIdx.x % QI8_1 == 0) {
|
|
295
|
+
if (threadIdx.x % QI8_1 == 0 && (ni == WARP_SIZE || threadIdx.x < ni)) {
|
|
348
296
|
if (std::is_same<Tds, half2>::value) {
|
|
349
297
|
((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum);
|
|
350
298
|
} else {
|
|
@@ -353,173 +301,336 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
|
|
|
353
301
|
}
|
|
354
302
|
}
|
|
355
303
|
|
|
356
|
-
typedef
|
|
357
|
-
|
|
304
|
+
typedef void (*dequantize_V_t)(const void *, void *, const int64_t);
|
|
305
|
+
|
|
306
|
+
template <typename T, int ne>
|
|
307
|
+
static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
|
308
|
+
if constexpr (std::is_same_v<T, half>) {
|
|
309
|
+
ggml_cuda_memcpy_1<ne*sizeof(half)>(dst, (const half *) vx + i0);
|
|
310
|
+
} else if constexpr (std::is_same_v<T, float>) {
|
|
311
|
+
static_assert(ne % 2 == 0, "bad ne");
|
|
312
|
+
half2 tmp[ne/2];
|
|
313
|
+
ggml_cuda_memcpy_1<ne*sizeof(half)>(tmp, (const half *) vx + i0);
|
|
314
|
+
float2 * dst_f2 = (float2 *) dst;
|
|
315
|
+
#pragma unroll
|
|
316
|
+
for (int l = 0; l < ne/2; ++l) {
|
|
317
|
+
dst_f2[l] = __half22float2(tmp[l]);
|
|
318
|
+
}
|
|
319
|
+
} else {
|
|
320
|
+
static_assert(std::is_same_v<T, void>, "unsupported type");
|
|
321
|
+
}
|
|
322
|
+
}
|
|
358
323
|
|
|
359
|
-
template <typename T>
|
|
360
|
-
static __device__ __forceinline__
|
|
324
|
+
template <typename T, int ne>
|
|
325
|
+
static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
|
361
326
|
const block_q4_0 * x = (const block_q4_0 *) vx;
|
|
362
327
|
|
|
363
|
-
const int64_t ib =
|
|
364
|
-
const int iqs =
|
|
365
|
-
const int shift = (
|
|
328
|
+
const int64_t ib = i0 / QK4_0;
|
|
329
|
+
const int iqs = i0 % (QK4_0/2);
|
|
330
|
+
const int shift = (i0 % QK4_0) / (QK4_0/2);
|
|
331
|
+
|
|
332
|
+
int q;
|
|
333
|
+
static_assert(ne == 2 || ne == 4, "bad ne");
|
|
334
|
+
ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
|
|
335
|
+
q >>= 4*shift;
|
|
336
|
+
q &= 0x0F0F0F0F;
|
|
337
|
+
q = __vsubss4(q, 0x08080808);
|
|
366
338
|
|
|
367
|
-
const
|
|
368
|
-
const int q0 = x[ib].qs[iqs];
|
|
369
|
-
const int q = ((q0 >> (4*shift)) & 0x0F) - 8;
|
|
339
|
+
const int8_t * q8 = (const int8_t *) &q;
|
|
370
340
|
|
|
371
341
|
#ifdef FP16_AVAILABLE
|
|
372
|
-
if (std::
|
|
373
|
-
|
|
374
|
-
|
|
342
|
+
if constexpr (std::is_same_v<T, half>) {
|
|
343
|
+
const half2 d = __half2half2(x[ib].d);
|
|
344
|
+
|
|
345
|
+
#pragma unroll
|
|
346
|
+
for (int l0 = 0; l0 < ne; l0 += 2) {
|
|
347
|
+
((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
|
|
348
|
+
}
|
|
349
|
+
} else
|
|
375
350
|
#endif // FP16_AVAILABLE
|
|
351
|
+
if constexpr (std::is_same_v<T, float>) {
|
|
352
|
+
const float d = x[ib].d;
|
|
376
353
|
|
|
377
|
-
|
|
354
|
+
#pragma unroll
|
|
355
|
+
for (int l = 0; l < ne; ++l) {
|
|
356
|
+
((float *) dst)[l] = d * q8[l];
|
|
357
|
+
}
|
|
358
|
+
} else {
|
|
359
|
+
static_assert(std::is_same_v<T, void>, "bad type");
|
|
360
|
+
}
|
|
378
361
|
}
|
|
379
362
|
|
|
380
|
-
template <typename T>
|
|
381
|
-
static __device__ __forceinline__
|
|
363
|
+
template <typename T, int ne>
|
|
364
|
+
static __device__ __forceinline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
|
382
365
|
const block_q4_1 * x = (const block_q4_1 *) vx;
|
|
383
366
|
|
|
384
|
-
const int64_t ib =
|
|
385
|
-
const int iqs =
|
|
386
|
-
const int shift = (
|
|
367
|
+
const int64_t ib = i0 / QK4_1;
|
|
368
|
+
const int iqs = i0 % (QK4_1/2);
|
|
369
|
+
const int shift = (i0 % QK4_1) / (QK4_1/2);
|
|
370
|
+
|
|
371
|
+
int q;
|
|
372
|
+
static_assert(ne == 2 || ne == 4, "bad ne");
|
|
373
|
+
ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
|
|
374
|
+
q >>= 4*shift;
|
|
375
|
+
q &= 0x0F0F0F0F;
|
|
387
376
|
|
|
388
|
-
const
|
|
389
|
-
const int q0 = x[ib].qs[iqs];
|
|
390
|
-
const int q = ((q0 >> (4*shift)) & 0x0F);
|
|
377
|
+
const int8_t * q8 = (const int8_t *) &q;
|
|
391
378
|
|
|
392
379
|
#ifdef FP16_AVAILABLE
|
|
393
|
-
if (std::
|
|
394
|
-
|
|
395
|
-
|
|
380
|
+
if constexpr (std::is_same_v<T, half>) {
|
|
381
|
+
const half2 dm = x[ib].dm;
|
|
382
|
+
const half2 d = __half2half2( __low2half(dm));
|
|
383
|
+
const half2 m = __half2half2(__high2half(dm));
|
|
384
|
+
|
|
385
|
+
#pragma unroll
|
|
386
|
+
for (int l0 = 0; l0 < ne; l0 += 2) {
|
|
387
|
+
((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
|
|
388
|
+
}
|
|
389
|
+
} else
|
|
396
390
|
#endif // FP16_AVAILABLE
|
|
391
|
+
if constexpr (std::is_same_v<T, float>) {
|
|
392
|
+
const float2 dm = __half22float2(x[ib].dm);
|
|
397
393
|
|
|
398
|
-
|
|
394
|
+
#pragma unroll
|
|
395
|
+
for (int l = 0; l < ne; ++l) {
|
|
396
|
+
((float *) dst)[l] = dm.x * q8[l] + dm.y;
|
|
397
|
+
}
|
|
398
|
+
} else {
|
|
399
|
+
static_assert(std::is_same_v<T, void>, "bad type");
|
|
400
|
+
}
|
|
399
401
|
}
|
|
400
402
|
|
|
401
|
-
template <typename T>
|
|
402
|
-
static __device__ __forceinline__
|
|
403
|
+
template <typename T, int ne>
|
|
404
|
+
static __device__ __forceinline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
|
403
405
|
const block_q5_0 * x = (const block_q5_0 *) vx;
|
|
404
406
|
|
|
405
|
-
const int64_t ib =
|
|
406
|
-
const int idq =
|
|
407
|
-
const int iqs =
|
|
408
|
-
const int shift = (
|
|
407
|
+
const int64_t ib = i0 / QK5_0;
|
|
408
|
+
const int idq = i0 % QK5_0;
|
|
409
|
+
const int iqs = i0 % (QK5_0/2);
|
|
410
|
+
const int shift = (i0 % QK5_0) / (QK5_0/2);
|
|
409
411
|
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
const int q = (ql | qh) - 16;
|
|
412
|
+
int q;
|
|
413
|
+
static_assert(ne == 2 || ne == 4, "bad ne");
|
|
414
|
+
ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
|
|
415
|
+
q >>= 4*shift;
|
|
416
|
+
q &= 0x0F0F0F0F;
|
|
416
417
|
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
418
|
+
{
|
|
419
|
+
int qh;
|
|
420
|
+
ggml_cuda_memcpy_1<ne, 2>(&qh, x[ib].qh);
|
|
421
|
+
#pragma unroll
|
|
422
|
+
for (int l = 0; l < ne; ++l) {
|
|
423
|
+
q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
|
|
424
|
+
}
|
|
420
425
|
}
|
|
426
|
+
|
|
427
|
+
q = __vsubss4(q, 0x10101010);
|
|
428
|
+
|
|
429
|
+
const int8_t * q8 = (const int8_t *) &q;
|
|
430
|
+
|
|
431
|
+
#ifdef FP16_AVAILABLE
|
|
432
|
+
if constexpr (std::is_same_v<T, half>) {
|
|
433
|
+
const half2 d = __half2half2(x[ib].d);
|
|
434
|
+
|
|
435
|
+
#pragma unroll
|
|
436
|
+
for (int l0 = 0; l0 < ne; l0 += 2) {
|
|
437
|
+
((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
|
|
438
|
+
}
|
|
439
|
+
} else
|
|
421
440
|
#endif // FP16_AVAILABLE
|
|
441
|
+
if constexpr (std::is_same_v<T, float>) {
|
|
442
|
+
const float d = x[ib].d;
|
|
422
443
|
|
|
423
|
-
|
|
444
|
+
#pragma unroll
|
|
445
|
+
for (int l = 0; l < ne; ++l) {
|
|
446
|
+
((float *) dst)[l] = d * q8[l];
|
|
447
|
+
}
|
|
448
|
+
} else {
|
|
449
|
+
static_assert(std::is_same_v<T, void>, "bad type");
|
|
450
|
+
}
|
|
424
451
|
}
|
|
425
452
|
|
|
426
|
-
template <typename T>
|
|
427
|
-
static __device__ __forceinline__
|
|
453
|
+
template <typename T, int ne>
|
|
454
|
+
static __device__ __forceinline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
|
428
455
|
const block_q5_1 * x = (const block_q5_1 *) vx;
|
|
429
456
|
|
|
430
|
-
const int64_t ib =
|
|
431
|
-
const int idq =
|
|
432
|
-
const int iqs =
|
|
433
|
-
const int shift = (
|
|
457
|
+
const int64_t ib = i0 / QK5_1;
|
|
458
|
+
const int idq = i0 % QK5_1;
|
|
459
|
+
const int iqs = i0 % (QK5_1/2);
|
|
460
|
+
const int shift = (i0 % QK5_1) / (QK5_1/2);
|
|
434
461
|
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
const int q = (ql | qh);
|
|
462
|
+
int q;
|
|
463
|
+
static_assert(ne == 2 || ne == 4, "bad ne");
|
|
464
|
+
ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
|
|
465
|
+
q >>= 4*shift;
|
|
466
|
+
q &= 0x0F0F0F0F;
|
|
441
467
|
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
468
|
+
{
|
|
469
|
+
int qh;
|
|
470
|
+
ggml_cuda_memcpy_1<ne>(&qh, x[ib].qh);
|
|
471
|
+
#pragma unroll
|
|
472
|
+
for (int l = 0; l < ne; ++l) {
|
|
473
|
+
q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
|
|
474
|
+
}
|
|
445
475
|
}
|
|
476
|
+
|
|
477
|
+
const int8_t * q8 = (const int8_t *) &q;
|
|
478
|
+
|
|
479
|
+
#ifdef FP16_AVAILABLE
|
|
480
|
+
if constexpr (std::is_same_v<T, half>) {
|
|
481
|
+
const half2 dm = x[ib].dm;
|
|
482
|
+
const half2 d = __half2half2( __low2half(dm));
|
|
483
|
+
const half2 m = __half2half2(__high2half(dm));
|
|
484
|
+
|
|
485
|
+
#pragma unroll
|
|
486
|
+
for (int l0 = 0; l0 < ne; l0 += 2) {
|
|
487
|
+
((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
|
|
488
|
+
}
|
|
489
|
+
} else
|
|
446
490
|
#endif // FP16_AVAILABLE
|
|
491
|
+
if constexpr (std::is_same_v<T, float>) {
|
|
492
|
+
const float2 dm = __half22float2(x[ib].dm);
|
|
447
493
|
|
|
448
|
-
|
|
494
|
+
#pragma unroll
|
|
495
|
+
for (int l = 0; l < ne; ++l) {
|
|
496
|
+
((float *) dst)[l] = dm.x * q8[l] + dm.y;
|
|
497
|
+
}
|
|
498
|
+
} else {
|
|
499
|
+
static_assert(std::is_same_v<T, void>, "bad type");
|
|
500
|
+
}
|
|
449
501
|
}
|
|
450
502
|
|
|
451
|
-
template <typename T>
|
|
452
|
-
static __device__ __forceinline__
|
|
503
|
+
template <typename T, int ne>
|
|
504
|
+
static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
|
453
505
|
const block_q8_0 * x = (const block_q8_0 *) vx;
|
|
454
506
|
|
|
455
|
-
const int64_t ib =
|
|
456
|
-
const int iqs =
|
|
507
|
+
const int64_t ib = i0 / QK8_0;
|
|
508
|
+
const int iqs = i0 % QK8_0;
|
|
457
509
|
|
|
458
|
-
|
|
459
|
-
|
|
510
|
+
static_assert(ne % 2 == 0, "bad ne");
|
|
511
|
+
int8_t qs[ne];
|
|
512
|
+
ggml_cuda_memcpy_1<ne, 2>(qs, x[ib].qs + iqs);
|
|
460
513
|
|
|
461
514
|
#ifdef FP16_AVAILABLE
|
|
462
|
-
if (std::is_same<T, half>::value) {
|
|
463
|
-
|
|
464
|
-
|
|
515
|
+
if constexpr (std::is_same<T, half>::value) {
|
|
516
|
+
const half2 d = __half2half2(x[ib].d);
|
|
517
|
+
|
|
518
|
+
#pragma unroll
|
|
519
|
+
for (int l0 = 0; l0 < ne; l0 += 2) {
|
|
520
|
+
((half2 *) dst)[l0/2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]);
|
|
521
|
+
}
|
|
522
|
+
} else
|
|
465
523
|
#endif // FP16_AVAILABLE
|
|
524
|
+
if constexpr (std::is_same<T, float>::value) {
|
|
525
|
+
const float d = x[ib].d;
|
|
466
526
|
|
|
467
|
-
|
|
527
|
+
#pragma unroll
|
|
528
|
+
for (int l = 0; l < ne; ++l) {
|
|
529
|
+
((float *) dst)[l] = d * qs[l];
|
|
530
|
+
}
|
|
531
|
+
} else {
|
|
532
|
+
static_assert(std::is_same_v<T, void>, "unsupported type");
|
|
533
|
+
}
|
|
468
534
|
}
|
|
469
535
|
|
|
470
|
-
template <
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
536
|
+
template <ggml_type type_K, int D, int nthreads>
|
|
537
|
+
constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
|
|
538
|
+
if constexpr (type_K == GGML_TYPE_F16) {
|
|
539
|
+
return vec_dot_fattn_vec_KQ_f16<D, nthreads>;
|
|
540
|
+
} else if constexpr (type_K == GGML_TYPE_Q4_0) {
|
|
541
|
+
return vec_dot_fattn_vec_KQ_q4_0<D, nthreads>;
|
|
542
|
+
} else if constexpr (type_K == GGML_TYPE_Q4_1) {
|
|
543
|
+
return vec_dot_fattn_vec_KQ_q4_1<D, nthreads>;
|
|
544
|
+
} else if constexpr (type_K == GGML_TYPE_Q5_0) {
|
|
545
|
+
return vec_dot_fattn_vec_KQ_q5_0<D, nthreads>;
|
|
546
|
+
} else if constexpr (type_K == GGML_TYPE_Q5_1) {
|
|
547
|
+
return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
|
|
548
|
+
} else if constexpr (type_K == GGML_TYPE_Q8_0) {
|
|
549
|
+
return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
|
|
550
|
+
} else {
|
|
551
|
+
static_assert(type_K == -1, "bad type");
|
|
552
|
+
return nullptr;
|
|
553
|
+
}
|
|
475
554
|
}
|
|
476
555
|
|
|
477
|
-
template <
|
|
478
|
-
constexpr __device__
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
556
|
+
template <ggml_type type_V, typename T, int ne>
|
|
557
|
+
constexpr __device__ dequantize_V_t get_dequantize_V() {
|
|
558
|
+
if constexpr (type_V == GGML_TYPE_F16) {
|
|
559
|
+
return dequantize_V_f16<T, ne>;
|
|
560
|
+
} else if constexpr (type_V == GGML_TYPE_Q4_0) {
|
|
561
|
+
return dequantize_V_q4_0<T, ne>;
|
|
562
|
+
} else if constexpr (type_V == GGML_TYPE_Q4_1) {
|
|
563
|
+
return dequantize_V_q4_1<T, ne>;
|
|
564
|
+
} else if constexpr (type_V == GGML_TYPE_Q5_0) {
|
|
565
|
+
return dequantize_V_q5_0<T, ne>;
|
|
566
|
+
} else if constexpr (type_V == GGML_TYPE_Q5_1) {
|
|
567
|
+
return dequantize_V_q5_1<T, ne>;
|
|
568
|
+
} else if constexpr (type_V == GGML_TYPE_Q8_0) {
|
|
569
|
+
return dequantize_V_q8_0<T, ne>;
|
|
570
|
+
} else {
|
|
571
|
+
static_assert(type_V == -1, "bad type");
|
|
572
|
+
return nullptr;
|
|
573
|
+
}
|
|
486
574
|
}
|
|
487
575
|
|
|
488
|
-
template <int
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
nullptr;
|
|
497
|
-
}
|
|
576
|
+
template <int ncols1>
|
|
577
|
+
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
|
|
578
|
+
static __global__ void flash_attn_mask_to_KV_max(
|
|
579
|
+
const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
|
|
580
|
+
const int ne31 = gridDim.x;
|
|
581
|
+
const int tid = threadIdx.x;
|
|
582
|
+
const int sequence = blockIdx.y;
|
|
583
|
+
const int jt = blockIdx.x;
|
|
498
584
|
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
585
|
+
mask += sequence*s33 + jt*ncols1*s31;
|
|
586
|
+
|
|
587
|
+
__shared__ int buf_iw[WARP_SIZE];
|
|
588
|
+
if (tid < WARP_SIZE) {
|
|
589
|
+
buf_iw[tid] = 1;
|
|
590
|
+
}
|
|
591
|
+
__syncthreads();
|
|
592
|
+
|
|
593
|
+
int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
|
|
594
|
+
for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
|
|
595
|
+
int all_inf = 1;
|
|
596
|
+
|
|
597
|
+
#pragma unroll
|
|
598
|
+
for (int j = 0; j < ncols1; ++j) {
|
|
599
|
+
const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
|
|
600
|
+
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
|
|
601
|
+
}
|
|
602
|
+
|
|
603
|
+
all_inf = warp_reduce_all(all_inf);
|
|
604
|
+
if (tid % WARP_SIZE == 0) {
|
|
605
|
+
buf_iw[tid / WARP_SIZE] = all_inf;
|
|
606
|
+
}
|
|
607
|
+
__syncthreads();
|
|
608
|
+
all_inf = buf_iw[tid % WARP_SIZE];
|
|
609
|
+
__syncthreads();
|
|
610
|
+
all_inf = warp_reduce_all(all_inf);
|
|
508
611
|
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
612
|
+
if (!all_inf) {
|
|
613
|
+
break;
|
|
614
|
+
}
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
// If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
|
|
618
|
+
// If the break was triggered it's the lower edge of the tile with the first non-masked values.
|
|
619
|
+
// In either case, walk back the decrementation by FATTN_KQ_STRIDE.
|
|
620
|
+
KV_max_sj += FATTN_KQ_STRIDE;
|
|
621
|
+
|
|
622
|
+
if (threadIdx.x != 0) {
|
|
623
|
+
return;
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
KV_max[sequence*ne31 + jt] = KV_max_sj;
|
|
517
627
|
}
|
|
518
628
|
|
|
519
629
|
template<int D, int ncols1, int ncols2> // D == head size
|
|
520
630
|
__launch_bounds__(D, 1)
|
|
521
631
|
static __global__ void flash_attn_stream_k_fixup(
|
|
522
|
-
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11
|
|
632
|
+
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
|
|
633
|
+
const int nbatch_fa) {
|
|
523
634
|
constexpr int ncols = ncols1*ncols2;
|
|
524
635
|
|
|
525
636
|
const int bidx0 = blockIdx.x;
|
|
@@ -530,11 +641,11 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
530
641
|
|
|
531
642
|
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
|
532
643
|
|
|
533
|
-
const int iter_k = ne11 /
|
|
534
|
-
const int iter_j = (ne01 + (ncols1
|
|
644
|
+
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
|
645
|
+
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
|
535
646
|
|
|
536
|
-
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
537
|
-
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
647
|
+
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
648
|
+
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
538
649
|
|
|
539
650
|
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
|
540
651
|
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
|
@@ -543,14 +654,15 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
543
654
|
return;
|
|
544
655
|
}
|
|
545
656
|
|
|
546
|
-
const int
|
|
547
|
-
const int
|
|
657
|
+
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
|
|
658
|
+
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
|
659
|
+
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
|
548
660
|
|
|
549
661
|
if (jt*ncols1 + j >= ne01) {
|
|
550
662
|
return;
|
|
551
663
|
}
|
|
552
664
|
|
|
553
|
-
dst += jt*ne02*(ncols1*D) +
|
|
665
|
+
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
|
|
554
666
|
|
|
555
667
|
// Load the partial result that needs a fixup:
|
|
556
668
|
float dst_val = 0.0f;
|
|
@@ -569,7 +681,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
569
681
|
int bidx = bidx0 - 1;
|
|
570
682
|
int kbc_stop = kbc0;
|
|
571
683
|
while(true) {
|
|
572
|
-
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
684
|
+
const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
573
685
|
if (kbc == kbc_stop) { // Did not have any data.
|
|
574
686
|
bidx--;
|
|
575
687
|
kbc_stop = kbc;
|
|
@@ -607,24 +719,37 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
607
719
|
}
|
|
608
720
|
|
|
609
721
|
template<int D> // D == head size
|
|
610
|
-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
611
722
|
__launch_bounds__(D, 1)
|
|
612
|
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
613
723
|
static __global__ void flash_attn_combine_results(
|
|
614
724
|
const float * __restrict__ VKQ_parts,
|
|
615
725
|
const float2 * __restrict__ VKQ_meta,
|
|
616
726
|
float * __restrict__ dst,
|
|
617
727
|
const int parallel_blocks) {
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
728
|
+
// Dimension 0: threadIdx.x
|
|
729
|
+
// Dimension 1: blockIdx.x
|
|
730
|
+
// Dimension 2: blockIdx.y
|
|
731
|
+
// Dimension 3: blockIdx.z
|
|
732
|
+
// Memory layout is permuted with [0, 2, 1, 3]
|
|
733
|
+
|
|
734
|
+
const int ne01 = gridDim.x;
|
|
735
|
+
const int ne02 = gridDim.y;
|
|
736
|
+
|
|
737
|
+
const int col = blockIdx.x;
|
|
738
|
+
const int head = blockIdx.y;
|
|
739
|
+
const int sequence = blockIdx.z;
|
|
740
|
+
|
|
741
|
+
const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
|
|
742
|
+
|
|
743
|
+
VKQ_parts += j_dst_unrolled * parallel_blocks*D;
|
|
744
|
+
VKQ_meta += j_dst_unrolled * parallel_blocks;
|
|
745
|
+
dst += j_dst_unrolled * D;
|
|
621
746
|
|
|
622
747
|
const int tid = threadIdx.x;
|
|
623
748
|
__builtin_assume(tid < D);
|
|
624
749
|
|
|
625
750
|
extern __shared__ float2 meta[];
|
|
626
751
|
for (int i = tid; i < 2*parallel_blocks; i += D) {
|
|
627
|
-
((float *) meta)[i] = ((const float *)VKQ_meta) [
|
|
752
|
+
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
|
|
628
753
|
}
|
|
629
754
|
|
|
630
755
|
__syncthreads();
|
|
@@ -637,44 +762,19 @@ static __global__ void flash_attn_combine_results(
|
|
|
637
762
|
float VKQ_numerator = 0.0f;
|
|
638
763
|
float VKQ_denominator = 0.0f;
|
|
639
764
|
for (int l = 0; l < parallel_blocks; ++l) {
|
|
640
|
-
const float
|
|
641
|
-
float KQ_max_scale = expf(diff);
|
|
642
|
-
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
|
643
|
-
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
|
765
|
+
const float KQ_max_scale = expf(meta[l].x - kqmax);
|
|
644
766
|
|
|
645
|
-
VKQ_numerator += KQ_max_scale * VKQ_parts[l*
|
|
767
|
+
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
|
|
646
768
|
VKQ_denominator += KQ_max_scale * meta[l].y;
|
|
647
769
|
}
|
|
648
770
|
|
|
649
|
-
dst[
|
|
650
|
-
}
|
|
651
|
-
|
|
652
|
-
[[noreturn]]
|
|
653
|
-
static void on_no_fattn_vec_case(const int D) {
|
|
654
|
-
if (D == 64) {
|
|
655
|
-
fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
|
|
656
|
-
fprintf(stderr, "By default only f16 KV cache is supported.\n");
|
|
657
|
-
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
|
|
658
|
-
GGML_ABORT("fatal error");
|
|
659
|
-
} else if (D == 128) {
|
|
660
|
-
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
|
|
661
|
-
fprintf(stderr, "Supported combinations:\n");
|
|
662
|
-
fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
|
|
663
|
-
fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
|
|
664
|
-
fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
|
|
665
|
-
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
|
|
666
|
-
GGML_ABORT("fatal error");
|
|
667
|
-
} else {
|
|
668
|
-
fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
|
|
669
|
-
fprintf(stderr, "Only f16 is supported.\n");
|
|
670
|
-
GGML_ABORT("fatal error");
|
|
671
|
-
}
|
|
771
|
+
dst[tid] = VKQ_numerator / VKQ_denominator;
|
|
672
772
|
}
|
|
673
773
|
|
|
674
774
|
template <int DV, int ncols1, int ncols2>
|
|
675
775
|
void launch_fattn(
|
|
676
776
|
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
|
677
|
-
const int
|
|
777
|
+
const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
|
678
778
|
) {
|
|
679
779
|
constexpr int ncols = ncols1 * ncols2;
|
|
680
780
|
|
|
@@ -686,7 +786,8 @@ void launch_fattn(
|
|
|
686
786
|
|
|
687
787
|
GGML_ASSERT(V || is_mla);
|
|
688
788
|
|
|
689
|
-
const ggml_tensor * mask
|
|
789
|
+
const ggml_tensor * mask = dst->src[3];
|
|
790
|
+
const ggml_tensor * sinks = dst->src[4];
|
|
690
791
|
|
|
691
792
|
ggml_tensor * KQV = dst;
|
|
692
793
|
|
|
@@ -698,12 +799,6 @@ void launch_fattn(
|
|
|
698
799
|
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
|
|
699
800
|
|
|
700
801
|
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
|
701
|
-
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
|
702
|
-
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
|
703
|
-
|
|
704
|
-
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
|
705
|
-
|
|
706
|
-
GGML_ASSERT(Q->ne[3] == 1);
|
|
707
802
|
|
|
708
803
|
ggml_cuda_pool & pool = ctx.pool();
|
|
709
804
|
cudaStream_t main_stream = ctx.stream();
|
|
@@ -713,6 +808,7 @@ void launch_fattn(
|
|
|
713
808
|
|
|
714
809
|
ggml_cuda_pool_alloc<half> K_f16(pool);
|
|
715
810
|
ggml_cuda_pool_alloc<half> V_f16(pool);
|
|
811
|
+
ggml_cuda_pool_alloc<int> KV_max(pool);
|
|
716
812
|
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
|
717
813
|
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
|
718
814
|
|
|
@@ -727,43 +823,87 @@ void launch_fattn(
|
|
|
727
823
|
size_t nb23 = V ? V->nb[3] : nb13;
|
|
728
824
|
|
|
729
825
|
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
|
730
|
-
GGML_ASSERT(ggml_is_contiguously_allocated(K));
|
|
731
|
-
K_f16.alloc(ggml_nelements(K));
|
|
732
|
-
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
|
|
733
|
-
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
|
|
734
|
-
K_data = (char *) K_f16.ptr;
|
|
735
|
-
|
|
736
826
|
const size_t bs = ggml_blck_size(K->type);
|
|
737
827
|
const size_t ts = ggml_type_size(K->type);
|
|
738
828
|
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
829
|
+
K_f16.alloc(ggml_nelements(K));
|
|
830
|
+
if (ggml_is_contiguously_allocated(K)) {
|
|
831
|
+
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
|
|
832
|
+
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
|
|
833
|
+
|
|
834
|
+
nb11 = nb11*bs*sizeof(half)/ts;
|
|
835
|
+
nb12 = nb12*bs*sizeof(half)/ts;
|
|
836
|
+
nb13 = nb13*bs*sizeof(half)/ts;
|
|
837
|
+
} else {
|
|
838
|
+
GGML_ASSERT(K->nb[0] == ts);
|
|
839
|
+
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
|
|
840
|
+
const int64_t s01 = nb11 / ts;
|
|
841
|
+
const int64_t s02 = nb12 / ts;
|
|
842
|
+
const int64_t s03 = nb13 / ts;
|
|
843
|
+
to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
|
|
844
|
+
|
|
845
|
+
nb11 = K->ne[0] * sizeof(half);
|
|
846
|
+
nb12 = K->ne[1] * nb11;
|
|
847
|
+
nb13 = K->ne[2] * nb12;
|
|
848
|
+
}
|
|
849
|
+
K_data = (char *) K_f16.ptr;
|
|
742
850
|
}
|
|
743
851
|
|
|
744
852
|
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
|
|
745
|
-
GGML_ASSERT(ggml_is_contiguously_allocated(V));
|
|
746
|
-
V_f16.alloc(ggml_nelements(V));
|
|
747
|
-
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
|
748
|
-
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
|
749
|
-
V_data = (char *) V_f16.ptr;
|
|
750
|
-
|
|
751
853
|
const size_t bs = ggml_blck_size(V->type);
|
|
752
854
|
const size_t ts = ggml_type_size(V->type);
|
|
753
855
|
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
856
|
+
V_f16.alloc(ggml_nelements(V));
|
|
857
|
+
if (ggml_is_contiguously_allocated(V)) {
|
|
858
|
+
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
|
859
|
+
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
|
860
|
+
V_data = (char *) V_f16.ptr;
|
|
861
|
+
|
|
862
|
+
nb21 = nb21*bs*sizeof(half)/ts;
|
|
863
|
+
nb22 = nb22*bs*sizeof(half)/ts;
|
|
864
|
+
nb23 = nb23*bs*sizeof(half)/ts;
|
|
865
|
+
} else {
|
|
866
|
+
GGML_ASSERT(V->nb[0] == ts);
|
|
867
|
+
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
|
|
868
|
+
const int64_t s01 = nb21 / ts;
|
|
869
|
+
const int64_t s02 = nb22 / ts;
|
|
870
|
+
const int64_t s03 = nb23 / ts;
|
|
871
|
+
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
|
872
|
+
|
|
873
|
+
nb21 = V->ne[0] * sizeof(half);
|
|
874
|
+
nb22 = V->ne[1] * nb21;
|
|
875
|
+
nb23 = V->ne[2] * nb22;
|
|
876
|
+
}
|
|
877
|
+
V_data = (char *) V_f16.ptr;
|
|
757
878
|
}
|
|
758
879
|
|
|
759
|
-
int parallel_blocks = 1;
|
|
760
|
-
|
|
761
880
|
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
|
762
881
|
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
|
763
882
|
|
|
883
|
+
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
|
|
884
|
+
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
|
|
885
|
+
// multiple sequences of possibly different lengths.
|
|
886
|
+
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
|
|
887
|
+
const int s31 = mask->nb[1] / sizeof(half2);
|
|
888
|
+
const int s33 = mask->nb[3] / sizeof(half2);
|
|
889
|
+
|
|
890
|
+
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
|
|
891
|
+
const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
|
|
892
|
+
|
|
893
|
+
const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
|
|
894
|
+
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
|
|
895
|
+
|
|
896
|
+
KV_max.alloc(ne_KV_max);
|
|
897
|
+
flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
|
|
898
|
+
((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
|
|
899
|
+
CUDA_CHECK(cudaGetLastError());
|
|
900
|
+
}
|
|
901
|
+
|
|
764
902
|
const dim3 block_dim(warp_size, nwarps, 1);
|
|
765
903
|
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
|
766
904
|
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
|
905
|
+
GGML_ASSERT(max_blocks_per_sm > 0);
|
|
906
|
+
int parallel_blocks = max_blocks_per_sm;
|
|
767
907
|
|
|
768
908
|
dim3 blocks_num;
|
|
769
909
|
if (stream_k) {
|
|
@@ -780,13 +920,11 @@ void launch_fattn(
|
|
|
780
920
|
blocks_num.y = 1;
|
|
781
921
|
blocks_num.z = 1;
|
|
782
922
|
|
|
783
|
-
|
|
923
|
+
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
|
924
|
+
dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
|
|
925
|
+
}
|
|
784
926
|
} else {
|
|
785
|
-
|
|
786
|
-
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
|
787
|
-
|
|
788
|
-
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
|
|
789
|
-
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
|
|
927
|
+
const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
|
|
790
928
|
|
|
791
929
|
// parallel_blocks must not be larger than what the tensor size allows:
|
|
792
930
|
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
|
@@ -802,7 +940,7 @@ void launch_fattn(
|
|
|
802
940
|
const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
|
|
803
941
|
|
|
804
942
|
// Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
|
|
805
|
-
if (efficiency_percent_best >=
|
|
943
|
+
if (efficiency_percent_best >= 95 && nwaves > nwaves_best) {
|
|
806
944
|
break;
|
|
807
945
|
}
|
|
808
946
|
|
|
@@ -815,7 +953,7 @@ void launch_fattn(
|
|
|
815
953
|
|
|
816
954
|
blocks_num.x = ntiles_x;
|
|
817
955
|
blocks_num.y = parallel_blocks;
|
|
818
|
-
blocks_num.z = Q->ne[2]*Q->ne[3];
|
|
956
|
+
blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
|
|
819
957
|
|
|
820
958
|
if (parallel_blocks > 1) {
|
|
821
959
|
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
|
@@ -841,21 +979,24 @@ void launch_fattn(
|
|
|
841
979
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
842
980
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
843
981
|
|
|
982
|
+
// TODO other tensor dimensions after removal of WMMA kernel:
|
|
983
|
+
const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
|
|
984
|
+
|
|
844
985
|
GGML_ASSERT(block_dim.x % warp_size == 0);
|
|
845
986
|
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
|
|
846
987
|
(const char *) Q->data,
|
|
847
988
|
K_data,
|
|
848
989
|
V_data,
|
|
849
990
|
mask ? ((const char *) mask->data) : nullptr,
|
|
991
|
+
sinks ? ((const char *) sinks->data) : nullptr,
|
|
992
|
+
KV_max.ptr,
|
|
850
993
|
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
|
851
994
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
852
|
-
Q->ne[0], Q->ne[
|
|
853
|
-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
854
|
-
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
855
|
-
Q->nb[1], Q->nb[2], Q->nb[3],
|
|
856
|
-
nb11, nb12, nb13,
|
|
995
|
+
Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
|
|
996
|
+
K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
|
|
857
997
|
nb21, nb22, nb23,
|
|
858
|
-
|
|
998
|
+
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
|
|
999
|
+
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
|
|
859
1000
|
);
|
|
860
1001
|
CUDA_CHECK(cudaGetLastError());
|
|
861
1002
|
|
|
@@ -866,11 +1007,11 @@ void launch_fattn(
|
|
|
866
1007
|
|
|
867
1008
|
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
|
868
1009
|
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
|
869
|
-
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
|
1010
|
+
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
|
|
870
1011
|
}
|
|
871
1012
|
} else if (parallel_blocks > 1) {
|
|
872
1013
|
const dim3 block_dim_combine(DV, 1, 1);
|
|
873
|
-
const dim3 blocks_num_combine(Q->ne[1],
|
|
1014
|
+
const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
|
|
874
1015
|
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
|
875
1016
|
|
|
876
1017
|
flash_attn_combine_results<DV>
|