whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -113,6 +113,104 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
|
|
|
113
113
|
#endif
|
|
114
114
|
}
|
|
115
115
|
|
|
116
|
+
void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
|
|
117
|
+
assert(k % QK_K == 0);
|
|
118
|
+
block_q8_K * y_blocks = (block_q8_K *)y;
|
|
119
|
+
size_t nb = k / QK_K;
|
|
120
|
+
|
|
121
|
+
#if defined(__riscv_v_intrinsic)
|
|
122
|
+
const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8();
|
|
123
|
+
|
|
124
|
+
for (size_t i = 0; i < nb; i++) {
|
|
125
|
+
const float* x_block = x + i * QK_K;
|
|
126
|
+
block_q8_K* y_block = &y_blocks[i];
|
|
127
|
+
|
|
128
|
+
// 1. Calculate Min/Max
|
|
129
|
+
vfloat32m8_t max_v = __riscv_vfmv_v_f_f32m8(-__builtin_inff(), vlmax_f32m8);
|
|
130
|
+
vfloat32m8_t min_v = __riscv_vfmv_v_f_f32m8(__builtin_inff(), vlmax_f32m8);
|
|
131
|
+
|
|
132
|
+
size_t rem = QK_K;
|
|
133
|
+
size_t offset = 0;
|
|
134
|
+
while (rem > 0) {
|
|
135
|
+
size_t vl = __riscv_vsetvl_e32m8(rem);
|
|
136
|
+
vfloat32m8_t v_curr = __riscv_vle32_v_f32m8(x_block + offset, vl);
|
|
137
|
+
max_v = __riscv_vfmax_vv_f32m8(max_v, v_curr, vl);
|
|
138
|
+
min_v = __riscv_vfmin_vv_f32m8(min_v, v_curr, vl);
|
|
139
|
+
rem -= vl;
|
|
140
|
+
offset += vl;
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
vfloat32m1_t v_init_max = __riscv_vfmv_s_f_f32m1(-__builtin_inff(), 1);
|
|
144
|
+
vfloat32m1_t v_init_min = __riscv_vfmv_s_f_f32m1(__builtin_inff(), 1);
|
|
145
|
+
|
|
146
|
+
vfloat32m1_t v_scalar_max = __riscv_vfredmax_vs_f32m8_f32m1(max_v, v_init_max, vlmax_f32m8);
|
|
147
|
+
vfloat32m1_t v_scalar_min = __riscv_vfredmin_vs_f32m8_f32m1(min_v, v_init_min, vlmax_f32m8);
|
|
148
|
+
|
|
149
|
+
float max_val = __riscv_vfmv_f_s_f32m1_f32(v_scalar_max);
|
|
150
|
+
float min_val = __riscv_vfmv_f_s_f32m1_f32(v_scalar_min);
|
|
151
|
+
|
|
152
|
+
float amax = fabsf(max_val) > fabsf(min_val) ? fabsf(max_val) : fabsf(min_val);
|
|
153
|
+
|
|
154
|
+
if (amax == 0.0f) {
|
|
155
|
+
y_block->d = 0.0f;
|
|
156
|
+
memset(y_block->qs, 0, QK_K);
|
|
157
|
+
memset(y_block->bsums, 0, sizeof(y_block->bsums));
|
|
158
|
+
continue;
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
const float iscale = -127.f / (fabsf(max_val) > fabsf(min_val) ? max_val : min_val);
|
|
162
|
+
y_block->d = 1.0f / iscale;
|
|
163
|
+
|
|
164
|
+
// 2. Quantize and Calculate Sums
|
|
165
|
+
offset = 0;
|
|
166
|
+
rem = QK_K;
|
|
167
|
+
vint16m1_t v_zero_sum = __riscv_vmv_v_x_i16m1(0, 1);
|
|
168
|
+
|
|
169
|
+
while (rem > 0) {
|
|
170
|
+
size_t vl = __riscv_vsetvl_e32m8(rem);
|
|
171
|
+
vfloat32m8_t v_f = __riscv_vle32_v_f32m8(x_block + offset, vl);
|
|
172
|
+
|
|
173
|
+
v_f = __riscv_vfmul_vf_f32m8(v_f, iscale, vl);
|
|
174
|
+
|
|
175
|
+
vint32m8_t v_i32 = __riscv_vfcvt_x_f_v_i32m8_rm(v_f, __RISCV_FRM_RNE, vl);
|
|
176
|
+
vint16m4_t v_i16 = __riscv_vnclip_wx_i16m4(v_i32, 0, __RISCV_VXRM_RNE, vl);
|
|
177
|
+
vint8m2_t v_q = __riscv_vnclip_wx_i8m2(v_i16, 0, __RISCV_VXRM_RNE, vl);
|
|
178
|
+
|
|
179
|
+
__riscv_vse8_v_i8m2(y_block->qs + offset, v_q, vl);
|
|
180
|
+
|
|
181
|
+
// first iteration clear
|
|
182
|
+
|
|
183
|
+
int sum_idx;
|
|
184
|
+
vint8m1_t chunk_m1;
|
|
185
|
+
vint16m1_t v_sum;
|
|
186
|
+
sum_idx = offset / 16;
|
|
187
|
+
chunk_m1 = __riscv_vget_v_i8m2_i8m1(v_q, 0);
|
|
188
|
+
v_sum = __riscv_vwredsum_vs_i8m1_i16m1(chunk_m1, v_zero_sum, 16);
|
|
189
|
+
y_block->bsums[sum_idx] = (int16_t)__riscv_vmv_x_s_i16m1_i16(v_sum);
|
|
190
|
+
|
|
191
|
+
// remaining iterations
|
|
192
|
+
vint8m2_t slid_q = v_q;
|
|
193
|
+
for (size_t k = 16; k < vl; k += 16) {
|
|
194
|
+
slid_q = __riscv_vslidedown_vx_i8m2(slid_q, 16, vl);
|
|
195
|
+
|
|
196
|
+
sum_idx = (offset + k) / 16;
|
|
197
|
+
chunk_m1 = __riscv_vget_v_i8m2_i8m1(slid_q, 0);
|
|
198
|
+
|
|
199
|
+
v_sum = __riscv_vwredsum_vs_i8m1_i16m1(chunk_m1, v_zero_sum, 16);
|
|
200
|
+
y_block->bsums[sum_idx] =(int16_t)__riscv_vmv_x_s_i16m1_i16(v_sum);
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
rem -= vl;
|
|
204
|
+
offset += vl;
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
#else
|
|
208
|
+
GGML_UNUSED(nb);
|
|
209
|
+
// scalar
|
|
210
|
+
quantize_row_q8_K_ref(x, y, k);
|
|
211
|
+
#endif
|
|
212
|
+
}
|
|
213
|
+
|
|
116
214
|
//===================================== Dot products =================================
|
|
117
215
|
|
|
118
216
|
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
@@ -1954,3 +2052,1558 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
1954
2052
|
#endif
|
|
1955
2053
|
}
|
|
1956
2054
|
|
|
2055
|
+
static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
2056
|
+
assert(n % QK_K == 0);
|
|
2057
|
+
assert(nrc == 1);
|
|
2058
|
+
UNUSED(nrc);
|
|
2059
|
+
UNUSED(bx);
|
|
2060
|
+
UNUSED(by);
|
|
2061
|
+
UNUSED(bs);
|
|
2062
|
+
|
|
2063
|
+
const block_iq1_s * GGML_RESTRICT x = vx;
|
|
2064
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
|
2065
|
+
|
|
2066
|
+
const int nb = n / QK_K;
|
|
2067
|
+
|
|
2068
|
+
float sumf = 0;
|
|
2069
|
+
for (int i = 0; i < nb; ++i) {
|
|
2070
|
+
// Load qh once for the entire superblock.
|
|
2071
|
+
vuint16mf2_t qh = __riscv_vle16_v_u16mf2(x[i].qh, 8);
|
|
2072
|
+
|
|
2073
|
+
// Calculate ls.
|
|
2074
|
+
vuint16mf2_t temp = __riscv_vsrl_vx_u16mf2(qh, 12, 8);
|
|
2075
|
+
temp = __riscv_vand_vx_u16mf2(temp, 7, 8);
|
|
2076
|
+
vint32m1_t ls = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vwmulu_vx_u32m1(temp, 2, 8));
|
|
2077
|
+
ls = __riscv_vadd_vx_i32m1(ls, 1, 8);
|
|
2078
|
+
|
|
2079
|
+
// Calculate delta.
|
|
2080
|
+
vbool32_t mask = __riscv_vmseq_vx_u16mf2_b32(__riscv_vand_vx_u16mf2(qh, 0x8000, 8), 0, 8);
|
|
2081
|
+
vint32m1_t delta_neg = __riscv_vmv_v_x_i32m1(-1, 8);
|
|
2082
|
+
vint32m1_t delta_pos = __riscv_vmv_v_x_i32m1(1, 8);
|
|
2083
|
+
vint32m1_t delta = __riscv_vmerge_vvm_i32m1(delta_neg, delta_pos, mask, 8);
|
|
2084
|
+
|
|
2085
|
+
// Load qs.
|
|
2086
|
+
vuint8m1_t qs = __riscv_vle8_v_u8m1(x[i].qs, 32);
|
|
2087
|
+
|
|
2088
|
+
// Prepare the indices.
|
|
2089
|
+
const uint64_t shift = 0x0009000600030000;
|
|
2090
|
+
vuint16m2_t qh_shift = __riscv_vreinterpret_v_u64m2_u16m2(__riscv_vmv_v_x_u64m2(shift, 8));
|
|
2091
|
+
vuint16m2_t qh_gather_index = __riscv_vreinterpret_v_i16m2_u16m2(
|
|
2092
|
+
__riscv_vdiv_vx_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vid_v_u16m2(32)), 4, 32));
|
|
2093
|
+
vuint16m2_t qh_ext = __riscv_vlmul_ext_v_u16m1_u16m2(__riscv_vlmul_ext_v_u16mf2_u16m1(qh));
|
|
2094
|
+
vuint16m2_t qh_index = __riscv_vrgather_vv_u16m2(qh_ext, qh_gather_index, 32);
|
|
2095
|
+
qh_index = __riscv_vsrl_vv_u16m2(qh_index, qh_shift, 32);
|
|
2096
|
+
qh_index = __riscv_vand_vx_u16m2(qh_index, 7, 32);
|
|
2097
|
+
qh_index = __riscv_vsll_vx_u16m2(qh_index, 8, 32);
|
|
2098
|
+
qh_index = __riscv_vor_vv_u16m2(qh_index, __riscv_vzext_vf2_u16m2(qs, 32), 32);
|
|
2099
|
+
vuint16m2_t index = __riscv_vsll_vx_u16m2(qh_index, 3, 32);
|
|
2100
|
+
|
|
2101
|
+
// Final lsums.
|
|
2102
|
+
int32_t lsums_s[8];
|
|
2103
|
+
vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1);
|
|
2104
|
+
|
|
2105
|
+
// Sub-blocks 1-4
|
|
2106
|
+
{
|
|
2107
|
+
vuint16m1_t grid_index0 = __riscv_vget_v_u16m2_u16m1(index, 0);
|
|
2108
|
+
vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 16));
|
|
2109
|
+
vint8m4_t q80 = __riscv_vle8_v_i8m4(y[i].qs, 128);
|
|
2110
|
+
vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128);
|
|
2111
|
+
lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 0), one_scalar, 32));
|
|
2112
|
+
lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 1), one_scalar, 32));
|
|
2113
|
+
lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 2), one_scalar, 32));
|
|
2114
|
+
lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 3), one_scalar, 32));
|
|
2115
|
+
}
|
|
2116
|
+
__asm__ __volatile__("" ::: "memory");
|
|
2117
|
+
// Sub-blocks 5-8
|
|
2118
|
+
{
|
|
2119
|
+
vuint16m1_t grid_index1 = __riscv_vget_v_u16m2_u16m1(index, 1);
|
|
2120
|
+
vint8m4_t grid1 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index1, 16));
|
|
2121
|
+
vint8m4_t q81 = __riscv_vle8_v_i8m4(&y[i].qs[128], 128);
|
|
2122
|
+
vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(grid1, q81, 128);
|
|
2123
|
+
lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 0), one_scalar, 32));
|
|
2124
|
+
lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 1), one_scalar, 32));
|
|
2125
|
+
lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 2), one_scalar, 32));
|
|
2126
|
+
lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 3), one_scalar, 32));
|
|
2127
|
+
}
|
|
2128
|
+
__asm__ __volatile__("" ::: "memory");
|
|
2129
|
+
vint32m1_t lsums = __riscv_vle32_v_i32m1(&lsums_s[0], 8);
|
|
2130
|
+
|
|
2131
|
+
// Calculate the bsums.
|
|
2132
|
+
vint16m1_t bsums_0 = __riscv_vle16_v_i16m1(y[i].bsums, 16);
|
|
2133
|
+
const vuint32m1_t bsums_i32 = __riscv_vreinterpret_v_u16m1_u32m1(__riscv_vreinterpret_v_i16m1_u16m1(bsums_0));
|
|
2134
|
+
const vint16mf2_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 0, 8));
|
|
2135
|
+
const vint16mf2_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 16, 8));
|
|
2136
|
+
const vint32m1_t bsums = __riscv_vwadd_vv_i32m1(bsums_i32_0, bsums_i32_1, 8);
|
|
2137
|
+
|
|
2138
|
+
// Accumulation.
|
|
2139
|
+
vint32m1_t sumi_v = __riscv_vmul_vv_i32m1(ls, lsums, 8);
|
|
2140
|
+
vint32m1_t sumi1_v = __riscv_vmul_vv_i32m1(__riscv_vmul_vv_i32m1(ls, delta, 8), bsums, 8);
|
|
2141
|
+
|
|
2142
|
+
// Update sumf.
|
|
2143
|
+
int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8));
|
|
2144
|
+
int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8));
|
|
2145
|
+
sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
|
|
2146
|
+
}
|
|
2147
|
+
|
|
2148
|
+
*s = sumf;
|
|
2149
|
+
}
|
|
2150
|
+
|
|
2151
|
+
void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
2152
|
+
#if defined __riscv_v_intrinsic
|
|
2153
|
+
switch (__riscv_vlenb() * 8) {
|
|
2154
|
+
case 256:
|
|
2155
|
+
ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
|
2156
|
+
break;
|
|
2157
|
+
default:
|
|
2158
|
+
ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
2159
|
+
break;
|
|
2160
|
+
}
|
|
2161
|
+
#else
|
|
2162
|
+
ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
2163
|
+
#endif
|
|
2164
|
+
}
|
|
2165
|
+
|
|
2166
|
+
static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
2167
|
+
assert(n % QK_K == 0);
|
|
2168
|
+
assert(nrc == 1);
|
|
2169
|
+
UNUSED(nrc);
|
|
2170
|
+
UNUSED(bx);
|
|
2171
|
+
UNUSED(by);
|
|
2172
|
+
UNUSED(bs);
|
|
2173
|
+
|
|
2174
|
+
const block_iq1_m * GGML_RESTRICT x = vx;
|
|
2175
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
|
2176
|
+
|
|
2177
|
+
const int nb = n / QK_K;
|
|
2178
|
+
|
|
2179
|
+
iq1m_scale_t scale;
|
|
2180
|
+
float sumf = 0.0f;
|
|
2181
|
+
for (int i = 0; i < nb; ++i) {
|
|
2182
|
+
const int8_t * q8 = y[i].qs;
|
|
2183
|
+
const uint8_t * qs = x[i].qs;
|
|
2184
|
+
const uint8_t * qh = x[i].qh;
|
|
2185
|
+
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
|
2186
|
+
|
|
2187
|
+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
|
2188
|
+
|
|
2189
|
+
// Accumulators.
|
|
2190
|
+
vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
2191
|
+
vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
2192
|
+
|
|
2193
|
+
// We process 4 sub-blocks together.
|
|
2194
|
+
for (int ib = 0; ib < QK_K/128; ib++) {
|
|
2195
|
+
// Load qh for 4 sub-blocks.
|
|
2196
|
+
const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 8);
|
|
2197
|
+
const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 8);
|
|
2198
|
+
const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 8);
|
|
2199
|
+
const vuint16m1_t qhb = __riscv_vzext_vf2_u16m1(
|
|
2200
|
+
__riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 8)), 16);
|
|
2201
|
+
qh += 8;
|
|
2202
|
+
|
|
2203
|
+
// Prepare grid indices.
|
|
2204
|
+
const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 16), 16);
|
|
2205
|
+
const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 8));
|
|
2206
|
+
vuint16m1_t index = __riscv_vor_vv_u16m1(qsb, __riscv_vand_vx_u16m1(__riscv_vsll_vv_u16m1(qhb, shift, 16), 0x700, 16), 16);
|
|
2207
|
+
index = __riscv_vsll_vx_u16m1(index, 3, 16);
|
|
2208
|
+
qs += 16;
|
|
2209
|
+
|
|
2210
|
+
// Load the grid.
|
|
2211
|
+
const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4(
|
|
2212
|
+
__riscv_vluxei16_v_u64m4(iq1s_grid, index, 16)));
|
|
2213
|
+
|
|
2214
|
+
// Prepare the deltas.
|
|
2215
|
+
const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16(
|
|
2216
|
+
__riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 8)), 16), 0, 16);
|
|
2217
|
+
const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 16);
|
|
2218
|
+
const vint64m4_t delta_neg = __riscv_vmv_v_x_i64m4(0xffffffffffffffff, 16);
|
|
2219
|
+
const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4(
|
|
2220
|
+
__riscv_vmerge_vvm_i64m4(delta_pos, delta_neg, mask, 16));
|
|
2221
|
+
|
|
2222
|
+
// Load q8 for sub-blocks.
|
|
2223
|
+
const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128);
|
|
2224
|
+
q8 += 128;
|
|
2225
|
+
|
|
2226
|
+
// Calculate the lsums.
|
|
2227
|
+
const vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(iq1b, q8b, 128);
|
|
2228
|
+
const vint16m8_t lsum2 = __riscv_vwmul_vv_i16m8(delta, q8b, 128);
|
|
2229
|
+
|
|
2230
|
+
// Prepare the scales.
|
|
2231
|
+
const int16_t ls_0_0 = 2*((sc[0] >> 0) & 0x7) + 1;
|
|
2232
|
+
const int16_t ls_0_1 = 2*((sc[0] >> 3) & 0x7) + 1;
|
|
2233
|
+
const int16_t ls_1_0 = 2*((sc[0] >> 6) & 0x7) + 1;
|
|
2234
|
+
const int16_t ls_1_1 = 2*((sc[0] >> 9) & 0x7) + 1;
|
|
2235
|
+
const int16_t ls_2_0 = 2*((sc[1] >> 0) & 0x7) + 1;
|
|
2236
|
+
const int16_t ls_2_1 = 2*((sc[1] >> 3) & 0x7) + 1;
|
|
2237
|
+
const int16_t ls_3_0 = 2*((sc[1] >> 6) & 0x7) + 1;
|
|
2238
|
+
const int16_t ls_3_1 = 2*((sc[1] >> 9) & 0x7) + 1;
|
|
2239
|
+
sc += 2;
|
|
2240
|
+
|
|
2241
|
+
// Accumulate in acc0 and acc1 for each sub-block.
|
|
2242
|
+
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 16);
|
|
2243
|
+
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 16);
|
|
2244
|
+
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 16);
|
|
2245
|
+
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 16);
|
|
2246
|
+
//
|
|
2247
|
+
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 16);
|
|
2248
|
+
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 16);
|
|
2249
|
+
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 16);
|
|
2250
|
+
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 16);
|
|
2251
|
+
//
|
|
2252
|
+
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 16);
|
|
2253
|
+
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 16);
|
|
2254
|
+
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 16);
|
|
2255
|
+
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 16);
|
|
2256
|
+
//
|
|
2257
|
+
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 16);
|
|
2258
|
+
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16);
|
|
2259
|
+
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16);
|
|
2260
|
+
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16);
|
|
2261
|
+
}
|
|
2262
|
+
|
|
2263
|
+
// Reduce and accumulate in `sumf`.
|
|
2264
|
+
vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1);
|
|
2265
|
+
int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 16));
|
|
2266
|
+
int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 16));
|
|
2267
|
+
sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2);
|
|
2268
|
+
}
|
|
2269
|
+
|
|
2270
|
+
*s = sumf;
|
|
2271
|
+
}
|
|
2272
|
+
|
|
2273
|
+
void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
2274
|
+
#if defined __riscv_v_intrinsic
|
|
2275
|
+
switch (__riscv_vlenb() * 8) {
|
|
2276
|
+
case 256:
|
|
2277
|
+
ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
|
2278
|
+
break;
|
|
2279
|
+
default:
|
|
2280
|
+
ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
2281
|
+
break;
|
|
2282
|
+
}
|
|
2283
|
+
#else
|
|
2284
|
+
ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
2285
|
+
#endif
|
|
2286
|
+
}
|
|
2287
|
+
|
|
2288
|
+
static const uint8_t sign_gather_indices_arr[64] = {
|
|
2289
|
+
0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3,
|
|
2290
|
+
4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7
|
|
2291
|
+
};
|
|
2292
|
+
|
|
2293
|
+
static const uint8_t sign_bit_masks_arr[64] = {
|
|
2294
|
+
1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128,
|
|
2295
|
+
1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128
|
|
2296
|
+
};
|
|
2297
|
+
|
|
2298
|
+
|
|
2299
|
+
static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
2300
|
+
assert(n % QK_K == 0);
|
|
2301
|
+
UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);
|
|
2302
|
+
|
|
2303
|
+
const block_iq2_s * GGML_RESTRICT x = vx;
|
|
2304
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
|
2305
|
+
|
|
2306
|
+
const int nb = n / QK_K;
|
|
2307
|
+
const uint64_t * grid64 = (const uint64_t *)iq2s_grid;
|
|
2308
|
+
|
|
2309
|
+
// Pre-load Constants
|
|
2310
|
+
vuint8m2_t v_ids = __riscv_vid_v_u8m2(32);
|
|
2311
|
+
vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 32);
|
|
2312
|
+
vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 32);
|
|
2313
|
+
vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 32);
|
|
2314
|
+
vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 32);
|
|
2315
|
+
uint16_t shift_qh_arr[4] = {11, 9, 7, 5};
|
|
2316
|
+
vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 4);
|
|
2317
|
+
|
|
2318
|
+
float sumf = 0.0f;
|
|
2319
|
+
|
|
2320
|
+
for (int i = 0; i < nb; ++i) {
|
|
2321
|
+
const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
2322
|
+
|
|
2323
|
+
const uint8_t * GGML_RESTRICT qs = x[i].qs;
|
|
2324
|
+
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
|
2325
|
+
const uint8_t * GGML_RESTRICT scales = x[i].scales;
|
|
2326
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
|
2327
|
+
|
|
2328
|
+
const uint8_t * signs_ptr = qs + 32;
|
|
2329
|
+
float sum_block = 0.0f;
|
|
2330
|
+
|
|
2331
|
+
for (int ib = 0; ib < 8; ++ib) {
|
|
2332
|
+
|
|
2333
|
+
// Load Low Bits [4 bytes]
|
|
2334
|
+
vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 4);
|
|
2335
|
+
qs += 4;
|
|
2336
|
+
|
|
2337
|
+
// Load 1 byte. It contains bits for 4 mini-blocks.
|
|
2338
|
+
uint8_t qh_val = *qh++;
|
|
2339
|
+
|
|
2340
|
+
// Combine Low + High bits of 10bit indices
|
|
2341
|
+
vuint8mf4_t v_qh_raw = __riscv_vmv_v_x_u8mf4(qh_val, 4);
|
|
2342
|
+
vuint16mf2_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qh_raw, 4);
|
|
2343
|
+
vuint16mf2_t v_qh_mf2 = __riscv_vsll_vv_u16mf2(v_qh_u16, v_shift_qh, 4);
|
|
2344
|
+
v_qh_mf2 = __riscv_vand_vx_u16mf2(v_qh_mf2, 0x1800, 4);
|
|
2345
|
+
vuint16mf2_t v_qs_u16_mf2 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 4);
|
|
2346
|
+
vuint16mf2_t v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16_mf2, 3, 4);
|
|
2347
|
+
vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_mf2, 4);
|
|
2348
|
+
|
|
2349
|
+
// Lookup Grid
|
|
2350
|
+
vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(__riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 4)));
|
|
2351
|
+
|
|
2352
|
+
vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 4);
|
|
2353
|
+
signs_ptr += 4;
|
|
2354
|
+
vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);
|
|
2355
|
+
vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 32);
|
|
2356
|
+
|
|
2357
|
+
// generating sign mask
|
|
2358
|
+
vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 32);
|
|
2359
|
+
vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 32);
|
|
2360
|
+
|
|
2361
|
+
vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 32);
|
|
2362
|
+
q8 += 32;
|
|
2363
|
+
|
|
2364
|
+
// apply signs
|
|
2365
|
+
vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative,v_q8, v_q8, 0, 32);
|
|
2366
|
+
vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 32);
|
|
2367
|
+
|
|
2368
|
+
// Reduction
|
|
2369
|
+
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
|
|
2370
|
+
|
|
2371
|
+
// Reduce 0-15 (First Half)
|
|
2372
|
+
int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(
|
|
2373
|
+
__riscv_vget_v_i16m4_i16m2(v_dot, 0), v_zero, 16));
|
|
2374
|
+
|
|
2375
|
+
// Reduce 16-31 (Second Half)
|
|
2376
|
+
int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(
|
|
2377
|
+
__riscv_vget_v_i16m4_i16m2(v_dot, 1), v_zero, 16));
|
|
2378
|
+
|
|
2379
|
+
// Apply sub Scales
|
|
2380
|
+
uint8_t sc = *scales++;
|
|
2381
|
+
|
|
2382
|
+
sum_block += s0 * (2 * (sc & 0xF) + 1);
|
|
2383
|
+
sum_block += s1 * (2 * (sc >> 4) + 1);
|
|
2384
|
+
}
|
|
2385
|
+
sumf += sum_block * combined_scale;
|
|
2386
|
+
}
|
|
2387
|
+
*s = 0.125f * sumf;
|
|
2388
|
+
}
|
|
2389
|
+
|
|
2390
|
+
static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
2391
|
+
assert(n % QK_K == 0);
|
|
2392
|
+
UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);
|
|
2393
|
+
|
|
2394
|
+
const block_iq2_s * GGML_RESTRICT x = vx;
|
|
2395
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
|
2396
|
+
|
|
2397
|
+
const int nb = n / QK_K;
|
|
2398
|
+
const uint64_t * grid64 = (const uint64_t *)iq2s_grid;
|
|
2399
|
+
|
|
2400
|
+
// --- Pre-load Constants ---
|
|
2401
|
+
uint16_t gather_qh_arr[8] = {0, 0, 0, 0, 1, 1, 1, 1};
|
|
2402
|
+
vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 8);
|
|
2403
|
+
uint16_t shift_qh_arr[8] = {11, 9, 7, 5, 11, 9, 7, 5};
|
|
2404
|
+
vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 8);
|
|
2405
|
+
|
|
2406
|
+
// Constants for sign extraction
|
|
2407
|
+
vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64);
|
|
2408
|
+
vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64);
|
|
2409
|
+
|
|
2410
|
+
float sumf = 0.0f;
|
|
2411
|
+
|
|
2412
|
+
for (int i = 0; i < nb; ++i) {
|
|
2413
|
+
const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
2414
|
+
|
|
2415
|
+
const uint8_t * GGML_RESTRICT qs = x[i].qs;
|
|
2416
|
+
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
|
2417
|
+
const uint8_t * GGML_RESTRICT scales = x[i].scales;
|
|
2418
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
|
2419
|
+
|
|
2420
|
+
const uint8_t * signs_ptr = qs + 32;
|
|
2421
|
+
|
|
2422
|
+
float sum_block = 0.0f;
|
|
2423
|
+
|
|
2424
|
+
for (int ib = 0; ib < 4; ++ib) {
|
|
2425
|
+
// Combine low + high bits
|
|
2426
|
+
vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 8);
|
|
2427
|
+
qs += 8;
|
|
2428
|
+
uint16_t qh_val;
|
|
2429
|
+
memcpy(&qh_val, qh, 2);
|
|
2430
|
+
qh += 2;
|
|
2431
|
+
vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8((const uint8_t*)&qh_val, 2);
|
|
2432
|
+
vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 2);
|
|
2433
|
+
vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16);
|
|
2434
|
+
vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 8);
|
|
2435
|
+
v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 8);
|
|
2436
|
+
|
|
2437
|
+
// Mask: We want bits 11-12. 0x1800 = 0001 1000 0000 0000
|
|
2438
|
+
v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 8);
|
|
2439
|
+
vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 8);
|
|
2440
|
+
|
|
2441
|
+
// Multiply by 8 to get byte offset, instead of element offset
|
|
2442
|
+
v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 8);
|
|
2443
|
+
vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 8);
|
|
2444
|
+
|
|
2445
|
+
// Lookup Grid using Byte Offsets
|
|
2446
|
+
vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 8);
|
|
2447
|
+
|
|
2448
|
+
vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals);
|
|
2449
|
+
vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8);
|
|
2450
|
+
|
|
2451
|
+
// Load signs and generate sign mask
|
|
2452
|
+
vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 8);
|
|
2453
|
+
signs_ptr += 8;
|
|
2454
|
+
|
|
2455
|
+
vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);
|
|
2456
|
+
vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64);
|
|
2457
|
+
|
|
2458
|
+
vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64);
|
|
2459
|
+
vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64);
|
|
2460
|
+
|
|
2461
|
+
vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);
|
|
2462
|
+
q8 += 64;
|
|
2463
|
+
|
|
2464
|
+
vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64);
|
|
2465
|
+
vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 64);
|
|
2466
|
+
|
|
2467
|
+
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
|
|
2468
|
+
|
|
2469
|
+
int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
|
|
2470
|
+
__riscv_vget_v_i16m4_i16m1(v_dot, 0), v_zero, 16));
|
|
2471
|
+
int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
|
|
2472
|
+
__riscv_vget_v_i16m4_i16m1(v_dot, 1), v_zero, 16));
|
|
2473
|
+
int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
|
|
2474
|
+
__riscv_vget_v_i16m4_i16m1(v_dot, 2), v_zero, 16));
|
|
2475
|
+
int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
|
|
2476
|
+
__riscv_vget_v_i16m4_i16m1(v_dot, 3), v_zero, 16));
|
|
2477
|
+
|
|
2478
|
+
uint8_t sc0 = scales[0];
|
|
2479
|
+
uint8_t sc1 = scales[1];
|
|
2480
|
+
scales += 2;
|
|
2481
|
+
|
|
2482
|
+
sum_block += s0 * (2 * (sc0 & 0xF) + 1);
|
|
2483
|
+
sum_block += s1 * (2 * (sc0 >> 4) + 1);
|
|
2484
|
+
sum_block += s2 * (2 * (sc1 & 0xF) + 1);
|
|
2485
|
+
sum_block += s3 * (2 * (sc1 >> 4) + 1);
|
|
2486
|
+
}
|
|
2487
|
+
sumf += sum_block * combined_scale;
|
|
2488
|
+
}
|
|
2489
|
+
*s = 0.125f * sumf;
|
|
2490
|
+
}
|
|
2491
|
+
|
|
2492
|
+
void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
2493
|
+
#if defined __riscv_v_intrinsic
|
|
2494
|
+
switch (__riscv_vlenb() * 8) {
|
|
2495
|
+
case 128:
|
|
2496
|
+
ggml_vec_dot_iq2_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc);
|
|
2497
|
+
break;
|
|
2498
|
+
case 256:
|
|
2499
|
+
ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
|
2500
|
+
break;
|
|
2501
|
+
default:
|
|
2502
|
+
ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
2503
|
+
break;
|
|
2504
|
+
}
|
|
2505
|
+
#else
|
|
2506
|
+
ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
2507
|
+
#endif
|
|
2508
|
+
}
|
|
2509
|
+
|
|
2510
|
+
#if defined(__riscv_v)
|
|
2511
|
+
static const int8_t keven_signs_q2xs[1024] = {
|
|
2512
|
+
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
|
|
2513
|
+
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
|
|
2514
|
+
1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
|
|
2515
|
+
1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
|
|
2516
|
+
1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
|
|
2517
|
+
1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
|
|
2518
|
+
1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
|
|
2519
|
+
1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
|
|
2520
|
+
1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
|
|
2521
|
+
1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
|
|
2522
|
+
1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
|
|
2523
|
+
1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
|
|
2524
|
+
1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
|
|
2525
|
+
1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
|
|
2526
|
+
1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
|
|
2527
|
+
1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
|
|
2528
|
+
1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
|
|
2529
|
+
1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
|
|
2530
|
+
1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
|
|
2531
|
+
1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
|
|
2532
|
+
1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
|
|
2533
|
+
1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
|
|
2534
|
+
1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
|
|
2535
|
+
1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
|
|
2536
|
+
1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
|
|
2537
|
+
1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
|
|
2538
|
+
1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
|
|
2539
|
+
1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
|
|
2540
|
+
1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
|
|
2541
|
+
1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
|
|
2542
|
+
1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
|
|
2543
|
+
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
|
|
2544
|
+
};
|
|
2545
|
+
#endif
|
|
2546
|
+
|
|
2547
|
+
static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
2548
|
+
assert(n % QK_K == 0);
|
|
2549
|
+
assert(nrc == 1);
|
|
2550
|
+
UNUSED(nrc);
|
|
2551
|
+
UNUSED(bx);
|
|
2552
|
+
UNUSED(by);
|
|
2553
|
+
UNUSED(bs);
|
|
2554
|
+
|
|
2555
|
+
const block_iq2_xs * GGML_RESTRICT x = vx;
|
|
2556
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
|
2557
|
+
|
|
2558
|
+
const int nb = n / QK_K;
|
|
2559
|
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
|
2560
|
+
const uint64_t * grid64 = (const uint64_t *)iq2xs_grid;
|
|
2561
|
+
|
|
2562
|
+
float sumf = 0.0f;
|
|
2563
|
+
|
|
2564
|
+
for (int i = 0; i < nb; ++i) {
|
|
2565
|
+
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
2566
|
+
const uint16_t * GGML_RESTRICT qs = x[i].qs;
|
|
2567
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
|
2568
|
+
const uint8_t * GGML_RESTRICT scales = x[i].scales;
|
|
2569
|
+
|
|
2570
|
+
int32_t sum_int = 0;
|
|
2571
|
+
|
|
2572
|
+
// Loop over 4 subblocks of 64 elements (QK_K = 256)
|
|
2573
|
+
for (int ib64 = 0; ib64 < QK_K / 64; ++ib64) {
|
|
2574
|
+
// Load 8 uint16 indices (controls 64 values)
|
|
2575
|
+
vuint16mf2_t v_qs = __riscv_vle16_v_u16mf2(qs, 8);
|
|
2576
|
+
qs += 8;
|
|
2577
|
+
|
|
2578
|
+
// Extract indices for grid (low 9 bits) and signs (high 7 bits)
|
|
2579
|
+
// Multiply by 8 (<< 3) for byte offsets into the uint64 tables
|
|
2580
|
+
vuint16mf2_t vidx_grid = __riscv_vsll_vx_u16mf2(__riscv_vand_vx_u16mf2(v_qs, 511, 8), 3, 8);
|
|
2581
|
+
vuint16mf2_t vidx_sign = __riscv_vsll_vx_u16mf2(__riscv_vsrl_vx_u16mf2(v_qs, 9, 8), 3, 8);
|
|
2582
|
+
|
|
2583
|
+
vuint64m2_t vq2_64 = __riscv_vluxei16_v_u64m2(grid64, vidx_grid, 8);
|
|
2584
|
+
vuint64m2_t vs2_64 = __riscv_vluxei16_v_u64m2(signs64, vidx_sign, 8);
|
|
2585
|
+
|
|
2586
|
+
vint8m2_t q2u = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64));
|
|
2587
|
+
vint8m2_t q2s = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64));
|
|
2588
|
+
|
|
2589
|
+
vint8m2_t q2_final = __riscv_vmul_vv_i8m2(q2u, q2s, 64);
|
|
2590
|
+
|
|
2591
|
+
vint8m2_t q8v = __riscv_vle8_v_i8m2(q8, 64);
|
|
2592
|
+
q8 += 64;
|
|
2593
|
+
|
|
2594
|
+
vint16m4_t prod = __riscv_vwmul_vv_i16m4(q2_final, q8v, 64);
|
|
2595
|
+
|
|
2596
|
+
vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1);
|
|
2597
|
+
|
|
2598
|
+
int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
|
|
2599
|
+
__riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 16));
|
|
2600
|
+
int32_t sum1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
|
|
2601
|
+
__riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 16));
|
|
2602
|
+
int32_t sum2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
|
|
2603
|
+
__riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 16));
|
|
2604
|
+
int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
|
|
2605
|
+
__riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 16));
|
|
2606
|
+
|
|
2607
|
+
const uint8_t scale_byte_1 = scales[0];
|
|
2608
|
+
const uint8_t scale_byte_2 = scales[1];
|
|
2609
|
+
scales += 2;
|
|
2610
|
+
|
|
2611
|
+
sum_int += sum0 * ((scale_byte_1 & 0x0F) * 2 + 1);
|
|
2612
|
+
sum_int += sum1 * ((scale_byte_1 >> 4) * 2 + 1);
|
|
2613
|
+
sum_int += sum2 * ((scale_byte_2 & 0x0F) * 2 + 1);
|
|
2614
|
+
sum_int += sum3 * ((scale_byte_2 >> 4) * 2 + 1);
|
|
2615
|
+
}
|
|
2616
|
+
|
|
2617
|
+
sumf += d * sum_int;
|
|
2618
|
+
}
|
|
2619
|
+
*s = 0.125f * sumf;
|
|
2620
|
+
}
|
|
2621
|
+
|
|
2622
|
+
void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
2623
|
+
#if defined __riscv_v_intrinsic
|
|
2624
|
+
switch (__riscv_vlenb() * 8) {
|
|
2625
|
+
case 256:
|
|
2626
|
+
ggml_vec_dot_iq2_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
|
2627
|
+
break;
|
|
2628
|
+
default:
|
|
2629
|
+
ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
2630
|
+
break;
|
|
2631
|
+
}
|
|
2632
|
+
#else
|
|
2633
|
+
ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
2634
|
+
#endif
|
|
2635
|
+
}
|
|
2636
|
+
|
|
2637
|
+
static void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
2638
|
+
assert(n % QK_K == 0);
|
|
2639
|
+
assert(nrc == 1);
|
|
2640
|
+
UNUSED(nrc);
|
|
2641
|
+
UNUSED(bx);
|
|
2642
|
+
UNUSED(by);
|
|
2643
|
+
UNUSED(bs);
|
|
2644
|
+
|
|
2645
|
+
const block_iq2_xxs * GGML_RESTRICT x = vx;
|
|
2646
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
|
2647
|
+
|
|
2648
|
+
const int nb = n / QK_K;
|
|
2649
|
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
|
2650
|
+
const uint64_t * grid64 = (const uint64_t *)iq2xxs_grid;
|
|
2651
|
+
|
|
2652
|
+
uint32_t shift_constants[4] = {0, 7, 14, 21};
|
|
2653
|
+
vuint32m1_t v_shifts = __riscv_vle32_v_u32m1(shift_constants, 4);
|
|
2654
|
+
|
|
2655
|
+
float sumf = 0.0f;
|
|
2656
|
+
for (int i = 0; i < nb; ++i) {
|
|
2657
|
+
const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
2658
|
+
|
|
2659
|
+
const uint8_t * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs;
|
|
2660
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
|
2661
|
+
|
|
2662
|
+
float sum = 0.0f;
|
|
2663
|
+
|
|
2664
|
+
#pragma GCC unroll 1
|
|
2665
|
+
for (int ib32 = 0; ib32 < QK_K / 32; ib32 += 2) {
|
|
2666
|
+
vint8m2_t q8_1 = __riscv_vle8_v_i8m2(q8, 32); q8 += 32;
|
|
2667
|
+
vint8m2_t q8_2 = __riscv_vle8_v_i8m2(q8, 32); q8 += 32;
|
|
2668
|
+
|
|
2669
|
+
vuint8mf4_t v_raw_q2_1 = __riscv_vle8_v_u8mf4(q2_ptr, 4);
|
|
2670
|
+
vuint8mf4_t v_raw_q2_2 = __riscv_vle8_v_u8mf4(q2_ptr + 8, 4);
|
|
2671
|
+
|
|
2672
|
+
vuint16mf2_t vidx_q2_1 = __riscv_vwcvtu_x_x_v_u16mf2(v_raw_q2_1, 4);
|
|
2673
|
+
vuint16mf2_t vidx_q2_2 = __riscv_vwcvtu_x_x_v_u16mf2(v_raw_q2_2, 4);
|
|
2674
|
+
|
|
2675
|
+
vidx_q2_1 = __riscv_vsll_vx_u16mf2(vidx_q2_1, 3, 4);
|
|
2676
|
+
vidx_q2_2 = __riscv_vsll_vx_u16mf2(vidx_q2_2, 3, 4);
|
|
2677
|
+
|
|
2678
|
+
uint32_t s_packed_1, s_packed_2;
|
|
2679
|
+
memcpy(&s_packed_1, q2_ptr + 4, 4);
|
|
2680
|
+
memcpy(&s_packed_2, q2_ptr + 12, 4);
|
|
2681
|
+
|
|
2682
|
+
vuint32m1_t v_s_1 = __riscv_vmv_v_x_u32m1(s_packed_1, 4);
|
|
2683
|
+
vuint32m1_t v_s_2 = __riscv_vmv_v_x_u32m1(s_packed_2, 4);
|
|
2684
|
+
v_s_1 = __riscv_vsrl_vv_u32m1(v_s_1, v_shifts, 4);
|
|
2685
|
+
v_s_2 = __riscv_vsrl_vv_u32m1(v_s_2, v_shifts, 4);
|
|
2686
|
+
|
|
2687
|
+
v_s_1 = __riscv_vand_vx_u32m1(v_s_1, 127, 4);
|
|
2688
|
+
v_s_2 = __riscv_vand_vx_u32m1(v_s_2, 127, 4);
|
|
2689
|
+
|
|
2690
|
+
vuint16mf2_t vidx_s2_1 = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_1, 4), 3, 4);
|
|
2691
|
+
vuint16mf2_t vidx_s2_2 = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_2, 4), 3, 4);
|
|
2692
|
+
|
|
2693
|
+
vuint64m2_t vq2_64_1 = __riscv_vluxei16_v_u64m2(grid64, vidx_q2_1, 4);
|
|
2694
|
+
vuint64m2_t vq2_64_2 = __riscv_vluxei16_v_u64m2(grid64, vidx_q2_2, 4);
|
|
2695
|
+
|
|
2696
|
+
vint8m2_t q2_1 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64_1));
|
|
2697
|
+
vint8m2_t q2_2 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64_2));
|
|
2698
|
+
|
|
2699
|
+
vuint64m2_t vs2_64_1 = __riscv_vluxei16_v_u64m2(signs64, vidx_s2_1, 4);
|
|
2700
|
+
vuint64m2_t vs2_64_2 = __riscv_vluxei16_v_u64m2(signs64, vidx_s2_2, 4);
|
|
2701
|
+
vint8m2_t s2_1 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64_1));
|
|
2702
|
+
vint8m2_t s2_2 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64_2));
|
|
2703
|
+
|
|
2704
|
+
vint8m2_t q8s_1 = __riscv_vmul_vv_i8m2(q8_1, s2_1, 32);
|
|
2705
|
+
vint8m2_t q8s_2 = __riscv_vmul_vv_i8m2(q8_2, s2_2, 32);
|
|
2706
|
+
|
|
2707
|
+
vint16m4_t dot1 = __riscv_vwmul_vv_i16m4(q8s_1, q2_1, 32);
|
|
2708
|
+
vint16m4_t dot2 = __riscv_vwmul_vv_i16m4(q8s_2, q2_2, 32);
|
|
2709
|
+
|
|
2710
|
+
vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1);
|
|
2711
|
+
vint32m1_t sumv1 = __riscv_vwredsum_vs_i16m4_i32m1(dot1, zero_vec, 32);
|
|
2712
|
+
vint32m1_t sumv2 = __riscv_vwredsum_vs_i16m4_i32m1(dot2, zero_vec, 32);
|
|
2713
|
+
|
|
2714
|
+
int32_t scalar_sum1 = __riscv_vmv_x_s_i32m1_i32(sumv1);
|
|
2715
|
+
int32_t scalar_sum2 = __riscv_vmv_x_s_i32m1_i32(sumv2);
|
|
2716
|
+
|
|
2717
|
+
int16_t scale1 = 2 * ((s_packed_1 >> 28) & 0xF) + 1;
|
|
2718
|
+
int16_t scale2 = 2 * ((s_packed_2 >> 28) & 0xF) + 1;
|
|
2719
|
+
|
|
2720
|
+
sum += scalar_sum1 * scale1 + scalar_sum2 * scale2;
|
|
2721
|
+
q2_ptr += 16;
|
|
2722
|
+
}
|
|
2723
|
+
sumf += sum * combined_scale;
|
|
2724
|
+
}
|
|
2725
|
+
*s = 0.125f * sumf;
|
|
2726
|
+
}
|
|
2727
|
+
|
|
2728
|
+
static void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
2729
|
+
assert(n % QK_K == 0);
|
|
2730
|
+
assert(nrc == 1);
|
|
2731
|
+
UNUSED(nrc);
|
|
2732
|
+
UNUSED(bx);
|
|
2733
|
+
UNUSED(by);
|
|
2734
|
+
UNUSED(bs);
|
|
2735
|
+
|
|
2736
|
+
const block_iq2_xxs * GGML_RESTRICT x = vx;
|
|
2737
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
|
2738
|
+
|
|
2739
|
+
const int nb = n / QK_K;
|
|
2740
|
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
|
2741
|
+
const uint64_t * grid64 = (const uint64_t *)iq2xxs_grid;
|
|
2742
|
+
|
|
2743
|
+
uint32_t shift_constants[4] = {0, 7, 14, 21};
|
|
2744
|
+
vuint32mf2_t v_shifts = __riscv_vle32_v_u32mf2(shift_constants, 4);
|
|
2745
|
+
|
|
2746
|
+
float sumf = 0.0f;
|
|
2747
|
+
|
|
2748
|
+
for (int i = 0; i < nb; ++i) {
|
|
2749
|
+
const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
2750
|
+
|
|
2751
|
+
const uint8_t * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs;
|
|
2752
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
|
2753
|
+
|
|
2754
|
+
float sum = 0.0f;
|
|
2755
|
+
|
|
2756
|
+
for (int ib32 = 0; ib32 < QK_K / 32; ib32 += 2) {
|
|
2757
|
+
vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8, 32); q8 += 32;
|
|
2758
|
+
vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8, 32); q8 += 32;
|
|
2759
|
+
|
|
2760
|
+
vuint8mf8_t v_raw_q2_1 = __riscv_vle8_v_u8mf8(q2_ptr, 4);
|
|
2761
|
+
vuint8mf8_t v_raw_q2_2 = __riscv_vle8_v_u8mf8(q2_ptr + 8, 4);
|
|
2762
|
+
|
|
2763
|
+
vuint16mf4_t vidx_q2_1 = __riscv_vwcvtu_x_x_v_u16mf4(v_raw_q2_1, 4);
|
|
2764
|
+
vuint16mf4_t vidx_q2_2 = __riscv_vwcvtu_x_x_v_u16mf4(v_raw_q2_2, 4);
|
|
2765
|
+
|
|
2766
|
+
vidx_q2_1 = __riscv_vsll_vx_u16mf4(vidx_q2_1, 3, 4);
|
|
2767
|
+
vidx_q2_2 = __riscv_vsll_vx_u16mf4(vidx_q2_2, 3, 4);
|
|
2768
|
+
|
|
2769
|
+
uint32_t s_packed_1, s_packed_2;
|
|
2770
|
+
memcpy(&s_packed_1, q2_ptr + 4, 4);
|
|
2771
|
+
memcpy(&s_packed_2, q2_ptr + 12, 4);
|
|
2772
|
+
|
|
2773
|
+
vuint32mf2_t v_s_1 = __riscv_vmv_v_x_u32mf2(s_packed_1, 4);
|
|
2774
|
+
vuint32mf2_t v_s_2 = __riscv_vmv_v_x_u32mf2(s_packed_2, 4);
|
|
2775
|
+
|
|
2776
|
+
v_s_1 = __riscv_vsrl_vv_u32mf2(v_s_1, v_shifts, 4);
|
|
2777
|
+
v_s_2 = __riscv_vsrl_vv_u32mf2(v_s_2, v_shifts, 4);
|
|
2778
|
+
|
|
2779
|
+
v_s_1 = __riscv_vand_vx_u32mf2(v_s_1, 127, 4);
|
|
2780
|
+
v_s_2 = __riscv_vand_vx_u32mf2(v_s_2, 127, 4);
|
|
2781
|
+
|
|
2782
|
+
// Narrow u32 -> u16 (vncvt) and Scale by 8 to get byte offsets
|
|
2783
|
+
vuint16mf4_t vidx_s2_1 = __riscv_vsll_vx_u16mf4(__riscv_vncvt_x_x_w_u16mf4(v_s_1, 4), 3, 4);
|
|
2784
|
+
vuint16mf4_t vidx_s2_2 = __riscv_vsll_vx_u16mf4(__riscv_vncvt_x_x_w_u16mf4(v_s_2, 4), 3, 4);
|
|
2785
|
+
|
|
2786
|
+
// Load q2 values from lookup grid
|
|
2787
|
+
vuint64m1_t vq2_64_1 = __riscv_vluxei16_v_u64m1(grid64, vidx_q2_1, 4);
|
|
2788
|
+
vuint64m1_t vq2_64_2 = __riscv_vluxei16_v_u64m1(grid64, vidx_q2_2, 4);
|
|
2789
|
+
vint8m1_t q2_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vq2_64_1));
|
|
2790
|
+
vint8m1_t q2_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vq2_64_2));
|
|
2791
|
+
|
|
2792
|
+
// Load sign values
|
|
2793
|
+
vuint64m1_t vs2_64_1 = __riscv_vluxei16_v_u64m1(signs64, vidx_s2_1, 4);
|
|
2794
|
+
vuint64m1_t vs2_64_2 = __riscv_vluxei16_v_u64m1(signs64, vidx_s2_2, 4);
|
|
2795
|
+
vint8m1_t s2_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vs2_64_1));
|
|
2796
|
+
vint8m1_t s2_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vs2_64_2));
|
|
2797
|
+
|
|
2798
|
+
// Apply signs to q8
|
|
2799
|
+
vint8m1_t q8s_1 = __riscv_vmul_vv_i8m1(q8_1, s2_1, 32);
|
|
2800
|
+
vint8m1_t q8s_2 = __riscv_vmul_vv_i8m1(q8_2, s2_2, 32);
|
|
2801
|
+
|
|
2802
|
+
// multiplying q2 with q8
|
|
2803
|
+
vint16m2_t dot1 = __riscv_vwmul_vv_i16m2(q8s_1, q2_1, 32);
|
|
2804
|
+
vint16m2_t dot2 = __riscv_vwmul_vv_i16m2(q8s_2, q2_2, 32);
|
|
2805
|
+
|
|
2806
|
+
vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1);
|
|
2807
|
+
vint32m1_t sumv1 = __riscv_vwredsum_vs_i16m2_i32m1(dot1, zero_vec, 32);
|
|
2808
|
+
vint32m1_t sumv2 = __riscv_vwredsum_vs_i16m2_i32m1(dot2, zero_vec, 32);
|
|
2809
|
+
int32_t scalar_sum1 = __riscv_vmv_x_s_i32m1_i32(sumv1);
|
|
2810
|
+
int32_t scalar_sum2 = __riscv_vmv_x_s_i32m1_i32(sumv2);
|
|
2811
|
+
int16_t scale1 = 2 * ((s_packed_1 >> 28) & 0xF) + 1;
|
|
2812
|
+
int16_t scale2 = 2 * ((s_packed_2 >> 28) & 0xF) + 1;
|
|
2813
|
+
|
|
2814
|
+
sum += scalar_sum1 * scale1 + scalar_sum2 * scale2;
|
|
2815
|
+
q2_ptr += 16;
|
|
2816
|
+
}
|
|
2817
|
+
sumf += sum * combined_scale;
|
|
2818
|
+
}
|
|
2819
|
+
*s = 0.125f * sumf;
|
|
2820
|
+
}
|
|
2821
|
+
|
|
2822
|
+
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
2823
|
+
#if defined __riscv_v_intrinsic
|
|
2824
|
+
switch (__riscv_vlenb() * 8) {
|
|
2825
|
+
case 128:
|
|
2826
|
+
ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc);
|
|
2827
|
+
break;
|
|
2828
|
+
default:
|
|
2829
|
+
ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
|
2830
|
+
break;
|
|
2831
|
+
}
|
|
2832
|
+
#else
|
|
2833
|
+
ggml_vec_dot_iq2_xxs_q8_K(n, s, bs, vx, bx, vy, by, nrc);
|
|
2834
|
+
#endif
|
|
2835
|
+
}
|
|
2836
|
+
|
|
2837
|
+
static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
2838
|
+
assert(n % QK_K == 0);
|
|
2839
|
+
UNUSED(nrc);
|
|
2840
|
+
UNUSED(bx);
|
|
2841
|
+
UNUSED(by);
|
|
2842
|
+
UNUSED(bs);
|
|
2843
|
+
|
|
2844
|
+
const block_iq3_s * GGML_RESTRICT x = vx;
|
|
2845
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
|
2846
|
+
|
|
2847
|
+
const int nb = n / QK_K;
|
|
2848
|
+
|
|
2849
|
+
const uint64_t * grid64 = (const uint64_t *)iq3s_grid;
|
|
2850
|
+
|
|
2851
|
+
// --- Pre-load Constants ---
|
|
2852
|
+
const uint16_t qh_bit_shifts_arr[16] = {
|
|
2853
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
|
|
2854
|
+
};
|
|
2855
|
+
vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64);
|
|
2856
|
+
vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64);
|
|
2857
|
+
vuint16m1_t v_qh_shifts = __riscv_vle16_v_u16m1(qh_bit_shifts_arr, 16);
|
|
2858
|
+
|
|
2859
|
+
float sumf = 0.0f;
|
|
2860
|
+
|
|
2861
|
+
for (int i = 0; i < nb; ++i) {
|
|
2862
|
+
const float d = GGML_CPU_FP16_TO_FP32(x[i].d);
|
|
2863
|
+
const float combined_scale = d * y[i].d;
|
|
2864
|
+
|
|
2865
|
+
const uint8_t * GGML_RESTRICT qs = x[i].qs;
|
|
2866
|
+
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
|
2867
|
+
const uint8_t * GGML_RESTRICT scales = x[i].scales;
|
|
2868
|
+
const uint8_t * GGML_RESTRICT signs = x[i].signs;
|
|
2869
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
|
2870
|
+
|
|
2871
|
+
float sum_block = 0.0f;
|
|
2872
|
+
|
|
2873
|
+
// Loop: Process 64 weights (16 mini-blocks of 4) per iteration
|
|
2874
|
+
for (int ib = 0; ib < 4; ++ib) {
|
|
2875
|
+
|
|
2876
|
+
vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 16);
|
|
2877
|
+
qs += 16;
|
|
2878
|
+
|
|
2879
|
+
uint16_t qh_val;
|
|
2880
|
+
memcpy(&qh_val, qh, 2);
|
|
2881
|
+
qh += 2;
|
|
2882
|
+
|
|
2883
|
+
vuint16m1_t v_qh_val = __riscv_vmv_v_x_u16m1(qh_val, 16);
|
|
2884
|
+
// Extract bits: (qh >> i) & 1
|
|
2885
|
+
v_qh_val = __riscv_vsrl_vv_u16m1(v_qh_val, v_qh_shifts, 16);
|
|
2886
|
+
v_qh_val = __riscv_vand_vx_u16m1(v_qh_val, 1, 16);
|
|
2887
|
+
|
|
2888
|
+
vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 16);
|
|
2889
|
+
v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 16);
|
|
2890
|
+
v_qh_val = __riscv_vsll_vx_u16m1(v_qh_val, 10, 16);
|
|
2891
|
+
vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_val, 16);
|
|
2892
|
+
|
|
2893
|
+
// Grid value is 4xuint8
|
|
2894
|
+
vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2((const uint32_t *)grid64, v_grid_offsets, 16);
|
|
2895
|
+
vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed);
|
|
2896
|
+
vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 8);
|
|
2897
|
+
signs += 8;
|
|
2898
|
+
|
|
2899
|
+
// Generate sign mask
|
|
2900
|
+
vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);
|
|
2901
|
+
vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64);
|
|
2902
|
+
vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64);
|
|
2903
|
+
vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64);
|
|
2904
|
+
|
|
2905
|
+
vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);
|
|
2906
|
+
q8 += 64;
|
|
2907
|
+
|
|
2908
|
+
// Apply Signs
|
|
2909
|
+
vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64);
|
|
2910
|
+
vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 64);
|
|
2911
|
+
|
|
2912
|
+
// Reduction
|
|
2913
|
+
vint16m2_t v_dot_lo = __riscv_vget_v_i16m4_i16m2(v_dot, 0);
|
|
2914
|
+
vint16m2_t v_dot_hi = __riscv_vget_v_i16m4_i16m2(v_dot, 1);
|
|
2915
|
+
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
|
|
2916
|
+
|
|
2917
|
+
int32_t s_lo = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_lo, v_zero, 32));
|
|
2918
|
+
int32_t s_hi = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_hi, v_zero, 32));
|
|
2919
|
+
|
|
2920
|
+
// Apply sub-scales
|
|
2921
|
+
uint8_t sc_byte = *scales++;
|
|
2922
|
+
int sc_lo = (sc_byte & 0xF) * 2 + 1;
|
|
2923
|
+
int sc_hi = (sc_byte >> 4) * 2 + 1;
|
|
2924
|
+
|
|
2925
|
+
sum_block += s_lo * sc_lo + s_hi * sc_hi;
|
|
2926
|
+
}
|
|
2927
|
+
sumf += sum_block * combined_scale;
|
|
2928
|
+
}
|
|
2929
|
+
*s = sumf;
|
|
2930
|
+
}
|
|
2931
|
+
|
|
2932
|
+
void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
2933
|
+
#if defined __riscv_v_intrinsic
|
|
2934
|
+
switch (__riscv_vlenb() * 8) {
|
|
2935
|
+
case 256:
|
|
2936
|
+
ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
|
2937
|
+
break;
|
|
2938
|
+
default:
|
|
2939
|
+
ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
2940
|
+
break;
|
|
2941
|
+
}
|
|
2942
|
+
#else
|
|
2943
|
+
ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
2944
|
+
#endif
|
|
2945
|
+
}
|
|
2946
|
+
|
|
2947
|
+
static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
2948
|
+
assert(n % QK_K == 0);
|
|
2949
|
+
assert(nrc == 1);
|
|
2950
|
+
UNUSED(nrc);
|
|
2951
|
+
UNUSED(bx);
|
|
2952
|
+
UNUSED(by);
|
|
2953
|
+
UNUSED(bs);
|
|
2954
|
+
|
|
2955
|
+
const block_iq3_xxs * GGML_RESTRICT x = vx;
|
|
2956
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
|
2957
|
+
const int nb = n / QK_K;
|
|
2958
|
+
|
|
2959
|
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
|
2960
|
+
const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid;
|
|
2961
|
+
|
|
2962
|
+
// constants for unpacking logic
|
|
2963
|
+
const uint32_t shifts_val[8] = {0, 7, 14, 21, 0, 7, 14, 21};
|
|
2964
|
+
vuint32m1_t v_shifts = __riscv_vle32_v_u32m1(shifts_val, 8);
|
|
2965
|
+
|
|
2966
|
+
const uint32_t gather_idx_val[8] = {0, 0, 0, 0, 1, 1, 1, 1};
|
|
2967
|
+
vuint32m1_t v_gather_idx = __riscv_vle32_v_u32m1(gather_idx_val, 8);
|
|
2968
|
+
|
|
2969
|
+
uint32_t aux32[2];
|
|
2970
|
+
float sumf = 0.0f;
|
|
2971
|
+
|
|
2972
|
+
for (int i = 0; i < nb; ++i) {
|
|
2973
|
+
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
2974
|
+
|
|
2975
|
+
const uint8_t * GGML_RESTRICT q3_indices = x[i].qs;
|
|
2976
|
+
const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4;
|
|
2977
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
|
2978
|
+
|
|
2979
|
+
float block_sum = 0.0f;
|
|
2980
|
+
|
|
2981
|
+
for (int ib = 0; ib < QK_K / 64; ++ib) {
|
|
2982
|
+
// Load q8 (64 bytes)
|
|
2983
|
+
vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);
|
|
2984
|
+
q8 += 64;
|
|
2985
|
+
|
|
2986
|
+
// load of metadata via memcpy
|
|
2987
|
+
memcpy(aux32, metadata, 2 * sizeof(uint32_t));
|
|
2988
|
+
metadata += 2 * sizeof(uint32_t);
|
|
2989
|
+
|
|
2990
|
+
// Load q3 indices and gather magnitudes
|
|
2991
|
+
vuint8mf2_t v_q3_idx_u8 = __riscv_vle8_v_u8mf2(q3_indices, 16);
|
|
2992
|
+
q3_indices += 16;
|
|
2993
|
+
|
|
2994
|
+
vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_u8, 4, 16);
|
|
2995
|
+
vuint32m2_t v_q3_magnitudes_u32 = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 16);
|
|
2996
|
+
vint8m2_t v_q3_magnitudes = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u32m2_u8m2(v_q3_magnitudes_u32));
|
|
2997
|
+
|
|
2998
|
+
// --- Unpacking of Sign Indices ---
|
|
2999
|
+
|
|
3000
|
+
// 1. Load the 2 auxiliary 32-bit integers into a vector
|
|
3001
|
+
vuint32m1_t v_aux = __riscv_vle32_v_u32m1(aux32, 2);
|
|
3002
|
+
|
|
3003
|
+
// 2. Broadcast/Gather: replicate aux[0] to first 4 lanes, aux[1] to next 4 lanes
|
|
3004
|
+
vuint32m1_t v_aux_expanded = __riscv_vrgather_vv_u32m1(v_aux, v_gather_idx, 8);
|
|
3005
|
+
|
|
3006
|
+
// 3. Apply Shifts and Mask: ((val >> shift) & 127)
|
|
3007
|
+
vuint32m1_t v_s_vals_raw = __riscv_vand_vx_u32m1(__riscv_vsrl_vv_u32m1(v_aux_expanded, v_shifts, 8), 127, 8);
|
|
3008
|
+
|
|
3009
|
+
// 4. Narrow to u16 (required for vluxei index) and multiply by 8 (byte offset for u64 table)
|
|
3010
|
+
vuint16mf2_t sign_indices_byte_offset = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_vals_raw, 8), 3, 8);
|
|
3011
|
+
|
|
3012
|
+
// 5. Gather Signs
|
|
3013
|
+
vuint64m2_t v_s_vals_u64 = __riscv_vluxei16_v_u64m2(signs64, sign_indices_byte_offset, 8);
|
|
3014
|
+
vint8m2_t v_s_vals = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(v_s_vals_u64));
|
|
3015
|
+
|
|
3016
|
+
vint8m2_t v_q3_signed = __riscv_vmul_vv_i8m2(v_q3_magnitudes, v_s_vals, 64);
|
|
3017
|
+
vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_q8, v_q3_signed, 64);
|
|
3018
|
+
|
|
3019
|
+
vint16m2_t v_dot_1 = __riscv_vget_v_i16m4_i16m2(v_dot, 0);
|
|
3020
|
+
vint16m2_t v_dot_2 = __riscv_vget_v_i16m4_i16m2(v_dot, 1);
|
|
3021
|
+
|
|
3022
|
+
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
|
|
3023
|
+
vint32m1_t v_sum_1 = __riscv_vwredsum_vs_i16m2_i32m1(v_dot_1, v_zero, 32);
|
|
3024
|
+
vint32m1_t v_sum_2 = __riscv_vwredsum_vs_i16m2_i32m1(v_dot_2, v_zero, 32);
|
|
3025
|
+
|
|
3026
|
+
int32_t sum1_i = __riscv_vmv_x_s_i32m1_i32(v_sum_1);
|
|
3027
|
+
int32_t sum2_i = __riscv_vmv_x_s_i32m1_i32(v_sum_2);
|
|
3028
|
+
|
|
3029
|
+
const float scale1_f = (float)(2 * (aux32[0] >> 28) + 1);
|
|
3030
|
+
const float scale2_f = (float)(2 * (aux32[1] >> 28) + 1);
|
|
3031
|
+
|
|
3032
|
+
block_sum += sum1_i * scale1_f + sum2_i * scale2_f;
|
|
3033
|
+
}
|
|
3034
|
+
|
|
3035
|
+
sumf += d * block_sum;
|
|
3036
|
+
}
|
|
3037
|
+
*s = 0.25f * sumf;
|
|
3038
|
+
}
|
|
3039
|
+
|
|
3040
|
+
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
3041
|
+
#if defined __riscv_v_intrinsic
|
|
3042
|
+
switch (__riscv_vlenb() * 8) {
|
|
3043
|
+
case 256:
|
|
3044
|
+
ggml_vec_dot_iq3_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
|
3045
|
+
break;
|
|
3046
|
+
default:
|
|
3047
|
+
ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
3048
|
+
break;
|
|
3049
|
+
}
|
|
3050
|
+
#else
|
|
3051
|
+
ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
3052
|
+
#endif
|
|
3053
|
+
}
|
|
3054
|
+
|
|
3055
|
+
static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
3056
|
+
assert(nrc == 1);
|
|
3057
|
+
UNUSED(nrc);
|
|
3058
|
+
UNUSED(bx);
|
|
3059
|
+
UNUSED(by);
|
|
3060
|
+
UNUSED(bs);
|
|
3061
|
+
assert(n % QK4_NL == 0);
|
|
3062
|
+
static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
|
|
3063
|
+
|
|
3064
|
+
const block_iq4_nl * GGML_RESTRICT x = vx;
|
|
3065
|
+
const block_q8_0 * GGML_RESTRICT y = vy;
|
|
3066
|
+
|
|
3067
|
+
const int nb = n / QK4_NL;
|
|
3068
|
+
|
|
3069
|
+
int ib = 0;
|
|
3070
|
+
float sumf = 0;
|
|
3071
|
+
|
|
3072
|
+
// Load the lookup table once.
|
|
3073
|
+
const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16);
|
|
3074
|
+
int acc1, acc2;
|
|
3075
|
+
|
|
3076
|
+
// We process 2 blocks at once.
|
|
3077
|
+
for (; ib + 1 < nb; ib += 2) {
|
|
3078
|
+
// Weights and activations.
|
|
3079
|
+
vuint8m1_t iq4_packed1 = __riscv_vle8_v_u8m1(x[ib + 0].qs, 16);
|
|
3080
|
+
vint8m2_t q8b1 = __riscv_vle8_v_i8m2(y[ib + 0].qs, 32);
|
|
3081
|
+
vuint8m1_t iq4_packed2 = __riscv_vle8_v_u8m1(x[ib + 1].qs, 16);
|
|
3082
|
+
vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32);
|
|
3083
|
+
|
|
3084
|
+
// Unpack the weight blocks.
|
|
3085
|
+
vuint8m2_t iq4bits1;
|
|
3086
|
+
iq4bits1 = __riscv_vset_v_u8m1_u8m2(iq4bits1, 0, __riscv_vand_vx_u8m1(iq4_packed1, 0xf, 16));
|
|
3087
|
+
iq4bits1 = __riscv_vset_v_u8m1_u8m2(iq4bits1, 1, __riscv_vsrl_vx_u8m1(iq4_packed1, 4, 16));
|
|
3088
|
+
vuint8m2_t iq4bits2;
|
|
3089
|
+
iq4bits2 = __riscv_vset_v_u8m1_u8m2(iq4bits2, 0, __riscv_vand_vx_u8m1(iq4_packed2, 0xf, 16));
|
|
3090
|
+
iq4bits2 = __riscv_vset_v_u8m1_u8m2(iq4bits2, 1, __riscv_vsrl_vx_u8m1(iq4_packed2, 4, 16));
|
|
3091
|
+
|
|
3092
|
+
// Gather values from the lookup table.
|
|
3093
|
+
vint8m2_t iq4b1 = __riscv_vrgather_vv_i8m2(values, iq4bits1, 32);
|
|
3094
|
+
vint8m2_t iq4b2 = __riscv_vrgather_vv_i8m2(values, iq4bits2, 32);
|
|
3095
|
+
|
|
3096
|
+
// Accumulation.
|
|
3097
|
+
vint16m4_t sum1 = __riscv_vwmul_vv_i16m4(q8b1, iq4b1, 32);
|
|
3098
|
+
vint16m4_t sum2 = __riscv_vwmul_vv_i16m4(q8b2, iq4b2, 32);
|
|
3099
|
+
__riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m4_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 32), 1);
|
|
3100
|
+
__riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m4_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 32), 1);
|
|
3101
|
+
sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 0].d) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1));
|
|
3102
|
+
sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2));
|
|
3103
|
+
}
|
|
3104
|
+
|
|
3105
|
+
*s = sumf;
|
|
3106
|
+
}
|
|
3107
|
+
|
|
3108
|
+
static void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
3109
|
+
assert(nrc == 1);
|
|
3110
|
+
UNUSED(nrc);
|
|
3111
|
+
UNUSED(bx);
|
|
3112
|
+
UNUSED(by);
|
|
3113
|
+
UNUSED(bs);
|
|
3114
|
+
assert(n % QK4_NL == 0);
|
|
3115
|
+
static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
|
|
3116
|
+
|
|
3117
|
+
const block_iq4_nl * GGML_RESTRICT x = vx;
|
|
3118
|
+
const block_q8_0 * GGML_RESTRICT y = vy;
|
|
3119
|
+
|
|
3120
|
+
const int nb = n / QK4_NL;
|
|
3121
|
+
|
|
3122
|
+
int ib = 0;
|
|
3123
|
+
float sumf = 0;
|
|
3124
|
+
|
|
3125
|
+
// Load the lookup table once.
|
|
3126
|
+
const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16);
|
|
3127
|
+
int acc1, acc2;
|
|
3128
|
+
|
|
3129
|
+
// We process 2 blocks at once.
|
|
3130
|
+
for (; ib + 1 < nb; ib += 2) {
|
|
3131
|
+
// Weights and activations.
|
|
3132
|
+
vuint8mf2_t iq4_packed1 = __riscv_vle8_v_u8mf2(x[ib + 0].qs, 16);
|
|
3133
|
+
vint8mf2_t q8b_lo1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs, 16);
|
|
3134
|
+
vint8mf2_t q8b_hi1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs + 16, 16);
|
|
3135
|
+
vuint8mf2_t iq4_packed2 = __riscv_vle8_v_u8mf2(x[ib + 1].qs, 16);
|
|
3136
|
+
vint8mf2_t q8b_lo2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs, 16);
|
|
3137
|
+
vint8mf2_t q8b_hi2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs + 16, 16);
|
|
3138
|
+
|
|
3139
|
+
// Unpack the weight blocks.
|
|
3140
|
+
vuint8mf2_t iq4bits_lo1 = __riscv_vand_vx_u8mf2(iq4_packed1, 0xf, 16);
|
|
3141
|
+
vuint8mf2_t iq4bits_hi1 = __riscv_vsrl_vx_u8mf2(iq4_packed1, 4, 16);
|
|
3142
|
+
vuint8mf2_t iq4bits_lo2 = __riscv_vand_vx_u8mf2(iq4_packed2, 0xf, 16);
|
|
3143
|
+
vuint8mf2_t iq4bits_hi2 = __riscv_vsrl_vx_u8mf2(iq4_packed2, 4, 16);
|
|
3144
|
+
|
|
3145
|
+
// Gather values from the lookup table.
|
|
3146
|
+
vint8mf2_t iq4b_lo1 = __riscv_vrgather_vv_i8mf2(values, iq4bits_lo1, 16);
|
|
3147
|
+
vint8mf2_t iq4b_hi1 = __riscv_vrgather_vv_i8mf2(values, iq4bits_hi1, 16);
|
|
3148
|
+
vint8mf2_t iq4b_lo2 = __riscv_vrgather_vv_i8mf2(values, iq4bits_lo2, 16);
|
|
3149
|
+
vint8mf2_t iq4b_hi2 = __riscv_vrgather_vv_i8mf2(values, iq4bits_hi2, 16);
|
|
3150
|
+
|
|
3151
|
+
// Accumulation.
|
|
3152
|
+
vint16m1_t sum1 = __riscv_vwmul_vv_i16m1(q8b_lo1, iq4b_lo1, 16);
|
|
3153
|
+
sum1 = __riscv_vwmacc_vv_i16m1(sum1, q8b_hi1, iq4b_hi1, 16);
|
|
3154
|
+
vint16m1_t sum2 = __riscv_vwmul_vv_i16m1(q8b_lo2, iq4b_lo2, 16);
|
|
3155
|
+
sum2 = __riscv_vwmacc_vv_i16m1(sum2, q8b_hi2, iq4b_hi2, 16);
|
|
3156
|
+
__riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m1_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 16), 1);
|
|
3157
|
+
__riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m1_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 16), 1);
|
|
3158
|
+
sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 0].d) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1));
|
|
3159
|
+
sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2));
|
|
3160
|
+
}
|
|
3161
|
+
|
|
3162
|
+
*s = sumf;
|
|
3163
|
+
}
|
|
3164
|
+
|
|
3165
|
+
void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
3166
|
+
#if defined __riscv_v_intrinsic
|
|
3167
|
+
switch (__riscv_vlenb() * 8) {
|
|
3168
|
+
case 128:
|
|
3169
|
+
ggml_vec_dot_iq4_nl_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc);
|
|
3170
|
+
break;
|
|
3171
|
+
default:
|
|
3172
|
+
ggml_vec_dot_iq4_nl_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
|
3173
|
+
break;
|
|
3174
|
+
}
|
|
3175
|
+
#else
|
|
3176
|
+
ggml_vec_dot_iq4_nl_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
3177
|
+
#endif
|
|
3178
|
+
}
|
|
3179
|
+
|
|
3180
|
+
static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
3181
|
+
assert(nrc == 1);
|
|
3182
|
+
UNUSED(nrc);
|
|
3183
|
+
UNUSED(bx);
|
|
3184
|
+
UNUSED(by);
|
|
3185
|
+
UNUSED(bs);
|
|
3186
|
+
assert(n % QK_K == 0);
|
|
3187
|
+
|
|
3188
|
+
const block_iq4_xs * GGML_RESTRICT x = vx;
|
|
3189
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
|
3190
|
+
|
|
3191
|
+
const int nb = n / QK_K;
|
|
3192
|
+
|
|
3193
|
+
#if defined __riscv_v_intrinsic
|
|
3194
|
+
const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16);
|
|
3195
|
+
float sumf = 0;
|
|
3196
|
+
int acc[4];
|
|
3197
|
+
|
|
3198
|
+
// Indices for re-ordering IQ4 data.
|
|
3199
|
+
uint64_t index[16] = {
|
|
3200
|
+
0, 1, 8, 9,
|
|
3201
|
+
2, 3, 10, 11,
|
|
3202
|
+
4, 5,12, 13,
|
|
3203
|
+
6, 7, 14, 15,
|
|
3204
|
+
};
|
|
3205
|
+
vuint64m4_t i_vec = __riscv_vle64_v_u64m4(index, 16);
|
|
3206
|
+
|
|
3207
|
+
for (int ibl = 0; ibl < nb; ++ibl) {
|
|
3208
|
+
const int8_t * q8 = y[ibl].qs;
|
|
3209
|
+
const uint8_t * iq4 = x[ibl].qs;
|
|
3210
|
+
uint16_t h = x[ibl].scales_h;
|
|
3211
|
+
|
|
3212
|
+
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
|
|
3213
|
+
|
|
3214
|
+
for (int ib = 0; ib < QK_K / 128; ++ib) {
|
|
3215
|
+
// Weights and activations.
|
|
3216
|
+
vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 64);
|
|
3217
|
+
vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128);
|
|
3218
|
+
iq4 += 64;
|
|
3219
|
+
q8 += 128;
|
|
3220
|
+
|
|
3221
|
+
// Unpack the weight blocks.
|
|
3222
|
+
vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 64);
|
|
3223
|
+
vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 64);
|
|
3224
|
+
vuint8m4_t iq4bits;
|
|
3225
|
+
iq4bits = __riscv_vset_v_u8m2_u8m4(iq4bits, 0, iq4bits_lo);
|
|
3226
|
+
iq4bits = __riscv_vset_v_u8m2_u8m4(iq4bits, 1, iq4bits_hi);
|
|
3227
|
+
vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgather_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 16));
|
|
3228
|
+
vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 128);
|
|
3229
|
+
|
|
3230
|
+
// Multiply with activations.
|
|
3231
|
+
vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 128);
|
|
3232
|
+
|
|
3233
|
+
// Reduce separately.
|
|
3234
|
+
__riscv_vse32_v_i32m1(&acc[0],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32), 1);
|
|
3235
|
+
__riscv_vse32_v_i32m1(&acc[1],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32), 1);
|
|
3236
|
+
__riscv_vse32_v_i32m1(&acc[2],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32), 1);
|
|
3237
|
+
__riscv_vse32_v_i32m1(&acc[3],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32), 1);
|
|
3238
|
+
|
|
3239
|
+
int ls1 = ((x[ibl].scales_l[ib * 2 + 0] & 0xf) | ((h << 4) & 0x30)) - 32;
|
|
3240
|
+
int ls2 = ((x[ibl].scales_l[ib * 2 + 0] >> 4) | ((h << 2) & 0x30)) - 32;
|
|
3241
|
+
int ls3 = ((x[ibl].scales_l[ib * 2 + 1] & 0xf) | ((h << 0) & 0x30)) - 32;
|
|
3242
|
+
int ls4 = ((x[ibl].scales_l[ib * 2 + 1] >> 4) | ((h >> 2) & 0x30)) - 32;
|
|
3243
|
+
h >>= 8;
|
|
3244
|
+
|
|
3245
|
+
sumi1 += acc[0] * ls1;
|
|
3246
|
+
sumi2 += acc[1] * ls2;
|
|
3247
|
+
sumi3 += acc[2] * ls3;
|
|
3248
|
+
sumi4 += acc[3] * ls4;
|
|
3249
|
+
}
|
|
3250
|
+
|
|
3251
|
+
sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2 + sumi3 + sumi4);
|
|
3252
|
+
}
|
|
3253
|
+
|
|
3254
|
+
*s = sumf;
|
|
3255
|
+
|
|
3256
|
+
#else
|
|
3257
|
+
UNUSED(x);
|
|
3258
|
+
UNUSED(y);
|
|
3259
|
+
UNUSED(nb);
|
|
3260
|
+
ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
3261
|
+
#endif
|
|
3262
|
+
}
|
|
3263
|
+
|
|
3264
|
+
void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
3265
|
+
#if defined __riscv_v_intrinsic
|
|
3266
|
+
switch (__riscv_vlenb() * 8) {
|
|
3267
|
+
case 256:
|
|
3268
|
+
ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
|
3269
|
+
break;
|
|
3270
|
+
default:
|
|
3271
|
+
ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
3272
|
+
break;
|
|
3273
|
+
}
|
|
3274
|
+
#else
|
|
3275
|
+
ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
3276
|
+
#endif
|
|
3277
|
+
}
|
|
3278
|
+
|
|
3279
|
+
static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
3280
|
+
assert(nrc == 1);
|
|
3281
|
+
UNUSED(nrc);
|
|
3282
|
+
UNUSED(bx);
|
|
3283
|
+
UNUSED(by);
|
|
3284
|
+
UNUSED(bs);
|
|
3285
|
+
|
|
3286
|
+
const block_tq1_0 * GGML_RESTRICT x = vx;
|
|
3287
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
|
3288
|
+
|
|
3289
|
+
const int nb = n / QK_K;
|
|
3290
|
+
|
|
3291
|
+
float sumf = 0.0f;
|
|
3292
|
+
uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
|
|
3293
|
+
|
|
3294
|
+
for (int i = 0; i < nb; i++) {
|
|
3295
|
+
// First loop.
|
|
3296
|
+
vint32m4_t suml1;
|
|
3297
|
+
{
|
|
3298
|
+
const int vl = 32;
|
|
3299
|
+
vuint8m1_t tq = __riscv_vle8_v_u8m1(x[i].qs, vl);
|
|
3300
|
+
|
|
3301
|
+
vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(tq, 3, vl), 8, vl);
|
|
3302
|
+
vuint16m2_t tq1 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 3, vl), 3, vl), 8, vl);
|
|
3303
|
+
vuint16m2_t tq2 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 9, vl), 3, vl), 8, vl);
|
|
3304
|
+
vuint16m2_t tq3 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 27, vl), 3, vl), 8, vl);
|
|
3305
|
+
vuint16m2_t tq4 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 81, vl), 3, vl), 8, vl);
|
|
3306
|
+
|
|
3307
|
+
vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 0, vl), vl);
|
|
3308
|
+
vint16m2_t q81 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 32, vl), vl);
|
|
3309
|
+
vint16m2_t q82 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 64, vl), vl);
|
|
3310
|
+
vint16m2_t q83 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 96, vl), vl);
|
|
3311
|
+
vint16m2_t q84 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 128, vl), vl);
|
|
3312
|
+
|
|
3313
|
+
vint16m2_t sum0 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl);
|
|
3314
|
+
vint16m2_t sum1 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq1, 1, vl)), q81, vl);
|
|
3315
|
+
vint16m2_t sum2 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq2, 1, vl)), q82, vl);
|
|
3316
|
+
vint16m2_t sum3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq3, 1, vl)), q83, vl);
|
|
3317
|
+
vint16m2_t sum4 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq4, 1, vl)), q84, vl);
|
|
3318
|
+
|
|
3319
|
+
vint32m4_t sumi0 = __riscv_vwadd_vv_i32m4(sum0, sum1, vl);
|
|
3320
|
+
vint32m4_t sumi1 = __riscv_vwadd_vv_i32m4(sum2, sum3, vl);
|
|
3321
|
+
suml1 = __riscv_vadd_vv_i32m4(__riscv_vwcvt_x_x_v_i32m4(sum4, vl), __riscv_vadd_vv_i32m4(sumi0, sumi1, vl), vl);
|
|
3322
|
+
}
|
|
3323
|
+
|
|
3324
|
+
// Second loop.
|
|
3325
|
+
vint32m2_t suml2;
|
|
3326
|
+
{
|
|
3327
|
+
const int vl = 16;
|
|
3328
|
+
vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs + 32, vl);
|
|
3329
|
+
|
|
3330
|
+
vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(tq, 3 * 1, vl), 8, vl);
|
|
3331
|
+
vuint16m1_t tq1 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 3, vl), 3, vl), 8, vl);
|
|
3332
|
+
vuint16m1_t tq2 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 9, vl), 3, vl), 8, vl);
|
|
3333
|
+
vuint16m1_t tq3 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 27, vl), 3, vl), 8, vl);
|
|
3334
|
+
vuint16m1_t tq4 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 81, vl), 3, vl), 8, vl);
|
|
3335
|
+
|
|
3336
|
+
vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 160, vl), vl);
|
|
3337
|
+
vint16m1_t q81 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 176, vl), vl);
|
|
3338
|
+
vint16m1_t q82 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 192, vl), vl);
|
|
3339
|
+
vint16m1_t q83 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 208, vl), vl);
|
|
3340
|
+
vint16m1_t q84 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 224, vl), vl);
|
|
3341
|
+
|
|
3342
|
+
vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl);
|
|
3343
|
+
vint16m1_t sum1 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq1, 1, vl)), q81, vl);
|
|
3344
|
+
vint16m1_t sum2 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq2, 1, vl)), q82, vl);
|
|
3345
|
+
vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl);
|
|
3346
|
+
vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl);
|
|
3347
|
+
|
|
3348
|
+
vint32m2_t sumi0 = __riscv_vwadd_vv_i32m2(sum0, sum1, vl);
|
|
3349
|
+
vint32m2_t sumi1 = __riscv_vwadd_vv_i32m2(sum2, sum3, vl);
|
|
3350
|
+
suml2 = __riscv_vadd_vv_i32m2(__riscv_vwcvt_x_x_v_i32m2(sum4, vl), __riscv_vadd_vv_i32m2(sumi0, sumi1, vl), vl);
|
|
3351
|
+
}
|
|
3352
|
+
|
|
3353
|
+
// Third loop.
|
|
3354
|
+
vint32m2_t suml3;
|
|
3355
|
+
{
|
|
3356
|
+
const int vl = 16;
|
|
3357
|
+
|
|
3358
|
+
uint32_t qh;
|
|
3359
|
+
memcpy(&qh, &x[i].qh[0], 4);
|
|
3360
|
+
// Prevent fusion with vmv.
|
|
3361
|
+
__asm__ __volatile__("" : "+r"(qh));
|
|
3362
|
+
vuint8mf2_t tq = __riscv_vreinterpret_v_u32mf2_u8mf2(__riscv_vmv_v_x_u32mf2(qh, vl / 4));
|
|
3363
|
+
|
|
3364
|
+
vuint8mf2_t p = __riscv_vle8_v_u8mf2(pow, vl);
|
|
3365
|
+
|
|
3366
|
+
vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vv_u8mf2(tq, p, vl), 3, vl), 8, vl);
|
|
3367
|
+
|
|
3368
|
+
vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 240, vl), vl);
|
|
3369
|
+
|
|
3370
|
+
vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl);
|
|
3371
|
+
suml3 = __riscv_vwcvt_x_x_v_i32m2(sum0, vl);
|
|
3372
|
+
}
|
|
3373
|
+
|
|
3374
|
+
vint32m2_t sumb = __riscv_vadd_vv_i32m2(__riscv_vget_v_i32m4_i32m2(suml1, 0), __riscv_vget_v_i32m4_i32m2(suml1, 1), 16);
|
|
3375
|
+
sumb = __riscv_vadd_vv_i32m2(sumb, suml2, 16);
|
|
3376
|
+
sumb = __riscv_vadd_vv_i32m2(sumb, suml3, 16);
|
|
3377
|
+
|
|
3378
|
+
vint32m1_t sum = __riscv_vredsum_vs_i32m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16);
|
|
3379
|
+
sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
|
|
3380
|
+
}
|
|
3381
|
+
|
|
3382
|
+
*s = sumf;
|
|
3383
|
+
}
|
|
3384
|
+
|
|
3385
|
+
void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
3386
|
+
#if defined __riscv_v_intrinsic
|
|
3387
|
+
switch (__riscv_vlenb() * 8) {
|
|
3388
|
+
case 256:
|
|
3389
|
+
ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
|
3390
|
+
break;
|
|
3391
|
+
default:
|
|
3392
|
+
ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
3393
|
+
break;
|
|
3394
|
+
}
|
|
3395
|
+
#else
|
|
3396
|
+
ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
3397
|
+
#endif
|
|
3398
|
+
}
|
|
3399
|
+
|
|
3400
|
+
static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
3401
|
+
assert(n % QK_K == 0);
|
|
3402
|
+
assert(nrc == 1);
|
|
3403
|
+
UNUSED(nrc);
|
|
3404
|
+
UNUSED(bx);
|
|
3405
|
+
UNUSED(by);
|
|
3406
|
+
UNUSED(bs);
|
|
3407
|
+
|
|
3408
|
+
const block_tq2_0 * GGML_RESTRICT x = vx;
|
|
3409
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
|
3410
|
+
|
|
3411
|
+
const int nb = n / QK_K;
|
|
3412
|
+
|
|
3413
|
+
float sumf = 0.0f;
|
|
3414
|
+
for (int i = 0; i < nb; ++i) {
|
|
3415
|
+
int32_t sumi = 0;
|
|
3416
|
+
|
|
3417
|
+
for (size_t j = 0; j < sizeof(x[0].qs); j += 32) {
|
|
3418
|
+
const int8_t * py0 = &y[i].qs[j * 4 + 0 * 32];
|
|
3419
|
+
const int8_t * py1 = &y[i].qs[j * 4 + 1 * 32];
|
|
3420
|
+
const int8_t * py2 = &y[i].qs[j * 4 + 2 * 32];
|
|
3421
|
+
const int8_t * py3 = &y[i].qs[j * 4 + 3 * 32];
|
|
3422
|
+
const uint8_t* px = &x[i].qs[j];
|
|
3423
|
+
|
|
3424
|
+
size_t vlmax_16m2 = __riscv_vsetvl_e16m2(32);
|
|
3425
|
+
vint16m2_t vacc16 = __riscv_vmv_v_x_i16m2(0, vlmax_16m2);
|
|
3426
|
+
|
|
3427
|
+
size_t vl = __riscv_vsetvl_e8m1(32);
|
|
3428
|
+
|
|
3429
|
+
vuint8m1_t vx_u8 = __riscv_vle8_v_u8m1(px, vl);
|
|
3430
|
+
|
|
3431
|
+
vint8m1_t vy0 = __riscv_vle8_v_i8m1(py0 , vl);
|
|
3432
|
+
vint8m1_t vy1 = __riscv_vle8_v_i8m1(py1, vl);
|
|
3433
|
+
vint8m1_t vy2 = __riscv_vle8_v_i8m1(py2, vl);
|
|
3434
|
+
vint8m1_t vy3 = __riscv_vle8_v_i8m1(py3, vl);
|
|
3435
|
+
|
|
3436
|
+
// l=0 (bits 1:0)
|
|
3437
|
+
vuint8m1_t t0 = __riscv_vand_vx_u8m1(vx_u8, 0x03, vl);
|
|
3438
|
+
vint8m1_t vq0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t0), 1, vl);
|
|
3439
|
+
|
|
3440
|
+
// l=1 (bits 3:2)
|
|
3441
|
+
vuint8m1_t t1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 2, vl), 0x03, vl);
|
|
3442
|
+
vint8m1_t vq1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t1), 1, vl);
|
|
3443
|
+
|
|
3444
|
+
// l=2 (bits 5:4)
|
|
3445
|
+
vuint8m1_t t2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 4, vl), 0x03, vl);
|
|
3446
|
+
vint8m1_t vq2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t2), 1, vl);
|
|
3447
|
+
|
|
3448
|
+
// l=3 (bits 7:6)
|
|
3449
|
+
vuint8m1_t t3 = __riscv_vsrl_vx_u8m1(vx_u8, 6, vl); // No final AND needed as vsrl shifts in zeros
|
|
3450
|
+
vint8m1_t vq3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t3), 1, vl);
|
|
3451
|
+
|
|
3452
|
+
// 4. Multiply and accumulate
|
|
3453
|
+
vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq0, vy0, vl);
|
|
3454
|
+
vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq1, vy1, vl);
|
|
3455
|
+
vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq2, vy2, vl);
|
|
3456
|
+
vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq3, vy3, vl);
|
|
3457
|
+
|
|
3458
|
+
vlmax_16m2 = __riscv_vsetvl_e16m2(32);
|
|
3459
|
+
vint32m1_t vzero32 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
3460
|
+
vint32m1_t vred32 = __riscv_vwredsum_vs_i16m2_i32m1(vacc16, vzero32, vlmax_16m2);
|
|
3461
|
+
|
|
3462
|
+
sumi += __riscv_vmv_x_s_i32m1_i32(vred32);
|
|
3463
|
+
}
|
|
3464
|
+
const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
|
|
3465
|
+
sumf += (float)sumi * d;
|
|
3466
|
+
}
|
|
3467
|
+
|
|
3468
|
+
*s = sumf;
|
|
3469
|
+
}
|
|
3470
|
+
|
|
3471
|
+
void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
3472
|
+
#if defined __riscv_v_intrinsic
|
|
3473
|
+
switch (__riscv_vlenb() * 8) {
|
|
3474
|
+
case 256:
|
|
3475
|
+
ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
|
3476
|
+
break;
|
|
3477
|
+
default:
|
|
3478
|
+
ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
3479
|
+
break;
|
|
3480
|
+
}
|
|
3481
|
+
#else
|
|
3482
|
+
ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
3483
|
+
#endif
|
|
3484
|
+
}
|
|
3485
|
+
|
|
3486
|
+
static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
3487
|
+
assert(nrc == 1);
|
|
3488
|
+
UNUSED(nrc);
|
|
3489
|
+
UNUSED(bx);
|
|
3490
|
+
UNUSED(by);
|
|
3491
|
+
UNUSED(bs);
|
|
3492
|
+
assert(n % QK_MXFP4 == 0);
|
|
3493
|
+
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
|
|
3494
|
+
|
|
3495
|
+
const block_mxfp4 * GGML_RESTRICT x = vx;
|
|
3496
|
+
const block_q8_0 * GGML_RESTRICT y = vy;
|
|
3497
|
+
|
|
3498
|
+
const int nb = n / QK_MXFP4;
|
|
3499
|
+
|
|
3500
|
+
int ib = 0;
|
|
3501
|
+
float sumf = 0;
|
|
3502
|
+
|
|
3503
|
+
// Load the lookup table once.
|
|
3504
|
+
const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_mxfp4, 16);
|
|
3505
|
+
int acc1, acc2;
|
|
3506
|
+
|
|
3507
|
+
// We process 2 blocks at once.
|
|
3508
|
+
for (; ib + 1 < nb; ib += 2) {
|
|
3509
|
+
// Weights and activations.
|
|
3510
|
+
vuint8m1_t mx_packed1 = __riscv_vle8_v_u8m1(x[ib + 0].qs, 16);
|
|
3511
|
+
vint8m2_t q8b1 = __riscv_vle8_v_i8m2(y[ib + 0].qs, 32);
|
|
3512
|
+
vuint8m1_t mx_packed2 = __riscv_vle8_v_u8m1(x[ib + 1].qs, 16);
|
|
3513
|
+
vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32);
|
|
3514
|
+
|
|
3515
|
+
// Unpack the weight blocks.
|
|
3516
|
+
vuint8m2_t mxbits1;
|
|
3517
|
+
mxbits1 = __riscv_vset_v_u8m1_u8m2(mxbits1, 0, __riscv_vand_vx_u8m1(mx_packed1, 0xf, 16));
|
|
3518
|
+
mxbits1 = __riscv_vset_v_u8m1_u8m2(mxbits1, 1, __riscv_vsrl_vx_u8m1(mx_packed1, 4, 16));
|
|
3519
|
+
vuint8m2_t mxbits2;
|
|
3520
|
+
mxbits2 = __riscv_vset_v_u8m1_u8m2(mxbits2, 0, __riscv_vand_vx_u8m1(mx_packed2, 0xf, 16));
|
|
3521
|
+
mxbits2 = __riscv_vset_v_u8m1_u8m2(mxbits2, 1, __riscv_vsrl_vx_u8m1(mx_packed2, 4, 16));
|
|
3522
|
+
|
|
3523
|
+
// Gather values from the lookup table.
|
|
3524
|
+
vint8m2_t mxb1 = __riscv_vrgather_vv_i8m2(values, mxbits1, 32);
|
|
3525
|
+
vint8m2_t mxb2 = __riscv_vrgather_vv_i8m2(values, mxbits2, 32);
|
|
3526
|
+
|
|
3527
|
+
// Accumulation.
|
|
3528
|
+
vint16m4_t sum1 = __riscv_vwmul_vv_i16m4(q8b1, mxb1, 32);
|
|
3529
|
+
vint16m4_t sum2 = __riscv_vwmul_vv_i16m4(q8b2, mxb2, 32);
|
|
3530
|
+
__riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m4_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 32), 1);
|
|
3531
|
+
__riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m4_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 32), 1);
|
|
3532
|
+
sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1));
|
|
3533
|
+
sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2));
|
|
3534
|
+
}
|
|
3535
|
+
|
|
3536
|
+
*s = sumf;
|
|
3537
|
+
}
|
|
3538
|
+
|
|
3539
|
+
static void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
3540
|
+
assert(nrc == 1);
|
|
3541
|
+
UNUSED(nrc);
|
|
3542
|
+
UNUSED(bx);
|
|
3543
|
+
UNUSED(by);
|
|
3544
|
+
UNUSED(bs);
|
|
3545
|
+
assert(n % QK_MXFP4 == 0);
|
|
3546
|
+
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
|
|
3547
|
+
|
|
3548
|
+
const block_mxfp4 * GGML_RESTRICT x = vx;
|
|
3549
|
+
const block_q8_0 * GGML_RESTRICT y = vy;
|
|
3550
|
+
|
|
3551
|
+
const int nb = n / QK_MXFP4;
|
|
3552
|
+
|
|
3553
|
+
int ib = 0;
|
|
3554
|
+
float sumf = 0;
|
|
3555
|
+
|
|
3556
|
+
// Load the lookup table once.
|
|
3557
|
+
const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_mxfp4, 16);
|
|
3558
|
+
int acc1, acc2;
|
|
3559
|
+
|
|
3560
|
+
// We process 2 blocks at once.
|
|
3561
|
+
for (; ib + 1 < nb; ib+=2) {
|
|
3562
|
+
// Weights and activations.
|
|
3563
|
+
vuint8mf2_t mx_packed1 = __riscv_vle8_v_u8mf2(x[ib + 0].qs, 16);
|
|
3564
|
+
vint8mf2_t q8b_lo1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs, 16);
|
|
3565
|
+
vint8mf2_t q8b_hi1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs + 16, 16);
|
|
3566
|
+
vuint8mf2_t mx_packed2 = __riscv_vle8_v_u8mf2(x[ib + 1].qs, 16);
|
|
3567
|
+
vint8mf2_t q8b_lo2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs, 16);
|
|
3568
|
+
vint8mf2_t q8b_hi2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs + 16, 16);
|
|
3569
|
+
|
|
3570
|
+
// Unpack the weight blocks.
|
|
3571
|
+
vuint8mf2_t mxbits_lo1 = __riscv_vand_vx_u8mf2(mx_packed1, 0xf, 16);
|
|
3572
|
+
vuint8mf2_t mxbits_hi1 = __riscv_vsrl_vx_u8mf2(mx_packed1, 4, 16);
|
|
3573
|
+
vuint8mf2_t mxbits_lo2 = __riscv_vand_vx_u8mf2(mx_packed2, 0xf, 16);
|
|
3574
|
+
vuint8mf2_t mxbits_hi2 = __riscv_vsrl_vx_u8mf2(mx_packed2, 4, 16);
|
|
3575
|
+
|
|
3576
|
+
// Gather values from the lookup table.
|
|
3577
|
+
vint8mf2_t mxb_lo1 = __riscv_vrgather_vv_i8mf2(values, mxbits_lo1, 16);
|
|
3578
|
+
vint8mf2_t mxb_hi1 = __riscv_vrgather_vv_i8mf2(values, mxbits_hi1, 16);
|
|
3579
|
+
vint8mf2_t mxb_lo2 = __riscv_vrgather_vv_i8mf2(values, mxbits_lo2, 16);
|
|
3580
|
+
vint8mf2_t mxb_hi2 = __riscv_vrgather_vv_i8mf2(values, mxbits_hi2, 16);
|
|
3581
|
+
|
|
3582
|
+
// Accumulation.
|
|
3583
|
+
vint16m1_t sum1 = __riscv_vwmul_vv_i16m1(q8b_lo1, mxb_lo1, 16);
|
|
3584
|
+
sum1 = __riscv_vwmacc_vv_i16m1(sum1, q8b_hi1, mxb_hi1, 16);
|
|
3585
|
+
vint16m1_t sum2 = __riscv_vwmul_vv_i16m1(q8b_lo2, mxb_lo2, 16);
|
|
3586
|
+
sum2 = __riscv_vwmacc_vv_i16m1(sum2, q8b_hi2, mxb_hi2, 16);
|
|
3587
|
+
__riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m1_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 16), 1);
|
|
3588
|
+
__riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m1_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 16), 1);
|
|
3589
|
+
sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1));
|
|
3590
|
+
sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2));
|
|
3591
|
+
}
|
|
3592
|
+
|
|
3593
|
+
*s = sumf;
|
|
3594
|
+
}
|
|
3595
|
+
|
|
3596
|
+
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
3597
|
+
#if defined __riscv_v_intrinsic
|
|
3598
|
+
switch (__riscv_vlenb() * 8) {
|
|
3599
|
+
case 128:
|
|
3600
|
+
ggml_vec_dot_mxfp4_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc);
|
|
3601
|
+
break;
|
|
3602
|
+
default:
|
|
3603
|
+
ggml_vec_dot_mxfp4_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
|
3604
|
+
break;
|
|
3605
|
+
}
|
|
3606
|
+
#else
|
|
3607
|
+
return ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
3608
|
+
#endif
|
|
3609
|
+
}
|