whispercpp 1.3.4 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +158 -44
- data/ext/extconf.rb +3 -2
- data/ext/ruby_whisper.c +34 -6
- data/ext/ruby_whisper.h +67 -0
- data/ext/ruby_whisper_context.c +236 -144
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +12 -13
- data/ext/ruby_whisper_params.c +47 -24
- data/ext/ruby_whisper_segment.c +84 -20
- data/ext/ruby_whisper_token.c +371 -0
- data/ext/ruby_whisper_transcribe.cpp +5 -2
- data/ext/ruby_whisper_vad_context.c +122 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +138 -0
- data/ext/ruby_whisper_vad_segments.c +105 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +129 -112
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +28 -15
- data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
- data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
- data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
- data/ext/sources/examples/talk-llama/llama-context.h +70 -23
- data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
- data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
- data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
- data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
- data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
- data/ext/sources/examples/talk-llama/llama-model.h +112 -18
- data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
- data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
- data/ext/sources/examples/talk-llama/llama.cpp +802 -21
- data/ext/sources/examples/talk-llama/llama.h +210 -39
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +704 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +90 -56
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +5 -2
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +6 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +14 -12
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +246 -21
- data/ext/sources/ggml/src/CMakeLists.txt +85 -11
- data/ext/sources/ggml/src/ggml-alloc.c +128 -50
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
- data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
- data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
- data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
- data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
- data/ext/sources/ggml/src/ggml-impl.h +129 -6
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- data/ext/sources/ggml/src/ggml.c +590 -64
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +106 -62
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +162 -4
- data/test/test_context_params.rb +82 -0
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +81 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +100 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +27 -0
- data/whispercpp.gemspec +1 -1
- metadata +502 -37
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -1,16 +1,31 @@
|
|
|
1
|
-
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
|
1
|
+
// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
|
2
2
|
// SPDX-License-Identifier: MIT
|
|
3
3
|
//
|
|
4
4
|
#include <arm_neon.h>
|
|
5
5
|
#include <assert.h>
|
|
6
|
+
#include <stdio.h>
|
|
6
7
|
#include <atomic>
|
|
7
8
|
#include <cfloat>
|
|
9
|
+
#include <algorithm>
|
|
10
|
+
#include <cmath>
|
|
8
11
|
#include <stdexcept>
|
|
9
12
|
#include <stdint.h>
|
|
10
13
|
#include <string.h>
|
|
14
|
+
#include <string>
|
|
15
|
+
#include <vector>
|
|
16
|
+
#include <array>
|
|
17
|
+
#include <cstddef>
|
|
18
|
+
#include <cstdint>
|
|
19
|
+
#include <fstream>
|
|
20
|
+
#include <set>
|
|
21
|
+
#include <iostream>
|
|
22
|
+
#include <climits>
|
|
11
23
|
#if defined(__linux__)
|
|
12
24
|
#include <asm/hwcap.h>
|
|
13
25
|
#include <sys/auxv.h>
|
|
26
|
+
#include <sys/types.h>
|
|
27
|
+
#include <sys/stat.h>
|
|
28
|
+
#include <unistd.h>
|
|
14
29
|
#elif defined(__APPLE__)
|
|
15
30
|
#include <string_view>
|
|
16
31
|
#include <sys/sysctl.h>
|
|
@@ -35,90 +50,369 @@
|
|
|
35
50
|
#define GGML_COMMON_DECL_CPP
|
|
36
51
|
#include "ggml-common.h"
|
|
37
52
|
|
|
53
|
+
static constexpr int GGML_KLEIDIAI_MAX_KERNEL_SLOTS = 2;
|
|
54
|
+
static constexpr uint32_t GGML_KLEIDIAI_PACK_MAGIC = 0x4b4c4149; // "KLAI"
|
|
55
|
+
static constexpr uint16_t GGML_KLEIDIAI_PACK_VERSION = 1;
|
|
56
|
+
static constexpr size_t GGML_KLEIDIAI_PACK_ALIGN = 64;
|
|
57
|
+
|
|
38
58
|
struct ggml_kleidiai_context {
|
|
39
59
|
cpu_feature features;
|
|
40
|
-
ggml_kleidiai_kernels *
|
|
41
|
-
|
|
60
|
+
ggml_kleidiai_kernels * kernels_q4;
|
|
61
|
+
ggml_kleidiai_kernels * kernels_q8;
|
|
62
|
+
int sme_thread_cap; // <= 0 means “SME disabled/unknown”;
|
|
63
|
+
int thread_hint; // <= 0 means “no hint”
|
|
64
|
+
} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1 };
|
|
42
65
|
|
|
43
66
|
static const char* cpu_feature_to_string(cpu_feature f) {
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
67
|
+
if (f == CPU_FEATURE_NONE) {
|
|
68
|
+
return "NONE";
|
|
69
|
+
} else if ((f & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
|
|
70
|
+
return "SME";
|
|
71
|
+
} else if ((f & CPU_FEATURE_SVE) == CPU_FEATURE_SVE) {
|
|
72
|
+
return "SVE";
|
|
73
|
+
}
|
|
74
|
+
else if ((f & CPU_FEATURE_I8MM) == CPU_FEATURE_I8MM) {
|
|
75
|
+
return "I8MM";
|
|
76
|
+
} else if ((f & CPU_FEATURE_DOTPROD) == CPU_FEATURE_DOTPROD) {
|
|
77
|
+
return "DOTPROD";
|
|
78
|
+
}
|
|
79
|
+
else {
|
|
80
|
+
return "UNKNOWN";
|
|
51
81
|
}
|
|
52
82
|
}
|
|
53
83
|
|
|
54
|
-
static
|
|
84
|
+
static size_t detect_num_smcus() {
|
|
85
|
+
if (!ggml_cpu_has_sme()) {
|
|
86
|
+
return 0;
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
#if defined(__linux__) && defined(__aarch64__)
|
|
90
|
+
// Linux/aarch64: Best-effort count of Streaming Mode Compute Units (SMCUs) via SMIDR_EL1 sysfs.
|
|
91
|
+
size_t num_private = 0;
|
|
92
|
+
std::set<uint32_t> shared_ids;
|
|
93
|
+
|
|
94
|
+
for (size_t cpu = 0;; ++cpu) {
|
|
95
|
+
const std::string path =
|
|
96
|
+
"/sys/devices/system/cpu/cpu" + std::to_string(cpu) +
|
|
97
|
+
"/regs/identification/smidr_el1";
|
|
98
|
+
|
|
99
|
+
std::ifstream file(path);
|
|
100
|
+
if (!file.is_open()) {
|
|
101
|
+
break;
|
|
102
|
+
}
|
|
55
103
|
|
|
104
|
+
uint64_t smidr = 0;
|
|
105
|
+
if (!(file >> std::hex >> smidr)) {
|
|
106
|
+
continue;
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
// Arm ARM: SMIDR_EL1
|
|
110
|
+
const uint32_t sh = (uint32_t)((smidr >> 13) & 0x3);
|
|
111
|
+
// Build an "affinity-like" identifier for shared SMCUs.
|
|
112
|
+
// Keep the original packing logic, but isolate it here.
|
|
113
|
+
const uint32_t id = (uint32_t)((smidr & 0xFFFu) | ((smidr >> 20) & 0xFFFFF000u));
|
|
114
|
+
|
|
115
|
+
switch (sh) {
|
|
116
|
+
case 0b10: // private SMCU
|
|
117
|
+
++num_private;
|
|
118
|
+
break;
|
|
119
|
+
case 0b11: // shared SMCU
|
|
120
|
+
shared_ids.emplace(id);
|
|
121
|
+
break;
|
|
122
|
+
case 0b00:
|
|
123
|
+
// Ambiguous / implementation-defined. Be conservative:
|
|
124
|
+
// treat id==0 as private, otherwise as shared.
|
|
125
|
+
if (id == 0) ++num_private;
|
|
126
|
+
else shared_ids.emplace(id);
|
|
127
|
+
break;
|
|
128
|
+
default:
|
|
129
|
+
break;
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
return num_private + shared_ids.size();
|
|
134
|
+
|
|
135
|
+
#elif defined(__APPLE__) && defined(__aarch64__)
|
|
136
|
+
// table for known M4 variants. Users can override via GGML_KLEIDIAI_SME=<n>.
|
|
137
|
+
char chip_name[256] = {};
|
|
138
|
+
size_t size = sizeof(chip_name);
|
|
139
|
+
|
|
140
|
+
if (sysctlbyname("machdep.cpu.brand_string", chip_name, &size, nullptr, 0) == 0) {
|
|
141
|
+
const std::string brand(chip_name);
|
|
142
|
+
|
|
143
|
+
struct ModelSMCU { const char *match; size_t smcus; };
|
|
144
|
+
static const ModelSMCU table[] = {
|
|
145
|
+
{ "M4 Ultra", 2 },
|
|
146
|
+
{ "M4 Max", 2 },
|
|
147
|
+
{ "M4 Pro", 2 },
|
|
148
|
+
{ "M4", 1 },
|
|
149
|
+
};
|
|
150
|
+
|
|
151
|
+
for (const auto &e : table) {
|
|
152
|
+
if (brand.find(e.match) != std::string::npos) {
|
|
153
|
+
return e.smcus;
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
return 1;
|
|
158
|
+
|
|
159
|
+
#else
|
|
160
|
+
return 1;
|
|
161
|
+
#endif
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
static int parse_uint_env(const char *s, const char *name, bool *ok) {
|
|
165
|
+
if (!s) { *ok = false; return 0; }
|
|
166
|
+
char *end = nullptr;
|
|
167
|
+
long v = strtol(s, &end, 10);
|
|
168
|
+
if (end == s || *end != '\0') {
|
|
169
|
+
GGML_LOG_WARN("kleidiai: invalid %s='%s' (expected integer)\n", name, s);
|
|
170
|
+
*ok = false;
|
|
171
|
+
return 0;
|
|
172
|
+
}
|
|
173
|
+
if (v < 0 || v > INT_MAX) {
|
|
174
|
+
GGML_LOG_WARN("kleidiai: out-of-range %s='%s'\n", name, s);
|
|
175
|
+
*ok = false;
|
|
176
|
+
return 0;
|
|
177
|
+
}
|
|
178
|
+
*ok = true;
|
|
179
|
+
return (int)v;
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
static void init_kleidiai_context(void) {
|
|
56
183
|
ggml_critical_section_start();
|
|
57
184
|
static bool initialized = false;
|
|
58
185
|
|
|
59
186
|
if (!initialized) {
|
|
60
187
|
initialized = true;
|
|
61
|
-
|
|
62
|
-
|
|
188
|
+
|
|
189
|
+
const char *env_sme = getenv("GGML_KLEIDIAI_SME");
|
|
190
|
+
const char *env_threads = getenv("GGML_TOTAL_THREADS");
|
|
191
|
+
|
|
192
|
+
const bool cpu_has_sme = ggml_cpu_has_sme();
|
|
193
|
+
size_t detected_smcus = 0;
|
|
63
194
|
|
|
64
195
|
ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
|
|
65
196
|
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
|
|
66
|
-
(ggml_cpu_has_sve()
|
|
197
|
+
((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
|
|
67
198
|
|
|
68
|
-
if (
|
|
69
|
-
|
|
199
|
+
if (env_threads) {
|
|
200
|
+
bool ok = false;
|
|
201
|
+
int hint = parse_uint_env(env_threads, "GGML_TOTAL_THREADS", &ok);
|
|
202
|
+
if (ok && hint > 0) {
|
|
203
|
+
ctx.thread_hint = hint;
|
|
204
|
+
}
|
|
70
205
|
}
|
|
71
206
|
|
|
72
|
-
|
|
73
|
-
|
|
207
|
+
// SME policy:
|
|
208
|
+
// - If CPU doesn't support SME: SME always off.
|
|
209
|
+
// - Else:
|
|
210
|
+
// - env unset => auto-detect cores; enable if detected > 0.
|
|
211
|
+
// - env=0 => force off.
|
|
212
|
+
// - env>0 => force N cores (skip detection).
|
|
213
|
+
int sme_cores = 0;
|
|
214
|
+
bool sme_env_ok = false;
|
|
215
|
+
bool sme_env_set = (env_sme != nullptr);
|
|
216
|
+
|
|
217
|
+
if (!cpu_has_sme) {
|
|
218
|
+
if (sme_env_set) {
|
|
219
|
+
bool ok = false;
|
|
220
|
+
int req = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
|
|
221
|
+
if (ok && req > 0) {
|
|
222
|
+
GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME=%d but SME is not supported on this CPU; disabling SME\n", req);
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
sme_cores = 0;
|
|
226
|
+
} else {
|
|
227
|
+
if (sme_env_set) {
|
|
228
|
+
bool ok = false;
|
|
229
|
+
int v = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
|
|
230
|
+
sme_env_ok = ok;
|
|
231
|
+
|
|
232
|
+
if (!ok) {
|
|
233
|
+
GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME set but parsing failed; falling back to runtime SME-core detection\n");
|
|
234
|
+
detected_smcus = detect_num_smcus();
|
|
235
|
+
sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
|
|
236
|
+
} else if (v == 0) {
|
|
237
|
+
sme_cores = 0;
|
|
238
|
+
} else {
|
|
239
|
+
sme_cores = v;
|
|
240
|
+
}
|
|
241
|
+
} else {
|
|
242
|
+
detected_smcus = detect_num_smcus();
|
|
243
|
+
sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
if (!sme_env_set && sme_cores == 0) {
|
|
247
|
+
GGML_LOG_WARN("kleidiai: SME supported but runtime SME-core detection returned 0; falling back to NEON\n");
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
if (sme_cores > 0) {
|
|
251
|
+
ctx.features |= CPU_FEATURE_SME;
|
|
252
|
+
}
|
|
74
253
|
}
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
254
|
+
|
|
255
|
+
// Kernel selection
|
|
256
|
+
ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
|
257
|
+
ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
|
|
258
|
+
|
|
259
|
+
if (!ctx.kernels_q4) {
|
|
260
|
+
GGML_LOG_INFO("kleidiai: no compatible q4 kernels found for CPU features mask %d\n", (int)ctx.features);
|
|
261
|
+
} else {
|
|
262
|
+
GGML_LOG_INFO("kleidiai: primary q4 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
if (!ctx.kernels_q8) {
|
|
266
|
+
GGML_LOG_INFO("kleidiai: no compatible q8 kernels found for CPU features mask %d\n", (int)ctx.features);
|
|
267
|
+
} else {
|
|
268
|
+
GGML_LOG_INFO("kleidiai: primary q8 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
ctx.sme_thread_cap = (ctx.features & CPU_FEATURE_SME) ? sme_cores : 0;
|
|
272
|
+
|
|
273
|
+
if (ctx.features & CPU_FEATURE_SME) {
|
|
274
|
+
if (sme_env_set && sme_env_ok && sme_cores > 0) {
|
|
275
|
+
GGML_LOG_INFO("kleidiai: SME enabled (GGML_KLEIDIAI_SME=%d override)\n", sme_cores);
|
|
276
|
+
} else {
|
|
277
|
+
GGML_LOG_INFO("kleidiai: SME enabled (runtime-detected SME cores=%d)\n", sme_cores);
|
|
278
|
+
}
|
|
279
|
+
} else {
|
|
280
|
+
GGML_LOG_INFO("kleidiai: SME disabled\n");
|
|
79
281
|
}
|
|
80
|
-
#endif
|
|
81
282
|
}
|
|
283
|
+
|
|
82
284
|
ggml_critical_section_end();
|
|
83
285
|
}
|
|
84
286
|
|
|
85
|
-
static inline
|
|
86
|
-
|
|
87
|
-
return tensor->ne[dim];
|
|
287
|
+
static inline int kleidiai_sme_thread_cap() {
|
|
288
|
+
return ctx.sme_thread_cap;
|
|
88
289
|
}
|
|
89
290
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
Args...> || ...);
|
|
291
|
+
static inline size_t align_up(size_t value, size_t alignment) {
|
|
292
|
+
if (alignment == 0) {
|
|
293
|
+
return value;
|
|
294
|
+
}
|
|
295
|
+
const size_t remainder = value % alignment;
|
|
296
|
+
return remainder == 0 ? value : value + (alignment - remainder);
|
|
97
297
|
}
|
|
98
298
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
299
|
+
static inline bool kleidiai_pack_fallback_allowed() {
|
|
300
|
+
if (ctx.sme_thread_cap <= 0) {
|
|
301
|
+
return false;
|
|
302
|
+
}
|
|
303
|
+
if (ctx.thread_hint <= 0) {
|
|
304
|
+
return true;
|
|
305
|
+
}
|
|
306
|
+
return ctx.thread_hint > ctx.sme_thread_cap;
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
struct kleidiai_weight_header {
|
|
310
|
+
uint32_t magic;
|
|
311
|
+
uint16_t version;
|
|
312
|
+
uint16_t slot_count;
|
|
313
|
+
uint64_t offsets[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
|
|
314
|
+
uint64_t sizes[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
|
|
315
|
+
};
|
|
316
|
+
|
|
317
|
+
static inline kleidiai_weight_header * kleidiai_weight_header_from_ptr(void * data) {
|
|
318
|
+
return reinterpret_cast<kleidiai_weight_header *>(data);
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
static inline const kleidiai_weight_header * kleidiai_weight_header_from_ptr(const void * data) {
|
|
322
|
+
return reinterpret_cast<const kleidiai_weight_header *>(data);
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
static inline bool kleidiai_is_weight_header_valid(const kleidiai_weight_header * header) {
|
|
326
|
+
if (!header) {
|
|
327
|
+
return false;
|
|
328
|
+
}
|
|
329
|
+
if (header->magic != GGML_KLEIDIAI_PACK_MAGIC || header->version != GGML_KLEIDIAI_PACK_VERSION) {
|
|
330
|
+
return false;
|
|
331
|
+
}
|
|
332
|
+
if (header->slot_count == 0 || header->slot_count > GGML_KLEIDIAI_MAX_KERNEL_SLOTS) {
|
|
333
|
+
return false;
|
|
334
|
+
}
|
|
335
|
+
return true;
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
static inline uint8_t * kleidiai_weight_slot_ptr(kleidiai_weight_header * header, int slot) {
|
|
339
|
+
if (!kleidiai_is_weight_header_valid(header)) {
|
|
340
|
+
return nullptr;
|
|
341
|
+
}
|
|
342
|
+
if (slot < 0 || slot >= header->slot_count) {
|
|
343
|
+
return nullptr;
|
|
344
|
+
}
|
|
345
|
+
return reinterpret_cast<uint8_t *>(header) + header->offsets[slot];
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
static inline const uint8_t * kleidiai_weight_slot_ptr(const kleidiai_weight_header * header, int slot) {
|
|
349
|
+
if (!kleidiai_is_weight_header_valid(header)) {
|
|
350
|
+
return nullptr;
|
|
351
|
+
}
|
|
352
|
+
if (slot < 0 || slot >= header->slot_count) {
|
|
353
|
+
return nullptr;
|
|
354
|
+
}
|
|
355
|
+
return reinterpret_cast<const uint8_t *>(header) + header->offsets[slot];
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q4() {
|
|
359
|
+
return ctx.kernels_q4;
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q8() {
|
|
363
|
+
return ctx.kernels_q8;
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
template <typename SelectFallback>
|
|
367
|
+
static int kleidiai_collect_kernel_chain_common(
|
|
368
|
+
ggml_kleidiai_kernels * primary,
|
|
369
|
+
cpu_feature features,
|
|
370
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out,
|
|
371
|
+
SelectFallback select_fallback) {
|
|
372
|
+
int count = 0;
|
|
373
|
+
if (!primary) {
|
|
374
|
+
return 0;
|
|
375
|
+
}
|
|
376
|
+
out[count++] = primary;
|
|
377
|
+
|
|
378
|
+
if ((primary->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
|
|
379
|
+
const cpu_feature fallback_mask = static_cast<cpu_feature>(features & ~CPU_FEATURE_SME);
|
|
380
|
+
if (fallback_mask != CPU_FEATURE_NONE) {
|
|
381
|
+
ggml_kleidiai_kernels * fallback = select_fallback(fallback_mask);
|
|
382
|
+
if (fallback && fallback != primary &&
|
|
383
|
+
fallback->lhs_type == primary->lhs_type &&
|
|
384
|
+
fallback->rhs_type == primary->rhs_type &&
|
|
385
|
+
fallback->op_type == primary->op_type) {
|
|
386
|
+
out[count++] = fallback;
|
|
118
387
|
}
|
|
119
|
-
}
|
|
120
|
-
|
|
121
|
-
|
|
388
|
+
}
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
return count;
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
static int kleidiai_collect_kernel_chain(const struct ggml_tensor * op,
|
|
395
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
|
|
396
|
+
ggml_kleidiai_kernels * primary = ggml_kleidiai_select_kernels(ctx.features, op);
|
|
397
|
+
return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
|
|
398
|
+
[&](cpu_feature mask) { return ggml_kleidiai_select_kernels(mask, op); });
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
static int kleidiai_collect_q4_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
|
|
402
|
+
ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q4();
|
|
403
|
+
return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
|
|
404
|
+
[&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q4_0(mask); });
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
static int kleidiai_collect_q8_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
|
|
408
|
+
ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q8();
|
|
409
|
+
return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
|
|
410
|
+
[&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q8_0(mask); });
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
|
|
414
|
+
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
|
|
415
|
+
return tensor->ne[dim];
|
|
122
416
|
}
|
|
123
417
|
|
|
124
418
|
namespace ggml::cpu::kleidiai {
|
|
@@ -144,45 +438,113 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
144
438
|
if (op->op != GGML_OP_MUL_MAT) {
|
|
145
439
|
return false;
|
|
146
440
|
}
|
|
147
|
-
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
|
|
148
|
-
GGML_ASSERT(kernels);
|
|
149
|
-
bool is_gemv = op->src[1]->ne[1] == 1;
|
|
150
|
-
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
151
|
-
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
152
441
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
442
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
443
|
+
const int slot_count = kleidiai_collect_kernel_chain(op, kernel_chain);
|
|
444
|
+
if (slot_count == 0) {
|
|
445
|
+
return false;
|
|
446
|
+
}
|
|
156
447
|
|
|
157
|
-
|
|
158
|
-
size_t
|
|
159
|
-
size_t
|
|
448
|
+
const bool is_gemv = op->src[1]->ne[1] == 1;
|
|
449
|
+
const size_t k = op->src[0]->ne[0];
|
|
450
|
+
const size_t n = op->src[0]->ne[1];
|
|
451
|
+
const size_t m = op->src[1]->ne[1];
|
|
160
452
|
|
|
161
|
-
if (
|
|
162
|
-
|
|
163
|
-
|
|
453
|
+
if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) {
|
|
454
|
+
const size_t qk = (op->src[0]->type == GGML_TYPE_Q4_0) ? QK4_0 : QK8_0;
|
|
455
|
+
|
|
456
|
+
size_t cursor = 0;
|
|
457
|
+
bool any_slot = false;
|
|
458
|
+
|
|
459
|
+
for (int slot = 0; slot < slot_count; ++slot) {
|
|
460
|
+
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
461
|
+
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
462
|
+
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
463
|
+
|
|
464
|
+
if (!lhs_info || !lhs_info->packed_size_ex || !kernel) {
|
|
465
|
+
return false;
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
const size_t mr = kernel->get_mr();
|
|
469
|
+
const size_t kr = kernel->get_kr();
|
|
470
|
+
const size_t sr = kernel->get_sr();
|
|
471
|
+
|
|
472
|
+
const size_t packed = lhs_info->packed_size_ex(m, k, qk, mr, kr, sr);
|
|
473
|
+
|
|
474
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
475
|
+
cursor += packed;
|
|
476
|
+
any_slot = true;
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
if (!any_slot) {
|
|
480
|
+
return false;
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
size = cursor;
|
|
484
|
+
return true;
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
if (op->src[0]->type == GGML_TYPE_F16) {
|
|
164
488
|
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
|
|
165
489
|
const int64_t rhs_batch_size0 = op->src[0]->ne[2];
|
|
490
|
+
GGML_ASSERT(rhs_batch_size0 > 0);
|
|
166
491
|
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
492
|
+
|
|
493
|
+
size_t cursor = 0;
|
|
494
|
+
bool any_slot = false;
|
|
495
|
+
|
|
496
|
+
for (int slot = 0; slot < slot_count; ++slot) {
|
|
497
|
+
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
498
|
+
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
499
|
+
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
500
|
+
if (!lhs_info || !lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex || !kernel) {
|
|
501
|
+
return false;
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
const size_t mr = kernel->get_mr();
|
|
505
|
+
const size_t kr = kernel->get_kr();
|
|
506
|
+
const size_t sr = kernel->get_sr();
|
|
507
|
+
|
|
508
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
509
|
+
cursor += lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr);
|
|
510
|
+
any_slot = true;
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
for (int slot = 0; slot < slot_count; ++slot) {
|
|
514
|
+
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
515
|
+
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
516
|
+
if (!kernel || !kernels->rhs_info.packed_size_ex) {
|
|
517
|
+
return false;
|
|
518
|
+
}
|
|
519
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
520
|
+
cursor += kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0);
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
524
|
+
cursor += k * n * sizeof(float);
|
|
525
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
526
|
+
cursor += n * sizeof(float);
|
|
527
|
+
|
|
528
|
+
if (!any_slot) {
|
|
529
|
+
return false;
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
size = cursor;
|
|
533
|
+
return true;
|
|
172
534
|
}
|
|
173
535
|
|
|
174
|
-
return
|
|
536
|
+
return false;
|
|
175
537
|
}
|
|
176
538
|
|
|
177
539
|
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
|
|
178
540
|
if (dst->op == GGML_OP_MUL_MAT) {
|
|
179
|
-
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
|
180
|
-
return
|
|
541
|
+
if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
|
|
542
|
+
return compute_forward_qx(params, dst);
|
|
181
543
|
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
|
182
544
|
return compute_forward_fp16(params, dst);
|
|
183
545
|
}
|
|
184
546
|
} else if (dst->op == GGML_OP_GET_ROWS) {
|
|
185
|
-
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
|
547
|
+
if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
|
|
186
548
|
return compute_forward_get_rows(params, dst);
|
|
187
549
|
}
|
|
188
550
|
}
|
|
@@ -196,12 +558,18 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
196
558
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
197
559
|
|
|
198
560
|
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
|
199
|
-
|
|
561
|
+
if (!kernels) {
|
|
562
|
+
return false;
|
|
563
|
+
}
|
|
200
564
|
|
|
201
565
|
const bool is_gemv = src1->ne[1] == 1;
|
|
202
566
|
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
203
567
|
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
204
568
|
GGML_ASSERT(kernel);
|
|
569
|
+
if (!kernels->rhs_info.pack_func_ex ||
|
|
570
|
+
!kernel->get_lhs_offset_ex || !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex) {
|
|
571
|
+
return false;
|
|
572
|
+
}
|
|
205
573
|
|
|
206
574
|
const int nth = params->nth;
|
|
207
575
|
const int ith = params->ith;
|
|
@@ -228,10 +596,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
228
596
|
const int64_t kr = (int64_t) kernel->get_kr();
|
|
229
597
|
const int64_t sr = (int64_t) kernel->get_sr();
|
|
230
598
|
|
|
231
|
-
const size_t lhs_packed_size =
|
|
232
|
-
const size_t rhs_packed_size =
|
|
233
|
-
const size_t kxn_size =
|
|
234
|
-
const size_t bias_size =
|
|
599
|
+
const size_t lhs_packed_size = lhs_info->packed_size_ex(m, k, 0, mr, kr, sr);
|
|
600
|
+
const size_t rhs_packed_size = kernels->rhs_info.packed_size_ex(n, k, nr, kr, 0);
|
|
601
|
+
const size_t kxn_size = k * n * sizeof(float);
|
|
602
|
+
const size_t bias_size = n * sizeof(float);
|
|
235
603
|
|
|
236
604
|
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
|
|
237
605
|
GGML_ASSERT(wsize_required <= params->wsize);
|
|
@@ -259,10 +627,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
259
627
|
const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
|
260
628
|
|
|
261
629
|
// Base packed offset (aligned) and per-row stride in bytes
|
|
262
|
-
const size_t base_packed_off
|
|
263
|
-
|
|
264
|
-
const size_t next_block_off = variant_call<size_t>(
|
|
265
|
-
lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
|
630
|
+
const size_t base_packed_off = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
|
|
631
|
+
const size_t next_block_off = lhs_info->get_packed_offset_ex(m_start + mr, k, 0, mr, kr, sr);
|
|
266
632
|
const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;
|
|
267
633
|
|
|
268
634
|
int64_t remaining = m_count;
|
|
@@ -278,9 +644,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
278
644
|
const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
|
|
279
645
|
void * dst_ptr = lhs_packed + dst_off;
|
|
280
646
|
|
|
281
|
-
|
|
282
|
-
(size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr,
|
|
283
|
-
/*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr);
|
|
647
|
+
lhs_info->pack_func_ex(take, k, 0, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
|
|
284
648
|
|
|
285
649
|
cur += take;
|
|
286
650
|
remaining -= take;
|
|
@@ -296,10 +660,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
296
660
|
reinterpret_cast<const uint16_t *>(rhs_batch_base),
|
|
297
661
|
rhs_stride);
|
|
298
662
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
/*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)),
|
|
302
|
-
rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr);
|
|
663
|
+
kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, n * sizeof(float),
|
|
664
|
+
rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
|
|
303
665
|
}
|
|
304
666
|
|
|
305
667
|
ggml_barrier(params->threadpool);
|
|
@@ -320,20 +682,15 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
320
682
|
const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
|
|
321
683
|
|
|
322
684
|
// LHS packed base at row 0 (consistent with packing above)
|
|
323
|
-
const size_t lhs_packed_offset0 =
|
|
324
|
-
|
|
325
|
-
const size_t
|
|
326
|
-
const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
|
|
685
|
+
const size_t lhs_packed_offset0 = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
|
|
686
|
+
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
|
|
687
|
+
const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
|
|
327
688
|
|
|
328
689
|
const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
|
|
329
690
|
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
|
|
330
691
|
float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
|
|
331
692
|
|
|
332
|
-
|
|
333
|
-
(size_t)m, (size_t)n_to_process, (size_t)k,
|
|
334
|
-
lhs_ptr, rhs_ptr,
|
|
335
|
-
dst_ptr, dst_stride, sizeof(float),
|
|
336
|
-
-FLT_MAX, FLT_MAX);
|
|
693
|
+
kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
|
337
694
|
}
|
|
338
695
|
}
|
|
339
696
|
|
|
@@ -345,108 +702,486 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
345
702
|
return true;
|
|
346
703
|
}
|
|
347
704
|
|
|
348
|
-
bool
|
|
349
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
|
705
|
+
bool compute_forward_qx(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
706
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
|
|
350
707
|
|
|
351
708
|
const ggml_tensor * src0 = dst->src[0];
|
|
352
709
|
const ggml_tensor * src1 = dst->src[1];
|
|
353
710
|
|
|
354
711
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
355
712
|
|
|
356
|
-
|
|
357
|
-
|
|
713
|
+
const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
|
|
714
|
+
const bool has_header = kleidiai_is_weight_header_valid(header);
|
|
715
|
+
const bool is_gemv = src1->ne[1] == 1;
|
|
716
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
717
|
+
const int slot_total = kleidiai_collect_kernel_chain(dst, kernel_chain);
|
|
358
718
|
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
719
|
+
auto weight_for_slot = [&](int slot_index, size_t & size_out) -> const uint8_t * {
|
|
720
|
+
if (slot_index < 0 || slot_index >= slot_total) {
|
|
721
|
+
return nullptr;
|
|
722
|
+
}
|
|
723
|
+
if (has_header) {
|
|
724
|
+
if (slot_index < header->slot_count) {
|
|
725
|
+
size_out = static_cast<size_t>(header->sizes[slot_index]);
|
|
726
|
+
return kleidiai_weight_slot_ptr(header, slot_index);
|
|
727
|
+
}
|
|
728
|
+
return nullptr;
|
|
729
|
+
}
|
|
730
|
+
if (slot_index == 0) {
|
|
731
|
+
size_out = ggml_nbytes(src0);
|
|
732
|
+
return static_cast<const uint8_t *>(src0->data);
|
|
733
|
+
}
|
|
734
|
+
return nullptr;
|
|
735
|
+
};
|
|
736
|
+
|
|
737
|
+
struct runtime_slot {
|
|
738
|
+
int slot_index;
|
|
739
|
+
ggml_kleidiai_kernels * kernels;
|
|
740
|
+
kernel_info * kernel;
|
|
741
|
+
lhs_packing_info * lhs_info;
|
|
742
|
+
size_t mr;
|
|
743
|
+
size_t nr;
|
|
744
|
+
size_t kr;
|
|
745
|
+
size_t sr;
|
|
746
|
+
size_t n_step;
|
|
747
|
+
size_t lhs_packed_size;
|
|
748
|
+
size_t lhs_offset;
|
|
749
|
+
size_t n_offset;
|
|
750
|
+
size_t n_cols;
|
|
751
|
+
int assigned_threads;
|
|
752
|
+
int thread_begin;
|
|
753
|
+
int thread_end;
|
|
754
|
+
const uint8_t * rhs_base;
|
|
755
|
+
};
|
|
756
|
+
|
|
757
|
+
std::array<runtime_slot, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> runtime{};
|
|
758
|
+
int runtime_count = 0;
|
|
759
|
+
|
|
760
|
+
for (int slot = 0; slot < slot_total && runtime_count < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
|
|
761
|
+
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
762
|
+
kernel_info * kinfo = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
763
|
+
lhs_packing_info * linfo = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
764
|
+
if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || !linfo->get_offset ||
|
|
765
|
+
!kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset) {
|
|
766
|
+
continue;
|
|
767
|
+
}
|
|
362
768
|
|
|
363
|
-
|
|
769
|
+
size_t rhs_size = 0;
|
|
770
|
+
const uint8_t * rhs_ptr = weight_for_slot(slot, rhs_size);
|
|
771
|
+
if (!rhs_ptr || rhs_size == 0) {
|
|
772
|
+
continue;
|
|
773
|
+
}
|
|
364
774
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
775
|
+
runtime[runtime_count] = {
|
|
776
|
+
slot,
|
|
777
|
+
kernels,
|
|
778
|
+
kinfo,
|
|
779
|
+
linfo,
|
|
780
|
+
kinfo->get_mr(),
|
|
781
|
+
kinfo->get_nr(),
|
|
782
|
+
kinfo->get_kr(),
|
|
783
|
+
kinfo->get_sr(),
|
|
784
|
+
kinfo->get_n_step(),
|
|
785
|
+
0,
|
|
786
|
+
0,
|
|
787
|
+
0,
|
|
788
|
+
0,
|
|
789
|
+
0,
|
|
790
|
+
0,
|
|
791
|
+
0,
|
|
792
|
+
rhs_ptr
|
|
793
|
+
};
|
|
794
|
+
++runtime_count;
|
|
795
|
+
}
|
|
796
|
+
|
|
797
|
+
if (runtime_count == 0) {
|
|
798
|
+
ggml_kleidiai_kernels * fallback = ggml_kleidiai_select_kernels(ctx.features, dst);
|
|
799
|
+
if (!fallback) {
|
|
800
|
+
return false;
|
|
801
|
+
}
|
|
802
|
+
kernel_info * kinfo = is_gemv ? &fallback->gemv : &fallback->gemm;
|
|
803
|
+
lhs_packing_info * linfo = is_gemv ? &fallback->gemv_lhs_info : &fallback->gemm_lhs_info;
|
|
804
|
+
rhs_packing_info * rinfo = &fallback->rhs_info;
|
|
805
|
+
if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex ||
|
|
806
|
+
!kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset ||
|
|
807
|
+
!rinfo || !rinfo->pack_func_ex || !rinfo->packed_size_ex) {
|
|
808
|
+
return false;
|
|
809
|
+
}
|
|
810
|
+
kernel_chain[0] = fallback;
|
|
811
|
+
runtime[0] = {
|
|
812
|
+
0,
|
|
813
|
+
fallback,
|
|
814
|
+
kinfo,
|
|
815
|
+
linfo,
|
|
816
|
+
kinfo->get_mr(),
|
|
817
|
+
kinfo->get_nr(),
|
|
818
|
+
kinfo->get_kr(),
|
|
819
|
+
kinfo->get_sr(),
|
|
820
|
+
kinfo->get_n_step(),
|
|
821
|
+
0,
|
|
822
|
+
0,
|
|
823
|
+
0,
|
|
824
|
+
0,
|
|
825
|
+
0,
|
|
826
|
+
0,
|
|
827
|
+
0,
|
|
828
|
+
nullptr
|
|
829
|
+
};
|
|
830
|
+
size_t rhs_size_fallback = 0;
|
|
831
|
+
const uint8_t * rhs_base = weight_for_slot(0, rhs_size_fallback);
|
|
832
|
+
if (!rhs_base) {
|
|
833
|
+
rhs_base = static_cast<const uint8_t *>(src0->data);
|
|
834
|
+
}
|
|
835
|
+
runtime[0].rhs_base = rhs_base;
|
|
836
|
+
runtime_count = 1;
|
|
837
|
+
}
|
|
838
|
+
|
|
839
|
+
const int nth_total = params->nth > 0 ? params->nth : 1;
|
|
840
|
+
const int ith_total = params->ith;
|
|
841
|
+
|
|
842
|
+
int sme_slot = -1;
|
|
843
|
+
for (int i = 0; i < runtime_count; ++i) {
|
|
844
|
+
if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
|
|
845
|
+
sme_slot = i;
|
|
846
|
+
break;
|
|
847
|
+
}
|
|
848
|
+
}
|
|
849
|
+
|
|
850
|
+
const int sme_cap_limit = ctx.sme_thread_cap;
|
|
851
|
+
const bool use_hybrid = sme_cap_limit > 0 &&
|
|
852
|
+
runtime_count > 1 &&
|
|
853
|
+
nth_total > sme_cap_limit;
|
|
854
|
+
// Heuristic: disable hybrid for very small workloads where per-slot overhead dominates.
|
|
855
|
+
// If rows are small or average columns per thread are small, keep single-slot.
|
|
856
|
+
size_t min_cols_per_thread = 0;
|
|
857
|
+
if (runtime_count > 0 && nth_total > 0) {
|
|
858
|
+
min_cols_per_thread = (size_t) std::max<int64_t>(1, (int64_t)ne01 / (int64_t)nth_total);
|
|
859
|
+
}
|
|
860
|
+
const bool too_small_for_hybrid = (min_cols_per_thread < 2) || (ne11 < 128);
|
|
861
|
+
|
|
862
|
+
const bool hybrid_enabled = use_hybrid && !too_small_for_hybrid;
|
|
863
|
+
|
|
864
|
+
if (!hybrid_enabled) {
|
|
865
|
+
int chosen_slot = 0;
|
|
866
|
+
if (too_small_for_hybrid && sme_slot != -1) {
|
|
867
|
+
chosen_slot = sme_slot;
|
|
868
|
+
} else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) {
|
|
869
|
+
chosen_slot = 1;
|
|
870
|
+
}
|
|
871
|
+
if (chosen_slot != 0 && chosen_slot < runtime_count) {
|
|
872
|
+
runtime[0] = runtime[chosen_slot];
|
|
873
|
+
}
|
|
874
|
+
runtime_count = runtime_count > 0 ? 1 : 0;
|
|
875
|
+
|
|
876
|
+
// Recompute SME slot based on the collapsed runtime[0]
|
|
877
|
+
sme_slot = -1;
|
|
878
|
+
if (runtime_count > 0 &&
|
|
879
|
+
(runtime[0].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
|
|
880
|
+
sme_slot = 0;
|
|
881
|
+
}
|
|
882
|
+
}
|
|
883
|
+
|
|
884
|
+
int sme_cap = kleidiai_sme_thread_cap();
|
|
885
|
+
if (sme_cap < 0) {
|
|
886
|
+
sme_cap = nth_total;
|
|
887
|
+
}
|
|
888
|
+
sme_cap = std::min(sme_cap, nth_total);
|
|
889
|
+
|
|
890
|
+
int threads_remaining = nth_total;
|
|
891
|
+
if (sme_slot != -1) {
|
|
892
|
+
int sme_threads = std::min(std::max(sme_cap, 0), threads_remaining);
|
|
893
|
+
runtime[sme_slot].assigned_threads = sme_threads;
|
|
894
|
+
threads_remaining -= sme_threads;
|
|
895
|
+
}
|
|
896
|
+
|
|
897
|
+
int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
|
|
898
|
+
int fallback_count = 0;
|
|
899
|
+
for (int i = 0; i < runtime_count; ++i) {
|
|
900
|
+
if (i == sme_slot) {
|
|
901
|
+
continue;
|
|
902
|
+
}
|
|
903
|
+
fallback_indices[fallback_count++] = i;
|
|
904
|
+
}
|
|
905
|
+
|
|
906
|
+
for (int fi = 0; fi < fallback_count; ++fi) {
|
|
907
|
+
if (threads_remaining <= 0) {
|
|
908
|
+
break;
|
|
909
|
+
}
|
|
910
|
+
const int slot_index = fallback_indices[fi];
|
|
911
|
+
const int slots_left = fallback_count - fi;
|
|
912
|
+
int share = (threads_remaining + slots_left - 1) / slots_left;
|
|
913
|
+
share = std::min(share, threads_remaining);
|
|
914
|
+
runtime[slot_index].assigned_threads = share;
|
|
915
|
+
threads_remaining -= share;
|
|
916
|
+
}
|
|
917
|
+
|
|
918
|
+
if (threads_remaining > 0) {
|
|
919
|
+
const int fallback_slot = (sme_slot != -1) ? sme_slot : 0;
|
|
920
|
+
runtime[fallback_slot].assigned_threads += threads_remaining;
|
|
921
|
+
threads_remaining = 0;
|
|
922
|
+
}
|
|
923
|
+
|
|
924
|
+
int thread_cursor = 0;
|
|
925
|
+
for (int i = 0; i < runtime_count; ++i) {
|
|
926
|
+
runtime[i].thread_begin = thread_cursor;
|
|
927
|
+
thread_cursor += runtime[i].assigned_threads;
|
|
928
|
+
runtime[i].thread_end = thread_cursor;
|
|
929
|
+
}
|
|
930
|
+
|
|
931
|
+
if (thread_cursor < nth_total && runtime_count > 0) {
|
|
932
|
+
runtime[runtime_count - 1].assigned_threads += nth_total - thread_cursor;
|
|
933
|
+
runtime[runtime_count - 1].thread_end = nth_total;
|
|
934
|
+
}
|
|
935
|
+
|
|
936
|
+
int local_slot = -1;
|
|
937
|
+
int local_ith = 0;
|
|
938
|
+
for (int i = 0; i < runtime_count; ++i) {
|
|
939
|
+
if (ith_total >= runtime[i].thread_begin && ith_total < runtime[i].thread_end) {
|
|
940
|
+
local_slot = i;
|
|
941
|
+
local_ith = ith_total - runtime[i].thread_begin;
|
|
942
|
+
break;
|
|
943
|
+
}
|
|
944
|
+
}
|
|
945
|
+
if (local_slot == -1) {
|
|
946
|
+
return false;
|
|
947
|
+
}
|
|
368
948
|
|
|
369
949
|
const size_t k = ne00;
|
|
370
950
|
const size_t m = ne11;
|
|
371
951
|
const size_t n = ne01;
|
|
372
952
|
|
|
373
|
-
size_t
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
953
|
+
size_t cursor = 0;
|
|
954
|
+
for (int i = 0; i < runtime_count; ++i) {
|
|
955
|
+
const ggml_type slot_rhs_type = runtime[i].kernels->rhs_type;
|
|
956
|
+
const size_t slot_pack_size_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
|
|
957
|
+
slot_rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
|
|
958
|
+
runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, slot_pack_size_arg, runtime[i].mr, runtime[i].kr, runtime[i].sr);
|
|
959
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
960
|
+
runtime[i].lhs_offset = cursor;
|
|
961
|
+
cursor += runtime[i].lhs_packed_size;
|
|
962
|
+
}
|
|
380
963
|
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
const size_t n_start = ith * num_n_per_thread;
|
|
964
|
+
GGML_ASSERT(cursor <= params->wsize);
|
|
965
|
+
uint8_t * scratch = static_cast<uint8_t *>(params->wdata);
|
|
384
966
|
|
|
385
|
-
size_t
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
967
|
+
size_t assigned_cols = 0;
|
|
968
|
+
uint64_t weighted_total = 0;
|
|
969
|
+
if (runtime_count > 1 && sme_slot != -1) {
|
|
970
|
+
for (int i = 0; i < runtime_count; ++i) {
|
|
971
|
+
const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
|
|
972
|
+
weighted_total += (uint64_t)runtime[i].assigned_threads * weight;
|
|
390
973
|
}
|
|
391
974
|
}
|
|
975
|
+
for (int i = 0; i < runtime_count; ++i) {
|
|
976
|
+
runtime[i].n_offset = assigned_cols;
|
|
977
|
+
if (runtime[i].assigned_threads == 0) {
|
|
978
|
+
runtime[i].n_cols = 0;
|
|
979
|
+
continue;
|
|
980
|
+
}
|
|
981
|
+
const size_t remaining_cols = n - assigned_cols;
|
|
982
|
+
if (remaining_cols == 0) {
|
|
983
|
+
runtime[i].n_cols = 0;
|
|
984
|
+
continue;
|
|
985
|
+
}
|
|
986
|
+
const size_t step = runtime[i].n_step ? runtime[i].n_step : 1;
|
|
987
|
+
size_t target = 0;
|
|
988
|
+
if (weighted_total > 0) {
|
|
989
|
+
const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
|
|
990
|
+
target = (size_t)(((uint64_t)n * runtime[i].assigned_threads * weight) / weighted_total);
|
|
991
|
+
} else {
|
|
992
|
+
target = (size_t)(((uint64_t)n * runtime[i].assigned_threads) / nth_total);
|
|
993
|
+
}
|
|
994
|
+
target = std::min(target, remaining_cols);
|
|
995
|
+
size_t aligned = round_down(target, step);
|
|
996
|
+
if (aligned == 0 && remaining_cols >= step) {
|
|
997
|
+
aligned = step;
|
|
998
|
+
}
|
|
999
|
+
runtime[i].n_cols = aligned;
|
|
1000
|
+
assigned_cols += aligned;
|
|
1001
|
+
}
|
|
392
1002
|
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
1003
|
+
if (assigned_cols < n) {
|
|
1004
|
+
for (int i = runtime_count - 1; i >= 0; --i) {
|
|
1005
|
+
if (runtime[i].assigned_threads > 0) {
|
|
1006
|
+
runtime[i].n_cols += n - assigned_cols;
|
|
1007
|
+
break;
|
|
1008
|
+
}
|
|
1009
|
+
}
|
|
399
1010
|
}
|
|
1011
|
+
const size_t dst_stride = dst->nb[1];
|
|
1012
|
+
|
|
1013
|
+
for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) {
|
|
1014
|
+
const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
|
|
1015
|
+
uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
|
|
400
1016
|
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
1017
|
+
if (runtime[local_slot].assigned_threads > 0) {
|
|
1018
|
+
runtime_slot & slot = runtime[local_slot];
|
|
1019
|
+
const ggml_type slot_rhs_type = slot.kernels->rhs_type;
|
|
1020
|
+
const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
|
|
1021
|
+
slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
|
|
1022
|
+
const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr);
|
|
1023
|
+
int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads;
|
|
1024
|
+
max_threads = std::max<int64_t>(1, max_threads);
|
|
1025
|
+
const int64_t use_threads = std::min<int64_t>(slot.assigned_threads, max_threads);
|
|
407
1026
|
|
|
408
|
-
|
|
409
|
-
|
|
1027
|
+
if (local_ith < use_threads) {
|
|
1028
|
+
const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / use_threads), slot.mr);
|
|
1029
|
+
const int64_t num_m_per_threadN_1 = (int64_t)m - (use_threads - 1) * num_m_per_thread0;
|
|
1030
|
+
|
|
1031
|
+
const int64_t m_start = (int64_t)local_ith * num_m_per_thread0;
|
|
1032
|
+
const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
|
1033
|
+
|
|
1034
|
+
const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
|
|
1035
|
+
const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
|
|
1036
|
+
const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0;
|
|
1037
|
+
|
|
1038
|
+
int64_t remaining = m_count;
|
|
1039
|
+
int64_t cur = m_start;
|
|
1040
|
+
|
|
1041
|
+
uint8_t * lhs_packed = scratch + slot.lhs_offset;
|
|
1042
|
+
while (remaining > 0) {
|
|
1043
|
+
const int64_t row_in_group = cur;
|
|
1044
|
+
const int64_t avail = (int64_t)m - row_in_group;
|
|
1045
|
+
const int64_t take = std::min(avail, remaining);
|
|
1046
|
+
|
|
1047
|
+
const size_t src_off = slot.lhs_info->get_offset(row_in_group, src1->nb[1]);
|
|
1048
|
+
const void * src_ptr = lhs_batch_base + src_off;
|
|
1049
|
+
const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
|
|
1050
|
+
void * dst_ptr = lhs_packed + dst_off;
|
|
1051
|
+
|
|
1052
|
+
slot.lhs_info->pack_func_ex(take, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr);
|
|
1053
|
+
|
|
1054
|
+
cur += take;
|
|
1055
|
+
remaining -= take;
|
|
1056
|
+
}
|
|
1057
|
+
}
|
|
1058
|
+
}
|
|
410
1059
|
|
|
411
|
-
|
|
1060
|
+
ggml_barrier(params->threadpool);
|
|
412
1061
|
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
1062
|
+
runtime_slot & slot = runtime[local_slot];
|
|
1063
|
+
if (slot.n_cols > 0 && slot.assigned_threads > 0) {
|
|
1064
|
+
int64_t active_threads = slot.assigned_threads;
|
|
1065
|
+
const int64_t max_threads = slot.n_step ? (slot.n_cols / slot.n_step) : slot.assigned_threads;
|
|
1066
|
+
if (max_threads > 0) {
|
|
1067
|
+
active_threads = std::min<int64_t>(active_threads, std::max<int64_t>(1, max_threads));
|
|
1068
|
+
}
|
|
1069
|
+
active_threads = std::max<int64_t>(1, active_threads);
|
|
1070
|
+
|
|
1071
|
+
if (local_ith < active_threads) {
|
|
1072
|
+
const size_t step = slot.n_step ? slot.n_step : 1;
|
|
1073
|
+
const size_t chunk0 = round_down((size_t)(slot.n_cols / active_threads), step);
|
|
1074
|
+
const size_t chunkN = slot.n_cols - (active_threads - 1) * chunk0;
|
|
1075
|
+
const size_t local_start = (size_t)local_ith * chunk0;
|
|
1076
|
+
const size_t cols = (local_ith == active_threads - 1) ? chunkN : chunk0;
|
|
1077
|
+
|
|
1078
|
+
if (cols > 0) {
|
|
1079
|
+
const ggml_type slot_rhs_type = slot.kernels->rhs_type;
|
|
1080
|
+
const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
|
|
1081
|
+
slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
|
|
1082
|
+
const size_t slot_rhs_block_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
|
|
1083
|
+
slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
|
|
1084
|
+
const size_t global_start = slot.n_offset + local_start;
|
|
1085
|
+
const size_t lhs_packed_offset = slot.lhs_info->get_packed_offset_ex(0, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
|
|
1086
|
+
const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot_rhs_block_arg);
|
|
1087
|
+
const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride);
|
|
1088
|
+
|
|
1089
|
+
const uint8_t * lhs_ptr = scratch + slot.lhs_offset + lhs_packed_offset;
|
|
1090
|
+
const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset;
|
|
1091
|
+
float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
|
|
1092
|
+
|
|
1093
|
+
slot.kernel->run_kernel_ex(m, cols, k, slot_rhs_block_arg,
|
|
1094
|
+
lhs_ptr,
|
|
1095
|
+
rhs_ptr,
|
|
1096
|
+
dst_ptr,
|
|
1097
|
+
dst_stride,
|
|
1098
|
+
sizeof(float),
|
|
1099
|
+
-FLT_MAX,
|
|
1100
|
+
FLT_MAX);
|
|
1101
|
+
}
|
|
1102
|
+
}
|
|
1103
|
+
}
|
|
421
1104
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
1105
|
+
if (batch_idx != ne12 - 1) {
|
|
1106
|
+
ggml_barrier(params->threadpool);
|
|
1107
|
+
}
|
|
425
1108
|
}
|
|
426
1109
|
|
|
427
1110
|
return true;
|
|
428
1111
|
}
|
|
429
1112
|
|
|
430
1113
|
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
431
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
|
432
|
-
GGML_ASSERT(ctx.kernels);
|
|
433
|
-
|
|
1114
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
|
|
434
1115
|
const ggml_tensor * src0 = dst->src[0];
|
|
435
1116
|
const ggml_tensor * src1 = dst->src[1];
|
|
436
1117
|
|
|
437
1118
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
438
1119
|
|
|
439
|
-
|
|
440
|
-
|
|
1120
|
+
const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
|
|
1121
|
+
const bool has_header = kleidiai_is_weight_header_valid(header);
|
|
1122
|
+
|
|
1123
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
1124
|
+
const bool want_q8 = src0->type == GGML_TYPE_Q8_0;
|
|
1125
|
+
const int chain_count = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
|
|
1126
|
+
: kleidiai_collect_q4_chain(kernel_chain);
|
|
1127
|
+
|
|
1128
|
+
ggml_kleidiai_kernels * kernels = nullptr;
|
|
1129
|
+
const uint8_t * packed_base = static_cast<const uint8_t *>(src0->data);
|
|
1130
|
+
|
|
1131
|
+
if (has_header && chain_count > 0) {
|
|
1132
|
+
int select_slot = 0;
|
|
1133
|
+
if (select_slot >= header->slot_count) {
|
|
1134
|
+
select_slot = header->slot_count - 1;
|
|
1135
|
+
}
|
|
1136
|
+
if (select_slot >= 0 && select_slot < chain_count) {
|
|
1137
|
+
kernels = kernel_chain[select_slot];
|
|
1138
|
+
const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, select_slot);
|
|
1139
|
+
if (slot_ptr) {
|
|
1140
|
+
packed_base = slot_ptr;
|
|
1141
|
+
}
|
|
1142
|
+
}
|
|
1143
|
+
}
|
|
1144
|
+
|
|
1145
|
+
if (!kernels && chain_count > 0) {
|
|
1146
|
+
kernels = kernel_chain[0];
|
|
1147
|
+
if (has_header) {
|
|
1148
|
+
const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, 0);
|
|
1149
|
+
if (slot_ptr) {
|
|
1150
|
+
packed_base = slot_ptr;
|
|
1151
|
+
}
|
|
1152
|
+
}
|
|
1153
|
+
}
|
|
1154
|
+
|
|
1155
|
+
if (!kernels) {
|
|
1156
|
+
return false;
|
|
1157
|
+
}
|
|
1158
|
+
|
|
1159
|
+
rhs_packing_info * rhs_info = &kernels->rhs_info;
|
|
1160
|
+
kernel_info * kernel = &kernels->gemm;
|
|
1161
|
+
if (!rhs_info->to_float || !kernel->get_nr) {
|
|
1162
|
+
return false;
|
|
1163
|
+
}
|
|
441
1164
|
|
|
442
1165
|
const int64_t nc = ne00;
|
|
443
1166
|
const int64_t nr = ggml_nelements(src1);
|
|
444
1167
|
|
|
1168
|
+
const ggml_type rhs_type = kernels->rhs_type;
|
|
1169
|
+
size_t block_len = 0;
|
|
1170
|
+
size_t num_bytes_multiplier = 0;
|
|
1171
|
+
if (rhs_type == GGML_TYPE_Q4_0) {
|
|
1172
|
+
block_len = QK4_0;
|
|
1173
|
+
num_bytes_multiplier = sizeof(uint16_t);
|
|
1174
|
+
} else if (rhs_type == GGML_TYPE_Q8_0) {
|
|
1175
|
+
block_len = QK8_0;
|
|
1176
|
+
num_bytes_multiplier = sizeof(float);
|
|
1177
|
+
} else {
|
|
1178
|
+
return false;
|
|
1179
|
+
}
|
|
1180
|
+
|
|
445
1181
|
const size_t block_rows = kernel->get_nr();
|
|
446
1182
|
const size_t kr = kernel->get_kr();
|
|
447
1183
|
|
|
448
|
-
const size_t
|
|
449
|
-
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
|
|
1184
|
+
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, block_len);
|
|
450
1185
|
|
|
451
1186
|
const int ith = params->ith;
|
|
452
1187
|
const int nth = params->nth;
|
|
@@ -461,7 +1196,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
461
1196
|
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
|
|
462
1197
|
|
|
463
1198
|
float *out = (float *)((char *)dst->data + i * nb1);
|
|
464
|
-
rhs_info->to_float(
|
|
1199
|
+
rhs_info->to_float(packed_base, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
|
|
465
1200
|
}
|
|
466
1201
|
|
|
467
1202
|
return true;
|
|
@@ -469,21 +1204,136 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
469
1204
|
|
|
470
1205
|
public:
|
|
471
1206
|
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
|
472
|
-
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
|
473
|
-
GGML_ASSERT(ctx.kernels);
|
|
1207
|
+
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0);
|
|
474
1208
|
const size_t n = tensor->ne[1];
|
|
475
1209
|
const size_t k = tensor->ne[0];
|
|
476
|
-
size_t nr = ctx.kernels->gemm.get_nr();
|
|
477
|
-
size_t kr = ctx.kernels->gemm.get_kr();
|
|
478
|
-
size_t sr = ctx.kernels->gemm.get_sr();
|
|
479
1210
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
1211
|
+
kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(tensor->data);
|
|
1212
|
+
if (!header) {
|
|
1213
|
+
return -1;
|
|
1214
|
+
}
|
|
1215
|
+
|
|
1216
|
+
header->magic = GGML_KLEIDIAI_PACK_MAGIC;
|
|
1217
|
+
header->version = GGML_KLEIDIAI_PACK_VERSION;
|
|
1218
|
+
header->slot_count = 0;
|
|
1219
|
+
|
|
1220
|
+
uint8_t * base_ptr = static_cast<uint8_t *>(tensor->data);
|
|
1221
|
+
size_t cursor = sizeof(kleidiai_weight_header);
|
|
1222
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
1223
|
+
|
|
1224
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
1225
|
+
const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
|
|
1226
|
+
const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
|
|
1227
|
+
: kleidiai_collect_q4_chain(kernel_chain);
|
|
1228
|
+
const bool allow_fallback = kleidiai_pack_fallback_allowed();
|
|
1229
|
+
|
|
1230
|
+
std::vector<int8_t> qdata;
|
|
1231
|
+
std::vector<float> scales;
|
|
1232
|
+
|
|
1233
|
+
if (want_q8 && slot_total > 0) {
|
|
1234
|
+
qdata.resize(n * k, 0);
|
|
1235
|
+
scales.resize(n, 0.0f);
|
|
1236
|
+
|
|
1237
|
+
const size_t row_stride = tensor->nb[1];
|
|
1238
|
+
const size_t k_blocks = (k + QK8_0 - 1) / QK8_0;
|
|
1239
|
+
|
|
1240
|
+
for (size_t row = 0; row < n; ++row) {
|
|
1241
|
+
const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
|
|
1242
|
+
static_cast<const uint8_t *>(data) + row * row_stride);
|
|
1243
|
+
|
|
1244
|
+
float max_abs = 0.0f;
|
|
1245
|
+
for (size_t block = 0; block < k_blocks; ++block) {
|
|
1246
|
+
const block_q8_0 & blk = row_blocks[block];
|
|
1247
|
+
const float d = GGML_FP16_TO_FP32(blk.d);
|
|
1248
|
+
for (size_t l = 0; l < QK8_0; ++l) {
|
|
1249
|
+
const size_t linear_idx = block * QK8_0 + l;
|
|
1250
|
+
if (linear_idx >= k) {
|
|
1251
|
+
break;
|
|
1252
|
+
}
|
|
1253
|
+
const float value = d * static_cast<float>(blk.qs[l]);
|
|
1254
|
+
max_abs = std::max(max_abs, std::fabs(value));
|
|
1255
|
+
}
|
|
1256
|
+
}
|
|
1257
|
+
|
|
1258
|
+
float scale = max_abs > 0.0f ? max_abs / 127.0f : 0.0f;
|
|
1259
|
+
scales[row] = scale;
|
|
1260
|
+
const float inv_scale = scale > 0.0f ? 1.0f / scale : 0.0f;
|
|
1261
|
+
|
|
1262
|
+
for (size_t block = 0; block < k_blocks; ++block) {
|
|
1263
|
+
const block_q8_0 & blk = row_blocks[block];
|
|
1264
|
+
const float d = GGML_FP16_TO_FP32(blk.d);
|
|
1265
|
+
for (size_t l = 0; l < QK8_0; ++l) {
|
|
1266
|
+
const size_t linear_idx = block * QK8_0 + l;
|
|
1267
|
+
if (linear_idx >= k) {
|
|
1268
|
+
break;
|
|
1269
|
+
}
|
|
1270
|
+
const float value = d * static_cast<float>(blk.qs[l]);
|
|
1271
|
+
int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
|
|
1272
|
+
q = std::clamp(q, -127, 127);
|
|
1273
|
+
qdata[row * k + linear_idx] = static_cast<int8_t>(q);
|
|
1274
|
+
}
|
|
1275
|
+
}
|
|
1276
|
+
}
|
|
1277
|
+
}
|
|
1278
|
+
|
|
1279
|
+
for (int slot = 0; slot < slot_total && slot < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
|
|
1280
|
+
if (!allow_fallback && slot > 0) {
|
|
1281
|
+
break;
|
|
1282
|
+
}
|
|
1283
|
+
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
1284
|
+
kernel_info * kernel = &kernels->gemm;
|
|
1285
|
+
rhs_packing_info * rhs_info = &kernels->rhs_info;
|
|
1286
|
+
if (!rhs_info || !rhs_info->pack_func_ex || !rhs_info->packed_size_ex || !kernel) {
|
|
1287
|
+
continue;
|
|
1288
|
+
}
|
|
1289
|
+
|
|
1290
|
+
const size_t nr = kernel->get_nr();
|
|
1291
|
+
const size_t kr = kernel->get_kr();
|
|
1292
|
+
const size_t sr = kernel->get_sr();
|
|
1293
|
+
const ggml_type rhs_type = kernels->rhs_type;
|
|
1294
|
+
const size_t block_len = rhs_type == GGML_TYPE_Q8_0 ? QK8_0 :
|
|
1295
|
+
rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : 0;
|
|
1296
|
+
if (block_len == 0) {
|
|
1297
|
+
continue;
|
|
1298
|
+
}
|
|
1299
|
+
|
|
1300
|
+
const size_t packed_size = rhs_info->packed_size_ex(n, k, nr, kr, block_len);
|
|
1301
|
+
const size_t aligned_cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
1302
|
+
|
|
1303
|
+
uint8_t * dst_ptr = base_ptr + aligned_cursor;
|
|
1304
|
+
|
|
1305
|
+
if (rhs_type == GGML_TYPE_Q4_0) {
|
|
1306
|
+
struct kai_rhs_pack_qs4cxs1s0_param params;
|
|
1307
|
+
params.lhs_zero_point = 1;
|
|
1308
|
+
params.rhs_zero_point = 8;
|
|
1309
|
+
rhs_info->pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
|
|
1310
|
+
static_cast<const uint8_t *>(data), nullptr, nullptr,
|
|
1311
|
+
dst_ptr, 0, ¶ms);
|
|
1312
|
+
} else if (rhs_type == GGML_TYPE_Q8_0) {
|
|
1313
|
+
struct kai_rhs_pack_qsi8cx_params params;
|
|
1314
|
+
params.lhs_zero_point = 1;
|
|
1315
|
+
params.scale_multiplier = 1.0f;
|
|
1316
|
+
rhs_info->pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
|
|
1317
|
+
qdata.data(), nullptr, scales.data(),
|
|
1318
|
+
dst_ptr, 0, ¶ms);
|
|
1319
|
+
} else {
|
|
1320
|
+
continue;
|
|
1321
|
+
}
|
|
1322
|
+
|
|
1323
|
+
header->offsets[header->slot_count] = aligned_cursor;
|
|
1324
|
+
header->sizes[header->slot_count] = packed_size;
|
|
1325
|
+
++header->slot_count;
|
|
1326
|
+
|
|
1327
|
+
cursor = aligned_cursor + packed_size;
|
|
1328
|
+
}
|
|
1329
|
+
|
|
1330
|
+
if (header->slot_count == 0) {
|
|
1331
|
+
header->magic = 0;
|
|
1332
|
+
header->version = 0;
|
|
1333
|
+
memcpy(tensor->data, data, data_size);
|
|
1334
|
+
}
|
|
484
1335
|
|
|
485
1336
|
return 0;
|
|
486
|
-
GGML_UNUSED(data_size);
|
|
487
1337
|
}
|
|
488
1338
|
};
|
|
489
1339
|
|
|
@@ -513,9 +1363,8 @@ static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t bu
|
|
|
513
1363
|
}
|
|
514
1364
|
|
|
515
1365
|
static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
|
516
|
-
return "CPU_KLEIDIAI";
|
|
517
|
-
|
|
518
1366
|
GGML_UNUSED(buft);
|
|
1367
|
+
return "CPU_KLEIDIAI";
|
|
519
1368
|
}
|
|
520
1369
|
|
|
521
1370
|
static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
@@ -534,33 +1383,80 @@ static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(
|
|
|
534
1383
|
}
|
|
535
1384
|
|
|
536
1385
|
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
537
|
-
return TENSOR_ALIGNMENT;
|
|
538
|
-
|
|
539
1386
|
GGML_UNUSED(buft);
|
|
1387
|
+
return TENSOR_ALIGNMENT;
|
|
540
1388
|
}
|
|
541
1389
|
|
|
542
1390
|
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
|
543
|
-
|
|
544
|
-
GGML_ASSERT(ctx.kernels);
|
|
1391
|
+
GGML_UNUSED(buft);
|
|
545
1392
|
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
const size_t kr = ctx.kernels->gemm.get_kr();
|
|
1393
|
+
if (tensor->type != GGML_TYPE_Q4_0 && tensor->type != GGML_TYPE_Q8_0) {
|
|
1394
|
+
return ggml_nbytes(tensor);
|
|
1395
|
+
}
|
|
550
1396
|
|
|
551
|
-
|
|
1397
|
+
const size_t n = tensor->ne[1];
|
|
1398
|
+
const size_t k = tensor->ne[0];
|
|
552
1399
|
|
|
553
|
-
|
|
1400
|
+
size_t cursor = sizeof(kleidiai_weight_header);
|
|
1401
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
1402
|
+
|
|
1403
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
1404
|
+
const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
|
|
1405
|
+
const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
|
|
1406
|
+
: kleidiai_collect_q4_chain(kernel_chain);
|
|
1407
|
+
const bool allow_fallback = kleidiai_pack_fallback_allowed();
|
|
1408
|
+
|
|
1409
|
+
size_t slot_count = 0;
|
|
1410
|
+
for (int slot = 0; slot < slot_total; ++slot) {
|
|
1411
|
+
if (!allow_fallback && slot > 0) {
|
|
1412
|
+
break;
|
|
1413
|
+
}
|
|
1414
|
+
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
1415
|
+
if (!kernels) {
|
|
1416
|
+
continue;
|
|
1417
|
+
}
|
|
1418
|
+
kernel_info * kernel = &kernels->gemm;
|
|
1419
|
+
rhs_packing_info * rhs_info = &kernels->rhs_info;
|
|
1420
|
+
if (!kernel || !rhs_info || !rhs_info->packed_size_ex) {
|
|
1421
|
+
continue;
|
|
1422
|
+
}
|
|
1423
|
+
|
|
1424
|
+
const ggml_type rhs_type = kernels->rhs_type;
|
|
1425
|
+
const size_t block_len = rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
|
|
1426
|
+
rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
|
|
1427
|
+
if (block_len == 0) {
|
|
1428
|
+
continue;
|
|
1429
|
+
}
|
|
1430
|
+
|
|
1431
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
1432
|
+
cursor += rhs_info->packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), block_len);
|
|
1433
|
+
++slot_count;
|
|
1434
|
+
}
|
|
1435
|
+
|
|
1436
|
+
if (slot_count == 0) {
|
|
1437
|
+
return ggml_nbytes(tensor);
|
|
1438
|
+
}
|
|
1439
|
+
|
|
1440
|
+
return std::max(cursor, ggml_nbytes(tensor));
|
|
554
1441
|
}
|
|
555
1442
|
|
|
556
1443
|
namespace ggml::cpu::kleidiai {
|
|
557
1444
|
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
558
1445
|
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
|
1446
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
1447
|
+
const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
|
|
559
1448
|
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
|
|
560
|
-
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
|
1449
|
+
(op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
|
|
561
1450
|
op->src[0]->buffer &&
|
|
562
1451
|
(ggml_n_dims(op->src[0]) == 2) &&
|
|
563
|
-
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() &&
|
|
1452
|
+
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() &&
|
|
1453
|
+
slot_total > 0) {
|
|
1454
|
+
if (op->src[0]->type == GGML_TYPE_Q4_0 && ctx.kernels_q4 == nullptr) {
|
|
1455
|
+
return false;
|
|
1456
|
+
}
|
|
1457
|
+
if (op->src[0]->type == GGML_TYPE_Q8_0 && ctx.kernels_q8 == nullptr) {
|
|
1458
|
+
return false;
|
|
1459
|
+
}
|
|
564
1460
|
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
|
565
1461
|
return false;
|
|
566
1462
|
}
|
|
@@ -576,14 +1472,17 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
|
576
1472
|
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
|
|
577
1473
|
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
|
578
1474
|
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
|
579
|
-
}
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
1475
|
+
} else {
|
|
1476
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
1477
|
+
const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
|
|
1478
|
+
const bool has_kernel = slot_total > 0;
|
|
1479
|
+
if (has_kernel && op->src[1]->ne[1] > 1) {
|
|
1480
|
+
if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
|
|
1481
|
+
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
|
|
1482
|
+
return nullptr;
|
|
1483
|
+
}
|
|
1484
|
+
return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
|
|
584
1485
|
}
|
|
585
|
-
|
|
586
|
-
return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
|
|
587
1486
|
}
|
|
588
1487
|
}
|
|
589
1488
|
return nullptr;
|