whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -3,105 +3,50 @@
|
|
|
3
3
|
#pragma clang diagnostic ignored "-Wunused-variable"
|
|
4
4
|
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
|
5
5
|
|
|
6
|
-
#ifdef HTP_DEBUG
|
|
7
|
-
# define FARF_HIGH 1
|
|
8
|
-
#endif
|
|
9
|
-
|
|
10
6
|
#include <HAP_farf.h>
|
|
11
|
-
#include <HAP_mem.h>
|
|
12
7
|
#include <HAP_perf.h>
|
|
13
|
-
|
|
14
|
-
#include <hexagon_protos.h>
|
|
15
|
-
#include <hexagon_types.h>
|
|
8
|
+
|
|
16
9
|
#include <math.h>
|
|
17
|
-
#include <qurt_thread.h>
|
|
18
10
|
#include <string.h>
|
|
19
11
|
|
|
12
|
+
#include "hex-dma.h"
|
|
13
|
+
#include "hvx-utils.h"
|
|
14
|
+
#include "hvx-dump.h"
|
|
15
|
+
|
|
20
16
|
#define GGML_COMMON_DECL_C
|
|
21
17
|
#include "ggml-common.h"
|
|
22
18
|
#include "htp-ctx.h"
|
|
23
|
-
#include "htp-dma.h"
|
|
24
19
|
#include "htp-msg.h"
|
|
25
20
|
#include "htp-ops.h"
|
|
26
|
-
#include "hvx-utils.h"
|
|
27
|
-
#include "ops-utils.h"
|
|
28
21
|
|
|
29
22
|
#define MM_SPAD_SRC0_NROWS 16
|
|
30
23
|
#define MM_SPAD_SRC1_NROWS 16
|
|
31
24
|
#define MM_SPAD_DST_NROWS 2
|
|
32
25
|
|
|
33
|
-
struct
|
|
26
|
+
struct htp_matmul_context {
|
|
34
27
|
const char * type;
|
|
35
|
-
|
|
36
|
-
void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy);
|
|
37
|
-
};
|
|
28
|
+
struct htp_ops_context * octx;
|
|
38
29
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
typedef struct {
|
|
44
|
-
HVX_Vector v[4];
|
|
45
|
-
} HVX_Vector_x4;
|
|
46
|
-
|
|
47
|
-
typedef struct {
|
|
48
|
-
HVX_Vector v[8];
|
|
49
|
-
} HVX_Vector_x8;
|
|
50
|
-
|
|
51
|
-
// vdelta control to replicate first 4x fp32 values across lanes
|
|
52
|
-
static const uint8_t __attribute__((aligned(128))) repl_4x_fp32[128] = {
|
|
53
|
-
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
|
|
54
|
-
0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
|
|
55
|
-
0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
|
|
56
|
-
0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
|
|
57
|
-
0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04,
|
|
58
|
-
0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
|
|
59
|
-
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10,
|
|
60
|
-
};
|
|
30
|
+
void (*vec_dot_1x1)(const int n, float * restrict s0,
|
|
31
|
+
const void * restrict vx0,
|
|
32
|
+
const void * restrict vy0);
|
|
61
33
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
|
|
66
|
-
0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
|
|
67
|
-
0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
|
|
68
|
-
0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44,
|
|
69
|
-
0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
|
|
70
|
-
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20,
|
|
71
|
-
};
|
|
34
|
+
void (*vec_dot_2x1)(const int n, float * restrict s0,
|
|
35
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
36
|
+
const void * restrict vy0);
|
|
72
37
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
|
|
77
|
-
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
|
|
78
|
-
0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08,
|
|
79
|
-
0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04,
|
|
80
|
-
0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10,
|
|
81
|
-
0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
|
82
|
-
};
|
|
38
|
+
void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1,
|
|
39
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
40
|
+
const void * restrict vy0, const void * restrict vy1);
|
|
83
41
|
|
|
84
|
-
//
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
|
|
88
|
-
0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
|
|
89
|
-
0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02,
|
|
90
|
-
0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02,
|
|
91
|
-
0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10,
|
|
92
|
-
0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
|
93
|
-
};
|
|
42
|
+
// Precomputed values
|
|
43
|
+
uint32_t src0_nrows_per_thread;
|
|
44
|
+
uint32_t src1_nrows_per_thread;
|
|
94
45
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
|
100
|
-
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
|
101
|
-
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
|
102
|
-
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
|
103
|
-
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
|
104
|
-
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
|
46
|
+
struct fastdiv_values mm_div_ne12_ne1;
|
|
47
|
+
struct fastdiv_values mm_div_ne1;
|
|
48
|
+
struct fastdiv_values mm_div_r2;
|
|
49
|
+
struct fastdiv_values mm_div_r3;
|
|
105
50
|
};
|
|
106
51
|
|
|
107
52
|
// vdelta control to expand first 32 e8m0 values into 32 uint32 elements
|
|
@@ -129,10 +74,10 @@ static inline size_t q8x4x2_row_size(uint32_t ne) {
|
|
|
129
74
|
// ensures perfect alignment of quants and full row
|
|
130
75
|
const uint32_t qk = QK_Q8_0x4x2;
|
|
131
76
|
const uint32_t nb = (ne + qk - 1) / qk;
|
|
132
|
-
return
|
|
77
|
+
return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128);
|
|
133
78
|
}
|
|
134
79
|
|
|
135
|
-
static inline HVX_Vector_x8
|
|
80
|
+
static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) {
|
|
136
81
|
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
|
137
82
|
|
|
138
83
|
HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
|
|
@@ -141,10 +86,11 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
|
|
|
141
86
|
HVX_Vector v6_7 = vptr[3]; // ...
|
|
142
87
|
|
|
143
88
|
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
|
89
|
+
const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
|
|
144
90
|
|
|
145
|
-
HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
|
|
146
|
-
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
|
|
147
|
-
HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
|
|
91
|
+
HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements
|
|
92
|
+
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements
|
|
93
|
+
HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ...
|
|
148
94
|
HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
|
|
149
95
|
HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
|
|
150
96
|
HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
|
|
@@ -152,21 +98,54 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
|
|
|
152
98
|
HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
|
|
153
99
|
|
|
154
100
|
// Convert uint4 to int4 (i.e. x - 8)
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
v7 = Q6_Vb_vsub_VbVb(v7, i8);
|
|
101
|
+
v0 = Q6_Vb_vsub_VbVb(v0, i8);
|
|
102
|
+
v1 = Q6_Vb_vsub_VbVb(v1, i8);
|
|
103
|
+
v2 = Q6_Vb_vsub_VbVb(v2, i8);
|
|
104
|
+
v3 = Q6_Vb_vsub_VbVb(v3, i8);
|
|
105
|
+
v4 = Q6_Vb_vsub_VbVb(v4, i8);
|
|
106
|
+
v5 = Q6_Vb_vsub_VbVb(v5, i8);
|
|
107
|
+
v6 = Q6_Vb_vsub_VbVb(v6, i8);
|
|
108
|
+
v7 = Q6_Vb_vsub_VbVb(v7, i8);
|
|
164
109
|
|
|
165
110
|
HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
|
|
166
111
|
return r;
|
|
167
112
|
}
|
|
168
113
|
|
|
169
|
-
static
|
|
114
|
+
static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
|
|
115
|
+
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
|
116
|
+
|
|
117
|
+
const uint32_t qk = QK_Q4_0x4x2; // 256
|
|
118
|
+
const uint32_t nb = n / qk;
|
|
119
|
+
const uint32_t nloe = n % qk;
|
|
120
|
+
|
|
121
|
+
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
|
122
|
+
const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
|
|
123
|
+
|
|
124
|
+
HVX_Vector_x8 r;
|
|
125
|
+
uint32_t i = 0;
|
|
126
|
+
|
|
127
|
+
#pragma unroll(2)
|
|
128
|
+
for (i=0; i < nb; i++) {
|
|
129
|
+
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
|
|
130
|
+
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements
|
|
131
|
+
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements
|
|
132
|
+
r.v[i*2+0] = Q6_Vb_vsub_VbVb(v0, i8);
|
|
133
|
+
r.v[i*2+1] = Q6_Vb_vsub_VbVb(v1, i8);
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
if (nloe) {
|
|
137
|
+
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
|
|
138
|
+
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements
|
|
139
|
+
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements
|
|
140
|
+
HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
|
|
141
|
+
r.v[i*2+0] = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v0_1_p), i8);
|
|
142
|
+
r.v[i*2+1] = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v0_1_p), i8);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
return r;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) {
|
|
170
149
|
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
|
171
150
|
|
|
172
151
|
HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
|
|
@@ -175,6 +154,7 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr)
|
|
|
175
154
|
HVX_Vector v6_7 = vptr[3]; // ...
|
|
176
155
|
|
|
177
156
|
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
|
157
|
+
const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
|
|
178
158
|
|
|
179
159
|
HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
|
|
180
160
|
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
|
|
@@ -185,21 +165,54 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr)
|
|
|
185
165
|
HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
|
|
186
166
|
HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
|
|
187
167
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
|
|
168
|
+
v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
|
|
169
|
+
v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
|
|
170
|
+
v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
|
|
171
|
+
v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
|
|
172
|
+
v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
|
|
173
|
+
v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
|
|
174
|
+
v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
|
|
175
|
+
v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
|
|
197
176
|
|
|
198
177
|
HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
|
|
199
178
|
return r;
|
|
200
179
|
}
|
|
201
180
|
|
|
202
|
-
static inline HVX_Vector_x8
|
|
181
|
+
static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
|
|
182
|
+
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
|
183
|
+
|
|
184
|
+
const uint32_t qk = QK_Q4_0x4x2; // 256
|
|
185
|
+
const uint32_t nb = n / qk;
|
|
186
|
+
const uint32_t nloe = n % qk;
|
|
187
|
+
|
|
188
|
+
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
|
189
|
+
const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
|
|
190
|
+
|
|
191
|
+
HVX_Vector_x8 r;
|
|
192
|
+
uint32_t i = 0;
|
|
193
|
+
|
|
194
|
+
#pragma unroll(2)
|
|
195
|
+
for (i=0; i < nb; i++) {
|
|
196
|
+
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
|
|
197
|
+
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements
|
|
198
|
+
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements
|
|
199
|
+
r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
|
|
200
|
+
r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
if (nloe) {
|
|
204
|
+
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
|
|
205
|
+
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements
|
|
206
|
+
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements
|
|
207
|
+
HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
|
|
208
|
+
r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0);
|
|
209
|
+
r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0);
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
return r;
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_full(const uint8_t * restrict ptr) {
|
|
203
216
|
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
|
204
217
|
|
|
205
218
|
HVX_Vector v0 = vptr[0]; // first 128 vals
|
|
@@ -215,44 +228,8 @@ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) {
|
|
|
215
228
|
return r;
|
|
216
229
|
}
|
|
217
230
|
|
|
218
|
-
static inline
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
HVX_Vector v0 = vptr[0]; // first 64 vals
|
|
222
|
-
HVX_Vector v1 = vptr[1]; // second 64 vals
|
|
223
|
-
HVX_Vector v2 = vptr[2]; // third 64 vals
|
|
224
|
-
HVX_Vector v3 = vptr[3]; // forth 64 vals
|
|
225
|
-
|
|
226
|
-
HVX_Vector_x4 r = { v0, v1, v2, v3 };
|
|
227
|
-
return r;
|
|
228
|
-
}
|
|
229
|
-
|
|
230
|
-
static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict ptr) {
|
|
231
|
-
const HVX_VectorPair * restrict vptr = (const HVX_VectorPair *) ptr;
|
|
232
|
-
|
|
233
|
-
HVX_VectorPair v0 = vptr[0]; // first 64 vals
|
|
234
|
-
HVX_VectorPair v1 = vptr[1]; // second 64 vals
|
|
235
|
-
HVX_VectorPair v2 = vptr[2]; // third 64 vals
|
|
236
|
-
HVX_VectorPair v3 = vptr[3]; // forth 64 vals
|
|
237
|
-
|
|
238
|
-
HVX_Vector vq0_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v0), Q6_V_vzero());
|
|
239
|
-
HVX_Vector vq0_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v0), Q6_V_vzero());
|
|
240
|
-
HVX_Vector vq1_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v1), Q6_V_vzero());
|
|
241
|
-
HVX_Vector vq1_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v1), Q6_V_vzero());
|
|
242
|
-
HVX_Vector vq2_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v2), Q6_V_vzero());
|
|
243
|
-
HVX_Vector vq2_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v2), Q6_V_vzero());
|
|
244
|
-
HVX_Vector vq3_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v3), Q6_V_vzero());
|
|
245
|
-
HVX_Vector vq3_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v3), Q6_V_vzero());
|
|
246
|
-
|
|
247
|
-
HVX_Vector vh0 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq0_hi, vq0_lo));
|
|
248
|
-
HVX_Vector vh1 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq1_hi, vq1_lo));
|
|
249
|
-
HVX_Vector vh2 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq2_hi, vq2_lo));
|
|
250
|
-
HVX_Vector vh3 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq3_hi, vq3_lo));
|
|
251
|
-
|
|
252
|
-
// vcombine does a shuffle, use vdeal to undo
|
|
253
|
-
|
|
254
|
-
HVX_Vector_x4 r = { Q6_Vh_vdeal_Vh(vh0), Q6_Vh_vdeal_Vh(vh1), Q6_Vh_vdeal_Vh(vh2), Q6_Vh_vdeal_Vh(vh3) };
|
|
255
|
-
return r;
|
|
231
|
+
static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_partial(const uint8_t * restrict ptr, uint32_t nloe) {
|
|
232
|
+
return hvx_vec_load_q8x4x8_full(ptr);
|
|
256
233
|
}
|
|
257
234
|
|
|
258
235
|
// Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors).
|
|
@@ -262,14 +239,14 @@ static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict
|
|
|
262
239
|
// if() checks are optimized out at compile time -- make sure to pass N as a constexpr.
|
|
263
240
|
|
|
264
241
|
static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
|
|
265
|
-
HVX_Vector r0 =
|
|
266
|
-
HVX_Vector r1 =
|
|
267
|
-
HVX_Vector r2 =
|
|
268
|
-
HVX_Vector r3 =
|
|
269
|
-
HVX_Vector r4 =
|
|
270
|
-
HVX_Vector r5 =
|
|
271
|
-
HVX_Vector r6 =
|
|
272
|
-
HVX_Vector r7 =
|
|
242
|
+
HVX_Vector r0 = Q6_V_vzero();
|
|
243
|
+
HVX_Vector r1 = Q6_V_vzero();
|
|
244
|
+
HVX_Vector r2 = Q6_V_vzero();
|
|
245
|
+
HVX_Vector r3 = Q6_V_vzero();
|
|
246
|
+
HVX_Vector r4 = Q6_V_vzero();
|
|
247
|
+
HVX_Vector r5 = Q6_V_vzero();
|
|
248
|
+
HVX_Vector r6 = Q6_V_vzero();
|
|
249
|
+
HVX_Vector r7 = Q6_V_vzero();
|
|
273
250
|
|
|
274
251
|
HVX_VectorPair p3;
|
|
275
252
|
HVX_VectorPair p2;
|
|
@@ -308,40 +285,67 @@ static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, uns
|
|
|
308
285
|
}
|
|
309
286
|
|
|
310
287
|
static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) {
|
|
311
|
-
|
|
288
|
+
HVX_Vector r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]);
|
|
289
|
+
HVX_Vector r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]);
|
|
290
|
+
HVX_Vector r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]);
|
|
291
|
+
HVX_Vector r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]);
|
|
292
|
+
HVX_Vector r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]);
|
|
293
|
+
HVX_Vector r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]);
|
|
294
|
+
HVX_Vector r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]);
|
|
295
|
+
HVX_Vector r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]);
|
|
296
|
+
|
|
297
|
+
HVX_VectorPair p0 = Q6_W_vdeal_VVR(r1, r0, -4);
|
|
298
|
+
HVX_VectorPair p1 = Q6_W_vdeal_VVR(r3, r2, -4);
|
|
299
|
+
HVX_VectorPair p2 = Q6_W_vdeal_VVR(r5, r4, -4);
|
|
300
|
+
HVX_VectorPair p3 = Q6_W_vdeal_VVR(r7, r6, -4);
|
|
301
|
+
|
|
302
|
+
r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
|
|
303
|
+
r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1));
|
|
304
|
+
r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2));
|
|
305
|
+
r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3));
|
|
306
|
+
|
|
307
|
+
p0 = Q6_W_vdeal_VVR(r1, r0, -4);
|
|
308
|
+
p1 = Q6_W_vdeal_VVR(r3, r2, -4);
|
|
309
|
+
|
|
310
|
+
r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
|
|
311
|
+
r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1));
|
|
312
|
+
|
|
313
|
+
p0 = Q6_W_vdeal_VVR(r1, r0, -4);
|
|
314
|
+
r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
|
|
315
|
+
|
|
316
|
+
return r0;
|
|
312
317
|
}
|
|
313
318
|
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
return hvx_vec_rmpy_x8_n(x, y, 1024);
|
|
319
|
+
static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
|
|
320
|
+
if (n >= 512)
|
|
321
|
+
return hvx_vec_rmpy_x8_full(x, y);
|
|
322
|
+
|
|
323
|
+
return hvx_vec_rmpy_x8_partial(x, y, 512);
|
|
320
324
|
}
|
|
321
325
|
|
|
322
|
-
static void
|
|
326
|
+
static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
|
|
323
327
|
assert(n % 32 == 0); // min sub-block size
|
|
324
|
-
assert((unsigned long)
|
|
325
|
-
assert((unsigned long)
|
|
328
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
329
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
326
330
|
|
|
327
331
|
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
328
332
|
|
|
329
|
-
const uint32_t x_dblk_size = 8 * 4 * 2;
|
|
330
|
-
const uint32_t x_qblk_size = qk / 2;
|
|
331
|
-
const uint32_t x_qrow_size = n / 2;
|
|
333
|
+
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
334
|
+
const uint32_t x_qblk_size = qk / 2; // int4
|
|
335
|
+
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
|
332
336
|
|
|
333
|
-
const uint32_t y_dblk_size = 8 * 4 * 2;
|
|
334
|
-
const uint32_t y_qblk_size = qk;
|
|
335
|
-
const uint32_t y_qrow_size = n;
|
|
337
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
338
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
339
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
336
340
|
|
|
337
|
-
const uint8_t * restrict r0_x_q = ((const uint8_t *)
|
|
338
|
-
const uint8_t * restrict r0_x_d = ((const uint8_t *)
|
|
341
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
|
|
342
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
|
|
339
343
|
|
|
340
|
-
const uint8_t * restrict y_q = ((const uint8_t *)
|
|
341
|
-
const uint8_t * restrict y_d = ((const uint8_t *)
|
|
344
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
345
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
342
346
|
|
|
343
|
-
// Row sum (
|
|
344
|
-
HVX_Vector r0_sum =
|
|
347
|
+
// Row sum (sf)
|
|
348
|
+
HVX_Vector r0_sum = Q6_V_vzero();
|
|
345
349
|
|
|
346
350
|
// Multiply and accumulate into int32.
|
|
347
351
|
// Compute combined scale (fp32).
|
|
@@ -352,79 +356,77 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
|
|
|
352
356
|
|
|
353
357
|
uint32_t i = 0;
|
|
354
358
|
for (; i < nb; i++) {
|
|
355
|
-
HVX_Vector_x8 vy_q =
|
|
356
|
-
HVX_Vector_x8 r0_q =
|
|
359
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
|
360
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
|
|
357
361
|
|
|
358
362
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
359
363
|
|
|
360
|
-
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d
|
|
364
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
361
365
|
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
362
366
|
|
|
363
367
|
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
364
368
|
|
|
365
369
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
366
370
|
|
|
367
|
-
r0_sum =
|
|
371
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
368
372
|
}
|
|
369
373
|
|
|
370
|
-
// Process leftovers
|
|
374
|
+
// Process leftovers
|
|
371
375
|
if (nloe) {
|
|
372
|
-
HVX_Vector_x8 vy_q =
|
|
373
|
-
HVX_Vector_x8 r0_q =
|
|
376
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
|
377
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
374
378
|
|
|
375
|
-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(
|
|
379
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
376
380
|
|
|
377
|
-
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d
|
|
381
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
378
382
|
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
379
383
|
|
|
380
384
|
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
381
385
|
|
|
382
|
-
// Zero out unused
|
|
386
|
+
// Zero out unused elements
|
|
383
387
|
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
384
388
|
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
389
|
+
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
385
390
|
|
|
386
391
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
387
392
|
|
|
388
|
-
r0_sum =
|
|
393
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
389
394
|
}
|
|
390
395
|
|
|
391
|
-
|
|
392
|
-
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
|
|
396
|
+
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
|
|
393
397
|
|
|
394
|
-
hvx_vec_store_u(
|
|
398
|
+
hvx_vec_store_u(s0, 4, r0_sum);
|
|
395
399
|
}
|
|
396
400
|
|
|
397
|
-
static void
|
|
398
|
-
|
|
399
|
-
const void * restrict
|
|
400
|
-
uint32_t vx_row_size,
|
|
401
|
-
const void * restrict vy) {
|
|
401
|
+
static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
402
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
403
|
+
const void * restrict vy0) {
|
|
402
404
|
assert(n % 32 == 0); // min sub-block size
|
|
403
|
-
assert((unsigned long)
|
|
404
|
-
assert((unsigned long)
|
|
405
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
406
|
+
assert((unsigned long) vx1 % 128 == 0);
|
|
407
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
405
408
|
|
|
406
409
|
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
407
410
|
|
|
408
|
-
const uint32_t x_dblk_size = 8 * 4 * 2;
|
|
409
|
-
const uint32_t x_qblk_size = qk / 2;
|
|
410
|
-
const uint32_t x_qrow_size = n / 2;
|
|
411
|
-
|
|
412
|
-
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
413
|
-
const uint32_t y_qblk_size = qk; // int8
|
|
414
|
-
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
411
|
+
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
412
|
+
const uint32_t x_qblk_size = qk / 2; // int4
|
|
413
|
+
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
|
415
414
|
|
|
416
|
-
const
|
|
417
|
-
const
|
|
415
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
416
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
417
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
418
418
|
|
|
419
|
-
const uint8_t * restrict
|
|
420
|
-
const uint8_t * restrict
|
|
419
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
|
|
420
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
|
421
|
+
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
|
422
|
+
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
|
421
423
|
|
|
422
|
-
const uint8_t * restrict y_q = ((const uint8_t *)
|
|
423
|
-
const uint8_t * restrict y_d = ((const uint8_t *)
|
|
424
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
425
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
424
426
|
|
|
425
|
-
// Row sum (
|
|
426
|
-
HVX_Vector r0_sum =
|
|
427
|
-
HVX_Vector r1_sum =
|
|
427
|
+
// Row sum (sf)
|
|
428
|
+
HVX_Vector r0_sum = Q6_V_vzero();
|
|
429
|
+
HVX_Vector r1_sum = Q6_V_vzero();
|
|
428
430
|
|
|
429
431
|
// Multiply and accumulate into int32.
|
|
430
432
|
// Compute combined scale (fp32).
|
|
@@ -435,14 +437,14 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
|
|
|
435
437
|
|
|
436
438
|
uint32_t i = 0;
|
|
437
439
|
for (; i < nb; i++) {
|
|
438
|
-
HVX_Vector_x8 vy_q =
|
|
439
|
-
HVX_Vector_x8 r0_q =
|
|
440
|
-
HVX_Vector_x8 r1_q =
|
|
440
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
|
441
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
|
|
442
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
|
|
441
443
|
|
|
442
444
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
443
445
|
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
|
444
446
|
|
|
445
|
-
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d
|
|
447
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
446
448
|
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
447
449
|
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
448
450
|
|
|
@@ -452,50 +454,178 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
|
|
|
452
454
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
453
455
|
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
454
456
|
|
|
455
|
-
r0_sum =
|
|
456
|
-
r1_sum =
|
|
457
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
458
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
457
459
|
}
|
|
458
460
|
|
|
459
|
-
// Process leftovers
|
|
461
|
+
// Process leftovers
|
|
460
462
|
if (nloe) {
|
|
461
|
-
HVX_Vector_x8 vy_q =
|
|
462
|
-
HVX_Vector_x8 r0_q =
|
|
463
|
-
HVX_Vector_x8 r1_q =
|
|
463
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
|
464
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
465
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
464
466
|
|
|
465
|
-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(
|
|
466
|
-
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(
|
|
467
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
468
|
+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
|
|
467
469
|
|
|
468
|
-
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d
|
|
470
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
469
471
|
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
470
472
|
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
471
473
|
|
|
472
474
|
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
473
475
|
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
|
474
476
|
|
|
475
|
-
// Zero out unused
|
|
477
|
+
// Zero out unused elements
|
|
476
478
|
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
477
479
|
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
478
480
|
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
|
481
|
+
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
482
|
+
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
|
479
483
|
|
|
480
484
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
481
485
|
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
482
486
|
|
|
483
|
-
r0_sum =
|
|
484
|
-
r1_sum =
|
|
487
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
488
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
485
489
|
}
|
|
486
490
|
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
+
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
|
|
492
|
+
hvx_vec_store_u(s0, 8, rsum);
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
|
|
496
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
497
|
+
const void * restrict vy0, const void * restrict vy1) {
|
|
498
|
+
assert(n % 32 == 0);
|
|
499
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
500
|
+
assert((unsigned long) vx1 % 128 == 0);
|
|
501
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
502
|
+
assert((unsigned long) vy1 % 128 == 0);
|
|
503
|
+
|
|
504
|
+
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
505
|
+
|
|
506
|
+
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
507
|
+
const uint32_t x_qblk_size = qk / 2; // int4
|
|
508
|
+
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
|
509
|
+
|
|
510
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
511
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
512
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
513
|
+
|
|
514
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
|
|
515
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
|
516
|
+
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
|
517
|
+
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
|
518
|
+
|
|
519
|
+
const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
|
|
520
|
+
const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
|
|
521
|
+
const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
|
|
522
|
+
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
|
|
523
|
+
|
|
524
|
+
// Row sums (sf) - 4 accumulators for 2×2 tile
|
|
525
|
+
HVX_Vector r0_c0_sum = Q6_V_vzero();
|
|
526
|
+
HVX_Vector r0_c1_sum = Q6_V_vzero();
|
|
527
|
+
HVX_Vector r1_c0_sum = Q6_V_vzero();
|
|
528
|
+
HVX_Vector r1_c1_sum = Q6_V_vzero();
|
|
529
|
+
|
|
530
|
+
const uint32_t nb = n / qk; // num full blocks
|
|
531
|
+
const uint32_t nloe = n % qk; // num leftover elements
|
|
532
|
+
|
|
533
|
+
uint32_t i = 0;
|
|
534
|
+
for (; i < nb; i++) {
|
|
535
|
+
// Load src1 columns (reused across both src0 rows)
|
|
536
|
+
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
|
|
537
|
+
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
|
|
538
|
+
|
|
539
|
+
// Load src0 rows (reused across both src1 columns)
|
|
540
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
|
|
541
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
|
|
542
|
+
|
|
543
|
+
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
|
|
544
|
+
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
|
|
545
|
+
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
|
|
546
|
+
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
|
|
547
|
+
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
|
|
548
|
+
|
|
549
|
+
// Load scales
|
|
550
|
+
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
|
|
551
|
+
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
|
|
552
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
553
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
554
|
+
|
|
555
|
+
// Compute combined scales
|
|
556
|
+
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
|
557
|
+
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
|
558
|
+
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
|
559
|
+
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
|
560
|
+
|
|
561
|
+
// Apply scales and accumulate
|
|
562
|
+
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
|
563
|
+
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
|
564
|
+
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
|
565
|
+
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
|
566
|
+
|
|
567
|
+
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
|
568
|
+
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
|
569
|
+
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
|
570
|
+
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
// Process leftovers
|
|
574
|
+
if (nloe) {
|
|
575
|
+
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
|
|
576
|
+
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
|
|
577
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
578
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
579
|
+
|
|
580
|
+
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
|
|
581
|
+
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
|
|
582
|
+
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
|
|
583
|
+
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
|
|
584
|
+
|
|
585
|
+
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
|
|
586
|
+
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
|
|
587
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
588
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
589
|
+
|
|
590
|
+
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
|
591
|
+
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
|
592
|
+
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
|
593
|
+
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
|
594
|
+
|
|
595
|
+
// Zero out unused scales
|
|
596
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
597
|
+
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
|
|
598
|
+
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
|
|
599
|
+
r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
|
|
600
|
+
r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
|
|
601
|
+
r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
|
|
602
|
+
r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
|
|
603
|
+
r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
|
|
604
|
+
r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
|
|
605
|
+
|
|
606
|
+
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
|
607
|
+
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
|
608
|
+
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
|
609
|
+
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
|
610
|
+
|
|
611
|
+
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
|
612
|
+
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
|
613
|
+
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
|
614
|
+
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
// Reduce and store results
|
|
618
|
+
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
|
|
619
|
+
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
|
|
491
620
|
|
|
492
|
-
hvx_vec_store_u(
|
|
621
|
+
hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0
|
|
622
|
+
hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1
|
|
493
623
|
}
|
|
494
624
|
|
|
495
|
-
static void
|
|
625
|
+
static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
|
|
496
626
|
assert(n % 32 == 0); // min sub-block size
|
|
497
|
-
assert((unsigned long)
|
|
498
|
-
assert((unsigned long)
|
|
627
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
628
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
499
629
|
|
|
500
630
|
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
501
631
|
|
|
@@ -507,14 +637,14 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
|
|
|
507
637
|
const uint32_t y_qblk_size = qk; // int8
|
|
508
638
|
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
509
639
|
|
|
510
|
-
const uint8_t * restrict r0_x_q = ((const uint8_t *)
|
|
511
|
-
const uint8_t * restrict r0_x_d = ((const uint8_t *)
|
|
640
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
|
|
641
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
|
|
512
642
|
|
|
513
|
-
const uint8_t * restrict y_q = ((const uint8_t *)
|
|
514
|
-
const uint8_t * restrict y_d = ((const uint8_t *)
|
|
643
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
644
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
515
645
|
|
|
516
|
-
// Row sum (
|
|
517
|
-
HVX_Vector r0_sum =
|
|
646
|
+
// Row sum (sf)
|
|
647
|
+
HVX_Vector r0_sum = Q6_V_vzero();
|
|
518
648
|
|
|
519
649
|
// Multiply and accumulate into int32.
|
|
520
650
|
// Compute combined scale (fp32).
|
|
@@ -525,79 +655,77 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
|
|
|
525
655
|
|
|
526
656
|
uint32_t i = 0;
|
|
527
657
|
for (; i < nb; i++) {
|
|
528
|
-
HVX_Vector_x8 vy_q =
|
|
529
|
-
HVX_Vector_x8 r0_q =
|
|
658
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
|
659
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
|
|
530
660
|
|
|
531
661
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
532
662
|
|
|
533
|
-
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d
|
|
663
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
534
664
|
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
535
665
|
|
|
536
666
|
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
537
667
|
|
|
538
668
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
539
669
|
|
|
540
|
-
r0_sum =
|
|
670
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
541
671
|
}
|
|
542
672
|
|
|
543
|
-
// Process leftovers
|
|
673
|
+
// Process leftovers
|
|
544
674
|
if (nloe) {
|
|
545
|
-
HVX_Vector_x8 vy_q =
|
|
546
|
-
HVX_Vector_x8 r0_q =
|
|
675
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
|
676
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
547
677
|
|
|
548
|
-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(
|
|
678
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
549
679
|
|
|
550
|
-
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d
|
|
680
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
551
681
|
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
552
682
|
|
|
553
683
|
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
554
684
|
|
|
555
|
-
// Zero out unused
|
|
685
|
+
// Zero out unused elements
|
|
556
686
|
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
557
687
|
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
688
|
+
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
558
689
|
|
|
559
690
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
560
691
|
|
|
561
|
-
r0_sum =
|
|
692
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
562
693
|
}
|
|
563
694
|
|
|
564
|
-
|
|
565
|
-
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
|
|
695
|
+
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
|
|
566
696
|
|
|
567
|
-
hvx_vec_store_u(
|
|
697
|
+
hvx_vec_store_u(s0, 4, r0_sum);
|
|
568
698
|
}
|
|
569
699
|
|
|
570
|
-
static void
|
|
571
|
-
|
|
572
|
-
const void * restrict
|
|
573
|
-
uint32_t vx_row_size,
|
|
574
|
-
const void * restrict vy) {
|
|
700
|
+
static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
701
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
702
|
+
const void * restrict vy0) {
|
|
575
703
|
assert(n % 32 == 0); // min sub-block size
|
|
576
|
-
assert((unsigned long)
|
|
577
|
-
assert((unsigned long)
|
|
704
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
705
|
+
assert((unsigned long) vx1 % 128 == 0);
|
|
706
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
578
707
|
|
|
579
708
|
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
580
709
|
|
|
581
|
-
const uint32_t x_dblk_size = 8 * 4 * 2;
|
|
582
|
-
const uint32_t x_qblk_size = qk;
|
|
583
|
-
const uint32_t x_qrow_size = n;
|
|
584
|
-
|
|
585
|
-
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
586
|
-
const uint32_t y_qblk_size = qk; // int8
|
|
587
|
-
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
710
|
+
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
711
|
+
const uint32_t x_qblk_size = qk; // int8
|
|
712
|
+
const uint32_t x_qrow_size = n; // int8 (not padded)
|
|
588
713
|
|
|
589
|
-
const
|
|
590
|
-
const
|
|
714
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
715
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
716
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
591
717
|
|
|
592
|
-
const uint8_t * restrict
|
|
593
|
-
const uint8_t * restrict
|
|
718
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
|
|
719
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
|
720
|
+
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
|
721
|
+
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
|
594
722
|
|
|
595
|
-
const uint8_t * restrict y_q = ((const uint8_t *)
|
|
596
|
-
const uint8_t * restrict y_d = ((const uint8_t *)
|
|
723
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
724
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
597
725
|
|
|
598
726
|
// Row sum (qf32)
|
|
599
|
-
HVX_Vector r0_sum =
|
|
600
|
-
HVX_Vector r1_sum =
|
|
727
|
+
HVX_Vector r0_sum = Q6_V_vzero();
|
|
728
|
+
HVX_Vector r1_sum = Q6_V_vzero();
|
|
601
729
|
|
|
602
730
|
// Multiply and accumulate into int32.
|
|
603
731
|
// Compute combined scale (fp32).
|
|
@@ -608,14 +736,14 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
|
|
|
608
736
|
|
|
609
737
|
uint32_t i = 0;
|
|
610
738
|
for (; i < nb; i++) {
|
|
611
|
-
HVX_Vector_x8 vy_q =
|
|
612
|
-
HVX_Vector_x8 r0_q =
|
|
613
|
-
HVX_Vector_x8 r1_q =
|
|
739
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
|
740
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
|
|
741
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
|
|
614
742
|
|
|
615
743
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
616
744
|
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
|
617
745
|
|
|
618
|
-
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d
|
|
746
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
619
747
|
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
620
748
|
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
621
749
|
|
|
@@ -625,18 +753,18 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
|
|
|
625
753
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
626
754
|
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
627
755
|
|
|
628
|
-
r0_sum =
|
|
629
|
-
r1_sum =
|
|
756
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
757
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
630
758
|
}
|
|
631
759
|
|
|
632
|
-
// Process leftovers
|
|
760
|
+
// Process leftovers
|
|
633
761
|
if (nloe) {
|
|
634
|
-
HVX_Vector_x8 vy_q =
|
|
635
|
-
HVX_Vector_x8 r0_q =
|
|
636
|
-
HVX_Vector_x8 r1_q =
|
|
762
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
|
763
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
764
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
637
765
|
|
|
638
|
-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(
|
|
639
|
-
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(
|
|
766
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
767
|
+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
|
|
640
768
|
|
|
641
769
|
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
642
770
|
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
@@ -645,33 +773,158 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
|
|
|
645
773
|
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
646
774
|
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
|
647
775
|
|
|
648
|
-
// Zero out unused
|
|
776
|
+
// Zero out unused elements
|
|
649
777
|
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
650
778
|
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
651
779
|
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
|
780
|
+
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
781
|
+
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
|
652
782
|
|
|
653
783
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
654
784
|
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
655
785
|
|
|
656
|
-
r0_sum =
|
|
657
|
-
r1_sum =
|
|
786
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
787
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
658
788
|
}
|
|
659
789
|
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
790
|
+
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
|
|
791
|
+
hvx_vec_store_u(s0, 8, rsum);
|
|
792
|
+
}
|
|
793
|
+
|
|
794
|
+
static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
|
|
795
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
796
|
+
const void * restrict vy0, const void * restrict vy1) {
|
|
797
|
+
assert(n % 32 == 0);
|
|
798
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
799
|
+
assert((unsigned long) vx1 % 128 == 0);
|
|
800
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
801
|
+
assert((unsigned long) vy1 % 128 == 0);
|
|
802
|
+
|
|
803
|
+
const uint32_t qk = QK_Q8_0x4x2 * 4;
|
|
804
|
+
|
|
805
|
+
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
806
|
+
const uint32_t x_qblk_size = qk; // int8
|
|
807
|
+
const uint32_t x_qrow_size = n; // int8 (not padded)
|
|
808
|
+
|
|
809
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
810
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
811
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
812
|
+
|
|
813
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
|
|
814
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
|
815
|
+
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
|
816
|
+
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
|
817
|
+
|
|
818
|
+
const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
|
|
819
|
+
const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
|
|
820
|
+
const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
|
|
821
|
+
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
|
|
822
|
+
|
|
823
|
+
// Row sums (sf) - 4 accumulators for 2×2 tile
|
|
824
|
+
HVX_Vector r0_c0_sum = Q6_V_vzero();
|
|
825
|
+
HVX_Vector r0_c1_sum = Q6_V_vzero();
|
|
826
|
+
HVX_Vector r1_c0_sum = Q6_V_vzero();
|
|
827
|
+
HVX_Vector r1_c1_sum = Q6_V_vzero();
|
|
828
|
+
|
|
829
|
+
const uint32_t nb = n / qk; // num full blocks
|
|
830
|
+
const uint32_t nloe = n % qk; // num leftover elements
|
|
831
|
+
|
|
832
|
+
uint32_t i = 0;
|
|
833
|
+
for (; i < nb; i++) {
|
|
834
|
+
// Load src1 columns (reused across both src0 rows)
|
|
835
|
+
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
|
|
836
|
+
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
|
|
837
|
+
|
|
838
|
+
// Load src0 rows (reused across both src1 columns)
|
|
839
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
|
|
840
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
|
|
841
|
+
|
|
842
|
+
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
|
|
843
|
+
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
|
|
844
|
+
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
|
|
845
|
+
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
|
|
846
|
+
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
|
|
847
|
+
|
|
848
|
+
// Load scales
|
|
849
|
+
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
|
|
850
|
+
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
|
|
851
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
852
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
853
|
+
|
|
854
|
+
// Compute combined scales
|
|
855
|
+
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
|
856
|
+
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
|
857
|
+
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
|
858
|
+
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
|
859
|
+
|
|
860
|
+
// Apply scales and accumulate
|
|
861
|
+
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
|
862
|
+
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
|
863
|
+
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
|
864
|
+
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
|
865
|
+
|
|
866
|
+
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
|
867
|
+
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
|
868
|
+
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
|
869
|
+
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
|
870
|
+
}
|
|
664
871
|
|
|
665
|
-
|
|
872
|
+
// Process leftovers
|
|
873
|
+
if (nloe) {
|
|
874
|
+
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
|
|
875
|
+
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
|
|
876
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
877
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
878
|
+
|
|
879
|
+
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
|
|
880
|
+
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
|
|
881
|
+
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
|
|
882
|
+
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
|
|
883
|
+
|
|
884
|
+
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
|
|
885
|
+
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
|
|
886
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
887
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
888
|
+
|
|
889
|
+
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
|
890
|
+
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
|
891
|
+
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
|
892
|
+
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
|
893
|
+
|
|
894
|
+
// Zero out unused elements
|
|
895
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
896
|
+
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
|
|
897
|
+
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
|
|
898
|
+
r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
|
|
899
|
+
r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
|
|
900
|
+
r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
|
|
901
|
+
r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
|
|
902
|
+
r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
|
|
903
|
+
r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
|
|
904
|
+
|
|
905
|
+
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
|
906
|
+
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
|
907
|
+
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
|
908
|
+
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
|
909
|
+
|
|
910
|
+
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
|
911
|
+
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
|
912
|
+
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
|
913
|
+
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
|
914
|
+
}
|
|
915
|
+
|
|
916
|
+
// Reduce and store results
|
|
917
|
+
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
|
|
918
|
+
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
|
|
919
|
+
|
|
920
|
+
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
|
|
921
|
+
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
|
|
666
922
|
}
|
|
667
923
|
|
|
668
|
-
static void
|
|
669
|
-
float * restrict s,
|
|
670
|
-
const void * restrict vx,
|
|
671
|
-
const void * restrict vy) {
|
|
924
|
+
static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
|
|
672
925
|
assert(n % 32 == 0); // min sub-block size
|
|
673
|
-
assert((unsigned long)
|
|
674
|
-
assert((unsigned long)
|
|
926
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
927
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
675
928
|
|
|
676
929
|
const uint32_t qk = QK_MXFP4x4x2 * 4;
|
|
677
930
|
|
|
@@ -683,14 +936,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
|
|
|
683
936
|
const uint32_t y_qblk_size = qk; // int8
|
|
684
937
|
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
685
938
|
|
|
686
|
-
const uint8_t * restrict r0_x_q = ((const uint8_t *)
|
|
687
|
-
const uint8_t * restrict r0_x_d = ((const uint8_t *)
|
|
939
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
|
|
940
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
|
|
688
941
|
|
|
689
|
-
const uint8_t * restrict y_q = ((const uint8_t *)
|
|
690
|
-
const uint8_t * restrict y_d = ((const uint8_t *)
|
|
942
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
943
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
691
944
|
|
|
692
|
-
// Row sum (
|
|
693
|
-
HVX_Vector r0_sum =
|
|
945
|
+
// Row sum (sf)
|
|
946
|
+
HVX_Vector r0_sum = Q6_V_vzero();
|
|
694
947
|
|
|
695
948
|
// Multiply and accumulate into int32.
|
|
696
949
|
// Compute combined scale (fp32).
|
|
@@ -701,8 +954,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
|
|
|
701
954
|
|
|
702
955
|
uint32_t i = 0;
|
|
703
956
|
for (; i < nb; i++) {
|
|
704
|
-
HVX_Vector_x8 vy_q =
|
|
705
|
-
HVX_Vector_x8 r0_q =
|
|
957
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size);
|
|
958
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
|
|
706
959
|
|
|
707
960
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
708
961
|
|
|
@@ -728,17 +981,17 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
|
|
|
728
981
|
|
|
729
982
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
730
983
|
|
|
731
|
-
r0_sum =
|
|
984
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
732
985
|
}
|
|
733
986
|
|
|
734
987
|
// Process leftovers
|
|
735
988
|
if (nloe) {
|
|
736
|
-
HVX_Vector_x8 vy_q =
|
|
737
|
-
HVX_Vector_x8 r0_q =
|
|
989
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe);
|
|
990
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
738
991
|
|
|
739
|
-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(
|
|
992
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
740
993
|
|
|
741
|
-
HVX_Vector vy_d = *(const HVX_UVector *) (y_d
|
|
994
|
+
HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
|
|
742
995
|
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
743
996
|
|
|
744
997
|
// Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
|
|
@@ -761,62 +1014,60 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
|
|
|
761
1014
|
// Zero-out unused scales
|
|
762
1015
|
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
763
1016
|
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
1017
|
+
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
764
1018
|
|
|
765
1019
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
766
1020
|
|
|
767
|
-
r0_sum =
|
|
1021
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
768
1022
|
}
|
|
769
1023
|
|
|
770
|
-
|
|
771
|
-
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
|
|
1024
|
+
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
|
|
772
1025
|
|
|
773
|
-
hvx_vec_store_u(
|
|
1026
|
+
hvx_vec_store_u(s0, 4, r0_sum);
|
|
774
1027
|
}
|
|
775
1028
|
|
|
776
|
-
static void
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
uint32_t vx_row_size,
|
|
780
|
-
const void * restrict vy) {
|
|
1029
|
+
static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
1030
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
1031
|
+
const void * restrict vy0) {
|
|
781
1032
|
assert(n % 32 == 0); // min sub-block size
|
|
782
|
-
assert((unsigned long)
|
|
783
|
-
assert((unsigned long)
|
|
1033
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
1034
|
+
assert((unsigned long) vx1 % 128 == 0);
|
|
1035
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
784
1036
|
|
|
785
1037
|
const uint32_t qk = QK_MXFP4x4x2 * 4;
|
|
786
1038
|
|
|
787
|
-
const uint32_t x_dblk_size = 8 * 4 * 1;
|
|
788
|
-
const uint32_t x_qblk_size = qk / 2;
|
|
789
|
-
const uint32_t x_qrow_size = n / 2;
|
|
1039
|
+
const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
|
|
1040
|
+
const uint32_t x_qblk_size = qk / 2; // fp4
|
|
1041
|
+
const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
|
|
790
1042
|
|
|
791
|
-
const uint32_t y_dblk_size = 8 * 4 * 2;
|
|
792
|
-
const uint32_t y_qblk_size = qk;
|
|
793
|
-
const uint32_t y_qrow_size = n;
|
|
1043
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
1044
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
1045
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
794
1046
|
|
|
795
|
-
const uint8_t * restrict r0_x_q = ((const uint8_t *)
|
|
796
|
-
const uint8_t * restrict r0_x_d = ((const uint8_t *)
|
|
1047
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
|
|
1048
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
|
1049
|
+
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
|
1050
|
+
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
|
797
1051
|
|
|
798
|
-
const uint8_t * restrict
|
|
799
|
-
const uint8_t * restrict
|
|
1052
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first
|
|
1053
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
|
|
800
1054
|
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
// Row sum (qf32)
|
|
805
|
-
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
|
|
806
|
-
HVX_Vector r1_sum = Q6_V_vsplat_R(0);
|
|
1055
|
+
// Row sum (sf)
|
|
1056
|
+
HVX_Vector r0_sum = Q6_V_vzero();
|
|
1057
|
+
HVX_Vector r1_sum = Q6_V_vzero();
|
|
807
1058
|
|
|
808
1059
|
// Multiply and accumulate into int32.
|
|
809
1060
|
// Compute combined scale (fp32).
|
|
810
|
-
// Apply scale to acc and accumulate into the row sum (
|
|
1061
|
+
// Apply scale to acc and accumulate into the row sum (f32).
|
|
811
1062
|
|
|
812
1063
|
const uint32_t nb = n / qk; // num full blocks
|
|
813
1064
|
int32_t nloe = n % qk; // num leftover elemements (must be signed)
|
|
814
1065
|
|
|
815
1066
|
uint32_t i = 0;
|
|
816
1067
|
for (; i < nb; i++) {
|
|
817
|
-
HVX_Vector_x8 vy_q =
|
|
818
|
-
HVX_Vector_x8 r0_q =
|
|
819
|
-
HVX_Vector_x8 r1_q =
|
|
1068
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size);
|
|
1069
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
|
|
1070
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size);
|
|
820
1071
|
|
|
821
1072
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
822
1073
|
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
|
@@ -849,20 +1100,20 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
|
|
|
849
1100
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
850
1101
|
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
851
1102
|
|
|
852
|
-
r0_sum =
|
|
853
|
-
r1_sum =
|
|
1103
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
1104
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
854
1105
|
}
|
|
855
1106
|
|
|
856
1107
|
// Process leftovers
|
|
857
1108
|
if (nloe) {
|
|
858
|
-
HVX_Vector_x8 vy_q =
|
|
859
|
-
HVX_Vector_x8 r0_q =
|
|
860
|
-
HVX_Vector_x8 r1_q =
|
|
1109
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe);
|
|
1110
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
1111
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
861
1112
|
|
|
862
1113
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
863
1114
|
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
|
864
1115
|
|
|
865
|
-
HVX_Vector vy_d = *(const HVX_UVector *) (y_d
|
|
1116
|
+
HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
|
|
866
1117
|
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
867
1118
|
HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
|
|
868
1119
|
|
|
@@ -887,111 +1138,326 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
|
|
|
887
1138
|
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
|
|
888
1139
|
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
|
|
889
1140
|
|
|
890
|
-
// Zero-out unused
|
|
1141
|
+
// Zero-out unused values
|
|
891
1142
|
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
892
1143
|
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
893
1144
|
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
|
1145
|
+
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
1146
|
+
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
|
894
1147
|
|
|
895
1148
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
896
1149
|
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
897
1150
|
|
|
898
|
-
r0_sum =
|
|
899
|
-
r1_sum =
|
|
1151
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
1152
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
900
1153
|
}
|
|
901
1154
|
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
1155
|
+
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
|
|
1156
|
+
hvx_vec_store_u(s0, 8, rsum);
|
|
1157
|
+
}
|
|
1158
|
+
|
|
1159
|
+
static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
|
|
1160
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
1161
|
+
const void * restrict vy0, const void * restrict vy1) {
|
|
1162
|
+
assert(n % 32 == 0);
|
|
1163
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
1164
|
+
assert((unsigned long) vx1 % 128 == 0);
|
|
1165
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
1166
|
+
assert((unsigned long) vy1 % 128 == 0);
|
|
1167
|
+
|
|
1168
|
+
const uint32_t qk = QK_MXFP4x4x2 * 4;
|
|
1169
|
+
|
|
1170
|
+
const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
|
|
1171
|
+
const uint32_t x_qblk_size = qk / 2; // fp4
|
|
1172
|
+
const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
|
|
1173
|
+
|
|
1174
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
1175
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
1176
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
1177
|
+
|
|
1178
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
|
|
1179
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
|
1180
|
+
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
|
1181
|
+
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
|
1182
|
+
|
|
1183
|
+
const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
|
|
1184
|
+
const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
|
|
1185
|
+
const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
|
|
1186
|
+
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
|
|
1187
|
+
|
|
1188
|
+
// Row sums (sf) - 4 accumulators for 2×2 tile
|
|
1189
|
+
HVX_Vector r0_c0_sum = Q6_V_vzero();
|
|
1190
|
+
HVX_Vector r0_c1_sum = Q6_V_vzero();
|
|
1191
|
+
HVX_Vector r1_c0_sum = Q6_V_vzero();
|
|
1192
|
+
HVX_Vector r1_c1_sum = Q6_V_vzero();
|
|
1193
|
+
|
|
1194
|
+
const uint32_t nb = n / qk; // num full blocks
|
|
1195
|
+
const uint32_t nloe = n % qk; // num leftover elements
|
|
1196
|
+
|
|
1197
|
+
uint32_t i = 0;
|
|
1198
|
+
for (; i < nb; i++) {
|
|
1199
|
+
// Load src1 columns (reused across both src0 rows)
|
|
1200
|
+
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
|
|
1201
|
+
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
|
|
1202
|
+
|
|
1203
|
+
// Load src0 rows (reused across both src1 columns)
|
|
1204
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
|
|
1205
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size);
|
|
1206
|
+
|
|
1207
|
+
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
|
|
1208
|
+
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
|
|
1209
|
+
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
|
|
1210
|
+
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
|
|
1211
|
+
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
|
|
1212
|
+
|
|
1213
|
+
// Load scales
|
|
1214
|
+
HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
|
|
1215
|
+
HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
|
|
1216
|
+
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
1217
|
+
HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
|
|
1218
|
+
|
|
1219
|
+
// Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
|
|
1220
|
+
HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
|
|
1221
|
+
vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
|
|
1222
|
+
vy0_d = Q6_Vsf_equals_Vqf32(vy0_d);
|
|
1223
|
+
vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
|
|
1224
|
+
vy1_d = Q6_Vsf_equals_Vqf32(vy1_d);
|
|
1225
|
+
|
|
1226
|
+
// Convert rX_d scales from e8m0 to fp32
|
|
1227
|
+
// Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
|
|
1228
|
+
// Left shift with zero fill to create FP32
|
|
1229
|
+
// FIXME: might need to handle zero as a special case (see ggml-cpu code)
|
|
1230
|
+
HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
|
|
1231
|
+
HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
|
|
1232
|
+
r0_d = Q6_V_vdelta_VV(r0_d, expand);
|
|
1233
|
+
r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
|
|
1234
|
+
r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
|
|
1235
|
+
r1_d = Q6_V_vdelta_VV(r1_d, expand);
|
|
1236
|
+
r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
|
|
1237
|
+
r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
|
|
1238
|
+
|
|
1239
|
+
// Compute combined scales
|
|
1240
|
+
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
|
|
1241
|
+
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
|
|
1242
|
+
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
|
|
1243
|
+
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
|
|
1244
|
+
|
|
1245
|
+
// Apply scales and accumulate
|
|
1246
|
+
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
|
1247
|
+
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
|
1248
|
+
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
|
1249
|
+
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
|
1250
|
+
|
|
1251
|
+
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
|
1252
|
+
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
|
1253
|
+
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
|
1254
|
+
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
|
1255
|
+
}
|
|
1256
|
+
|
|
1257
|
+
// Process leftovers
|
|
1258
|
+
if (nloe) {
|
|
1259
|
+
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial( y0_q + i * y_qblk_size, nloe);
|
|
1260
|
+
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial( y1_q + i * y_qblk_size, nloe);
|
|
1261
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
1262
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
1263
|
+
|
|
1264
|
+
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
|
|
1265
|
+
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
|
|
1266
|
+
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
|
|
1267
|
+
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
|
|
1268
|
+
|
|
1269
|
+
HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
|
|
1270
|
+
HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
|
|
1271
|
+
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
1272
|
+
HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
|
|
906
1273
|
|
|
907
|
-
|
|
1274
|
+
// Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
|
|
1275
|
+
HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
|
|
1276
|
+
vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
|
|
1277
|
+
vy0_d = Q6_Vsf_equals_Vqf32(vy0_d);
|
|
1278
|
+
vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
|
|
1279
|
+
vy1_d = Q6_Vsf_equals_Vqf32(vy1_d);
|
|
1280
|
+
|
|
1281
|
+
// Convert rX_d scales from e8m0 to fp32
|
|
1282
|
+
// Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
|
|
1283
|
+
// Left shift with zero fill to create FP32
|
|
1284
|
+
// FIXME: might need to handle zero as a special case (see ggml-cpu code)
|
|
1285
|
+
HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
|
|
1286
|
+
HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
|
|
1287
|
+
r0_d = Q6_V_vdelta_VV(r0_d, expand);
|
|
1288
|
+
r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
|
|
1289
|
+
r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
|
|
1290
|
+
r1_d = Q6_V_vdelta_VV(r1_d, expand);
|
|
1291
|
+
r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
|
|
1292
|
+
r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
|
|
1293
|
+
|
|
1294
|
+
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
|
|
1295
|
+
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
|
|
1296
|
+
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
|
|
1297
|
+
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
|
|
1298
|
+
|
|
1299
|
+
// Zero out unused scales
|
|
1300
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
1301
|
+
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
|
|
1302
|
+
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
|
|
1303
|
+
r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
|
|
1304
|
+
r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
|
|
1305
|
+
r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
|
|
1306
|
+
r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
|
|
1307
|
+
r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
|
|
1308
|
+
r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
|
|
1309
|
+
|
|
1310
|
+
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
|
1311
|
+
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
|
1312
|
+
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
|
1313
|
+
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
|
1314
|
+
|
|
1315
|
+
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
|
1316
|
+
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
|
1317
|
+
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
|
1318
|
+
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
|
1319
|
+
}
|
|
1320
|
+
|
|
1321
|
+
// Reduce and store results
|
|
1322
|
+
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
|
|
1323
|
+
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
|
|
1324
|
+
|
|
1325
|
+
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
|
|
1326
|
+
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
|
|
908
1327
|
}
|
|
909
1328
|
|
|
910
|
-
static void
|
|
1329
|
+
static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
|
911
1330
|
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
|
|
912
1331
|
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
|
|
913
1332
|
|
|
914
1333
|
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
|
915
1334
|
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
|
916
1335
|
|
|
917
|
-
|
|
1336
|
+
HVX_VectorPair rsum_p = Q6_W_vzero();
|
|
918
1337
|
|
|
919
1338
|
uint32_t i = 0;
|
|
920
1339
|
|
|
921
1340
|
#pragma unroll(4)
|
|
922
1341
|
for (i = 0; i < nvec; i++) {
|
|
923
|
-
|
|
924
|
-
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
|
1342
|
+
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x[i], y[i]);
|
|
925
1343
|
}
|
|
926
1344
|
|
|
927
1345
|
if (nloe) {
|
|
928
1346
|
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
|
929
1347
|
HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
|
|
930
1348
|
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
|
|
931
|
-
|
|
932
|
-
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
|
933
|
-
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
|
1349
|
+
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
|
|
934
1350
|
}
|
|
935
1351
|
|
|
936
|
-
rsum = Q6_Vsf_equals_Vqf32(
|
|
937
|
-
hvx_vec_store_u(
|
|
1352
|
+
HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
|
|
1353
|
+
hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum));
|
|
938
1354
|
}
|
|
939
1355
|
|
|
940
|
-
static void
|
|
941
|
-
|
|
942
|
-
const void * restrict
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
const HVX_Vector * restrict
|
|
946
|
-
const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size);
|
|
947
|
-
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
|
|
1356
|
+
static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0,
|
|
1357
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
1358
|
+
const void * restrict vy0) {
|
|
1359
|
+
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
|
|
1360
|
+
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
|
|
1361
|
+
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
|
|
948
1362
|
|
|
949
1363
|
uint32_t nvec = n / VLEN_FP16;
|
|
950
1364
|
uint32_t nloe = n % VLEN_FP16;
|
|
951
1365
|
|
|
952
|
-
|
|
953
|
-
|
|
1366
|
+
HVX_VectorPair rsum0_p = Q6_W_vzero();
|
|
1367
|
+
HVX_VectorPair rsum1_p = Q6_W_vzero();
|
|
954
1368
|
|
|
955
1369
|
uint32_t i = 0;
|
|
956
1370
|
|
|
957
1371
|
#pragma unroll(2)
|
|
958
1372
|
for (i = 0; i < nvec; i++) {
|
|
959
1373
|
HVX_Vector y_hf = y[i];
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
|
|
964
|
-
rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
|
|
1374
|
+
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0[i], y_hf);
|
|
1375
|
+
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1[i], y_hf);
|
|
965
1376
|
}
|
|
966
1377
|
|
|
967
1378
|
if (nloe) {
|
|
968
1379
|
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
|
1380
|
+
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
|
|
969
1381
|
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
|
|
970
1382
|
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
|
|
971
|
-
|
|
1383
|
+
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
|
|
1384
|
+
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
|
|
1385
|
+
}
|
|
1386
|
+
|
|
1387
|
+
HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
|
|
1388
|
+
HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
|
|
1389
|
+
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
|
|
1390
|
+
hvx_vec_store_u(s0, 8, rsum);
|
|
1391
|
+
}
|
|
972
1392
|
|
|
973
|
-
|
|
974
|
-
|
|
1393
|
+
static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1,
|
|
1394
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
1395
|
+
const void * restrict vy0, const void * restrict vy1) {
|
|
1396
|
+
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
|
|
1397
|
+
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
|
|
1398
|
+
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
|
|
1399
|
+
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
|
|
1400
|
+
|
|
1401
|
+
uint32_t nvec = n / VLEN_FP16;
|
|
1402
|
+
uint32_t nloe = n % VLEN_FP16;
|
|
975
1403
|
|
|
976
|
-
|
|
977
|
-
|
|
1404
|
+
// Row sums (sf) - 4 accumulators for 2×2 tile
|
|
1405
|
+
HVX_VectorPair r0_c0_sum_p = Q6_W_vzero();
|
|
1406
|
+
HVX_VectorPair r0_c1_sum_p = Q6_W_vzero();
|
|
1407
|
+
HVX_VectorPair r1_c0_sum_p = Q6_W_vzero();
|
|
1408
|
+
HVX_VectorPair r1_c1_sum_p = Q6_W_vzero();
|
|
1409
|
+
|
|
1410
|
+
uint32_t i = 0;
|
|
1411
|
+
|
|
1412
|
+
#pragma unroll(2)
|
|
1413
|
+
for (i = 0; i < nvec; i++) {
|
|
1414
|
+
HVX_Vector r0_hf = x0[i];
|
|
1415
|
+
HVX_Vector r1_hf = x1[i];
|
|
1416
|
+
HVX_Vector c0_hf = y0[i];
|
|
1417
|
+
HVX_Vector c1_hf = y1[i];
|
|
1418
|
+
|
|
1419
|
+
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
|
|
1420
|
+
r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
|
|
1421
|
+
r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
|
|
1422
|
+
r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
|
|
1423
|
+
r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
|
|
1424
|
+
}
|
|
1425
|
+
|
|
1426
|
+
if (nloe) {
|
|
1427
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
|
1428
|
+
|
|
1429
|
+
HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]);
|
|
1430
|
+
HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]);
|
|
1431
|
+
HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);
|
|
1432
|
+
HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);
|
|
1433
|
+
|
|
1434
|
+
r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
|
|
1435
|
+
r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
|
|
1436
|
+
r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
|
|
1437
|
+
r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
|
|
978
1438
|
}
|
|
979
1439
|
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
1440
|
+
HVX_Vector r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c0_sum_p), Q6_V_hi_W(r0_c0_sum_p)));
|
|
1441
|
+
HVX_Vector r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c1_sum_p), Q6_V_hi_W(r0_c1_sum_p)));
|
|
1442
|
+
HVX_Vector r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c0_sum_p), Q6_V_hi_W(r1_c0_sum_p)));
|
|
1443
|
+
HVX_Vector r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c1_sum_p), Q6_V_hi_W(r1_c1_sum_p)));
|
|
1444
|
+
|
|
1445
|
+
// Reduce and store results
|
|
1446
|
+
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
|
|
1447
|
+
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
|
|
983
1448
|
|
|
984
|
-
hvx_vec_store_u(&
|
|
1449
|
+
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
|
|
1450
|
+
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
|
|
985
1451
|
}
|
|
986
1452
|
|
|
987
|
-
static void
|
|
1453
|
+
static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
|
988
1454
|
const HVX_UVector * restrict x = (const HVX_UVector *) vx;
|
|
989
1455
|
const HVX_UVector * restrict y = (const HVX_UVector *) vy;
|
|
990
1456
|
|
|
991
1457
|
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
|
992
1458
|
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
|
993
1459
|
|
|
994
|
-
HVX_Vector rsum =
|
|
1460
|
+
HVX_Vector rsum = Q6_V_vzero();
|
|
995
1461
|
|
|
996
1462
|
uint32_t i = 0;
|
|
997
1463
|
|
|
@@ -1010,20 +1476,20 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res
|
|
|
1010
1476
|
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
|
1011
1477
|
}
|
|
1012
1478
|
|
|
1013
|
-
rsum = Q6_Vsf_equals_Vqf32(
|
|
1479
|
+
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
|
|
1014
1480
|
hvx_vec_store_u(&s[0], 4, rsum);
|
|
1015
1481
|
}
|
|
1016
1482
|
|
|
1017
|
-
static void
|
|
1483
|
+
static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
|
|
1018
1484
|
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
|
|
1019
1485
|
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
|
|
1020
1486
|
|
|
1021
1487
|
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
|
1022
1488
|
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
|
1023
1489
|
|
|
1024
|
-
const HVX_Vector zero =
|
|
1490
|
+
const HVX_Vector zero = Q6_V_vzero();
|
|
1025
1491
|
|
|
1026
|
-
HVX_Vector rsum =
|
|
1492
|
+
HVX_Vector rsum = Q6_V_vzero();
|
|
1027
1493
|
|
|
1028
1494
|
uint32_t i = 0;
|
|
1029
1495
|
|
|
@@ -1062,7 +1528,8 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res
|
|
|
1062
1528
|
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
|
1063
1529
|
}
|
|
1064
1530
|
|
|
1065
|
-
|
|
1531
|
+
// Convert into fp32 and reduce
|
|
1532
|
+
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
|
|
1066
1533
|
hvx_vec_store_u(&s[0], 4, rsum);
|
|
1067
1534
|
}
|
|
1068
1535
|
|
|
@@ -1110,14 +1577,16 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res
|
|
|
1110
1577
|
const uint32_t nb2 = dst->nb[2]; \
|
|
1111
1578
|
const uint32_t nb3 = dst->nb[3];
|
|
1112
1579
|
|
|
1113
|
-
#define htp_matmul_preamble
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1580
|
+
#define htp_matmul_preamble \
|
|
1581
|
+
struct htp_matmul_context * mmctx = data; \
|
|
1582
|
+
struct htp_ops_context * octx = mmctx->octx; \
|
|
1583
|
+
htp_matmul_tensors_preamble; \
|
|
1584
|
+
dma_queue *dma_queue = octx->ctx->dma[ith]; \
|
|
1585
|
+
uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread;
|
|
1117
1586
|
|
|
1118
1587
|
// *** matmul with support for 4d tensors and full broadcasting
|
|
1119
1588
|
|
|
1120
|
-
static void matmul_4d(
|
|
1589
|
+
static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
|
|
1121
1590
|
htp_matmul_preamble;
|
|
1122
1591
|
|
|
1123
1592
|
uint64_t t1, t2;
|
|
@@ -1163,13 +1632,13 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|
|
1163
1632
|
for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
|
1164
1633
|
for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
|
1165
1634
|
for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
|
|
1166
|
-
const uint32_t i13 = fastdiv(ir1, &
|
|
1167
|
-
const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &
|
|
1635
|
+
const uint32_t i13 = fastdiv(ir1, &mmctx->mm_div_ne12_ne1);
|
|
1636
|
+
const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &mmctx->mm_div_ne1);
|
|
1168
1637
|
const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
|
|
1169
1638
|
|
|
1170
1639
|
// broadcast src0 into src1
|
|
1171
|
-
const uint32_t i03 = fastdiv(i13, &
|
|
1172
|
-
const uint32_t i02 = fastdiv(i12, &
|
|
1640
|
+
const uint32_t i03 = fastdiv(i13, &mmctx->mm_div_r3);
|
|
1641
|
+
const uint32_t i02 = fastdiv(i12, &mmctx->mm_div_r2);
|
|
1173
1642
|
|
|
1174
1643
|
const uint32_t i1 = i11;
|
|
1175
1644
|
const uint32_t i2 = i12;
|
|
@@ -1182,7 +1651,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|
|
1182
1651
|
const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
|
|
1183
1652
|
for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
|
|
1184
1653
|
const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
|
|
1185
|
-
|
|
1654
|
+
mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col);
|
|
1186
1655
|
}
|
|
1187
1656
|
}
|
|
1188
1657
|
}
|
|
@@ -1197,7 +1666,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|
|
1197
1666
|
}
|
|
1198
1667
|
|
|
1199
1668
|
// src1 tensor is already in VTCM spad
|
|
1200
|
-
static void matmul_2d(
|
|
1669
|
+
static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
|
|
1201
1670
|
htp_matmul_preamble;
|
|
1202
1671
|
|
|
1203
1672
|
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
|
@@ -1222,7 +1691,7 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|
|
1222
1691
|
// Per-thread VTCM scratchpads for all tensors
|
|
1223
1692
|
// Note that the entire src1 tensor is already in VTCM
|
|
1224
1693
|
// For other tensors we allocate N rows per thread, padded to HVX vector size
|
|
1225
|
-
uint8_t * restrict spad_dst = dst_spad->data
|
|
1694
|
+
uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
|
|
1226
1695
|
uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
|
|
1227
1696
|
uint8_t * restrict src1_data = src1_spad->data;
|
|
1228
1697
|
|
|
@@ -1246,11 +1715,21 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|
|
1246
1715
|
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
|
|
1247
1716
|
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
|
1248
1717
|
|
|
1249
|
-
|
|
1250
|
-
|
|
1718
|
+
// Process src1 columns in pairs (2×2 tiling)
|
|
1719
|
+
uint32_t ir1 = 0;
|
|
1720
|
+
for (; ir1 + 1 < src1_nrows; ir1 += 2) {
|
|
1721
|
+
const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride);
|
|
1722
|
+
const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride);
|
|
1723
|
+
float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size));
|
|
1724
|
+
float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size));
|
|
1725
|
+
mmctx->vec_dot_2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1);
|
|
1726
|
+
}
|
|
1727
|
+
|
|
1728
|
+
// Handle remaining src1 rows (fallback to 2×1)
|
|
1729
|
+
for (; ir1 < src1_nrows; ++ir1) {
|
|
1251
1730
|
const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
|
|
1252
1731
|
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
|
|
1253
|
-
|
|
1732
|
+
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col);
|
|
1254
1733
|
}
|
|
1255
1734
|
|
|
1256
1735
|
// Prefetch next (n + spad_nrows) row
|
|
@@ -1274,20 +1753,20 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|
|
1274
1753
|
for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
|
|
1275
1754
|
const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
|
|
1276
1755
|
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
|
|
1277
|
-
|
|
1756
|
+
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
|
|
1278
1757
|
}
|
|
1279
1758
|
}
|
|
1280
1759
|
|
|
1281
1760
|
t2 = HAP_perf_get_qtimer_count();
|
|
1282
1761
|
|
|
1283
|
-
FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n",
|
|
1762
|
+
FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth,
|
|
1284
1763
|
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
|
|
1285
1764
|
src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
|
1286
1765
|
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
|
1287
1766
|
}
|
|
1288
1767
|
|
|
1289
1768
|
// q8x4x2 src1 tensor is already in VTCM spad
|
|
1290
|
-
static void matvec_2d(
|
|
1769
|
+
static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
|
|
1291
1770
|
htp_matmul_preamble;
|
|
1292
1771
|
|
|
1293
1772
|
const uint32_t src0_nrows = ne01;
|
|
@@ -1338,7 +1817,7 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|
|
1338
1817
|
// Process src0 rows
|
|
1339
1818
|
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
|
|
1340
1819
|
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
|
1341
|
-
|
|
1820
|
+
mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
|
|
1342
1821
|
|
|
1343
1822
|
// Prefetch next (n + spad_nrows) row
|
|
1344
1823
|
const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
|
|
@@ -1356,14 +1835,14 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|
|
1356
1835
|
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
|
|
1357
1836
|
src0_stride, src0_row_size, 1);
|
|
1358
1837
|
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
|
1359
|
-
|
|
1838
|
+
mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
|
|
1360
1839
|
}
|
|
1361
1840
|
|
|
1362
|
-
|
|
1841
|
+
hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
|
|
1363
1842
|
|
|
1364
1843
|
t2 = HAP_perf_get_qtimer_count();
|
|
1365
1844
|
|
|
1366
|
-
FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n",
|
|
1845
|
+
FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth,
|
|
1367
1846
|
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
|
|
1368
1847
|
src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
|
1369
1848
|
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
|
@@ -1377,7 +1856,7 @@ struct mmid_row_mapping {
|
|
|
1377
1856
|
};
|
|
1378
1857
|
|
|
1379
1858
|
// src1 tensor is already in VTCM spad
|
|
1380
|
-
static void matmul_id(
|
|
1859
|
+
static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
|
1381
1860
|
htp_matmul_preamble;
|
|
1382
1861
|
|
|
1383
1862
|
struct htp_tensor * restrict ids = &octx->src2;
|
|
@@ -1411,7 +1890,7 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|
|
1411
1890
|
const size_t src0_row_size = nb01;
|
|
1412
1891
|
const size_t src1_row_size = q8x4x2_row_size(ne10);
|
|
1413
1892
|
|
|
1414
|
-
const size_t src0_row_size_padded =
|
|
1893
|
+
const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
|
|
1415
1894
|
|
|
1416
1895
|
// Per-thread VTCM scratchpads for all tensors
|
|
1417
1896
|
// Note that the entire src1 tensor is already in VTCM
|
|
@@ -1450,11 +1929,10 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|
|
1450
1929
|
const int rm2 = row_mapping.i2; // token idx
|
|
1451
1930
|
|
|
1452
1931
|
const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx
|
|
1453
|
-
const uint8_t * restrict src1_col =
|
|
1454
|
-
(const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
|
|
1932
|
+
const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
|
|
1455
1933
|
float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
|
|
1456
1934
|
|
|
1457
|
-
|
|
1935
|
+
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
|
|
1458
1936
|
}
|
|
1459
1937
|
|
|
1460
1938
|
// Prefetch next (n + spad_nrows) row
|
|
@@ -1480,25 +1958,24 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|
|
1480
1958
|
const int rm2 = row_mapping.i2; // token idx
|
|
1481
1959
|
|
|
1482
1960
|
const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx
|
|
1483
|
-
const uint8_t * restrict src1_col =
|
|
1484
|
-
(const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
|
|
1961
|
+
const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
|
|
1485
1962
|
float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
|
|
1486
1963
|
|
|
1487
|
-
|
|
1964
|
+
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
|
|
1488
1965
|
}
|
|
1489
1966
|
}
|
|
1490
1967
|
}
|
|
1491
1968
|
|
|
1492
1969
|
t2 = HAP_perf_get_qtimer_count();
|
|
1493
1970
|
|
|
1494
|
-
FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n",
|
|
1971
|
+
FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type,
|
|
1495
1972
|
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
|
|
1496
1973
|
src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1],
|
|
1497
1974
|
dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
|
1498
1975
|
}
|
|
1499
1976
|
|
|
1500
1977
|
// src1 tensor is already in VTCM spad
|
|
1501
|
-
static void matvec_id(
|
|
1978
|
+
static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
|
|
1502
1979
|
htp_matmul_preamble;
|
|
1503
1980
|
|
|
1504
1981
|
struct htp_tensor * restrict ids = &octx->src2;
|
|
@@ -1524,7 +2001,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|
|
1524
2001
|
const size_t src0_row_size = nb01;
|
|
1525
2002
|
const size_t src1_row_size = q8x4x2_row_size(ne10);
|
|
1526
2003
|
|
|
1527
|
-
const size_t src0_row_size_padded =
|
|
2004
|
+
const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
|
|
1528
2005
|
|
|
1529
2006
|
const uint32_t n_aids = src2->ne[0]; // num activated experts
|
|
1530
2007
|
const uint32_t n_ids = ne02; // num experts
|
|
@@ -1558,7 +2035,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|
|
1558
2035
|
// Process src0 rows
|
|
1559
2036
|
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
|
|
1560
2037
|
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
|
1561
|
-
|
|
2038
|
+
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
|
|
1562
2039
|
|
|
1563
2040
|
// Prefetch next (n + spad_nrows) row
|
|
1564
2041
|
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
|
|
@@ -1576,13 +2053,13 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|
|
1576
2053
|
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
|
|
1577
2054
|
src0_row_size_padded, src0_row_size, 1);
|
|
1578
2055
|
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
|
1579
|
-
|
|
2056
|
+
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
|
|
1580
2057
|
}
|
|
1581
2058
|
}
|
|
1582
2059
|
|
|
1583
2060
|
t2 = HAP_perf_get_qtimer_count();
|
|
1584
2061
|
|
|
1585
|
-
FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n",
|
|
2062
|
+
FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type,
|
|
1586
2063
|
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
|
|
1587
2064
|
src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0],
|
|
1588
2065
|
dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
|
@@ -1590,18 +2067,18 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|
|
1590
2067
|
|
|
1591
2068
|
// *** dynamic quant
|
|
1592
2069
|
|
|
1593
|
-
static inline void
|
|
2070
|
+
static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
|
|
1594
2071
|
assert((unsigned long) x % 128 == 0);
|
|
1595
2072
|
assert((unsigned long) y_q % 128 == 0);
|
|
1596
2073
|
|
|
1597
2074
|
HVX_Vector * vx = (HVX_Vector *) x;
|
|
1598
|
-
HVX_Vector zero =
|
|
2075
|
+
HVX_Vector zero = Q6_V_vzero();
|
|
1599
2076
|
|
|
1600
2077
|
// Use reduce max fp32 to find max(abs(e)) first
|
|
1601
|
-
HVX_Vector vmax0_sf =
|
|
1602
|
-
HVX_Vector vmax1_sf =
|
|
1603
|
-
HVX_Vector vmax2_sf =
|
|
1604
|
-
HVX_Vector vmax3_sf =
|
|
2078
|
+
HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0]));
|
|
2079
|
+
HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1]));
|
|
2080
|
+
HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2]));
|
|
2081
|
+
HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3]));
|
|
1605
2082
|
// Load and convert into QF32
|
|
1606
2083
|
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
|
|
1607
2084
|
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
|
|
@@ -1609,10 +2086,10 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri
|
|
|
1609
2086
|
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
|
|
1610
2087
|
|
|
1611
2088
|
// Convert to QF32
|
|
1612
|
-
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
|
|
1613
|
-
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
|
|
1614
|
-
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
|
|
1615
|
-
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
|
|
2089
|
+
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes
|
|
2090
|
+
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes
|
|
2091
|
+
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes
|
|
2092
|
+
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes
|
|
1616
2093
|
|
|
1617
2094
|
// Combine and convert to fp16
|
|
1618
2095
|
HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
|
|
@@ -1622,11 +2099,6 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri
|
|
|
1622
2099
|
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
|
|
1623
2100
|
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
|
|
1624
2101
|
|
|
1625
|
-
// Replicate first fp16 scale across all lanes
|
|
1626
|
-
HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_fp16;
|
|
1627
|
-
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
|
|
1628
|
-
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
|
|
1629
|
-
|
|
1630
2102
|
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
|
1631
2103
|
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
|
1632
2104
|
HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
|
|
@@ -1641,8 +2113,8 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri
|
|
|
1641
2113
|
hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf);
|
|
1642
2114
|
|
|
1643
2115
|
// Divide input by the scale
|
|
1644
|
-
HVX_Vector vd01_inv_hf =
|
|
1645
|
-
HVX_Vector vd23_inv_hf =
|
|
2116
|
+
HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
|
|
2117
|
+
HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
|
|
1646
2118
|
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
|
|
1647
2119
|
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
|
|
1648
2120
|
|
|
@@ -1654,14 +2126,14 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri
|
|
|
1654
2126
|
*(HVX_Vector *) y_q = vx_i8;
|
|
1655
2127
|
}
|
|
1656
2128
|
|
|
1657
|
-
static inline void
|
|
2129
|
+
static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
|
|
1658
2130
|
assert((unsigned long) x % 128 == 0);
|
|
1659
2131
|
assert((unsigned long) y_q % 128 == 0);
|
|
1660
2132
|
|
|
1661
2133
|
HVX_Vector * vx = (HVX_Vector *) x;
|
|
1662
2134
|
|
|
1663
2135
|
// Load and convert into QF32
|
|
1664
|
-
HVX_Vector zero =
|
|
2136
|
+
HVX_Vector zero = Q6_V_vzero();
|
|
1665
2137
|
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
|
|
1666
2138
|
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
|
|
1667
2139
|
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
|
|
@@ -1672,13 +2144,8 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri
|
|
|
1672
2144
|
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
|
|
1673
2145
|
|
|
1674
2146
|
// Compute max and scale
|
|
1675
|
-
HVX_Vector vmax01_hf =
|
|
1676
|
-
HVX_Vector vmax23_hf =
|
|
1677
|
-
|
|
1678
|
-
// Replicate first fp16 scale across all lanes
|
|
1679
|
-
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16;
|
|
1680
|
-
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
|
|
1681
|
-
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
|
|
2147
|
+
HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes
|
|
2148
|
+
HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes
|
|
1682
2149
|
|
|
1683
2150
|
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
|
1684
2151
|
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
|
@@ -1689,8 +2156,8 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri
|
|
|
1689
2156
|
hvx_vec_store_u(y_d + 4, 4, vd23_hf);
|
|
1690
2157
|
|
|
1691
2158
|
// Divide input by the scale
|
|
1692
|
-
HVX_Vector vd01_inv_hf =
|
|
1693
|
-
HVX_Vector vd23_inv_hf =
|
|
2159
|
+
HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
|
|
2160
|
+
HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
|
|
1694
2161
|
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
|
|
1695
2162
|
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
|
|
1696
2163
|
|
|
@@ -1702,14 +2169,14 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri
|
|
|
1702
2169
|
*(HVX_Vector *) y_q = vx_i8;
|
|
1703
2170
|
}
|
|
1704
2171
|
|
|
1705
|
-
static inline void
|
|
2172
|
+
static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
|
|
1706
2173
|
assert((unsigned long) x % 128 == 0);
|
|
1707
2174
|
assert((unsigned long) y_q % 128 == 0);
|
|
1708
2175
|
|
|
1709
2176
|
HVX_Vector * vx = (HVX_Vector *) x;
|
|
1710
2177
|
|
|
1711
2178
|
// Load and convert into QF32
|
|
1712
|
-
HVX_Vector zero =
|
|
2179
|
+
HVX_Vector zero = Q6_V_vzero();
|
|
1713
2180
|
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
|
|
1714
2181
|
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
|
|
1715
2182
|
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
|
|
@@ -1720,12 +2187,8 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri
|
|
|
1720
2187
|
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
|
|
1721
2188
|
|
|
1722
2189
|
// Compute max and scale
|
|
1723
|
-
HVX_Vector vmax_hf =
|
|
1724
|
-
vmax_hf =
|
|
1725
|
-
|
|
1726
|
-
// Replicate first fp16 scale across all lanes
|
|
1727
|
-
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16;
|
|
1728
|
-
vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl);
|
|
2190
|
+
HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
|
|
2191
|
+
vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes
|
|
1729
2192
|
|
|
1730
2193
|
HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
|
1731
2194
|
HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16);
|
|
@@ -1733,7 +2196,7 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri
|
|
|
1733
2196
|
*(HVX_UVector *) y_d = vd_hf;
|
|
1734
2197
|
|
|
1735
2198
|
// Divide input by the scale
|
|
1736
|
-
HVX_Vector vd_inv_hf =
|
|
2199
|
+
HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf);
|
|
1737
2200
|
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf));
|
|
1738
2201
|
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf));
|
|
1739
2202
|
|
|
@@ -1746,7 +2209,7 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri
|
|
|
1746
2209
|
}
|
|
1747
2210
|
|
|
1748
2211
|
// Overrides input x
|
|
1749
|
-
static void
|
|
2212
|
+
static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
|
|
1750
2213
|
assert(k % 32 == 0);
|
|
1751
2214
|
const uint32_t qk = QK_Q8_0x4x2;
|
|
1752
2215
|
const uint32_t nb = (k + qk - 1) / qk;
|
|
@@ -1764,29 +2227,31 @@ static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, u
|
|
|
1764
2227
|
|
|
1765
2228
|
for (uint32_t i = 0; i < nb; i++) {
|
|
1766
2229
|
#if FP32_QUANTIZE_GROUP_SIZE == 32
|
|
1767
|
-
|
|
1768
|
-
|
|
2230
|
+
quantize_block_f32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
|
|
2231
|
+
quantize_block_f32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
|
|
1769
2232
|
#elif FP32_QUANTIZE_GROUP_SIZE == 64
|
|
1770
|
-
|
|
1771
|
-
|
|
2233
|
+
quantize_block_f32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
|
|
2234
|
+
quantize_block_f32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
|
|
1772
2235
|
#elif FP32_QUANTIZE_GROUP_SIZE == 128
|
|
1773
|
-
|
|
1774
|
-
|
|
2236
|
+
quantize_block_f32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
|
|
2237
|
+
quantize_block_f32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
|
|
1775
2238
|
#else
|
|
1776
2239
|
#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128"
|
|
1777
2240
|
#endif
|
|
1778
2241
|
}
|
|
1779
2242
|
|
|
1780
2243
|
// now copy the scales into final location
|
|
1781
|
-
|
|
2244
|
+
hvx_copy_f16_ua(y_d, t_d, nb * 8);
|
|
1782
2245
|
}
|
|
1783
2246
|
|
|
1784
|
-
static void
|
|
1785
|
-
|
|
1786
|
-
|
|
1787
|
-
|
|
1788
|
-
|
|
1789
|
-
|
|
2247
|
+
static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) {
|
|
2248
|
+
struct htp_matmul_context * mmctx = data;
|
|
2249
|
+
struct htp_ops_context * octx = mmctx->octx;
|
|
2250
|
+
|
|
2251
|
+
const struct htp_tensor * src = &octx->src1;
|
|
2252
|
+
uint8_t * restrict dst = octx->src1_spad.data;
|
|
2253
|
+
struct htp_spad * spad = &octx->src0_spad;
|
|
2254
|
+
uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
|
|
1790
2255
|
|
|
1791
2256
|
uint64_t t1 = HAP_perf_get_qtimer_count();
|
|
1792
2257
|
|
|
@@ -1807,27 +2272,33 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
|
|
|
1807
2272
|
uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first);
|
|
1808
2273
|
uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith);
|
|
1809
2274
|
|
|
1810
|
-
const size_t src_row_size_padded =
|
|
2275
|
+
const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float));
|
|
1811
2276
|
memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding
|
|
1812
2277
|
|
|
1813
2278
|
for (uint32_t i = ir_first; i < ir_last; ++i) {
|
|
1814
|
-
|
|
1815
|
-
|
|
2279
|
+
hex_l2fetch(src_data, src_row_size, src_row_size, 2);
|
|
2280
|
+
hvx_copy_f32_aa(tmp_data, src_data, ne0);
|
|
1816
2281
|
|
|
1817
2282
|
// FARF(HIGH, "quantize-q8x4-row: %u\n", i);
|
|
1818
|
-
|
|
2283
|
+
quantize_row_f32_q8x4x2((float *) tmp_data, dst_data, ne0);
|
|
1819
2284
|
dst_data += dst_row_size;
|
|
1820
2285
|
src_data += src_row_size;
|
|
1821
2286
|
}
|
|
1822
2287
|
|
|
1823
2288
|
uint64_t t2 = HAP_perf_get_qtimer_count();
|
|
1824
2289
|
|
|
1825
|
-
FARF(HIGH, "quantize-
|
|
2290
|
+
FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
|
|
1826
2291
|
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
|
1827
2292
|
}
|
|
1828
2293
|
|
|
1829
|
-
static void
|
|
1830
|
-
|
|
2294
|
+
static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
|
|
2295
|
+
struct htp_matmul_context * mmctx = data;
|
|
2296
|
+
struct htp_ops_context * octx = mmctx->octx;
|
|
2297
|
+
|
|
2298
|
+
const struct htp_tensor * src = &octx->src1;
|
|
2299
|
+
uint8_t * restrict dst = octx->src1_spad.data;
|
|
2300
|
+
uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
|
|
2301
|
+
uint32_t dst_stride = octx->src1_spad.stride;
|
|
1831
2302
|
|
|
1832
2303
|
uint64_t t1 = HAP_perf_get_qtimer_count();
|
|
1833
2304
|
|
|
@@ -1848,8 +2319,8 @@ static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict
|
|
|
1848
2319
|
uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
|
|
1849
2320
|
|
|
1850
2321
|
for (uint32_t i = ir_first; i < ir_last; ++i) {
|
|
1851
|
-
|
|
1852
|
-
|
|
2322
|
+
hex_l2fetch(src_data, src_row_size, src_stride, 2);
|
|
2323
|
+
hvx_copy_f16_f32_au(dst_data, src_data, ne0);
|
|
1853
2324
|
|
|
1854
2325
|
dst_data += dst_stride;
|
|
1855
2326
|
src_data += src_stride;
|
|
@@ -1857,13 +2328,19 @@ static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict
|
|
|
1857
2328
|
|
|
1858
2329
|
uint64_t t2 = HAP_perf_get_qtimer_count();
|
|
1859
2330
|
|
|
1860
|
-
FARF(HIGH, "quantize-
|
|
2331
|
+
FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
|
|
1861
2332
|
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
|
1862
2333
|
}
|
|
1863
2334
|
|
|
1864
2335
|
// TODO just a plain copy that should be done via the DMA during the Op setup
|
|
1865
|
-
static void
|
|
1866
|
-
|
|
2336
|
+
static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
|
|
2337
|
+
struct htp_matmul_context * mmctx = data;
|
|
2338
|
+
struct htp_ops_context * octx = mmctx->octx;
|
|
2339
|
+
|
|
2340
|
+
const struct htp_tensor * src = &octx->src1;
|
|
2341
|
+
uint8_t * restrict dst = octx->src1_spad.data;
|
|
2342
|
+
uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
|
|
2343
|
+
uint32_t dst_stride = octx->src1_spad.stride;
|
|
1867
2344
|
|
|
1868
2345
|
uint64_t t1 = HAP_perf_get_qtimer_count();
|
|
1869
2346
|
|
|
@@ -1884,8 +2361,8 @@ static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict
|
|
|
1884
2361
|
uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
|
|
1885
2362
|
|
|
1886
2363
|
for (uint32_t i = ir_first; i < ir_last; ++i) {
|
|
1887
|
-
|
|
1888
|
-
|
|
2364
|
+
hex_l2fetch(src_data, src_row_size, src_stride, 2);
|
|
2365
|
+
hvx_copy_f16_au(dst_data, src_data, ne0);
|
|
1889
2366
|
|
|
1890
2367
|
dst_data += dst_stride;
|
|
1891
2368
|
src_data += src_stride;
|
|
@@ -1893,400 +2370,177 @@ static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict
|
|
|
1893
2370
|
|
|
1894
2371
|
uint64_t t2 = HAP_perf_get_qtimer_count();
|
|
1895
2372
|
|
|
1896
|
-
FARF(HIGH, "quantize-
|
|
2373
|
+
FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
|
|
1897
2374
|
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
|
1898
2375
|
}
|
|
1899
2376
|
|
|
1900
|
-
static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) {
|
|
1901
|
-
struct htp_ops_context * octx = data;
|
|
1902
|
-
quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread);
|
|
1903
|
-
}
|
|
1904
|
-
|
|
1905
|
-
static void htp_quantize_fp32_fp16(unsigned int n, unsigned int i, void * data) {
|
|
1906
|
-
struct htp_ops_context * octx = data;
|
|
1907
|
-
quantize_fp32_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
|
|
1908
|
-
}
|
|
1909
|
-
|
|
1910
|
-
static void htp_quantize_fp16_fp16(unsigned int n, unsigned int i, void * data) {
|
|
1911
|
-
struct htp_ops_context * octx = data;
|
|
1912
|
-
quantize_fp16_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
|
|
1913
|
-
}
|
|
1914
|
-
|
|
1915
|
-
// ** matmul/matvec callbacks for worker_pool
|
|
1916
|
-
|
|
1917
|
-
static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
|
|
1918
|
-
struct htp_ops_context * octx = data;
|
|
1919
|
-
|
|
1920
|
-
struct htp_matmul_type mt;
|
|
1921
|
-
mt.type = "q4x4x2-q8x4x2";
|
|
1922
|
-
mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
|
|
1923
|
-
mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
|
|
1924
|
-
|
|
1925
|
-
matvec_2d(&mt, octx, n, i);
|
|
1926
|
-
}
|
|
1927
|
-
|
|
1928
|
-
static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
|
|
1929
|
-
struct htp_ops_context * octx = data;
|
|
1930
|
-
|
|
1931
|
-
struct htp_matmul_type mt;
|
|
1932
|
-
mt.type = "q4x4x2-q8x4x2";
|
|
1933
|
-
mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
|
|
1934
|
-
mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
|
|
1935
|
-
|
|
1936
|
-
matmul_2d(&mt, octx, n, i);
|
|
1937
|
-
}
|
|
1938
|
-
|
|
1939
|
-
static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
|
|
1940
|
-
struct htp_ops_context * octx = data;
|
|
1941
|
-
|
|
1942
|
-
struct htp_matmul_type mt;
|
|
1943
|
-
mt.type = "q8x4x2-q8x4x2";
|
|
1944
|
-
mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
|
|
1945
|
-
mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
|
|
1946
|
-
|
|
1947
|
-
matvec_2d(&mt, octx, n, i);
|
|
1948
|
-
}
|
|
1949
|
-
|
|
1950
|
-
static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
|
|
1951
|
-
struct htp_ops_context * octx = data;
|
|
1952
|
-
|
|
1953
|
-
struct htp_matmul_type mt;
|
|
1954
|
-
mt.type = "q8x4x2-q8x4x2";
|
|
1955
|
-
mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
|
|
1956
|
-
mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
|
|
1957
|
-
|
|
1958
|
-
matmul_2d(&mt, octx, n, i);
|
|
1959
|
-
}
|
|
1960
|
-
|
|
1961
|
-
static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
|
|
1962
|
-
struct htp_ops_context * octx = data;
|
|
1963
|
-
|
|
1964
|
-
struct htp_matmul_type mt;
|
|
1965
|
-
mt.type = "mxfp4x4x2-q8x4x2";
|
|
1966
|
-
mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
|
|
1967
|
-
mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
|
|
1968
|
-
|
|
1969
|
-
matvec_2d(&mt, octx, n, i);
|
|
1970
|
-
}
|
|
1971
|
-
|
|
1972
|
-
static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
|
|
1973
|
-
struct htp_ops_context * octx = data;
|
|
1974
|
-
|
|
1975
|
-
struct htp_matmul_type mt;
|
|
1976
|
-
mt.type = "mxfp4x4x2-q8x4x2";
|
|
1977
|
-
mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
|
|
1978
|
-
mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
|
|
1979
|
-
|
|
1980
|
-
matmul_2d(&mt, octx, n, i);
|
|
1981
|
-
}
|
|
1982
|
-
|
|
1983
|
-
static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
|
|
1984
|
-
struct htp_ops_context * octx = data;
|
|
1985
|
-
|
|
1986
|
-
struct htp_matmul_type mt;
|
|
1987
|
-
mt.type = "f16-f16";
|
|
1988
|
-
mt.vec_dot = vec_dot_f16_f16_aa;
|
|
1989
|
-
mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
|
|
1990
|
-
|
|
1991
|
-
matvec_2d(&mt, octx, n, i);
|
|
1992
|
-
}
|
|
1993
|
-
|
|
1994
|
-
static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
|
|
1995
|
-
struct htp_ops_context * octx = data;
|
|
1996
|
-
|
|
1997
|
-
struct htp_matmul_type mt;
|
|
1998
|
-
mt.type = "f16-f16";
|
|
1999
|
-
mt.vec_dot = vec_dot_f16_f16_aa;
|
|
2000
|
-
mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
|
|
2001
|
-
|
|
2002
|
-
matmul_2d(&mt, octx, n, i);
|
|
2003
|
-
}
|
|
2004
|
-
|
|
2005
|
-
static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) {
|
|
2006
|
-
struct htp_ops_context * octx = data;
|
|
2007
|
-
|
|
2008
|
-
struct htp_matmul_type mt;
|
|
2009
|
-
mt.type = "f16-f32";
|
|
2010
|
-
mt.vec_dot = vec_dot_f16_f32_uu;
|
|
2011
|
-
|
|
2012
|
-
matmul_4d(&mt, octx, n, i);
|
|
2013
|
-
}
|
|
2014
|
-
|
|
2015
|
-
static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) {
|
|
2016
|
-
struct htp_ops_context * octx = data;
|
|
2017
|
-
|
|
2018
|
-
struct htp_matmul_type mt;
|
|
2019
|
-
mt.type = "f16-f16";
|
|
2020
|
-
mt.vec_dot = vec_dot_f16_f16_uu;
|
|
2021
|
-
|
|
2022
|
-
matmul_4d(&mt, octx, n, i);
|
|
2023
|
-
}
|
|
2024
|
-
|
|
2025
|
-
// ** matmul-id callbacks for worker_pool
|
|
2026
|
-
|
|
2027
|
-
static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
|
|
2028
|
-
struct htp_ops_context * octx = data;
|
|
2029
|
-
|
|
2030
|
-
struct htp_matmul_type mt;
|
|
2031
|
-
mt.type = "q4x4x2-q8x4x2";
|
|
2032
|
-
mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
|
|
2033
|
-
mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
|
|
2034
|
-
|
|
2035
|
-
matvec_id(&mt, octx, n, i);
|
|
2036
|
-
}
|
|
2037
|
-
|
|
2038
|
-
static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
|
|
2039
|
-
struct htp_ops_context * octx = data;
|
|
2040
|
-
|
|
2041
|
-
struct htp_matmul_type mt;
|
|
2042
|
-
mt.type = "q4x4x2-q8x4x2";
|
|
2043
|
-
mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
|
|
2044
|
-
mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
|
|
2045
|
-
|
|
2046
|
-
matmul_id(&mt, octx, n, i);
|
|
2047
|
-
}
|
|
2048
|
-
|
|
2049
|
-
static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
|
|
2050
|
-
struct htp_ops_context * octx = data;
|
|
2051
|
-
|
|
2052
|
-
struct htp_matmul_type mt;
|
|
2053
|
-
mt.type = "q8x4x2-q8x4x2";
|
|
2054
|
-
mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
|
|
2055
|
-
mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
|
|
2056
|
-
|
|
2057
|
-
matvec_id(&mt, octx, n, i);
|
|
2058
|
-
}
|
|
2059
|
-
|
|
2060
|
-
static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
|
|
2061
|
-
struct htp_ops_context * octx = data;
|
|
2062
|
-
|
|
2063
|
-
struct htp_matmul_type mt;
|
|
2064
|
-
mt.type = "q8x4x2-q8x4x2";
|
|
2065
|
-
mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
|
|
2066
|
-
mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
|
|
2067
2377
|
|
|
2068
|
-
|
|
2378
|
+
static inline bool htp_is_permuted(const struct htp_tensor * t) {
|
|
2379
|
+
return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3];
|
|
2069
2380
|
}
|
|
2070
2381
|
|
|
2071
|
-
static
|
|
2072
|
-
|
|
2073
|
-
|
|
2074
|
-
|
|
2075
|
-
|
|
2076
|
-
|
|
2077
|
-
|
|
2078
|
-
|
|
2079
|
-
|
|
2382
|
+
static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) {
|
|
2383
|
+
switch (type) {
|
|
2384
|
+
case HTP_TYPE_Q4_0:
|
|
2385
|
+
mmctx->type = "q4x4x2-f32";
|
|
2386
|
+
mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1;
|
|
2387
|
+
mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1;
|
|
2388
|
+
mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2;
|
|
2389
|
+
return 0;
|
|
2390
|
+
case HTP_TYPE_Q8_0:
|
|
2391
|
+
mmctx->type = "q8x4x2-f32";
|
|
2392
|
+
mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1;
|
|
2393
|
+
mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1;
|
|
2394
|
+
mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2;
|
|
2395
|
+
return 0;
|
|
2396
|
+
case HTP_TYPE_MXFP4:
|
|
2397
|
+
mmctx->type = "mxfp4x4x2-f32";
|
|
2398
|
+
mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1;
|
|
2399
|
+
mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1;
|
|
2400
|
+
mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2;
|
|
2401
|
+
return 0;
|
|
2402
|
+
default:
|
|
2403
|
+
return -1;
|
|
2404
|
+
}
|
|
2080
2405
|
}
|
|
2081
2406
|
|
|
2082
|
-
static void
|
|
2083
|
-
|
|
2084
|
-
|
|
2085
|
-
|
|
2086
|
-
|
|
2087
|
-
|
|
2088
|
-
|
|
2089
|
-
|
|
2090
|
-
|
|
2091
|
-
|
|
2407
|
+
static void htp_mminit_spad(struct htp_ops_context * octx,
|
|
2408
|
+
size_t dst_row_size,
|
|
2409
|
+
size_t src0_row_size_padded,
|
|
2410
|
+
size_t src1_row_size,
|
|
2411
|
+
uint32_t src1_nrows,
|
|
2412
|
+
size_t src2_spad_size_per_thread) {
|
|
2413
|
+
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
|
2414
|
+
octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
|
2415
|
+
octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
|
|
2416
|
+
|
|
2417
|
+
if (src2_spad_size_per_thread > 0) {
|
|
2418
|
+
octx->src2_spad.size_per_thread = src2_spad_size_per_thread;
|
|
2419
|
+
octx->src2_spad.size = octx->src2_spad.size_per_thread;
|
|
2420
|
+
}
|
|
2092
2421
|
|
|
2093
|
-
//
|
|
2422
|
+
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
|
2423
|
+
size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
|
2424
|
+
if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
|
|
2425
|
+
octx->src0_spad.size_per_thread = src1_row_size_padded;
|
|
2426
|
+
}
|
|
2094
2427
|
|
|
2095
|
-
|
|
2096
|
-
|
|
2428
|
+
octx->src1_spad.size = octx->src1_spad.size_per_thread;
|
|
2429
|
+
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
|
2430
|
+
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
|
2097
2431
|
}
|
|
2098
2432
|
|
|
2099
2433
|
int op_matmul(struct htp_ops_context * octx) {
|
|
2100
2434
|
htp_matmul_tensors_preamble;
|
|
2101
2435
|
|
|
2102
|
-
|
|
2436
|
+
struct htp_matmul_context mmctx_struct = {0};
|
|
2437
|
+
struct htp_matmul_context * mmctx = &mmctx_struct;
|
|
2438
|
+
mmctx->octx = octx;
|
|
2103
2439
|
|
|
2104
2440
|
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
|
2105
2441
|
const uint32_t src1_nrows = ne11 * ne12 * ne13;
|
|
2106
2442
|
|
|
2443
|
+
// Compute src0_nrows_per_thread
|
|
2444
|
+
mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
|
|
2445
|
+
mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even
|
|
2446
|
+
|
|
2107
2447
|
const size_t src0_row_size = nb01;
|
|
2108
2448
|
const size_t dst_row_size = nb1;
|
|
2109
2449
|
size_t src1_row_size = nb11;
|
|
2110
2450
|
|
|
2111
|
-
const size_t src0_row_size_padded =
|
|
2451
|
+
const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
|
|
2112
2452
|
size_t src1_row_size_padded;
|
|
2113
2453
|
|
|
2114
2454
|
worker_callback_t quant_job_func;
|
|
2115
|
-
worker_callback_t matmul_job_func;
|
|
2455
|
+
worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d;
|
|
2116
2456
|
|
|
2117
2457
|
bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE);
|
|
2118
2458
|
|
|
2119
|
-
|
|
2120
|
-
|
|
2121
|
-
|
|
2122
|
-
|
|
2123
|
-
|
|
2124
|
-
|
|
2125
|
-
} else {
|
|
2126
|
-
matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2;
|
|
2127
|
-
}
|
|
2459
|
+
if (src0->type == HTP_TYPE_F16) {
|
|
2460
|
+
// Try optimized f16-f16 path first (src1 in VTCM)
|
|
2461
|
+
const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128);
|
|
2462
|
+
const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256);
|
|
2463
|
+
const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
|
|
2464
|
+
const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
|
|
2128
2465
|
|
|
2129
|
-
|
|
2466
|
+
const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
|
|
2130
2467
|
|
|
2131
|
-
|
|
2132
|
-
|
|
2468
|
+
// Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
|
|
2469
|
+
// It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
|
|
2470
|
+
const bool is_batched = (ne02 > 1) || (ne03 > 1);
|
|
2471
|
+
const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
|
|
2133
2472
|
|
|
2134
|
-
|
|
2135
|
-
|
|
2136
|
-
|
|
2473
|
+
if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
|
|
2474
|
+
// Optimized path
|
|
2475
|
+
quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16;
|
|
2476
|
+
mmctx->type = "f16-f16";
|
|
2477
|
+
mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1;
|
|
2478
|
+
mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1;
|
|
2479
|
+
mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2;
|
|
2137
2480
|
|
|
2138
|
-
|
|
2139
|
-
src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
|
2140
|
-
if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
|
|
2141
|
-
octx->src0_spad.size_per_thread = src1_row_size_padded;
|
|
2142
|
-
}
|
|
2481
|
+
src1_row_size = f16_src1_row_size; // row size post quantization
|
|
2143
2482
|
|
|
2144
|
-
octx->
|
|
2145
|
-
octx->src0_spad.
|
|
2146
|
-
octx->
|
|
2147
|
-
break;
|
|
2148
|
-
|
|
2149
|
-
case HTP_TYPE_Q8_0:
|
|
2150
|
-
op_type = "q8x4x2-fp32";
|
|
2151
|
-
quant_job_func = htp_quantize_fp32_q8x4x2;
|
|
2152
|
-
if (src1_nrows > 1) {
|
|
2153
|
-
matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2;
|
|
2154
|
-
} else {
|
|
2155
|
-
matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2;
|
|
2156
|
-
}
|
|
2157
|
-
|
|
2158
|
-
src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
|
|
2159
|
-
|
|
2160
|
-
// Entire src1 tensor is placed into the VTCM
|
|
2161
|
-
// For other tensors we allocate N rows per thread, padded to HVX vector size
|
|
2162
|
-
|
|
2163
|
-
octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
|
2164
|
-
octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
|
2165
|
-
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
|
|
2166
|
-
|
|
2167
|
-
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
|
2168
|
-
src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
|
2169
|
-
if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
|
|
2170
|
-
octx->src0_spad.size_per_thread = src1_row_size_padded;
|
|
2171
|
-
}
|
|
2483
|
+
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
|
2484
|
+
octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
|
2485
|
+
octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
|
|
2172
2486
|
|
|
2173
2487
|
octx->src1_spad.size = octx->src1_spad.size_per_thread;
|
|
2174
2488
|
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
|
2175
2489
|
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
|
2176
|
-
|
|
2177
|
-
|
|
2178
|
-
|
|
2179
|
-
|
|
2180
|
-
|
|
2181
|
-
|
|
2182
|
-
matmul_job_func
|
|
2490
|
+
} else {
|
|
2491
|
+
// Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
|
|
2492
|
+
quant_job_func = NULL;
|
|
2493
|
+
if (src1->type == HTP_TYPE_F32) {
|
|
2494
|
+
mmctx->type = "f16-f32";
|
|
2495
|
+
mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1;
|
|
2496
|
+
matmul_job_func = matmul_4d;
|
|
2183
2497
|
} else {
|
|
2184
|
-
|
|
2498
|
+
mmctx->type = "f16-f16";
|
|
2499
|
+
mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1;
|
|
2500
|
+
matmul_job_func = matmul_4d;
|
|
2185
2501
|
}
|
|
2186
2502
|
|
|
2187
|
-
src1_row_size =
|
|
2503
|
+
src1_row_size = nb11; // original row size in DDR
|
|
2188
2504
|
|
|
2189
|
-
|
|
2190
|
-
|
|
2505
|
+
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
|
2506
|
+
octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
|
|
2507
|
+
octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
|
|
2191
2508
|
|
|
2192
|
-
octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
|
2193
|
-
octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
|
2194
|
-
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
|
|
2195
|
-
|
|
2196
|
-
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
|
2197
|
-
src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
|
2198
|
-
if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
|
|
2199
|
-
octx->src0_spad.size_per_thread = src1_row_size_padded;
|
|
2200
|
-
}
|
|
2201
|
-
|
|
2202
|
-
octx->src1_spad.size = octx->src1_spad.size_per_thread;
|
|
2203
2509
|
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
|
2510
|
+
octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
|
|
2204
2511
|
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
|
2205
|
-
break;
|
|
2206
2512
|
|
|
2207
|
-
|
|
2208
|
-
|
|
2209
|
-
|
|
2210
|
-
|
|
2211
|
-
|
|
2212
|
-
const size_t f16_src0_spad_size = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
|
|
2213
|
-
const size_t f16_dst_spad_size = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
|
|
2214
|
-
|
|
2215
|
-
const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
|
|
2216
|
-
|
|
2217
|
-
// Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
|
|
2218
|
-
// It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
|
|
2219
|
-
const bool is_batched = (ne02 > 1) || (ne03 > 1);
|
|
2220
|
-
const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
|
|
2221
|
-
|
|
2222
|
-
if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
|
|
2223
|
-
// Optimized path
|
|
2224
|
-
op_type = "f16-f16";
|
|
2225
|
-
quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_fp32_fp16 : htp_quantize_fp16_fp16;
|
|
2226
|
-
if (src1_nrows > 1) {
|
|
2227
|
-
matmul_job_func = htp_matmul_2d_f16_f16;
|
|
2228
|
-
} else {
|
|
2229
|
-
matmul_job_func = htp_matvec_2d_f16_f16;
|
|
2230
|
-
}
|
|
2231
|
-
|
|
2232
|
-
src1_row_size = f16_src1_row_size; // row size post quantization
|
|
2233
|
-
|
|
2234
|
-
octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
|
2235
|
-
octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
|
2236
|
-
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
|
|
2237
|
-
|
|
2238
|
-
octx->src1_spad.size = octx->src1_spad.size_per_thread;
|
|
2239
|
-
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
|
2240
|
-
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
|
2241
|
-
} else {
|
|
2242
|
-
// Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
|
|
2243
|
-
quant_job_func = NULL;
|
|
2244
|
-
if (src1->type == HTP_TYPE_F32) {
|
|
2245
|
-
op_type = "f16-f32";
|
|
2246
|
-
matmul_job_func = htp_matmul_4d_f16_f32;
|
|
2247
|
-
} else {
|
|
2248
|
-
op_type = "f16-f16";
|
|
2249
|
-
matmul_job_func = htp_matmul_4d_f16_f16;
|
|
2250
|
-
}
|
|
2251
|
-
|
|
2252
|
-
src1_row_size = nb11; // original row size in DDR
|
|
2253
|
-
|
|
2254
|
-
octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
|
2255
|
-
octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
|
|
2256
|
-
octx->src1_spad.size_per_thread = htp_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
|
|
2257
|
-
|
|
2258
|
-
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
|
2259
|
-
octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
|
|
2260
|
-
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
|
2261
|
-
|
|
2262
|
-
// Init fastdiv for matmul_4d (supports broadcasting)
|
|
2263
|
-
octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
|
|
2264
|
-
octx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]);
|
|
2265
|
-
octx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
|
|
2266
|
-
octx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
|
|
2267
|
-
|
|
2268
|
-
need_quant = false;
|
|
2269
|
-
}
|
|
2270
|
-
}
|
|
2271
|
-
break;
|
|
2513
|
+
// Init fastdiv for matmul_4d (supports broadcasting)
|
|
2514
|
+
mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
|
|
2515
|
+
mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]);
|
|
2516
|
+
mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
|
|
2517
|
+
mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
|
|
2272
2518
|
|
|
2273
|
-
|
|
2519
|
+
need_quant = false;
|
|
2520
|
+
}
|
|
2521
|
+
} else {
|
|
2522
|
+
if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
|
|
2274
2523
|
return HTP_STATUS_NO_SUPPORT;
|
|
2524
|
+
}
|
|
2525
|
+
|
|
2526
|
+
quant_job_func = quantize_f32_q8x4x2;
|
|
2527
|
+
src1_row_size = q8x4x2_row_size(ne10);
|
|
2528
|
+
htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0);
|
|
2275
2529
|
}
|
|
2276
2530
|
|
|
2277
2531
|
// VTCM scratchpads for all tensors
|
|
2278
2532
|
size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
|
|
2279
2533
|
|
|
2280
|
-
FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n",
|
|
2534
|
+
FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type,
|
|
2281
2535
|
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size);
|
|
2282
2536
|
|
|
2283
|
-
FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n",
|
|
2537
|
+
FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0],
|
|
2284
2538
|
src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],
|
|
2285
2539
|
dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data);
|
|
2286
2540
|
|
|
2287
2541
|
// Make sure the reserved vtcm size is sufficient
|
|
2288
2542
|
if (octx->ctx->vtcm_size < spad_size) {
|
|
2289
|
-
FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n",
|
|
2543
|
+
FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type,
|
|
2290
2544
|
octx->ctx->vtcm_size, spad_size);
|
|
2291
2545
|
return HTP_STATUS_VTCM_TOO_SMALL;
|
|
2292
2546
|
}
|
|
@@ -2295,48 +2549,47 @@ int op_matmul(struct htp_ops_context * octx) {
|
|
|
2295
2549
|
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
|
2296
2550
|
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
|
2297
2551
|
|
|
2298
|
-
octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
|
|
2299
|
-
octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even
|
|
2300
|
-
|
|
2301
2552
|
octx->src0_spad.stride = src0_row_size_padded;
|
|
2302
2553
|
octx->src1_spad.stride = src1_row_size;
|
|
2303
2554
|
|
|
2304
2555
|
if (need_quant) {
|
|
2305
|
-
|
|
2306
|
-
|
|
2307
|
-
octx->
|
|
2308
|
-
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs);
|
|
2556
|
+
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
|
|
2557
|
+
mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
|
|
2558
|
+
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
|
|
2309
2559
|
}
|
|
2310
2560
|
|
|
2311
2561
|
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
|
2312
|
-
// Run matmul jobs
|
|
2313
2562
|
const uint32_t n_matmul_jobs = octx->n_threads;
|
|
2314
|
-
worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func,
|
|
2563
|
+
worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);
|
|
2315
2564
|
}
|
|
2316
2565
|
|
|
2317
2566
|
return HTP_STATUS_OK;
|
|
2318
2567
|
}
|
|
2319
2568
|
|
|
2320
|
-
// ** main matmul-id entry point
|
|
2321
|
-
|
|
2322
2569
|
int op_matmul_id(struct htp_ops_context * octx) {
|
|
2323
2570
|
htp_matmul_tensors_preamble;
|
|
2324
2571
|
|
|
2325
|
-
struct
|
|
2326
|
-
|
|
2327
|
-
|
|
2572
|
+
struct htp_matmul_context mmctx_struct = {0};
|
|
2573
|
+
struct htp_matmul_context * mmctx = &mmctx_struct;
|
|
2574
|
+
mmctx->octx = octx;
|
|
2328
2575
|
|
|
2329
|
-
|
|
2330
|
-
worker_callback_t matmul_id_job_func;
|
|
2576
|
+
struct htp_tensor * restrict ids = &octx->src2;
|
|
2331
2577
|
|
|
2332
2578
|
const size_t src0_row_size = nb01;
|
|
2333
2579
|
const size_t dst_row_size = nb1;
|
|
2334
2580
|
|
|
2335
|
-
const size_t src0_row_size_padded =
|
|
2581
|
+
const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
|
|
2336
2582
|
|
|
2337
2583
|
const uint32_t src0_nrows = ne01; // per expert
|
|
2338
2584
|
const uint32_t src1_nrows = ne11 * ne12 * ne13;
|
|
2339
2585
|
|
|
2586
|
+
worker_callback_t quant_job_func;
|
|
2587
|
+
worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id;
|
|
2588
|
+
|
|
2589
|
+
// Compute src0_nrows_per_thread
|
|
2590
|
+
mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
|
|
2591
|
+
mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even
|
|
2592
|
+
|
|
2340
2593
|
size_t src1_row_size;
|
|
2341
2594
|
size_t src1_row_size_padded;
|
|
2342
2595
|
|
|
@@ -2347,112 +2600,29 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
|
|
2347
2600
|
size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
|
|
2348
2601
|
size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
|
|
2349
2602
|
|
|
2350
|
-
|
|
2351
|
-
|
|
2352
|
-
|
|
2353
|
-
quant_job_func = htp_quantize_fp32_q8x4x2;
|
|
2354
|
-
src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
|
|
2355
|
-
if (src1_nrows > 1) {
|
|
2356
|
-
matmul_id_job_func = htp_matmul_id_q4x4x2_q8x4x2;
|
|
2357
|
-
} else {
|
|
2358
|
-
matmul_id_job_func = htp_matvec_id_q4x4x2_q8x4x2;
|
|
2359
|
-
}
|
|
2360
|
-
|
|
2361
|
-
// Entire src1 tensor is placed into the VTCM
|
|
2362
|
-
// For other tensors we allocate N rows per thread, padded to HVX vector size
|
|
2363
|
-
octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
|
2364
|
-
octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
|
2365
|
-
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
|
|
2366
|
-
octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
|
|
2367
|
-
|
|
2368
|
-
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
|
2369
|
-
src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
|
2370
|
-
if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
|
|
2371
|
-
octx->src0_spad.size_per_thread = src1_row_size_padded;
|
|
2372
|
-
}
|
|
2373
|
-
|
|
2374
|
-
octx->src2_spad.size = octx->src2_spad.size_per_thread;
|
|
2375
|
-
octx->src1_spad.size = octx->src1_spad.size_per_thread;
|
|
2376
|
-
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
|
2377
|
-
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
|
2378
|
-
break;
|
|
2379
|
-
|
|
2380
|
-
case HTP_TYPE_Q8_0:
|
|
2381
|
-
op_type = "q8x2x2-f32";
|
|
2382
|
-
quant_job_func = htp_quantize_fp32_q8x4x2;
|
|
2383
|
-
src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
|
|
2384
|
-
if (src1_nrows > 1) {
|
|
2385
|
-
matmul_id_job_func = htp_matmul_id_q8x4x2_q8x4x2;
|
|
2386
|
-
} else {
|
|
2387
|
-
matmul_id_job_func = htp_matvec_id_q8x4x2_q8x4x2;
|
|
2388
|
-
}
|
|
2389
|
-
|
|
2390
|
-
// Entire src1 tensor is placed into the VTCM
|
|
2391
|
-
// For other tensors we allocate N rows per thread, padded to HVX vector size
|
|
2392
|
-
octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
|
2393
|
-
octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
|
2394
|
-
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
|
|
2395
|
-
octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
|
|
2396
|
-
|
|
2397
|
-
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
|
2398
|
-
src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
|
2399
|
-
if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
|
|
2400
|
-
octx->src0_spad.size_per_thread = src1_row_size_padded;
|
|
2401
|
-
}
|
|
2402
|
-
|
|
2403
|
-
octx->src2_spad.size = octx->src2_spad.size_per_thread;
|
|
2404
|
-
octx->src1_spad.size = octx->src1_spad.size_per_thread;
|
|
2405
|
-
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
|
2406
|
-
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
|
2407
|
-
break;
|
|
2408
|
-
|
|
2409
|
-
case HTP_TYPE_MXFP4:
|
|
2410
|
-
op_type = "mxfp4x2x2-f32";
|
|
2411
|
-
quant_job_func = htp_quantize_fp32_q8x4x2;
|
|
2412
|
-
src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
|
|
2413
|
-
if (src1_nrows > 1) {
|
|
2414
|
-
matmul_id_job_func = htp_matmul_id_mxfp4x4x2_q8x4x2;
|
|
2415
|
-
} else {
|
|
2416
|
-
matmul_id_job_func = htp_matvec_id_mxfp4x4x2_q8x4x2;
|
|
2417
|
-
}
|
|
2418
|
-
|
|
2419
|
-
// Entire src1 tensor is placed into the VTCM
|
|
2420
|
-
// For other tensors we allocate N rows per thread, padded to HVX vector size
|
|
2421
|
-
octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
|
2422
|
-
octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
|
2423
|
-
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
|
|
2424
|
-
octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
|
|
2425
|
-
|
|
2426
|
-
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
|
2427
|
-
src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
|
2428
|
-
if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
|
|
2429
|
-
octx->src0_spad.size_per_thread = src1_row_size_padded;
|
|
2430
|
-
}
|
|
2603
|
+
if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
|
|
2604
|
+
return HTP_STATUS_NO_SUPPORT;
|
|
2605
|
+
}
|
|
2431
2606
|
|
|
2432
|
-
|
|
2433
|
-
|
|
2434
|
-
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
|
2435
|
-
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
|
2436
|
-
break;
|
|
2607
|
+
quant_job_func = quantize_f32_q8x4x2;
|
|
2608
|
+
src1_row_size = q8x4x2_row_size(ne10);
|
|
2437
2609
|
|
|
2438
|
-
|
|
2439
|
-
|
|
2440
|
-
}
|
|
2610
|
+
const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
|
|
2611
|
+
htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread);
|
|
2441
2612
|
|
|
2442
2613
|
size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
|
|
2443
2614
|
|
|
2444
|
-
FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n",
|
|
2615
|
+
FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type,
|
|
2445
2616
|
octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size);
|
|
2446
2617
|
|
|
2447
|
-
FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n",
|
|
2618
|
+
FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type,
|
|
2448
2619
|
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
|
|
2449
2620
|
ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data,
|
|
2450
2621
|
src1->data, dst->data);
|
|
2451
2622
|
|
|
2452
2623
|
// Make sure the reserved vtcm size is sufficient
|
|
2453
2624
|
if (octx->ctx->vtcm_size < spad_size) {
|
|
2454
|
-
FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n",
|
|
2455
|
-
octx->ctx->vtcm_size, spad_size);
|
|
2625
|
+
FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size);
|
|
2456
2626
|
return HTP_STATUS_VTCM_TOO_SMALL;
|
|
2457
2627
|
}
|
|
2458
2628
|
|
|
@@ -2461,8 +2631,8 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
|
|
2461
2631
|
octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
|
2462
2632
|
octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size;
|
|
2463
2633
|
|
|
2464
|
-
octx->
|
|
2465
|
-
octx->
|
|
2634
|
+
octx->src0_spad.stride = src0_row_size_padded;
|
|
2635
|
+
octx->src1_spad.stride = src1_row_size;
|
|
2466
2636
|
|
|
2467
2637
|
if (src1_nrows > 1) {
|
|
2468
2638
|
// initialize matrix_row_counts and map
|
|
@@ -2474,8 +2644,7 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
|
|
2474
2644
|
// group rows by src0 matrix
|
|
2475
2645
|
for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx
|
|
2476
2646
|
for (uint32_t id = 0; id < n_ids; ++id) { // expert idx
|
|
2477
|
-
const uint32_t i02 =
|
|
2478
|
-
*(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
|
|
2647
|
+
const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
|
|
2479
2648
|
|
|
2480
2649
|
assert(i02 >= 0 && i02 < n_as);
|
|
2481
2650
|
|
|
@@ -2487,16 +2656,14 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
|
|
2487
2656
|
|
|
2488
2657
|
// Setup worker pool callbacks
|
|
2489
2658
|
if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) {
|
|
2490
|
-
// Run quant jobs
|
|
2491
2659
|
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
|
|
2492
|
-
|
|
2493
|
-
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func,
|
|
2660
|
+
mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
|
|
2661
|
+
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
|
|
2494
2662
|
}
|
|
2495
2663
|
|
|
2496
2664
|
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
|
2497
|
-
// Run matmul-id jobs
|
|
2498
2665
|
const uint32_t n_matmul_jobs = octx->n_threads;
|
|
2499
|
-
worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func,
|
|
2666
|
+
worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
|
|
2500
2667
|
}
|
|
2501
2668
|
|
|
2502
2669
|
return HTP_STATUS_OK;
|