whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -25,9 +25,8 @@
|
|
|
25
25
|
#define UNUSED GGML_UNUSED
|
|
26
26
|
|
|
27
27
|
#if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
int8_t * out_scales) {
|
|
28
|
+
// Helper for decoding scales and mins of Q4_K and Q5_K block formats
|
|
29
|
+
static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) {
|
|
31
30
|
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
|
32
31
|
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
|
33
32
|
constexpr uint32_t kmask3 = 0x03030303;
|
|
@@ -499,6 +498,81 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
|
|
499
498
|
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
500
499
|
}
|
|
501
500
|
|
|
501
|
+
void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
502
|
+
const int qk = QK8_0;
|
|
503
|
+
const int nb = n / qk;
|
|
504
|
+
const int ncols_interleaved = 4;
|
|
505
|
+
const int blocklen = 4;
|
|
506
|
+
|
|
507
|
+
assert (n % qk == 0);
|
|
508
|
+
assert (nc % ncols_interleaved == 0);
|
|
509
|
+
|
|
510
|
+
UNUSED(s);
|
|
511
|
+
UNUSED(bs);
|
|
512
|
+
UNUSED(vx);
|
|
513
|
+
UNUSED(vy);
|
|
514
|
+
UNUSED(nr);
|
|
515
|
+
UNUSED(nc);
|
|
516
|
+
UNUSED(nb);
|
|
517
|
+
UNUSED(ncols_interleaved);
|
|
518
|
+
UNUSED(blocklen);
|
|
519
|
+
|
|
520
|
+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
521
|
+
const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);
|
|
522
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
523
|
+
float * res_ptr = s;
|
|
524
|
+
|
|
525
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
526
|
+
const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
|
|
527
|
+
|
|
528
|
+
float32x4_t sumf = vdupq_n_f32(0);
|
|
529
|
+
for (int l = 0; l < nb; l++) {
|
|
530
|
+
uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
|
|
531
|
+
uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
|
|
532
|
+
uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
|
|
533
|
+
uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
|
|
534
|
+
|
|
535
|
+
int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
|
|
536
|
+
int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
|
|
537
|
+
int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
|
|
538
|
+
int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
|
|
539
|
+
int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
|
|
540
|
+
int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
|
|
541
|
+
int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
|
|
542
|
+
int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
|
|
543
|
+
|
|
544
|
+
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
|
|
545
|
+
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
|
|
546
|
+
|
|
547
|
+
int32x4_t sumi = vdupq_n_s32(0);
|
|
548
|
+
sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
|
|
549
|
+
sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
|
|
550
|
+
sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
|
|
551
|
+
sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
|
|
552
|
+
sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
|
|
553
|
+
sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
|
|
554
|
+
sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
|
|
555
|
+
sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
|
|
556
|
+
|
|
557
|
+
float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
|
|
558
|
+
float32x4_t b_d = {
|
|
559
|
+
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),
|
|
560
|
+
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),
|
|
561
|
+
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),
|
|
562
|
+
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),
|
|
563
|
+
};
|
|
564
|
+
float32x4_t d = a_d * b_d;
|
|
565
|
+
|
|
566
|
+
sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
|
|
567
|
+
}
|
|
568
|
+
|
|
569
|
+
vst1q_f32(res_ptr + x * 4, sumf);
|
|
570
|
+
}
|
|
571
|
+
return;
|
|
572
|
+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
573
|
+
ggml_gemv_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
574
|
+
}
|
|
575
|
+
|
|
502
576
|
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
503
577
|
constexpr int qk = QK_K;
|
|
504
578
|
const int nb = n / qk;
|
|
@@ -561,7 +635,7 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
|
561
635
|
for (int i = 0; i < 2; i++) {
|
|
562
636
|
int8_t aux_q4sb[8];
|
|
563
637
|
const int offset = sb * 24 + i * 12;
|
|
564
|
-
|
|
638
|
+
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
|
565
639
|
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
|
566
640
|
}
|
|
567
641
|
|
|
@@ -701,13 +775,13 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
|
|
701
775
|
for (int i = 0; i < 2; i++) {
|
|
702
776
|
int8_t aux_q4sb[8];
|
|
703
777
|
const int offset = sb * 24 + i * 12;
|
|
704
|
-
|
|
778
|
+
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
|
705
779
|
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
|
706
780
|
}
|
|
707
781
|
|
|
708
782
|
const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
|
|
709
783
|
|
|
710
|
-
// Load the 64 quants from q8K duplicated to use vecdots with the
|
|
784
|
+
// Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
|
|
711
785
|
// but still need the qs to use the low and hi bits from q4
|
|
712
786
|
const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
|
|
713
787
|
int8x16_t q8_qs[8];
|
|
@@ -786,17 +860,18 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
|
|
786
860
|
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
787
861
|
}
|
|
788
862
|
|
|
789
|
-
void
|
|
863
|
+
void ggml_gemv_q5_K_8x4_q8_K(int n,
|
|
790
864
|
float * GGML_RESTRICT s,
|
|
791
865
|
size_t bs,
|
|
792
866
|
const void * GGML_RESTRICT vx,
|
|
793
867
|
const void * GGML_RESTRICT vy,
|
|
794
868
|
int nr,
|
|
795
869
|
int nc) {
|
|
796
|
-
|
|
797
|
-
const int
|
|
798
|
-
|
|
799
|
-
|
|
870
|
+
constexpr int qk = QK_K;
|
|
871
|
+
const int nb = n / qk;
|
|
872
|
+
|
|
873
|
+
constexpr int ncols_interleaved = 8;
|
|
874
|
+
constexpr int blocklen = 4;
|
|
800
875
|
|
|
801
876
|
assert(n % qk == 0);
|
|
802
877
|
assert(nc % ncols_interleaved == 0);
|
|
@@ -806,55 +881,156 @@ void ggml_gemv_q8_0_4x4_q8_0(int n,
|
|
|
806
881
|
UNUSED(blocklen);
|
|
807
882
|
|
|
808
883
|
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
809
|
-
|
|
884
|
+
constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
|
|
885
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
886
|
+
const uint8x16_t mone = vdupq_n_u8(1);
|
|
887
|
+
const uint8x16_t mtwo = vdupq_n_u8(2);
|
|
888
|
+
|
|
889
|
+
// 1x8 tile = 2 x 4
|
|
890
|
+
float32x4_t acc_f32[col_groups];
|
|
891
|
+
|
|
892
|
+
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
|
893
|
+
|
|
894
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
895
|
+
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
|
896
|
+
|
|
897
|
+
for (int i = 0; i < col_groups; i++) {
|
|
898
|
+
acc_f32[i] = vdupq_n_f32(0);
|
|
899
|
+
}
|
|
810
900
|
|
|
811
|
-
for (int c = 0; c < nc; c += ncols_interleaved) {
|
|
812
|
-
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
813
|
-
float32x4_t acc = vdupq_n_f32(0);
|
|
814
901
|
for (int b = 0; b < nb; b++) {
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
902
|
+
float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
|
|
903
|
+
float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
|
|
904
|
+
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
|
905
|
+
float32x4_t sb_scale_0123 = vmulq_f32(q5_d_0, q8_d);
|
|
906
|
+
float32x4_t sb_scale_4567 = vmulq_f32(q5_d_1, q8_d);
|
|
907
|
+
float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
|
|
908
|
+
float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
|
|
909
|
+
float32x4_t sb_min_0123 = vmulq_f32(q5_dmin_0, q8_d);
|
|
910
|
+
float32x4_t sb_min_4567 = vmulq_f32(q5_dmin_1, q8_d);
|
|
818
911
|
|
|
819
|
-
|
|
820
|
-
|
|
912
|
+
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
|
|
913
|
+
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
|
|
914
|
+
int32x4_t acc_lo[col_groups];
|
|
915
|
+
int32x4_t acc_hi[col_groups];
|
|
821
916
|
|
|
822
|
-
|
|
917
|
+
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
|
918
|
+
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
|
919
|
+
int16_t bsums_arr[8];
|
|
920
|
+
vst1q_s16(bsums_arr, bsums);
|
|
823
921
|
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
922
|
+
uint8x16_t qh[col_groups][8];
|
|
923
|
+
for (int c = 0; c < col_groups; c++) {
|
|
924
|
+
for (int i = 0; i < 8; i++) {
|
|
925
|
+
qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
|
|
926
|
+
}
|
|
927
|
+
}
|
|
828
928
|
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
929
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
930
|
+
for (int i = 0; i < col_groups; i++) {
|
|
931
|
+
acc_lo[i] = vdupq_n_s32(0);
|
|
932
|
+
acc_hi[i] = vdupq_n_s32(0);
|
|
933
|
+
}
|
|
934
|
+
// Need scales for the low and high nibbles
|
|
935
|
+
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
|
936
|
+
int16x8_t q5sb_mins[2];
|
|
937
|
+
int16x8_t q5sb_scales[2];
|
|
938
|
+
for (int i = 0; i < 2; i++) {
|
|
939
|
+
int8_t aux_q5sb[8];
|
|
940
|
+
const int offset = sb * 24 + i * 12;
|
|
941
|
+
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
|
|
942
|
+
q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
|
|
943
|
+
}
|
|
833
944
|
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
945
|
+
int8x16_t q8_qs[4];
|
|
946
|
+
for (int i = 0; i < 4; i++) {
|
|
947
|
+
q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
|
|
948
|
+
}
|
|
949
|
+
|
|
950
|
+
for (int c = 0; c < col_groups; c++) {
|
|
951
|
+
uint8x16_t q5_cols[8];
|
|
952
|
+
uint8x16_t hbit_lo[8];
|
|
953
|
+
uint8x16_t hbit_hi[8];
|
|
954
|
+
int8x16_t q5_lo[8];
|
|
955
|
+
int8x16_t q5_hi[8];
|
|
956
|
+
|
|
957
|
+
for (int i = 0; i < 8; i++) {
|
|
958
|
+
q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
|
|
959
|
+
hbit_lo[i] = vandq_u8(qh[c][i], mone);
|
|
960
|
+
hbit_hi[i] = vshlq_n_u8(vandq_u8(qh[c][i], mtwo), 3);
|
|
961
|
+
qh[c][i] = vshrq_n_u8(qh[c][i], 2);
|
|
962
|
+
q5_lo[i] = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_cols[i], m4b), hbit_lo[i], 4));
|
|
963
|
+
q5_hi[i] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_cols[i], 4), hbit_hi[i]));
|
|
964
|
+
}
|
|
965
|
+
|
|
966
|
+
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[0], q8_qs[0], 0);
|
|
967
|
+
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[1], q8_qs[0], 1);
|
|
968
|
+
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[2], q8_qs[0], 2);
|
|
969
|
+
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[3], q8_qs[0], 3);
|
|
970
|
+
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[4], q8_qs[1], 0);
|
|
971
|
+
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[5], q8_qs[1], 1);
|
|
972
|
+
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[6], q8_qs[1], 2);
|
|
973
|
+
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[7], q8_qs[1], 3);
|
|
974
|
+
|
|
975
|
+
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[0], q8_qs[2], 0);
|
|
976
|
+
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[1], q8_qs[2], 1);
|
|
977
|
+
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[2], q8_qs[2], 2);
|
|
978
|
+
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[3], q8_qs[2], 3);
|
|
979
|
+
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[4], q8_qs[3], 0);
|
|
980
|
+
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[5], q8_qs[3], 1);
|
|
981
|
+
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[6], q8_qs[3], 2);
|
|
982
|
+
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[7], q8_qs[3], 3);
|
|
983
|
+
}
|
|
984
|
+
|
|
985
|
+
// Scales
|
|
986
|
+
// row c0123 blk0 and blk1
|
|
987
|
+
const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
|
|
988
|
+
const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
|
|
989
|
+
const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
|
|
990
|
+
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
|
|
991
|
+
acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
|
|
992
|
+
// row c4567 blk0 and blk1
|
|
993
|
+
const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
|
|
994
|
+
const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
|
|
995
|
+
const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
|
|
996
|
+
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
|
|
997
|
+
acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
|
|
998
|
+
|
|
999
|
+
// Bias Correction
|
|
1000
|
+
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
|
1001
|
+
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
|
1002
|
+
|
|
1003
|
+
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
|
|
1004
|
+
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
|
|
1005
|
+
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
|
|
1006
|
+
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
|
|
1007
|
+
} // for sb
|
|
1008
|
+
|
|
1009
|
+
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
|
|
1010
|
+
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
|
|
1011
|
+
} // for b
|
|
842
1012
|
|
|
1013
|
+
int base = x * ncols_interleaved;
|
|
1014
|
+
vst1q_f32(s + base, acc_f32[0]);
|
|
1015
|
+
vst1q_f32(s + base + 4, acc_f32[1]);
|
|
1016
|
+
} // for x
|
|
1017
|
+
return;
|
|
843
1018
|
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
844
|
-
|
|
1019
|
+
ggml_gemv_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
845
1020
|
}
|
|
846
1021
|
|
|
847
|
-
void
|
|
1022
|
+
void ggml_gemv_q5_K_8x8_q8_K(int n,
|
|
848
1023
|
float * GGML_RESTRICT s,
|
|
849
1024
|
size_t bs,
|
|
850
1025
|
const void * GGML_RESTRICT vx,
|
|
851
1026
|
const void * GGML_RESTRICT vy,
|
|
852
1027
|
int nr,
|
|
853
1028
|
int nc) {
|
|
854
|
-
|
|
855
|
-
const int
|
|
856
|
-
|
|
857
|
-
|
|
1029
|
+
constexpr int qk = QK_K;
|
|
1030
|
+
const int nb = n / qk;
|
|
1031
|
+
|
|
1032
|
+
constexpr int ncols_interleaved = 8;
|
|
1033
|
+
constexpr int blocklen = 8;
|
|
858
1034
|
|
|
859
1035
|
assert(n % qk == 0);
|
|
860
1036
|
assert(nc % ncols_interleaved == 0);
|
|
@@ -864,269 +1040,1003 @@ void ggml_gemv_q8_0_4x8_q8_0(int n,
|
|
|
864
1040
|
UNUSED(blocklen);
|
|
865
1041
|
|
|
866
1042
|
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
867
|
-
|
|
1043
|
+
constexpr int col_pairs = ncols_interleaved / 2;
|
|
1044
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
1045
|
+
const uint8x16_t mone = vdupq_n_u8(1);
|
|
1046
|
+
const uint8x16_t mtwo = vdupq_n_u8(2);
|
|
868
1047
|
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
float32x4_t acc = vdupq_n_f32(0);
|
|
1048
|
+
// 1x8 tile = 2 x 4
|
|
1049
|
+
float32x4_t acc_f32[ncols_interleaved / 4];
|
|
872
1050
|
|
|
873
|
-
|
|
874
|
-
int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
|
|
875
|
-
int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
|
|
876
|
-
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
|
|
1051
|
+
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
|
877
1052
|
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);
|
|
881
|
-
int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);
|
|
882
|
-
int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);
|
|
883
|
-
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
|
|
1053
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
1054
|
+
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
|
884
1055
|
|
|
885
|
-
|
|
886
|
-
|
|
1056
|
+
for (int i = 0; i < ncols_interleaved / 4; i++) {
|
|
1057
|
+
acc_f32[i] = vdupq_n_f32(0);
|
|
1058
|
+
}
|
|
887
1059
|
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
//
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
ret0 = vdotq_s32(ret0, b_high.val[2], a3);
|
|
899
|
-
ret1 = vdotq_s32(ret1, b_high.val[3], a3);
|
|
1060
|
+
for (int b = 0; b < nb; b++) {
|
|
1061
|
+
float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
|
|
1062
|
+
float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
|
|
1063
|
+
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
|
1064
|
+
float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d);
|
|
1065
|
+
float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d);
|
|
1066
|
+
float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
|
|
1067
|
+
float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
|
|
1068
|
+
float32x4_t sb_min_0 = vmulq_f32(q5_dmin_0, q8_d);
|
|
1069
|
+
float32x4_t sb_min_1 = vmulq_f32(q5_dmin_1, q8_d);
|
|
900
1070
|
|
|
901
|
-
|
|
1071
|
+
// 2 sb each iteration
|
|
1072
|
+
int32x4_t acc_lo[col_pairs];
|
|
1073
|
+
int32x4_t acc_hi[col_pairs];
|
|
902
1074
|
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
vst1q_f32(s, acc);
|
|
908
|
-
s += ncols_interleaved;
|
|
909
|
-
}
|
|
910
|
-
return;
|
|
1075
|
+
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
|
1076
|
+
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
|
1077
|
+
int16_t bsums_arr[8];
|
|
1078
|
+
vst1q_s16(bsums_arr, bsums);
|
|
911
1079
|
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
1080
|
+
// Load qh once per block and shift after each subblock
|
|
1081
|
+
const uint8_t * qh_base = q5_ptr[b].qh;
|
|
1082
|
+
uint8x16_t qh[col_pairs][4];
|
|
1083
|
+
for (int cp = 0; cp < col_pairs; cp++) {
|
|
1084
|
+
qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
|
|
1085
|
+
qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
|
|
1086
|
+
qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
|
|
1087
|
+
qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
|
|
1088
|
+
}
|
|
915
1089
|
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
1090
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
1091
|
+
for (int i = 0; i < col_pairs; i++) {
|
|
1092
|
+
acc_lo[i] = vdupq_n_s32(0);
|
|
1093
|
+
acc_hi[i] = vdupq_n_s32(0);
|
|
1094
|
+
}
|
|
1095
|
+
// Need scales for the low and high nibbles
|
|
1096
|
+
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
|
1097
|
+
int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
|
|
1098
|
+
int16x8_t q5sb_scales[2];
|
|
1099
|
+
for (int i = 0; i < 2; i++) {
|
|
1100
|
+
int8_t aux_q5sb[8];
|
|
1101
|
+
const int offset = sb * 24 + i * 12;
|
|
1102
|
+
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
|
|
1103
|
+
q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
|
|
1104
|
+
}
|
|
921
1105
|
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
1106
|
+
const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K;
|
|
1107
|
+
|
|
1108
|
+
// Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
|
|
1109
|
+
const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
|
|
1110
|
+
int8x16_t q8_qs[8];
|
|
1111
|
+
for (int i = 0; i < 8; i++) {
|
|
1112
|
+
q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
|
|
1113
|
+
}
|
|
1114
|
+
|
|
1115
|
+
// Q5s column pair loop unrolled
|
|
1116
|
+
{
|
|
1117
|
+
// Cols 01
|
|
1118
|
+
uint8x16_t qs_0 = vld1q_u8(qs_base);
|
|
1119
|
+
uint8x16_t qs_1 = vld1q_u8(qs_base + 64);
|
|
1120
|
+
uint8x16_t qs_2 = vld1q_u8(qs_base + 128);
|
|
1121
|
+
uint8x16_t qs_3 = vld1q_u8(qs_base + 192);
|
|
1122
|
+
|
|
1123
|
+
uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone);
|
|
1124
|
+
uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone);
|
|
1125
|
+
uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone);
|
|
1126
|
+
uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone);
|
|
1127
|
+
uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3);
|
|
1128
|
+
uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3);
|
|
1129
|
+
uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3);
|
|
1130
|
+
uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3);
|
|
1131
|
+
|
|
1132
|
+
qh[0][0] = vshrq_n_u8(qh[0][0], 2);
|
|
1133
|
+
qh[0][1] = vshrq_n_u8(qh[0][1], 2);
|
|
1134
|
+
qh[0][2] = vshrq_n_u8(qh[0][2], 2);
|
|
1135
|
+
qh[0][3] = vshrq_n_u8(qh[0][3], 2);
|
|
1136
|
+
|
|
1137
|
+
acc_lo[0] = ggml_vdotq_s32(
|
|
1138
|
+
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
|
|
1139
|
+
acc_lo[0] = ggml_vdotq_s32(
|
|
1140
|
+
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
|
|
1141
|
+
acc_lo[0] = ggml_vdotq_s32(
|
|
1142
|
+
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
|
|
1143
|
+
acc_lo[0] = ggml_vdotq_s32(
|
|
1144
|
+
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
|
|
1145
|
+
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
|
|
1146
|
+
q8_qs[4]);
|
|
1147
|
+
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
|
|
1148
|
+
q8_qs[5]);
|
|
1149
|
+
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
|
|
1150
|
+
q8_qs[6]);
|
|
1151
|
+
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
|
|
1152
|
+
q8_qs[7]);
|
|
1153
|
+
|
|
1154
|
+
// Cols 23
|
|
1155
|
+
qs_0 = vld1q_u8(qs_base + 16);
|
|
1156
|
+
qs_1 = vld1q_u8(qs_base + 80);
|
|
1157
|
+
qs_2 = vld1q_u8(qs_base + 144);
|
|
1158
|
+
qs_3 = vld1q_u8(qs_base + 208);
|
|
1159
|
+
|
|
1160
|
+
hbit_lo_0 = vandq_u8(qh[1][0], mone);
|
|
1161
|
+
hbit_lo_1 = vandq_u8(qh[1][1], mone);
|
|
1162
|
+
hbit_lo_2 = vandq_u8(qh[1][2], mone);
|
|
1163
|
+
hbit_lo_3 = vandq_u8(qh[1][3], mone);
|
|
1164
|
+
hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3);
|
|
1165
|
+
hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3);
|
|
1166
|
+
hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3);
|
|
1167
|
+
hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3);
|
|
1168
|
+
|
|
1169
|
+
qh[1][0] = vshrq_n_u8(qh[1][0], 2);
|
|
1170
|
+
qh[1][1] = vshrq_n_u8(qh[1][1], 2);
|
|
1171
|
+
qh[1][2] = vshrq_n_u8(qh[1][2], 2);
|
|
1172
|
+
qh[1][3] = vshrq_n_u8(qh[1][3], 2);
|
|
1173
|
+
|
|
1174
|
+
acc_lo[1] = ggml_vdotq_s32(
|
|
1175
|
+
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
|
|
1176
|
+
acc_lo[1] = ggml_vdotq_s32(
|
|
1177
|
+
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
|
|
1178
|
+
acc_lo[1] = ggml_vdotq_s32(
|
|
1179
|
+
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
|
|
1180
|
+
acc_lo[1] = ggml_vdotq_s32(
|
|
1181
|
+
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
|
|
1182
|
+
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
|
|
1183
|
+
q8_qs[4]);
|
|
1184
|
+
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
|
|
1185
|
+
q8_qs[5]);
|
|
1186
|
+
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
|
|
1187
|
+
q8_qs[6]);
|
|
1188
|
+
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
|
|
1189
|
+
q8_qs[7]);
|
|
1190
|
+
|
|
1191
|
+
// Cols 45
|
|
1192
|
+
qs_0 = vld1q_u8(qs_base + 32);
|
|
1193
|
+
qs_1 = vld1q_u8(qs_base + 96);
|
|
1194
|
+
qs_2 = vld1q_u8(qs_base + 160);
|
|
1195
|
+
qs_3 = vld1q_u8(qs_base + 224);
|
|
1196
|
+
|
|
1197
|
+
hbit_lo_0 = vandq_u8(qh[2][0], mone);
|
|
1198
|
+
hbit_lo_1 = vandq_u8(qh[2][1], mone);
|
|
1199
|
+
hbit_lo_2 = vandq_u8(qh[2][2], mone);
|
|
1200
|
+
hbit_lo_3 = vandq_u8(qh[2][3], mone);
|
|
1201
|
+
hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3);
|
|
1202
|
+
hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3);
|
|
1203
|
+
hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3);
|
|
1204
|
+
hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3);
|
|
1205
|
+
|
|
1206
|
+
qh[2][0] = vshrq_n_u8(qh[2][0], 2);
|
|
1207
|
+
qh[2][1] = vshrq_n_u8(qh[2][1], 2);
|
|
1208
|
+
qh[2][2] = vshrq_n_u8(qh[2][2], 2);
|
|
1209
|
+
qh[2][3] = vshrq_n_u8(qh[2][3], 2);
|
|
1210
|
+
|
|
1211
|
+
acc_lo[2] = ggml_vdotq_s32(
|
|
1212
|
+
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
|
|
1213
|
+
acc_lo[2] = ggml_vdotq_s32(
|
|
1214
|
+
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
|
|
1215
|
+
acc_lo[2] = ggml_vdotq_s32(
|
|
1216
|
+
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
|
|
1217
|
+
acc_lo[2] = ggml_vdotq_s32(
|
|
1218
|
+
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
|
|
1219
|
+
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
|
|
1220
|
+
q8_qs[4]);
|
|
1221
|
+
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
|
|
1222
|
+
q8_qs[5]);
|
|
1223
|
+
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
|
|
1224
|
+
q8_qs[6]);
|
|
1225
|
+
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
|
|
1226
|
+
q8_qs[7]);
|
|
1227
|
+
|
|
1228
|
+
// Cols 45
|
|
1229
|
+
qs_0 = vld1q_u8(qs_base + 48);
|
|
1230
|
+
qs_1 = vld1q_u8(qs_base + 112);
|
|
1231
|
+
qs_2 = vld1q_u8(qs_base + 176);
|
|
1232
|
+
qs_3 = vld1q_u8(qs_base + 240);
|
|
1233
|
+
|
|
1234
|
+
hbit_lo_0 = vandq_u8(qh[3][0], mone);
|
|
1235
|
+
hbit_lo_1 = vandq_u8(qh[3][1], mone);
|
|
1236
|
+
hbit_lo_2 = vandq_u8(qh[3][2], mone);
|
|
1237
|
+
hbit_lo_3 = vandq_u8(qh[3][3], mone);
|
|
1238
|
+
hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3);
|
|
1239
|
+
hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3);
|
|
1240
|
+
hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3);
|
|
1241
|
+
hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3);
|
|
1242
|
+
|
|
1243
|
+
qh[3][0] = vshrq_n_u8(qh[3][0], 2);
|
|
1244
|
+
qh[3][1] = vshrq_n_u8(qh[3][1], 2);
|
|
1245
|
+
qh[3][2] = vshrq_n_u8(qh[3][2], 2);
|
|
1246
|
+
qh[3][3] = vshrq_n_u8(qh[3][3], 2);
|
|
1247
|
+
|
|
1248
|
+
acc_lo[3] = ggml_vdotq_s32(
|
|
1249
|
+
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
|
|
1250
|
+
acc_lo[3] = ggml_vdotq_s32(
|
|
1251
|
+
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
|
|
1252
|
+
acc_lo[3] = ggml_vdotq_s32(
|
|
1253
|
+
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
|
|
1254
|
+
acc_lo[3] = ggml_vdotq_s32(
|
|
1255
|
+
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
|
|
1256
|
+
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
|
|
1257
|
+
q8_qs[4]);
|
|
1258
|
+
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
|
|
1259
|
+
q8_qs[5]);
|
|
1260
|
+
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
|
|
1261
|
+
q8_qs[6]);
|
|
1262
|
+
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
|
|
1263
|
+
q8_qs[7]);
|
|
1264
|
+
}
|
|
1265
|
+
|
|
1266
|
+
// Prepare bsum vectors for bias computation
|
|
1267
|
+
// Each pair of subblocks share the same bsums
|
|
1268
|
+
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
|
1269
|
+
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
|
1270
|
+
|
|
1271
|
+
// Iterates over a pair of column pairs (4 columns) to use a single 128 register
|
|
1272
|
+
// p = 0 -> 0123 p2 -> 4567
|
|
1273
|
+
for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
|
|
1274
|
+
int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]);
|
|
1275
|
+
int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]);
|
|
1276
|
+
int16x4_t group_mins_lo = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]);
|
|
1277
|
+
int16x4_t group_mins_hi = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]);
|
|
1278
|
+
float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
|
|
1279
|
+
float32x4_t sb_min = p == 0 ? sb_min_0 : sb_min_1;
|
|
1280
|
+
|
|
1281
|
+
// 0123 or 4567
|
|
1282
|
+
float32x4_t sumf_0 =
|
|
1283
|
+
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
|
|
1284
|
+
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
|
|
1285
|
+
|
|
1286
|
+
float32x4_t sumf_1 =
|
|
1287
|
+
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
|
|
1288
|
+
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
|
|
1289
|
+
|
|
1290
|
+
// FUSED BIAS: Compute and subtract bias immediately
|
|
1291
|
+
// bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min
|
|
1292
|
+
int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo);
|
|
1293
|
+
bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi);
|
|
1294
|
+
float32x4_t bias_f32 = vcvtq_f32_s32(bias);
|
|
1295
|
+
acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32);
|
|
1296
|
+
}
|
|
1297
|
+
} // for sb
|
|
1298
|
+
} // for b
|
|
1299
|
+
|
|
1300
|
+
int base = x * ncols_interleaved;
|
|
1301
|
+
vst1q_f32(s + base, acc_f32[0]);
|
|
1302
|
+
vst1q_f32(s + base + 4, acc_f32[1]);
|
|
1303
|
+
} // for x
|
|
1304
|
+
return;
|
|
1305
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
1306
|
+
ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
1307
|
+
}
|
|
1308
|
+
|
|
1309
|
+
void ggml_gemv_q6_K_8x4_q8_K(int n,
|
|
1310
|
+
float * GGML_RESTRICT s,
|
|
1311
|
+
size_t bs,
|
|
1312
|
+
const void * GGML_RESTRICT vx,
|
|
1313
|
+
const void * GGML_RESTRICT vy,
|
|
1314
|
+
int nr,
|
|
1315
|
+
int nc) {
|
|
1316
|
+
constexpr int qk = QK_K;
|
|
1317
|
+
const int nb = n / qk;
|
|
1318
|
+
|
|
1319
|
+
constexpr int ncols_interleaved = 8;
|
|
1320
|
+
constexpr int blocklen = 4;
|
|
1321
|
+
|
|
1322
|
+
assert(n % qk == 0);
|
|
1323
|
+
assert(nc % ncols_interleaved == 0);
|
|
925
1324
|
|
|
926
|
-
UNUSED(s);
|
|
927
|
-
UNUSED(bs);
|
|
928
|
-
UNUSED(vx);
|
|
929
|
-
UNUSED(vy);
|
|
930
|
-
UNUSED(nr);
|
|
931
|
-
UNUSED(nc);
|
|
932
1325
|
UNUSED(nb);
|
|
933
1326
|
UNUSED(ncols_interleaved);
|
|
934
1327
|
UNUSED(blocklen);
|
|
935
1328
|
|
|
936
|
-
#if
|
|
937
|
-
|
|
938
|
-
const
|
|
939
|
-
|
|
940
|
-
|
|
1329
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
1330
|
+
constexpr int col_groups = ncols_interleaved / 4;
|
|
1331
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
1332
|
+
const uint8x16_t mask_lo = vdupq_n_u8(0x03);
|
|
1333
|
+
const uint8x16_t mask_hi = vdupq_n_u8(0x30);
|
|
941
1334
|
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1335
|
+
// 1x8 tile = 2 x 4
|
|
1336
|
+
float32x4_t acc_f32[2];
|
|
1337
|
+
|
|
1338
|
+
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
|
1339
|
+
|
|
1340
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
1341
|
+
const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
|
1342
|
+
|
|
1343
|
+
for (int i = 0; i < col_groups; i++) {
|
|
1344
|
+
acc_f32[i] = vdupq_n_f32(0);
|
|
1345
|
+
}
|
|
1346
|
+
|
|
1347
|
+
for (int b = 0; b < nb; b++) {
|
|
1348
|
+
float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3
|
|
1349
|
+
float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7
|
|
1350
|
+
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
|
1351
|
+
float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
|
|
1352
|
+
float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
|
|
1353
|
+
|
|
1354
|
+
int32x4_t acc[col_groups];
|
|
1355
|
+
for (int i = 0; i < col_groups; i++) {
|
|
1356
|
+
acc[i] = vdupq_n_s32(0);
|
|
1357
|
+
}
|
|
1358
|
+
|
|
1359
|
+
// Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
|
|
1360
|
+
// Reused for bias and dequantization later
|
|
1361
|
+
int16_t q6_scales[16 * 8];
|
|
1362
|
+
for (int i = 0; i < 16; i++) {
|
|
1363
|
+
int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
|
|
1364
|
+
vst1q_s16(q6_scales + i * 8, scales);
|
|
1365
|
+
}
|
|
1366
|
+
|
|
1367
|
+
// Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
|
|
1368
|
+
int32x4_t bias_lo = vdupq_n_s32(0);
|
|
1369
|
+
int32x4_t bias_hi = vdupq_n_s32(0);
|
|
1370
|
+
|
|
1371
|
+
// Load bsums in chunks of 4 to process with vectorized operations
|
|
1372
|
+
for (int i = 0; i < 16; i += 4) {
|
|
1373
|
+
int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i);
|
|
1374
|
+
int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
|
|
1375
|
+
int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
|
|
1376
|
+
int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
|
|
1377
|
+
int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
|
|
1378
|
+
int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
|
|
1379
|
+
int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
|
|
1380
|
+
int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
|
|
1381
|
+
int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
|
|
1382
|
+
|
|
1383
|
+
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
|
|
1384
|
+
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
|
|
1385
|
+
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
|
|
1386
|
+
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
|
|
1387
|
+
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
|
|
1388
|
+
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
|
|
1389
|
+
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
|
|
1390
|
+
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
|
|
1391
|
+
}
|
|
1392
|
+
bias_lo = vshlq_n_s32(bias_lo, 5);
|
|
1393
|
+
bias_hi = vshlq_n_s32(bias_hi, 5);
|
|
1394
|
+
|
|
1395
|
+
// Process two 128-value halves per superblock
|
|
1396
|
+
for (int half = 0; half < 2; half++) {
|
|
1397
|
+
const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
|
|
1398
|
+
const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
|
|
1399
|
+
|
|
1400
|
+
// A subblock (sb) is a set of weights that share the scale
|
|
1401
|
+
// Since q6_K scales are per 16 elements
|
|
1402
|
+
// num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
|
|
1403
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
1404
|
+
const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
|
|
1405
|
+
const int8_t * q8_base_h = q8_base_l + 64;
|
|
1406
|
+
|
|
1407
|
+
// Load and duplicate q8 values (each register covers four interleaved columns of q6)
|
|
1408
|
+
int8x16_t q8_l[4];
|
|
1409
|
+
int8x16_t q8_h[4];
|
|
1410
|
+
for (int i = 0; i < 4; i++) {
|
|
1411
|
+
q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4));
|
|
1412
|
+
q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4));
|
|
1413
|
+
}
|
|
1414
|
+
|
|
1415
|
+
const int ql_off_base = sb * QK_K / 2;
|
|
1416
|
+
const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
|
|
1417
|
+
|
|
1418
|
+
// Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
|
|
1419
|
+
uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
|
|
1420
|
+
uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
|
|
1421
|
+
uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
|
|
1422
|
+
uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
|
|
1423
|
+
|
|
1424
|
+
// Adjust qh for subblocks 2 and 3 (shift right by 2)
|
|
1425
|
+
if (sb > 1) {
|
|
1426
|
+
q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
|
|
1427
|
+
q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
|
|
1428
|
+
q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
|
|
1429
|
+
q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
|
|
1430
|
+
q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
|
|
1431
|
+
q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
|
|
1432
|
+
q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
|
|
1433
|
+
q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
|
|
1434
|
+
}
|
|
1435
|
+
|
|
1436
|
+
const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3],
|
|
1437
|
+
q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] };
|
|
1438
|
+
const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3],
|
|
1439
|
+
q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] };
|
|
1440
|
+
|
|
1441
|
+
// Process column groups (0-3, 4-7)
|
|
1442
|
+
for (int g = 0; g < col_groups; g++) {
|
|
1443
|
+
int32x4_t sb_acc_l = vdupq_n_s32(0);
|
|
1444
|
+
int32x4_t sb_acc_h = vdupq_n_s32(0);
|
|
1445
|
+
|
|
1446
|
+
for (int chunk = 0; chunk < 4; chunk++) {
|
|
1447
|
+
const int idx = chunk * 2 + g;
|
|
1448
|
+
|
|
1449
|
+
const uint8x16_t q6_qs_l = q6_ql[idx];
|
|
1450
|
+
const uint8x16_t q6_qs_h = q6_qh[idx];
|
|
1451
|
+
|
|
1452
|
+
// Extract high 2 bits for upper nibble reconstruction
|
|
1453
|
+
const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi);
|
|
1454
|
+
|
|
1455
|
+
// q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
|
|
1456
|
+
const int8x16_t q6_l =
|
|
1457
|
+
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4));
|
|
1458
|
+
const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh));
|
|
1459
|
+
|
|
1460
|
+
sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]);
|
|
1461
|
+
sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]);
|
|
1462
|
+
}
|
|
1463
|
+
|
|
1464
|
+
const int scale_idx_l = half * 8 + sb;
|
|
1465
|
+
const int scale_idx_h = half * 8 + sb + 4;
|
|
1466
|
+
|
|
1467
|
+
const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4));
|
|
1468
|
+
const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4));
|
|
1469
|
+
|
|
1470
|
+
acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l);
|
|
1471
|
+
acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h);
|
|
1472
|
+
}
|
|
1473
|
+
}
|
|
1474
|
+
} // for half
|
|
1475
|
+
|
|
1476
|
+
// Bias correction
|
|
1477
|
+
acc[0] = vsubq_s32(acc[0], bias_lo);
|
|
1478
|
+
acc[1] = vsubq_s32(acc[1], bias_hi);
|
|
1479
|
+
|
|
1480
|
+
// Apply superblock scale (no mins for q6_K)
|
|
1481
|
+
// acc[g] has [c0, c1, c2, c3]
|
|
1482
|
+
float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0);
|
|
1483
|
+
float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1);
|
|
1484
|
+
|
|
1485
|
+
acc_f32[0] = vaddq_f32(acc_f32[0], w_0123);
|
|
1486
|
+
acc_f32[1] = vaddq_f32(acc_f32[1], w_4567);
|
|
1487
|
+
} // for b
|
|
1488
|
+
|
|
1489
|
+
int base = x * ncols_interleaved;
|
|
1490
|
+
vst1q_f32(s + base, acc_f32[0]);
|
|
1491
|
+
vst1q_f32(s + base + 4, acc_f32[1]);
|
|
1492
|
+
} // for x
|
|
1493
|
+
return;
|
|
1494
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
1495
|
+
ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
1496
|
+
}
|
|
1497
|
+
|
|
1498
|
+
void ggml_gemv_q6_K_8x8_q8_K(int n,
|
|
1499
|
+
float * GGML_RESTRICT s,
|
|
1500
|
+
size_t bs,
|
|
1501
|
+
const void * GGML_RESTRICT vx,
|
|
1502
|
+
const void * GGML_RESTRICT vy,
|
|
1503
|
+
int nr,
|
|
1504
|
+
int nc) {
|
|
1505
|
+
constexpr int qk = QK_K;
|
|
1506
|
+
const int nb = n / qk;
|
|
1507
|
+
|
|
1508
|
+
constexpr int ncols_interleaved = 8;
|
|
1509
|
+
constexpr int blocklen = 8;
|
|
1510
|
+
|
|
1511
|
+
assert(n % qk == 0);
|
|
1512
|
+
assert(nc % ncols_interleaved == 0);
|
|
1513
|
+
|
|
1514
|
+
UNUSED(nb);
|
|
1515
|
+
UNUSED(ncols_interleaved);
|
|
1516
|
+
UNUSED(blocklen);
|
|
1517
|
+
|
|
1518
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
1519
|
+
constexpr int col_pairs = ncols_interleaved / 2;
|
|
1520
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
1521
|
+
const uint8x16_t mask_lo = vdupq_n_u8(0x03);
|
|
1522
|
+
const uint8x16_t mask_hi = vdupq_n_u8(0x30);
|
|
1523
|
+
|
|
1524
|
+
// 1x8 tile = 2 x 4
|
|
1525
|
+
float32x4_t acc_f32[2];
|
|
1526
|
+
|
|
1527
|
+
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
|
1528
|
+
|
|
1529
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
1530
|
+
const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
|
1531
|
+
|
|
1532
|
+
acc_f32[0] = vdupq_n_f32(0);
|
|
1533
|
+
acc_f32[1] = vdupq_n_f32(0);
|
|
1534
|
+
|
|
1535
|
+
for (int b = 0; b < nb; b++) {
|
|
1536
|
+
float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3
|
|
1537
|
+
float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7
|
|
1538
|
+
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
|
1539
|
+
float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
|
|
1540
|
+
float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
|
|
1541
|
+
|
|
1542
|
+
int32x2_t acc[col_pairs];
|
|
1543
|
+
for (int i = 0; i < col_pairs; i++) {
|
|
1544
|
+
acc[i] = vdup_n_s32(0);
|
|
1545
|
+
}
|
|
1546
|
+
|
|
1547
|
+
// Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
|
|
1548
|
+
// Reused for bias and dequantization later
|
|
1549
|
+
int16_t q6_scales[16 * 8];
|
|
1550
|
+
for (int i = 0; i < 16; i++) {
|
|
1551
|
+
int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
|
|
1552
|
+
vst1q_s16(q6_scales + i * 8, scales);
|
|
1553
|
+
}
|
|
1554
|
+
|
|
1555
|
+
// Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
|
|
1556
|
+
int32x4_t bias_lo = vdupq_n_s32(0);
|
|
1557
|
+
int32x4_t bias_hi = vdupq_n_s32(0);
|
|
1558
|
+
|
|
1559
|
+
// Load bsums in chunks of 4 to process with vectorized operations
|
|
1560
|
+
for (int i = 0; i < 16; i += 4) {
|
|
1561
|
+
int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i);
|
|
1562
|
+
int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
|
|
1563
|
+
int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
|
|
1564
|
+
int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
|
|
1565
|
+
int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
|
|
1566
|
+
int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
|
|
1567
|
+
int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
|
|
1568
|
+
int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
|
|
1569
|
+
int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
|
|
1570
|
+
|
|
1571
|
+
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
|
|
1572
|
+
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
|
|
1573
|
+
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
|
|
1574
|
+
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
|
|
1575
|
+
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
|
|
1576
|
+
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
|
|
1577
|
+
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
|
|
1578
|
+
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
|
|
1579
|
+
}
|
|
1580
|
+
bias_lo = vshlq_n_s32(bias_lo, 5);
|
|
1581
|
+
bias_hi = vshlq_n_s32(bias_hi, 5);
|
|
1582
|
+
|
|
1583
|
+
// Process two 128-value halves per superblock
|
|
1584
|
+
for (int half = 0; half < 2; half++) {
|
|
1585
|
+
const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
|
|
1586
|
+
const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
|
|
1587
|
+
|
|
1588
|
+
// A subblock (sb) is a set of weights that share the scale
|
|
1589
|
+
// Since q6_K scales are per 16 elements
|
|
1590
|
+
// num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
|
|
1591
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
1592
|
+
const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
|
|
1593
|
+
const int8_t * q8_base_h = q8_base_l + 64;
|
|
1594
|
+
|
|
1595
|
+
// Load and duplicate q8 values (each register covers two interleaved columns of q6)
|
|
1596
|
+
int8x16_t q8_l[2];
|
|
1597
|
+
int8x16_t q8_h[2];
|
|
1598
|
+
for (int i = 0; i < 2; i++) {
|
|
1599
|
+
q8_l[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_l + i * 8));
|
|
1600
|
+
q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8));
|
|
1601
|
+
}
|
|
1602
|
+
|
|
1603
|
+
const int ql_off_base = sb * QK_K / 2;
|
|
1604
|
+
const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
|
|
1605
|
+
|
|
1606
|
+
// Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
|
|
1607
|
+
uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
|
|
1608
|
+
uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
|
|
1609
|
+
uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
|
|
1610
|
+
uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
|
|
1611
|
+
|
|
1612
|
+
// Adjust qh for subblocks 2 and 3 (shift right by 2)
|
|
1613
|
+
if (sb > 1) {
|
|
1614
|
+
q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
|
|
1615
|
+
q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
|
|
1616
|
+
q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
|
|
1617
|
+
q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
|
|
1618
|
+
q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
|
|
1619
|
+
q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
|
|
1620
|
+
q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
|
|
1621
|
+
q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
|
|
1622
|
+
}
|
|
1623
|
+
|
|
1624
|
+
// Process column pairs (0-1, 2-3, 4-5, 6-7)
|
|
1625
|
+
for (int cp = 0; cp < col_pairs; cp++) {
|
|
1626
|
+
const uint8x16_t q6_qs_cp_0_l = q6_ql_0.val[cp];
|
|
1627
|
+
const uint8x16_t q6_qs_cp_1_l = q6_ql_1.val[cp];
|
|
1628
|
+
const uint8x16_t q6_qs_cp_0_h = q6_qh_0.val[cp];
|
|
1629
|
+
const uint8x16_t q6_qs_cp_1_h = q6_qh_1.val[cp];
|
|
1630
|
+
|
|
1631
|
+
// Extract high 2 bits for upper nibble reconstruction
|
|
1632
|
+
const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);
|
|
1633
|
+
const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);
|
|
1634
|
+
|
|
1635
|
+
// q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
|
|
1636
|
+
const int8x16_t q6_l0 = vreinterpretq_s8_u8(
|
|
1637
|
+
vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4));
|
|
1638
|
+
const int8x16_t q6_l1 = vreinterpretq_s8_u8(
|
|
1639
|
+
vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4));
|
|
1640
|
+
const int8x16_t q6_h0 =
|
|
1641
|
+
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh));
|
|
1642
|
+
const int8x16_t q6_h1 =
|
|
1643
|
+
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh));
|
|
1644
|
+
|
|
1645
|
+
int32x4_t sb_acc_l = vdupq_n_s32(0);
|
|
1646
|
+
sb_acc_l = vdotq_s32(sb_acc_l, q6_l0, q8_l[0]);
|
|
1647
|
+
sb_acc_l = vdotq_s32(sb_acc_l, q6_l1, q8_l[1]);
|
|
1648
|
+
|
|
1649
|
+
int32x4_t sb_acc_h = vdupq_n_s32(0);
|
|
1650
|
+
sb_acc_h = vdotq_s32(sb_acc_h, q6_h0, q8_h[0]);
|
|
1651
|
+
sb_acc_h = vdotq_s32(sb_acc_h, q6_h1, q8_h[1]);
|
|
1652
|
+
|
|
1653
|
+
// Pairwise add to get per-column sums: [col0, col1]
|
|
1654
|
+
int32x2_t sum_l = vpadd_s32(vget_low_s32(sb_acc_l), vget_high_s32(sb_acc_l));
|
|
1655
|
+
int32x2_t sum_h = vpadd_s32(vget_low_s32(sb_acc_h), vget_high_s32(sb_acc_h));
|
|
1656
|
+
|
|
1657
|
+
const int scale_idx_l = half * 8 + sb;
|
|
1658
|
+
const int scale_idx_h = half * 8 + sb + 4;
|
|
1659
|
+
|
|
1660
|
+
// Access scales using array indexing (scales are interleaved by column)
|
|
1661
|
+
const int32x2_t scale_vec_l = { (int32_t) q6_scales[scale_idx_l * 8 + cp * 2],
|
|
1662
|
+
(int32_t) q6_scales[scale_idx_l * 8 + cp * 2 + 1] };
|
|
1663
|
+
const int32x2_t scale_vec_h = { (int32_t) q6_scales[scale_idx_h * 8 + cp * 2],
|
|
1664
|
+
(int32_t) q6_scales[scale_idx_h * 8 + cp * 2 + 1] };
|
|
1665
|
+
|
|
1666
|
+
// Accumulate scaled results
|
|
1667
|
+
acc[cp] = vmla_s32(acc[cp], sum_l, scale_vec_l);
|
|
1668
|
+
acc[cp] = vmla_s32(acc[cp], sum_h, scale_vec_h);
|
|
1669
|
+
}
|
|
1670
|
+
}
|
|
1671
|
+
} // for half
|
|
1672
|
+
|
|
1673
|
+
// Bias correction
|
|
1674
|
+
acc[0] = vsub_s32(acc[0], vget_low_s32(bias_lo));
|
|
1675
|
+
acc[1] = vsub_s32(acc[1], vget_high_s32(bias_lo));
|
|
1676
|
+
acc[2] = vsub_s32(acc[2], vget_low_s32(bias_hi));
|
|
1677
|
+
acc[3] = vsub_s32(acc[3], vget_high_s32(bias_hi));
|
|
1678
|
+
|
|
1679
|
+
// Apply superblock scale (no mins for q6_K)
|
|
1680
|
+
// acc[cp] has [c0, c1]
|
|
1681
|
+
float32x2_t w_01 = vmul_f32(vcvt_f32_s32(acc[0]), vget_low_f32(sb_scale_0));
|
|
1682
|
+
float32x2_t w_23 = vmul_f32(vcvt_f32_s32(acc[1]), vget_high_f32(sb_scale_0));
|
|
1683
|
+
float32x2_t w_45 = vmul_f32(vcvt_f32_s32(acc[2]), vget_low_f32(sb_scale_1));
|
|
1684
|
+
float32x2_t w_67 = vmul_f32(vcvt_f32_s32(acc[3]), vget_high_f32(sb_scale_1));
|
|
1685
|
+
|
|
1686
|
+
acc_f32[0] = vaddq_f32(acc_f32[0], vcombine_f32(w_01, w_23));
|
|
1687
|
+
acc_f32[1] = vaddq_f32(acc_f32[1], vcombine_f32(w_45, w_67));
|
|
1688
|
+
} // for b
|
|
1689
|
+
|
|
1690
|
+
int base = x * ncols_interleaved;
|
|
1691
|
+
vst1q_f32(s + base, acc_f32[0]);
|
|
1692
|
+
vst1q_f32(s + base + 4, acc_f32[1]);
|
|
1693
|
+
} // for x
|
|
1694
|
+
return;
|
|
1695
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
1696
|
+
ggml_gemv_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
1697
|
+
}
|
|
1698
|
+
|
|
1699
|
+
void ggml_gemv_q8_0_4x4_q8_0(int n,
|
|
1700
|
+
float * GGML_RESTRICT s,
|
|
1701
|
+
size_t bs,
|
|
1702
|
+
const void * GGML_RESTRICT vx,
|
|
1703
|
+
const void * GGML_RESTRICT vy,
|
|
1704
|
+
int nr,
|
|
1705
|
+
int nc) {
|
|
1706
|
+
const int qk = QK8_0;
|
|
1707
|
+
const int nb = n / qk;
|
|
1708
|
+
const int ncols_interleaved = 4;
|
|
1709
|
+
const int blocklen = 4;
|
|
1710
|
+
|
|
1711
|
+
assert(n % qk == 0);
|
|
1712
|
+
assert(nc % ncols_interleaved == 0);
|
|
1713
|
+
|
|
1714
|
+
UNUSED(nb);
|
|
1715
|
+
UNUSED(ncols_interleaved);
|
|
1716
|
+
UNUSED(blocklen);
|
|
1717
|
+
|
|
1718
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
1719
|
+
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
|
|
1720
|
+
|
|
1721
|
+
for (int c = 0; c < nc; c += ncols_interleaved) {
|
|
1722
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
1723
|
+
float32x4_t acc = vdupq_n_f32(0);
|
|
1724
|
+
for (int b = 0; b < nb; b++) {
|
|
1725
|
+
int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
|
|
1726
|
+
int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
|
|
1727
|
+
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
|
|
1728
|
+
|
|
1729
|
+
int8x16x2_t a = vld1q_s8_x2(a_ptr->qs);
|
|
1730
|
+
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
|
|
1731
|
+
|
|
1732
|
+
int32x4_t ret = vdupq_n_s32(0);
|
|
1733
|
+
|
|
1734
|
+
ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0);
|
|
1735
|
+
ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1);
|
|
1736
|
+
ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2);
|
|
1737
|
+
ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3);
|
|
1738
|
+
|
|
1739
|
+
ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0);
|
|
1740
|
+
ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1);
|
|
1741
|
+
ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2);
|
|
1742
|
+
ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3);
|
|
1743
|
+
|
|
1744
|
+
acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
|
|
1745
|
+
a_ptr++;
|
|
1746
|
+
b_ptr++;
|
|
1747
|
+
}
|
|
1748
|
+
vst1q_f32(s, acc);
|
|
1749
|
+
s += ncols_interleaved;
|
|
1750
|
+
}
|
|
1751
|
+
return;
|
|
1752
|
+
|
|
1753
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
1754
|
+
ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
1755
|
+
}
|
|
1756
|
+
|
|
1757
|
+
void ggml_gemv_q8_0_4x8_q8_0(int n,
|
|
1758
|
+
float * GGML_RESTRICT s,
|
|
1759
|
+
size_t bs,
|
|
1760
|
+
const void * GGML_RESTRICT vx,
|
|
1761
|
+
const void * GGML_RESTRICT vy,
|
|
1762
|
+
int nr,
|
|
1763
|
+
int nc) {
|
|
1764
|
+
const int qk = QK8_0;
|
|
1765
|
+
const int nb = n / qk;
|
|
1766
|
+
const int ncols_interleaved = 4;
|
|
1767
|
+
const int blocklen = 8;
|
|
1768
|
+
|
|
1769
|
+
assert(n % qk == 0);
|
|
1770
|
+
assert(nc % ncols_interleaved == 0);
|
|
1771
|
+
|
|
1772
|
+
UNUSED(nb);
|
|
1773
|
+
UNUSED(ncols_interleaved);
|
|
1774
|
+
UNUSED(blocklen);
|
|
1775
|
+
|
|
1776
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
1777
|
+
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
|
|
1778
|
+
|
|
1779
|
+
for (int c = 0; c < nc; c += ncols_interleaved) {
|
|
1780
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
1781
|
+
float32x4_t acc = vdupq_n_f32(0);
|
|
1782
|
+
|
|
1783
|
+
for (int b = 0; b < nb; b++) {
|
|
1784
|
+
int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
|
|
1785
|
+
int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
|
|
1786
|
+
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
|
|
1787
|
+
|
|
1788
|
+
int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs);
|
|
1789
|
+
int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]);
|
|
1790
|
+
int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);
|
|
1791
|
+
int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);
|
|
1792
|
+
int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);
|
|
1793
|
+
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
|
|
1794
|
+
|
|
1795
|
+
int32x4_t ret0 = vdupq_n_s32(0);
|
|
1796
|
+
int32x4_t ret1 = vdupq_n_s32(0);
|
|
1797
|
+
|
|
1798
|
+
// 0..7
|
|
1799
|
+
ret0 = vdotq_s32(ret0, b_low.val[0], a0);
|
|
1800
|
+
ret1 = vdotq_s32(ret1, b_low.val[1], a0);
|
|
1801
|
+
// 8..15
|
|
1802
|
+
ret0 = vdotq_s32(ret0, b_low.val[2], a1);
|
|
1803
|
+
ret1 = vdotq_s32(ret1, b_low.val[3], a1);
|
|
1804
|
+
// 16..23
|
|
1805
|
+
ret0 = vdotq_s32(ret0, b_high.val[0], a2);
|
|
1806
|
+
ret1 = vdotq_s32(ret1, b_high.val[1], a2);
|
|
1807
|
+
// 24..31
|
|
1808
|
+
ret0 = vdotq_s32(ret0, b_high.val[2], a3);
|
|
1809
|
+
ret1 = vdotq_s32(ret1, b_high.val[3], a3);
|
|
1810
|
+
|
|
1811
|
+
int32x4_t ret = vpaddq_s32(ret0, ret1);
|
|
1812
|
+
|
|
1813
|
+
acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
|
|
1814
|
+
a_ptr++;
|
|
1815
|
+
b_ptr++;
|
|
1816
|
+
}
|
|
1817
|
+
vst1q_f32(s, acc);
|
|
1818
|
+
s += ncols_interleaved;
|
|
1819
|
+
}
|
|
1820
|
+
return;
|
|
1821
|
+
|
|
1822
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
1823
|
+
ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
1824
|
+
}
|
|
1825
|
+
|
|
1826
|
+
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
1827
|
+
const int qk = QK8_0;
|
|
1828
|
+
const int nb = n / qk;
|
|
1829
|
+
const int ncols_interleaved = 4;
|
|
1830
|
+
const int blocklen = 4;
|
|
1831
|
+
|
|
1832
|
+
assert (n % qk == 0);
|
|
1833
|
+
assert (nr % 4 == 0);
|
|
1834
|
+
assert (nc % ncols_interleaved == 0);
|
|
1835
|
+
|
|
1836
|
+
UNUSED(s);
|
|
1837
|
+
UNUSED(bs);
|
|
1838
|
+
UNUSED(vx);
|
|
1839
|
+
UNUSED(vy);
|
|
1840
|
+
UNUSED(nr);
|
|
1841
|
+
UNUSED(nc);
|
|
1842
|
+
UNUSED(nb);
|
|
1843
|
+
UNUSED(ncols_interleaved);
|
|
1844
|
+
UNUSED(blocklen);
|
|
1845
|
+
|
|
1846
|
+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
1847
|
+
const void * b_ptr = vx;
|
|
1848
|
+
const void * a_ptr = vy;
|
|
1849
|
+
float * res_ptr = s;
|
|
1850
|
+
size_t res_stride = bs * sizeof(float);
|
|
1851
|
+
|
|
1852
|
+
__asm__ __volatile__(
|
|
1853
|
+
"mov x10, %x[nr]\n"
|
|
1854
|
+
"mov x9, #0x88\n"
|
|
1855
|
+
"cmp x10, #0x10\n"
|
|
1856
|
+
"mul x9, %x[nb], x9\n"
|
|
1857
|
+
"blt 4f\n"
|
|
1858
|
+
"1:" // Row loop
|
|
1859
|
+
"add x28, %x[b_ptr], #0x8\n"
|
|
1860
|
+
"mov x27, %x[nc]\n"
|
|
1861
|
+
"add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
|
|
1862
|
+
"2:" // Column loop
|
|
1863
|
+
"add x25, %x[a_ptr], #0x8\n"
|
|
1864
|
+
"movi v15.16b, #0x0\n"
|
|
1865
|
+
"movi v19.16b, #0x0\n"
|
|
1866
|
+
"mov x24, %x[nb]\n"
|
|
1867
|
+
"add x23, x25, x9\n"
|
|
1868
|
+
"movi v18.16b, #0x0\n"
|
|
1869
|
+
"movi v14.16b, #0x0\n"
|
|
1870
|
+
"add x22, x23, x9\n"
|
|
1871
|
+
"movi v11.16b, #0x0\n"
|
|
1872
|
+
"movi v13.16b, #0x0\n"
|
|
1873
|
+
"add x21, x22, x9\n"
|
|
1874
|
+
"movi v23.16b, #0x0\n"
|
|
1875
|
+
"movi v16.16b, #0x0\n"
|
|
1876
|
+
"movi v25.16b, #0x0\n"
|
|
1877
|
+
"movi v7.16b, #0x0\n"
|
|
1878
|
+
"movi v0.16b, #0x0\n"
|
|
1879
|
+
"movi v4.16b, #0x0\n"
|
|
1880
|
+
"movi v5.16b, #0x0\n"
|
|
1881
|
+
"movi v21.16b, #0x0\n"
|
|
1882
|
+
"movi v8.16b, #0x0\n"
|
|
1883
|
+
"movi v1.16b, #0x0\n"
|
|
1884
|
+
"3:" // Block loop
|
|
1885
|
+
"ldr q3, [x28, #0x0]\n"
|
|
1886
|
+
"ldr q31, [x25, #0x0]\n"
|
|
1887
|
+
"movi v28.16b, #0x4\n"
|
|
1888
|
+
"movi v10.4s, #0x0\n"
|
|
1889
|
+
"ldr q22, [x28, #0x10]\n"
|
|
1890
|
+
"ldr q6, [x25, #0x10]\n"
|
|
1891
|
+
"movi v29.4s, #0x0\n"
|
|
1892
|
+
"movi v9.4s, #0x0\n"
|
|
1893
|
+
"ldr q27, [x28, #0x20]\n"
|
|
1894
|
+
"ldr q30, [x28, #0x30]\n"
|
|
1895
|
+
"movi v20.4s, #0x0\n"
|
|
1896
|
+
"movi v24.16b, #0xf0\n"
|
|
1897
|
+
"ldr d2, [x25, #-0x8]\n"
|
|
1898
|
+
"ldr d26, [x23, #-0x8]\n"
|
|
1899
|
+
"sshl v12.16b, v3.16b, v28.16b\n"
|
|
1900
|
+
"sub x20, x28, #0x8\n"
|
|
1901
|
+
"ldr d17, [x20, #0x0]\n"
|
|
1902
|
+
"and v3.16b, v3.16b, v24.16b\n"
|
|
1903
|
+
"subs x24, x24, #0x1\n"
|
|
1904
|
+
"add x28, x28, #0x48\n"
|
|
1905
|
+
".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n"
|
|
1906
|
+
".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n"
|
|
1907
|
+
".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n"
|
|
1908
|
+
".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n"
|
|
1909
|
+
"sshl v31.16b, v22.16b, v28.16b\n"
|
|
1910
|
+
"and v22.16b, v22.16b, v24.16b\n"
|
|
1911
|
+
"fcvtl v17.4s, v17.4h\n"
|
|
1912
|
+
"fcvtl v2.4s, v2.4h\n"
|
|
1913
|
+
"fcvtl v26.4s, v26.4h\n"
|
|
1914
|
+
".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n"
|
|
1915
|
+
".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n"
|
|
1916
|
+
".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n"
|
|
1917
|
+
".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n"
|
|
1918
|
+
"sshl v6.16b, v27.16b, v28.16b\n"
|
|
1919
|
+
"sshl v28.16b, v30.16b, v28.16b\n"
|
|
1920
|
+
"and v27.16b, v27.16b, v24.16b\n"
|
|
1921
|
+
"and v30.16b, v30.16b, v24.16b\n"
|
|
1922
|
+
"ldr q24, [x25, #0x20]\n"
|
|
1923
|
+
".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n"
|
|
1924
|
+
".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
|
|
1925
|
+
".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n"
|
|
1926
|
+
".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n"
|
|
1927
|
+
"ldr q24, [x25, #0x30]\n"
|
|
1928
|
+
".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n"
|
|
1929
|
+
".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n"
|
|
1930
|
+
".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n"
|
|
1931
|
+
".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n"
|
|
1932
|
+
"ldr q24, [x25, #0x40]\n"
|
|
1933
|
+
".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n"
|
|
1934
|
+
".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
|
|
1935
|
+
".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n"
|
|
1936
|
+
".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n"
|
|
1937
|
+
"ldr q24, [x25, #0x50]\n"
|
|
1938
|
+
".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n"
|
|
1939
|
+
".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n"
|
|
1940
|
+
".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n"
|
|
1941
|
+
".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n"
|
|
1942
|
+
"ldr q24, [x25, #0x60]\n"
|
|
1943
|
+
".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n"
|
|
1944
|
+
".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
|
|
1945
|
+
".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n"
|
|
1946
|
+
".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n"
|
|
1947
|
+
"ldr q24, [x25, #0x70]\n"
|
|
1948
|
+
"add x25, x25, #0x88\n"
|
|
1949
|
+
".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n"
|
|
1950
|
+
".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n"
|
|
1951
|
+
".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n"
|
|
1952
|
+
".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n"
|
|
1953
|
+
"fmul v24.4s, v17.4s, v2.s[0]\n"
|
|
1954
|
+
"scvtf v10.4s, v10.4s, #0x4\n"
|
|
1955
|
+
"scvtf v29.4s, v29.4s, #0x4\n"
|
|
1956
|
+
"scvtf v9.4s, v9.4s, #0x4\n"
|
|
1957
|
+
"scvtf v20.4s, v20.4s, #0x4\n"
|
|
1958
|
+
"fmla v15.4s, v10.4s, v24.4s\n"
|
|
1959
|
+
"ldr q24, [x23, #0x0]\n"
|
|
1960
|
+
"fmul v10.4s, v17.4s, v2.s[1]\n"
|
|
1961
|
+
"fmla v19.4s, v29.4s, v10.4s\n"
|
|
1962
|
+
"ldr q10, [x23, #0x10]\n"
|
|
1963
|
+
"fmul v29.4s, v17.4s, v2.s[2]\n"
|
|
1964
|
+
"fmul v2.4s, v17.4s, v2.s[3]\n"
|
|
1965
|
+
"fmla v18.4s, v9.4s, v29.4s\n"
|
|
1966
|
+
"movi v9.4s, #0x0\n"
|
|
1967
|
+
"movi v29.4s, #0x0\n"
|
|
1968
|
+
".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n"
|
|
1969
|
+
".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n"
|
|
1970
|
+
"fmla v14.4s, v20.4s, v2.4s\n"
|
|
1971
|
+
"movi v20.4s, #0x0\n"
|
|
1972
|
+
"movi v2.4s, #0x0\n"
|
|
1973
|
+
".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n"
|
|
1974
|
+
".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
|
|
1975
|
+
"ldr q24, [x23, #0x20]\n"
|
|
1976
|
+
".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n"
|
|
1977
|
+
".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n"
|
|
1978
|
+
".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n"
|
|
1979
|
+
".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n"
|
|
1980
|
+
"ldr q10, [x23, #0x30]\n"
|
|
1981
|
+
".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n"
|
|
1982
|
+
".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
|
|
1983
|
+
".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n"
|
|
1984
|
+
".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
|
|
1985
|
+
"ldr q24, [x23, #0x40]\n"
|
|
1986
|
+
".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n"
|
|
1987
|
+
".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n"
|
|
1988
|
+
".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n"
|
|
1989
|
+
".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n"
|
|
1990
|
+
"ldr q10, [x23, #0x50]\n"
|
|
1991
|
+
".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n"
|
|
1992
|
+
".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
|
|
1993
|
+
".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n"
|
|
1994
|
+
".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
|
|
1995
|
+
"ldr q24, [x23, #0x60]\n"
|
|
1996
|
+
".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n"
|
|
1997
|
+
".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n"
|
|
1998
|
+
".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n"
|
|
1999
|
+
".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n"
|
|
2000
|
+
"ldr q10, [x23, #0x70]\n"
|
|
2001
|
+
"add x23, x23, #0x88\n"
|
|
2002
|
+
".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n"
|
|
2003
|
+
".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
|
|
2004
|
+
".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n"
|
|
2005
|
+
".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
|
|
2006
|
+
"ldr q24, [x22, #0x0]\n"
|
|
2007
|
+
".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n"
|
|
2008
|
+
".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n"
|
|
2009
|
+
".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n"
|
|
2010
|
+
".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n"
|
|
2011
|
+
"fmul v10.4s, v17.4s, v26.s[0]\n"
|
|
2012
|
+
"scvtf v9.4s, v9.4s, #0x4\n"
|
|
2013
|
+
"scvtf v29.4s, v29.4s, #0x4\n"
|
|
2014
|
+
"scvtf v20.4s, v20.4s, #0x4\n"
|
|
2015
|
+
"scvtf v2.4s, v2.4s, #0x4\n"
|
|
2016
|
+
"fmla v11.4s, v9.4s, v10.4s\n"
|
|
2017
|
+
"ldr q9, [x22, #0x10]\n"
|
|
2018
|
+
"fmul v10.4s, v17.4s, v26.s[1]\n"
|
|
2019
|
+
"fmla v13.4s, v29.4s, v10.4s\n"
|
|
2020
|
+
"ldr d29, [x22, #-0x8]\n"
|
|
2021
|
+
"fmul v10.4s, v17.4s, v26.s[2]\n"
|
|
2022
|
+
"fmul v26.4s, v17.4s, v26.s[3]\n"
|
|
2023
|
+
"fcvtl v29.4s, v29.4h\n"
|
|
2024
|
+
"fmla v23.4s, v20.4s, v10.4s\n"
|
|
2025
|
+
"movi v20.4s, #0x0\n"
|
|
2026
|
+
"movi v10.4s, #0x0\n"
|
|
2027
|
+
"fmla v16.4s, v2.4s, v26.4s\n"
|
|
2028
|
+
"movi v26.4s, #0x0\n"
|
|
2029
|
+
"movi v2.4s, #0x0\n"
|
|
2030
|
+
".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n"
|
|
2031
|
+
".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
|
|
2032
|
+
".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n"
|
|
2033
|
+
".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
|
|
2034
|
+
"ldr q24, [x22, #0x20]\n"
|
|
2035
|
+
".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n"
|
|
2036
|
+
".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
|
|
2037
|
+
".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n"
|
|
2038
|
+
".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n"
|
|
2039
|
+
"ldr q9, [x22, #0x30]\n"
|
|
1130
2040
|
".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n"
|
|
1131
2041
|
".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n"
|
|
1132
2042
|
".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n"
|
|
@@ -2247,110 +3157,935 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
|
2247
3157
|
);
|
|
2248
3158
|
return;
|
|
2249
3159
|
}
|
|
2250
|
-
#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
3160
|
+
#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
3161
|
+
|
|
3162
|
+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
|
3163
|
+
ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
3164
|
+
}
|
|
3165
|
+
|
|
3166
|
+
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
3167
|
+
const int qk = QK8_0;
|
|
3168
|
+
const int nb = n / qk;
|
|
3169
|
+
const int ncols_interleaved = 4;
|
|
3170
|
+
const int blocklen = 4;
|
|
3171
|
+
|
|
3172
|
+
assert (n % qk == 0);
|
|
3173
|
+
assert (nr % 4 == 0);
|
|
3174
|
+
assert (nc % ncols_interleaved == 0);
|
|
3175
|
+
|
|
3176
|
+
UNUSED(s);
|
|
3177
|
+
UNUSED(bs);
|
|
3178
|
+
UNUSED(vx);
|
|
3179
|
+
UNUSED(vy);
|
|
3180
|
+
UNUSED(nr);
|
|
3181
|
+
UNUSED(nc);
|
|
3182
|
+
UNUSED(nb);
|
|
3183
|
+
UNUSED(ncols_interleaved);
|
|
3184
|
+
UNUSED(blocklen);
|
|
3185
|
+
|
|
3186
|
+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
3187
|
+
const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
|
|
3188
|
+
|
|
3189
|
+
for (int y = 0; y < nr / 4; y++) {
|
|
3190
|
+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
3191
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
3192
|
+
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
|
|
3193
|
+
|
|
3194
|
+
float32x4_t sumf[4];
|
|
3195
|
+
for (int m = 0; m < 4; m++) {
|
|
3196
|
+
sumf[m] = vdupq_n_f32(0);
|
|
3197
|
+
}
|
|
3198
|
+
|
|
3199
|
+
for (int l = 0; l < nb; l++) {
|
|
3200
|
+
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
|
|
3201
|
+
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
|
|
3202
|
+
|
|
3203
|
+
int32x4_t sumi_0 = vdupq_n_s32(0);
|
|
3204
|
+
int32x4_t sumi_1 = vdupq_n_s32(0);
|
|
3205
|
+
int32x4_t sumi_2 = vdupq_n_s32(0);
|
|
3206
|
+
int32x4_t sumi_3 = vdupq_n_s32(0);
|
|
3207
|
+
|
|
3208
|
+
for (int k = 0; k < 4; k++) {
|
|
3209
|
+
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
|
|
3210
|
+
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
|
|
3211
|
+
|
|
3212
|
+
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
|
|
3213
|
+
int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
|
|
3214
|
+
int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
|
|
3215
|
+
|
|
3216
|
+
sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
|
|
3217
|
+
sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
|
|
3218
|
+
sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
|
|
3219
|
+
sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
|
|
3220
|
+
sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
|
|
3221
|
+
sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
|
|
3222
|
+
sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
|
|
3223
|
+
sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
|
|
3224
|
+
}
|
|
3225
|
+
|
|
3226
|
+
sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
|
|
3227
|
+
sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
|
|
3228
|
+
sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
|
|
3229
|
+
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
|
|
3230
|
+
}
|
|
3231
|
+
|
|
3232
|
+
for (int m = 0; m < 4; m++) {
|
|
3233
|
+
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
|
|
3234
|
+
}
|
|
3235
|
+
}
|
|
3236
|
+
}
|
|
3237
|
+
return;
|
|
3238
|
+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
3239
|
+
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
3240
|
+
}
|
|
3241
|
+
|
|
3242
|
+
void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
3243
|
+
const int qk = QK8_0;
|
|
3244
|
+
const int nb = n / qk;
|
|
3245
|
+
const int ncols_interleaved = 4;
|
|
3246
|
+
const int blocklen = 4;
|
|
3247
|
+
|
|
3248
|
+
assert (n % qk == 0);
|
|
3249
|
+
assert (nr % 4 == 0);
|
|
3250
|
+
assert (nc % ncols_interleaved == 0);
|
|
3251
|
+
|
|
3252
|
+
UNUSED(s);
|
|
3253
|
+
UNUSED(bs);
|
|
3254
|
+
UNUSED(vx);
|
|
3255
|
+
UNUSED(vy);
|
|
3256
|
+
UNUSED(nr);
|
|
3257
|
+
UNUSED(nc);
|
|
3258
|
+
UNUSED(nb);
|
|
3259
|
+
UNUSED(ncols_interleaved);
|
|
3260
|
+
UNUSED(blocklen);
|
|
3261
|
+
|
|
3262
|
+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
3263
|
+
const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);
|
|
3264
|
+
|
|
3265
|
+
for (int y = 0; y < nr / 4; y++) {
|
|
3266
|
+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
3267
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
3268
|
+
const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
|
|
3269
|
+
|
|
3270
|
+
float32x4_t sumf[4];
|
|
3271
|
+
for (int m = 0; m < 4; m++) {
|
|
3272
|
+
sumf[m] = vdupq_n_f32(0);
|
|
3273
|
+
}
|
|
3274
|
+
|
|
3275
|
+
for (int l = 0; l < nb; l++) {
|
|
3276
|
+
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
|
|
3277
|
+
float32x4_t b_d = {
|
|
3278
|
+
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),
|
|
3279
|
+
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),
|
|
3280
|
+
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),
|
|
3281
|
+
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),
|
|
3282
|
+
};
|
|
3283
|
+
|
|
3284
|
+
int32x4_t sumi_0 = vdupq_n_s32(0);
|
|
3285
|
+
int32x4_t sumi_1 = vdupq_n_s32(0);
|
|
3286
|
+
int32x4_t sumi_2 = vdupq_n_s32(0);
|
|
3287
|
+
int32x4_t sumi_3 = vdupq_n_s32(0);
|
|
3288
|
+
|
|
3289
|
+
for (int k = 0; k < 4; k++) {
|
|
3290
|
+
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
|
|
3291
|
+
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
|
|
3292
|
+
|
|
3293
|
+
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
|
|
3294
|
+
int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
|
|
3295
|
+
int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
|
|
3296
|
+
|
|
3297
|
+
sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
|
|
3298
|
+
sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
|
|
3299
|
+
sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
|
|
3300
|
+
sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
|
|
3301
|
+
sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
|
|
3302
|
+
sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
|
|
3303
|
+
sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
|
|
3304
|
+
sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
|
|
3305
|
+
}
|
|
3306
|
+
|
|
3307
|
+
sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
|
|
3308
|
+
sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
|
|
3309
|
+
sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
|
|
3310
|
+
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
|
|
3311
|
+
}
|
|
3312
|
+
|
|
3313
|
+
for (int m = 0; m < 4; m++) {
|
|
3314
|
+
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
|
|
3315
|
+
}
|
|
3316
|
+
}
|
|
3317
|
+
}
|
|
3318
|
+
return;
|
|
3319
|
+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
3320
|
+
ggml_gemm_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
3321
|
+
}
|
|
3322
|
+
|
|
3323
|
+
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
3324
|
+
constexpr int qk = QK_K;
|
|
3325
|
+
const int nb = n / qk;
|
|
3326
|
+
|
|
3327
|
+
constexpr int ncols_interleaved = 8;
|
|
3328
|
+
constexpr int blocklen = 4;
|
|
3329
|
+
|
|
3330
|
+
assert(n % qk == 0);
|
|
3331
|
+
assert(nr % 4 == 0);
|
|
3332
|
+
assert(nc % ncols_interleaved == 0);
|
|
3333
|
+
|
|
3334
|
+
UNUSED(nb);
|
|
3335
|
+
UNUSED(ncols_interleaved);
|
|
3336
|
+
UNUSED(blocklen);
|
|
3337
|
+
|
|
3338
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
3339
|
+
constexpr int q8_k_blocklen = 4;
|
|
3340
|
+
constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
|
|
3341
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
3342
|
+
|
|
3343
|
+
// 8 accumulators: 2 row pairs × 4 col pairs
|
|
3344
|
+
float32x4_t acc_f32[acc_size];
|
|
3345
|
+
|
|
3346
|
+
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
|
3347
|
+
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
|
3348
|
+
|
|
3349
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
3350
|
+
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
|
3351
|
+
|
|
3352
|
+
for (int i = 0; i < acc_size; i++) {
|
|
3353
|
+
acc_f32[i] = vdupq_n_f32(0);
|
|
3354
|
+
}
|
|
3355
|
+
|
|
3356
|
+
for (int b = 0; b < nb; b++) {
|
|
3357
|
+
// d4 0 1 2 3, 4 5 6 7
|
|
3358
|
+
float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
|
|
3359
|
+
float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
|
|
3360
|
+
// d8 0 1 2 3
|
|
3361
|
+
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
|
|
3362
|
+
// mins
|
|
3363
|
+
float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
|
|
3364
|
+
float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
|
|
3365
|
+
|
|
3366
|
+
// Precomputation of scales and mins
|
|
3367
|
+
float32x4_t sbd_scale_0123[q8_k_blocklen];
|
|
3368
|
+
float32x4_t sbd_scale_4567[q8_k_blocklen];
|
|
3369
|
+
float32x4_t sbd_min_0123[q8_k_blocklen];
|
|
3370
|
+
float32x4_t sbd_min_4567[q8_k_blocklen];
|
|
3371
|
+
|
|
3372
|
+
sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
|
|
3373
|
+
sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
|
|
3374
|
+
sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
|
|
3375
|
+
sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
|
|
3376
|
+
|
|
3377
|
+
sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
|
|
3378
|
+
sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
|
|
3379
|
+
sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
|
|
3380
|
+
sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
|
|
3381
|
+
|
|
3382
|
+
sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
|
|
3383
|
+
sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
|
|
3384
|
+
sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
|
|
3385
|
+
sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
|
|
3386
|
+
|
|
3387
|
+
sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
|
|
3388
|
+
sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
|
|
3389
|
+
sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
|
|
3390
|
+
sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
|
|
3391
|
+
|
|
3392
|
+
// Precomputation of bsums, each vpaddq calcs all the bsums for each row
|
|
3393
|
+
const int16x8_t bsums[q8_k_blocklen] = {
|
|
3394
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
|
3395
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
|
3396
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
|
3397
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
|
3398
|
+
};
|
|
3399
|
+
int16_t bsums_arr[QK_K / 64][8];
|
|
3400
|
+
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
|
3401
|
+
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
|
3402
|
+
}
|
|
3403
|
+
|
|
3404
|
+
// interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
|
|
3405
|
+
int32x4_t bias_acc[acc_size];
|
|
3406
|
+
for (int i = 0; i < acc_size; i++) {
|
|
3407
|
+
bias_acc[i] = vdupq_n_s32(0);
|
|
3408
|
+
}
|
|
3409
|
+
|
|
3410
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
3411
|
+
// Int accumulators for qs vecdot (4 row x 2 col quartets)
|
|
3412
|
+
int32x4_t acc_lo[acc_size];
|
|
3413
|
+
int32x4_t acc_hi[acc_size];
|
|
3414
|
+
for (int i = 0; i < acc_size; i++) {
|
|
3415
|
+
acc_lo[i] = vdupq_n_s32(0);
|
|
3416
|
+
acc_hi[i] = vdupq_n_s32(0);
|
|
3417
|
+
}
|
|
3418
|
+
// Need scales for the low and high nibbles
|
|
3419
|
+
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
|
3420
|
+
int16x8_t q4sb_scales[2];
|
|
3421
|
+
int16x8_t q4sb_mins[2];
|
|
3422
|
+
for (int i = 0; i < 2; i++) {
|
|
3423
|
+
int8_t aux_q4sb[8];
|
|
3424
|
+
const int offset = sb * 24 + i * 12;
|
|
3425
|
+
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
|
3426
|
+
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
|
3427
|
+
}
|
|
3428
|
+
|
|
3429
|
+
constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
|
|
3430
|
+
for (int k = 0; k < reads_per_sb; k++) {
|
|
3431
|
+
const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
|
|
3432
|
+
const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
|
|
3433
|
+
|
|
3434
|
+
// 0..3 & 32..35
|
|
3435
|
+
const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
|
|
3436
|
+
const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
|
|
3437
|
+
|
|
3438
|
+
const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
|
|
3439
|
+
const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
|
|
3440
|
+
|
|
3441
|
+
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
|
|
3442
|
+
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
|
|
3443
|
+
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
|
|
3444
|
+
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
|
|
3445
|
+
|
|
3446
|
+
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
|
|
3447
|
+
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
|
|
3448
|
+
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
|
|
3449
|
+
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
|
|
3450
|
+
|
|
3451
|
+
const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
|
|
3452
|
+
const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
|
|
3453
|
+
|
|
3454
|
+
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
|
|
3455
|
+
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
|
|
3456
|
+
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
|
|
3457
|
+
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
|
|
3458
|
+
|
|
3459
|
+
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
|
|
3460
|
+
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
|
|
3461
|
+
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
|
|
3462
|
+
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
|
|
3463
|
+
}
|
|
3464
|
+
|
|
3465
|
+
// Scale and bias application
|
|
3466
|
+
// acc is stored interleaved to match output layout
|
|
3467
|
+
const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
|
|
3468
|
+
const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
|
|
3469
|
+
const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
|
|
3470
|
+
const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
|
|
3471
|
+
for (int row = 0; row < q8_k_blocklen; row++) {
|
|
3472
|
+
// Bias correction
|
|
3473
|
+
// row c0123 blk0 and blk1
|
|
3474
|
+
const float32x4_t sumf_0123 =
|
|
3475
|
+
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
|
|
3476
|
+
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
|
|
3477
|
+
acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
|
|
3478
|
+
|
|
3479
|
+
// row c4567 blk0 and blk1
|
|
3480
|
+
const float32x4_t sumf_4567 =
|
|
3481
|
+
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
|
|
3482
|
+
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
|
|
3483
|
+
acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
|
|
3484
|
+
|
|
3485
|
+
// Bias
|
|
3486
|
+
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
|
|
3487
|
+
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
|
|
3488
|
+
|
|
3489
|
+
// row c0123 blk0 and blk1
|
|
3490
|
+
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
|
3491
|
+
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
|
3492
|
+
|
|
3493
|
+
// row c4567 blk0 and blk1
|
|
3494
|
+
bias_acc[2 * row + 1] =
|
|
3495
|
+
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
|
3496
|
+
bias_acc[2 * row + 1] =
|
|
3497
|
+
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
|
3498
|
+
}
|
|
3499
|
+
} // for sb
|
|
2251
3500
|
|
|
2252
|
-
|
|
2253
|
-
|
|
3501
|
+
for (int row = 0; row < q8_k_blocklen; row++) {
|
|
3502
|
+
acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
|
|
3503
|
+
acc_f32[2 * row + 1] =
|
|
3504
|
+
vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
|
|
3505
|
+
}
|
|
3506
|
+
} // for b
|
|
3507
|
+
|
|
3508
|
+
for (int i = 0; i < q8_k_blocklen; i++) {
|
|
3509
|
+
int row = y * q8_k_blocklen + i;
|
|
3510
|
+
for (int j = 0; j < 2; j++) {
|
|
3511
|
+
int col = x * ncols_interleaved + j * 4;
|
|
3512
|
+
int offset = row * bs + col;
|
|
3513
|
+
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
|
3514
|
+
}
|
|
3515
|
+
}
|
|
3516
|
+
} // for x
|
|
3517
|
+
} // for y
|
|
3518
|
+
return;
|
|
3519
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
3520
|
+
ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
2254
3521
|
}
|
|
2255
3522
|
|
|
2256
|
-
void
|
|
2257
|
-
|
|
2258
|
-
|
|
2259
|
-
|
|
2260
|
-
|
|
3523
|
+
void ggml_gemm_q5_K_8x4_q8_K(int n,
|
|
3524
|
+
float * GGML_RESTRICT s,
|
|
3525
|
+
size_t bs,
|
|
3526
|
+
const void * GGML_RESTRICT vx,
|
|
3527
|
+
const void * GGML_RESTRICT vy,
|
|
3528
|
+
int nr,
|
|
3529
|
+
int nc) {
|
|
3530
|
+
constexpr int qk = QK_K;
|
|
3531
|
+
const int nb = n / qk;
|
|
2261
3532
|
|
|
2262
|
-
|
|
2263
|
-
|
|
2264
|
-
|
|
3533
|
+
constexpr int ncols_interleaved = 8;
|
|
3534
|
+
constexpr int blocklen = 4;
|
|
3535
|
+
|
|
3536
|
+
assert(n % qk == 0);
|
|
3537
|
+
assert(nr % 4 == 0);
|
|
3538
|
+
assert(nc % ncols_interleaved == 0);
|
|
2265
3539
|
|
|
2266
|
-
UNUSED(s);
|
|
2267
|
-
UNUSED(bs);
|
|
2268
|
-
UNUSED(vx);
|
|
2269
|
-
UNUSED(vy);
|
|
2270
|
-
UNUSED(nr);
|
|
2271
|
-
UNUSED(nc);
|
|
2272
3540
|
UNUSED(nb);
|
|
2273
3541
|
UNUSED(ncols_interleaved);
|
|
2274
3542
|
UNUSED(blocklen);
|
|
2275
3543
|
|
|
2276
|
-
#if
|
|
2277
|
-
|
|
3544
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
3545
|
+
constexpr int q8_k_blocklen = 4;
|
|
3546
|
+
constexpr int acc_size = 2 * 4; // 2 row pairs, 4 col pairs
|
|
3547
|
+
constexpr int col_groups = ncols_interleaved / 4;
|
|
3548
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
3549
|
+
const uint8x16_t mone = vdupq_n_u8(1);
|
|
3550
|
+
const uint8x16_t mtwo = vdupq_n_u8(2);
|
|
3551
|
+
|
|
3552
|
+
// 8 accumulators: 2 row pairs, 4 col pairs
|
|
3553
|
+
float32x4_t acc_f32[acc_size];
|
|
3554
|
+
|
|
3555
|
+
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
|
3556
|
+
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
|
2278
3557
|
|
|
2279
|
-
for (int y = 0; y < nr / 4; y++) {
|
|
2280
|
-
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
2281
3558
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
2282
|
-
const
|
|
3559
|
+
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
|
2283
3560
|
|
|
2284
|
-
|
|
2285
|
-
|
|
2286
|
-
sumf[m] = vdupq_n_f32(0);
|
|
3561
|
+
for (int i = 0; i < acc_size; i++) {
|
|
3562
|
+
acc_f32[i] = vdupq_n_f32(0);
|
|
2287
3563
|
}
|
|
2288
3564
|
|
|
2289
|
-
for (int
|
|
2290
|
-
|
|
2291
|
-
float32x4_t
|
|
3565
|
+
for (int b = 0; b < nb; b++) {
|
|
3566
|
+
// d5 0 1 2 3, 4 5 6 7
|
|
3567
|
+
float32x4_t q5_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));
|
|
3568
|
+
float32x4_t q5_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));
|
|
3569
|
+
// d8 0 1 2 3
|
|
3570
|
+
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
|
|
3571
|
+
// mins
|
|
3572
|
+
float32x4_t q5_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));
|
|
3573
|
+
float32x4_t q5_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));
|
|
2292
3574
|
|
|
2293
|
-
|
|
2294
|
-
|
|
2295
|
-
|
|
2296
|
-
|
|
3575
|
+
// Precomputation of scales and mins
|
|
3576
|
+
float32x4_t sbd_scale_0123[q8_k_blocklen];
|
|
3577
|
+
float32x4_t sbd_scale_4567[q8_k_blocklen];
|
|
3578
|
+
float32x4_t sbd_min_0123[q8_k_blocklen];
|
|
3579
|
+
float32x4_t sbd_min_4567[q8_k_blocklen];
|
|
2297
3580
|
|
|
2298
|
-
|
|
2299
|
-
|
|
2300
|
-
|
|
3581
|
+
sbd_scale_0123[0] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 0);
|
|
3582
|
+
sbd_scale_4567[0] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 0);
|
|
3583
|
+
sbd_min_0123[0] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 0);
|
|
3584
|
+
sbd_min_4567[0] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 0);
|
|
2301
3585
|
|
|
2302
|
-
|
|
2303
|
-
|
|
2304
|
-
|
|
3586
|
+
sbd_scale_0123[1] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 1);
|
|
3587
|
+
sbd_scale_4567[1] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 1);
|
|
3588
|
+
sbd_min_0123[1] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 1);
|
|
3589
|
+
sbd_min_4567[1] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 1);
|
|
2305
3590
|
|
|
2306
|
-
|
|
2307
|
-
|
|
2308
|
-
|
|
2309
|
-
|
|
2310
|
-
|
|
2311
|
-
|
|
2312
|
-
|
|
2313
|
-
|
|
3591
|
+
sbd_scale_0123[2] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 2);
|
|
3592
|
+
sbd_scale_4567[2] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 2);
|
|
3593
|
+
sbd_min_0123[2] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 2);
|
|
3594
|
+
sbd_min_4567[2] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 2);
|
|
3595
|
+
|
|
3596
|
+
sbd_scale_0123[3] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 3);
|
|
3597
|
+
sbd_scale_4567[3] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 3);
|
|
3598
|
+
sbd_min_0123[3] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 3);
|
|
3599
|
+
sbd_min_4567[3] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 3);
|
|
3600
|
+
|
|
3601
|
+
// Precomputation of bsums, each vpaddq calcs all the bsums for each row
|
|
3602
|
+
const int16x8_t bsums[q8_k_blocklen] = {
|
|
3603
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
|
3604
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
|
3605
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
|
3606
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
|
3607
|
+
};
|
|
3608
|
+
int16_t bsums_arr[QK_K / 64][8];
|
|
3609
|
+
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
|
3610
|
+
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
|
3611
|
+
}
|
|
3612
|
+
|
|
3613
|
+
// interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
|
|
3614
|
+
int32x4_t bias_acc[acc_size];
|
|
3615
|
+
for (int i = 0; i < acc_size; i++) {
|
|
3616
|
+
bias_acc[i] = vdupq_n_s32(0);
|
|
3617
|
+
}
|
|
3618
|
+
|
|
3619
|
+
uint8x16_t qh[col_groups][8];
|
|
3620
|
+
for (int c = 0; c < col_groups; c++) {
|
|
3621
|
+
for (int i = 0; i < 8; i++) {
|
|
3622
|
+
qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
|
|
3623
|
+
}
|
|
3624
|
+
}
|
|
3625
|
+
|
|
3626
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
3627
|
+
// Int accumulators for qs vecdot (4 row * 2 col quartets)
|
|
3628
|
+
int32x4_t acc_lo[acc_size];
|
|
3629
|
+
int32x4_t acc_hi[acc_size];
|
|
3630
|
+
for (int i = 0; i < acc_size; i++) {
|
|
3631
|
+
acc_lo[i] = vdupq_n_s32(0);
|
|
3632
|
+
acc_hi[i] = vdupq_n_s32(0);
|
|
3633
|
+
}
|
|
3634
|
+
// Need scales for the low and high nibbles
|
|
3635
|
+
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
|
3636
|
+
int16x8_t q5sb_scales[2];
|
|
3637
|
+
int16x8_t q5sb_mins[2];
|
|
3638
|
+
for (int i = 0; i < 2; i++) {
|
|
3639
|
+
int8_t aux_q5sb[8];
|
|
3640
|
+
const int offset = sb * 24 + i * 12;
|
|
3641
|
+
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
|
|
3642
|
+
q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
|
|
3643
|
+
}
|
|
3644
|
+
|
|
3645
|
+
constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
|
|
3646
|
+
for (int k = 0; k < reads_per_sb; k++) {
|
|
3647
|
+
const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
|
|
3648
|
+
const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
|
|
3649
|
+
|
|
3650
|
+
// 0..3 & 32..35
|
|
3651
|
+
const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k);
|
|
3652
|
+
const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16);
|
|
3653
|
+
|
|
3654
|
+
// NOTE: This is the only difference with q4_K
|
|
3655
|
+
const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone);
|
|
3656
|
+
const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3);
|
|
3657
|
+
qh[0][k] = vshrq_n_u8(qh[0][k], 2);
|
|
3658
|
+
const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone);
|
|
3659
|
+
const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3);
|
|
3660
|
+
qh[1][k] = vshrq_n_u8(qh[1][k], 2);
|
|
3661
|
+
// From here, same as q4_K
|
|
3662
|
+
|
|
3663
|
+
const int8x16_t q5_0123_lo =
|
|
3664
|
+
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4));
|
|
3665
|
+
const int8x16_t q5_0123_hi =
|
|
3666
|
+
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123));
|
|
3667
|
+
|
|
3668
|
+
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
|
|
3669
|
+
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
|
|
3670
|
+
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q5_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
|
|
3671
|
+
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q5_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
|
|
3672
|
+
|
|
3673
|
+
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q5_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
|
|
3674
|
+
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q5_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
|
|
3675
|
+
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
|
|
3676
|
+
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
|
|
3677
|
+
|
|
3678
|
+
const int8x16_t q5_4567_lo =
|
|
3679
|
+
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4));
|
|
3680
|
+
const int8x16_t q5_4567_hi =
|
|
3681
|
+
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567));
|
|
3682
|
+
|
|
3683
|
+
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
|
|
3684
|
+
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
|
|
3685
|
+
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
|
|
3686
|
+
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q5_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
|
|
3687
|
+
|
|
3688
|
+
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q5_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
|
|
3689
|
+
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q5_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
|
|
3690
|
+
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q5_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
|
|
3691
|
+
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q5_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
|
|
3692
|
+
}
|
|
3693
|
+
|
|
3694
|
+
// Scale and bias application
|
|
3695
|
+
// acc is stored interleaved to match output layout
|
|
3696
|
+
const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
|
|
3697
|
+
const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
|
|
3698
|
+
const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
|
|
3699
|
+
const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
|
|
3700
|
+
for (int row = 0; row < q8_k_blocklen; row++) {
|
|
3701
|
+
// Bias correction
|
|
3702
|
+
// row c0123 blk0 and blk1
|
|
3703
|
+
const float32x4_t sumf_0123 =
|
|
3704
|
+
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
|
|
3705
|
+
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
|
|
3706
|
+
acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
|
|
3707
|
+
|
|
3708
|
+
// row c4567 blk0 and blk1
|
|
3709
|
+
const float32x4_t sumf_4567 =
|
|
3710
|
+
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
|
|
3711
|
+
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
|
|
3712
|
+
acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
|
|
3713
|
+
|
|
3714
|
+
// Bias
|
|
3715
|
+
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
|
|
3716
|
+
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
|
|
3717
|
+
|
|
3718
|
+
// row c0123 blk0 and blk1
|
|
3719
|
+
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
|
|
3720
|
+
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
|
|
3721
|
+
|
|
3722
|
+
// row c4567 blk0 and blk1
|
|
3723
|
+
bias_acc[2 * row + 1] =
|
|
3724
|
+
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
|
|
3725
|
+
bias_acc[2 * row + 1] =
|
|
3726
|
+
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
|
|
3727
|
+
}
|
|
3728
|
+
} // for sb
|
|
3729
|
+
|
|
3730
|
+
for (int row = 0; row < q8_k_blocklen; row++) {
|
|
3731
|
+
acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
|
|
3732
|
+
acc_f32[2 * row + 1] =
|
|
3733
|
+
vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
|
|
3734
|
+
}
|
|
3735
|
+
} // for b
|
|
3736
|
+
|
|
3737
|
+
for (int i = 0; i < q8_k_blocklen; i++) {
|
|
3738
|
+
int row = y * q8_k_blocklen + i;
|
|
3739
|
+
for (int j = 0; j < 2; j++) {
|
|
3740
|
+
int col = x * ncols_interleaved + j * 4;
|
|
3741
|
+
int offset = row * bs + col;
|
|
3742
|
+
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
|
2314
3743
|
}
|
|
3744
|
+
}
|
|
3745
|
+
} // for x
|
|
3746
|
+
} // for y
|
|
3747
|
+
return;
|
|
3748
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
3749
|
+
ggml_gemm_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
3750
|
+
}
|
|
3751
|
+
|
|
3752
|
+
void ggml_gemm_q4_K_8x8_q8_K(int n,
|
|
3753
|
+
float * GGML_RESTRICT s,
|
|
3754
|
+
size_t bs,
|
|
3755
|
+
const void * GGML_RESTRICT vx,
|
|
3756
|
+
const void * GGML_RESTRICT vy,
|
|
3757
|
+
int nr,
|
|
3758
|
+
int nc) {
|
|
3759
|
+
constexpr int qk = QK_K;
|
|
3760
|
+
const int nb = n / qk;
|
|
3761
|
+
|
|
3762
|
+
constexpr int ncols_interleaved = 8;
|
|
3763
|
+
constexpr int blocklen = 8;
|
|
3764
|
+
|
|
3765
|
+
assert(n % qk == 0);
|
|
3766
|
+
assert(nr % 4 == 0);
|
|
3767
|
+
assert(nc % ncols_interleaved == 0);
|
|
3768
|
+
|
|
3769
|
+
UNUSED(nb);
|
|
3770
|
+
UNUSED(ncols_interleaved);
|
|
3771
|
+
UNUSED(blocklen);
|
|
3772
|
+
|
|
3773
|
+
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
3774
|
+
if (svcntb() * 8 == 256) {
|
|
3775
|
+
constexpr int q8_k_blocklen = 4;
|
|
3776
|
+
const svuint8_t m4b_1 = svdup_n_u8(0x0f);
|
|
3777
|
+
// 8 accumulators: 2 row pairs × 4 col pairs
|
|
3778
|
+
svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67;
|
|
3779
|
+
uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 };
|
|
3780
|
+
svbool_t pg = svptrue_pat_b32(SV_VL8);
|
|
3781
|
+
svuint32_t idx = svld1(pg, idx_arr);
|
|
3782
|
+
|
|
3783
|
+
static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7};
|
|
3784
|
+
svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data);
|
|
3785
|
+
|
|
3786
|
+
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
|
3787
|
+
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
|
3788
|
+
|
|
3789
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
3790
|
+
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
|
3791
|
+
|
|
3792
|
+
acc_f32_01 = svdup_n_f32(0);
|
|
3793
|
+
acc_f32_23 = svdup_n_f32(0);
|
|
3794
|
+
acc_f32_45 = svdup_n_f32(0);
|
|
3795
|
+
acc_f32_67 = svdup_n_f32(0);
|
|
3796
|
+
|
|
3797
|
+
for (int b = 0; b < nb; b++) {
|
|
3798
|
+
// bsums pairs belongs to the same q8_k subblock
|
|
3799
|
+
// 64 elements loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum
|
|
3800
|
+
const int16x8_t bsums[4]{
|
|
3801
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
|
3802
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
|
3803
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
|
3804
|
+
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
|
3805
|
+
};
|
|
3806
|
+
|
|
3807
|
+
int32_t bsums_arr32[4][8];
|
|
3808
|
+
|
|
3809
|
+
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
|
3810
|
+
int16x8_t v16 = bsums[q8_row];
|
|
3811
|
+
|
|
3812
|
+
// low 4
|
|
3813
|
+
int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16));
|
|
3814
|
+
vst1q_s32(&bsums_arr32[q8_row][0], v32_lo);
|
|
3815
|
+
|
|
3816
|
+
// high 4
|
|
3817
|
+
int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16));
|
|
3818
|
+
vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);
|
|
3819
|
+
}
|
|
3820
|
+
|
|
3821
|
+
svint32_t sb_acc_0 = svdup_n_s32(0);
|
|
3822
|
+
svint32_t sb_acc_2 = svdup_n_s32(0);
|
|
3823
|
+
|
|
3824
|
+
svint32_t acc_00 = svdup_n_s32(0);
|
|
3825
|
+
svint32_t acc_11 = svdup_n_s32(0);
|
|
3826
|
+
svint32_t acc_22 = svdup_n_s32(0);
|
|
3827
|
+
svint32_t acc_33 = svdup_n_s32(0);
|
|
3828
|
+
svint32_t acc_44 = svdup_n_s32(0);
|
|
3829
|
+
svint32_t acc_55 = svdup_n_s32(0);
|
|
3830
|
+
svint32_t acc_66 = svdup_n_s32(0);
|
|
3831
|
+
svint32_t acc_77 = svdup_n_s32(0);
|
|
3832
|
+
|
|
3833
|
+
svint32_t bias_acc_00 = svdup_n_s32(0);
|
|
3834
|
+
svint32_t bias_acc_22 = svdup_n_s32(0);
|
|
3835
|
+
svint32_t bias_acc_44 = svdup_n_s32(0);
|
|
3836
|
+
svint32_t bias_acc_66 = svdup_n_s32(0);
|
|
3837
|
+
|
|
3838
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
3839
|
+
// Need scales for the low and high nibbles
|
|
3840
|
+
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
|
3841
|
+
svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3;
|
|
3842
|
+
svint32_t q4sb_mins_0, q4sb_mins_1;
|
|
3843
|
+
{
|
|
3844
|
+
// 2-superblock I am working on
|
|
3845
|
+
const int offset = sb * 24 + 0 * 12;
|
|
3846
|
+
const uint8_t * scales_in = &q4_ptr[b].scales[offset];
|
|
3847
|
+
|
|
3848
|
+
const int offset1 = sb * 24 + 12;
|
|
3849
|
+
const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1];
|
|
3850
|
+
|
|
3851
|
+
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
|
3852
|
+
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
|
3853
|
+
constexpr uint32_t kmask3 = 0x03030303;
|
|
3854
|
+
constexpr uint8_t scales_size = 12;
|
|
3855
|
+
|
|
3856
|
+
uint32_t sm[3];
|
|
3857
|
+
memcpy(sm, scales_in, scales_size);
|
|
3858
|
+
|
|
3859
|
+
uint32_t sm1[3];
|
|
3860
|
+
memcpy(sm1, scales_in1, scales_size);
|
|
3861
|
+
|
|
3862
|
+
const uint32_t mins_0_3 = sm[1] & kmask1;
|
|
3863
|
+
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
|
|
3864
|
+
|
|
3865
|
+
const uint32_t mins_0_3_1 = sm1[1] & kmask1;
|
|
3866
|
+
const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4);
|
|
3867
|
+
|
|
3868
|
+
svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7));
|
|
3869
|
+
svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1));
|
|
3870
|
+
|
|
3871
|
+
/* reinterpret u32 → u8 */
|
|
3872
|
+
svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp);
|
|
3873
|
+
svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1);
|
|
3874
|
+
|
|
3875
|
+
/* widen u8 → u16->u32 (lower half only) */
|
|
3876
|
+
svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8));
|
|
3877
|
+
svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1));
|
|
3878
|
+
|
|
3879
|
+
q4sb_mins_0 = svreinterpret_s32_u32(mins_u16);
|
|
3880
|
+
q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1);
|
|
3881
|
+
|
|
3882
|
+
uint32_t scales_u32_0 = sm[0] & kmask1;
|
|
3883
|
+
uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
|
|
3884
|
+
uint32_t scales_u32_2 = sm1[0] & kmask1;
|
|
3885
|
+
uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4);
|
|
3886
|
+
|
|
3887
|
+
svuint32_t S01 = svdup_n_u32(scales_u32_0);
|
|
3888
|
+
svuint32_t S23 = svdup_n_u32(scales_u32_1);
|
|
3889
|
+
svuint32_t R01 = svdup_n_u32(scales_u32_2);
|
|
3890
|
+
svuint32_t R23 = svdup_n_u32(scales_u32_3);
|
|
3891
|
+
|
|
3892
|
+
svint8_t S01_b = svreinterpret_s8_u32(S01);
|
|
3893
|
+
svint8_t S23_b = svreinterpret_s8_u32(S23);
|
|
3894
|
+
svint8_t R01_b = svreinterpret_s8_u32(R01);
|
|
3895
|
+
svint8_t R23_b = svreinterpret_s8_u32(R23);
|
|
3896
|
+
|
|
3897
|
+
svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b)));
|
|
3898
|
+
svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b)));
|
|
3899
|
+
svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b)));
|
|
3900
|
+
svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b)));
|
|
3901
|
+
|
|
3902
|
+
block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx);
|
|
3903
|
+
block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx);
|
|
3904
|
+
block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx);
|
|
3905
|
+
block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx);
|
|
3906
|
+
}
|
|
3907
|
+
|
|
3908
|
+
const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256;
|
|
3909
|
+
|
|
3910
|
+
// Load 32-byte per row pair, 1 subblock each time
|
|
3911
|
+
// predicate for activating higher lanes for 16 int8 elements
|
|
3912
|
+
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
|
|
3913
|
+
// predicate for activating lower lanes for 16 int8 elements
|
|
3914
|
+
const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
|
|
3915
|
+
|
|
3916
|
+
svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));
|
|
3917
|
+
svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));
|
|
3918
|
+
svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));
|
|
3919
|
+
svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));
|
|
3920
|
+
|
|
3921
|
+
svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));
|
|
3922
|
+
svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));
|
|
3923
|
+
svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));
|
|
3924
|
+
svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));
|
|
3925
|
+
|
|
3926
|
+
// Q4s columns iterated in pairs (01, 23, 45, 67)
|
|
3927
|
+
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
|
|
3928
|
+
|
|
3929
|
+
sb_acc_0 = svdup_n_s32(0);
|
|
3930
|
+
sb_acc_2 = svdup_n_s32(0);
|
|
3931
|
+
|
|
3932
|
+
svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);
|
|
3933
|
+
svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);
|
|
3934
|
+
svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);
|
|
3935
|
+
svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);
|
|
3936
|
+
|
|
3937
|
+
svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));
|
|
3938
|
+
svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));
|
|
3939
|
+
svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));
|
|
3940
|
+
svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));
|
|
3941
|
+
|
|
3942
|
+
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);
|
|
3943
|
+
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);
|
|
3944
|
+
|
|
3945
|
+
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);
|
|
3946
|
+
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);
|
|
3947
|
+
|
|
3948
|
+
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);
|
|
3949
|
+
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);
|
|
3950
|
+
|
|
3951
|
+
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);
|
|
3952
|
+
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);
|
|
3953
|
+
|
|
3954
|
+
if(cp == 0) {
|
|
3955
|
+
acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);
|
|
3956
|
+
acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);
|
|
3957
|
+
}
|
|
3958
|
+
if(cp == 1) {
|
|
3959
|
+
acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);
|
|
3960
|
+
acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);
|
|
3961
|
+
}
|
|
3962
|
+
if(cp == 2) {
|
|
3963
|
+
acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);
|
|
3964
|
+
acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);
|
|
3965
|
+
}
|
|
3966
|
+
if(cp == 3) {
|
|
3967
|
+
acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);
|
|
3968
|
+
acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);
|
|
3969
|
+
}
|
|
3970
|
+
}
|
|
3971
|
+
|
|
3972
|
+
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0);
|
|
3973
|
+
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1);
|
|
3974
|
+
|
|
3975
|
+
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0);
|
|
3976
|
+
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1);
|
|
3977
|
+
|
|
3978
|
+
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0);
|
|
3979
|
+
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1);
|
|
3980
|
+
|
|
3981
|
+
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0);
|
|
3982
|
+
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1);
|
|
3983
|
+
} // for sb
|
|
3984
|
+
|
|
3985
|
+
|
|
3986
|
+
acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4));
|
|
3987
|
+
acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4));
|
|
3988
|
+
acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4));
|
|
3989
|
+
acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4));
|
|
3990
|
+
acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4));
|
|
3991
|
+
acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4));
|
|
3992
|
+
acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4));
|
|
3993
|
+
acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4));
|
|
3994
|
+
|
|
3995
|
+
svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1);
|
|
3996
|
+
svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1);
|
|
3997
|
+
|
|
3998
|
+
svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1);
|
|
3999
|
+
svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1);
|
|
4000
|
+
|
|
4001
|
+
// Broadcast q8 scalar
|
|
4002
|
+
svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]);
|
|
4003
|
+
|
|
4004
|
+
svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0)));
|
|
4005
|
+
|
|
4006
|
+
svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0)));
|
|
4007
|
+
|
|
4008
|
+
svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
|
4009
|
+
svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
|
4010
|
+
|
|
4011
|
+
acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1);
|
|
4012
|
+
acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1);
|
|
4013
|
+
|
|
4014
|
+
q8_d = svdup_f32(q8_ptr[b].d[1]);
|
|
4015
|
+
|
|
4016
|
+
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
|
4017
|
+
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
|
2315
4018
|
|
|
2316
|
-
|
|
2317
|
-
|
|
2318
|
-
sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
|
|
2319
|
-
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
|
|
2320
|
-
}
|
|
4019
|
+
acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1);
|
|
4020
|
+
acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1);
|
|
2321
4021
|
|
|
2322
|
-
|
|
2323
|
-
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
|
|
2324
|
-
}
|
|
2325
|
-
}
|
|
2326
|
-
}
|
|
2327
|
-
return;
|
|
2328
|
-
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
2329
|
-
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
2330
|
-
}
|
|
4022
|
+
q8_d = svdup_f32(q8_ptr[b].d[2]);
|
|
2331
4023
|
|
|
2332
|
-
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
2333
|
-
constexpr int qk = QK_K;
|
|
2334
|
-
const int nb = n / qk;
|
|
2335
4024
|
|
|
2336
|
-
|
|
2337
|
-
|
|
4025
|
+
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
|
4026
|
+
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
|
2338
4027
|
|
|
2339
|
-
|
|
2340
|
-
|
|
2341
|
-
assert(nc % ncols_interleaved == 0);
|
|
4028
|
+
acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1);
|
|
4029
|
+
acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1);
|
|
2342
4030
|
|
|
2343
|
-
|
|
2344
|
-
UNUSED(ncols_interleaved);
|
|
2345
|
-
UNUSED(blocklen);
|
|
4031
|
+
q8_d = svdup_f32(q8_ptr[b].d[3]);
|
|
2346
4032
|
|
|
2347
|
-
|
|
4033
|
+
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
|
4034
|
+
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
|
4035
|
+
|
|
4036
|
+
acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1);
|
|
4037
|
+
acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1);
|
|
4038
|
+
|
|
4039
|
+
} // for b
|
|
4040
|
+
|
|
4041
|
+
// With the previous reorder, the tile is already in the correct memory layout.
|
|
4042
|
+
// Predicate for exactly 4 lanes
|
|
4043
|
+
svbool_t pg4 = svptrue_pat_b32(SV_VL4);
|
|
4044
|
+
for (int i = 0; i < q8_k_blocklen; i++) {
|
|
4045
|
+
int row = y * q8_k_blocklen + i;
|
|
4046
|
+
for (int j = 0; j < 2; j++) {
|
|
4047
|
+
int col = x * ncols_interleaved + j * 4;
|
|
4048
|
+
int offset = row * bs + col;
|
|
4049
|
+
|
|
4050
|
+
if (i == 0 && j == 0) {
|
|
4051
|
+
// acc_f32_0 → lower half of acc_f32_01
|
|
4052
|
+
svst1_f32(pg4, s + offset, acc_f32_01);
|
|
4053
|
+
} else if (i == 0 && j == 1) {
|
|
4054
|
+
// acc_f32_1 → upper half of acc_f32_01
|
|
4055
|
+
svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4));
|
|
4056
|
+
} else if (i == 1 && j == 0) {
|
|
4057
|
+
// acc_f32_2
|
|
4058
|
+
svst1_f32(pg4, s + offset, acc_f32_23);
|
|
4059
|
+
} else if (i == 1 && j == 1) {
|
|
4060
|
+
// acc_f32_3
|
|
4061
|
+
svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4));
|
|
4062
|
+
} else if (i == 2 && j == 0) {
|
|
4063
|
+
// acc_f32_4
|
|
4064
|
+
svst1_f32(pg4, s + offset, acc_f32_45);
|
|
4065
|
+
} else if (i == 2 && j == 1) {
|
|
4066
|
+
// acc_f32_5
|
|
4067
|
+
svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4));
|
|
4068
|
+
} else if (i == 3 && j == 0) {
|
|
4069
|
+
// acc_f32_6
|
|
4070
|
+
svst1_f32(pg4, s + offset, acc_f32_67);
|
|
4071
|
+
} else if (i == 3 && j == 1) {
|
|
4072
|
+
// acc_f32_7
|
|
4073
|
+
svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4));
|
|
4074
|
+
}
|
|
4075
|
+
}
|
|
4076
|
+
}
|
|
4077
|
+
} // for x
|
|
4078
|
+
} // for y
|
|
4079
|
+
return;
|
|
4080
|
+
}
|
|
4081
|
+
#endif // SVE compile-time end
|
|
4082
|
+
|
|
4083
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
2348
4084
|
constexpr int q8_k_blocklen = 4;
|
|
2349
|
-
|
|
2350
|
-
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
4085
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
2351
4086
|
|
|
2352
4087
|
// 8 accumulators: 2 row pairs × 4 col pairs
|
|
2353
|
-
float32x4_t acc_f32[
|
|
4088
|
+
float32x4_t acc_f32[blocklen];
|
|
2354
4089
|
|
|
2355
4090
|
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
|
2356
4091
|
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
|
@@ -2358,162 +4093,167 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
|
2358
4093
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
2359
4094
|
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
|
2360
4095
|
|
|
2361
|
-
for (int i = 0; i <
|
|
4096
|
+
for (int i = 0; i < blocklen; i++) {
|
|
2362
4097
|
acc_f32[i] = vdupq_n_f32(0);
|
|
2363
4098
|
}
|
|
2364
4099
|
|
|
2365
4100
|
for (int b = 0; b < nb; b++) {
|
|
2366
|
-
//
|
|
2367
|
-
|
|
2368
|
-
float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
|
|
2369
|
-
// d8 0 1 2 3
|
|
2370
|
-
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
|
|
2371
|
-
// mins
|
|
2372
|
-
float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
|
|
2373
|
-
float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
|
|
2374
|
-
|
|
2375
|
-
// Precomputation of scales and mins
|
|
2376
|
-
float32x4_t sbd_scale_0123[q8_k_blocklen];
|
|
2377
|
-
float32x4_t sbd_scale_4567[q8_k_blocklen];
|
|
2378
|
-
float32x4_t sbd_min_0123[q8_k_blocklen];
|
|
2379
|
-
float32x4_t sbd_min_4567[q8_k_blocklen];
|
|
2380
|
-
|
|
2381
|
-
sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
|
|
2382
|
-
sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
|
|
2383
|
-
sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
|
|
2384
|
-
sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
|
|
2385
|
-
|
|
2386
|
-
sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
|
|
2387
|
-
sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
|
|
2388
|
-
sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
|
|
2389
|
-
sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
|
|
2390
|
-
|
|
2391
|
-
sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
|
|
2392
|
-
sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
|
|
2393
|
-
sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
|
|
2394
|
-
sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
|
|
2395
|
-
|
|
2396
|
-
sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
|
|
2397
|
-
sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
|
|
2398
|
-
sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
|
|
2399
|
-
sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
|
|
2400
|
-
|
|
2401
|
-
// Precomputation of bsums, each vpaddq calcs all the bsums for each row
|
|
2402
|
-
const int16x8_t bsums[q8_k_blocklen] = {
|
|
4101
|
+
// bsums pairs belongs to the same q8_k subblock
|
|
4102
|
+
const int16x8_t bsums[4]{
|
|
2403
4103
|
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
|
2404
4104
|
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
|
2405
4105
|
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
|
2406
4106
|
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
|
2407
4107
|
};
|
|
2408
|
-
int16_t bsums_arr[
|
|
4108
|
+
int16_t bsums_arr[4][8];
|
|
2409
4109
|
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
|
2410
4110
|
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
|
2411
4111
|
}
|
|
2412
4112
|
|
|
2413
|
-
|
|
2414
|
-
int32x4_t
|
|
2415
|
-
|
|
4113
|
+
int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
|
|
4114
|
+
int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
|
|
4115
|
+
int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
|
|
4116
|
+
for (int i = 0; i < 8; i++) {
|
|
4117
|
+
acc[i] = vdupq_n_s32(0);
|
|
2416
4118
|
bias_acc[i] = vdupq_n_s32(0);
|
|
2417
4119
|
}
|
|
2418
4120
|
|
|
2419
4121
|
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
2420
|
-
// Int accumulators for qs vecdot (4 row x 2 col quartets)
|
|
2421
|
-
int32x4_t acc_lo[acc_size];
|
|
2422
|
-
int32x4_t acc_hi[acc_size];
|
|
2423
|
-
for (int i = 0; i < acc_size; i++) {
|
|
2424
|
-
acc_lo[i] = vdupq_n_s32(0);
|
|
2425
|
-
acc_hi[i] = vdupq_n_s32(0);
|
|
2426
|
-
}
|
|
2427
4122
|
// Need scales for the low and high nibbles
|
|
2428
4123
|
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
|
2429
|
-
|
|
2430
|
-
int16x8_t q4sb_mins[2];
|
|
4124
|
+
int8_t q4sb_scales[2][8];
|
|
4125
|
+
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
|
|
2431
4126
|
for (int i = 0; i < 2; i++) {
|
|
2432
|
-
int8_t aux_q4sb[8];
|
|
2433
4127
|
const int offset = sb * 24 + i * 12;
|
|
2434
|
-
|
|
2435
|
-
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
|
4128
|
+
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
|
|
2436
4129
|
}
|
|
2437
4130
|
|
|
2438
|
-
|
|
2439
|
-
|
|
2440
|
-
const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
|
|
2441
|
-
const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
|
|
4131
|
+
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
|
|
4132
|
+
const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
|
|
2442
4133
|
|
|
2443
|
-
|
|
2444
|
-
|
|
2445
|
-
const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
|
|
4134
|
+
int8x16_t q8_qs_01[8];
|
|
4135
|
+
int8x16_t q8_qs_23[8];
|
|
2446
4136
|
|
|
2447
|
-
|
|
2448
|
-
|
|
4137
|
+
// Load 32-byte per row pair, 1 subblock each time
|
|
4138
|
+
for (int i = 0; i < 8; i++) {
|
|
4139
|
+
const int offset = i * 32; // 16 for row 01, 16 for row 23
|
|
4140
|
+
q8_qs_01[i] = vld1q_s8(q8_base + offset);
|
|
4141
|
+
q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
|
|
4142
|
+
}
|
|
2449
4143
|
|
|
2450
|
-
|
|
2451
|
-
|
|
2452
|
-
|
|
2453
|
-
|
|
4144
|
+
const int8x16_t q8s[2][8] = {
|
|
4145
|
+
{ q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
|
|
4146
|
+
q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
|
|
4147
|
+
{ q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
|
|
4148
|
+
q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
|
|
4149
|
+
};
|
|
2454
4150
|
|
|
2455
|
-
|
|
2456
|
-
|
|
2457
|
-
|
|
2458
|
-
|
|
4151
|
+
// Q4s columns iterated in pairs (01, 23, 45, 67)
|
|
4152
|
+
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
|
|
4153
|
+
for (int i = 0; i < 4; i++) {
|
|
4154
|
+
sb_acc[i] = vdupq_n_s32(0);
|
|
4155
|
+
}
|
|
2459
4156
|
|
|
2460
|
-
|
|
2461
|
-
|
|
4157
|
+
uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
|
|
4158
|
+
uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
|
|
4159
|
+
uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
|
|
4160
|
+
uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
|
|
4161
|
+
const int8x16_t q4_nibbles[2][4] = {
|
|
4162
|
+
{
|
|
4163
|
+
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
|
|
4164
|
+
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
|
|
4165
|
+
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
|
|
4166
|
+
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
|
|
4167
|
+
},
|
|
4168
|
+
{
|
|
4169
|
+
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
|
|
4170
|
+
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
|
|
4171
|
+
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
|
|
4172
|
+
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
|
|
4173
|
+
}
|
|
4174
|
+
};
|
|
2462
4175
|
|
|
2463
|
-
|
|
2464
|
-
|
|
2465
|
-
|
|
2466
|
-
|
|
4176
|
+
// Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
|
|
4177
|
+
// for each of the internal 32 qs subblock (blk)
|
|
4178
|
+
for (int rp = 0; rp < 2; rp++) {
|
|
4179
|
+
for (int blk = 0; blk < 2; blk++) {
|
|
4180
|
+
const int8x16_t * q8 = &q8s[rp][4 * blk];
|
|
4181
|
+
const int8x16_t * q4 = q4_nibbles[blk];
|
|
4182
|
+
int32x4_t acc = sb_acc[2 * rp + blk];
|
|
4183
|
+
// mul add for each qs in the same subblock
|
|
4184
|
+
for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
|
|
4185
|
+
acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
|
|
4186
|
+
}
|
|
4187
|
+
sb_acc[2 * rp + blk] = acc;
|
|
4188
|
+
}
|
|
4189
|
+
}
|
|
2467
4190
|
|
|
2468
|
-
|
|
2469
|
-
|
|
2470
|
-
|
|
2471
|
-
|
|
4191
|
+
// Scales[i] corresponds to column i
|
|
4192
|
+
const int scale_offset = cp * 2;
|
|
4193
|
+
const int32_t scale_00 = q4sb_scales[0][scale_offset];
|
|
4194
|
+
const int32_t scale_01 = q4sb_scales[0][scale_offset + 1];
|
|
4195
|
+
const int32_t scale_10 = q4sb_scales[1][scale_offset];
|
|
4196
|
+
const int32_t scale_11 = q4sb_scales[1][scale_offset + 1];
|
|
4197
|
+
const int32x4_t block_scale_0 = vcombine_s32(vdup_n_s32(scale_00), vdup_n_s32(scale_01));
|
|
4198
|
+
const int32x4_t block_scale_1 = vcombine_s32(vdup_n_s32(scale_10), vdup_n_s32(scale_11));
|
|
4199
|
+
|
|
4200
|
+
acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale_0);
|
|
4201
|
+
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale_0);
|
|
4202
|
+
acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale_1);
|
|
4203
|
+
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale_1);
|
|
2472
4204
|
}
|
|
2473
4205
|
|
|
2474
|
-
//
|
|
2475
|
-
|
|
2476
|
-
|
|
2477
|
-
|
|
2478
|
-
|
|
2479
|
-
|
|
2480
|
-
for (int row = 0; row < q8_k_blocklen; row++) {
|
|
2481
|
-
// Bias correction
|
|
2482
|
-
// row c0123 blk0 and blk1
|
|
2483
|
-
const float32x4_t sumf_0123 =
|
|
2484
|
-
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
|
|
2485
|
-
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
|
|
2486
|
-
acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
|
|
2487
|
-
|
|
2488
|
-
// row c4567 blk0 and blk1
|
|
2489
|
-
const float32x4_t sumf_4567 =
|
|
2490
|
-
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
|
|
2491
|
-
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
|
|
2492
|
-
acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
|
|
2493
|
-
|
|
2494
|
-
// Bias
|
|
2495
|
-
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
|
|
2496
|
-
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
|
|
2497
|
-
|
|
2498
|
-
// row c0123 blk0 and blk1
|
|
2499
|
-
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
|
2500
|
-
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
|
4206
|
+
// Multiply Acc bsum + mins
|
|
4207
|
+
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
|
4208
|
+
// Each pair of subblocks share the same bsums
|
|
4209
|
+
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
|
|
4210
|
+
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
|
|
4211
|
+
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
|
|
2501
4212
|
|
|
2502
|
-
|
|
2503
|
-
|
|
2504
|
-
|
|
2505
|
-
|
|
2506
|
-
|
|
4213
|
+
bias_acc[2 * q8_row] =
|
|
4214
|
+
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
|
4215
|
+
bias_acc[2 * q8_row] =
|
|
4216
|
+
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
|
4217
|
+
bias_acc[2 * q8_row + 1] =
|
|
4218
|
+
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
|
4219
|
+
bias_acc[2 * q8_row + 1] =
|
|
4220
|
+
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
|
2507
4221
|
}
|
|
2508
4222
|
} // for sb
|
|
2509
4223
|
|
|
2510
|
-
|
|
2511
|
-
|
|
2512
|
-
|
|
2513
|
-
|
|
4224
|
+
// Reorder of i8mm output with bias and output layout
|
|
4225
|
+
for (int i = 0; i < 8; i++) {
|
|
4226
|
+
int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
|
|
4227
|
+
acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
|
|
4228
|
+
}
|
|
4229
|
+
int32x4_t reorder_acc[8] = {
|
|
4230
|
+
vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
|
|
4231
|
+
vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
|
|
4232
|
+
vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
|
|
4233
|
+
vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
|
|
4234
|
+
vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
|
|
4235
|
+
vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
|
|
4236
|
+
vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
|
|
4237
|
+
vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
|
|
4238
|
+
};
|
|
4239
|
+
|
|
4240
|
+
for (int i = 0; i < q8_k_blocklen; i++) {
|
|
4241
|
+
for (int j = 0; j < 2; j++) {
|
|
4242
|
+
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
|
|
4243
|
+
float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
|
|
4244
|
+
const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d);
|
|
4245
|
+
|
|
4246
|
+
float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
|
|
4247
|
+
const float32x4_t scale = vmulq_f32(q4_d, q8_d);
|
|
4248
|
+
|
|
4249
|
+
acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
|
|
4250
|
+
acc_f32[2 * i + j] =
|
|
4251
|
+
vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
|
|
4252
|
+
}
|
|
2514
4253
|
}
|
|
2515
4254
|
} // for b
|
|
2516
4255
|
|
|
4256
|
+
// With the previous reorder, the tile is already in the correct memory layout.
|
|
2517
4257
|
for (int i = 0; i < q8_k_blocklen; i++) {
|
|
2518
4258
|
int row = y * q8_k_blocklen + i;
|
|
2519
4259
|
for (int j = 0; j < 2; j++) {
|
|
@@ -2525,11 +4265,11 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
|
2525
4265
|
} // for x
|
|
2526
4266
|
} // for y
|
|
2527
4267
|
return;
|
|
2528
|
-
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(
|
|
2529
|
-
|
|
4268
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
4269
|
+
ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
2530
4270
|
}
|
|
2531
4271
|
|
|
2532
|
-
void
|
|
4272
|
+
void ggml_gemm_q5_K_8x8_q8_K(int n,
|
|
2533
4273
|
float * GGML_RESTRICT s,
|
|
2534
4274
|
size_t bs,
|
|
2535
4275
|
const void * GGML_RESTRICT vx,
|
|
@@ -2552,7 +4292,10 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
|
|
2552
4292
|
|
|
2553
4293
|
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
2554
4294
|
constexpr int q8_k_blocklen = 4;
|
|
4295
|
+
constexpr int col_pairs = ncols_interleaved / 2;
|
|
2555
4296
|
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
4297
|
+
const uint8x16_t mone = vdupq_n_u8(1);
|
|
4298
|
+
const uint8x16_t mtwo = vdupq_n_u8(2);
|
|
2556
4299
|
|
|
2557
4300
|
// 8 accumulators: 2 row pairs × 4 col pairs
|
|
2558
4301
|
float32x4_t acc_f32[blocklen];
|
|
@@ -2561,7 +4304,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
|
|
2561
4304
|
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
|
2562
4305
|
|
|
2563
4306
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
2564
|
-
const
|
|
4307
|
+
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
|
2565
4308
|
|
|
2566
4309
|
for (int i = 0; i < blocklen; i++) {
|
|
2567
4310
|
acc_f32[i] = vdupq_n_f32(0);
|
|
@@ -2588,14 +4331,24 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
|
|
2588
4331
|
bias_acc[i] = vdupq_n_s32(0);
|
|
2589
4332
|
}
|
|
2590
4333
|
|
|
4334
|
+
// Load qh once per block and shift after each subblock
|
|
4335
|
+
const uint8_t * qh_base = q5_ptr[b].qh;
|
|
4336
|
+
uint8x16_t qh[col_pairs][4];
|
|
4337
|
+
for (int cp = 0; cp < col_pairs; cp++) {
|
|
4338
|
+
qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
|
|
4339
|
+
qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
|
|
4340
|
+
qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
|
|
4341
|
+
qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
|
|
4342
|
+
}
|
|
4343
|
+
|
|
2591
4344
|
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
2592
4345
|
// Need scales for the low and high nibbles
|
|
2593
4346
|
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
|
2594
|
-
int8_t
|
|
2595
|
-
int16x8_t
|
|
4347
|
+
int8_t q5sb_scales[2][8];
|
|
4348
|
+
int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
|
|
2596
4349
|
for (int i = 0; i < 2; i++) {
|
|
2597
4350
|
const int offset = sb * 24 + i * 12;
|
|
2598
|
-
|
|
4351
|
+
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]);
|
|
2599
4352
|
}
|
|
2600
4353
|
|
|
2601
4354
|
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
|
|
@@ -2612,64 +4365,89 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
|
|
2612
4365
|
}
|
|
2613
4366
|
|
|
2614
4367
|
const int8x16_t q8s[2][8] = {
|
|
2615
|
-
{ q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
|
|
2616
|
-
|
|
2617
|
-
{ q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
|
|
2618
|
-
|
|
4368
|
+
{ q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6],
|
|
4369
|
+
q8_qs_01[7] },
|
|
4370
|
+
{ q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6],
|
|
4371
|
+
q8_qs_23[7] },
|
|
2619
4372
|
};
|
|
2620
4373
|
|
|
2621
|
-
//
|
|
2622
|
-
for (int cp = 0; cp <
|
|
4374
|
+
// Q5s columns iterated in pairs (01, 23, 45, 67)
|
|
4375
|
+
for (int cp = 0; cp < col_pairs; cp++) {
|
|
2623
4376
|
for (int i = 0; i < 4; i++) {
|
|
2624
4377
|
sb_acc[i] = vdupq_n_s32(0);
|
|
2625
4378
|
}
|
|
2626
4379
|
|
|
2627
|
-
uint8x16_t
|
|
2628
|
-
uint8x16_t
|
|
2629
|
-
uint8x16_t
|
|
2630
|
-
uint8x16_t
|
|
2631
|
-
|
|
2632
|
-
|
|
2633
|
-
|
|
2634
|
-
|
|
2635
|
-
|
|
2636
|
-
|
|
2637
|
-
|
|
2638
|
-
|
|
2639
|
-
|
|
2640
|
-
|
|
2641
|
-
|
|
2642
|
-
|
|
2643
|
-
|
|
2644
|
-
|
|
2645
|
-
|
|
2646
|
-
|
|
2647
|
-
|
|
2648
|
-
|
|
2649
|
-
|
|
2650
|
-
|
|
2651
|
-
|
|
2652
|
-
|
|
2653
|
-
|
|
2654
|
-
|
|
2655
|
-
|
|
2656
|
-
|
|
2657
|
-
|
|
2658
|
-
|
|
2659
|
-
|
|
4380
|
+
uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
|
|
4381
|
+
uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
|
|
4382
|
+
uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
|
|
4383
|
+
uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
|
|
4384
|
+
|
|
4385
|
+
// This is the only part of the algorithm that differs with Q4_K
|
|
4386
|
+
// Extract High bits and pack into 5 bit weights
|
|
4387
|
+
uint8x16_t hbit_lo_0 = vandq_u8(qh[cp][0], mone);
|
|
4388
|
+
uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3);
|
|
4389
|
+
qh[cp][0] = vshrq_n_u8(qh[cp][0], 2);
|
|
4390
|
+
// Same as Q4_K, i8mm to dequantize the weights.
|
|
4391
|
+
const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4));
|
|
4392
|
+
int32x4_t acc_0 = sb_acc[0];
|
|
4393
|
+
acc_0 = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]);
|
|
4394
|
+
int32x4_t acc_2 = sb_acc[2];
|
|
4395
|
+
acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]);
|
|
4396
|
+
const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0));
|
|
4397
|
+
int32x4_t acc_1 = sb_acc[1];
|
|
4398
|
+
acc_1 = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]);
|
|
4399
|
+
int32x4_t acc_3 = sb_acc[3];
|
|
4400
|
+
acc_3 = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]);
|
|
4401
|
+
|
|
4402
|
+
// Repeat for the other 3 columns (8..15, 16..23, 24..31)
|
|
4403
|
+
uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3);
|
|
4404
|
+
uint8x16_t hbit_lo_1 = vandq_u8(qh[cp][1], mone);
|
|
4405
|
+
qh[cp][1] = vshrq_n_u8(qh[cp][1], 2);
|
|
4406
|
+
const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4));
|
|
4407
|
+
acc_0 = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]);
|
|
4408
|
+
acc_2 = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]);
|
|
4409
|
+
const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1));
|
|
4410
|
+
acc_1 = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]);
|
|
4411
|
+
acc_3 = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]);
|
|
4412
|
+
|
|
4413
|
+
uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3);
|
|
4414
|
+
uint8x16_t hbit_lo_2 = vandq_u8(qh[cp][2], mone);
|
|
4415
|
+
qh[cp][2] = vshrq_n_u8(qh[cp][2], 2);
|
|
4416
|
+
const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4));
|
|
4417
|
+
acc_0 = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]);
|
|
4418
|
+
acc_2 = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]);
|
|
4419
|
+
const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2));
|
|
4420
|
+
acc_1 = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]);
|
|
4421
|
+
acc_3 = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]);
|
|
4422
|
+
|
|
4423
|
+
uint8x16_t hbit_lo_3 = vandq_u8(qh[cp][3], mone);
|
|
4424
|
+
uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3);
|
|
4425
|
+
qh[cp][3] = vshrq_n_u8(qh[cp][3], 2);
|
|
4426
|
+
const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4));
|
|
4427
|
+
acc_0 = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]);
|
|
4428
|
+
sb_acc[0] = acc_0;
|
|
4429
|
+
acc_2 = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]);
|
|
4430
|
+
sb_acc[2] = acc_2;
|
|
2660
4431
|
|
|
2661
4432
|
// Scales[i] corresponds to column i
|
|
2662
|
-
const int
|
|
2663
|
-
|
|
2664
|
-
|
|
2665
|
-
|
|
2666
|
-
|
|
2667
|
-
|
|
2668
|
-
|
|
2669
|
-
|
|
2670
|
-
|
|
2671
|
-
|
|
2672
|
-
|
|
4433
|
+
const int scale_offset = cp * 2;
|
|
4434
|
+
const int32_t s0 = q5sb_scales[0][scale_offset];
|
|
4435
|
+
const int32_t s1 = q5sb_scales[0][scale_offset + 1];
|
|
4436
|
+
const int32x4_t block_scale = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1));
|
|
4437
|
+
acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale);
|
|
4438
|
+
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale);
|
|
4439
|
+
|
|
4440
|
+
const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3));
|
|
4441
|
+
acc_1 = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]);
|
|
4442
|
+
sb_acc[1] = acc_1;
|
|
4443
|
+
acc_3 = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]);
|
|
4444
|
+
sb_acc[3] = acc_3;
|
|
4445
|
+
|
|
4446
|
+
const int32_t s2 = q5sb_scales[1][scale_offset];
|
|
4447
|
+
const int32_t s3 = q5sb_scales[1][scale_offset + 1];
|
|
4448
|
+
const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3));
|
|
4449
|
+
acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale2);
|
|
4450
|
+
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2);
|
|
2673
4451
|
}
|
|
2674
4452
|
|
|
2675
4453
|
// Multiply Acc bsum + mins
|
|
@@ -2680,13 +4458,13 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
|
|
2680
4458
|
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
|
|
2681
4459
|
|
|
2682
4460
|
bias_acc[2 * q8_row] =
|
|
2683
|
-
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(
|
|
4461
|
+
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
|
|
2684
4462
|
bias_acc[2 * q8_row] =
|
|
2685
|
-
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(
|
|
4463
|
+
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
|
|
2686
4464
|
bias_acc[2 * q8_row + 1] =
|
|
2687
|
-
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(
|
|
4465
|
+
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
|
|
2688
4466
|
bias_acc[2 * q8_row + 1] =
|
|
2689
|
-
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(
|
|
4467
|
+
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
|
|
2690
4468
|
}
|
|
2691
4469
|
} // for sb
|
|
2692
4470
|
|
|
@@ -2709,11 +4487,11 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
|
|
2709
4487
|
for (int i = 0; i < q8_k_blocklen; i++) {
|
|
2710
4488
|
for (int j = 0; j < 2; j++) {
|
|
2711
4489
|
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
|
|
2712
|
-
float32x4_t
|
|
2713
|
-
const float32x4_t dmins = vmulq_f32(
|
|
4490
|
+
float32x4_t q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4)));
|
|
4491
|
+
const float32x4_t dmins = vmulq_f32(q5_dmin, q8_d);
|
|
2714
4492
|
|
|
2715
|
-
float32x4_t
|
|
2716
|
-
const float32x4_t scale = vmulq_f32(
|
|
4493
|
+
float32x4_t q5_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4)));
|
|
4494
|
+
const float32x4_t scale = vmulq_f32(q5_d, q8_d);
|
|
2717
4495
|
|
|
2718
4496
|
acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
|
|
2719
4497
|
acc_f32[2 * i + j] =
|
|
@@ -2735,9 +4513,427 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
|
|
2735
4513
|
} // for y
|
|
2736
4514
|
return;
|
|
2737
4515
|
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
2738
|
-
|
|
4516
|
+
ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
4517
|
+
}
|
|
4518
|
+
|
|
4519
|
+
void ggml_gemm_q6_K_8x4_q8_K(int n,
|
|
4520
|
+
float * GGML_RESTRICT s,
|
|
4521
|
+
size_t bs,
|
|
4522
|
+
const void * GGML_RESTRICT vx,
|
|
4523
|
+
const void * GGML_RESTRICT vy,
|
|
4524
|
+
int nr,
|
|
4525
|
+
int nc) {
|
|
4526
|
+
constexpr int qk = QK_K;
|
|
4527
|
+
const int nb = n / qk;
|
|
4528
|
+
|
|
4529
|
+
constexpr int ncols_interleaved = 8;
|
|
4530
|
+
constexpr int blocklen = 4;
|
|
4531
|
+
|
|
4532
|
+
assert(n % qk == 0);
|
|
4533
|
+
assert(nr % 4 == 0);
|
|
4534
|
+
assert(nc % ncols_interleaved == 0);
|
|
4535
|
+
|
|
4536
|
+
UNUSED(nb);
|
|
4537
|
+
UNUSED(ncols_interleaved);
|
|
4538
|
+
UNUSED(blocklen);
|
|
4539
|
+
|
|
4540
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
4541
|
+
constexpr int q8_k_blocklen = 4;
|
|
4542
|
+
constexpr int col_groups = ncols_interleaved / 4;
|
|
4543
|
+
constexpr int acc_size = q8_k_blocklen * col_groups; // 4 rows, 2 column groups
|
|
4544
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
4545
|
+
const uint8x16_t mask_lo = vdupq_n_u8(0x03);
|
|
4546
|
+
const uint8x16_t mask_hi = vdupq_n_u8(0x30);
|
|
4547
|
+
const int8x16_t m32s = vdupq_n_s8(32);
|
|
4548
|
+
|
|
4549
|
+
float32x4_t acc_f32[acc_size];
|
|
4550
|
+
|
|
4551
|
+
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
|
4552
|
+
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
|
4553
|
+
|
|
4554
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
4555
|
+
const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
|
4556
|
+
|
|
4557
|
+
for (int i = 0; i < acc_size; i++) {
|
|
4558
|
+
acc_f32[i] = vdupq_n_f32(0);
|
|
4559
|
+
}
|
|
4560
|
+
|
|
4561
|
+
for (int b = 0; b < nb; b++) {
|
|
4562
|
+
float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));
|
|
4563
|
+
float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));
|
|
4564
|
+
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
|
|
4565
|
+
|
|
4566
|
+
float32x4_t sbd_scale_0123[q8_k_blocklen];
|
|
4567
|
+
float32x4_t sbd_scale_4567[q8_k_blocklen];
|
|
4568
|
+
|
|
4569
|
+
sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0);
|
|
4570
|
+
sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0);
|
|
4571
|
+
sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1);
|
|
4572
|
+
sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1);
|
|
4573
|
+
sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2);
|
|
4574
|
+
sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2);
|
|
4575
|
+
sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3);
|
|
4576
|
+
sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3);
|
|
4577
|
+
|
|
4578
|
+
int32x4_t acc_s32[acc_size];
|
|
4579
|
+
for (int i = 0; i < acc_size; i++) {
|
|
4580
|
+
acc_s32[i] = vdupq_n_s32(0);
|
|
4581
|
+
}
|
|
4582
|
+
|
|
4583
|
+
int16_t q6_scales[8 * 16];
|
|
4584
|
+
for (int i = 0; i < 16; i++) {
|
|
4585
|
+
int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
|
|
4586
|
+
vst1q_s16(q6_scales + i * 8, scales);
|
|
4587
|
+
}
|
|
4588
|
+
|
|
4589
|
+
for (int half = 0; half < 2; half++) {
|
|
4590
|
+
const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
|
|
4591
|
+
const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
|
|
4592
|
+
|
|
4593
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
4594
|
+
int32x4_t acc_lo[acc_size];
|
|
4595
|
+
int32x4_t acc_hi[acc_size];
|
|
4596
|
+
for (int i = 0; i < acc_size; i++) {
|
|
4597
|
+
acc_lo[i] = vdupq_n_s32(0);
|
|
4598
|
+
acc_hi[i] = vdupq_n_s32(0);
|
|
4599
|
+
}
|
|
4600
|
+
|
|
4601
|
+
const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
|
|
4602
|
+
const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
|
|
4603
|
+
|
|
4604
|
+
// 4 rows * 16 elements per scale
|
|
4605
|
+
// 4 reads of 16 bytes each
|
|
4606
|
+
constexpr int reads_per_sb = 4;
|
|
4607
|
+
int8x16_t q8_l[reads_per_sb];
|
|
4608
|
+
int8x16_t q8_h[reads_per_sb];
|
|
4609
|
+
for (int k = 0; k < reads_per_sb; k++) {
|
|
4610
|
+
q8_l[k] = vld1q_s8(q8_base_l + 16 * k);
|
|
4611
|
+
q8_h[k] = vld1q_s8(q8_base_h + 16 * k);
|
|
4612
|
+
}
|
|
4613
|
+
|
|
4614
|
+
const int ql_off_base = sb * QK_K / 2;
|
|
4615
|
+
const int qh_off_base = ql_off_base & 255;
|
|
4616
|
+
|
|
4617
|
+
uint8x16_t q6_ql_0123[reads_per_sb];
|
|
4618
|
+
uint8x16_t q6_ql_4567[reads_per_sb];
|
|
4619
|
+
uint8x16_t q6_qh_0123[reads_per_sb];
|
|
4620
|
+
uint8x16_t q6_qh_4567[reads_per_sb];
|
|
4621
|
+
|
|
4622
|
+
for (int k = 0; k < reads_per_sb; k++) {
|
|
4623
|
+
q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32);
|
|
4624
|
+
q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16);
|
|
4625
|
+
q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32);
|
|
4626
|
+
q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16);
|
|
4627
|
+
}
|
|
4628
|
+
|
|
4629
|
+
if (sb > 1) {
|
|
4630
|
+
for (int k = 0; k < reads_per_sb; k++) {
|
|
4631
|
+
q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2);
|
|
4632
|
+
q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2);
|
|
4633
|
+
}
|
|
4634
|
+
}
|
|
4635
|
+
|
|
4636
|
+
for (int k = 0; k < reads_per_sb; k++) {
|
|
4637
|
+
// q = (ql | qh) - 32
|
|
4638
|
+
const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo);
|
|
4639
|
+
const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi);
|
|
4640
|
+
const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo);
|
|
4641
|
+
const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi);
|
|
4642
|
+
|
|
4643
|
+
const int8x16_t q6_0123_lo = vsubq_s8(
|
|
4644
|
+
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s);
|
|
4645
|
+
const int8x16_t q6_0123_hi = vsubq_s8(
|
|
4646
|
+
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s);
|
|
4647
|
+
|
|
4648
|
+
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0); // 0..3 r0 c0123
|
|
4649
|
+
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1); // 0..3 r1 c0123
|
|
4650
|
+
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2); // 0..3 r2 c0123
|
|
4651
|
+
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3); // 0..3 r3 c0123
|
|
4652
|
+
|
|
4653
|
+
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0); // 64..67 r0 c0123
|
|
4654
|
+
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1); // 64..67 r1 c0123
|
|
4655
|
+
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2); // 64..67 r2 c0123
|
|
4656
|
+
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3); // 64..67 r3 c0123
|
|
4657
|
+
|
|
4658
|
+
const int8x16_t q6_4567_lo = vsubq_s8(
|
|
4659
|
+
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s);
|
|
4660
|
+
const int8x16_t q6_4567_hi = vsubq_s8(
|
|
4661
|
+
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s);
|
|
4662
|
+
|
|
4663
|
+
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0); // 0..3 r0 c4567
|
|
4664
|
+
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1); // 0..3 r1 c4567
|
|
4665
|
+
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2); // 0..3 r2 c4567
|
|
4666
|
+
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3); // 0..3 r3 c4567
|
|
4667
|
+
|
|
4668
|
+
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0); // 64..67 r0 c4567
|
|
4669
|
+
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1); // 64..67 r1 c4567
|
|
4670
|
+
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2); // 64..67 r2 c4567
|
|
4671
|
+
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3); // 64..67 r3 c4567
|
|
4672
|
+
}
|
|
4673
|
+
|
|
4674
|
+
// Scale and bias
|
|
4675
|
+
const int scale_idx_l = half * 8 + sb;
|
|
4676
|
+
const int scale_idx_h = half * 8 + sb + 4;
|
|
4677
|
+
|
|
4678
|
+
for (int g = 0; g < col_groups; g++) {
|
|
4679
|
+
const int16x4_t scales_l16 = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4);
|
|
4680
|
+
const int16x4_t scales_h16 = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4);
|
|
4681
|
+
const int32x4_t scale_vec_l = vmovl_s16(scales_l16);
|
|
4682
|
+
const int32x4_t scale_vec_h = vmovl_s16(scales_h16);
|
|
4683
|
+
const int acc_offset = g * q8_k_blocklen;
|
|
4684
|
+
|
|
4685
|
+
for (int row = 0; row < q8_k_blocklen; row++) {
|
|
4686
|
+
const int idx = row * 2 + g;
|
|
4687
|
+
acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l);
|
|
4688
|
+
acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h);
|
|
4689
|
+
}
|
|
4690
|
+
}
|
|
4691
|
+
}
|
|
4692
|
+
}
|
|
4693
|
+
|
|
4694
|
+
// Finally we apply the superblock scales
|
|
4695
|
+
for (int row = 0; row < q8_k_blocklen; row++) {
|
|
4696
|
+
const int idx0 = 2 * row;
|
|
4697
|
+
const int idx1 = 2 * row + 1;
|
|
4698
|
+
const int32x4_t acc_0123 = acc_s32[idx0];
|
|
4699
|
+
const int32x4_t acc_4567 = acc_s32[idx1];
|
|
4700
|
+
|
|
4701
|
+
acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]);
|
|
4702
|
+
acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]);
|
|
4703
|
+
}
|
|
4704
|
+
} // for b
|
|
4705
|
+
|
|
4706
|
+
for (int i = 0; i < q8_k_blocklen; i++) {
|
|
4707
|
+
int row = y * q8_k_blocklen + i;
|
|
4708
|
+
for (int j = 0; j < 2; j++) {
|
|
4709
|
+
int col = x * ncols_interleaved + j * 4;
|
|
4710
|
+
int offset = row * bs + col;
|
|
4711
|
+
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
|
4712
|
+
}
|
|
4713
|
+
}
|
|
4714
|
+
} // for x
|
|
4715
|
+
} // for y
|
|
4716
|
+
return;
|
|
4717
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
4718
|
+
ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
2739
4719
|
}
|
|
2740
4720
|
|
|
4721
|
+
void ggml_gemm_q6_K_8x8_q8_K(int n,
|
|
4722
|
+
float * GGML_RESTRICT s,
|
|
4723
|
+
size_t bs,
|
|
4724
|
+
const void * GGML_RESTRICT vx,
|
|
4725
|
+
const void * GGML_RESTRICT vy,
|
|
4726
|
+
int nr,
|
|
4727
|
+
int nc) {
|
|
4728
|
+
constexpr int qk = QK_K;
|
|
4729
|
+
const int nb = n / qk;
|
|
4730
|
+
|
|
4731
|
+
constexpr int ncols_interleaved = 8;
|
|
4732
|
+
constexpr int blocklen = 8;
|
|
4733
|
+
|
|
4734
|
+
assert(n % qk == 0);
|
|
4735
|
+
assert(nr % 4 == 0);
|
|
4736
|
+
assert(nc % ncols_interleaved == 0);
|
|
4737
|
+
|
|
4738
|
+
UNUSED(nb);
|
|
4739
|
+
UNUSED(ncols_interleaved);
|
|
4740
|
+
UNUSED(blocklen);
|
|
4741
|
+
|
|
4742
|
+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
4743
|
+
constexpr int q8_k_blocklen = 4;
|
|
4744
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
4745
|
+
const uint8x16_t mask_lo = vdupq_n_u8(0x03);
|
|
4746
|
+
const uint8x16_t mask_hi = vdupq_n_u8(0x30);
|
|
4747
|
+
const int8x16_t m32s = vdupq_n_s8(32);
|
|
4748
|
+
|
|
4749
|
+
// 8 accumulators: 4 q8 rows × 2 col groups (0-3, 4-7)
|
|
4750
|
+
float32x4_t acc_f32[blocklen];
|
|
4751
|
+
|
|
4752
|
+
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
|
4753
|
+
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
|
4754
|
+
|
|
4755
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
4756
|
+
const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
|
4757
|
+
|
|
4758
|
+
for (int i = 0; i < blocklen; i++) {
|
|
4759
|
+
acc_f32[i] = vdupq_n_f32(0);
|
|
4760
|
+
}
|
|
4761
|
+
|
|
4762
|
+
for (int b = 0; b < nb; b++) {
|
|
4763
|
+
int32x4_t acc[8]; // rows 01 stored in [0][1][2][3], rows 23 stored in [4][5][6][7]
|
|
4764
|
+
for (int i = 0; i < 8; i++) {
|
|
4765
|
+
acc[i] = vdupq_n_s32(0);
|
|
4766
|
+
}
|
|
4767
|
+
|
|
4768
|
+
// Q6_K has simple 8-bit scales, 16 per block (one per 16 values)
|
|
4769
|
+
// Reused for bias and dequantization later
|
|
4770
|
+
int16_t q6_scales[16 * 8];
|
|
4771
|
+
for (int i = 0; i < 16; ++i) {
|
|
4772
|
+
int16x8_t s16 = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
|
|
4773
|
+
vst1q_s16(q6_scales + i * 8, s16);
|
|
4774
|
+
}
|
|
4775
|
+
|
|
4776
|
+
// Process two 128-value halves per superblock
|
|
4777
|
+
for (int half = 0; half < 2; half++) {
|
|
4778
|
+
|
|
4779
|
+
const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
|
|
4780
|
+
const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
|
|
4781
|
+
|
|
4782
|
+
// A subblock (sb) is a set of weights that share the scale
|
|
4783
|
+
// Since q6_K scales are per 16 elements
|
|
4784
|
+
// num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
|
|
4785
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
4786
|
+
// Q6_K weight index increasing by 64 instead of 32 requires
|
|
4787
|
+
// loading various q8 memory regions
|
|
4788
|
+
const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
|
|
4789
|
+
const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
|
|
4790
|
+
|
|
4791
|
+
int8x16_t q8_l_01[2];
|
|
4792
|
+
int8x16_t q8_l_23[2];
|
|
4793
|
+
for (int i = 0; i < 2; i++) {
|
|
4794
|
+
const int offset = i * 32;
|
|
4795
|
+
q8_l_01[i] = vld1q_s8(q8_base_l + offset); // 0..7 & 8..15 (r01)
|
|
4796
|
+
q8_l_23[i] = vld1q_s8(q8_base_l + offset + 16); // 0..7 & 8..15 (r23)
|
|
4797
|
+
}
|
|
4798
|
+
|
|
4799
|
+
int8x16_t q8_h_01[2];
|
|
4800
|
+
int8x16_t q8_h_23[2];
|
|
4801
|
+
for (int i = 0; i < 2; i++) {
|
|
4802
|
+
const int offset = i * 32;
|
|
4803
|
+
q8_h_01[i] = vld1q_s8(q8_base_h + offset);
|
|
4804
|
+
q8_h_23[i] = vld1q_s8(q8_base_h + offset + 16);
|
|
4805
|
+
}
|
|
4806
|
+
|
|
4807
|
+
const int ql_off_base = sb * QK_K / 2;
|
|
4808
|
+
|
|
4809
|
+
uint8x16_t q6_ql_0[4];
|
|
4810
|
+
uint8x16_t q6_ql_1[4];
|
|
4811
|
+
for (int k = 0; k < 4; k++) {
|
|
4812
|
+
q6_ql_0[k] = vld1q_u8(ql_base + ql_off_base + 16 * k);
|
|
4813
|
+
q6_ql_1[k] = vld1q_u8(ql_base + ql_off_base + 64 + 16 * k);
|
|
4814
|
+
}
|
|
4815
|
+
|
|
4816
|
+
const int qh_off_base = (sb * QK_K / 2) & 255; // wrap after 256 bytes
|
|
4817
|
+
uint8x16_t q6_qh_0[4];
|
|
4818
|
+
uint8x16_t q6_qh_1[4];
|
|
4819
|
+
for (int k = 0; k < 4; k++) {
|
|
4820
|
+
q6_qh_0[k] = vld1q_u8(qh_base + qh_off_base + 16 * k);
|
|
4821
|
+
q6_qh_1[k] = vld1q_u8(qh_base + qh_off_base + 64 + 16 * k);
|
|
4822
|
+
}
|
|
4823
|
+
|
|
4824
|
+
// Adjust for the proper high bits (Sb 2 and 3)
|
|
4825
|
+
if (sb > 1) {
|
|
4826
|
+
for (int k = 0; k < 4; k++) {
|
|
4827
|
+
q6_qh_0[k] = vshrq_n_u8(q6_qh_0[k], 2);
|
|
4828
|
+
q6_qh_1[k] = vshrq_n_u8(q6_qh_1[k], 2);
|
|
4829
|
+
}
|
|
4830
|
+
}
|
|
4831
|
+
|
|
4832
|
+
// Process column pairs (0-1, 2-3, 4-5, 6-7)
|
|
4833
|
+
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
|
|
4834
|
+
const uint8x16_t q6_qs_cp_0_l = q6_ql_0[cp];
|
|
4835
|
+
const uint8x16_t q6_qs_cp_1_l = q6_ql_1[cp];
|
|
4836
|
+
const uint8x16_t q6_qs_cp_0_h = q6_qh_0[cp];
|
|
4837
|
+
const uint8x16_t q6_qs_cp_1_h = q6_qh_1[cp];
|
|
4838
|
+
|
|
4839
|
+
// Extract high 2 bits for upper nibble reconstruction
|
|
4840
|
+
const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);
|
|
4841
|
+
const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);
|
|
4842
|
+
|
|
4843
|
+
// q6 = (low4 | high2<<4) - 32
|
|
4844
|
+
// Use vsliq_n_u8 to combine shift-left-insert in one instruction (like Q5_K)
|
|
4845
|
+
const int8x16_t q6_l0 = vsubq_s8(
|
|
4846
|
+
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)),
|
|
4847
|
+
m32s);
|
|
4848
|
+
const int8x16_t q6_l1 = vsubq_s8(
|
|
4849
|
+
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)),
|
|
4850
|
+
m32s);
|
|
4851
|
+
const int8x16_t q6_h0 = vsubq_s8(
|
|
4852
|
+
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)), m32s);
|
|
4853
|
+
const int8x16_t q6_h1 = vsubq_s8(
|
|
4854
|
+
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)), m32s);
|
|
4855
|
+
|
|
4856
|
+
// row pair 0, base_l
|
|
4857
|
+
int32x4_t sb_acc_0l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_01[0]);
|
|
4858
|
+
sb_acc_0l = vmmlaq_s32(sb_acc_0l, q6_l1, q8_l_01[1]);
|
|
4859
|
+
// row pair 0, base_h
|
|
4860
|
+
int32x4_t sb_acc_0h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_01[0]);
|
|
4861
|
+
sb_acc_0h = vmmlaq_s32(sb_acc_0h, q6_h1, q8_h_01[1]);
|
|
4862
|
+
// row pair 1, base_l
|
|
4863
|
+
int32x4_t sb_acc_1l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_23[0]);
|
|
4864
|
+
sb_acc_1l = vmmlaq_s32(sb_acc_1l, q6_l1, q8_l_23[1]);
|
|
4865
|
+
// row pair 1, base_h
|
|
4866
|
+
int32x4_t sb_acc_1h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_23[0]);
|
|
4867
|
+
sb_acc_1h = vmmlaq_s32(sb_acc_1h, q6_h1, q8_h_23[1]);
|
|
4868
|
+
|
|
4869
|
+
const int scale_idx_l = half * 8 + sb;
|
|
4870
|
+
const int scale_idx_h = half * 8 + sb + 4;
|
|
4871
|
+
|
|
4872
|
+
const int32x4_t scale_vec_l = {
|
|
4873
|
+
q6_scales[scale_idx_l * 8 + cp * 2 + 0],
|
|
4874
|
+
q6_scales[scale_idx_l * 8 + cp * 2 + 0],
|
|
4875
|
+
q6_scales[scale_idx_l * 8 + cp * 2 + 1],
|
|
4876
|
+
q6_scales[scale_idx_l * 8 + cp * 2 + 1],
|
|
4877
|
+
};
|
|
4878
|
+
const int32x4_t scale_vec_h = {
|
|
4879
|
+
q6_scales[scale_idx_h * 8 + cp * 2 + 0],
|
|
4880
|
+
q6_scales[scale_idx_h * 8 + cp * 2 + 0],
|
|
4881
|
+
q6_scales[scale_idx_h * 8 + cp * 2 + 1],
|
|
4882
|
+
q6_scales[scale_idx_h * 8 + cp * 2 + 1],
|
|
4883
|
+
};
|
|
4884
|
+
|
|
4885
|
+
acc[cp] = vmlaq_s32(acc[cp], sb_acc_0l, scale_vec_l);
|
|
4886
|
+
acc[cp] = vmlaq_s32(acc[cp], sb_acc_0h, scale_vec_h);
|
|
4887
|
+
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1l, scale_vec_l);
|
|
4888
|
+
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1h, scale_vec_h);
|
|
4889
|
+
}
|
|
4890
|
+
}
|
|
4891
|
+
} // for half
|
|
4892
|
+
|
|
4893
|
+
// Reorder i8mm output to match memory layout
|
|
4894
|
+
for (int i = 0; i < 8; i++) {
|
|
4895
|
+
int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
|
|
4896
|
+
acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
|
|
4897
|
+
}
|
|
4898
|
+
int32x4_t reorder_acc[8] = {
|
|
4899
|
+
vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
|
|
4900
|
+
vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
|
|
4901
|
+
vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
|
|
4902
|
+
vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
|
|
4903
|
+
vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
|
|
4904
|
+
vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
|
|
4905
|
+
vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
|
|
4906
|
+
vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
|
|
4907
|
+
};
|
|
4908
|
+
|
|
4909
|
+
// Apply superblock scale (no mins for q6_K)
|
|
4910
|
+
for (int i = 0; i < q8_k_blocklen; i++) {
|
|
4911
|
+
for (int j = 0; j < 2; j++) {
|
|
4912
|
+
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
|
|
4913
|
+
float32x4_t q6_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q6_ptr[b].d + j * 4)));
|
|
4914
|
+
const float32x4_t scale = vmulq_f32(q6_d, q8_d);
|
|
4915
|
+
|
|
4916
|
+
acc_f32[2 * i + j] =
|
|
4917
|
+
vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
|
|
4918
|
+
}
|
|
4919
|
+
}
|
|
4920
|
+
} // for b
|
|
4921
|
+
|
|
4922
|
+
// Store results
|
|
4923
|
+
for (int i = 0; i < q8_k_blocklen; i++) {
|
|
4924
|
+
int row = y * q8_k_blocklen + i;
|
|
4925
|
+
for (int j = 0; j < 2; j++) {
|
|
4926
|
+
int col = x * ncols_interleaved + j * 4;
|
|
4927
|
+
int offset = row * bs + col;
|
|
4928
|
+
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
|
4929
|
+
}
|
|
4930
|
+
}
|
|
4931
|
+
} // for x
|
|
4932
|
+
} // for y
|
|
4933
|
+
return;
|
|
4934
|
+
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
4935
|
+
ggml_gemm_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
4936
|
+
}
|
|
2741
4937
|
|
|
2742
4938
|
void ggml_gemm_q8_0_4x4_q8_0(int n,
|
|
2743
4939
|
float * GGML_RESTRICT s,
|