whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -1,20 +1,31 @@
|
|
|
1
|
-
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
|
1
|
+
// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
|
2
2
|
// SPDX-License-Identifier: MIT
|
|
3
3
|
//
|
|
4
4
|
#include <arm_neon.h>
|
|
5
5
|
#include <assert.h>
|
|
6
|
+
#include <stdio.h>
|
|
6
7
|
#include <atomic>
|
|
7
8
|
#include <cfloat>
|
|
8
|
-
#include <cmath>
|
|
9
9
|
#include <algorithm>
|
|
10
|
+
#include <cmath>
|
|
10
11
|
#include <stdexcept>
|
|
11
12
|
#include <stdint.h>
|
|
12
13
|
#include <string.h>
|
|
13
14
|
#include <string>
|
|
14
15
|
#include <vector>
|
|
16
|
+
#include <array>
|
|
17
|
+
#include <cstddef>
|
|
18
|
+
#include <cstdint>
|
|
19
|
+
#include <fstream>
|
|
20
|
+
#include <set>
|
|
21
|
+
#include <iostream>
|
|
22
|
+
#include <climits>
|
|
15
23
|
#if defined(__linux__)
|
|
16
24
|
#include <asm/hwcap.h>
|
|
17
25
|
#include <sys/auxv.h>
|
|
26
|
+
#include <sys/types.h>
|
|
27
|
+
#include <sys/stat.h>
|
|
28
|
+
#include <unistd.h>
|
|
18
29
|
#elif defined(__APPLE__)
|
|
19
30
|
#include <string_view>
|
|
20
31
|
#include <sys/sysctl.h>
|
|
@@ -39,11 +50,18 @@
|
|
|
39
50
|
#define GGML_COMMON_DECL_CPP
|
|
40
51
|
#include "ggml-common.h"
|
|
41
52
|
|
|
53
|
+
static constexpr int GGML_KLEIDIAI_MAX_KERNEL_SLOTS = 2;
|
|
54
|
+
static constexpr uint32_t GGML_KLEIDIAI_PACK_MAGIC = 0x4b4c4149; // "KLAI"
|
|
55
|
+
static constexpr uint16_t GGML_KLEIDIAI_PACK_VERSION = 1;
|
|
56
|
+
static constexpr size_t GGML_KLEIDIAI_PACK_ALIGN = 64;
|
|
57
|
+
|
|
42
58
|
struct ggml_kleidiai_context {
|
|
43
59
|
cpu_feature features;
|
|
44
60
|
ggml_kleidiai_kernels * kernels_q4;
|
|
45
61
|
ggml_kleidiai_kernels * kernels_q8;
|
|
46
|
-
|
|
62
|
+
int sme_thread_cap; // <= 0 means “SME disabled/unknown”;
|
|
63
|
+
int thread_hint; // <= 0 means “no hint”
|
|
64
|
+
} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1 };
|
|
47
65
|
|
|
48
66
|
static const char* cpu_feature_to_string(cpu_feature f) {
|
|
49
67
|
if (f == CPU_FEATURE_NONE) {
|
|
@@ -63,41 +81,335 @@ static const char* cpu_feature_to_string(cpu_feature f) {
|
|
|
63
81
|
}
|
|
64
82
|
}
|
|
65
83
|
|
|
66
|
-
static
|
|
84
|
+
static size_t detect_num_smcus() {
|
|
85
|
+
if (!ggml_cpu_has_sme()) {
|
|
86
|
+
return 0;
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
#if defined(__linux__) && defined(__aarch64__)
|
|
90
|
+
// Linux/aarch64: Best-effort count of Streaming Mode Compute Units (SMCUs) via SMIDR_EL1 sysfs.
|
|
91
|
+
size_t num_private = 0;
|
|
92
|
+
std::set<uint32_t> shared_ids;
|
|
93
|
+
|
|
94
|
+
for (size_t cpu = 0;; ++cpu) {
|
|
95
|
+
const std::string path =
|
|
96
|
+
"/sys/devices/system/cpu/cpu" + std::to_string(cpu) +
|
|
97
|
+
"/regs/identification/smidr_el1";
|
|
98
|
+
|
|
99
|
+
std::ifstream file(path);
|
|
100
|
+
if (!file.is_open()) {
|
|
101
|
+
break;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
uint64_t smidr = 0;
|
|
105
|
+
if (!(file >> std::hex >> smidr)) {
|
|
106
|
+
continue;
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
// Arm ARM: SMIDR_EL1
|
|
110
|
+
const uint32_t sh = (uint32_t)((smidr >> 13) & 0x3);
|
|
111
|
+
// Build an "affinity-like" identifier for shared SMCUs.
|
|
112
|
+
// Keep the original packing logic, but isolate it here.
|
|
113
|
+
const uint32_t id = (uint32_t)((smidr & 0xFFFu) | ((smidr >> 20) & 0xFFFFF000u));
|
|
114
|
+
|
|
115
|
+
switch (sh) {
|
|
116
|
+
case 0b10: // private SMCU
|
|
117
|
+
++num_private;
|
|
118
|
+
break;
|
|
119
|
+
case 0b11: // shared SMCU
|
|
120
|
+
shared_ids.emplace(id);
|
|
121
|
+
break;
|
|
122
|
+
case 0b00:
|
|
123
|
+
// Ambiguous / implementation-defined. Be conservative:
|
|
124
|
+
// treat id==0 as private, otherwise as shared.
|
|
125
|
+
if (id == 0) ++num_private;
|
|
126
|
+
else shared_ids.emplace(id);
|
|
127
|
+
break;
|
|
128
|
+
default:
|
|
129
|
+
break;
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
return num_private + shared_ids.size();
|
|
134
|
+
|
|
135
|
+
#elif defined(__APPLE__) && defined(__aarch64__)
|
|
136
|
+
// table for known M4 variants. Users can override via GGML_KLEIDIAI_SME=<n>.
|
|
137
|
+
char chip_name[256] = {};
|
|
138
|
+
size_t size = sizeof(chip_name);
|
|
139
|
+
|
|
140
|
+
if (sysctlbyname("machdep.cpu.brand_string", chip_name, &size, nullptr, 0) == 0) {
|
|
141
|
+
const std::string brand(chip_name);
|
|
142
|
+
|
|
143
|
+
struct ModelSMCU { const char *match; size_t smcus; };
|
|
144
|
+
static const ModelSMCU table[] = {
|
|
145
|
+
{ "M4 Ultra", 2 },
|
|
146
|
+
{ "M4 Max", 2 },
|
|
147
|
+
{ "M4 Pro", 2 },
|
|
148
|
+
{ "M4", 1 },
|
|
149
|
+
};
|
|
67
150
|
|
|
151
|
+
for (const auto &e : table) {
|
|
152
|
+
if (brand.find(e.match) != std::string::npos) {
|
|
153
|
+
return e.smcus;
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
return 1;
|
|
158
|
+
|
|
159
|
+
#else
|
|
160
|
+
return 1;
|
|
161
|
+
#endif
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
static int parse_uint_env(const char *s, const char *name, bool *ok) {
|
|
165
|
+
if (!s) { *ok = false; return 0; }
|
|
166
|
+
char *end = nullptr;
|
|
167
|
+
long v = strtol(s, &end, 10);
|
|
168
|
+
if (end == s || *end != '\0') {
|
|
169
|
+
GGML_LOG_WARN("kleidiai: invalid %s='%s' (expected integer)\n", name, s);
|
|
170
|
+
*ok = false;
|
|
171
|
+
return 0;
|
|
172
|
+
}
|
|
173
|
+
if (v < 0 || v > INT_MAX) {
|
|
174
|
+
GGML_LOG_WARN("kleidiai: out-of-range %s='%s'\n", name, s);
|
|
175
|
+
*ok = false;
|
|
176
|
+
return 0;
|
|
177
|
+
}
|
|
178
|
+
*ok = true;
|
|
179
|
+
return (int)v;
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
static void init_kleidiai_context(void) {
|
|
68
183
|
ggml_critical_section_start();
|
|
69
184
|
static bool initialized = false;
|
|
70
185
|
|
|
71
186
|
if (!initialized) {
|
|
72
187
|
initialized = true;
|
|
73
|
-
|
|
74
|
-
|
|
188
|
+
|
|
189
|
+
const char *env_sme = getenv("GGML_KLEIDIAI_SME");
|
|
190
|
+
const char *env_threads = getenv("GGML_TOTAL_THREADS");
|
|
191
|
+
|
|
192
|
+
const bool cpu_has_sme = ggml_cpu_has_sme();
|
|
193
|
+
size_t detected_smcus = 0;
|
|
75
194
|
|
|
76
195
|
ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
|
|
77
196
|
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
|
|
78
197
|
((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
|
|
79
198
|
|
|
80
|
-
if (
|
|
81
|
-
|
|
199
|
+
if (env_threads) {
|
|
200
|
+
bool ok = false;
|
|
201
|
+
int hint = parse_uint_env(env_threads, "GGML_TOTAL_THREADS", &ok);
|
|
202
|
+
if (ok && hint > 0) {
|
|
203
|
+
ctx.thread_hint = hint;
|
|
204
|
+
}
|
|
82
205
|
}
|
|
83
206
|
|
|
84
|
-
|
|
85
|
-
|
|
207
|
+
// SME policy:
|
|
208
|
+
// - If CPU doesn't support SME: SME always off.
|
|
209
|
+
// - Else:
|
|
210
|
+
// - env unset => auto-detect cores; enable if detected > 0.
|
|
211
|
+
// - env=0 => force off.
|
|
212
|
+
// - env>0 => force N cores (skip detection).
|
|
213
|
+
int sme_cores = 0;
|
|
214
|
+
bool sme_env_ok = false;
|
|
215
|
+
bool sme_env_set = (env_sme != nullptr);
|
|
216
|
+
|
|
217
|
+
if (!cpu_has_sme) {
|
|
218
|
+
if (sme_env_set) {
|
|
219
|
+
bool ok = false;
|
|
220
|
+
int req = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
|
|
221
|
+
if (ok && req > 0) {
|
|
222
|
+
GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME=%d but SME is not supported on this CPU; disabling SME\n", req);
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
sme_cores = 0;
|
|
226
|
+
} else {
|
|
227
|
+
if (sme_env_set) {
|
|
228
|
+
bool ok = false;
|
|
229
|
+
int v = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
|
|
230
|
+
sme_env_ok = ok;
|
|
231
|
+
|
|
232
|
+
if (!ok) {
|
|
233
|
+
GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME set but parsing failed; falling back to runtime SME-core detection\n");
|
|
234
|
+
detected_smcus = detect_num_smcus();
|
|
235
|
+
sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
|
|
236
|
+
} else if (v == 0) {
|
|
237
|
+
sme_cores = 0;
|
|
238
|
+
} else {
|
|
239
|
+
sme_cores = v;
|
|
240
|
+
}
|
|
241
|
+
} else {
|
|
242
|
+
detected_smcus = detect_num_smcus();
|
|
243
|
+
sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
if (!sme_env_set && sme_cores == 0) {
|
|
247
|
+
GGML_LOG_WARN("kleidiai: SME supported but runtime SME-core detection returned 0; falling back to NEON\n");
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
if (sme_cores > 0) {
|
|
251
|
+
ctx.features |= CPU_FEATURE_SME;
|
|
252
|
+
}
|
|
86
253
|
}
|
|
254
|
+
|
|
255
|
+
// Kernel selection
|
|
87
256
|
ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
|
88
257
|
ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
|
|
89
|
-
|
|
90
|
-
if (ctx.kernels_q4) {
|
|
91
|
-
|
|
258
|
+
|
|
259
|
+
if (!ctx.kernels_q4) {
|
|
260
|
+
GGML_LOG_INFO("kleidiai: no compatible q4 kernels found for CPU features mask %d\n", (int)ctx.features);
|
|
261
|
+
} else {
|
|
262
|
+
GGML_LOG_INFO("kleidiai: primary q4 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
if (!ctx.kernels_q8) {
|
|
266
|
+
GGML_LOG_INFO("kleidiai: no compatible q8 kernels found for CPU features mask %d\n", (int)ctx.features);
|
|
267
|
+
} else {
|
|
268
|
+
GGML_LOG_INFO("kleidiai: primary q8 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
|
|
92
269
|
}
|
|
93
|
-
|
|
94
|
-
|
|
270
|
+
|
|
271
|
+
ctx.sme_thread_cap = (ctx.features & CPU_FEATURE_SME) ? sme_cores : 0;
|
|
272
|
+
|
|
273
|
+
if (ctx.features & CPU_FEATURE_SME) {
|
|
274
|
+
if (sme_env_set && sme_env_ok && sme_cores > 0) {
|
|
275
|
+
GGML_LOG_INFO("kleidiai: SME enabled (GGML_KLEIDIAI_SME=%d override)\n", sme_cores);
|
|
276
|
+
} else {
|
|
277
|
+
GGML_LOG_INFO("kleidiai: SME enabled (runtime-detected SME cores=%d)\n", sme_cores);
|
|
278
|
+
}
|
|
279
|
+
} else {
|
|
280
|
+
GGML_LOG_INFO("kleidiai: SME disabled\n");
|
|
95
281
|
}
|
|
96
|
-
#endif
|
|
97
282
|
}
|
|
283
|
+
|
|
98
284
|
ggml_critical_section_end();
|
|
99
285
|
}
|
|
100
286
|
|
|
287
|
+
static inline int kleidiai_sme_thread_cap() {
|
|
288
|
+
return ctx.sme_thread_cap;
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
static inline size_t align_up(size_t value, size_t alignment) {
|
|
292
|
+
if (alignment == 0) {
|
|
293
|
+
return value;
|
|
294
|
+
}
|
|
295
|
+
const size_t remainder = value % alignment;
|
|
296
|
+
return remainder == 0 ? value : value + (alignment - remainder);
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
static inline bool kleidiai_pack_fallback_allowed() {
|
|
300
|
+
if (ctx.sme_thread_cap <= 0) {
|
|
301
|
+
return false;
|
|
302
|
+
}
|
|
303
|
+
if (ctx.thread_hint <= 0) {
|
|
304
|
+
return true;
|
|
305
|
+
}
|
|
306
|
+
return ctx.thread_hint > ctx.sme_thread_cap;
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
struct kleidiai_weight_header {
|
|
310
|
+
uint32_t magic;
|
|
311
|
+
uint16_t version;
|
|
312
|
+
uint16_t slot_count;
|
|
313
|
+
uint64_t offsets[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
|
|
314
|
+
uint64_t sizes[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
|
|
315
|
+
};
|
|
316
|
+
|
|
317
|
+
static inline kleidiai_weight_header * kleidiai_weight_header_from_ptr(void * data) {
|
|
318
|
+
return reinterpret_cast<kleidiai_weight_header *>(data);
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
static inline const kleidiai_weight_header * kleidiai_weight_header_from_ptr(const void * data) {
|
|
322
|
+
return reinterpret_cast<const kleidiai_weight_header *>(data);
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
static inline bool kleidiai_is_weight_header_valid(const kleidiai_weight_header * header) {
|
|
326
|
+
if (!header) {
|
|
327
|
+
return false;
|
|
328
|
+
}
|
|
329
|
+
if (header->magic != GGML_KLEIDIAI_PACK_MAGIC || header->version != GGML_KLEIDIAI_PACK_VERSION) {
|
|
330
|
+
return false;
|
|
331
|
+
}
|
|
332
|
+
if (header->slot_count == 0 || header->slot_count > GGML_KLEIDIAI_MAX_KERNEL_SLOTS) {
|
|
333
|
+
return false;
|
|
334
|
+
}
|
|
335
|
+
return true;
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
static inline uint8_t * kleidiai_weight_slot_ptr(kleidiai_weight_header * header, int slot) {
|
|
339
|
+
if (!kleidiai_is_weight_header_valid(header)) {
|
|
340
|
+
return nullptr;
|
|
341
|
+
}
|
|
342
|
+
if (slot < 0 || slot >= header->slot_count) {
|
|
343
|
+
return nullptr;
|
|
344
|
+
}
|
|
345
|
+
return reinterpret_cast<uint8_t *>(header) + header->offsets[slot];
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
static inline const uint8_t * kleidiai_weight_slot_ptr(const kleidiai_weight_header * header, int slot) {
|
|
349
|
+
if (!kleidiai_is_weight_header_valid(header)) {
|
|
350
|
+
return nullptr;
|
|
351
|
+
}
|
|
352
|
+
if (slot < 0 || slot >= header->slot_count) {
|
|
353
|
+
return nullptr;
|
|
354
|
+
}
|
|
355
|
+
return reinterpret_cast<const uint8_t *>(header) + header->offsets[slot];
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q4() {
|
|
359
|
+
return ctx.kernels_q4;
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q8() {
|
|
363
|
+
return ctx.kernels_q8;
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
template <typename SelectFallback>
|
|
367
|
+
static int kleidiai_collect_kernel_chain_common(
|
|
368
|
+
ggml_kleidiai_kernels * primary,
|
|
369
|
+
cpu_feature features,
|
|
370
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out,
|
|
371
|
+
SelectFallback select_fallback) {
|
|
372
|
+
int count = 0;
|
|
373
|
+
if (!primary) {
|
|
374
|
+
return 0;
|
|
375
|
+
}
|
|
376
|
+
out[count++] = primary;
|
|
377
|
+
|
|
378
|
+
if ((primary->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
|
|
379
|
+
const cpu_feature fallback_mask = static_cast<cpu_feature>(features & ~CPU_FEATURE_SME);
|
|
380
|
+
if (fallback_mask != CPU_FEATURE_NONE) {
|
|
381
|
+
ggml_kleidiai_kernels * fallback = select_fallback(fallback_mask);
|
|
382
|
+
if (fallback && fallback != primary &&
|
|
383
|
+
fallback->lhs_type == primary->lhs_type &&
|
|
384
|
+
fallback->rhs_type == primary->rhs_type &&
|
|
385
|
+
fallback->op_type == primary->op_type) {
|
|
386
|
+
out[count++] = fallback;
|
|
387
|
+
}
|
|
388
|
+
}
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
return count;
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
static int kleidiai_collect_kernel_chain(const struct ggml_tensor * op,
|
|
395
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
|
|
396
|
+
ggml_kleidiai_kernels * primary = ggml_kleidiai_select_kernels(ctx.features, op);
|
|
397
|
+
return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
|
|
398
|
+
[&](cpu_feature mask) { return ggml_kleidiai_select_kernels(mask, op); });
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
static int kleidiai_collect_q4_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
|
|
402
|
+
ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q4();
|
|
403
|
+
return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
|
|
404
|
+
[&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q4_0(mask); });
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
static int kleidiai_collect_q8_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
|
|
408
|
+
ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q8();
|
|
409
|
+
return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
|
|
410
|
+
[&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q8_0(mask); });
|
|
411
|
+
}
|
|
412
|
+
|
|
101
413
|
static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
|
|
102
414
|
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
|
|
103
415
|
return tensor->ne[dim];
|
|
@@ -126,49 +438,108 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
126
438
|
if (op->op != GGML_OP_MUL_MAT) {
|
|
127
439
|
return false;
|
|
128
440
|
}
|
|
129
|
-
|
|
130
|
-
|
|
441
|
+
|
|
442
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
443
|
+
const int slot_count = kleidiai_collect_kernel_chain(op, kernel_chain);
|
|
444
|
+
if (slot_count == 0) {
|
|
131
445
|
return false;
|
|
132
446
|
}
|
|
133
|
-
bool is_gemv = op->src[1]->ne[1] == 1;
|
|
134
|
-
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
135
|
-
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
136
447
|
|
|
137
|
-
|
|
138
|
-
size_t
|
|
139
|
-
size_t
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
448
|
+
const bool is_gemv = op->src[1]->ne[1] == 1;
|
|
449
|
+
const size_t k = op->src[0]->ne[0];
|
|
450
|
+
const size_t n = op->src[0]->ne[1];
|
|
451
|
+
const size_t m = op->src[1]->ne[1];
|
|
452
|
+
|
|
453
|
+
if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) {
|
|
454
|
+
const size_t qk = (op->src[0]->type == GGML_TYPE_Q4_0) ? QK4_0 : QK8_0;
|
|
455
|
+
|
|
456
|
+
size_t cursor = 0;
|
|
457
|
+
bool any_slot = false;
|
|
458
|
+
|
|
459
|
+
for (int slot = 0; slot < slot_count; ++slot) {
|
|
460
|
+
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
461
|
+
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
462
|
+
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
463
|
+
|
|
464
|
+
if (!lhs_info || !lhs_info->packed_size_ex || !kernel) {
|
|
465
|
+
return false;
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
const size_t mr = kernel->get_mr();
|
|
469
|
+
const size_t kr = kernel->get_kr();
|
|
470
|
+
const size_t sr = kernel->get_sr();
|
|
471
|
+
|
|
472
|
+
const size_t packed = lhs_info->packed_size_ex(m, k, qk, mr, kr, sr);
|
|
473
|
+
|
|
474
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
475
|
+
cursor += packed;
|
|
476
|
+
any_slot = true;
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
if (!any_slot) {
|
|
480
|
+
return false;
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
size = cursor;
|
|
484
|
+
return true;
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
if (op->src[0]->type == GGML_TYPE_F16) {
|
|
153
488
|
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
|
|
154
489
|
const int64_t rhs_batch_size0 = op->src[0]->ne[2];
|
|
490
|
+
GGML_ASSERT(rhs_batch_size0 > 0);
|
|
155
491
|
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
492
|
+
|
|
493
|
+
size_t cursor = 0;
|
|
494
|
+
bool any_slot = false;
|
|
495
|
+
|
|
496
|
+
for (int slot = 0; slot < slot_count; ++slot) {
|
|
497
|
+
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
498
|
+
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
499
|
+
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
500
|
+
if (!lhs_info || !lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex || !kernel) {
|
|
501
|
+
return false;
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
const size_t mr = kernel->get_mr();
|
|
505
|
+
const size_t kr = kernel->get_kr();
|
|
506
|
+
const size_t sr = kernel->get_sr();
|
|
507
|
+
|
|
508
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
509
|
+
cursor += lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr);
|
|
510
|
+
any_slot = true;
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
for (int slot = 0; slot < slot_count; ++slot) {
|
|
514
|
+
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
515
|
+
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
516
|
+
if (!kernel || !kernels->rhs_info.packed_size_ex) {
|
|
517
|
+
return false;
|
|
518
|
+
}
|
|
519
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
520
|
+
cursor += kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0);
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
524
|
+
cursor += k * n * sizeof(float);
|
|
525
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
526
|
+
cursor += n * sizeof(float);
|
|
527
|
+
|
|
528
|
+
if (!any_slot) {
|
|
529
|
+
return false;
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
size = cursor;
|
|
533
|
+
return true;
|
|
161
534
|
}
|
|
162
535
|
|
|
163
|
-
return
|
|
536
|
+
return false;
|
|
164
537
|
}
|
|
165
538
|
|
|
166
539
|
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
|
|
167
540
|
if (dst->op == GGML_OP_MUL_MAT) {
|
|
168
|
-
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
|
169
|
-
return
|
|
170
|
-
} else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
|
|
171
|
-
return compute_forward_q8_0(params, dst);
|
|
541
|
+
if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
|
|
542
|
+
return compute_forward_qx(params, dst);
|
|
172
543
|
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
|
173
544
|
return compute_forward_fp16(params, dst);
|
|
174
545
|
}
|
|
@@ -331,204 +702,457 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
331
702
|
return true;
|
|
332
703
|
}
|
|
333
704
|
|
|
334
|
-
bool
|
|
335
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
|
705
|
+
bool compute_forward_qx(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
706
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
|
|
336
707
|
|
|
337
708
|
const ggml_tensor * src0 = dst->src[0];
|
|
338
709
|
const ggml_tensor * src1 = dst->src[1];
|
|
339
710
|
|
|
340
711
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
341
712
|
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
bool is_gemv = src1->ne[1] == 1;
|
|
348
|
-
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
349
|
-
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
350
|
-
|
|
351
|
-
GGML_ASSERT(kernel);
|
|
352
|
-
if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
|
|
353
|
-
!kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
|
|
354
|
-
return false;
|
|
355
|
-
}
|
|
713
|
+
const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
|
|
714
|
+
const bool has_header = kleidiai_is_weight_header_valid(header);
|
|
715
|
+
const bool is_gemv = src1->ne[1] == 1;
|
|
716
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
717
|
+
const int slot_total = kleidiai_collect_kernel_chain(dst, kernel_chain);
|
|
356
718
|
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
719
|
+
auto weight_for_slot = [&](int slot_index, size_t & size_out) -> const uint8_t * {
|
|
720
|
+
if (slot_index < 0 || slot_index >= slot_total) {
|
|
721
|
+
return nullptr;
|
|
722
|
+
}
|
|
723
|
+
if (has_header) {
|
|
724
|
+
if (slot_index < header->slot_count) {
|
|
725
|
+
size_out = static_cast<size_t>(header->sizes[slot_index]);
|
|
726
|
+
return kleidiai_weight_slot_ptr(header, slot_index);
|
|
727
|
+
}
|
|
728
|
+
return nullptr;
|
|
729
|
+
}
|
|
730
|
+
if (slot_index == 0) {
|
|
731
|
+
size_out = ggml_nbytes(src0);
|
|
732
|
+
return static_cast<const uint8_t *>(src0->data);
|
|
733
|
+
}
|
|
734
|
+
return nullptr;
|
|
735
|
+
};
|
|
736
|
+
|
|
737
|
+
struct runtime_slot {
|
|
738
|
+
int slot_index;
|
|
739
|
+
ggml_kleidiai_kernels * kernels;
|
|
740
|
+
kernel_info * kernel;
|
|
741
|
+
lhs_packing_info * lhs_info;
|
|
742
|
+
size_t mr;
|
|
743
|
+
size_t nr;
|
|
744
|
+
size_t kr;
|
|
745
|
+
size_t sr;
|
|
746
|
+
size_t n_step;
|
|
747
|
+
size_t lhs_packed_size;
|
|
748
|
+
size_t lhs_offset;
|
|
749
|
+
size_t n_offset;
|
|
750
|
+
size_t n_cols;
|
|
751
|
+
int assigned_threads;
|
|
752
|
+
int thread_begin;
|
|
753
|
+
int thread_end;
|
|
754
|
+
const uint8_t * rhs_base;
|
|
755
|
+
};
|
|
756
|
+
|
|
757
|
+
std::array<runtime_slot, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> runtime{};
|
|
758
|
+
int runtime_count = 0;
|
|
759
|
+
|
|
760
|
+
for (int slot = 0; slot < slot_total && runtime_count < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
|
|
761
|
+
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
762
|
+
kernel_info * kinfo = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
763
|
+
lhs_packing_info * linfo = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
764
|
+
if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || !linfo->get_offset ||
|
|
765
|
+
!kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset) {
|
|
766
|
+
continue;
|
|
767
|
+
}
|
|
360
768
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
769
|
+
size_t rhs_size = 0;
|
|
770
|
+
const uint8_t * rhs_ptr = weight_for_slot(slot, rhs_size);
|
|
771
|
+
if (!rhs_ptr || rhs_size == 0) {
|
|
772
|
+
continue;
|
|
773
|
+
}
|
|
364
774
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
775
|
+
runtime[runtime_count] = {
|
|
776
|
+
slot,
|
|
777
|
+
kernels,
|
|
778
|
+
kinfo,
|
|
779
|
+
linfo,
|
|
780
|
+
kinfo->get_mr(),
|
|
781
|
+
kinfo->get_nr(),
|
|
782
|
+
kinfo->get_kr(),
|
|
783
|
+
kinfo->get_sr(),
|
|
784
|
+
kinfo->get_n_step(),
|
|
785
|
+
0,
|
|
786
|
+
0,
|
|
787
|
+
0,
|
|
788
|
+
0,
|
|
789
|
+
0,
|
|
790
|
+
0,
|
|
791
|
+
0,
|
|
792
|
+
rhs_ptr
|
|
793
|
+
};
|
|
794
|
+
++runtime_count;
|
|
795
|
+
}
|
|
368
796
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
797
|
+
if (runtime_count == 0) {
|
|
798
|
+
ggml_kleidiai_kernels * fallback = ggml_kleidiai_select_kernels(ctx.features, dst);
|
|
799
|
+
if (!fallback) {
|
|
800
|
+
return false;
|
|
801
|
+
}
|
|
802
|
+
kernel_info * kinfo = is_gemv ? &fallback->gemv : &fallback->gemm;
|
|
803
|
+
lhs_packing_info * linfo = is_gemv ? &fallback->gemv_lhs_info : &fallback->gemm_lhs_info;
|
|
804
|
+
rhs_packing_info * rinfo = &fallback->rhs_info;
|
|
805
|
+
if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex ||
|
|
806
|
+
!kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset ||
|
|
807
|
+
!rinfo || !rinfo->pack_func_ex || !rinfo->packed_size_ex) {
|
|
808
|
+
return false;
|
|
809
|
+
}
|
|
810
|
+
kernel_chain[0] = fallback;
|
|
811
|
+
runtime[0] = {
|
|
812
|
+
0,
|
|
813
|
+
fallback,
|
|
814
|
+
kinfo,
|
|
815
|
+
linfo,
|
|
816
|
+
kinfo->get_mr(),
|
|
817
|
+
kinfo->get_nr(),
|
|
818
|
+
kinfo->get_kr(),
|
|
819
|
+
kinfo->get_sr(),
|
|
820
|
+
kinfo->get_n_step(),
|
|
821
|
+
0,
|
|
822
|
+
0,
|
|
823
|
+
0,
|
|
824
|
+
0,
|
|
825
|
+
0,
|
|
826
|
+
0,
|
|
827
|
+
0,
|
|
828
|
+
nullptr
|
|
829
|
+
};
|
|
830
|
+
size_t rhs_size_fallback = 0;
|
|
831
|
+
const uint8_t * rhs_base = weight_for_slot(0, rhs_size_fallback);
|
|
832
|
+
if (!rhs_base) {
|
|
833
|
+
rhs_base = static_cast<const uint8_t *>(src0->data);
|
|
834
|
+
}
|
|
835
|
+
runtime[0].rhs_base = rhs_base;
|
|
836
|
+
runtime_count = 1;
|
|
837
|
+
}
|
|
372
838
|
|
|
373
|
-
const
|
|
374
|
-
const
|
|
375
|
-
const size_t n_start = ith * num_n_per_thread;
|
|
839
|
+
const int nth_total = params->nth > 0 ? params->nth : 1;
|
|
840
|
+
const int ith_total = params->ith;
|
|
376
841
|
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
842
|
+
int sme_slot = -1;
|
|
843
|
+
for (int i = 0; i < runtime_count; ++i) {
|
|
844
|
+
if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
|
|
845
|
+
sme_slot = i;
|
|
846
|
+
break;
|
|
382
847
|
}
|
|
383
848
|
}
|
|
384
849
|
|
|
385
|
-
|
|
386
|
-
const
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
850
|
+
const int sme_cap_limit = ctx.sme_thread_cap;
|
|
851
|
+
const bool use_hybrid = sme_cap_limit > 0 &&
|
|
852
|
+
runtime_count > 1 &&
|
|
853
|
+
nth_total > sme_cap_limit;
|
|
854
|
+
// Heuristic: disable hybrid for very small workloads where per-slot overhead dominates.
|
|
855
|
+
// If rows are small or average columns per thread are small, keep single-slot.
|
|
856
|
+
size_t min_cols_per_thread = 0;
|
|
857
|
+
if (runtime_count > 0 && nth_total > 0) {
|
|
858
|
+
min_cols_per_thread = (size_t) std::max<int64_t>(1, (int64_t)ne01 / (int64_t)nth_total);
|
|
391
859
|
}
|
|
860
|
+
const bool too_small_for_hybrid = (min_cols_per_thread < 2) || (ne11 < 128);
|
|
392
861
|
|
|
393
|
-
|
|
394
|
-
// Transform LHS
|
|
395
|
-
const size_t src_stride = src1->nb[1];
|
|
396
|
-
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
|
|
397
|
-
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr);
|
|
398
|
-
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
|
399
|
-
|
|
400
|
-
// Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer
|
|
401
|
-
lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
|
|
402
|
-
}
|
|
862
|
+
const bool hybrid_enabled = use_hybrid && !too_small_for_hybrid;
|
|
403
863
|
|
|
404
|
-
|
|
864
|
+
if (!hybrid_enabled) {
|
|
865
|
+
int chosen_slot = 0;
|
|
866
|
+
if (too_small_for_hybrid && sme_slot != -1) {
|
|
867
|
+
chosen_slot = sme_slot;
|
|
868
|
+
} else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) {
|
|
869
|
+
chosen_slot = 1;
|
|
870
|
+
}
|
|
871
|
+
if (chosen_slot != 0 && chosen_slot < runtime_count) {
|
|
872
|
+
runtime[0] = runtime[chosen_slot];
|
|
873
|
+
}
|
|
874
|
+
runtime_count = runtime_count > 0 ? 1 : 0;
|
|
405
875
|
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
|
876
|
+
// Recompute SME slot based on the collapsed runtime[0]
|
|
877
|
+
sme_slot = -1;
|
|
878
|
+
if (runtime_count > 0 &&
|
|
879
|
+
(runtime[0].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
|
|
880
|
+
sme_slot = 0;
|
|
881
|
+
}
|
|
882
|
+
}
|
|
414
883
|
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
884
|
+
int sme_cap = kleidiai_sme_thread_cap();
|
|
885
|
+
if (sme_cap < 0) {
|
|
886
|
+
sme_cap = nth_total;
|
|
418
887
|
}
|
|
888
|
+
sme_cap = std::min(sme_cap, nth_total);
|
|
419
889
|
|
|
420
|
-
|
|
421
|
-
|
|
890
|
+
int threads_remaining = nth_total;
|
|
891
|
+
if (sme_slot != -1) {
|
|
892
|
+
int sme_threads = std::min(std::max(sme_cap, 0), threads_remaining);
|
|
893
|
+
runtime[sme_slot].assigned_threads = sme_threads;
|
|
894
|
+
threads_remaining -= sme_threads;
|
|
895
|
+
}
|
|
422
896
|
|
|
423
|
-
|
|
424
|
-
|
|
897
|
+
int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
|
|
898
|
+
int fallback_count = 0;
|
|
899
|
+
for (int i = 0; i < runtime_count; ++i) {
|
|
900
|
+
if (i == sme_slot) {
|
|
901
|
+
continue;
|
|
902
|
+
}
|
|
903
|
+
fallback_indices[fallback_count++] = i;
|
|
904
|
+
}
|
|
425
905
|
|
|
426
|
-
|
|
427
|
-
|
|
906
|
+
for (int fi = 0; fi < fallback_count; ++fi) {
|
|
907
|
+
if (threads_remaining <= 0) {
|
|
908
|
+
break;
|
|
909
|
+
}
|
|
910
|
+
const int slot_index = fallback_indices[fi];
|
|
911
|
+
const int slots_left = fallback_count - fi;
|
|
912
|
+
int share = (threads_remaining + slots_left - 1) / slots_left;
|
|
913
|
+
share = std::min(share, threads_remaining);
|
|
914
|
+
runtime[slot_index].assigned_threads = share;
|
|
915
|
+
threads_remaining -= share;
|
|
916
|
+
}
|
|
428
917
|
|
|
429
|
-
|
|
918
|
+
if (threads_remaining > 0) {
|
|
919
|
+
const int fallback_slot = (sme_slot != -1) ? sme_slot : 0;
|
|
920
|
+
runtime[fallback_slot].assigned_threads += threads_remaining;
|
|
921
|
+
threads_remaining = 0;
|
|
922
|
+
}
|
|
430
923
|
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
924
|
+
int thread_cursor = 0;
|
|
925
|
+
for (int i = 0; i < runtime_count; ++i) {
|
|
926
|
+
runtime[i].thread_begin = thread_cursor;
|
|
927
|
+
thread_cursor += runtime[i].assigned_threads;
|
|
928
|
+
runtime[i].thread_end = thread_cursor;
|
|
434
929
|
}
|
|
435
930
|
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
931
|
+
if (thread_cursor < nth_total && runtime_count > 0) {
|
|
932
|
+
runtime[runtime_count - 1].assigned_threads += nth_total - thread_cursor;
|
|
933
|
+
runtime[runtime_count - 1].thread_end = nth_total;
|
|
934
|
+
}
|
|
439
935
|
|
|
440
|
-
|
|
441
|
-
|
|
936
|
+
int local_slot = -1;
|
|
937
|
+
int local_ith = 0;
|
|
938
|
+
for (int i = 0; i < runtime_count; ++i) {
|
|
939
|
+
if (ith_total >= runtime[i].thread_begin && ith_total < runtime[i].thread_end) {
|
|
940
|
+
local_slot = i;
|
|
941
|
+
local_ith = ith_total - runtime[i].thread_begin;
|
|
942
|
+
break;
|
|
943
|
+
}
|
|
944
|
+
}
|
|
945
|
+
if (local_slot == -1) {
|
|
442
946
|
return false;
|
|
443
947
|
}
|
|
444
948
|
|
|
445
|
-
const int ith = params->ith;
|
|
446
|
-
const int nth_raw = params->nth;
|
|
447
|
-
const int nth = nth_raw > 0 ? nth_raw : 1;
|
|
448
|
-
|
|
449
949
|
const size_t k = ne00;
|
|
450
950
|
const size_t m = ne11;
|
|
451
951
|
const size_t n = ne01;
|
|
452
952
|
|
|
453
|
-
size_t
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
953
|
+
size_t cursor = 0;
|
|
954
|
+
for (int i = 0; i < runtime_count; ++i) {
|
|
955
|
+
const ggml_type slot_rhs_type = runtime[i].kernels->rhs_type;
|
|
956
|
+
const size_t slot_pack_size_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
|
|
957
|
+
slot_rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
|
|
958
|
+
runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, slot_pack_size_arg, runtime[i].mr, runtime[i].kr, runtime[i].sr);
|
|
959
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
960
|
+
runtime[i].lhs_offset = cursor;
|
|
961
|
+
cursor += runtime[i].lhs_packed_size;
|
|
962
|
+
}
|
|
460
963
|
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
const size_t n_start = ith * num_n_per_thread;
|
|
964
|
+
GGML_ASSERT(cursor <= params->wsize);
|
|
965
|
+
uint8_t * scratch = static_cast<uint8_t *>(params->wdata);
|
|
464
966
|
|
|
465
|
-
size_t
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
967
|
+
size_t assigned_cols = 0;
|
|
968
|
+
uint64_t weighted_total = 0;
|
|
969
|
+
if (runtime_count > 1 && sme_slot != -1) {
|
|
970
|
+
for (int i = 0; i < runtime_count; ++i) {
|
|
971
|
+
const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
|
|
972
|
+
weighted_total += (uint64_t)runtime[i].assigned_threads * weight;
|
|
470
973
|
}
|
|
471
974
|
}
|
|
975
|
+
for (int i = 0; i < runtime_count; ++i) {
|
|
976
|
+
runtime[i].n_offset = assigned_cols;
|
|
977
|
+
if (runtime[i].assigned_threads == 0) {
|
|
978
|
+
runtime[i].n_cols = 0;
|
|
979
|
+
continue;
|
|
980
|
+
}
|
|
981
|
+
const size_t remaining_cols = n - assigned_cols;
|
|
982
|
+
if (remaining_cols == 0) {
|
|
983
|
+
runtime[i].n_cols = 0;
|
|
984
|
+
continue;
|
|
985
|
+
}
|
|
986
|
+
const size_t step = runtime[i].n_step ? runtime[i].n_step : 1;
|
|
987
|
+
size_t target = 0;
|
|
988
|
+
if (weighted_total > 0) {
|
|
989
|
+
const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
|
|
990
|
+
target = (size_t)(((uint64_t)n * runtime[i].assigned_threads * weight) / weighted_total);
|
|
991
|
+
} else {
|
|
992
|
+
target = (size_t)(((uint64_t)n * runtime[i].assigned_threads) / nth_total);
|
|
993
|
+
}
|
|
994
|
+
target = std::min(target, remaining_cols);
|
|
995
|
+
size_t aligned = round_down(target, step);
|
|
996
|
+
if (aligned == 0 && remaining_cols >= step) {
|
|
997
|
+
aligned = step;
|
|
998
|
+
}
|
|
999
|
+
runtime[i].n_cols = aligned;
|
|
1000
|
+
assigned_cols += aligned;
|
|
1001
|
+
}
|
|
472
1002
|
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
1003
|
+
if (assigned_cols < n) {
|
|
1004
|
+
for (int i = runtime_count - 1; i >= 0; --i) {
|
|
1005
|
+
if (runtime[i].assigned_threads > 0) {
|
|
1006
|
+
runtime[i].n_cols += n - assigned_cols;
|
|
1007
|
+
break;
|
|
1008
|
+
}
|
|
1009
|
+
}
|
|
478
1010
|
}
|
|
1011
|
+
const size_t dst_stride = dst->nb[1];
|
|
479
1012
|
|
|
480
|
-
|
|
481
|
-
const
|
|
482
|
-
|
|
483
|
-
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
|
|
484
|
-
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
|
1013
|
+
for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) {
|
|
1014
|
+
const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
|
|
1015
|
+
uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
|
|
485
1016
|
|
|
486
|
-
|
|
487
|
-
|
|
1017
|
+
if (runtime[local_slot].assigned_threads > 0) {
|
|
1018
|
+
runtime_slot & slot = runtime[local_slot];
|
|
1019
|
+
const ggml_type slot_rhs_type = slot.kernels->rhs_type;
|
|
1020
|
+
const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
|
|
1021
|
+
slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
|
|
1022
|
+
const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr);
|
|
1023
|
+
int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads;
|
|
1024
|
+
max_threads = std::max<int64_t>(1, max_threads);
|
|
1025
|
+
const int64_t use_threads = std::min<int64_t>(slot.assigned_threads, max_threads);
|
|
488
1026
|
|
|
489
|
-
|
|
1027
|
+
if (local_ith < use_threads) {
|
|
1028
|
+
const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / use_threads), slot.mr);
|
|
1029
|
+
const int64_t num_m_per_threadN_1 = (int64_t)m - (use_threads - 1) * num_m_per_thread0;
|
|
490
1030
|
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
1031
|
+
const int64_t m_start = (int64_t)local_ith * num_m_per_thread0;
|
|
1032
|
+
const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
|
1033
|
+
|
|
1034
|
+
const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
|
|
1035
|
+
const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
|
|
1036
|
+
const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0;
|
|
1037
|
+
|
|
1038
|
+
int64_t remaining = m_count;
|
|
1039
|
+
int64_t cur = m_start;
|
|
1040
|
+
|
|
1041
|
+
uint8_t * lhs_packed = scratch + slot.lhs_offset;
|
|
1042
|
+
while (remaining > 0) {
|
|
1043
|
+
const int64_t row_in_group = cur;
|
|
1044
|
+
const int64_t avail = (int64_t)m - row_in_group;
|
|
1045
|
+
const int64_t take = std::min(avail, remaining);
|
|
1046
|
+
|
|
1047
|
+
const size_t src_off = slot.lhs_info->get_offset(row_in_group, src1->nb[1]);
|
|
1048
|
+
const void * src_ptr = lhs_batch_base + src_off;
|
|
1049
|
+
const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
|
|
1050
|
+
void * dst_ptr = lhs_packed + dst_off;
|
|
1051
|
+
|
|
1052
|
+
slot.lhs_info->pack_func_ex(take, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr);
|
|
1053
|
+
|
|
1054
|
+
cur += take;
|
|
1055
|
+
remaining -= take;
|
|
1056
|
+
}
|
|
1057
|
+
}
|
|
1058
|
+
}
|
|
1059
|
+
|
|
1060
|
+
ggml_barrier(params->threadpool);
|
|
498
1061
|
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
1062
|
+
runtime_slot & slot = runtime[local_slot];
|
|
1063
|
+
if (slot.n_cols > 0 && slot.assigned_threads > 0) {
|
|
1064
|
+
int64_t active_threads = slot.assigned_threads;
|
|
1065
|
+
const int64_t max_threads = slot.n_step ? (slot.n_cols / slot.n_step) : slot.assigned_threads;
|
|
1066
|
+
if (max_threads > 0) {
|
|
1067
|
+
active_threads = std::min<int64_t>(active_threads, std::max<int64_t>(1, max_threads));
|
|
1068
|
+
}
|
|
1069
|
+
active_threads = std::max<int64_t>(1, active_threads);
|
|
1070
|
+
|
|
1071
|
+
if (local_ith < active_threads) {
|
|
1072
|
+
const size_t step = slot.n_step ? slot.n_step : 1;
|
|
1073
|
+
const size_t chunk0 = round_down((size_t)(slot.n_cols / active_threads), step);
|
|
1074
|
+
const size_t chunkN = slot.n_cols - (active_threads - 1) * chunk0;
|
|
1075
|
+
const size_t local_start = (size_t)local_ith * chunk0;
|
|
1076
|
+
const size_t cols = (local_ith == active_threads - 1) ? chunkN : chunk0;
|
|
1077
|
+
|
|
1078
|
+
if (cols > 0) {
|
|
1079
|
+
const ggml_type slot_rhs_type = slot.kernels->rhs_type;
|
|
1080
|
+
const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
|
|
1081
|
+
slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
|
|
1082
|
+
const size_t slot_rhs_block_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
|
|
1083
|
+
slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
|
|
1084
|
+
const size_t global_start = slot.n_offset + local_start;
|
|
1085
|
+
const size_t lhs_packed_offset = slot.lhs_info->get_packed_offset_ex(0, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
|
|
1086
|
+
const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot_rhs_block_arg);
|
|
1087
|
+
const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride);
|
|
1088
|
+
|
|
1089
|
+
const uint8_t * lhs_ptr = scratch + slot.lhs_offset + lhs_packed_offset;
|
|
1090
|
+
const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset;
|
|
1091
|
+
float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
|
|
1092
|
+
|
|
1093
|
+
slot.kernel->run_kernel_ex(m, cols, k, slot_rhs_block_arg,
|
|
1094
|
+
lhs_ptr,
|
|
1095
|
+
rhs_ptr,
|
|
1096
|
+
dst_ptr,
|
|
1097
|
+
dst_stride,
|
|
1098
|
+
sizeof(float),
|
|
1099
|
+
-FLT_MAX,
|
|
1100
|
+
FLT_MAX);
|
|
1101
|
+
}
|
|
1102
|
+
}
|
|
1103
|
+
}
|
|
1104
|
+
|
|
1105
|
+
if (batch_idx != ne12 - 1) {
|
|
1106
|
+
ggml_barrier(params->threadpool);
|
|
1107
|
+
}
|
|
502
1108
|
}
|
|
503
1109
|
|
|
504
1110
|
return true;
|
|
505
1111
|
}
|
|
506
1112
|
|
|
507
1113
|
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
1114
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
|
|
508
1115
|
const ggml_tensor * src0 = dst->src[0];
|
|
509
1116
|
const ggml_tensor * src1 = dst->src[1];
|
|
510
1117
|
|
|
511
1118
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
512
1119
|
|
|
1120
|
+
const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
|
|
1121
|
+
const bool has_header = kleidiai_is_weight_header_valid(header);
|
|
1122
|
+
|
|
1123
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
1124
|
+
const bool want_q8 = src0->type == GGML_TYPE_Q8_0;
|
|
1125
|
+
const int chain_count = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
|
|
1126
|
+
: kleidiai_collect_q4_chain(kernel_chain);
|
|
1127
|
+
|
|
513
1128
|
ggml_kleidiai_kernels * kernels = nullptr;
|
|
514
|
-
|
|
515
|
-
size_t num_bytes_multiplier = 0;
|
|
1129
|
+
const uint8_t * packed_base = static_cast<const uint8_t *>(src0->data);
|
|
516
1130
|
|
|
517
|
-
if (
|
|
518
|
-
|
|
519
|
-
|
|
1131
|
+
if (has_header && chain_count > 0) {
|
|
1132
|
+
int select_slot = 0;
|
|
1133
|
+
if (select_slot >= header->slot_count) {
|
|
1134
|
+
select_slot = header->slot_count - 1;
|
|
520
1135
|
}
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
1136
|
+
if (select_slot >= 0 && select_slot < chain_count) {
|
|
1137
|
+
kernels = kernel_chain[select_slot];
|
|
1138
|
+
const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, select_slot);
|
|
1139
|
+
if (slot_ptr) {
|
|
1140
|
+
packed_base = slot_ptr;
|
|
1141
|
+
}
|
|
527
1142
|
}
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
1143
|
+
}
|
|
1144
|
+
|
|
1145
|
+
if (!kernels && chain_count > 0) {
|
|
1146
|
+
kernels = kernel_chain[0];
|
|
1147
|
+
if (has_header) {
|
|
1148
|
+
const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, 0);
|
|
1149
|
+
if (slot_ptr) {
|
|
1150
|
+
packed_base = slot_ptr;
|
|
1151
|
+
}
|
|
1152
|
+
}
|
|
1153
|
+
}
|
|
1154
|
+
|
|
1155
|
+
if (!kernels) {
|
|
532
1156
|
return false;
|
|
533
1157
|
}
|
|
534
1158
|
|
|
@@ -541,6 +1165,19 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
541
1165
|
const int64_t nc = ne00;
|
|
542
1166
|
const int64_t nr = ggml_nelements(src1);
|
|
543
1167
|
|
|
1168
|
+
const ggml_type rhs_type = kernels->rhs_type;
|
|
1169
|
+
size_t block_len = 0;
|
|
1170
|
+
size_t num_bytes_multiplier = 0;
|
|
1171
|
+
if (rhs_type == GGML_TYPE_Q4_0) {
|
|
1172
|
+
block_len = QK4_0;
|
|
1173
|
+
num_bytes_multiplier = sizeof(uint16_t);
|
|
1174
|
+
} else if (rhs_type == GGML_TYPE_Q8_0) {
|
|
1175
|
+
block_len = QK8_0;
|
|
1176
|
+
num_bytes_multiplier = sizeof(float);
|
|
1177
|
+
} else {
|
|
1178
|
+
return false;
|
|
1179
|
+
}
|
|
1180
|
+
|
|
544
1181
|
const size_t block_rows = kernel->get_nr();
|
|
545
1182
|
const size_t kr = kernel->get_kr();
|
|
546
1183
|
|
|
@@ -559,7 +1196,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
559
1196
|
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
|
|
560
1197
|
|
|
561
1198
|
float *out = (float *)((char *)dst->data + i * nb1);
|
|
562
|
-
rhs_info->to_float(
|
|
1199
|
+
rhs_info->to_float(packed_base, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
|
|
563
1200
|
}
|
|
564
1201
|
|
|
565
1202
|
return true;
|
|
@@ -567,36 +1204,39 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
567
1204
|
|
|
568
1205
|
public:
|
|
569
1206
|
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
|
1207
|
+
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0);
|
|
570
1208
|
const size_t n = tensor->ne[1];
|
|
571
1209
|
const size_t k = tensor->ne[0];
|
|
572
1210
|
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
size_t nr = ctx.kernels_q4->gemm.get_nr();
|
|
578
|
-
size_t kr = ctx.kernels_q4->gemm.get_kr();
|
|
579
|
-
size_t sr = ctx.kernels_q4->gemm.get_sr();
|
|
1211
|
+
kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(tensor->data);
|
|
1212
|
+
if (!header) {
|
|
1213
|
+
return -1;
|
|
1214
|
+
}
|
|
580
1215
|
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
1216
|
+
header->magic = GGML_KLEIDIAI_PACK_MAGIC;
|
|
1217
|
+
header->version = GGML_KLEIDIAI_PACK_VERSION;
|
|
1218
|
+
header->slot_count = 0;
|
|
1219
|
+
|
|
1220
|
+
uint8_t * base_ptr = static_cast<uint8_t *>(tensor->data);
|
|
1221
|
+
size_t cursor = sizeof(kleidiai_weight_header);
|
|
1222
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
1223
|
+
|
|
1224
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
1225
|
+
const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
|
|
1226
|
+
const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
|
|
1227
|
+
: kleidiai_collect_q4_chain(kernel_chain);
|
|
1228
|
+
const bool allow_fallback = kleidiai_pack_fallback_allowed();
|
|
1229
|
+
|
|
1230
|
+
std::vector<int8_t> qdata;
|
|
1231
|
+
std::vector<float> scales;
|
|
1232
|
+
|
|
1233
|
+
if (want_q8 && slot_total > 0) {
|
|
1234
|
+
qdata.resize(n * k, 0);
|
|
1235
|
+
scales.resize(n, 0.0f);
|
|
593
1236
|
|
|
594
1237
|
const size_t row_stride = tensor->nb[1];
|
|
595
1238
|
const size_t k_blocks = (k + QK8_0 - 1) / QK8_0;
|
|
596
1239
|
|
|
597
|
-
std::vector<int8_t> qdata(n * k, 0);
|
|
598
|
-
std::vector<float> scales(n, 0.0f);
|
|
599
|
-
|
|
600
1240
|
for (size_t row = 0; row < n; ++row) {
|
|
601
1241
|
const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
|
|
602
1242
|
static_cast<const uint8_t *>(data) + row * row_stride);
|
|
@@ -610,7 +1250,7 @@ public:
|
|
|
610
1250
|
if (linear_idx >= k) {
|
|
611
1251
|
break;
|
|
612
1252
|
}
|
|
613
|
-
const float value = d * blk.qs[l];
|
|
1253
|
+
const float value = d * static_cast<float>(blk.qs[l]);
|
|
614
1254
|
max_abs = std::max(max_abs, std::fabs(value));
|
|
615
1255
|
}
|
|
616
1256
|
}
|
|
@@ -627,31 +1267,73 @@ public:
|
|
|
627
1267
|
if (linear_idx >= k) {
|
|
628
1268
|
break;
|
|
629
1269
|
}
|
|
630
|
-
const float value = d * blk.qs[l];
|
|
1270
|
+
const float value = d * static_cast<float>(blk.qs[l]);
|
|
631
1271
|
int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
|
|
632
1272
|
q = std::clamp(q, -127, 127);
|
|
633
1273
|
qdata[row * k + linear_idx] = static_cast<int8_t>(q);
|
|
634
1274
|
}
|
|
635
1275
|
}
|
|
636
1276
|
}
|
|
1277
|
+
}
|
|
1278
|
+
|
|
1279
|
+
for (int slot = 0; slot < slot_total && slot < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
|
|
1280
|
+
if (!allow_fallback && slot > 0) {
|
|
1281
|
+
break;
|
|
1282
|
+
}
|
|
1283
|
+
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
1284
|
+
kernel_info * kernel = &kernels->gemm;
|
|
1285
|
+
rhs_packing_info * rhs_info = &kernels->rhs_info;
|
|
1286
|
+
if (!rhs_info || !rhs_info->pack_func_ex || !rhs_info->packed_size_ex || !kernel) {
|
|
1287
|
+
continue;
|
|
1288
|
+
}
|
|
1289
|
+
|
|
1290
|
+
const size_t nr = kernel->get_nr();
|
|
1291
|
+
const size_t kr = kernel->get_kr();
|
|
1292
|
+
const size_t sr = kernel->get_sr();
|
|
1293
|
+
const ggml_type rhs_type = kernels->rhs_type;
|
|
1294
|
+
const size_t block_len = rhs_type == GGML_TYPE_Q8_0 ? QK8_0 :
|
|
1295
|
+
rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : 0;
|
|
1296
|
+
if (block_len == 0) {
|
|
1297
|
+
continue;
|
|
1298
|
+
}
|
|
637
1299
|
|
|
638
|
-
size_t
|
|
639
|
-
size_t
|
|
640
|
-
|
|
1300
|
+
const size_t packed_size = rhs_info->packed_size_ex(n, k, nr, kr, block_len);
|
|
1301
|
+
const size_t aligned_cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
1302
|
+
|
|
1303
|
+
uint8_t * dst_ptr = base_ptr + aligned_cursor;
|
|
1304
|
+
|
|
1305
|
+
if (rhs_type == GGML_TYPE_Q4_0) {
|
|
1306
|
+
struct kai_rhs_pack_qs4cxs1s0_param params;
|
|
1307
|
+
params.lhs_zero_point = 1;
|
|
1308
|
+
params.rhs_zero_point = 8;
|
|
1309
|
+
rhs_info->pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
|
|
1310
|
+
static_cast<const uint8_t *>(data), nullptr, nullptr,
|
|
1311
|
+
dst_ptr, 0, ¶ms);
|
|
1312
|
+
} else if (rhs_type == GGML_TYPE_Q8_0) {
|
|
1313
|
+
struct kai_rhs_pack_qsi8cx_params params;
|
|
1314
|
+
params.lhs_zero_point = 1;
|
|
1315
|
+
params.scale_multiplier = 1.0f;
|
|
1316
|
+
rhs_info->pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
|
|
1317
|
+
qdata.data(), nullptr, scales.data(),
|
|
1318
|
+
dst_ptr, 0, ¶ms);
|
|
1319
|
+
} else {
|
|
1320
|
+
continue;
|
|
1321
|
+
}
|
|
1322
|
+
|
|
1323
|
+
header->offsets[header->slot_count] = aligned_cursor;
|
|
1324
|
+
header->sizes[header->slot_count] = packed_size;
|
|
1325
|
+
++header->slot_count;
|
|
641
1326
|
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
params.scale_multiplier = 1.0f;
|
|
1327
|
+
cursor = aligned_cursor + packed_size;
|
|
1328
|
+
}
|
|
645
1329
|
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
return 0;
|
|
1330
|
+
if (header->slot_count == 0) {
|
|
1331
|
+
header->magic = 0;
|
|
1332
|
+
header->version = 0;
|
|
1333
|
+
memcpy(tensor->data, data, data_size);
|
|
651
1334
|
}
|
|
652
1335
|
|
|
653
|
-
|
|
654
|
-
return -1;
|
|
1336
|
+
return 0;
|
|
655
1337
|
}
|
|
656
1338
|
};
|
|
657
1339
|
|
|
@@ -681,9 +1363,8 @@ static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t bu
|
|
|
681
1363
|
}
|
|
682
1364
|
|
|
683
1365
|
static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
|
684
|
-
return "CPU_KLEIDIAI";
|
|
685
|
-
|
|
686
1366
|
GGML_UNUSED(buft);
|
|
1367
|
+
return "CPU_KLEIDIAI";
|
|
687
1368
|
}
|
|
688
1369
|
|
|
689
1370
|
static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
@@ -702,49 +1383,78 @@ static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(
|
|
|
702
1383
|
}
|
|
703
1384
|
|
|
704
1385
|
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
705
|
-
return TENSOR_ALIGNMENT;
|
|
706
|
-
|
|
707
1386
|
GGML_UNUSED(buft);
|
|
1387
|
+
return TENSOR_ALIGNMENT;
|
|
708
1388
|
}
|
|
709
1389
|
|
|
710
1390
|
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
|
711
1391
|
GGML_UNUSED(buft);
|
|
712
1392
|
|
|
1393
|
+
if (tensor->type != GGML_TYPE_Q4_0 && tensor->type != GGML_TYPE_Q8_0) {
|
|
1394
|
+
return ggml_nbytes(tensor);
|
|
1395
|
+
}
|
|
1396
|
+
|
|
713
1397
|
const size_t n = tensor->ne[1];
|
|
714
1398
|
const size_t k = tensor->ne[0];
|
|
715
1399
|
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
1400
|
+
size_t cursor = sizeof(kleidiai_weight_header);
|
|
1401
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
1402
|
+
|
|
1403
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
1404
|
+
const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
|
|
1405
|
+
const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
|
|
1406
|
+
: kleidiai_collect_q4_chain(kernel_chain);
|
|
1407
|
+
const bool allow_fallback = kleidiai_pack_fallback_allowed();
|
|
1408
|
+
|
|
1409
|
+
size_t slot_count = 0;
|
|
1410
|
+
for (int slot = 0; slot < slot_total; ++slot) {
|
|
1411
|
+
if (!allow_fallback && slot > 0) {
|
|
1412
|
+
break;
|
|
1413
|
+
}
|
|
1414
|
+
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
1415
|
+
if (!kernels) {
|
|
1416
|
+
continue;
|
|
1417
|
+
}
|
|
1418
|
+
kernel_info * kernel = &kernels->gemm;
|
|
1419
|
+
rhs_packing_info * rhs_info = &kernels->rhs_info;
|
|
1420
|
+
if (!kernel || !rhs_info || !rhs_info->packed_size_ex) {
|
|
1421
|
+
continue;
|
|
1422
|
+
}
|
|
1423
|
+
|
|
1424
|
+
const ggml_type rhs_type = kernels->rhs_type;
|
|
1425
|
+
const size_t block_len = rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
|
|
1426
|
+
rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
|
|
1427
|
+
if (block_len == 0) {
|
|
1428
|
+
continue;
|
|
1429
|
+
}
|
|
1430
|
+
|
|
1431
|
+
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
1432
|
+
cursor += rhs_info->packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), block_len);
|
|
1433
|
+
++slot_count;
|
|
729
1434
|
}
|
|
730
1435
|
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
const size_t raw = ggml_nbytes(tensor);
|
|
1436
|
+
if (slot_count == 0) {
|
|
1437
|
+
return ggml_nbytes(tensor);
|
|
1438
|
+
}
|
|
735
1439
|
|
|
736
|
-
return
|
|
1440
|
+
return std::max(cursor, ggml_nbytes(tensor));
|
|
737
1441
|
}
|
|
738
1442
|
|
|
739
1443
|
namespace ggml::cpu::kleidiai {
|
|
740
1444
|
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
741
1445
|
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
|
1446
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
1447
|
+
const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
|
|
742
1448
|
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
|
|
743
1449
|
(op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
|
|
744
1450
|
op->src[0]->buffer &&
|
|
745
1451
|
(ggml_n_dims(op->src[0]) == 2) &&
|
|
746
|
-
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()
|
|
747
|
-
|
|
1452
|
+
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() &&
|
|
1453
|
+
slot_total > 0) {
|
|
1454
|
+
if (op->src[0]->type == GGML_TYPE_Q4_0 && ctx.kernels_q4 == nullptr) {
|
|
1455
|
+
return false;
|
|
1456
|
+
}
|
|
1457
|
+
if (op->src[0]->type == GGML_TYPE_Q8_0 && ctx.kernels_q8 == nullptr) {
|
|
748
1458
|
return false;
|
|
749
1459
|
}
|
|
750
1460
|
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
|
@@ -762,14 +1472,17 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
|
762
1472
|
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
|
|
763
1473
|
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
|
764
1474
|
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
|
765
|
-
}
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
1475
|
+
} else {
|
|
1476
|
+
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
1477
|
+
const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
|
|
1478
|
+
const bool has_kernel = slot_total > 0;
|
|
1479
|
+
if (has_kernel && op->src[1]->ne[1] > 1) {
|
|
1480
|
+
if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
|
|
1481
|
+
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
|
|
1482
|
+
return nullptr;
|
|
1483
|
+
}
|
|
1484
|
+
return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
|
|
770
1485
|
}
|
|
771
|
-
|
|
772
|
-
return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
|
|
773
1486
|
}
|
|
774
1487
|
}
|
|
775
1488
|
return nullptr;
|