whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -24,6 +24,94 @@
|
|
|
24
24
|
|
|
25
25
|
#define UNUSED GGML_UNUSED
|
|
26
26
|
|
|
27
|
+
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
|
28
|
+
assert(QK8_0 == 32);
|
|
29
|
+
assert(k % QK8_0 == 0);
|
|
30
|
+
const int nb = k / QK8_0;
|
|
31
|
+
|
|
32
|
+
#if defined(__riscv_v_intrinsic)
|
|
33
|
+
block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
|
|
34
|
+
const size_t vl_calc = __riscv_vsetvl_e32m8(QK8_0);
|
|
35
|
+
const size_t vl_save = __riscv_vsetvl_e64m2(4);
|
|
36
|
+
vfloat32m1_t v_scalar_zero = __riscv_vfmv_s_f_f32m1(0.0f, __riscv_vsetvl_e32m1(1));
|
|
37
|
+
|
|
38
|
+
for (int i = 0; i < nb; i++) {
|
|
39
|
+
const float *x_block_base = x + i * QK8_0;
|
|
40
|
+
vint8m2_t q_r0, q_r1, q_r2, q_r3;
|
|
41
|
+
{
|
|
42
|
+
vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 0 * k, vl_calc);
|
|
43
|
+
vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc);
|
|
44
|
+
vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc);
|
|
45
|
+
float amax = __riscv_vfmv_f_s_f32m1_f32(v_max);
|
|
46
|
+
|
|
47
|
+
float d = amax / 127.0f;
|
|
48
|
+
y[i].d[0] = GGML_CPU_FP32_TO_FP16(d);
|
|
49
|
+
|
|
50
|
+
float id = d ? 1.0f / d : 0.0f;
|
|
51
|
+
vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc);
|
|
52
|
+
vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc);
|
|
53
|
+
q_r0 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc);
|
|
54
|
+
}
|
|
55
|
+
asm volatile ("" ::: "memory");
|
|
56
|
+
|
|
57
|
+
{
|
|
58
|
+
vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 1 * k, vl_calc);
|
|
59
|
+
vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc);
|
|
60
|
+
vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc);
|
|
61
|
+
float amax = __riscv_vfmv_f_s_f32m1_f32(v_max);
|
|
62
|
+
|
|
63
|
+
float d = amax / 127.0f;
|
|
64
|
+
y[i].d[1] = GGML_CPU_FP32_TO_FP16(d);
|
|
65
|
+
float id = d ? 1.0f / d : 0.0f;
|
|
66
|
+
|
|
67
|
+
vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc);
|
|
68
|
+
vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc);
|
|
69
|
+
q_r1 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc);
|
|
70
|
+
}
|
|
71
|
+
asm volatile ("" ::: "memory");
|
|
72
|
+
{
|
|
73
|
+
vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 2 * k, vl_calc);
|
|
74
|
+
vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc);
|
|
75
|
+
vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc);
|
|
76
|
+
float amax = __riscv_vfmv_f_s_f32m1_f32(v_max);
|
|
77
|
+
|
|
78
|
+
float d = amax / 127.0f;
|
|
79
|
+
y[i].d[2] = GGML_CPU_FP32_TO_FP16(d);
|
|
80
|
+
float id = d ? 1.0f / d : 0.0f;
|
|
81
|
+
|
|
82
|
+
vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc);
|
|
83
|
+
vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc);
|
|
84
|
+
q_r2 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc);
|
|
85
|
+
}
|
|
86
|
+
asm volatile ("" ::: "memory");
|
|
87
|
+
{
|
|
88
|
+
vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 3 * k, vl_calc);
|
|
89
|
+
vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc);
|
|
90
|
+
vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc);
|
|
91
|
+
float amax = __riscv_vfmv_f_s_f32m1_f32(v_max);
|
|
92
|
+
|
|
93
|
+
float d = amax / 127.0f;
|
|
94
|
+
y[i].d[3] = GGML_CPU_FP32_TO_FP16(d);
|
|
95
|
+
float id = d ? 1.0f / d : 0.0f;
|
|
96
|
+
|
|
97
|
+
vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc);
|
|
98
|
+
vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc);
|
|
99
|
+
q_r3 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc);
|
|
100
|
+
}
|
|
101
|
+
vint64m2_t v_q64_r0 = __riscv_vreinterpret_v_i8m2_i64m2(q_r0);
|
|
102
|
+
vint64m2_t v_q64_r1 = __riscv_vreinterpret_v_i8m2_i64m2(q_r1);
|
|
103
|
+
vint64m2_t v_q64_r2 = __riscv_vreinterpret_v_i8m2_i64m2(q_r2);
|
|
104
|
+
vint64m2_t v_q64_r3 = __riscv_vreinterpret_v_i8m2_i64m2(q_r3);
|
|
105
|
+
vint64m2x4_t v_quant_tuple = __riscv_vcreate_v_i64m2x4(v_q64_r0, v_q64_r1, v_q64_r2, v_q64_r3);
|
|
106
|
+
__riscv_vsseg4e64_v_i64m2x4((int64_t*)y[i].qs, v_quant_tuple, vl_save);
|
|
107
|
+
}
|
|
108
|
+
#else
|
|
109
|
+
UNUSED(nb);
|
|
110
|
+
UNUSED(y);
|
|
111
|
+
ggml_quantize_mat_q8_0_4x4_generic(x, vy, k);
|
|
112
|
+
#endif
|
|
113
|
+
}
|
|
114
|
+
|
|
27
115
|
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
28
116
|
const int qk = QK8_0;
|
|
29
117
|
const int nb = n / qk;
|
|
@@ -115,6 +203,486 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
|
115
203
|
ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
116
204
|
}
|
|
117
205
|
|
|
206
|
+
void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
207
|
+
const int qk = QK8_0;
|
|
208
|
+
const int nb = n / qk;
|
|
209
|
+
const int ncols_interleaved = 16;
|
|
210
|
+
const int blocklen = 1;
|
|
211
|
+
|
|
212
|
+
assert (n % qk == 0);
|
|
213
|
+
assert (nc % ncols_interleaved == 0);
|
|
214
|
+
|
|
215
|
+
UNUSED(s);
|
|
216
|
+
UNUSED(bs);
|
|
217
|
+
UNUSED(vx);
|
|
218
|
+
UNUSED(vy);
|
|
219
|
+
UNUSED(nr);
|
|
220
|
+
UNUSED(nc);
|
|
221
|
+
UNUSED(nb);
|
|
222
|
+
UNUSED(ncols_interleaved);
|
|
223
|
+
UNUSED(blocklen);
|
|
224
|
+
|
|
225
|
+
#if defined __riscv_v_intrinsic
|
|
226
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
227
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
228
|
+
const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb);
|
|
229
|
+
|
|
230
|
+
// 1x16 Accumulator
|
|
231
|
+
vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
232
|
+
|
|
233
|
+
for (int l = 0; l < nb; l++) {
|
|
234
|
+
// 1x16 Integer Accumulator
|
|
235
|
+
vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
236
|
+
vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
237
|
+
|
|
238
|
+
// Accumulation loop.
|
|
239
|
+
for (int i = 0; i < QK4_0 / 2; i++) {
|
|
240
|
+
// Load `b_ptr`.
|
|
241
|
+
const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16);
|
|
242
|
+
const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16);
|
|
243
|
+
const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16);
|
|
244
|
+
|
|
245
|
+
sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], b_0_lo, 16);
|
|
246
|
+
sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], b_0_hi, 16);
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16);
|
|
250
|
+
|
|
251
|
+
const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
|
|
252
|
+
const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16);
|
|
253
|
+
|
|
254
|
+
sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
__riscv_vse32_v_f32m2(s + x * 16, sumf, 16);
|
|
258
|
+
}
|
|
259
|
+
return;
|
|
260
|
+
#endif
|
|
261
|
+
ggml_gemv_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
265
|
+
const int qk = QK_K;
|
|
266
|
+
const int nb = n / qk;
|
|
267
|
+
const int ncols_interleaved = 16;
|
|
268
|
+
const int blocklen = 1;
|
|
269
|
+
|
|
270
|
+
assert (n % qk == 0);
|
|
271
|
+
assert (nc % ncols_interleaved == 0);
|
|
272
|
+
|
|
273
|
+
UNUSED(s);
|
|
274
|
+
UNUSED(bs);
|
|
275
|
+
UNUSED(vx);
|
|
276
|
+
UNUSED(vy);
|
|
277
|
+
UNUSED(nr);
|
|
278
|
+
UNUSED(nc);
|
|
279
|
+
UNUSED(nb);
|
|
280
|
+
UNUSED(ncols_interleaved);
|
|
281
|
+
UNUSED(blocklen);
|
|
282
|
+
|
|
283
|
+
#if defined __riscv_v_intrinsic
|
|
284
|
+
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
|
285
|
+
|
|
286
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
287
|
+
const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb);
|
|
288
|
+
|
|
289
|
+
// 1x16 Accumulator
|
|
290
|
+
vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
291
|
+
|
|
292
|
+
for (int l = 0; l < nb; l++) {
|
|
293
|
+
vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, 16);
|
|
294
|
+
|
|
295
|
+
// Load `dmin`.
|
|
296
|
+
const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2(
|
|
297
|
+
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d, 16);
|
|
298
|
+
|
|
299
|
+
// We process 4 sub-blocks at once.
|
|
300
|
+
for (int j = 0; j < QK_K / 128; j++) {
|
|
301
|
+
// Extract the scales and the mins.
|
|
302
|
+
//
|
|
303
|
+
// Low bits.
|
|
304
|
+
vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64);
|
|
305
|
+
vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64);
|
|
306
|
+
vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64);
|
|
307
|
+
|
|
308
|
+
// High bits.
|
|
309
|
+
vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64);
|
|
310
|
+
vuint8m2_t scales_hi;
|
|
311
|
+
vuint8m2_t mins_hi;
|
|
312
|
+
if (!j) {
|
|
313
|
+
scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64);
|
|
314
|
+
mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64);
|
|
315
|
+
} else {
|
|
316
|
+
scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64);
|
|
317
|
+
mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64);
|
|
318
|
+
}
|
|
319
|
+
vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64);
|
|
320
|
+
vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64));
|
|
321
|
+
|
|
322
|
+
// Reduce the mins and multiply with `dmin`.
|
|
323
|
+
//
|
|
324
|
+
// Correct in `sumf`.
|
|
325
|
+
vint32m2_t bsums = __riscv_vmv_v_x_i32m2(0, 16);
|
|
326
|
+
bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8] + a_ptr[l].bsums[j * 8 + 1], __riscv_vget_v_i16m4_i16m1(mins, 0), 16);
|
|
327
|
+
bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 2] + a_ptr[l].bsums[j * 8 + 3], __riscv_vget_v_i16m4_i16m1(mins, 1), 16);
|
|
328
|
+
bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), 16);
|
|
329
|
+
bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), 16);
|
|
330
|
+
|
|
331
|
+
sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, 16), 16), 16);
|
|
332
|
+
|
|
333
|
+
// Accumulation for 2 sub-blocks.
|
|
334
|
+
//
|
|
335
|
+
// This might overflow, so we accumulate in two steps.
|
|
336
|
+
//
|
|
337
|
+
// Recheck.
|
|
338
|
+
for (int k = 0; k < 2; k++) {
|
|
339
|
+
vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
340
|
+
vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
341
|
+
|
|
342
|
+
for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {
|
|
343
|
+
// Load `b_ptr`.
|
|
344
|
+
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16);
|
|
345
|
+
const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16));
|
|
346
|
+
const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16));
|
|
347
|
+
|
|
348
|
+
sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + i], b_s_0, 16);
|
|
349
|
+
sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 32 + i], b_s_1, 16);
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
sumi = __riscv_vwmacc_vv_i32m2(sumi,
|
|
353
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),
|
|
354
|
+
sumi_s_0_16, 16);
|
|
355
|
+
sumi = __riscv_vwmacc_vv_i32m2(sumi,
|
|
356
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),
|
|
357
|
+
sumi_s_1_16, 16);
|
|
358
|
+
}
|
|
359
|
+
// Accumulation for 2 sub-blocks.
|
|
360
|
+
//
|
|
361
|
+
// This might overflow, so we accumulate in two steps.
|
|
362
|
+
//
|
|
363
|
+
// Recheck.
|
|
364
|
+
for (int k = 0; k < 2; k++) {
|
|
365
|
+
vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
366
|
+
vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
367
|
+
|
|
368
|
+
for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {
|
|
369
|
+
// Load `b_ptr`.
|
|
370
|
+
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16);
|
|
371
|
+
const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16));
|
|
372
|
+
const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16));
|
|
373
|
+
|
|
374
|
+
sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + 64 + i], b_s_0, 16);
|
|
375
|
+
sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 96 + i], b_s_1, 16);
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
sumi = __riscv_vwmacc_vv_i32m2(sumi,
|
|
379
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),
|
|
380
|
+
sumi_s_0_16, 16);
|
|
381
|
+
sumi = __riscv_vwmacc_vv_i32m2(sumi,
|
|
382
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),
|
|
383
|
+
sumi_s_1_16, 16);
|
|
384
|
+
}
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)&b_ptr[l].d[0], 16), 16);
|
|
388
|
+
const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d, 16);
|
|
389
|
+
|
|
390
|
+
sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
__riscv_vse32_v_f32m2(s + x * 16, sumf, 16);
|
|
394
|
+
}
|
|
395
|
+
return;
|
|
396
|
+
#endif
|
|
397
|
+
ggml_gemv_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
401
|
+
const int qk = QK8_0;
|
|
402
|
+
const int nb = n / qk;
|
|
403
|
+
const int ncols_interleaved = 16;
|
|
404
|
+
const int blocklen = 1;
|
|
405
|
+
|
|
406
|
+
assert (n % qk == 0);
|
|
407
|
+
assert (nc % ncols_interleaved == 0);
|
|
408
|
+
|
|
409
|
+
UNUSED(s);
|
|
410
|
+
UNUSED(bs);
|
|
411
|
+
UNUSED(vx);
|
|
412
|
+
UNUSED(vy);
|
|
413
|
+
UNUSED(nr);
|
|
414
|
+
UNUSED(nc);
|
|
415
|
+
UNUSED(nb);
|
|
416
|
+
UNUSED(ncols_interleaved);
|
|
417
|
+
UNUSED(blocklen);
|
|
418
|
+
|
|
419
|
+
#if defined __riscv_v_intrinsic
|
|
420
|
+
const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16);
|
|
421
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
422
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
423
|
+
const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb);
|
|
424
|
+
|
|
425
|
+
// 1x16 Accumulator1
|
|
426
|
+
vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
427
|
+
|
|
428
|
+
for (int l = 0; l < nb; l++) {
|
|
429
|
+
// 1x16 integer accumulator
|
|
430
|
+
vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
431
|
+
|
|
432
|
+
// Accumulation loop.
|
|
433
|
+
for (int i = 0; i < QK4_NL / 2; i++) {
|
|
434
|
+
// Load `b_ptr`.
|
|
435
|
+
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16);
|
|
436
|
+
const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16);
|
|
437
|
+
const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16);
|
|
438
|
+
// const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
|
|
439
|
+
// const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
|
|
440
|
+
|
|
441
|
+
const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], 16);
|
|
442
|
+
const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], 16);
|
|
443
|
+
sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, 16), 16);
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
|
|
447
|
+
const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16);
|
|
448
|
+
|
|
449
|
+
sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
__riscv_vse32_v_f32m2(s + x * 16, sumf, 16);
|
|
453
|
+
}
|
|
454
|
+
return;
|
|
455
|
+
#endif
|
|
456
|
+
ggml_gemv_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
460
|
+
const int qk = QK8_0;
|
|
461
|
+
const int nb = n / qk;
|
|
462
|
+
const int ncols_interleaved = 16;
|
|
463
|
+
const int blocklen = 1;
|
|
464
|
+
|
|
465
|
+
assert (n % qk == 0);
|
|
466
|
+
assert (nc % ncols_interleaved == 0);
|
|
467
|
+
|
|
468
|
+
UNUSED(s);
|
|
469
|
+
UNUSED(bs);
|
|
470
|
+
UNUSED(vx);
|
|
471
|
+
UNUSED(vy);
|
|
472
|
+
UNUSED(nr);
|
|
473
|
+
UNUSED(nc);
|
|
474
|
+
UNUSED(nb);
|
|
475
|
+
UNUSED(ncols_interleaved);
|
|
476
|
+
UNUSED(blocklen);
|
|
477
|
+
UNUSED(bs);
|
|
478
|
+
|
|
479
|
+
#if defined __riscv_v_intrinsic
|
|
480
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
481
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
482
|
+
const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb);
|
|
483
|
+
|
|
484
|
+
// 1x16 Accumulator
|
|
485
|
+
vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
486
|
+
|
|
487
|
+
for (int l = 0; l < nb; l++) {
|
|
488
|
+
// 1x16 Integer Accumulator
|
|
489
|
+
vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
490
|
+
|
|
491
|
+
// Accumulation loop.
|
|
492
|
+
for (int i = 0; i < QK8_0; i++) {
|
|
493
|
+
// Load `b_ptr`.
|
|
494
|
+
const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16);
|
|
495
|
+
// const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16);
|
|
496
|
+
|
|
497
|
+
sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], 16), 16);
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
|
|
501
|
+
const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16);
|
|
502
|
+
|
|
503
|
+
sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
__riscv_vse32_v_f32m2(s + x * 16, sumf, 16);
|
|
507
|
+
}
|
|
508
|
+
return;
|
|
509
|
+
#endif
|
|
510
|
+
ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
514
|
+
assert(n % QK_K == 0);
|
|
515
|
+
assert(nr == 1);
|
|
516
|
+
assert(nc % 16 == 0);
|
|
517
|
+
|
|
518
|
+
UNUSED(bs);
|
|
519
|
+
|
|
520
|
+
const int N_COLS_TILE = 16;
|
|
521
|
+
const int num_k_blocks = n / QK_K;
|
|
522
|
+
|
|
523
|
+
const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE);
|
|
524
|
+
for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) {
|
|
525
|
+
|
|
526
|
+
const block_q8_K* lhs_base_ptr = (const block_q8_K*)vy;
|
|
527
|
+
const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks;
|
|
528
|
+
|
|
529
|
+
vfloat32m2_t v_sumf = __riscv_vfmv_v_f_f32m2(0.0f, vl);
|
|
530
|
+
|
|
531
|
+
for (int k_block = 0; k_block < num_k_blocks; ++k_block) {
|
|
532
|
+
const block_q8_K* lhs_current = &lhs_base_ptr[k_block];
|
|
533
|
+
const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block];
|
|
534
|
+
|
|
535
|
+
// 1. Prepare Global Min Scales
|
|
536
|
+
vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl);
|
|
537
|
+
vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl);
|
|
538
|
+
|
|
539
|
+
vfloat32m2_t v_g_min_final = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d, vl);
|
|
540
|
+
|
|
541
|
+
vint32m2_t v_isum = __riscv_vmv_v_x_i32m2(0, vl);
|
|
542
|
+
|
|
543
|
+
const uint8_t* rhs_qs_ptr = rhs_current->qs;
|
|
544
|
+
const uint8_t* rhs_sc_ptr = rhs_current->scales;
|
|
545
|
+
const int8_t* lhs_qs_ptr = lhs_current->qs;
|
|
546
|
+
|
|
547
|
+
// --- Phase Loop (4 phases x 64 elements) ---
|
|
548
|
+
for (int phase = 0; phase < 4; ++phase) {
|
|
549
|
+
|
|
550
|
+
// A. Load Scales/Mins
|
|
551
|
+
vuint16m1_t v_d_sb_0, v_d_sb_1, v_d_sb_2, v_d_sb_3;
|
|
552
|
+
vuint16m1_t v_m_sb_0, v_m_sb_1, v_m_sb_2, v_m_sb_3;
|
|
553
|
+
|
|
554
|
+
{
|
|
555
|
+
vuint8mf2_t v_raw;
|
|
556
|
+
// Sub-block 0
|
|
557
|
+
v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl);
|
|
558
|
+
v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
|
|
559
|
+
v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
|
|
560
|
+
|
|
561
|
+
// Sub-block 1
|
|
562
|
+
v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl);
|
|
563
|
+
v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
|
|
564
|
+
v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
|
|
565
|
+
|
|
566
|
+
// Sub-block 2
|
|
567
|
+
v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl);
|
|
568
|
+
v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
|
|
569
|
+
v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
|
|
570
|
+
|
|
571
|
+
// Sub-block 3
|
|
572
|
+
v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl);
|
|
573
|
+
v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
|
|
574
|
+
v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
|
|
575
|
+
|
|
576
|
+
rhs_sc_ptr += 64;
|
|
577
|
+
}
|
|
578
|
+
|
|
579
|
+
int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16);
|
|
580
|
+
int k_offsets[4] = {0, 32, 64, 96};
|
|
581
|
+
|
|
582
|
+
// B. Inner Dot Product Loop
|
|
583
|
+
for (int l = 0; l < 16; ++l) {
|
|
584
|
+
vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl);
|
|
585
|
+
rhs_qs_ptr += 16;
|
|
586
|
+
|
|
587
|
+
// Sub-block 0
|
|
588
|
+
{
|
|
589
|
+
vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl);
|
|
590
|
+
vint16m1_t v_w = __riscv_vmul_vv_i16m1(
|
|
591
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
|
|
592
|
+
__riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl);
|
|
593
|
+
|
|
594
|
+
int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[0] + l];
|
|
595
|
+
v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl);
|
|
596
|
+
}
|
|
597
|
+
// Sub-block 1
|
|
598
|
+
{
|
|
599
|
+
vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl);
|
|
600
|
+
vint16m1_t v_w = __riscv_vmul_vv_i16m1(
|
|
601
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
|
|
602
|
+
__riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl);
|
|
603
|
+
|
|
604
|
+
int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[1] + l];
|
|
605
|
+
v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl);
|
|
606
|
+
}
|
|
607
|
+
// Sub-block 2
|
|
608
|
+
{
|
|
609
|
+
vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl);
|
|
610
|
+
vint16m1_t v_w = __riscv_vmul_vv_i16m1(
|
|
611
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
|
|
612
|
+
__riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl);
|
|
613
|
+
|
|
614
|
+
int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[2] + l];
|
|
615
|
+
v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl);
|
|
616
|
+
}
|
|
617
|
+
// Sub-block 3
|
|
618
|
+
{
|
|
619
|
+
vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl);
|
|
620
|
+
vint16m1_t v_w = __riscv_vmul_vv_i16m1(
|
|
621
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
|
|
622
|
+
__riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl);
|
|
623
|
+
|
|
624
|
+
int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[3] + l];
|
|
625
|
+
v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl);
|
|
626
|
+
}
|
|
627
|
+
}
|
|
628
|
+
|
|
629
|
+
// correction
|
|
630
|
+
int sb_base_abs = base_k_phase / 16;
|
|
631
|
+
|
|
632
|
+
// Sub-block 0
|
|
633
|
+
{
|
|
634
|
+
int sb_idx = sb_base_abs + (k_offsets[0] / 16);
|
|
635
|
+
int16_t bsum = lhs_current->bsums[sb_idx];
|
|
636
|
+
vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0);
|
|
637
|
+
vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl);
|
|
638
|
+
vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl);
|
|
639
|
+
v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl);
|
|
640
|
+
}
|
|
641
|
+
// Sub-block 1
|
|
642
|
+
{
|
|
643
|
+
int sb_idx = sb_base_abs + (k_offsets[1] / 16);
|
|
644
|
+
int16_t bsum = lhs_current->bsums[sb_idx];
|
|
645
|
+
vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1);
|
|
646
|
+
vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl);
|
|
647
|
+
vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl);
|
|
648
|
+
v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl);
|
|
649
|
+
}
|
|
650
|
+
// Sub-block 2
|
|
651
|
+
{
|
|
652
|
+
int sb_idx = sb_base_abs + (k_offsets[2] / 16);
|
|
653
|
+
int16_t bsum = lhs_current->bsums[sb_idx];
|
|
654
|
+
vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2);
|
|
655
|
+
vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl);
|
|
656
|
+
vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl);
|
|
657
|
+
v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl);
|
|
658
|
+
}
|
|
659
|
+
// Sub-block 3
|
|
660
|
+
{
|
|
661
|
+
int sb_idx = sb_base_abs + (k_offsets[3] / 16);
|
|
662
|
+
int16_t bsum = lhs_current->bsums[sb_idx];
|
|
663
|
+
vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3);
|
|
664
|
+
vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl);
|
|
665
|
+
vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl);
|
|
666
|
+
v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl);
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
} // End Phase Loop
|
|
670
|
+
|
|
671
|
+
// Apply global Scales
|
|
672
|
+
vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl);
|
|
673
|
+
vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl);
|
|
674
|
+
|
|
675
|
+
vfloat32m2_t v_g_all_final = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d, vl);
|
|
676
|
+
vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum, vl);
|
|
677
|
+
v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all_final, vl);
|
|
678
|
+
v_sumf = __riscv_vfadd_vv_f32m2(v_sumf, v_sum, vl);
|
|
679
|
+
|
|
680
|
+
} // End K-Block
|
|
681
|
+
__riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl);
|
|
682
|
+
|
|
683
|
+
}
|
|
684
|
+
}
|
|
685
|
+
|
|
118
686
|
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
119
687
|
const int qk = QK8_0;
|
|
120
688
|
const int nb = n / qk;
|
|
@@ -340,3 +908,826 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
|
340
908
|
#endif
|
|
341
909
|
ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
342
910
|
}
|
|
911
|
+
|
|
912
|
+
void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
913
|
+
const int qk = QK8_0;
|
|
914
|
+
const int nb = n / qk;
|
|
915
|
+
const int ncols_interleaved = 16;
|
|
916
|
+
const int blocklen = 1;
|
|
917
|
+
|
|
918
|
+
assert (n % qk == 0);
|
|
919
|
+
assert (nr % 4 == 0);
|
|
920
|
+
assert (nc % ncols_interleaved == 0);
|
|
921
|
+
|
|
922
|
+
UNUSED(s);
|
|
923
|
+
UNUSED(bs);
|
|
924
|
+
UNUSED(vx);
|
|
925
|
+
UNUSED(vy);
|
|
926
|
+
UNUSED(nr);
|
|
927
|
+
UNUSED(nc);
|
|
928
|
+
UNUSED(nb);
|
|
929
|
+
UNUSED(ncols_interleaved);
|
|
930
|
+
UNUSED(blocklen);
|
|
931
|
+
|
|
932
|
+
#if defined __riscv_v_intrinsic
|
|
933
|
+
for (int y = 0; y < nr / 4; y++) {
|
|
934
|
+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
935
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
936
|
+
const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb);
|
|
937
|
+
|
|
938
|
+
// 4x16 Accumulators
|
|
939
|
+
vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
940
|
+
vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
941
|
+
vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
942
|
+
vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
943
|
+
|
|
944
|
+
for (int l = 0; l < nb; l++) {
|
|
945
|
+
// 4x16 integer accumulators
|
|
946
|
+
vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
947
|
+
vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
948
|
+
vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
949
|
+
vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
950
|
+
vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
951
|
+
vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
952
|
+
vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
953
|
+
vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
954
|
+
|
|
955
|
+
// Accumulation loop.
|
|
956
|
+
for (int i = 0; i < QK4_0 / 2; i++) {
|
|
957
|
+
// Load `b_ptr`.
|
|
958
|
+
const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16);
|
|
959
|
+
const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16);
|
|
960
|
+
const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16);
|
|
961
|
+
|
|
962
|
+
sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], b_0_lo, 16);
|
|
963
|
+
sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], b_0_lo, 16);
|
|
964
|
+
sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], b_0_lo, 16);
|
|
965
|
+
sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], b_0_lo, 16);
|
|
966
|
+
|
|
967
|
+
sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], b_0_hi, 16);
|
|
968
|
+
sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], b_0_hi, 16);
|
|
969
|
+
sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], b_0_hi, 16);
|
|
970
|
+
sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], b_0_hi, 16);
|
|
971
|
+
}
|
|
972
|
+
|
|
973
|
+
// Do the final accumulation in i32 to prevent overflow.
|
|
974
|
+
const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16);
|
|
975
|
+
const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, 16);
|
|
976
|
+
const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, 16);
|
|
977
|
+
const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, 16);
|
|
978
|
+
|
|
979
|
+
const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
|
|
980
|
+
const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16);
|
|
981
|
+
const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16);
|
|
982
|
+
const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16);
|
|
983
|
+
const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16);
|
|
984
|
+
|
|
985
|
+
sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);
|
|
986
|
+
sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);
|
|
987
|
+
sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16);
|
|
988
|
+
sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16);
|
|
989
|
+
}
|
|
990
|
+
|
|
991
|
+
__riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16);
|
|
992
|
+
__riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16);
|
|
993
|
+
__riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16);
|
|
994
|
+
__riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);
|
|
995
|
+
}
|
|
996
|
+
}
|
|
997
|
+
return;
|
|
998
|
+
#endif
|
|
999
|
+
ggml_gemm_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
1000
|
+
}
|
|
1001
|
+
|
|
1002
|
+
void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
1003
|
+
const int qk = QK_K;
|
|
1004
|
+
const int nb = n / qk;
|
|
1005
|
+
const int ncols_interleaved = 16;
|
|
1006
|
+
const int blocklen = 1;
|
|
1007
|
+
|
|
1008
|
+
assert (n % qk == 0);
|
|
1009
|
+
assert (nr % 4 == 0);
|
|
1010
|
+
assert (nc % ncols_interleaved == 0);
|
|
1011
|
+
|
|
1012
|
+
UNUSED(s);
|
|
1013
|
+
UNUSED(bs);
|
|
1014
|
+
UNUSED(vx);
|
|
1015
|
+
UNUSED(vy);
|
|
1016
|
+
UNUSED(nr);
|
|
1017
|
+
UNUSED(nc);
|
|
1018
|
+
UNUSED(nb);
|
|
1019
|
+
UNUSED(ncols_interleaved);
|
|
1020
|
+
UNUSED(blocklen);
|
|
1021
|
+
|
|
1022
|
+
#if defined __riscv_v_intrinsic
|
|
1023
|
+
for (int y = 0; y < nr / 4; y++) {
|
|
1024
|
+
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
|
1025
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
1026
|
+
const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb);
|
|
1027
|
+
|
|
1028
|
+
// 4x16 Accumulators
|
|
1029
|
+
vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
1030
|
+
vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
1031
|
+
vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
1032
|
+
vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
1033
|
+
|
|
1034
|
+
for (int l = 0; l < nb; l++) {
|
|
1035
|
+
vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
1036
|
+
vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
1037
|
+
vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
1038
|
+
vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
1039
|
+
|
|
1040
|
+
// Load `dmin`.
|
|
1041
|
+
const vfloat32m2_t dmins = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16);
|
|
1042
|
+
|
|
1043
|
+
// We process 4 sub-blocks at once.
|
|
1044
|
+
for (int j = 0; j < QK_K / 128; j++) {
|
|
1045
|
+
// Extract the scales and the mins.
|
|
1046
|
+
//
|
|
1047
|
+
// Low bits.
|
|
1048
|
+
vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64);
|
|
1049
|
+
vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64);
|
|
1050
|
+
vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64);
|
|
1051
|
+
|
|
1052
|
+
// High bits.
|
|
1053
|
+
vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64);
|
|
1054
|
+
vuint8m2_t scales_hi;
|
|
1055
|
+
vuint8m2_t mins_hi;
|
|
1056
|
+
if (!j) {
|
|
1057
|
+
scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64);
|
|
1058
|
+
mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64);
|
|
1059
|
+
} else {
|
|
1060
|
+
scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64);
|
|
1061
|
+
mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64);
|
|
1062
|
+
}
|
|
1063
|
+
vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64);
|
|
1064
|
+
vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64));
|
|
1065
|
+
|
|
1066
|
+
// Reduce the mins and multiply with `dmin`.
|
|
1067
|
+
//
|
|
1068
|
+
// Correct in `sumf`.
|
|
1069
|
+
vint32m2_t bsums_0 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
1070
|
+
vint32m2_t bsums_1 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
1071
|
+
vint32m2_t bsums_2 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
1072
|
+
vint32m2_t bsums_3 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
1073
|
+
|
|
1074
|
+
bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0,
|
|
1075
|
+
a_ptr[l].bsums[j * 32] + a_ptr[l].bsums[j * 32 + 4],
|
|
1076
|
+
__riscv_vget_v_i16m4_i16m1(mins, 0), 16);
|
|
1077
|
+
bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1,
|
|
1078
|
+
a_ptr[l].bsums[j * 32 + 1] + a_ptr[l].bsums[j * 32 + 5],
|
|
1079
|
+
__riscv_vget_v_i16m4_i16m1(mins, 0), 16);
|
|
1080
|
+
bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2,
|
|
1081
|
+
a_ptr[l].bsums[j * 32 + 2] + a_ptr[l].bsums[j * 32 + 6],
|
|
1082
|
+
__riscv_vget_v_i16m4_i16m1(mins, 0), 16);
|
|
1083
|
+
bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3,
|
|
1084
|
+
a_ptr[l].bsums[j * 32 + 3] + a_ptr[l].bsums[j * 32 + 7],
|
|
1085
|
+
__riscv_vget_v_i16m4_i16m1(mins, 0), 16);
|
|
1086
|
+
bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0,
|
|
1087
|
+
a_ptr[l].bsums[j * 32 + 8] + a_ptr[l].bsums[j * 32 + 8 + 4],
|
|
1088
|
+
__riscv_vget_v_i16m4_i16m1(mins, 1), 16);
|
|
1089
|
+
bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1,
|
|
1090
|
+
a_ptr[l].bsums[j * 32 + 8 + 1] + a_ptr[l].bsums[j * 32 + 8 + 5],
|
|
1091
|
+
__riscv_vget_v_i16m4_i16m1(mins, 1), 16);
|
|
1092
|
+
bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2,
|
|
1093
|
+
a_ptr[l].bsums[j * 32 + 8 + 2] + a_ptr[l].bsums[j * 32 + 8 + 6],
|
|
1094
|
+
__riscv_vget_v_i16m4_i16m1(mins, 1), 16);
|
|
1095
|
+
bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3,
|
|
1096
|
+
a_ptr[l].bsums[j * 32 + 8 + 3] + a_ptr[l].bsums[j * 32 + 8 + 7],
|
|
1097
|
+
__riscv_vget_v_i16m4_i16m1(mins, 1), 16);
|
|
1098
|
+
bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0,
|
|
1099
|
+
a_ptr[l].bsums[j * 32 + 16] + a_ptr[l].bsums[j * 32 + 16 + 4],
|
|
1100
|
+
__riscv_vget_v_i16m4_i16m1(mins, 2), 16);
|
|
1101
|
+
bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1,
|
|
1102
|
+
a_ptr[l].bsums[j * 32 + 16 + 1] + a_ptr[l].bsums[j * 32 + 16 + 5],
|
|
1103
|
+
__riscv_vget_v_i16m4_i16m1(mins, 2), 16);
|
|
1104
|
+
bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2,
|
|
1105
|
+
a_ptr[l].bsums[j * 32 + 16 + 2] + a_ptr[l].bsums[j * 32 + 16 + 6],
|
|
1106
|
+
__riscv_vget_v_i16m4_i16m1(mins, 2), 16);
|
|
1107
|
+
bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3,
|
|
1108
|
+
a_ptr[l].bsums[j * 32 + 16 + 3] + a_ptr[l].bsums[j * 32 + 16 + 7],
|
|
1109
|
+
__riscv_vget_v_i16m4_i16m1(mins, 2), 16);
|
|
1110
|
+
bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0,
|
|
1111
|
+
a_ptr[l].bsums[j * 32 + 24 + 0] + a_ptr[l].bsums[j * 32 + 24 + 4],
|
|
1112
|
+
__riscv_vget_v_i16m4_i16m1(mins, 3), 16);
|
|
1113
|
+
bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1,
|
|
1114
|
+
a_ptr[l].bsums[j * 32 + 24 + 1] + a_ptr[l].bsums[j * 32 + 24 + 5],
|
|
1115
|
+
__riscv_vget_v_i16m4_i16m1(mins, 3), 16);
|
|
1116
|
+
bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2,
|
|
1117
|
+
a_ptr[l].bsums[j * 32 + 24 + 2] + a_ptr[l].bsums[j * 32 + 24 + 6],
|
|
1118
|
+
__riscv_vget_v_i16m4_i16m1(mins, 3), 16);
|
|
1119
|
+
bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3,
|
|
1120
|
+
a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7],
|
|
1121
|
+
__riscv_vget_v_i16m4_i16m1(mins, 3), 16);
|
|
1122
|
+
|
|
1123
|
+
const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[0], 16);
|
|
1124
|
+
const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[1], 16);
|
|
1125
|
+
const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[2], 16);
|
|
1126
|
+
const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[3], 16);
|
|
1127
|
+
|
|
1128
|
+
sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, 16), 16), 16);
|
|
1129
|
+
sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, 16), 16), 16);
|
|
1130
|
+
sumf_2 = __riscv_vfsub_vv_f32m2(sumf_2, __riscv_vfmul_vv_f32m2(dmins_d_2, __riscv_vfcvt_f_x_v_f32m2(bsums_2, 16), 16), 16);
|
|
1131
|
+
sumf_3 = __riscv_vfsub_vv_f32m2(sumf_3, __riscv_vfmul_vv_f32m2(dmins_d_3, __riscv_vfcvt_f_x_v_f32m2(bsums_3, 16), 16), 16);
|
|
1132
|
+
|
|
1133
|
+
|
|
1134
|
+
// Accumulation for 2 sub-blocks.
|
|
1135
|
+
//
|
|
1136
|
+
// This might overflow, so we accumulate in two steps.
|
|
1137
|
+
//
|
|
1138
|
+
// Recheck.
|
|
1139
|
+
for (int k = 0; k < 2; k++) {
|
|
1140
|
+
// 4x16 integer accumulators
|
|
1141
|
+
vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
1142
|
+
vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
1143
|
+
vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
1144
|
+
vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
1145
|
+
vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
1146
|
+
vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
1147
|
+
vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
1148
|
+
vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
1149
|
+
|
|
1150
|
+
for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {
|
|
1151
|
+
// Load `b_ptr`.
|
|
1152
|
+
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16);
|
|
1153
|
+
const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16));
|
|
1154
|
+
const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16));
|
|
1155
|
+
|
|
1156
|
+
sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + i * 4], b_s_0, 16);
|
|
1157
|
+
sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 1], b_s_0, 16);
|
|
1158
|
+
sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 2], b_s_0, 16);
|
|
1159
|
+
sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 3], b_s_0, 16);
|
|
1160
|
+
|
|
1161
|
+
sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4], b_s_1, 16);
|
|
1162
|
+
sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 1], b_s_1, 16);
|
|
1163
|
+
sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 2], b_s_1, 16);
|
|
1164
|
+
sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 3], b_s_1, 16);
|
|
1165
|
+
}
|
|
1166
|
+
|
|
1167
|
+
sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0,
|
|
1168
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),
|
|
1169
|
+
sumi_0_s_0_16, 16);
|
|
1170
|
+
sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0,
|
|
1171
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),
|
|
1172
|
+
sumi_0_s_1_16, 16);
|
|
1173
|
+
sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1,
|
|
1174
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),
|
|
1175
|
+
sumi_1_s_0_16, 16);
|
|
1176
|
+
sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1,
|
|
1177
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),
|
|
1178
|
+
sumi_1_s_1_16, 16);
|
|
1179
|
+
sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2,
|
|
1180
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),
|
|
1181
|
+
sumi_2_s_0_16, 16);
|
|
1182
|
+
sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2,
|
|
1183
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),
|
|
1184
|
+
sumi_2_s_1_16, 16);
|
|
1185
|
+
sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3,
|
|
1186
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),
|
|
1187
|
+
sumi_3_s_0_16, 16);
|
|
1188
|
+
sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3,
|
|
1189
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),
|
|
1190
|
+
sumi_3_s_1_16, 16);
|
|
1191
|
+
}
|
|
1192
|
+
// Accumulation for 2 sub-blocks.
|
|
1193
|
+
//
|
|
1194
|
+
// This might overflow, so we accumulate in two steps.
|
|
1195
|
+
//
|
|
1196
|
+
// Recheck.
|
|
1197
|
+
for (int k = 0; k < 2; k++) {
|
|
1198
|
+
// 4x16 integer accumulators
|
|
1199
|
+
vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
1200
|
+
vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
1201
|
+
vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
1202
|
+
vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
1203
|
+
vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
1204
|
+
vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
1205
|
+
vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
1206
|
+
vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
1207
|
+
|
|
1208
|
+
for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {
|
|
1209
|
+
// Load `b_ptr`.
|
|
1210
|
+
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16);
|
|
1211
|
+
const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16));
|
|
1212
|
+
const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16));
|
|
1213
|
+
|
|
1214
|
+
sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4], b_s_0, 16);
|
|
1215
|
+
sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 1], b_s_0, 16);
|
|
1216
|
+
sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 2], b_s_0, 16);
|
|
1217
|
+
sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 3], b_s_0, 16);
|
|
1218
|
+
|
|
1219
|
+
sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4], b_s_1, 16);
|
|
1220
|
+
sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 1], b_s_1, 16);
|
|
1221
|
+
sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 2], b_s_1, 16);
|
|
1222
|
+
sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 3], b_s_1, 16);
|
|
1223
|
+
}
|
|
1224
|
+
|
|
1225
|
+
sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0,
|
|
1226
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),
|
|
1227
|
+
sumi_0_s_0_16, 16);
|
|
1228
|
+
sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0,
|
|
1229
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),
|
|
1230
|
+
sumi_0_s_1_16, 16);
|
|
1231
|
+
sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1,
|
|
1232
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),
|
|
1233
|
+
sumi_1_s_0_16, 16);
|
|
1234
|
+
sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1,
|
|
1235
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),
|
|
1236
|
+
sumi_1_s_1_16, 16);
|
|
1237
|
+
sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2,
|
|
1238
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),
|
|
1239
|
+
sumi_2_s_0_16, 16);
|
|
1240
|
+
sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2,
|
|
1241
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),
|
|
1242
|
+
sumi_2_s_1_16, 16);
|
|
1243
|
+
sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3,
|
|
1244
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),
|
|
1245
|
+
sumi_3_s_0_16, 16);
|
|
1246
|
+
sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3,
|
|
1247
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),
|
|
1248
|
+
sumi_3_s_1_16, 16);
|
|
1249
|
+
}
|
|
1250
|
+
}
|
|
1251
|
+
|
|
1252
|
+
const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16), 16);
|
|
1253
|
+
const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], 16);
|
|
1254
|
+
const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], 16);
|
|
1255
|
+
const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], 16);
|
|
1256
|
+
const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[3], 16);
|
|
1257
|
+
|
|
1258
|
+
sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);
|
|
1259
|
+
sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);
|
|
1260
|
+
sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16);
|
|
1261
|
+
sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16);
|
|
1262
|
+
}
|
|
1263
|
+
|
|
1264
|
+
__riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16);
|
|
1265
|
+
__riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16);
|
|
1266
|
+
__riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16);
|
|
1267
|
+
__riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);
|
|
1268
|
+
}
|
|
1269
|
+
}
|
|
1270
|
+
return;
|
|
1271
|
+
#endif
|
|
1272
|
+
ggml_gemm_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
1273
|
+
}
|
|
1274
|
+
|
|
1275
|
+
void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
1276
|
+
const int qk = QK8_0;
|
|
1277
|
+
const int nb = n / qk;
|
|
1278
|
+
const int ncols_interleaved = 16;
|
|
1279
|
+
const int blocklen = 1;
|
|
1280
|
+
|
|
1281
|
+
assert (n % qk == 0);
|
|
1282
|
+
assert (nr % 4 == 0);
|
|
1283
|
+
assert (nc % ncols_interleaved == 0);
|
|
1284
|
+
|
|
1285
|
+
UNUSED(s);
|
|
1286
|
+
UNUSED(bs);
|
|
1287
|
+
UNUSED(vx);
|
|
1288
|
+
UNUSED(vy);
|
|
1289
|
+
UNUSED(nr);
|
|
1290
|
+
UNUSED(nc);
|
|
1291
|
+
UNUSED(nb);
|
|
1292
|
+
UNUSED(ncols_interleaved);
|
|
1293
|
+
UNUSED(blocklen);
|
|
1294
|
+
|
|
1295
|
+
#if defined __riscv_v_intrinsic
|
|
1296
|
+
const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16);
|
|
1297
|
+
for (int y = 0; y < nr / 4; y++) {
|
|
1298
|
+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
1299
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
1300
|
+
const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb);
|
|
1301
|
+
|
|
1302
|
+
// 4x16 Accumulators
|
|
1303
|
+
vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
1304
|
+
vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
1305
|
+
vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
1306
|
+
vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
1307
|
+
|
|
1308
|
+
for (int l = 0; l < nb; l++) {
|
|
1309
|
+
// 4x16 integer accumulators
|
|
1310
|
+
vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
1311
|
+
vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
1312
|
+
vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
1313
|
+
vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
1314
|
+
|
|
1315
|
+
// Accumulation loop.
|
|
1316
|
+
for (int i = 0; i < QK4_NL / 2; i++) {
|
|
1317
|
+
// Load `b_ptr`.
|
|
1318
|
+
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16);
|
|
1319
|
+
const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16);
|
|
1320
|
+
const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16);
|
|
1321
|
+
// const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
|
|
1322
|
+
// const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
|
|
1323
|
+
|
|
1324
|
+
const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], 16);
|
|
1325
|
+
const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], 16);
|
|
1326
|
+
const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], 16);
|
|
1327
|
+
const vint16m1_t sumi_3_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 3], 16);
|
|
1328
|
+
|
|
1329
|
+
const vint16m1_t sumi_0_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4], 16);
|
|
1330
|
+
const vint16m1_t sumi_1_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 1], 16);
|
|
1331
|
+
const vint16m1_t sumi_2_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 2], 16);
|
|
1332
|
+
const vint16m1_t sumi_3_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 3], 16);
|
|
1333
|
+
|
|
1334
|
+
sumi_0 = __riscv_vadd_vv_i32m2(sumi_0, __riscv_vwadd_vv_i32m2(sumi_0_lo, sumi_0_hi, 16), 16);
|
|
1335
|
+
sumi_1 = __riscv_vadd_vv_i32m2(sumi_1, __riscv_vwadd_vv_i32m2(sumi_1_lo, sumi_1_hi, 16), 16);
|
|
1336
|
+
sumi_2 = __riscv_vadd_vv_i32m2(sumi_2, __riscv_vwadd_vv_i32m2(sumi_2_lo, sumi_2_hi, 16), 16);
|
|
1337
|
+
sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, 16), 16);
|
|
1338
|
+
}
|
|
1339
|
+
|
|
1340
|
+
const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
|
|
1341
|
+
const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16);
|
|
1342
|
+
const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16);
|
|
1343
|
+
const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16);
|
|
1344
|
+
const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16);
|
|
1345
|
+
|
|
1346
|
+
sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);
|
|
1347
|
+
sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);
|
|
1348
|
+
sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16);
|
|
1349
|
+
sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16);
|
|
1350
|
+
}
|
|
1351
|
+
|
|
1352
|
+
__riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16);
|
|
1353
|
+
__riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16);
|
|
1354
|
+
__riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16);
|
|
1355
|
+
__riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);
|
|
1356
|
+
}
|
|
1357
|
+
}
|
|
1358
|
+
return;
|
|
1359
|
+
#endif
|
|
1360
|
+
ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
1361
|
+
}
|
|
1362
|
+
|
|
1363
|
+
void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
1364
|
+
const int qk = QK8_0;
|
|
1365
|
+
const int nb = n / qk;
|
|
1366
|
+
const int ncols_interleaved = 16;
|
|
1367
|
+
const int blocklen = 1;
|
|
1368
|
+
|
|
1369
|
+
assert (n % qk == 0);
|
|
1370
|
+
assert (nr % 4 == 0);
|
|
1371
|
+
assert (nc % ncols_interleaved == 0);
|
|
1372
|
+
|
|
1373
|
+
UNUSED(s);
|
|
1374
|
+
UNUSED(bs);
|
|
1375
|
+
UNUSED(vx);
|
|
1376
|
+
UNUSED(vy);
|
|
1377
|
+
UNUSED(nr);
|
|
1378
|
+
UNUSED(nc);
|
|
1379
|
+
UNUSED(nb);
|
|
1380
|
+
UNUSED(ncols_interleaved);
|
|
1381
|
+
UNUSED(blocklen);
|
|
1382
|
+
|
|
1383
|
+
#if defined __riscv_v_intrinsic
|
|
1384
|
+
for (int y = 0; y < nr / 4; y++) {
|
|
1385
|
+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
1386
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
1387
|
+
const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb);
|
|
1388
|
+
|
|
1389
|
+
// 4x16 Accumulators
|
|
1390
|
+
vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
1391
|
+
vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
1392
|
+
vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
1393
|
+
vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
1394
|
+
|
|
1395
|
+
for (int l = 0; l < nb; l++) {
|
|
1396
|
+
// 4x16 Integer Accumulators
|
|
1397
|
+
vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
1398
|
+
vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
1399
|
+
vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
1400
|
+
vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
1401
|
+
|
|
1402
|
+
// Accumulation loop.
|
|
1403
|
+
for (int i = 0; i < QK8_0; i++) {
|
|
1404
|
+
// Load `b_ptr`.
|
|
1405
|
+
const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16);
|
|
1406
|
+
// const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16);
|
|
1407
|
+
|
|
1408
|
+
sumi_0 = __riscv_vwadd_wv_i32m2(sumi_0, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 0], 16), 16);
|
|
1409
|
+
sumi_1 = __riscv_vwadd_wv_i32m2(sumi_1, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 1], 16), 16);
|
|
1410
|
+
sumi_2 = __riscv_vwadd_wv_i32m2(sumi_2, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 2], 16), 16);
|
|
1411
|
+
sumi_3 = __riscv_vwadd_wv_i32m2(sumi_3, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 3], 16), 16);
|
|
1412
|
+
}
|
|
1413
|
+
|
|
1414
|
+
const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
|
|
1415
|
+
const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16);
|
|
1416
|
+
const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16);
|
|
1417
|
+
const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16);
|
|
1418
|
+
const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16);
|
|
1419
|
+
|
|
1420
|
+
sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);
|
|
1421
|
+
sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);
|
|
1422
|
+
sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16);
|
|
1423
|
+
sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16);
|
|
1424
|
+
}
|
|
1425
|
+
|
|
1426
|
+
__riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16);
|
|
1427
|
+
__riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16);
|
|
1428
|
+
__riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16);
|
|
1429
|
+
__riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);
|
|
1430
|
+
}
|
|
1431
|
+
}
|
|
1432
|
+
return;
|
|
1433
|
+
#endif
|
|
1434
|
+
ggml_gemm_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
1435
|
+
}
|
|
1436
|
+
|
|
1437
|
+
void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
1438
|
+
assert(n % QK_K == 0);
|
|
1439
|
+
const int num_k_blocks = n / QK_K;
|
|
1440
|
+
const int N_ROWS_TILE = 4;
|
|
1441
|
+
const int N_COLS_TILE = 16;
|
|
1442
|
+
assert(nr % N_ROWS_TILE == 0);
|
|
1443
|
+
assert(nc % N_COLS_TILE == 0);
|
|
1444
|
+
|
|
1445
|
+
const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE);
|
|
1446
|
+
// --- Tiling Loops ---
|
|
1447
|
+
#pragma GCC unroll 1
|
|
1448
|
+
for (int row_tile = 0; row_tile < nr; row_tile += N_ROWS_TILE) {
|
|
1449
|
+
#pragma GCC unroll 1
|
|
1450
|
+
for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) {
|
|
1451
|
+
// Base Pointers
|
|
1452
|
+
const block_q8_Kx4* lhs_base_ptr = (const block_q8_Kx4*)vy + (row_tile / N_ROWS_TILE) * num_k_blocks;
|
|
1453
|
+
const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks;
|
|
1454
|
+
|
|
1455
|
+
// Persistent Float Accumulators
|
|
1456
|
+
vfloat32m2_t v_sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
|
|
1457
|
+
vfloat32m2_t v_sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
|
|
1458
|
+
vfloat32m2_t v_sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
|
|
1459
|
+
vfloat32m2_t v_sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
|
|
1460
|
+
|
|
1461
|
+
// --- Super-Block Loop (K=0..255) ---
|
|
1462
|
+
#pragma GCC unroll 1
|
|
1463
|
+
for (int k_block = 0; k_block < num_k_blocks; ++k_block) {
|
|
1464
|
+
const block_q8_Kx4* lhs_current = &lhs_base_ptr[k_block];
|
|
1465
|
+
const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block];
|
|
1466
|
+
|
|
1467
|
+
// 1. Load Global Min Scales (Keep as F16/LMUL=1 to save registers)
|
|
1468
|
+
vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl);
|
|
1469
|
+
vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl);
|
|
1470
|
+
|
|
1471
|
+
// 2. Initialize Integer Accumulators
|
|
1472
|
+
vint32m2_t v_isum_0 = __riscv_vmv_v_x_i32m2(0, vl);
|
|
1473
|
+
vint32m2_t v_isum_1 = __riscv_vmv_v_x_i32m2(0, vl);
|
|
1474
|
+
vint32m2_t v_isum_2 = __riscv_vmv_v_x_i32m2(0, vl);
|
|
1475
|
+
vint32m2_t v_isum_3 = __riscv_vmv_v_x_i32m2(0, vl);
|
|
1476
|
+
|
|
1477
|
+
const uint8_t* rhs_qs_ptr = rhs_current->qs;
|
|
1478
|
+
const uint8_t* rhs_sc_ptr = rhs_current->scales;
|
|
1479
|
+
const int8_t* lhs_qs_ptr = lhs_current->qs;
|
|
1480
|
+
|
|
1481
|
+
// --- Phase Loop (4 phases x 64 elements) ---
|
|
1482
|
+
#pragma GCC unroll 1
|
|
1483
|
+
for (int phase = 0; phase < 4; ++phase) {
|
|
1484
|
+
|
|
1485
|
+
// A. Load Scales/Mins for the 4 interleaved sub-blocks
|
|
1486
|
+
vuint16m1_t v_d_sb_0, v_d_sb_1, v_d_sb_2, v_d_sb_3;
|
|
1487
|
+
vuint16m1_t v_m_sb_0, v_m_sb_1, v_m_sb_2, v_m_sb_3;
|
|
1488
|
+
|
|
1489
|
+
// Unrolled Load Logic
|
|
1490
|
+
{
|
|
1491
|
+
vuint8mf2_t v_raw;
|
|
1492
|
+
// Sub-block 0
|
|
1493
|
+
v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl);
|
|
1494
|
+
v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
|
|
1495
|
+
v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
|
|
1496
|
+
|
|
1497
|
+
// Sub-block 1
|
|
1498
|
+
v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl);
|
|
1499
|
+
v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
|
|
1500
|
+
v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
|
|
1501
|
+
|
|
1502
|
+
// Sub-block 2
|
|
1503
|
+
v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl);
|
|
1504
|
+
v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
|
|
1505
|
+
v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
|
|
1506
|
+
|
|
1507
|
+
// Sub-block 3
|
|
1508
|
+
v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl);
|
|
1509
|
+
v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
|
|
1510
|
+
v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
|
|
1511
|
+
|
|
1512
|
+
rhs_sc_ptr += 64;
|
|
1513
|
+
}
|
|
1514
|
+
|
|
1515
|
+
int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16);
|
|
1516
|
+
int k_offsets[4] = {0, 32, 64, 96};
|
|
1517
|
+
|
|
1518
|
+
// B. Inner Dot Product Loop
|
|
1519
|
+
#pragma GCC unroll 1
|
|
1520
|
+
for (int l = 0; l < 16; ++l) {
|
|
1521
|
+
vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl);
|
|
1522
|
+
rhs_qs_ptr += 16;
|
|
1523
|
+
|
|
1524
|
+
// Unroll over 4 sub-blocks (0, 1, 2, 3 relative to phase)
|
|
1525
|
+
|
|
1526
|
+
// --- Sub-block 0 ---
|
|
1527
|
+
{
|
|
1528
|
+
vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl);
|
|
1529
|
+
vint16m1_t v_w = __riscv_vmul_vv_i16m1(
|
|
1530
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
|
|
1531
|
+
__riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl);
|
|
1532
|
+
|
|
1533
|
+
const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[0] + l) * 4];
|
|
1534
|
+
v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl);
|
|
1535
|
+
v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl);
|
|
1536
|
+
v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl);
|
|
1537
|
+
v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl);
|
|
1538
|
+
}
|
|
1539
|
+
// --- Sub-block 1 ---
|
|
1540
|
+
{
|
|
1541
|
+
vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl);
|
|
1542
|
+
vint16m1_t v_w = __riscv_vmul_vv_i16m1(
|
|
1543
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
|
|
1544
|
+
__riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl);
|
|
1545
|
+
|
|
1546
|
+
const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[1] + l) * 4];
|
|
1547
|
+
v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl);
|
|
1548
|
+
v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl);
|
|
1549
|
+
v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl);
|
|
1550
|
+
v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl);
|
|
1551
|
+
}
|
|
1552
|
+
// --- Sub-block 2 ---
|
|
1553
|
+
{
|
|
1554
|
+
vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl);
|
|
1555
|
+
vint16m1_t v_w = __riscv_vmul_vv_i16m1(
|
|
1556
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
|
|
1557
|
+
__riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl);
|
|
1558
|
+
|
|
1559
|
+
const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[2] + l) * 4];
|
|
1560
|
+
v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl);
|
|
1561
|
+
v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl);
|
|
1562
|
+
v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl);
|
|
1563
|
+
v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl);
|
|
1564
|
+
}
|
|
1565
|
+
// --- Sub-block 3 ---
|
|
1566
|
+
{
|
|
1567
|
+
vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl);
|
|
1568
|
+
vint16m1_t v_w = __riscv_vmul_vv_i16m1(
|
|
1569
|
+
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
|
|
1570
|
+
__riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl);
|
|
1571
|
+
|
|
1572
|
+
const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[3] + l) * 4];
|
|
1573
|
+
v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl);
|
|
1574
|
+
v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl);
|
|
1575
|
+
v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl);
|
|
1576
|
+
v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl);
|
|
1577
|
+
}
|
|
1578
|
+
}
|
|
1579
|
+
|
|
1580
|
+
// C CORRECTION
|
|
1581
|
+
int sb_base_abs = base_k_phase / 16;
|
|
1582
|
+
|
|
1583
|
+
// --- Correction Sub-block 0 ---
|
|
1584
|
+
{
|
|
1585
|
+
int sb_abs = sb_base_abs + (k_offsets[0] / 16);
|
|
1586
|
+
vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0);
|
|
1587
|
+
|
|
1588
|
+
// Row 0
|
|
1589
|
+
vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl);
|
|
1590
|
+
vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl);
|
|
1591
|
+
vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
1592
|
+
v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl);
|
|
1593
|
+
|
|
1594
|
+
// Row 1
|
|
1595
|
+
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl);
|
|
1596
|
+
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl);
|
|
1597
|
+
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
1598
|
+
v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl);
|
|
1599
|
+
|
|
1600
|
+
// Row 2
|
|
1601
|
+
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl);
|
|
1602
|
+
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl);
|
|
1603
|
+
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
1604
|
+
v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl);
|
|
1605
|
+
|
|
1606
|
+
// Row 3
|
|
1607
|
+
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl);
|
|
1608
|
+
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl);
|
|
1609
|
+
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
1610
|
+
v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl);
|
|
1611
|
+
}
|
|
1612
|
+
|
|
1613
|
+
// --- Correction Sub-block 1 ---
|
|
1614
|
+
{
|
|
1615
|
+
int sb_abs = sb_base_abs + (k_offsets[1] / 16);
|
|
1616
|
+
vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1);
|
|
1617
|
+
|
|
1618
|
+
vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl);
|
|
1619
|
+
vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl);
|
|
1620
|
+
vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
1621
|
+
v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl);
|
|
1622
|
+
|
|
1623
|
+
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl);
|
|
1624
|
+
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl);
|
|
1625
|
+
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
1626
|
+
v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl);
|
|
1627
|
+
|
|
1628
|
+
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl);
|
|
1629
|
+
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl);
|
|
1630
|
+
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
1631
|
+
v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl);
|
|
1632
|
+
|
|
1633
|
+
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl);
|
|
1634
|
+
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl);
|
|
1635
|
+
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
1636
|
+
v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl);
|
|
1637
|
+
}
|
|
1638
|
+
|
|
1639
|
+
// --- Correction Sub-block 2 ---
|
|
1640
|
+
{
|
|
1641
|
+
int sb_abs = sb_base_abs + (k_offsets[2] / 16);
|
|
1642
|
+
vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2);
|
|
1643
|
+
|
|
1644
|
+
vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl);
|
|
1645
|
+
vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl);
|
|
1646
|
+
vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
1647
|
+
v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl);
|
|
1648
|
+
|
|
1649
|
+
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl);
|
|
1650
|
+
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl);
|
|
1651
|
+
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
1652
|
+
v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl);
|
|
1653
|
+
|
|
1654
|
+
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl);
|
|
1655
|
+
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl);
|
|
1656
|
+
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
1657
|
+
v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl);
|
|
1658
|
+
|
|
1659
|
+
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl);
|
|
1660
|
+
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl);
|
|
1661
|
+
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
1662
|
+
v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl);
|
|
1663
|
+
}
|
|
1664
|
+
|
|
1665
|
+
// --- Correction Sub-block 3 ---
|
|
1666
|
+
{
|
|
1667
|
+
int sb_abs = sb_base_abs + (k_offsets[3] / 16);
|
|
1668
|
+
vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3);
|
|
1669
|
+
|
|
1670
|
+
vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl);
|
|
1671
|
+
vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl);
|
|
1672
|
+
vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
1673
|
+
v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl);
|
|
1674
|
+
|
|
1675
|
+
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl);
|
|
1676
|
+
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl);
|
|
1677
|
+
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
1678
|
+
v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl);
|
|
1679
|
+
|
|
1680
|
+
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl);
|
|
1681
|
+
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl);
|
|
1682
|
+
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
1683
|
+
v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl);
|
|
1684
|
+
|
|
1685
|
+
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl);
|
|
1686
|
+
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl);
|
|
1687
|
+
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
1688
|
+
v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl);
|
|
1689
|
+
}
|
|
1690
|
+
|
|
1691
|
+
} // End Phase Loop
|
|
1692
|
+
|
|
1693
|
+
// --- Apply Main Scales ---
|
|
1694
|
+
vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl);
|
|
1695
|
+
vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl);
|
|
1696
|
+
|
|
1697
|
+
{
|
|
1698
|
+
vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[0], vl);
|
|
1699
|
+
vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_0, vl);
|
|
1700
|
+
v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl);
|
|
1701
|
+
v_sumf_0 = __riscv_vfadd_vv_f32m2(v_sumf_0, v_sum, vl);
|
|
1702
|
+
}
|
|
1703
|
+
// Row 1
|
|
1704
|
+
{
|
|
1705
|
+
vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[1], vl);
|
|
1706
|
+
vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_1, vl);
|
|
1707
|
+
v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl);
|
|
1708
|
+
v_sumf_1 = __riscv_vfadd_vv_f32m2(v_sumf_1, v_sum, vl);
|
|
1709
|
+
}
|
|
1710
|
+
// Row 2
|
|
1711
|
+
{
|
|
1712
|
+
vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[2], vl);
|
|
1713
|
+
vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_2, vl);
|
|
1714
|
+
v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl);
|
|
1715
|
+
v_sumf_2 = __riscv_vfadd_vv_f32m2(v_sumf_2, v_sum, vl);
|
|
1716
|
+
}
|
|
1717
|
+
// Row 3
|
|
1718
|
+
{
|
|
1719
|
+
vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[3], vl);
|
|
1720
|
+
vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_3, vl);
|
|
1721
|
+
v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl);
|
|
1722
|
+
v_sumf_3 = __riscv_vfadd_vv_f32m2(v_sumf_3, v_sum, vl);
|
|
1723
|
+
}
|
|
1724
|
+
|
|
1725
|
+
} // End K-Block
|
|
1726
|
+
|
|
1727
|
+
__riscv_vse32_v_f32m2(s + (row_tile + 0) * bs + col_tile, v_sumf_0, vl);
|
|
1728
|
+
__riscv_vse32_v_f32m2(s + (row_tile + 1) * bs + col_tile, v_sumf_1, vl);
|
|
1729
|
+
__riscv_vse32_v_f32m2(s + (row_tile + 2) * bs + col_tile, v_sumf_2, vl);
|
|
1730
|
+
__riscv_vse32_v_f32m2(s + (row_tile + 3) * bs + col_tile, v_sumf_3, vl);
|
|
1731
|
+
}
|
|
1732
|
+
}
|
|
1733
|
+
}
|