whispercpp 1.3.6 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/README.md +38 -5
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -8
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +36 -42
- data/ext/ruby_whisper.h +135 -0
- data/ext/ruby_whisper_context.c +107 -28
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -65
- data/ext/ruby_whisper_segment.c +6 -6
- data/ext/ruby_whisper_transcribe.cpp +42 -15
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +1 -1
- data/ext/sources/examples/cli/cli.cpp +43 -9
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +199 -163
- data/ext/sources/ggml/CMakeLists.txt +21 -13
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +72 -10
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-rpc.h +3 -3
- data/ext/sources/ggml/include/ggml.h +101 -9
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +22 -5
- data/ext/sources/ggml/src/ggml-alloc.c +5 -1
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
- data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
- data/ext/sources/ggml/src/ggml-impl.h +6 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
- data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +289 -114
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
- data/ext/sources/ggml/src/ggml.c +110 -28
- data/ext/sources/ggml/src/gguf.cpp +173 -28
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +56 -12
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +411 -62
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +24 -6
- data/whispercpp.gemspec +2 -2
- metadata +215 -281
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
- data/ext/sources/examples/talk-llama/llama-context.h +0 -359
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
- data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
- data/ext/sources/examples/talk-llama/llama-model.h +0 -597
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
- data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
- data/ext/sources/examples/talk-llama/llama.h +0 -1573
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -704
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
- /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
|
@@ -16,8 +16,9 @@
|
|
|
16
16
|
#define GGML_COMMON_DECL_C
|
|
17
17
|
#include "ggml-common.h"
|
|
18
18
|
#include "htp-ctx.h"
|
|
19
|
-
#include "htp-msg.h"
|
|
20
19
|
#include "htp-ops.h"
|
|
20
|
+
#include "htp-ops.h"
|
|
21
|
+
#include "hmx-ops.h"
|
|
21
22
|
|
|
22
23
|
#define MM_SPAD_SRC0_NROWS 16
|
|
23
24
|
#define MM_SPAD_SRC1_NROWS 16
|
|
@@ -39,6 +40,11 @@ struct htp_matmul_context {
|
|
|
39
40
|
const void * restrict vx0, const void * restrict vx1,
|
|
40
41
|
const void * restrict vy0, const void * restrict vy1);
|
|
41
42
|
|
|
43
|
+
void (*vec_dot_4x1)(const int n, float * restrict s0,
|
|
44
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
45
|
+
const void * restrict vx2, const void * restrict vx3,
|
|
46
|
+
const void * restrict vy0);
|
|
47
|
+
|
|
42
48
|
// Precomputed values
|
|
43
49
|
uint32_t src0_nrows_per_thread;
|
|
44
50
|
uint32_t src1_nrows_per_thread;
|
|
@@ -47,6 +53,11 @@ struct htp_matmul_context {
|
|
|
47
53
|
struct fastdiv_values mm_div_ne1;
|
|
48
54
|
struct fastdiv_values mm_div_r2;
|
|
49
55
|
struct fastdiv_values mm_div_r3;
|
|
56
|
+
|
|
57
|
+
// Fields for scattered mapping & HMX support in MUL_MAT_ID
|
|
58
|
+
const uint32_t * matrix_row_counts;
|
|
59
|
+
const struct mmid_row_mapping * matrix_rows;
|
|
60
|
+
bool hmx_eligible;
|
|
50
61
|
};
|
|
51
62
|
|
|
52
63
|
// vdelta control to expand first 32 e8m0 values into 32 uint32 elements
|
|
@@ -60,6 +71,16 @@ static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
|
|
|
60
71
|
0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20,
|
|
61
72
|
};
|
|
62
73
|
|
|
74
|
+
// IQ4_NL dequantization LUT: maps 4-bit index (0-15) to int8 kvalue
|
|
75
|
+
// kvalues: -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
|
|
76
|
+
static const uint8_t __attribute__((aligned(VLEN))) kvalues_iq4nl_lut[] = {
|
|
77
|
+
0x81, 0, 0x98, 0, 0xAD, 0, 0xBF, 0, 0xCF, 0, 0xDD, 0, 0xEA, 0, 0xF6, 0, 0x01, 0, 0x0D, 0, 0x19, 0, 0x26, 0,
|
|
78
|
+
0x35, 0, 0x45, 0, 0x59, 0, 0x71, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
79
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
80
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
81
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
82
|
+
};
|
|
83
|
+
|
|
63
84
|
static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
|
|
64
85
|
0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0,
|
|
65
86
|
0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
@@ -68,6 +89,73 @@ static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
|
|
|
68
89
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
69
90
|
};
|
|
70
91
|
|
|
92
|
+
static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_full(const uint8_t * restrict ptr) {
|
|
93
|
+
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
|
94
|
+
|
|
95
|
+
HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
|
|
96
|
+
HVX_Vector v2_3 = vptr[1]; // ...
|
|
97
|
+
HVX_Vector v4_5 = vptr[2]; // ...
|
|
98
|
+
HVX_Vector v6_7 = vptr[3]; // ...
|
|
99
|
+
|
|
100
|
+
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
|
101
|
+
const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut;
|
|
102
|
+
|
|
103
|
+
HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
|
|
104
|
+
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
|
|
105
|
+
HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
|
|
106
|
+
HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
|
|
107
|
+
HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
|
|
108
|
+
HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
|
|
109
|
+
HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
|
|
110
|
+
HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
|
|
111
|
+
|
|
112
|
+
v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
|
|
113
|
+
v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
|
|
114
|
+
v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
|
|
115
|
+
v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
|
|
116
|
+
v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
|
|
117
|
+
v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
|
|
118
|
+
v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
|
|
119
|
+
v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
|
|
120
|
+
|
|
121
|
+
HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
|
|
122
|
+
return r;
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
|
|
126
|
+
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
|
127
|
+
|
|
128
|
+
const uint32_t qk = QK_Q4_0x4x2; // 256
|
|
129
|
+
const uint32_t nb = n / qk;
|
|
130
|
+
const uint32_t nloe = n % qk;
|
|
131
|
+
|
|
132
|
+
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
|
133
|
+
const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut;
|
|
134
|
+
|
|
135
|
+
HVX_Vector_x8 r;
|
|
136
|
+
uint32_t i = 0;
|
|
137
|
+
|
|
138
|
+
#pragma unroll(2)
|
|
139
|
+
for (i = 0; i < nb; i++) {
|
|
140
|
+
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
|
|
141
|
+
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements
|
|
142
|
+
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements
|
|
143
|
+
r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
|
|
144
|
+
r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
if (nloe) {
|
|
148
|
+
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
|
|
149
|
+
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements
|
|
150
|
+
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements
|
|
151
|
+
HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
|
|
152
|
+
r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0);
|
|
153
|
+
r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0);
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
return r;
|
|
157
|
+
}
|
|
158
|
+
|
|
71
159
|
// q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales
|
|
72
160
|
|
|
73
161
|
static inline size_t q8x4x2_row_size(uint32_t ne) {
|
|
@@ -77,6 +165,13 @@ static inline size_t q8x4x2_row_size(uint32_t ne) {
|
|
|
77
165
|
return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128);
|
|
78
166
|
}
|
|
79
167
|
|
|
168
|
+
static inline size_t q8_1x4x2_row_size(uint32_t ne) {
|
|
169
|
+
// ensures perfect alignment of quants and full row
|
|
170
|
+
const uint32_t qk = QK_Q8_0x4x2;
|
|
171
|
+
const uint32_t nb = (ne + qk - 1) / qk;
|
|
172
|
+
return hex_round_up(ne + nb * 8 * 2 * sizeof(__fp16), 128);
|
|
173
|
+
}
|
|
174
|
+
|
|
80
175
|
static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) {
|
|
81
176
|
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
|
82
177
|
|
|
@@ -145,6 +240,62 @@ static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, u
|
|
|
145
240
|
return r;
|
|
146
241
|
}
|
|
147
242
|
|
|
243
|
+
static inline HVX_Vector_x8 hvx_vec_load_q4_1x4x8_full(const uint8_t * restrict ptr) {
|
|
244
|
+
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
|
245
|
+
|
|
246
|
+
HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
|
|
247
|
+
HVX_Vector v2_3 = vptr[1]; // ...
|
|
248
|
+
HVX_Vector v4_5 = vptr[2]; // ...
|
|
249
|
+
HVX_Vector v6_7 = vptr[3]; // ...
|
|
250
|
+
|
|
251
|
+
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
|
252
|
+
|
|
253
|
+
HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements
|
|
254
|
+
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements
|
|
255
|
+
HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ...
|
|
256
|
+
HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
|
|
257
|
+
HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
|
|
258
|
+
HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
|
|
259
|
+
HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
|
|
260
|
+
HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
|
|
261
|
+
|
|
262
|
+
HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
|
|
263
|
+
return r;
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
static HVX_Vector_x8 hvx_vec_load_q4_1x4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
|
|
267
|
+
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
|
268
|
+
|
|
269
|
+
const uint32_t qk = QK_Q4_0x4x2; // 256
|
|
270
|
+
const uint32_t nb = n / qk;
|
|
271
|
+
const uint32_t nloe = n % qk;
|
|
272
|
+
|
|
273
|
+
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
|
274
|
+
|
|
275
|
+
HVX_Vector_x8 r;
|
|
276
|
+
uint32_t i = 0;
|
|
277
|
+
|
|
278
|
+
#pragma unroll(2)
|
|
279
|
+
for (i=0; i < nb; i++) {
|
|
280
|
+
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
|
|
281
|
+
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements
|
|
282
|
+
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements
|
|
283
|
+
r.v[i*2+0] = v0;
|
|
284
|
+
r.v[i*2+1] = v1;
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
if (nloe) {
|
|
288
|
+
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
|
|
289
|
+
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements
|
|
290
|
+
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements
|
|
291
|
+
HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
|
|
292
|
+
r.v[i*2+0] = Q6_V_lo_W(v0_1_p);
|
|
293
|
+
r.v[i*2+1] = Q6_V_hi_W(v0_1_p);
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
return r;
|
|
297
|
+
}
|
|
298
|
+
|
|
148
299
|
static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) {
|
|
149
300
|
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
|
150
301
|
|
|
@@ -323,82 +474,96 @@ static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8
|
|
|
323
474
|
return hvx_vec_rmpy_x8_partial(x, y, 512);
|
|
324
475
|
}
|
|
325
476
|
|
|
326
|
-
static void
|
|
477
|
+
static void vec_dot_q4_1x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
|
|
327
478
|
assert(n % 32 == 0); // min sub-block size
|
|
328
479
|
assert((unsigned long) vx0 % 128 == 0);
|
|
329
480
|
assert((unsigned long) vy0 % 128 == 0);
|
|
330
481
|
|
|
331
482
|
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
332
483
|
|
|
333
|
-
const uint32_t x_dblk_size = 8 * 4 * 2;
|
|
484
|
+
const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes
|
|
334
485
|
const uint32_t x_qblk_size = qk / 2; // int4
|
|
335
486
|
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
|
336
487
|
|
|
337
|
-
const uint32_t y_dblk_size = 8 * 4 *
|
|
488
|
+
const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes
|
|
338
489
|
const uint32_t y_qblk_size = qk; // int8
|
|
339
490
|
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
340
491
|
|
|
341
492
|
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
|
|
342
|
-
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
|
|
493
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales/offsets
|
|
343
494
|
|
|
344
495
|
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
345
|
-
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
496
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums
|
|
346
497
|
|
|
347
498
|
// Row sum (sf)
|
|
348
499
|
HVX_Vector r0_sum = Q6_V_vzero();
|
|
349
500
|
|
|
350
|
-
// Multiply and accumulate into int32.
|
|
351
|
-
// Compute combined scale (fp32).
|
|
352
|
-
// Apply scale to acc and accumulate into the row sum (qf32).
|
|
353
|
-
|
|
354
501
|
const uint32_t nb = n / qk; // num full blocks
|
|
355
502
|
const uint32_t nloe = n % qk; // num leftover elemements
|
|
356
503
|
|
|
357
504
|
uint32_t i = 0;
|
|
358
505
|
for (; i < nb; i++) {
|
|
359
506
|
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
|
360
|
-
HVX_Vector_x8 r0_q =
|
|
507
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size);
|
|
361
508
|
|
|
362
509
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
363
510
|
|
|
364
|
-
HVX_Vector
|
|
365
|
-
|
|
511
|
+
HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size);
|
|
512
|
+
HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2);
|
|
513
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal));
|
|
514
|
+
HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal));
|
|
515
|
+
|
|
516
|
+
HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
517
|
+
HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2);
|
|
518
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal));
|
|
519
|
+
HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal));
|
|
366
520
|
|
|
367
521
|
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
522
|
+
HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s)));
|
|
368
523
|
|
|
369
524
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
525
|
+
HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms);
|
|
370
526
|
|
|
371
|
-
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(
|
|
527
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum));
|
|
372
528
|
}
|
|
373
529
|
|
|
374
530
|
// Process leftovers
|
|
375
531
|
if (nloe) {
|
|
376
532
|
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
|
377
|
-
HVX_Vector_x8 r0_q =
|
|
533
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
378
534
|
|
|
379
535
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
380
536
|
|
|
381
|
-
HVX_Vector
|
|
382
|
-
|
|
537
|
+
HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size);
|
|
538
|
+
HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2);
|
|
539
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal));
|
|
540
|
+
HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal));
|
|
541
|
+
|
|
542
|
+
HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
543
|
+
HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2);
|
|
544
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal));
|
|
545
|
+
HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal));
|
|
383
546
|
|
|
384
547
|
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
548
|
+
HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s)));
|
|
385
549
|
|
|
386
550
|
// Zero out unused elements
|
|
387
551
|
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
388
552
|
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
553
|
+
r0_ms = Q6_V_vand_QV(bmask, r0_ms);
|
|
389
554
|
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
390
555
|
|
|
391
556
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
557
|
+
HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms);
|
|
392
558
|
|
|
393
|
-
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(
|
|
559
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum));
|
|
394
560
|
}
|
|
395
561
|
|
|
396
562
|
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
|
|
397
|
-
|
|
398
563
|
hvx_vec_store_u(s0, 4, r0_sum);
|
|
399
564
|
}
|
|
400
565
|
|
|
401
|
-
static void
|
|
566
|
+
static void vec_dot_q4_1x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
402
567
|
const void * restrict vx0, const void * restrict vx1,
|
|
403
568
|
const void * restrict vy0) {
|
|
404
569
|
assert(n % 32 == 0); // min sub-block size
|
|
@@ -408,11 +573,11 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
|
408
573
|
|
|
409
574
|
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
410
575
|
|
|
411
|
-
const uint32_t x_dblk_size = 8 * 4 * 2;
|
|
576
|
+
const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes
|
|
412
577
|
const uint32_t x_qblk_size = qk / 2; // int4
|
|
413
578
|
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
|
414
579
|
|
|
415
|
-
const uint32_t y_dblk_size = 8 * 4 *
|
|
580
|
+
const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes
|
|
416
581
|
const uint32_t y_qblk_size = qk; // int8
|
|
417
582
|
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
418
583
|
|
|
@@ -422,77 +587,306 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
|
422
587
|
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
|
423
588
|
|
|
424
589
|
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
425
|
-
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
590
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums
|
|
426
591
|
|
|
427
592
|
// Row sum (sf)
|
|
428
593
|
HVX_Vector r0_sum = Q6_V_vzero();
|
|
429
594
|
HVX_Vector r1_sum = Q6_V_vzero();
|
|
430
595
|
|
|
431
|
-
// Multiply and accumulate into int32.
|
|
432
|
-
// Compute combined scale (fp32).
|
|
433
|
-
// Apply scale to acc and accumulate into the row sum (qf32).
|
|
434
|
-
|
|
435
596
|
const uint32_t nb = n / qk; // num full blocks
|
|
436
597
|
const uint32_t nloe = n % qk; // num leftover elemements
|
|
437
598
|
|
|
438
599
|
uint32_t i = 0;
|
|
439
600
|
for (; i < nb; i++) {
|
|
440
601
|
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
|
441
|
-
HVX_Vector_x8 r0_q =
|
|
442
|
-
HVX_Vector_x8 r1_q =
|
|
602
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size);
|
|
603
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size);
|
|
443
604
|
|
|
444
605
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
445
606
|
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
|
446
607
|
|
|
447
|
-
HVX_Vector
|
|
448
|
-
|
|
449
|
-
HVX_Vector
|
|
608
|
+
HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size);
|
|
609
|
+
HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2);
|
|
610
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal));
|
|
611
|
+
HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal));
|
|
612
|
+
|
|
613
|
+
HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
614
|
+
HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2);
|
|
615
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal));
|
|
616
|
+
HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal));
|
|
617
|
+
|
|
618
|
+
HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
|
|
619
|
+
HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2);
|
|
620
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal));
|
|
621
|
+
HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal));
|
|
450
622
|
|
|
451
623
|
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
624
|
+
HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s)));
|
|
625
|
+
|
|
452
626
|
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
|
627
|
+
HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s)));
|
|
453
628
|
|
|
454
629
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
630
|
+
HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms);
|
|
631
|
+
|
|
455
632
|
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
633
|
+
HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms);
|
|
456
634
|
|
|
457
|
-
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(
|
|
458
|
-
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(
|
|
635
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum));
|
|
636
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum));
|
|
459
637
|
}
|
|
460
638
|
|
|
461
639
|
// Process leftovers
|
|
462
640
|
if (nloe) {
|
|
463
641
|
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
|
464
|
-
HVX_Vector_x8 r0_q =
|
|
465
|
-
HVX_Vector_x8 r1_q =
|
|
642
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
643
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
466
644
|
|
|
467
645
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
468
646
|
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
|
|
469
647
|
|
|
470
|
-
HVX_Vector
|
|
471
|
-
|
|
472
|
-
HVX_Vector
|
|
648
|
+
HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size);
|
|
649
|
+
HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2);
|
|
650
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal));
|
|
651
|
+
HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal));
|
|
652
|
+
|
|
653
|
+
HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
654
|
+
HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2);
|
|
655
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal));
|
|
656
|
+
HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal));
|
|
657
|
+
|
|
658
|
+
HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
|
|
659
|
+
HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2);
|
|
660
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal));
|
|
661
|
+
HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal));
|
|
473
662
|
|
|
474
663
|
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
664
|
+
HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s)));
|
|
665
|
+
|
|
475
666
|
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
|
667
|
+
HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s)));
|
|
476
668
|
|
|
477
669
|
// Zero out unused elements
|
|
478
670
|
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
479
671
|
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
672
|
+
r0_ms = Q6_V_vand_QV(bmask, r0_ms);
|
|
480
673
|
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
|
674
|
+
r1_ms = Q6_V_vand_QV(bmask, r1_ms);
|
|
481
675
|
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
482
676
|
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
|
483
677
|
|
|
484
678
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
679
|
+
HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms);
|
|
680
|
+
|
|
485
681
|
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
682
|
+
HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms);
|
|
486
683
|
|
|
487
|
-
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(
|
|
488
|
-
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(
|
|
684
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum));
|
|
685
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum));
|
|
489
686
|
}
|
|
490
687
|
|
|
491
688
|
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
|
|
492
689
|
hvx_vec_store_u(s0, 8, rsum);
|
|
493
690
|
}
|
|
494
691
|
|
|
495
|
-
static void
|
|
692
|
+
static void vec_dot_q4_1x4x2_q8x4x2_4x1(const int n, float * restrict s0,
|
|
693
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
694
|
+
const void * restrict vx2, const void * restrict vx3,
|
|
695
|
+
const void * restrict vy0) {
|
|
696
|
+
assert(n % 32 == 0); // min sub-block size
|
|
697
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
698
|
+
assert((unsigned long) vx1 % 128 == 0);
|
|
699
|
+
assert((unsigned long) vx2 % 128 == 0);
|
|
700
|
+
assert((unsigned long) vx3 % 128 == 0);
|
|
701
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
702
|
+
|
|
703
|
+
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
704
|
+
|
|
705
|
+
const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes
|
|
706
|
+
const uint32_t x_qblk_size = qk / 2; // int4
|
|
707
|
+
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
|
708
|
+
|
|
709
|
+
const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes
|
|
710
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
711
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
712
|
+
|
|
713
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
|
|
714
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
|
715
|
+
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
|
716
|
+
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
|
717
|
+
const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first
|
|
718
|
+
const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales
|
|
719
|
+
const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first
|
|
720
|
+
const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales
|
|
721
|
+
|
|
722
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
723
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums
|
|
724
|
+
|
|
725
|
+
// Row sum (sf)
|
|
726
|
+
HVX_Vector r0_sum = Q6_V_vzero();
|
|
727
|
+
HVX_Vector r1_sum = Q6_V_vzero();
|
|
728
|
+
HVX_Vector r2_sum = Q6_V_vzero();
|
|
729
|
+
HVX_Vector r3_sum = Q6_V_vzero();
|
|
730
|
+
|
|
731
|
+
const uint32_t nb = n / qk; // num full blocks
|
|
732
|
+
const uint32_t nloe = n % qk; // num leftover elements
|
|
733
|
+
|
|
734
|
+
uint32_t i = 0;
|
|
735
|
+
for (; i < nb; i++) {
|
|
736
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
|
737
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size);
|
|
738
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size);
|
|
739
|
+
HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_full(r2_x_q + i * x_qblk_size);
|
|
740
|
+
HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_full(r3_x_q + i * x_qblk_size);
|
|
741
|
+
|
|
742
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
743
|
+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
|
744
|
+
HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q));
|
|
745
|
+
HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q));
|
|
746
|
+
|
|
747
|
+
HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size);
|
|
748
|
+
HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2);
|
|
749
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal));
|
|
750
|
+
HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal));
|
|
751
|
+
|
|
752
|
+
HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
753
|
+
HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2);
|
|
754
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal));
|
|
755
|
+
HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal));
|
|
756
|
+
|
|
757
|
+
HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
|
|
758
|
+
HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2);
|
|
759
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal));
|
|
760
|
+
HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal));
|
|
761
|
+
|
|
762
|
+
HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size);
|
|
763
|
+
HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2);
|
|
764
|
+
HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal));
|
|
765
|
+
HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal));
|
|
766
|
+
|
|
767
|
+
HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size);
|
|
768
|
+
HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2);
|
|
769
|
+
HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal));
|
|
770
|
+
HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal));
|
|
771
|
+
|
|
772
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
773
|
+
HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s)));
|
|
774
|
+
|
|
775
|
+
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
|
776
|
+
HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s)));
|
|
777
|
+
|
|
778
|
+
HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d)));
|
|
779
|
+
HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s)));
|
|
780
|
+
|
|
781
|
+
HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d)));
|
|
782
|
+
HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s)));
|
|
783
|
+
|
|
784
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
785
|
+
HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms);
|
|
786
|
+
|
|
787
|
+
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
788
|
+
HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms);
|
|
789
|
+
|
|
790
|
+
HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
|
|
791
|
+
HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms);
|
|
792
|
+
|
|
793
|
+
HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
|
|
794
|
+
HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms);
|
|
795
|
+
|
|
796
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum));
|
|
797
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum));
|
|
798
|
+
r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum));
|
|
799
|
+
r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum));
|
|
800
|
+
}
|
|
801
|
+
|
|
802
|
+
if (nloe) {
|
|
803
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
|
804
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
805
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
806
|
+
HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_partial(r2_x_q + i * x_qblk_size, nloe);
|
|
807
|
+
HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_partial(r3_x_q + i * x_qblk_size, nloe);
|
|
808
|
+
|
|
809
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
810
|
+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
|
|
811
|
+
HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe));
|
|
812
|
+
HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe));
|
|
813
|
+
|
|
814
|
+
HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size);
|
|
815
|
+
HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2);
|
|
816
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal));
|
|
817
|
+
HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal));
|
|
818
|
+
|
|
819
|
+
HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
820
|
+
HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2);
|
|
821
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal));
|
|
822
|
+
HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal));
|
|
823
|
+
|
|
824
|
+
HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
|
|
825
|
+
HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2);
|
|
826
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal));
|
|
827
|
+
HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal));
|
|
828
|
+
|
|
829
|
+
HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size);
|
|
830
|
+
HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2);
|
|
831
|
+
HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal));
|
|
832
|
+
HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal));
|
|
833
|
+
|
|
834
|
+
HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size);
|
|
835
|
+
HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2);
|
|
836
|
+
HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal));
|
|
837
|
+
HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal));
|
|
838
|
+
|
|
839
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
840
|
+
HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s)));
|
|
841
|
+
|
|
842
|
+
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
|
843
|
+
HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s)));
|
|
844
|
+
|
|
845
|
+
HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d)));
|
|
846
|
+
HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s)));
|
|
847
|
+
|
|
848
|
+
HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d)));
|
|
849
|
+
HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s)));
|
|
850
|
+
|
|
851
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
852
|
+
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
853
|
+
r0_ms = Q6_V_vand_QV(bmask, r0_ms);
|
|
854
|
+
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
|
855
|
+
r1_ms = Q6_V_vand_QV(bmask, r1_ms);
|
|
856
|
+
r2_dd = Q6_V_vand_QV(bmask, r2_dd);
|
|
857
|
+
r2_ms = Q6_V_vand_QV(bmask, r2_ms);
|
|
858
|
+
r3_dd = Q6_V_vand_QV(bmask, r3_dd);
|
|
859
|
+
r3_ms = Q6_V_vand_QV(bmask, r3_ms);
|
|
860
|
+
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
861
|
+
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
|
862
|
+
r2_ia = Q6_V_vand_QV(bmask, r2_ia);
|
|
863
|
+
r3_ia = Q6_V_vand_QV(bmask, r3_ia);
|
|
864
|
+
|
|
865
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
866
|
+
HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms);
|
|
867
|
+
|
|
868
|
+
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
869
|
+
HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms);
|
|
870
|
+
|
|
871
|
+
HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
|
|
872
|
+
HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms);
|
|
873
|
+
|
|
874
|
+
HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
|
|
875
|
+
HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms);
|
|
876
|
+
|
|
877
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum));
|
|
878
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum));
|
|
879
|
+
r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum));
|
|
880
|
+
r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum));
|
|
881
|
+
}
|
|
882
|
+
|
|
883
|
+
HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } };
|
|
884
|
+
HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in);
|
|
885
|
+
hvx_vec_store_u(s0, 16, rsum);
|
|
886
|
+
}
|
|
887
|
+
|
|
888
|
+
|
|
889
|
+
static void vec_dot_q4_1x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
|
|
496
890
|
const void * restrict vx0, const void * restrict vx1,
|
|
497
891
|
const void * restrict vy0, const void * restrict vy1) {
|
|
498
892
|
assert(n % 32 == 0);
|
|
@@ -503,11 +897,11 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
|
|
|
503
897
|
|
|
504
898
|
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
505
899
|
|
|
506
|
-
const uint32_t x_dblk_size = 8 * 4 * 2;
|
|
900
|
+
const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes
|
|
507
901
|
const uint32_t x_qblk_size = qk / 2; // int4
|
|
508
902
|
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
|
509
903
|
|
|
510
|
-
const uint32_t y_dblk_size = 8 * 4 *
|
|
904
|
+
const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes
|
|
511
905
|
const uint32_t y_qblk_size = qk; // int8
|
|
512
906
|
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
513
907
|
|
|
@@ -517,9 +911,9 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
|
|
|
517
911
|
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
|
518
912
|
|
|
519
913
|
const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
|
|
520
|
-
const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
|
|
914
|
+
const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales/sums
|
|
521
915
|
const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
|
|
522
|
-
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
|
|
916
|
+
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales/sums
|
|
523
917
|
|
|
524
918
|
// Row sums (sf) - 4 accumulators for 2×2 tile
|
|
525
919
|
HVX_Vector r0_c0_sum = Q6_V_vzero();
|
|
@@ -532,13 +926,13 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
|
|
|
532
926
|
|
|
533
927
|
uint32_t i = 0;
|
|
534
928
|
for (; i < nb; i++) {
|
|
535
|
-
// Load src1 columns
|
|
929
|
+
// Load src1 columns
|
|
536
930
|
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
|
|
537
931
|
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
|
|
538
932
|
|
|
539
|
-
// Load src0 rows
|
|
540
|
-
HVX_Vector_x8 r0_q =
|
|
541
|
-
HVX_Vector_x8 r1_q =
|
|
933
|
+
// Load src0 rows
|
|
934
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size);
|
|
935
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size);
|
|
542
936
|
|
|
543
937
|
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
|
|
544
938
|
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
|
|
@@ -547,16 +941,38 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
|
|
|
547
941
|
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
|
|
548
942
|
|
|
549
943
|
// Load scales
|
|
550
|
-
HVX_Vector
|
|
551
|
-
|
|
552
|
-
HVX_Vector
|
|
553
|
-
HVX_Vector
|
|
944
|
+
HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
|
|
945
|
+
HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2);
|
|
946
|
+
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal));
|
|
947
|
+
HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal));
|
|
948
|
+
|
|
949
|
+
HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
|
|
950
|
+
HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2);
|
|
951
|
+
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal));
|
|
952
|
+
HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal));
|
|
953
|
+
|
|
954
|
+
HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
955
|
+
HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2);
|
|
956
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal));
|
|
957
|
+
HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal));
|
|
958
|
+
|
|
959
|
+
HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
|
|
960
|
+
HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2);
|
|
961
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal));
|
|
962
|
+
HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal));
|
|
554
963
|
|
|
555
964
|
// Compute combined scales
|
|
556
965
|
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
|
966
|
+
HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s)));
|
|
967
|
+
|
|
557
968
|
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
|
969
|
+
HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s)));
|
|
970
|
+
|
|
558
971
|
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
|
972
|
+
HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s)));
|
|
973
|
+
|
|
559
974
|
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
|
975
|
+
HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s)));
|
|
560
976
|
|
|
561
977
|
// Apply scales and accumulate
|
|
562
978
|
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
|
@@ -564,40 +980,72 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
|
|
|
564
980
|
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
|
565
981
|
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
|
566
982
|
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
983
|
+
HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms);
|
|
984
|
+
HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms);
|
|
985
|
+
HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms);
|
|
986
|
+
HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms);
|
|
987
|
+
|
|
988
|
+
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum));
|
|
989
|
+
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum));
|
|
990
|
+
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum));
|
|
991
|
+
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum));
|
|
571
992
|
}
|
|
572
993
|
|
|
573
994
|
// Process leftovers
|
|
574
995
|
if (nloe) {
|
|
575
996
|
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
|
|
576
997
|
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
|
|
577
|
-
HVX_Vector_x8 r0_q =
|
|
578
|
-
HVX_Vector_x8 r1_q =
|
|
998
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
999
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
579
1000
|
|
|
580
1001
|
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
|
|
581
1002
|
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
|
|
582
1003
|
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
|
|
583
1004
|
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
|
|
584
1005
|
|
|
585
|
-
HVX_Vector
|
|
586
|
-
|
|
587
|
-
HVX_Vector
|
|
588
|
-
HVX_Vector
|
|
1006
|
+
HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
|
|
1007
|
+
HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2);
|
|
1008
|
+
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal));
|
|
1009
|
+
HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal));
|
|
1010
|
+
|
|
1011
|
+
HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
|
|
1012
|
+
HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2);
|
|
1013
|
+
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal));
|
|
1014
|
+
HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal));
|
|
1015
|
+
|
|
1016
|
+
HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
1017
|
+
HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2);
|
|
1018
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal));
|
|
1019
|
+
HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal));
|
|
1020
|
+
|
|
1021
|
+
HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
|
|
1022
|
+
HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2);
|
|
1023
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal));
|
|
1024
|
+
HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal));
|
|
589
1025
|
|
|
590
1026
|
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
|
1027
|
+
HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s)));
|
|
1028
|
+
|
|
591
1029
|
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
|
1030
|
+
HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s)));
|
|
1031
|
+
|
|
592
1032
|
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
|
1033
|
+
HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s)));
|
|
1034
|
+
|
|
593
1035
|
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
|
1036
|
+
HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s)));
|
|
594
1037
|
|
|
595
|
-
// Zero out unused
|
|
1038
|
+
// Zero out unused elements
|
|
596
1039
|
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
597
1040
|
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
|
|
1041
|
+
r0_c0_ms = Q6_V_vand_QV(bmask, r0_c0_ms);
|
|
598
1042
|
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
|
|
1043
|
+
r0_c1_ms = Q6_V_vand_QV(bmask, r0_c1_ms);
|
|
599
1044
|
r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
|
|
1045
|
+
r1_c0_ms = Q6_V_vand_QV(bmask, r1_c0_ms);
|
|
600
1046
|
r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
|
|
1047
|
+
r1_c1_ms = Q6_V_vand_QV(bmask, r1_c1_ms);
|
|
1048
|
+
|
|
601
1049
|
r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
|
|
602
1050
|
r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
|
|
603
1051
|
r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
|
|
@@ -608,10 +1056,15 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
|
|
|
608
1056
|
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
|
609
1057
|
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
|
610
1058
|
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
1059
|
+
HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms);
|
|
1060
|
+
HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms);
|
|
1061
|
+
HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms);
|
|
1062
|
+
HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms);
|
|
1063
|
+
|
|
1064
|
+
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum));
|
|
1065
|
+
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum));
|
|
1066
|
+
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum));
|
|
1067
|
+
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum));
|
|
615
1068
|
}
|
|
616
1069
|
|
|
617
1070
|
// Reduce and store results
|
|
@@ -622,26 +1075,26 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
|
|
|
622
1075
|
hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1
|
|
623
1076
|
}
|
|
624
1077
|
|
|
625
|
-
static void
|
|
1078
|
+
static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
|
|
626
1079
|
assert(n % 32 == 0); // min sub-block size
|
|
627
1080
|
assert((unsigned long) vx0 % 128 == 0);
|
|
628
1081
|
assert((unsigned long) vy0 % 128 == 0);
|
|
629
1082
|
|
|
630
1083
|
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
631
1084
|
|
|
632
|
-
const uint32_t x_dblk_size = 8 * 4 * 2;
|
|
633
|
-
const uint32_t x_qblk_size = qk;
|
|
634
|
-
const uint32_t x_qrow_size = n;
|
|
1085
|
+
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
1086
|
+
const uint32_t x_qblk_size = qk / 2; // int4
|
|
1087
|
+
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
|
635
1088
|
|
|
636
|
-
const uint32_t y_dblk_size = 8 * 4 * 2;
|
|
637
|
-
const uint32_t y_qblk_size = qk;
|
|
638
|
-
const uint32_t y_qrow_size = n;
|
|
1089
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
1090
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
1091
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
639
1092
|
|
|
640
|
-
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0);
|
|
641
|
-
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size);
|
|
1093
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
|
|
1094
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
|
|
642
1095
|
|
|
643
|
-
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);
|
|
644
|
-
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);
|
|
1096
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
1097
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
645
1098
|
|
|
646
1099
|
// Row sum (sf)
|
|
647
1100
|
HVX_Vector r0_sum = Q6_V_vzero();
|
|
@@ -651,12 +1104,12 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo
|
|
|
651
1104
|
// Apply scale to acc and accumulate into the row sum (qf32).
|
|
652
1105
|
|
|
653
1106
|
const uint32_t nb = n / qk; // num full blocks
|
|
654
|
-
|
|
1107
|
+
const uint32_t nloe = n % qk; // num leftover elemements
|
|
655
1108
|
|
|
656
1109
|
uint32_t i = 0;
|
|
657
1110
|
for (; i < nb; i++) {
|
|
658
1111
|
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
|
659
|
-
HVX_Vector_x8 r0_q =
|
|
1112
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
|
|
660
1113
|
|
|
661
1114
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
662
1115
|
|
|
@@ -673,7 +1126,7 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo
|
|
|
673
1126
|
// Process leftovers
|
|
674
1127
|
if (nloe) {
|
|
675
1128
|
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
|
676
|
-
HVX_Vector_x8 r0_q =
|
|
1129
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
677
1130
|
|
|
678
1131
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
679
1132
|
|
|
@@ -697,7 +1150,7 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo
|
|
|
697
1150
|
hvx_vec_store_u(s0, 4, r0_sum);
|
|
698
1151
|
}
|
|
699
1152
|
|
|
700
|
-
static void
|
|
1153
|
+
static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
701
1154
|
const void * restrict vx0, const void * restrict vx1,
|
|
702
1155
|
const void * restrict vy0) {
|
|
703
1156
|
assert(n % 32 == 0); // min sub-block size
|
|
@@ -708,8 +1161,8 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
|
708
1161
|
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
709
1162
|
|
|
710
1163
|
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
711
|
-
const uint32_t x_qblk_size = qk;
|
|
712
|
-
const uint32_t x_qrow_size = n;
|
|
1164
|
+
const uint32_t x_qblk_size = qk / 2; // int4
|
|
1165
|
+
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
|
713
1166
|
|
|
714
1167
|
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
715
1168
|
const uint32_t y_qblk_size = qk; // int8
|
|
@@ -723,7 +1176,7 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
|
723
1176
|
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
724
1177
|
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
725
1178
|
|
|
726
|
-
// Row sum (
|
|
1179
|
+
// Row sum (sf)
|
|
727
1180
|
HVX_Vector r0_sum = Q6_V_vzero();
|
|
728
1181
|
HVX_Vector r1_sum = Q6_V_vzero();
|
|
729
1182
|
|
|
@@ -732,13 +1185,13 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
|
732
1185
|
// Apply scale to acc and accumulate into the row sum (qf32).
|
|
733
1186
|
|
|
734
1187
|
const uint32_t nb = n / qk; // num full blocks
|
|
735
|
-
|
|
1188
|
+
const uint32_t nloe = n % qk; // num leftover elemements
|
|
736
1189
|
|
|
737
1190
|
uint32_t i = 0;
|
|
738
1191
|
for (; i < nb; i++) {
|
|
739
1192
|
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
|
740
|
-
HVX_Vector_x8 r0_q =
|
|
741
|
-
HVX_Vector_x8 r1_q =
|
|
1193
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
|
|
1194
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
|
|
742
1195
|
|
|
743
1196
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
744
1197
|
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
|
@@ -760,13 +1213,13 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
|
760
1213
|
// Process leftovers
|
|
761
1214
|
if (nloe) {
|
|
762
1215
|
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
|
763
|
-
HVX_Vector_x8 r0_q =
|
|
764
|
-
HVX_Vector_x8 r1_q =
|
|
1216
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
1217
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
765
1218
|
|
|
766
1219
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
767
1220
|
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
|
|
768
1221
|
|
|
769
|
-
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d
|
|
1222
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
770
1223
|
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
771
1224
|
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
772
1225
|
|
|
@@ -791,7 +1244,134 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
|
791
1244
|
hvx_vec_store_u(s0, 8, rsum);
|
|
792
1245
|
}
|
|
793
1246
|
|
|
794
|
-
static void
|
|
1247
|
+
static void vec_dot_q4x4x2_q8x4x2_4x1(const int n, float * restrict s0,
|
|
1248
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
1249
|
+
const void * restrict vx2, const void * restrict vx3,
|
|
1250
|
+
const void * restrict vy0) {
|
|
1251
|
+
assert(n % 32 == 0); // min sub-block size
|
|
1252
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
1253
|
+
assert((unsigned long) vx1 % 128 == 0);
|
|
1254
|
+
assert((unsigned long) vx2 % 128 == 0);
|
|
1255
|
+
assert((unsigned long) vx3 % 128 == 0);
|
|
1256
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
1257
|
+
|
|
1258
|
+
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
1259
|
+
|
|
1260
|
+
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
1261
|
+
const uint32_t x_qblk_size = qk / 2; // int4
|
|
1262
|
+
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
|
1263
|
+
|
|
1264
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
1265
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
1266
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
1267
|
+
|
|
1268
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;
|
|
1269
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;
|
|
1270
|
+
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;
|
|
1271
|
+
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;
|
|
1272
|
+
const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0;
|
|
1273
|
+
const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size;
|
|
1274
|
+
const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0;
|
|
1275
|
+
const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size;
|
|
1276
|
+
|
|
1277
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);
|
|
1278
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);
|
|
1279
|
+
|
|
1280
|
+
// Row sum (sf)
|
|
1281
|
+
HVX_Vector r0_sum = Q6_V_vzero();
|
|
1282
|
+
HVX_Vector r1_sum = Q6_V_vzero();
|
|
1283
|
+
HVX_Vector r2_sum = Q6_V_vzero();
|
|
1284
|
+
HVX_Vector r3_sum = Q6_V_vzero();
|
|
1285
|
+
|
|
1286
|
+
const uint32_t nb = n / qk; // num full blocks
|
|
1287
|
+
const uint32_t nloe = n % qk; // num leftover elements
|
|
1288
|
+
|
|
1289
|
+
uint32_t i = 0;
|
|
1290
|
+
for (; i < nb; i++) {
|
|
1291
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
|
1292
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
|
|
1293
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
|
|
1294
|
+
HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_full(r2_x_q + i * x_qblk_size);
|
|
1295
|
+
HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_full(r3_x_q + i * x_qblk_size);
|
|
1296
|
+
|
|
1297
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
1298
|
+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
|
1299
|
+
HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q));
|
|
1300
|
+
HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q));
|
|
1301
|
+
|
|
1302
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
1303
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
1304
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
1305
|
+
HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size));
|
|
1306
|
+
HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size));
|
|
1307
|
+
|
|
1308
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
1309
|
+
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
|
1310
|
+
HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d)));
|
|
1311
|
+
HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d)));
|
|
1312
|
+
|
|
1313
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
1314
|
+
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
1315
|
+
HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
|
|
1316
|
+
HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
|
|
1317
|
+
|
|
1318
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
1319
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
1320
|
+
r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum));
|
|
1321
|
+
r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum));
|
|
1322
|
+
}
|
|
1323
|
+
|
|
1324
|
+
if (nloe) {
|
|
1325
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
|
1326
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
1327
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
1328
|
+
HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_partial(r2_x_q + i * x_qblk_size, nloe);
|
|
1329
|
+
HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_partial(r3_x_q + i * x_qblk_size, nloe);
|
|
1330
|
+
|
|
1331
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
1332
|
+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
|
|
1333
|
+
HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe));
|
|
1334
|
+
HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe));
|
|
1335
|
+
|
|
1336
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
1337
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
1338
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
1339
|
+
HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size));
|
|
1340
|
+
HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size));
|
|
1341
|
+
|
|
1342
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
1343
|
+
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
|
1344
|
+
HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d)));
|
|
1345
|
+
HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d)));
|
|
1346
|
+
|
|
1347
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
1348
|
+
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
1349
|
+
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
|
1350
|
+
r2_dd = Q6_V_vand_QV(bmask, r2_dd);
|
|
1351
|
+
r3_dd = Q6_V_vand_QV(bmask, r3_dd);
|
|
1352
|
+
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
1353
|
+
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
|
1354
|
+
r2_ia = Q6_V_vand_QV(bmask, r2_ia);
|
|
1355
|
+
r3_ia = Q6_V_vand_QV(bmask, r3_ia);
|
|
1356
|
+
|
|
1357
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
1358
|
+
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
1359
|
+
HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
|
|
1360
|
+
HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
|
|
1361
|
+
|
|
1362
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
1363
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
1364
|
+
r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum));
|
|
1365
|
+
r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum));
|
|
1366
|
+
}
|
|
1367
|
+
|
|
1368
|
+
HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } };
|
|
1369
|
+
HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in);
|
|
1370
|
+
hvx_vec_store_u(s0, 16, rsum);
|
|
1371
|
+
}
|
|
1372
|
+
|
|
1373
|
+
|
|
1374
|
+
static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
|
|
795
1375
|
const void * restrict vx0, const void * restrict vx1,
|
|
796
1376
|
const void * restrict vy0, const void * restrict vy1) {
|
|
797
1377
|
assert(n % 32 == 0);
|
|
@@ -800,11 +1380,11 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
|
|
|
800
1380
|
assert((unsigned long) vy0 % 128 == 0);
|
|
801
1381
|
assert((unsigned long) vy1 % 128 == 0);
|
|
802
1382
|
|
|
803
|
-
const uint32_t qk =
|
|
1383
|
+
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
804
1384
|
|
|
805
1385
|
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
806
|
-
const uint32_t x_qblk_size = qk;
|
|
807
|
-
const uint32_t x_qrow_size = n;
|
|
1386
|
+
const uint32_t x_qblk_size = qk / 2; // int4
|
|
1387
|
+
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
|
808
1388
|
|
|
809
1389
|
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
810
1390
|
const uint32_t y_qblk_size = qk; // int8
|
|
@@ -836,8 +1416,8 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
|
|
|
836
1416
|
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
|
|
837
1417
|
|
|
838
1418
|
// Load src0 rows (reused across both src1 columns)
|
|
839
|
-
HVX_Vector_x8 r0_q =
|
|
840
|
-
HVX_Vector_x8 r1_q =
|
|
1419
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
|
|
1420
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
|
|
841
1421
|
|
|
842
1422
|
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
|
|
843
1423
|
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
|
|
@@ -873,8 +1453,8 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
|
|
|
873
1453
|
if (nloe) {
|
|
874
1454
|
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
|
|
875
1455
|
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
|
|
876
|
-
HVX_Vector_x8 r0_q =
|
|
877
|
-
HVX_Vector_x8 r1_q =
|
|
1456
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
1457
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
878
1458
|
|
|
879
1459
|
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
|
|
880
1460
|
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
|
|
@@ -891,63 +1471,1016 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
|
|
|
891
1471
|
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
|
892
1472
|
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
|
893
1473
|
|
|
894
|
-
// Zero out unused
|
|
1474
|
+
// Zero out unused scales
|
|
1475
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
1476
|
+
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
|
|
1477
|
+
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
|
|
1478
|
+
r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
|
|
1479
|
+
r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
|
|
1480
|
+
r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
|
|
1481
|
+
r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
|
|
1482
|
+
r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
|
|
1483
|
+
r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
|
|
1484
|
+
|
|
1485
|
+
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
|
1486
|
+
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
|
1487
|
+
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
|
1488
|
+
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
|
1489
|
+
|
|
1490
|
+
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
|
1491
|
+
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
|
1492
|
+
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
|
1493
|
+
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
|
1494
|
+
}
|
|
1495
|
+
|
|
1496
|
+
// Reduce and store results
|
|
1497
|
+
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
|
|
1498
|
+
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
|
|
1499
|
+
|
|
1500
|
+
hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0
|
|
1501
|
+
hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1
|
|
1502
|
+
}
|
|
1503
|
+
|
|
1504
|
+
static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
|
|
1505
|
+
assert(n % 32 == 0); // min sub-block size
|
|
1506
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
1507
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
1508
|
+
|
|
1509
|
+
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
1510
|
+
|
|
1511
|
+
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
1512
|
+
const uint32_t x_qblk_size = qk; // int8
|
|
1513
|
+
const uint32_t x_qrow_size = n; // int8 (not padded)
|
|
1514
|
+
|
|
1515
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
1516
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
1517
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
1518
|
+
|
|
1519
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
|
|
1520
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
|
|
1521
|
+
|
|
1522
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
1523
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
1524
|
+
|
|
1525
|
+
// Row sum (sf)
|
|
1526
|
+
HVX_Vector r0_sum = Q6_V_vzero();
|
|
1527
|
+
|
|
1528
|
+
// Multiply and accumulate into int32.
|
|
1529
|
+
// Compute combined scale (fp32).
|
|
1530
|
+
// Apply scale to acc and accumulate into the row sum (qf32).
|
|
1531
|
+
|
|
1532
|
+
const uint32_t nb = n / qk; // num full blocks
|
|
1533
|
+
int32_t nloe = n % qk; // num leftover elemements (must be signed)
|
|
1534
|
+
|
|
1535
|
+
uint32_t i = 0;
|
|
1536
|
+
for (; i < nb; i++) {
|
|
1537
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
|
1538
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
|
|
1539
|
+
|
|
1540
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
1541
|
+
|
|
1542
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
1543
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
1544
|
+
|
|
1545
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
1546
|
+
|
|
1547
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
1548
|
+
|
|
1549
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
1550
|
+
}
|
|
1551
|
+
|
|
1552
|
+
// Process leftovers
|
|
1553
|
+
if (nloe) {
|
|
1554
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
|
1555
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
1556
|
+
|
|
1557
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
1558
|
+
|
|
1559
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
1560
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
1561
|
+
|
|
1562
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
1563
|
+
|
|
1564
|
+
// Zero out unused elements
|
|
1565
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
1566
|
+
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
1567
|
+
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
1568
|
+
|
|
1569
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
1570
|
+
|
|
1571
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
1572
|
+
}
|
|
1573
|
+
|
|
1574
|
+
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
|
|
1575
|
+
|
|
1576
|
+
hvx_vec_store_u(s0, 4, r0_sum);
|
|
1577
|
+
}
|
|
1578
|
+
|
|
1579
|
+
static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
1580
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
1581
|
+
const void * restrict vy0) {
|
|
1582
|
+
assert(n % 32 == 0); // min sub-block size
|
|
1583
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
1584
|
+
assert((unsigned long) vx1 % 128 == 0);
|
|
1585
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
1586
|
+
|
|
1587
|
+
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
1588
|
+
|
|
1589
|
+
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
1590
|
+
const uint32_t x_qblk_size = qk; // int8
|
|
1591
|
+
const uint32_t x_qrow_size = n; // int8 (not padded)
|
|
1592
|
+
|
|
1593
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
1594
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
1595
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
1596
|
+
|
|
1597
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
|
|
1598
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
|
1599
|
+
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
|
1600
|
+
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
|
1601
|
+
|
|
1602
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
1603
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
1604
|
+
|
|
1605
|
+
// Row sum (qf32)
|
|
1606
|
+
HVX_Vector r0_sum = Q6_V_vzero();
|
|
1607
|
+
HVX_Vector r1_sum = Q6_V_vzero();
|
|
1608
|
+
|
|
1609
|
+
// Multiply and accumulate into int32.
|
|
1610
|
+
// Compute combined scale (fp32).
|
|
1611
|
+
// Apply scale to acc and accumulate into the row sum (qf32).
|
|
1612
|
+
|
|
1613
|
+
const uint32_t nb = n / qk; // num full blocks
|
|
1614
|
+
int32_t nloe = n % qk; // num leftover elemements (must be signed)
|
|
1615
|
+
|
|
1616
|
+
uint32_t i = 0;
|
|
1617
|
+
for (; i < nb; i++) {
|
|
1618
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
|
1619
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
|
|
1620
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
|
|
1621
|
+
|
|
1622
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
1623
|
+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
|
1624
|
+
|
|
1625
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
1626
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
1627
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
1628
|
+
|
|
1629
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
1630
|
+
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
|
1631
|
+
|
|
1632
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
1633
|
+
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
1634
|
+
|
|
1635
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
1636
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
1637
|
+
}
|
|
1638
|
+
|
|
1639
|
+
// Process leftovers
|
|
1640
|
+
if (nloe) {
|
|
1641
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
|
1642
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
1643
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
1644
|
+
|
|
1645
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
1646
|
+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
|
|
1647
|
+
|
|
1648
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
1649
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
1650
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
1651
|
+
|
|
1652
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
1653
|
+
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
|
1654
|
+
|
|
1655
|
+
// Zero out unused elements
|
|
1656
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
1657
|
+
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
1658
|
+
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
|
1659
|
+
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
1660
|
+
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
|
1661
|
+
|
|
1662
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
1663
|
+
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
1664
|
+
|
|
1665
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
1666
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
1667
|
+
}
|
|
1668
|
+
|
|
1669
|
+
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
|
|
1670
|
+
hvx_vec_store_u(s0, 8, rsum);
|
|
1671
|
+
}
|
|
1672
|
+
|
|
1673
|
+
static void vec_dot_q8x4x2_q8x4x2_4x1(const int n, float * restrict s0,
|
|
1674
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
1675
|
+
const void * restrict vx2, const void * restrict vx3,
|
|
1676
|
+
const void * restrict vy0) {
|
|
1677
|
+
assert(n % 32 == 0); // min sub-block size
|
|
1678
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
1679
|
+
assert((unsigned long) vx1 % 128 == 0);
|
|
1680
|
+
assert((unsigned long) vx2 % 128 == 0);
|
|
1681
|
+
assert((unsigned long) vx3 % 128 == 0);
|
|
1682
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
1683
|
+
|
|
1684
|
+
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
1685
|
+
|
|
1686
|
+
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
1687
|
+
const uint32_t x_qblk_size = qk; // int8
|
|
1688
|
+
const uint32_t x_qrow_size = n; // int8 (not padded)
|
|
1689
|
+
|
|
1690
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
1691
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
1692
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
1693
|
+
|
|
1694
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
|
|
1695
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
|
1696
|
+
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
|
1697
|
+
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
|
1698
|
+
const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first
|
|
1699
|
+
const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales
|
|
1700
|
+
const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first
|
|
1701
|
+
const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales
|
|
1702
|
+
|
|
1703
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
1704
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
1705
|
+
|
|
1706
|
+
// Row sum (qf32)
|
|
1707
|
+
HVX_Vector r0_sum = Q6_V_vzero();
|
|
1708
|
+
HVX_Vector r1_sum = Q6_V_vzero();
|
|
1709
|
+
HVX_Vector r2_sum = Q6_V_vzero();
|
|
1710
|
+
HVX_Vector r3_sum = Q6_V_vzero();
|
|
1711
|
+
|
|
1712
|
+
const uint32_t nb = n / qk; // num full blocks
|
|
1713
|
+
int32_t nloe = n % qk; // num leftover elemements (must be signed)
|
|
1714
|
+
|
|
1715
|
+
uint32_t i = 0;
|
|
1716
|
+
for (; i < nb; i++) {
|
|
1717
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
|
1718
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
|
|
1719
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
|
|
1720
|
+
HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_full(r2_x_q + i * x_qblk_size);
|
|
1721
|
+
HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_full(r3_x_q + i * x_qblk_size);
|
|
1722
|
+
|
|
1723
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
1724
|
+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
|
1725
|
+
HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q));
|
|
1726
|
+
HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q));
|
|
1727
|
+
|
|
1728
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
1729
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
1730
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
1731
|
+
HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size));
|
|
1732
|
+
HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size));
|
|
1733
|
+
|
|
1734
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
1735
|
+
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
|
1736
|
+
HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d)));
|
|
1737
|
+
HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d)));
|
|
1738
|
+
|
|
1739
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
1740
|
+
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
1741
|
+
HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
|
|
1742
|
+
HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
|
|
1743
|
+
|
|
1744
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
1745
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
1746
|
+
r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum));
|
|
1747
|
+
r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum));
|
|
1748
|
+
}
|
|
1749
|
+
|
|
1750
|
+
if (nloe) {
|
|
1751
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
|
1752
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
1753
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
1754
|
+
HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_partial(r2_x_q + i * x_qblk_size, nloe);
|
|
1755
|
+
HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_partial(r3_x_q + i * x_qblk_size, nloe);
|
|
1756
|
+
|
|
1757
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
1758
|
+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
|
|
1759
|
+
HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe));
|
|
1760
|
+
HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe));
|
|
1761
|
+
|
|
1762
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
1763
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
1764
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
1765
|
+
HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size));
|
|
1766
|
+
HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size));
|
|
1767
|
+
|
|
1768
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
1769
|
+
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
|
1770
|
+
HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d)));
|
|
1771
|
+
HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d)));
|
|
1772
|
+
|
|
1773
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
1774
|
+
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
1775
|
+
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
|
1776
|
+
r2_dd = Q6_V_vand_QV(bmask, r2_dd);
|
|
1777
|
+
r3_dd = Q6_V_vand_QV(bmask, r3_dd);
|
|
1778
|
+
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
1779
|
+
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
|
1780
|
+
r2_ia = Q6_V_vand_QV(bmask, r2_ia);
|
|
1781
|
+
r3_ia = Q6_V_vand_QV(bmask, r3_ia);
|
|
1782
|
+
|
|
1783
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
1784
|
+
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
1785
|
+
HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
|
|
1786
|
+
HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
|
|
1787
|
+
|
|
1788
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
1789
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
1790
|
+
r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum));
|
|
1791
|
+
r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum));
|
|
1792
|
+
}
|
|
1793
|
+
|
|
1794
|
+
HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } };
|
|
1795
|
+
HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in);
|
|
1796
|
+
hvx_vec_store_u(s0, 16, rsum);
|
|
1797
|
+
}
|
|
1798
|
+
|
|
1799
|
+
|
|
1800
|
+
static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
|
|
1801
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
1802
|
+
const void * restrict vy0, const void * restrict vy1) {
|
|
1803
|
+
assert(n % 32 == 0);
|
|
1804
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
1805
|
+
assert((unsigned long) vx1 % 128 == 0);
|
|
1806
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
1807
|
+
assert((unsigned long) vy1 % 128 == 0);
|
|
1808
|
+
|
|
1809
|
+
const uint32_t qk = QK_Q8_0x4x2 * 4;
|
|
1810
|
+
|
|
1811
|
+
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
1812
|
+
const uint32_t x_qblk_size = qk; // int8
|
|
1813
|
+
const uint32_t x_qrow_size = n; // int8 (not padded)
|
|
1814
|
+
|
|
1815
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
1816
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
1817
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
1818
|
+
|
|
1819
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
|
|
1820
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
|
1821
|
+
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
|
1822
|
+
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
|
1823
|
+
|
|
1824
|
+
const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
|
|
1825
|
+
const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
|
|
1826
|
+
const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
|
|
1827
|
+
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
|
|
1828
|
+
|
|
1829
|
+
// Row sums (sf) - 4 accumulators for 2×2 tile
|
|
1830
|
+
HVX_Vector r0_c0_sum = Q6_V_vzero();
|
|
1831
|
+
HVX_Vector r0_c1_sum = Q6_V_vzero();
|
|
1832
|
+
HVX_Vector r1_c0_sum = Q6_V_vzero();
|
|
1833
|
+
HVX_Vector r1_c1_sum = Q6_V_vzero();
|
|
1834
|
+
|
|
1835
|
+
const uint32_t nb = n / qk; // num full blocks
|
|
1836
|
+
const uint32_t nloe = n % qk; // num leftover elements
|
|
1837
|
+
|
|
1838
|
+
uint32_t i = 0;
|
|
1839
|
+
for (; i < nb; i++) {
|
|
1840
|
+
// Load src1 columns (reused across both src0 rows)
|
|
1841
|
+
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
|
|
1842
|
+
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
|
|
1843
|
+
|
|
1844
|
+
// Load src0 rows (reused across both src1 columns)
|
|
1845
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
|
|
1846
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
|
|
1847
|
+
|
|
1848
|
+
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
|
|
1849
|
+
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
|
|
1850
|
+
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
|
|
1851
|
+
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
|
|
1852
|
+
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
|
|
1853
|
+
|
|
1854
|
+
// Load scales
|
|
1855
|
+
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
|
|
1856
|
+
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
|
|
1857
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
1858
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
1859
|
+
|
|
1860
|
+
// Compute combined scales
|
|
1861
|
+
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
|
1862
|
+
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
|
1863
|
+
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
|
1864
|
+
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
|
1865
|
+
|
|
1866
|
+
// Apply scales and accumulate
|
|
1867
|
+
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
|
1868
|
+
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
|
1869
|
+
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
|
1870
|
+
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
|
1871
|
+
|
|
1872
|
+
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
|
1873
|
+
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
|
1874
|
+
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
|
1875
|
+
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
|
1876
|
+
}
|
|
1877
|
+
|
|
1878
|
+
// Process leftovers
|
|
1879
|
+
if (nloe) {
|
|
1880
|
+
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
|
|
1881
|
+
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
|
|
1882
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
1883
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
1884
|
+
|
|
1885
|
+
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
|
|
1886
|
+
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
|
|
1887
|
+
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
|
|
1888
|
+
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
|
|
1889
|
+
|
|
1890
|
+
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
|
|
1891
|
+
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
|
|
1892
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
1893
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
1894
|
+
|
|
1895
|
+
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
|
1896
|
+
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
|
1897
|
+
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
|
1898
|
+
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
|
1899
|
+
|
|
1900
|
+
// Zero out unused elements
|
|
1901
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
1902
|
+
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
|
|
1903
|
+
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
|
|
1904
|
+
r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
|
|
1905
|
+
r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
|
|
1906
|
+
r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
|
|
1907
|
+
r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
|
|
1908
|
+
r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
|
|
1909
|
+
r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
|
|
1910
|
+
|
|
1911
|
+
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
|
1912
|
+
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
|
1913
|
+
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
|
1914
|
+
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
|
1915
|
+
|
|
1916
|
+
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
|
1917
|
+
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
|
1918
|
+
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
|
1919
|
+
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
|
1920
|
+
}
|
|
1921
|
+
|
|
1922
|
+
// Reduce and store results
|
|
1923
|
+
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
|
|
1924
|
+
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
|
|
1925
|
+
|
|
1926
|
+
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
|
|
1927
|
+
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
|
|
1928
|
+
}
|
|
1929
|
+
|
|
1930
|
+
// ======== IQ4_NL x Q8_0 vec_dot kernels ========
|
|
1931
|
+
// Same structure as Q4_0 vec_dot but uses IQ4_NL LUT-based load (4-bit index -> int8 kvalue).
|
|
1932
|
+
// Scale format is identical to Q4_0 (fp16 scales).
|
|
1933
|
+
|
|
1934
|
+
static void vec_dot_iq4nlx4x2_q8x4x2_1x1(const int n,
|
|
1935
|
+
float * restrict s0,
|
|
1936
|
+
const void * restrict vx0,
|
|
1937
|
+
const void * restrict vy0) {
|
|
1938
|
+
assert(n % 32 == 0);
|
|
1939
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
1940
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
1941
|
+
|
|
1942
|
+
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
1943
|
+
|
|
1944
|
+
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
1945
|
+
const uint32_t x_qblk_size = qk / 2; // int4
|
|
1946
|
+
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
|
1947
|
+
|
|
1948
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
1949
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
1950
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
1951
|
+
|
|
1952
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
|
|
1953
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
|
|
1954
|
+
|
|
1955
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
1956
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
1957
|
+
|
|
1958
|
+
HVX_Vector r0_sum = Q6_V_vzero();
|
|
1959
|
+
|
|
1960
|
+
const uint32_t nb = n / qk;
|
|
1961
|
+
const uint32_t nloe = n % qk;
|
|
1962
|
+
|
|
1963
|
+
uint32_t i = 0;
|
|
1964
|
+
for (; i < nb; i++) {
|
|
1965
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
|
1966
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
|
|
1967
|
+
|
|
1968
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
1969
|
+
|
|
1970
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
1971
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
1972
|
+
|
|
1973
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
1974
|
+
|
|
1975
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
1976
|
+
|
|
1977
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
1978
|
+
}
|
|
1979
|
+
|
|
1980
|
+
if (nloe) {
|
|
1981
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
|
1982
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
1983
|
+
|
|
1984
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
1985
|
+
|
|
1986
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
1987
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
1988
|
+
|
|
1989
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
1990
|
+
|
|
1991
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
1992
|
+
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
1993
|
+
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
1994
|
+
|
|
1995
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
1996
|
+
|
|
1997
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
1998
|
+
}
|
|
1999
|
+
|
|
2000
|
+
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
|
|
2001
|
+
|
|
2002
|
+
hvx_vec_store_u(s0, 4, r0_sum);
|
|
2003
|
+
}
|
|
2004
|
+
|
|
2005
|
+
static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n,
|
|
2006
|
+
float * restrict s0,
|
|
2007
|
+
const void * restrict vx0,
|
|
2008
|
+
const void * restrict vx1,
|
|
2009
|
+
const void * restrict vy0) {
|
|
2010
|
+
assert(n % 32 == 0);
|
|
2011
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
2012
|
+
assert((unsigned long) vx1 % 128 == 0);
|
|
2013
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
2014
|
+
|
|
2015
|
+
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
2016
|
+
|
|
2017
|
+
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
2018
|
+
const uint32_t x_qblk_size = qk / 2; // int4
|
|
2019
|
+
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
|
2020
|
+
|
|
2021
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
2022
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
2023
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
2024
|
+
|
|
2025
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
|
|
2026
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
|
2027
|
+
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
|
2028
|
+
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
|
2029
|
+
|
|
2030
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
2031
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
2032
|
+
|
|
2033
|
+
HVX_Vector r0_sum = Q6_V_vzero();
|
|
2034
|
+
HVX_Vector r1_sum = Q6_V_vzero();
|
|
2035
|
+
|
|
2036
|
+
const uint32_t nb = n / qk;
|
|
2037
|
+
const uint32_t nloe = n % qk;
|
|
2038
|
+
|
|
2039
|
+
uint32_t i = 0;
|
|
2040
|
+
for (; i < nb; i++) {
|
|
2041
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
|
2042
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
|
|
2043
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size);
|
|
2044
|
+
|
|
2045
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
2046
|
+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
|
2047
|
+
|
|
2048
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
2049
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
2050
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
2051
|
+
|
|
2052
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
2053
|
+
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
|
2054
|
+
|
|
2055
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
2056
|
+
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
2057
|
+
|
|
2058
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
2059
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
2060
|
+
}
|
|
2061
|
+
|
|
2062
|
+
if (nloe) {
|
|
2063
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
|
2064
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
2065
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
2066
|
+
|
|
2067
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
2068
|
+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
|
|
2069
|
+
|
|
2070
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
2071
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
2072
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
2073
|
+
|
|
2074
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
2075
|
+
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
|
2076
|
+
|
|
2077
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
2078
|
+
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
2079
|
+
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
|
2080
|
+
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
2081
|
+
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
|
2082
|
+
|
|
2083
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
2084
|
+
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
2085
|
+
|
|
2086
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
2087
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
2088
|
+
}
|
|
2089
|
+
|
|
2090
|
+
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
|
|
2091
|
+
hvx_vec_store_u(s0, 8, rsum);
|
|
2092
|
+
}
|
|
2093
|
+
|
|
2094
|
+
static void vec_dot_iq4nlx4x2_q8x4x2_4x1(const int n,
|
|
2095
|
+
float * restrict s0,
|
|
2096
|
+
const void * restrict vx0,
|
|
2097
|
+
const void * restrict vx1,
|
|
2098
|
+
const void * restrict vx2,
|
|
2099
|
+
const void * restrict vx3,
|
|
2100
|
+
const void * restrict vy0) {
|
|
2101
|
+
assert(n % 32 == 0);
|
|
2102
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
2103
|
+
assert((unsigned long) vx1 % 128 == 0);
|
|
2104
|
+
assert((unsigned long) vx2 % 128 == 0);
|
|
2105
|
+
assert((unsigned long) vx3 % 128 == 0);
|
|
2106
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
2107
|
+
|
|
2108
|
+
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
2109
|
+
|
|
2110
|
+
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
2111
|
+
const uint32_t x_qblk_size = qk / 2; // int4
|
|
2112
|
+
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
|
2113
|
+
|
|
2114
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
2115
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
2116
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
2117
|
+
|
|
2118
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
|
|
2119
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
|
2120
|
+
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
|
2121
|
+
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
|
2122
|
+
const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first
|
|
2123
|
+
const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales
|
|
2124
|
+
const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first
|
|
2125
|
+
const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales
|
|
2126
|
+
|
|
2127
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
2128
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
2129
|
+
|
|
2130
|
+
HVX_Vector r0_sum = Q6_V_vzero();
|
|
2131
|
+
HVX_Vector r1_sum = Q6_V_vzero();
|
|
2132
|
+
HVX_Vector r2_sum = Q6_V_vzero();
|
|
2133
|
+
HVX_Vector r3_sum = Q6_V_vzero();
|
|
2134
|
+
|
|
2135
|
+
const uint32_t nb = n / qk;
|
|
2136
|
+
const uint32_t nloe = n % qk;
|
|
2137
|
+
|
|
2138
|
+
uint32_t i = 0;
|
|
2139
|
+
for (; i < nb; i++) {
|
|
2140
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
|
2141
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
|
|
2142
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size);
|
|
2143
|
+
HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_full(r2_x_q + i * x_qblk_size);
|
|
2144
|
+
HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_full(r3_x_q + i * x_qblk_size);
|
|
2145
|
+
|
|
2146
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
2147
|
+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
|
2148
|
+
HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q));
|
|
2149
|
+
HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q));
|
|
2150
|
+
|
|
2151
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
2152
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
2153
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
2154
|
+
HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size));
|
|
2155
|
+
HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size));
|
|
2156
|
+
|
|
2157
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
2158
|
+
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
|
2159
|
+
HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d)));
|
|
2160
|
+
HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d)));
|
|
2161
|
+
|
|
2162
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
2163
|
+
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
2164
|
+
HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
|
|
2165
|
+
HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
|
|
2166
|
+
|
|
2167
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
2168
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
2169
|
+
r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum));
|
|
2170
|
+
r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum));
|
|
2171
|
+
}
|
|
2172
|
+
|
|
2173
|
+
if (nloe) {
|
|
2174
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
|
2175
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
2176
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
2177
|
+
HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_partial(r2_x_q + i * x_qblk_size, nloe);
|
|
2178
|
+
HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_partial(r3_x_q + i * x_qblk_size, nloe);
|
|
2179
|
+
|
|
2180
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
2181
|
+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
|
|
2182
|
+
HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe));
|
|
2183
|
+
HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe));
|
|
2184
|
+
|
|
2185
|
+
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
|
2186
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
2187
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
2188
|
+
HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size));
|
|
2189
|
+
HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size));
|
|
2190
|
+
|
|
2191
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
|
2192
|
+
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
|
2193
|
+
HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d)));
|
|
2194
|
+
HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d)));
|
|
2195
|
+
|
|
2196
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
2197
|
+
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
2198
|
+
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
|
2199
|
+
r2_dd = Q6_V_vand_QV(bmask, r2_dd);
|
|
2200
|
+
r3_dd = Q6_V_vand_QV(bmask, r3_dd);
|
|
2201
|
+
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
2202
|
+
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
|
2203
|
+
r2_ia = Q6_V_vand_QV(bmask, r2_ia);
|
|
2204
|
+
r3_ia = Q6_V_vand_QV(bmask, r3_ia);
|
|
2205
|
+
|
|
2206
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
2207
|
+
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
2208
|
+
HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
|
|
2209
|
+
HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
|
|
2210
|
+
|
|
2211
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
2212
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
2213
|
+
r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum));
|
|
2214
|
+
r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum));
|
|
2215
|
+
}
|
|
2216
|
+
|
|
2217
|
+
HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } };
|
|
2218
|
+
HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in);
|
|
2219
|
+
hvx_vec_store_u(s0, 16, rsum);
|
|
2220
|
+
}
|
|
2221
|
+
|
|
2222
|
+
|
|
2223
|
+
static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n,
|
|
2224
|
+
float * restrict s0,
|
|
2225
|
+
float * restrict s1,
|
|
2226
|
+
const void * restrict vx0,
|
|
2227
|
+
const void * restrict vx1,
|
|
2228
|
+
const void * restrict vy0,
|
|
2229
|
+
const void * restrict vy1) {
|
|
2230
|
+
assert(n % 32 == 0);
|
|
2231
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
2232
|
+
assert((unsigned long) vx1 % 128 == 0);
|
|
2233
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
2234
|
+
assert((unsigned long) vy1 % 128 == 0);
|
|
2235
|
+
|
|
2236
|
+
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
|
2237
|
+
|
|
2238
|
+
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
2239
|
+
const uint32_t x_qblk_size = qk / 2; // int4
|
|
2240
|
+
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
|
2241
|
+
|
|
2242
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
2243
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
2244
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
2245
|
+
|
|
2246
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;
|
|
2247
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;
|
|
2248
|
+
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;
|
|
2249
|
+
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;
|
|
2250
|
+
|
|
2251
|
+
const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;
|
|
2252
|
+
const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;
|
|
2253
|
+
const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;
|
|
2254
|
+
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;
|
|
2255
|
+
|
|
2256
|
+
HVX_Vector r0_c0_sum = Q6_V_vzero();
|
|
2257
|
+
HVX_Vector r0_c1_sum = Q6_V_vzero();
|
|
2258
|
+
HVX_Vector r1_c0_sum = Q6_V_vzero();
|
|
2259
|
+
HVX_Vector r1_c1_sum = Q6_V_vzero();
|
|
2260
|
+
|
|
2261
|
+
const uint32_t nb = n / qk;
|
|
2262
|
+
const uint32_t nloe = n % qk;
|
|
2263
|
+
|
|
2264
|
+
uint32_t i = 0;
|
|
2265
|
+
for (; i < nb; i++) {
|
|
2266
|
+
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
|
|
2267
|
+
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
|
|
2268
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
|
|
2269
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size);
|
|
2270
|
+
|
|
2271
|
+
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
|
|
2272
|
+
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
|
|
2273
|
+
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
|
|
2274
|
+
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
|
|
2275
|
+
|
|
2276
|
+
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
|
|
2277
|
+
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
|
|
2278
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
2279
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
2280
|
+
|
|
2281
|
+
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
|
2282
|
+
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
|
2283
|
+
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
|
2284
|
+
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
|
2285
|
+
|
|
2286
|
+
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
|
2287
|
+
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
|
2288
|
+
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
|
2289
|
+
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
|
2290
|
+
|
|
2291
|
+
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
|
2292
|
+
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
|
2293
|
+
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
|
2294
|
+
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
|
2295
|
+
}
|
|
2296
|
+
|
|
2297
|
+
if (nloe) {
|
|
2298
|
+
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
|
|
2299
|
+
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
|
|
2300
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
2301
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
2302
|
+
|
|
2303
|
+
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
|
|
2304
|
+
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
|
|
2305
|
+
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
|
|
2306
|
+
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
|
|
2307
|
+
|
|
2308
|
+
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
|
|
2309
|
+
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
|
|
2310
|
+
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
|
2311
|
+
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
|
2312
|
+
|
|
2313
|
+
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
|
2314
|
+
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
|
2315
|
+
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
|
2316
|
+
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
|
2317
|
+
|
|
2318
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
2319
|
+
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
|
|
2320
|
+
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
|
|
2321
|
+
r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
|
|
2322
|
+
r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
|
|
2323
|
+
r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
|
|
2324
|
+
r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
|
|
2325
|
+
r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
|
|
2326
|
+
r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
|
|
2327
|
+
|
|
2328
|
+
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
|
2329
|
+
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
|
2330
|
+
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
|
2331
|
+
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
|
2332
|
+
|
|
2333
|
+
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
|
2334
|
+
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
|
2335
|
+
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
|
2336
|
+
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
|
2337
|
+
}
|
|
2338
|
+
|
|
2339
|
+
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
|
|
2340
|
+
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
|
|
2341
|
+
|
|
2342
|
+
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);
|
|
2343
|
+
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);
|
|
2344
|
+
}
|
|
2345
|
+
|
|
2346
|
+
static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
|
|
2347
|
+
assert(n % 32 == 0); // min sub-block size
|
|
2348
|
+
assert((unsigned long) vx0 % 128 == 0);
|
|
2349
|
+
assert((unsigned long) vy0 % 128 == 0);
|
|
2350
|
+
|
|
2351
|
+
const uint32_t qk = QK_MXFP4x4x2 * 4;
|
|
2352
|
+
|
|
2353
|
+
const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
|
|
2354
|
+
const uint32_t x_qblk_size = qk / 2; // fp4
|
|
2355
|
+
const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
|
|
2356
|
+
|
|
2357
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
2358
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
2359
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
2360
|
+
|
|
2361
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
|
|
2362
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
|
|
2363
|
+
|
|
2364
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
|
2365
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
2366
|
+
|
|
2367
|
+
// Row sum (sf)
|
|
2368
|
+
HVX_Vector r0_sum = Q6_V_vzero();
|
|
2369
|
+
|
|
2370
|
+
// Multiply and accumulate into int32.
|
|
2371
|
+
// Compute combined scale (fp32).
|
|
2372
|
+
// Apply scale to acc and accumulate into the row sum (qf32).
|
|
2373
|
+
|
|
2374
|
+
const uint32_t nb = n / qk; // num full blocks
|
|
2375
|
+
int32_t nloe = n % qk; // num leftover elemements (must be signed)
|
|
2376
|
+
|
|
2377
|
+
uint32_t i = 0;
|
|
2378
|
+
for (; i < nb; i++) {
|
|
2379
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size);
|
|
2380
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
|
|
2381
|
+
|
|
2382
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
2383
|
+
|
|
2384
|
+
HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
|
|
2385
|
+
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
2386
|
+
|
|
2387
|
+
// Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
|
|
2388
|
+
HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
|
|
2389
|
+
vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
|
|
2390
|
+
vy_d = Q6_Vsf_equals_Vqf32(vy_d);
|
|
2391
|
+
|
|
2392
|
+
// Convert rX_d scales from e8m0 to fp32
|
|
2393
|
+
// Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
|
|
2394
|
+
// Left shift with zero fill to create FP32
|
|
2395
|
+
// FIXME: might need to handle zero as a special case (see ggml-cpu code)
|
|
2396
|
+
HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
|
|
2397
|
+
HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
|
|
2398
|
+
r0_d = Q6_V_vdelta_VV(r0_d, expand);
|
|
2399
|
+
r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
|
|
2400
|
+
r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
|
|
2401
|
+
|
|
2402
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
|
|
2403
|
+
|
|
2404
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
2405
|
+
|
|
2406
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
2407
|
+
}
|
|
2408
|
+
|
|
2409
|
+
// Process leftovers
|
|
2410
|
+
if (nloe) {
|
|
2411
|
+
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe);
|
|
2412
|
+
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
2413
|
+
|
|
2414
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
|
2415
|
+
|
|
2416
|
+
HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
|
|
2417
|
+
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
2418
|
+
|
|
2419
|
+
// Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
|
|
2420
|
+
HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
|
|
2421
|
+
vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
|
|
2422
|
+
vy_d = Q6_Vsf_equals_Vqf32(vy_d);
|
|
2423
|
+
|
|
2424
|
+
// Convert rX_d scales from e8m0 to fp32
|
|
2425
|
+
// Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
|
|
2426
|
+
// Left shift with zero fill to create FP32
|
|
2427
|
+
// FIXME: might need to handle zero as a special case (see ggml-cpu code)
|
|
2428
|
+
HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
|
|
2429
|
+
HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
|
|
2430
|
+
r0_d = Q6_V_vdelta_VV(r0_d, expand);
|
|
2431
|
+
r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
|
|
2432
|
+
r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
|
|
2433
|
+
|
|
2434
|
+
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
|
|
2435
|
+
|
|
2436
|
+
// Zero-out unused scales
|
|
895
2437
|
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
|
|
899
|
-
r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
|
|
900
|
-
r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
|
|
901
|
-
r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
|
|
902
|
-
r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
|
|
903
|
-
r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
|
|
2438
|
+
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
2439
|
+
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
904
2440
|
|
|
905
|
-
HVX_Vector
|
|
906
|
-
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
|
907
|
-
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
|
908
|
-
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
|
2441
|
+
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
909
2442
|
|
|
910
|
-
|
|
911
|
-
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
|
912
|
-
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
|
913
|
-
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
|
2443
|
+
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
914
2444
|
}
|
|
915
2445
|
|
|
916
|
-
|
|
917
|
-
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
|
|
918
|
-
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
|
|
2446
|
+
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
|
|
919
2447
|
|
|
920
|
-
hvx_vec_store_u(
|
|
921
|
-
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
|
|
2448
|
+
hvx_vec_store_u(s0, 4, r0_sum);
|
|
922
2449
|
}
|
|
923
2450
|
|
|
924
|
-
static void
|
|
2451
|
+
static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
2452
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
2453
|
+
const void * restrict vy0) {
|
|
925
2454
|
assert(n % 32 == 0); // min sub-block size
|
|
926
2455
|
assert((unsigned long) vx0 % 128 == 0);
|
|
2456
|
+
assert((unsigned long) vx1 % 128 == 0);
|
|
927
2457
|
assert((unsigned long) vy0 % 128 == 0);
|
|
928
2458
|
|
|
929
2459
|
const uint32_t qk = QK_MXFP4x4x2 * 4;
|
|
930
2460
|
|
|
931
|
-
const uint32_t x_dblk_size = 8 * 4 * 1;
|
|
932
|
-
const uint32_t x_qblk_size = qk / 2;
|
|
933
|
-
const uint32_t x_qrow_size = n / 2;
|
|
2461
|
+
const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
|
|
2462
|
+
const uint32_t x_qblk_size = qk / 2; // fp4
|
|
2463
|
+
const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
|
|
934
2464
|
|
|
935
|
-
const uint32_t y_dblk_size = 8 * 4 * 2;
|
|
936
|
-
const uint32_t y_qblk_size = qk;
|
|
937
|
-
const uint32_t y_qrow_size = n;
|
|
2465
|
+
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
|
2466
|
+
const uint32_t y_qblk_size = qk; // int8
|
|
2467
|
+
const uint32_t y_qrow_size = n; // int8 (not padded)
|
|
938
2468
|
|
|
939
|
-
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0
|
|
940
|
-
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size
|
|
2469
|
+
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
|
|
2470
|
+
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
|
2471
|
+
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
|
2472
|
+
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
|
941
2473
|
|
|
942
|
-
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0
|
|
943
|
-
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size
|
|
2474
|
+
const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first
|
|
2475
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
|
|
944
2476
|
|
|
945
2477
|
// Row sum (sf)
|
|
946
2478
|
HVX_Vector r0_sum = Q6_V_vzero();
|
|
2479
|
+
HVX_Vector r1_sum = Q6_V_vzero();
|
|
947
2480
|
|
|
948
2481
|
// Multiply and accumulate into int32.
|
|
949
2482
|
// Compute combined scale (fp32).
|
|
950
|
-
// Apply scale to acc and accumulate into the row sum (
|
|
2483
|
+
// Apply scale to acc and accumulate into the row sum (f32).
|
|
951
2484
|
|
|
952
2485
|
const uint32_t nb = n / qk; // num full blocks
|
|
953
2486
|
int32_t nloe = n % qk; // num leftover elemements (must be signed)
|
|
@@ -956,11 +2489,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const
|
|
|
956
2489
|
for (; i < nb; i++) {
|
|
957
2490
|
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size);
|
|
958
2491
|
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
|
|
2492
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size);
|
|
959
2493
|
|
|
960
2494
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
2495
|
+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
|
961
2496
|
|
|
962
2497
|
HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
|
|
963
2498
|
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
2499
|
+
HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
|
|
964
2500
|
|
|
965
2501
|
// Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
|
|
966
2502
|
HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
|
|
@@ -976,23 +2512,32 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const
|
|
|
976
2512
|
r0_d = Q6_V_vdelta_VV(r0_d, expand);
|
|
977
2513
|
r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
|
|
978
2514
|
r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
|
|
2515
|
+
r1_d = Q6_V_vdelta_VV(r1_d, expand);
|
|
2516
|
+
r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
|
|
2517
|
+
r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
|
|
979
2518
|
|
|
980
2519
|
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
|
|
2520
|
+
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
|
|
981
2521
|
|
|
982
2522
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
2523
|
+
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
983
2524
|
|
|
984
2525
|
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
2526
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
985
2527
|
}
|
|
986
2528
|
|
|
987
2529
|
// Process leftovers
|
|
988
2530
|
if (nloe) {
|
|
989
2531
|
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe);
|
|
990
2532
|
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
2533
|
+
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
991
2534
|
|
|
992
|
-
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(
|
|
2535
|
+
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
2536
|
+
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
|
993
2537
|
|
|
994
2538
|
HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
|
|
995
2539
|
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
2540
|
+
HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
|
|
996
2541
|
|
|
997
2542
|
// Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
|
|
998
2543
|
HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
|
|
@@ -1008,30 +2553,40 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const
|
|
|
1008
2553
|
r0_d = Q6_V_vdelta_VV(r0_d, expand);
|
|
1009
2554
|
r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
|
|
1010
2555
|
r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
|
|
2556
|
+
r1_d = Q6_V_vdelta_VV(r1_d, expand);
|
|
2557
|
+
r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
|
|
2558
|
+
r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
|
|
1011
2559
|
|
|
1012
2560
|
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
|
|
2561
|
+
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
|
|
1013
2562
|
|
|
1014
|
-
// Zero-out unused
|
|
2563
|
+
// Zero-out unused values
|
|
1015
2564
|
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
1016
2565
|
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
2566
|
+
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
|
1017
2567
|
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
2568
|
+
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
|
1018
2569
|
|
|
1019
2570
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
2571
|
+
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
1020
2572
|
|
|
1021
2573
|
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
2574
|
+
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
1022
2575
|
}
|
|
1023
2576
|
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
hvx_vec_store_u(s0, 4, r0_sum);
|
|
2577
|
+
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
|
|
2578
|
+
hvx_vec_store_u(s0, 8, rsum);
|
|
1027
2579
|
}
|
|
1028
2580
|
|
|
1029
|
-
static void
|
|
2581
|
+
static void vec_dot_mxfp4x4x2_q8x4x2_4x1(const int n, float * restrict s0,
|
|
1030
2582
|
const void * restrict vx0, const void * restrict vx1,
|
|
2583
|
+
const void * restrict vx2, const void * restrict vx3,
|
|
1031
2584
|
const void * restrict vy0) {
|
|
1032
2585
|
assert(n % 32 == 0); // min sub-block size
|
|
1033
2586
|
assert((unsigned long) vx0 % 128 == 0);
|
|
1034
2587
|
assert((unsigned long) vx1 % 128 == 0);
|
|
2588
|
+
assert((unsigned long) vx2 % 128 == 0);
|
|
2589
|
+
assert((unsigned long) vx3 % 128 == 0);
|
|
1035
2590
|
assert((unsigned long) vy0 % 128 == 0);
|
|
1036
2591
|
|
|
1037
2592
|
const uint32_t qk = QK_MXFP4x4x2 * 4;
|
|
@@ -1048,17 +2603,19 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
|
1048
2603
|
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
|
1049
2604
|
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
|
1050
2605
|
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
|
2606
|
+
const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first
|
|
2607
|
+
const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales
|
|
2608
|
+
const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first
|
|
2609
|
+
const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales
|
|
1051
2610
|
|
|
1052
2611
|
const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first
|
|
1053
|
-
const uint8_t * restrict y_d = ((const uint8_t *) vy0
|
|
2612
|
+
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
|
1054
2613
|
|
|
1055
2614
|
// Row sum (sf)
|
|
1056
2615
|
HVX_Vector r0_sum = Q6_V_vzero();
|
|
1057
2616
|
HVX_Vector r1_sum = Q6_V_vzero();
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
// Compute combined scale (fp32).
|
|
1061
|
-
// Apply scale to acc and accumulate into the row sum (f32).
|
|
2617
|
+
HVX_Vector r2_sum = Q6_V_vzero();
|
|
2618
|
+
HVX_Vector r3_sum = Q6_V_vzero();
|
|
1062
2619
|
|
|
1063
2620
|
const uint32_t nb = n / qk; // num full blocks
|
|
1064
2621
|
int32_t nloe = n % qk; // num leftover elemements (must be signed)
|
|
@@ -1068,13 +2625,19 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
|
1068
2625
|
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size);
|
|
1069
2626
|
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
|
|
1070
2627
|
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size);
|
|
2628
|
+
HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_full(r2_x_q + i * x_qblk_size);
|
|
2629
|
+
HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_full(r3_x_q + i * x_qblk_size);
|
|
1071
2630
|
|
|
1072
2631
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
1073
2632
|
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
|
2633
|
+
HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q));
|
|
2634
|
+
HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q));
|
|
1074
2635
|
|
|
1075
|
-
HVX_Vector vy_d = *(const HVX_UVector *) (y_d
|
|
2636
|
+
HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
|
|
1076
2637
|
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
1077
2638
|
HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
|
|
2639
|
+
HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size);
|
|
2640
|
+
HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size);
|
|
1078
2641
|
|
|
1079
2642
|
// Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
|
|
1080
2643
|
HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
|
|
@@ -1082,9 +2645,6 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
|
1082
2645
|
vy_d = Q6_Vsf_equals_Vqf32(vy_d);
|
|
1083
2646
|
|
|
1084
2647
|
// Convert rX_d scales from e8m0 to fp32
|
|
1085
|
-
// Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
|
|
1086
|
-
// Left shift with zero fill to create FP32
|
|
1087
|
-
// FIXME: might need to handle zero as a special case (see ggml-cpu code)
|
|
1088
2648
|
HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
|
|
1089
2649
|
HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
|
|
1090
2650
|
r0_d = Q6_V_vdelta_VV(r0_d, expand);
|
|
@@ -1093,29 +2653,46 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
|
1093
2653
|
r1_d = Q6_V_vdelta_VV(r1_d, expand);
|
|
1094
2654
|
r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
|
|
1095
2655
|
r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
|
|
2656
|
+
r2_d = Q6_V_vdelta_VV(r2_d, expand);
|
|
2657
|
+
r2_d = Q6_V_vand_VV(r2_d, e8m0_mask);
|
|
2658
|
+
r2_d = Q6_Vw_vasl_VwR(r2_d, 23);
|
|
2659
|
+
r3_d = Q6_V_vdelta_VV(r3_d, expand);
|
|
2660
|
+
r3_d = Q6_V_vand_VV(r3_d, e8m0_mask);
|
|
2661
|
+
r3_d = Q6_Vw_vasl_VwR(r3_d, 23);
|
|
1096
2662
|
|
|
1097
2663
|
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
|
|
1098
2664
|
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
|
|
2665
|
+
HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d));
|
|
2666
|
+
HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d));
|
|
1099
2667
|
|
|
1100
2668
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
1101
2669
|
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
2670
|
+
HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
|
|
2671
|
+
HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
|
|
1102
2672
|
|
|
1103
2673
|
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
1104
2674
|
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
2675
|
+
r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum));
|
|
2676
|
+
r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum));
|
|
1105
2677
|
}
|
|
1106
2678
|
|
|
1107
|
-
// Process leftovers
|
|
1108
2679
|
if (nloe) {
|
|
1109
2680
|
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe);
|
|
1110
2681
|
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
|
1111
2682
|
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
|
2683
|
+
HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_partial(r2_x_q + i * x_qblk_size, nloe);
|
|
2684
|
+
HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_partial(r3_x_q + i * x_qblk_size, nloe);
|
|
1112
2685
|
|
|
1113
2686
|
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
|
1114
2687
|
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
|
2688
|
+
HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q));
|
|
2689
|
+
HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q));
|
|
1115
2690
|
|
|
1116
2691
|
HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
|
|
1117
2692
|
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
|
|
1118
2693
|
HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
|
|
2694
|
+
HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size);
|
|
2695
|
+
HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size);
|
|
1119
2696
|
|
|
1120
2697
|
// Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
|
|
1121
2698
|
HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
|
|
@@ -1123,9 +2700,6 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
|
1123
2700
|
vy_d = Q6_Vsf_equals_Vqf32(vy_d);
|
|
1124
2701
|
|
|
1125
2702
|
// Convert rX_d scales from e8m0 to fp32
|
|
1126
|
-
// Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
|
|
1127
|
-
// Left shift with zero fill to create FP32
|
|
1128
|
-
// FIXME: might need to handle zero as a special case (see ggml-cpu code)
|
|
1129
2703
|
HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
|
|
1130
2704
|
HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
|
|
1131
2705
|
r0_d = Q6_V_vdelta_VV(r0_d, expand);
|
|
@@ -1134,28 +2708,46 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
|
|
|
1134
2708
|
r1_d = Q6_V_vdelta_VV(r1_d, expand);
|
|
1135
2709
|
r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
|
|
1136
2710
|
r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
|
|
2711
|
+
r2_d = Q6_V_vdelta_VV(r2_d, expand);
|
|
2712
|
+
r2_d = Q6_V_vand_VV(r2_d, e8m0_mask);
|
|
2713
|
+
r2_d = Q6_Vw_vasl_VwR(r2_d, 23);
|
|
2714
|
+
r3_d = Q6_V_vdelta_VV(r3_d, expand);
|
|
2715
|
+
r3_d = Q6_V_vand_VV(r3_d, e8m0_mask);
|
|
2716
|
+
r3_d = Q6_Vw_vasl_VwR(r3_d, 23);
|
|
1137
2717
|
|
|
1138
2718
|
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
|
|
1139
2719
|
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
|
|
2720
|
+
HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d));
|
|
2721
|
+
HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d));
|
|
1140
2722
|
|
|
1141
2723
|
// Zero-out unused values
|
|
1142
2724
|
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
|
1143
2725
|
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
|
1144
2726
|
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
|
2727
|
+
r2_dd = Q6_V_vand_QV(bmask, r2_dd);
|
|
2728
|
+
r3_dd = Q6_V_vand_QV(bmask, r3_dd);
|
|
1145
2729
|
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
|
1146
2730
|
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
|
2731
|
+
r2_ia = Q6_V_vand_QV(bmask, r2_ia);
|
|
2732
|
+
r3_ia = Q6_V_vand_QV(bmask, r3_ia);
|
|
1147
2733
|
|
|
1148
2734
|
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
|
1149
2735
|
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
|
2736
|
+
HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd);
|
|
2737
|
+
HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd);
|
|
1150
2738
|
|
|
1151
2739
|
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
|
1152
2740
|
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
|
2741
|
+
r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum));
|
|
2742
|
+
r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum));
|
|
1153
2743
|
}
|
|
1154
2744
|
|
|
1155
|
-
|
|
1156
|
-
|
|
2745
|
+
HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } };
|
|
2746
|
+
HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in);
|
|
2747
|
+
hvx_vec_store_u(s0, 16, rsum);
|
|
1157
2748
|
}
|
|
1158
2749
|
|
|
2750
|
+
|
|
1159
2751
|
static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
|
|
1160
2752
|
const void * restrict vx0, const void * restrict vx1,
|
|
1161
2753
|
const void * restrict vy0, const void * restrict vy1) {
|
|
@@ -1326,6 +2918,176 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float
|
|
|
1326
2918
|
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
|
|
1327
2919
|
}
|
|
1328
2920
|
|
|
2921
|
+
#if __HVX_ARCH__ < 79
|
|
2922
|
+
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
|
|
2923
|
+
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
|
2924
|
+
#else
|
|
2925
|
+
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
|
|
2926
|
+
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
|
2927
|
+
#endif
|
|
2928
|
+
|
|
2929
|
+
static void vec_dot_f32_f32_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
|
2930
|
+
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
|
|
2931
|
+
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
|
|
2932
|
+
|
|
2933
|
+
uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
|
|
2934
|
+
uint32_t nloe = n % VLEN_FP32; // leftover elements
|
|
2935
|
+
|
|
2936
|
+
HVX_Vector rsum = Q6_V_vzero();
|
|
2937
|
+
|
|
2938
|
+
uint32_t i = 0;
|
|
2939
|
+
|
|
2940
|
+
#pragma unroll(4)
|
|
2941
|
+
for (i = 0; i < nvec; i++) {
|
|
2942
|
+
HVX_Vector prod = HVX_OP_MUL_F32(x[i], y[i]);
|
|
2943
|
+
rsum = HVX_OP_ADD_F32(rsum, prod);
|
|
2944
|
+
}
|
|
2945
|
+
|
|
2946
|
+
if (nloe) {
|
|
2947
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
2948
|
+
HVX_Vector x_sf = Q6_V_vand_QV(bmask, x[i]);
|
|
2949
|
+
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
|
|
2950
|
+
HVX_Vector prod = HVX_OP_MUL_F32(x_sf, y_sf);
|
|
2951
|
+
rsum = HVX_OP_ADD_F32(rsum, prod);
|
|
2952
|
+
}
|
|
2953
|
+
|
|
2954
|
+
*s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum));
|
|
2955
|
+
}
|
|
2956
|
+
|
|
2957
|
+
static void vec_dot_f32_f32_aa_2x1(const int n, float * restrict s0,
|
|
2958
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
2959
|
+
const void * restrict vy0) {
|
|
2960
|
+
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
|
|
2961
|
+
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
|
|
2962
|
+
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
|
|
2963
|
+
|
|
2964
|
+
uint32_t nvec = n / VLEN_FP32;
|
|
2965
|
+
uint32_t nloe = n % VLEN_FP32;
|
|
2966
|
+
|
|
2967
|
+
HVX_Vector rsum0 = Q6_V_vzero();
|
|
2968
|
+
HVX_Vector rsum1 = Q6_V_vzero();
|
|
2969
|
+
|
|
2970
|
+
uint32_t i = 0;
|
|
2971
|
+
|
|
2972
|
+
#pragma unroll(2)
|
|
2973
|
+
for (i = 0; i < nvec; i++) {
|
|
2974
|
+
HVX_Vector y_sf = y[i];
|
|
2975
|
+
HVX_Vector prod0 = HVX_OP_MUL_F32(x0[i], y_sf);
|
|
2976
|
+
HVX_Vector prod1 = HVX_OP_MUL_F32(x1[i], y_sf);
|
|
2977
|
+
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
|
|
2978
|
+
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
|
|
2979
|
+
}
|
|
2980
|
+
|
|
2981
|
+
if (nloe) {
|
|
2982
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
2983
|
+
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
|
|
2984
|
+
HVX_Vector x0_sf = Q6_V_vand_QV(bmask, x0[i]);
|
|
2985
|
+
HVX_Vector x1_sf = Q6_V_vand_QV(bmask, x1[i]);
|
|
2986
|
+
HVX_Vector prod0 = HVX_OP_MUL_F32(x0_sf, y_sf);
|
|
2987
|
+
HVX_Vector prod1 = HVX_OP_MUL_F32(x1_sf, y_sf);
|
|
2988
|
+
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
|
|
2989
|
+
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
|
|
2990
|
+
}
|
|
2991
|
+
|
|
2992
|
+
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
|
|
2993
|
+
HVX_VectorAlias va;
|
|
2994
|
+
va.v = rsum;
|
|
2995
|
+
s0[0] = va.fp32[0];
|
|
2996
|
+
s0[1] = va.fp32[1];
|
|
2997
|
+
}
|
|
2998
|
+
|
|
2999
|
+
static void vec_dot_f32_f32_aa_2x2(const int n, float * restrict s0, float * restrict s1,
|
|
3000
|
+
const void * restrict vx0, const void * restrict vx1,
|
|
3001
|
+
const void * restrict vy0, const void * restrict vy1) {
|
|
3002
|
+
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
|
|
3003
|
+
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
|
|
3004
|
+
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
|
|
3005
|
+
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
|
|
3006
|
+
|
|
3007
|
+
uint32_t nvec = n / VLEN_FP32;
|
|
3008
|
+
uint32_t nloe = n % VLEN_FP32;
|
|
3009
|
+
|
|
3010
|
+
HVX_Vector r0_c0_sum = Q6_V_vzero();
|
|
3011
|
+
HVX_Vector r0_c1_sum = Q6_V_vzero();
|
|
3012
|
+
HVX_Vector r1_c0_sum = Q6_V_vzero();
|
|
3013
|
+
HVX_Vector r1_c1_sum = Q6_V_vzero();
|
|
3014
|
+
|
|
3015
|
+
uint32_t i = 0;
|
|
3016
|
+
|
|
3017
|
+
#pragma unroll(2)
|
|
3018
|
+
for (i = 0; i < nvec; i++) {
|
|
3019
|
+
HVX_Vector r0_sf = x0[i];
|
|
3020
|
+
HVX_Vector r1_sf = x1[i];
|
|
3021
|
+
HVX_Vector c0_sf = y0[i];
|
|
3022
|
+
HVX_Vector c1_sf = y1[i];
|
|
3023
|
+
|
|
3024
|
+
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
|
|
3025
|
+
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
|
|
3026
|
+
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
|
|
3027
|
+
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
|
|
3028
|
+
}
|
|
3029
|
+
|
|
3030
|
+
if (nloe) {
|
|
3031
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
3032
|
+
|
|
3033
|
+
HVX_Vector r0_sf = Q6_V_vand_QV(bmask, x0[i]);
|
|
3034
|
+
HVX_Vector r1_sf = Q6_V_vand_QV(bmask, x1[i]);
|
|
3035
|
+
HVX_Vector c0_sf = Q6_V_vand_QV(bmask, y0[i]);
|
|
3036
|
+
HVX_Vector c1_sf = Q6_V_vand_QV(bmask, y1[i]);
|
|
3037
|
+
|
|
3038
|
+
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
|
|
3039
|
+
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
|
|
3040
|
+
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
|
|
3041
|
+
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
|
|
3042
|
+
}
|
|
3043
|
+
|
|
3044
|
+
// Reduce and store results
|
|
3045
|
+
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
|
|
3046
|
+
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
|
|
3047
|
+
|
|
3048
|
+
HVX_VectorAlias va0, va1;
|
|
3049
|
+
va0.v = r0_r1_c0_sum;
|
|
3050
|
+
va1.v = r0_r1_c1_sum;
|
|
3051
|
+
s0[0] = va0.fp32[0];
|
|
3052
|
+
s0[1] = va0.fp32[1];
|
|
3053
|
+
s1[0] = va1.fp32[0];
|
|
3054
|
+
s1[1] = va1.fp32[1];
|
|
3055
|
+
}
|
|
3056
|
+
|
|
3057
|
+
static void vec_dot_f32_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
|
|
3058
|
+
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
|
|
3059
|
+
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
|
|
3060
|
+
|
|
3061
|
+
uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
|
|
3062
|
+
uint32_t nloe = n % VLEN_FP32; // leftover elements
|
|
3063
|
+
|
|
3064
|
+
HVX_Vector rsum = Q6_V_vzero();
|
|
3065
|
+
|
|
3066
|
+
uint32_t i = 0;
|
|
3067
|
+
|
|
3068
|
+
#pragma unroll(2)
|
|
3069
|
+
for (i = 0; i < nvec; i++) {
|
|
3070
|
+
HVX_Vector x_sf = vx[i];
|
|
3071
|
+
HVX_Vector y_sf = vy[i];
|
|
3072
|
+
|
|
3073
|
+
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
|
|
3074
|
+
}
|
|
3075
|
+
|
|
3076
|
+
if (nloe) {
|
|
3077
|
+
HVX_Vector x_sf = vx[i];
|
|
3078
|
+
HVX_Vector y_sf = vy[i];
|
|
3079
|
+
|
|
3080
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
|
3081
|
+
x_sf = Q6_V_vand_QV(bmask, x_sf);
|
|
3082
|
+
y_sf = Q6_V_vand_QV(bmask, y_sf);
|
|
3083
|
+
|
|
3084
|
+
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
|
|
3085
|
+
}
|
|
3086
|
+
|
|
3087
|
+
rsum = hvx_vec_reduce_sum_f32(rsum);
|
|
3088
|
+
hvx_vec_store_u(&s[0], 4, rsum);
|
|
3089
|
+
}
|
|
3090
|
+
|
|
1329
3091
|
static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
|
1330
3092
|
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
|
|
1331
3093
|
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
|
|
@@ -1533,11 +3295,11 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void *
|
|
|
1533
3295
|
hvx_vec_store_u(&s[0], 4, rsum);
|
|
1534
3296
|
}
|
|
1535
3297
|
|
|
1536
|
-
#define htp_matmul_tensors_preamble
|
|
1537
|
-
struct htp_tensor * restrict src0
|
|
1538
|
-
struct htp_tensor * restrict src1
|
|
1539
|
-
struct htp_tensor * restrict src2
|
|
1540
|
-
struct htp_tensor * restrict
|
|
3298
|
+
#define htp_matmul_tensors_preamble \
|
|
3299
|
+
const struct htp_tensor * restrict src0 = octx->src[0]; \
|
|
3300
|
+
const struct htp_tensor * restrict src1 = octx->src[1]; \
|
|
3301
|
+
const struct htp_tensor * restrict src2 = octx->src[2]; \
|
|
3302
|
+
const struct htp_tensor * restrict dst = octx->dst; \
|
|
1541
3303
|
struct htp_spad * restrict src0_spad = &octx->src0_spad; \
|
|
1542
3304
|
struct htp_spad * restrict src1_spad = &octx->src1_spad; \
|
|
1543
3305
|
struct htp_spad * restrict dst_spad = &octx->dst_spad; \
|
|
@@ -1744,7 +3506,7 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
|
|
|
1744
3506
|
// Process the last row (if any)
|
|
1745
3507
|
if (src0_end_row != src0_end_row_x2) {
|
|
1746
3508
|
uint32_t ir0 = src0_end_row_x2;
|
|
1747
|
-
const int is0 = (ir0 - src0_start_row);
|
|
3509
|
+
const int is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
|
|
1748
3510
|
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
|
|
1749
3511
|
src0_stride, src0_row_size, 1);
|
|
1750
3512
|
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
|
@@ -1773,7 +3535,6 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
|
|
|
1773
3535
|
|
|
1774
3536
|
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
|
1775
3537
|
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
|
1776
|
-
const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
|
|
1777
3538
|
|
|
1778
3539
|
// no work for this thread
|
|
1779
3540
|
if (src0_start_row >= src0_end_row) {
|
|
@@ -1803,39 +3564,89 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
|
|
|
1803
3564
|
const uint8_t * restrict src1_col = (const uint8_t *) src1_data;
|
|
1804
3565
|
float * restrict dst_col = (float *) dst->data;
|
|
1805
3566
|
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
|
|
1810
|
-
|
|
1811
|
-
|
|
3567
|
+
if (mmctx->vec_dot_4x1 != NULL) {
|
|
3568
|
+
const uint32_t src0_end_row_x4 = src0_start_row + ((src0_end_row - src0_start_row) & ~3U);
|
|
3569
|
+
|
|
3570
|
+
// Prefill spad with 4x src0 rows
|
|
3571
|
+
#pragma unroll(4)
|
|
3572
|
+
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) {
|
|
3573
|
+
const uint32_t is0 = (ir0 - src0_start_row);
|
|
3574
|
+
if (is0 >= MM_SPAD_SRC0_NROWS) {
|
|
3575
|
+
break;
|
|
3576
|
+
}
|
|
3577
|
+
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
|
|
3578
|
+
src0_stride, src0_row_size, 4);
|
|
1812
3579
|
}
|
|
1813
|
-
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
|
|
1814
|
-
src0_stride, src0_row_size, 2);
|
|
1815
|
-
}
|
|
1816
3580
|
|
|
1817
|
-
|
|
1818
|
-
|
|
1819
|
-
|
|
1820
|
-
|
|
3581
|
+
// Process src0 rows
|
|
3582
|
+
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) {
|
|
3583
|
+
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
|
3584
|
+
mmctx->vec_dot_4x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, ss0 + 2 * src0_stride, ss0 + 3 * src0_stride, src1_col);
|
|
1821
3585
|
|
|
1822
|
-
|
|
1823
|
-
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
|
|
3586
|
+
// Prefetch next (n + spad_nrows) row
|
|
3587
|
+
const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
|
|
3588
|
+
const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
|
|
3589
|
+
if (pr0 < src0_end_row_x4) {
|
|
3590
|
+
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
|
|
3591
|
+
src0_stride, src0_row_size, 4);
|
|
3592
|
+
}
|
|
3593
|
+
}
|
|
3594
|
+
|
|
3595
|
+
// Process leftovers
|
|
3596
|
+
uint32_t ir0 = src0_end_row_x4;
|
|
3597
|
+
if (ir0 + 2 <= src0_end_row) {
|
|
3598
|
+
const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
|
|
3599
|
+
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
|
|
3600
|
+
src0_stride, src0_row_size, 2);
|
|
3601
|
+
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
|
3602
|
+
mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
|
|
3603
|
+
ir0 += 2;
|
|
3604
|
+
}
|
|
3605
|
+
if (ir0 < src0_end_row) {
|
|
3606
|
+
const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
|
|
3607
|
+
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
|
|
3608
|
+
src0_stride, src0_row_size, 1);
|
|
3609
|
+
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
|
3610
|
+
mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
|
|
3611
|
+
ir0 += 1;
|
|
3612
|
+
}
|
|
3613
|
+
} else {
|
|
3614
|
+
const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
|
|
3615
|
+
|
|
3616
|
+
// Prefill spad with 2x src0 rows
|
|
3617
|
+
#pragma unroll(2)
|
|
3618
|
+
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
|
|
3619
|
+
const uint32_t is0 = (ir0 - src0_start_row);
|
|
3620
|
+
if (is0 >= MM_SPAD_SRC0_NROWS) {
|
|
3621
|
+
break;
|
|
3622
|
+
}
|
|
3623
|
+
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
|
|
1827
3624
|
src0_stride, src0_row_size, 2);
|
|
1828
3625
|
}
|
|
1829
|
-
}
|
|
1830
3626
|
|
|
1831
|
-
|
|
1832
|
-
|
|
1833
|
-
|
|
1834
|
-
|
|
1835
|
-
|
|
1836
|
-
|
|
1837
|
-
|
|
1838
|
-
|
|
3627
|
+
// Process src0 rows
|
|
3628
|
+
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
|
|
3629
|
+
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
|
3630
|
+
mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
|
|
3631
|
+
|
|
3632
|
+
// Prefetch next (n + spad_nrows) row
|
|
3633
|
+
const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
|
|
3634
|
+
const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
|
|
3635
|
+
if (pr0 < src0_end_row_x2) {
|
|
3636
|
+
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
|
|
3637
|
+
src0_stride, src0_row_size, 2);
|
|
3638
|
+
}
|
|
3639
|
+
}
|
|
3640
|
+
|
|
3641
|
+
// Process the last row (if any)
|
|
3642
|
+
if (src0_end_row != src0_end_row_x2) {
|
|
3643
|
+
const uint32_t ir0 = src0_end_row_x2;
|
|
3644
|
+
const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
|
|
3645
|
+
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
|
|
3646
|
+
src0_stride, src0_row_size, 1);
|
|
3647
|
+
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
|
3648
|
+
mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
|
|
3649
|
+
}
|
|
1839
3650
|
}
|
|
1840
3651
|
|
|
1841
3652
|
hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
|
|
@@ -1859,8 +3670,8 @@ struct mmid_row_mapping {
|
|
|
1859
3670
|
static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
|
1860
3671
|
htp_matmul_preamble;
|
|
1861
3672
|
|
|
1862
|
-
struct htp_tensor * restrict
|
|
1863
|
-
struct htp_spad * restrict
|
|
3673
|
+
const struct htp_tensor * restrict ids = octx->src[2];
|
|
3674
|
+
struct htp_spad * restrict src2_spad = &octx->src2_spad;
|
|
1864
3675
|
|
|
1865
3676
|
uint64_t t1, t2;
|
|
1866
3677
|
t1 = HAP_perf_get_qtimer_count();
|
|
@@ -1880,11 +3691,8 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
|
|
1880
3691
|
const uint32_t n_ids = ids->ne[0]; // n_expert_used
|
|
1881
3692
|
const uint32_t n_as = ne02; // n_expert
|
|
1882
3693
|
|
|
1883
|
-
const
|
|
1884
|
-
const
|
|
1885
|
-
|
|
1886
|
-
const uint32_t * matrix_row_counts = (const uint32_t *) src2_spad->data + 0;
|
|
1887
|
-
const struct mmid_row_mapping * matrix_rows = (const void *) src2_spad->data + matrix_row_counts_size;
|
|
3694
|
+
const uint32_t * matrix_row_counts = mmctx->matrix_row_counts;
|
|
3695
|
+
const struct mmid_row_mapping * matrix_rows = mmctx->matrix_rows;
|
|
1888
3696
|
|
|
1889
3697
|
const size_t dst_row_size = nb1;
|
|
1890
3698
|
const size_t src0_row_size = nb01;
|
|
@@ -1906,6 +3714,10 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
|
|
1906
3714
|
continue;
|
|
1907
3715
|
}
|
|
1908
3716
|
|
|
3717
|
+
if (mmctx->hmx_eligible) {
|
|
3718
|
+
continue;
|
|
3719
|
+
}
|
|
3720
|
+
|
|
1909
3721
|
const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0);
|
|
1910
3722
|
|
|
1911
3723
|
// Prefill spad with src0 rows
|
|
@@ -1947,7 +3759,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
|
|
1947
3759
|
// Process the last row (if any)
|
|
1948
3760
|
if (src0_end_row != src0_end_row_x2) {
|
|
1949
3761
|
uint32_t ir0 = src0_end_row_x2;
|
|
1950
|
-
const uint32_t is0 = (ir0 - src0_start_row);
|
|
3762
|
+
const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
|
|
1951
3763
|
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
|
|
1952
3764
|
src0_row_size_padded, src0_row_size, 1);
|
|
1953
3765
|
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
|
@@ -1978,8 +3790,8 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
|
|
1978
3790
|
static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
|
|
1979
3791
|
htp_matmul_preamble;
|
|
1980
3792
|
|
|
1981
|
-
struct htp_tensor * restrict
|
|
1982
|
-
struct htp_spad * restrict
|
|
3793
|
+
const struct htp_tensor * restrict ids = octx->src[2];
|
|
3794
|
+
struct htp_spad * restrict src2_spad = &octx->src2_spad;
|
|
1983
3795
|
|
|
1984
3796
|
uint64_t t1, t2;
|
|
1985
3797
|
t1 = HAP_perf_get_qtimer_count();
|
|
@@ -2049,7 +3861,7 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
|
|
|
2049
3861
|
// Process the last row (if any)
|
|
2050
3862
|
if (src0_end_row != src0_end_row_x2) {
|
|
2051
3863
|
uint32_t ir0 = src0_end_row_x2;
|
|
2052
|
-
const uint32_t is0 = (ir0 - src0_start_row);
|
|
3864
|
+
const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
|
|
2053
3865
|
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
|
|
2054
3866
|
src0_row_size_padded, src0_row_size, 1);
|
|
2055
3867
|
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
|
@@ -2067,6 +3879,94 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
|
|
|
2067
3879
|
|
|
2068
3880
|
// *** dynamic quant
|
|
2069
3881
|
|
|
3882
|
+
static inline void quantize_block_f32_q8_1x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
|
|
3883
|
+
assert((unsigned long) x % 128 == 0);
|
|
3884
|
+
assert((unsigned long) y_q % 128 == 0);
|
|
3885
|
+
|
|
3886
|
+
HVX_Vector * vx = (HVX_Vector *) x;
|
|
3887
|
+
HVX_Vector zero = Q6_V_vzero();
|
|
3888
|
+
|
|
3889
|
+
// Use reduce max fp32 to find max(abs(e)) first
|
|
3890
|
+
HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0]));
|
|
3891
|
+
HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1]));
|
|
3892
|
+
HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2]));
|
|
3893
|
+
HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3]));
|
|
3894
|
+
|
|
3895
|
+
// Load and convert into QF32
|
|
3896
|
+
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
|
|
3897
|
+
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
|
|
3898
|
+
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
|
|
3899
|
+
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
|
|
3900
|
+
|
|
3901
|
+
// Convert to QF32
|
|
3902
|
+
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
|
|
3903
|
+
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
|
|
3904
|
+
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
|
|
3905
|
+
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
|
|
3906
|
+
|
|
3907
|
+
// Combine and convert to fp16
|
|
3908
|
+
HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
|
|
3909
|
+
HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf)));
|
|
3910
|
+
|
|
3911
|
+
// Convert into fp16
|
|
3912
|
+
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
|
|
3913
|
+
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
|
|
3914
|
+
|
|
3915
|
+
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
|
3916
|
+
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
|
3917
|
+
HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
|
|
3918
|
+
HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
|
|
3919
|
+
|
|
3920
|
+
// Divide input by the scale
|
|
3921
|
+
HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
|
|
3922
|
+
HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
|
|
3923
|
+
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
|
|
3924
|
+
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
|
|
3925
|
+
|
|
3926
|
+
// Convert to int8
|
|
3927
|
+
HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
|
|
3928
|
+
HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
|
|
3929
|
+
HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
|
|
3930
|
+
|
|
3931
|
+
*(HVX_Vector *) y_q = vx_i8;
|
|
3932
|
+
|
|
3933
|
+
// --- Sum calculation ---
|
|
3934
|
+
const HVX_Vector ones = Q6_Vb_vsplat_R(1);
|
|
3935
|
+
HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones); // sum every 4 consecutive elements
|
|
3936
|
+
// Sum 8 elements:
|
|
3937
|
+
v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4));
|
|
3938
|
+
v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8));
|
|
3939
|
+
v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16));
|
|
3940
|
+
|
|
3941
|
+
// Copy to stack to extract sums and vmaxes
|
|
3942
|
+
float vmax0[32] __attribute__((aligned(128)));
|
|
3943
|
+
float vmax1[32] __attribute__((aligned(128)));
|
|
3944
|
+
float vmax2[32] __attribute__((aligned(128)));
|
|
3945
|
+
float vmax3[32] __attribute__((aligned(128)));
|
|
3946
|
+
int32_t sums[32] __attribute__((aligned(128)));
|
|
3947
|
+
|
|
3948
|
+
hvx_vec_store_u(vmax0, 128, vmax0_sf);
|
|
3949
|
+
hvx_vec_store_u(vmax1, 128, vmax1_sf);
|
|
3950
|
+
hvx_vec_store_u(vmax2, 128, vmax2_sf);
|
|
3951
|
+
hvx_vec_store_u(vmax3, 128, vmax3_sf);
|
|
3952
|
+
hvx_vec_store_u(sums, 128, v_sums);
|
|
3953
|
+
|
|
3954
|
+
float d0 = vmax0[0] / 127.0f;
|
|
3955
|
+
float d1 = vmax1[0] / 127.0f;
|
|
3956
|
+
float d2 = vmax2[0] / 127.0f;
|
|
3957
|
+
float d3 = vmax3[0] / 127.0f;
|
|
3958
|
+
|
|
3959
|
+
__fp16 * y_d_half = (__fp16 *) y_d;
|
|
3960
|
+
y_d_half[0] = d0;
|
|
3961
|
+
y_d_half[1] = (float) sums[0] * d0;
|
|
3962
|
+
y_d_half[2] = d1;
|
|
3963
|
+
y_d_half[3] = (float) sums[8] * d1;
|
|
3964
|
+
y_d_half[4] = d2;
|
|
3965
|
+
y_d_half[5] = (float) sums[16] * d2;
|
|
3966
|
+
y_d_half[6] = d3;
|
|
3967
|
+
y_d_half[7] = (float) sums[24] * d3;
|
|
3968
|
+
}
|
|
3969
|
+
|
|
2070
3970
|
static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
|
|
2071
3971
|
assert((unsigned long) x % 128 == 0);
|
|
2072
3972
|
assert((unsigned long) y_q % 128 == 0);
|
|
@@ -2248,7 +4148,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data)
|
|
|
2248
4148
|
struct htp_matmul_context * mmctx = data;
|
|
2249
4149
|
struct htp_ops_context * octx = mmctx->octx;
|
|
2250
4150
|
|
|
2251
|
-
const struct htp_tensor * src =
|
|
4151
|
+
const struct htp_tensor * src = octx->src[1];
|
|
2252
4152
|
uint8_t * restrict dst = octx->src1_spad.data;
|
|
2253
4153
|
struct htp_spad * spad = &octx->src0_spad;
|
|
2254
4154
|
uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
|
|
@@ -2291,11 +4191,123 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data)
|
|
|
2291
4191
|
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
|
2292
4192
|
}
|
|
2293
4193
|
|
|
4194
|
+
static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
|
|
4195
|
+
assert(k % 32 == 0);
|
|
4196
|
+
const uint32_t qk = QK_Q8_0x4x2;
|
|
4197
|
+
const uint32_t nb = (k + qk - 1) / qk;
|
|
4198
|
+
|
|
4199
|
+
const uint32_t qrow_size = k; // int8
|
|
4200
|
+
|
|
4201
|
+
const uint32_t dblk_size = 8 * 4; // 8x (d, s) __fp16 = 32 bytes
|
|
4202
|
+
const uint32_t qblk_size = QK_Q8_0x4x2; // int8
|
|
4203
|
+
|
|
4204
|
+
uint8_t * restrict y_q = (y + 0); // quants first
|
|
4205
|
+
uint8_t * restrict y_d = (y + qrow_size); // then scales/sums
|
|
4206
|
+
|
|
4207
|
+
// Temp scales override input since we're working off of the aligned temp buffer in VTCM
|
|
4208
|
+
uint8_t * restrict t_d = (uint8_t *) x;
|
|
4209
|
+
|
|
4210
|
+
for (uint32_t i = 0; i < nb; i++) {
|
|
4211
|
+
quantize_block_f32_q8_1x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
|
|
4212
|
+
quantize_block_f32_q8_1x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
|
|
4213
|
+
}
|
|
4214
|
+
|
|
4215
|
+
// now copy the scales/sums into final location
|
|
4216
|
+
hvx_copy_f16_ua(y_d, t_d, nb * 16);
|
|
4217
|
+
}
|
|
4218
|
+
|
|
4219
|
+
static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) {
|
|
4220
|
+
struct htp_matmul_context * mmctx = data;
|
|
4221
|
+
struct htp_ops_context * octx = mmctx->octx;
|
|
4222
|
+
|
|
4223
|
+
const struct htp_tensor * src = octx->src[1];
|
|
4224
|
+
uint8_t * restrict dst = octx->src1_spad.data;
|
|
4225
|
+
struct htp_spad * spad = &octx->src0_spad;
|
|
4226
|
+
uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
|
|
4227
|
+
|
|
4228
|
+
uint64_t t1 = HAP_perf_get_qtimer_count();
|
|
4229
|
+
|
|
4230
|
+
const uint32_t ne0 = src->ne[0];
|
|
4231
|
+
const uint32_t ne1 = src->ne[1];
|
|
4232
|
+
const uint32_t ne2 = src->ne[2];
|
|
4233
|
+
const uint32_t ne3 = src->ne[3];
|
|
4234
|
+
|
|
4235
|
+
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
|
|
4236
|
+
|
|
4237
|
+
const uint32_t ir_first = nrows_per_thread * ith; // first row
|
|
4238
|
+
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
|
|
4239
|
+
|
|
4240
|
+
const size_t src_row_size = src->nb[1];
|
|
4241
|
+
const size_t dst_row_size = q8_1x4x2_row_size(ne0);
|
|
4242
|
+
|
|
4243
|
+
uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first);
|
|
4244
|
+
uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first);
|
|
4245
|
+
uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith);
|
|
4246
|
+
|
|
4247
|
+
const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float));
|
|
4248
|
+
memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding
|
|
4249
|
+
|
|
4250
|
+
for (uint32_t i = ir_first; i < ir_last; ++i) {
|
|
4251
|
+
hex_l2fetch(src_data, src_row_size, src_row_size, 2);
|
|
4252
|
+
hvx_copy_f32_aa(tmp_data, src_data, ne0);
|
|
4253
|
+
|
|
4254
|
+
quantize_row_f32_q8_1x4x2((float *) tmp_data, dst_data, ne0);
|
|
4255
|
+
dst_data += dst_row_size;
|
|
4256
|
+
src_data += src_row_size;
|
|
4257
|
+
}
|
|
4258
|
+
|
|
4259
|
+
uint64_t t2 = HAP_perf_get_qtimer_count();
|
|
4260
|
+
|
|
4261
|
+
FARF(HIGH, "quantize-f32-q8_1x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
|
|
4262
|
+
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
|
4263
|
+
}
|
|
4264
|
+
|
|
4265
|
+
static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
|
|
4266
|
+
struct htp_matmul_context * mmctx = data;
|
|
4267
|
+
struct htp_ops_context * octx = mmctx->octx;
|
|
4268
|
+
|
|
4269
|
+
const struct htp_tensor * src = octx->src[1];
|
|
4270
|
+
uint8_t * restrict dst = octx->src1_spad.data;
|
|
4271
|
+
uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
|
|
4272
|
+
uint32_t dst_stride = octx->src1_spad.stride;
|
|
4273
|
+
|
|
4274
|
+
uint64_t t1 = HAP_perf_get_qtimer_count();
|
|
4275
|
+
|
|
4276
|
+
const uint32_t ne0 = src->ne[0];
|
|
4277
|
+
const uint32_t ne1 = src->ne[1];
|
|
4278
|
+
const uint32_t ne2 = src->ne[2];
|
|
4279
|
+
const uint32_t ne3 = src->ne[3];
|
|
4280
|
+
|
|
4281
|
+
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
|
|
4282
|
+
|
|
4283
|
+
const uint32_t ir_first = nrows_per_thread * ith; // first row
|
|
4284
|
+
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
|
|
4285
|
+
|
|
4286
|
+
const size_t src_row_size = ne0 * sizeof(float);
|
|
4287
|
+
const size_t src_stride = src->nb[1];
|
|
4288
|
+
|
|
4289
|
+
uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
|
|
4290
|
+
uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
|
|
4291
|
+
|
|
4292
|
+
for (uint32_t i = ir_first; i < ir_last; ++i) {
|
|
4293
|
+
hex_l2fetch(src_data, src_row_size, src_stride, 2);
|
|
4294
|
+
hvx_copy_f32_au(dst_data, src_data, ne0);
|
|
4295
|
+
|
|
4296
|
+
dst_data += dst_stride;
|
|
4297
|
+
src_data += src_stride;
|
|
4298
|
+
}
|
|
4299
|
+
|
|
4300
|
+
uint64_t t2 = HAP_perf_get_qtimer_count();
|
|
4301
|
+
|
|
4302
|
+
FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
|
|
4303
|
+
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
|
4304
|
+
}
|
|
4305
|
+
|
|
2294
4306
|
static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
|
|
2295
4307
|
struct htp_matmul_context * mmctx = data;
|
|
2296
4308
|
struct htp_ops_context * octx = mmctx->octx;
|
|
2297
4309
|
|
|
2298
|
-
const struct htp_tensor * src =
|
|
4310
|
+
const struct htp_tensor * src = octx->src[1];
|
|
2299
4311
|
uint8_t * restrict dst = octx->src1_spad.data;
|
|
2300
4312
|
uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
|
|
2301
4313
|
uint32_t dst_stride = octx->src1_spad.stride;
|
|
@@ -2337,7 +4349,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
|
|
|
2337
4349
|
struct htp_matmul_context * mmctx = data;
|
|
2338
4350
|
struct htp_ops_context * octx = mmctx->octx;
|
|
2339
4351
|
|
|
2340
|
-
const struct htp_tensor * src =
|
|
4352
|
+
const struct htp_tensor * src = octx->src[1];
|
|
2341
4353
|
uint8_t * restrict dst = octx->src1_spad.data;
|
|
2342
4354
|
uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
|
|
2343
4355
|
uint32_t dst_stride = octx->src1_spad.stride;
|
|
@@ -2386,18 +4398,35 @@ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_t
|
|
|
2386
4398
|
mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1;
|
|
2387
4399
|
mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1;
|
|
2388
4400
|
mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2;
|
|
4401
|
+
mmctx->vec_dot_4x1 = vec_dot_q4x4x2_q8x4x2_4x1;
|
|
4402
|
+
return 0;
|
|
4403
|
+
case HTP_TYPE_Q4_1:
|
|
4404
|
+
mmctx->type = "q4_1x4x2-f32";
|
|
4405
|
+
mmctx->vec_dot_1x1 = vec_dot_q4_1x4x2_q8x4x2_1x1;
|
|
4406
|
+
mmctx->vec_dot_2x1 = vec_dot_q4_1x4x2_q8x4x2_2x1;
|
|
4407
|
+
mmctx->vec_dot_2x2 = vec_dot_q4_1x4x2_q8x4x2_2x2;
|
|
4408
|
+
mmctx->vec_dot_4x1 = vec_dot_q4_1x4x2_q8x4x2_4x1;
|
|
2389
4409
|
return 0;
|
|
2390
4410
|
case HTP_TYPE_Q8_0:
|
|
2391
4411
|
mmctx->type = "q8x4x2-f32";
|
|
2392
4412
|
mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1;
|
|
2393
4413
|
mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1;
|
|
2394
4414
|
mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2;
|
|
4415
|
+
mmctx->vec_dot_4x1 = vec_dot_q8x4x2_q8x4x2_4x1;
|
|
4416
|
+
return 0;
|
|
4417
|
+
case HTP_TYPE_IQ4_NL:
|
|
4418
|
+
mmctx->type = "iq4nlx4x2-f32";
|
|
4419
|
+
mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1;
|
|
4420
|
+
mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1;
|
|
4421
|
+
mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2;
|
|
4422
|
+
mmctx->vec_dot_4x1 = vec_dot_iq4nlx4x2_q8x4x2_4x1;
|
|
2395
4423
|
return 0;
|
|
2396
4424
|
case HTP_TYPE_MXFP4:
|
|
2397
4425
|
mmctx->type = "mxfp4x4x2-f32";
|
|
2398
4426
|
mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1;
|
|
2399
4427
|
mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1;
|
|
2400
4428
|
mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2;
|
|
4429
|
+
mmctx->vec_dot_4x1 = vec_dot_mxfp4x4x2_q8x4x2_4x1;
|
|
2401
4430
|
return 0;
|
|
2402
4431
|
default:
|
|
2403
4432
|
return -1;
|
|
@@ -2430,7 +4459,7 @@ static void htp_mminit_spad(struct htp_ops_context * octx,
|
|
|
2430
4459
|
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
|
2431
4460
|
}
|
|
2432
4461
|
|
|
2433
|
-
int
|
|
4462
|
+
static int op_matmul_hvx(struct htp_ops_context * octx) {
|
|
2434
4463
|
htp_matmul_tensors_preamble;
|
|
2435
4464
|
|
|
2436
4465
|
struct htp_matmul_context mmctx_struct = {0};
|
|
@@ -2454,7 +4483,7 @@ int op_matmul(struct htp_ops_context * octx) {
|
|
|
2454
4483
|
worker_callback_t quant_job_func;
|
|
2455
4484
|
worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d;
|
|
2456
4485
|
|
|
2457
|
-
bool need_quant =
|
|
4486
|
+
bool need_quant = true;
|
|
2458
4487
|
|
|
2459
4488
|
if (src0->type == HTP_TYPE_F16) {
|
|
2460
4489
|
// Try optimized f16-f16 path first (src1 in VTCM)
|
|
@@ -2468,7 +4497,7 @@ int op_matmul(struct htp_ops_context * octx) {
|
|
|
2468
4497
|
// Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
|
|
2469
4498
|
// It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
|
|
2470
4499
|
const bool is_batched = (ne02 > 1) || (ne03 > 1);
|
|
2471
|
-
const bool is_permuted = htp_is_permuted(
|
|
4500
|
+
const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]);
|
|
2472
4501
|
|
|
2473
4502
|
if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
|
|
2474
4503
|
// Optimized path
|
|
@@ -2516,6 +4545,60 @@ int op_matmul(struct htp_ops_context * octx) {
|
|
|
2516
4545
|
mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
|
|
2517
4546
|
mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
|
|
2518
4547
|
|
|
4548
|
+
need_quant = false;
|
|
4549
|
+
}
|
|
4550
|
+
} else if (src0->type == HTP_TYPE_F32) {
|
|
4551
|
+
// Try optimized f32-f32 path first (src1 in VTCM)
|
|
4552
|
+
const size_t f32_src1_row_size = hex_round_up(ne10 * 4, 128);
|
|
4553
|
+
const size_t f32_src1_spad_size = hex_round_up(f32_src1_row_size * src1_nrows, 256);
|
|
4554
|
+
const size_t f32_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
|
|
4555
|
+
const size_t f32_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
|
|
4556
|
+
|
|
4557
|
+
const size_t f32_total_size = f32_src1_spad_size + f32_src0_spad_size + f32_dst_spad_size;
|
|
4558
|
+
|
|
4559
|
+
const bool is_batched = (ne02 > 1) || (ne03 > 1);
|
|
4560
|
+
const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]);
|
|
4561
|
+
|
|
4562
|
+
if (!is_batched && !is_permuted && f32_total_size <= octx->ctx->vtcm_size) {
|
|
4563
|
+
// Optimized path
|
|
4564
|
+
quant_job_func = quantize_f32_f32;
|
|
4565
|
+
mmctx->type = "f32-f32";
|
|
4566
|
+
mmctx->vec_dot_1x1 = vec_dot_f32_f32_aa_1x1;
|
|
4567
|
+
mmctx->vec_dot_2x1 = vec_dot_f32_f32_aa_2x1;
|
|
4568
|
+
mmctx->vec_dot_2x2 = vec_dot_f32_f32_aa_2x2;
|
|
4569
|
+
|
|
4570
|
+
src1_row_size = f32_src1_row_size;
|
|
4571
|
+
|
|
4572
|
+
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
|
4573
|
+
octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
|
4574
|
+
octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
|
|
4575
|
+
|
|
4576
|
+
octx->src1_spad.size = octx->src1_spad.size_per_thread;
|
|
4577
|
+
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
|
4578
|
+
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
|
4579
|
+
} else {
|
|
4580
|
+
// Fallback to DDR / broadcasting
|
|
4581
|
+
quant_job_func = NULL;
|
|
4582
|
+
mmctx->type = "f32-f32";
|
|
4583
|
+
mmctx->vec_dot_1x1 = vec_dot_f32_f32_uu_1x1;
|
|
4584
|
+
matmul_job_func = matmul_4d;
|
|
4585
|
+
|
|
4586
|
+
src1_row_size = nb11;
|
|
4587
|
+
|
|
4588
|
+
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
|
4589
|
+
octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
|
|
4590
|
+
octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
|
|
4591
|
+
|
|
4592
|
+
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
|
4593
|
+
octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
|
|
4594
|
+
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
|
4595
|
+
|
|
4596
|
+
// Init fastdiv for matmul_4d (supports broadcasting)
|
|
4597
|
+
mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
|
|
4598
|
+
mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]);
|
|
4599
|
+
mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
|
|
4600
|
+
mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
|
|
4601
|
+
|
|
2519
4602
|
need_quant = false;
|
|
2520
4603
|
}
|
|
2521
4604
|
} else {
|
|
@@ -2523,8 +4606,13 @@ int op_matmul(struct htp_ops_context * octx) {
|
|
|
2523
4606
|
return HTP_STATUS_NO_SUPPORT;
|
|
2524
4607
|
}
|
|
2525
4608
|
|
|
2526
|
-
|
|
2527
|
-
|
|
4609
|
+
if (src0->type == HTP_TYPE_Q4_1) {
|
|
4610
|
+
quant_job_func = quantize_f32_q8_1x4x2;
|
|
4611
|
+
src1_row_size = q8_1x4x2_row_size(ne10);
|
|
4612
|
+
} else {
|
|
4613
|
+
quant_job_func = quantize_f32_q8x4x2;
|
|
4614
|
+
src1_row_size = q8x4x2_row_size(ne10);
|
|
4615
|
+
}
|
|
2528
4616
|
htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0);
|
|
2529
4617
|
}
|
|
2530
4618
|
|
|
@@ -2545,27 +4633,148 @@ int op_matmul(struct htp_ops_context * octx) {
|
|
|
2545
4633
|
return HTP_STATUS_VTCM_TOO_SMALL;
|
|
2546
4634
|
}
|
|
2547
4635
|
|
|
2548
|
-
|
|
2549
|
-
octx->src1_spad.data = octx->
|
|
2550
|
-
octx->
|
|
4636
|
+
// Place src1 spad first. We use it for dyn.quant and may reuse between ops
|
|
4637
|
+
octx->src1_spad.data = octx->ctx->vtcm_base;
|
|
4638
|
+
octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
|
4639
|
+
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
|
4640
|
+
|
|
4641
|
+
octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL;
|
|
4642
|
+
octx->src0_spad.src = NULL;
|
|
4643
|
+
octx->dst_spad.src = NULL;
|
|
2551
4644
|
|
|
2552
4645
|
octx->src0_spad.stride = src0_row_size_padded;
|
|
2553
4646
|
octx->src1_spad.stride = src1_row_size;
|
|
2554
4647
|
|
|
2555
|
-
if (
|
|
4648
|
+
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)
|
|
4649
|
+
return HTP_STATUS_OK;
|
|
4650
|
+
|
|
4651
|
+
if (need_quant && !octx->src1_spad.src) {
|
|
2556
4652
|
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
|
|
2557
4653
|
mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
|
|
2558
4654
|
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
|
|
4655
|
+
octx->src1_spad.src = src1;
|
|
2559
4656
|
}
|
|
2560
4657
|
|
|
2561
|
-
|
|
2562
|
-
|
|
2563
|
-
worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);
|
|
2564
|
-
}
|
|
4658
|
+
const uint32_t n_matmul_jobs = octx->n_threads;
|
|
4659
|
+
worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);
|
|
2565
4660
|
|
|
2566
4661
|
return HTP_STATUS_OK;
|
|
2567
4662
|
}
|
|
2568
4663
|
|
|
4664
|
+
int op_matmul(struct htp_ops_context * octx) {
|
|
4665
|
+
htp_matmul_tensors_preamble;
|
|
4666
|
+
|
|
4667
|
+
#ifndef HTP_HAS_HMX
|
|
4668
|
+
return op_matmul_hvx(octx);
|
|
4669
|
+
#else
|
|
4670
|
+
if (!octx->ctx->hmx_enabled) {
|
|
4671
|
+
return op_matmul_hvx(octx);
|
|
4672
|
+
}
|
|
4673
|
+
|
|
4674
|
+
// HMX weight tile requires N to be 32-aligned.
|
|
4675
|
+
if (src0->ne[1] % 32 != 0) {
|
|
4676
|
+
return op_matmul_hvx(octx);
|
|
4677
|
+
}
|
|
4678
|
+
|
|
4679
|
+
// HMX supports F16, F32, Q4_0, Q8_0, IQ4_NL, MXFP4 weights.
|
|
4680
|
+
// Other types fall back to HVX.
|
|
4681
|
+
uint32_t wtype = src0->type;
|
|
4682
|
+
if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) {
|
|
4683
|
+
return op_matmul_hvx(octx);
|
|
4684
|
+
}
|
|
4685
|
+
|
|
4686
|
+
// Quantised HMX path requires K aligned to 256 (x4x2 super-block).
|
|
4687
|
+
// F16 and F32 HMX paths require K aligned to 32 (tile width).
|
|
4688
|
+
if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && src0->ne[0] % 256 != 0) {
|
|
4689
|
+
return op_matmul_hvx(octx);
|
|
4690
|
+
}
|
|
4691
|
+
|
|
4692
|
+
if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && src0->ne[0] % 32 != 0) {
|
|
4693
|
+
return op_matmul_hvx(octx);
|
|
4694
|
+
}
|
|
4695
|
+
|
|
4696
|
+
const bool is_batched = (src0->ne[2] * src0->ne[3] > 1 || src1->ne[2] * src1->ne[3] > 1);
|
|
4697
|
+
|
|
4698
|
+
// Quantised HMX kernels only handle flat 2D matmul (host already rejects
|
|
4699
|
+
// batched quantised, but guard here too). F16 batched matmul is handled
|
|
4700
|
+
// by the dedicated wrapper in hmx-matmul-ops.c.
|
|
4701
|
+
if (is_batched && src0->type != HTP_TYPE_F16) {
|
|
4702
|
+
return op_matmul_hvx(octx);
|
|
4703
|
+
}
|
|
4704
|
+
|
|
4705
|
+
// HMX assumes contiguous row-major layout. Fall back for permuted
|
|
4706
|
+
// tensors where strides are non-monotonic (e.g. transposed KV cache).
|
|
4707
|
+
if (src0->nb[0] > src0->nb[1] || src1->nb[0] > src1->nb[1]) {
|
|
4708
|
+
return op_matmul_hvx(octx);
|
|
4709
|
+
}
|
|
4710
|
+
|
|
4711
|
+
// M alignment: Use HMX when M >= 32, the last partial tile (m_total % 32 rows)
|
|
4712
|
+
// is handled by HMX itself; when M < 32 fall back to HVX.
|
|
4713
|
+
const int m_total = (int) src1->ne[1];
|
|
4714
|
+
const int m_hmx = m_total & ~31; // 0 when M < 32
|
|
4715
|
+
if (m_hmx == 0) {
|
|
4716
|
+
return op_matmul_hvx(octx);
|
|
4717
|
+
}
|
|
4718
|
+
|
|
4719
|
+
// Always re-quantize src1 since HMX kernel overwrites vtcm/spad,
|
|
4720
|
+
// so any previously cached quantized data is invalid.
|
|
4721
|
+
octx->src1_spad.src = NULL;
|
|
4722
|
+
|
|
4723
|
+
int k = (int) src0->ne[0]; // inner dimension
|
|
4724
|
+
int n = (int) src0->ne[1]; // weight columns
|
|
4725
|
+
|
|
4726
|
+
int ret = -1;
|
|
4727
|
+
|
|
4728
|
+
// Row strides in elements. For compact tensors these equal k; for
|
|
4729
|
+
// permuted attention views they can be larger, so pass the real stride.
|
|
4730
|
+
const int act_stride = (int)(src1->nb[1] / sizeof(float));
|
|
4731
|
+
const int wgt_stride = (int)(src0->nb[1] / sizeof(__fp16));
|
|
4732
|
+
|
|
4733
|
+
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
|
4734
|
+
return HTP_STATUS_OK;
|
|
4735
|
+
}
|
|
4736
|
+
|
|
4737
|
+
if (is_batched) {
|
|
4738
|
+
if (src0->type == HTP_TYPE_F16) {
|
|
4739
|
+
hmx_matmul_f16_f32_batched_params_t batch_params = {
|
|
4740
|
+
.dst = (float *) dst->data,
|
|
4741
|
+
.activation = (float *) src1->data,
|
|
4742
|
+
.permuted_weight = (const __fp16 *) src0->data,
|
|
4743
|
+
.m = m_total,
|
|
4744
|
+
.k = k,
|
|
4745
|
+
.n = n,
|
|
4746
|
+
.act_stride = act_stride,
|
|
4747
|
+
.weight_stride = wgt_stride,
|
|
4748
|
+
.dst_stride = (int) (dst->nb[1] / sizeof(float)),
|
|
4749
|
+
.ne02 = ne02,
|
|
4750
|
+
.ne03 = ne03,
|
|
4751
|
+
.ne12 = ne12,
|
|
4752
|
+
.ne13 = ne13,
|
|
4753
|
+
.src0_nb2 = src0->nb[2],
|
|
4754
|
+
.src0_nb3 = src0->nb[3],
|
|
4755
|
+
.src1_nb2 = src1->nb[2],
|
|
4756
|
+
.src1_nb3 = src1->nb[3],
|
|
4757
|
+
.dst_nb2 = dst->nb[2],
|
|
4758
|
+
.dst_nb3 = dst->nb[3],
|
|
4759
|
+
};
|
|
4760
|
+
ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params);
|
|
4761
|
+
} else {
|
|
4762
|
+
return op_matmul_hvx(octx);
|
|
4763
|
+
}
|
|
4764
|
+
} else {
|
|
4765
|
+
ret = hmx_matmul_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
|
|
4766
|
+
m_total, k, n, act_stride, (int) src0->nb[1], (int) src0->type);
|
|
4767
|
+
}
|
|
4768
|
+
|
|
4769
|
+
if (ret != 0) {
|
|
4770
|
+
FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret);
|
|
4771
|
+
return op_matmul(octx);
|
|
4772
|
+
}
|
|
4773
|
+
|
|
4774
|
+
return 0;
|
|
4775
|
+
#endif // HTP_HAS_HMX
|
|
4776
|
+
}
|
|
4777
|
+
|
|
2569
4778
|
int op_matmul_id(struct htp_ops_context * octx) {
|
|
2570
4779
|
htp_matmul_tensors_preamble;
|
|
2571
4780
|
|
|
@@ -2573,7 +4782,7 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
|
|
2573
4782
|
struct htp_matmul_context * mmctx = &mmctx_struct;
|
|
2574
4783
|
mmctx->octx = octx;
|
|
2575
4784
|
|
|
2576
|
-
struct htp_tensor * restrict ids =
|
|
4785
|
+
const struct htp_tensor * restrict ids = octx->src[2];
|
|
2577
4786
|
|
|
2578
4787
|
const size_t src0_row_size = nb01;
|
|
2579
4788
|
const size_t dst_row_size = nb1;
|
|
@@ -2599,15 +4808,42 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
|
|
2599
4808
|
|
|
2600
4809
|
size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
|
|
2601
4810
|
size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
|
|
4811
|
+
const size_t total_map_size = matrix_row_counts_size + matrix_row_map_size;
|
|
4812
|
+
|
|
4813
|
+
void * mapping_buf = NULL;
|
|
4814
|
+
bool must_free_mapping = false;
|
|
4815
|
+
|
|
4816
|
+
if (octx->ctx->ddr_spad_base && total_map_size <= octx->ctx->ddr_spad_size) {
|
|
4817
|
+
mapping_buf = octx->ctx->ddr_spad_base;
|
|
4818
|
+
} else {
|
|
4819
|
+
mapping_buf = memalign(128, total_map_size);
|
|
4820
|
+
if (mapping_buf) {
|
|
4821
|
+
must_free_mapping = true;
|
|
4822
|
+
} else {
|
|
4823
|
+
return HTP_STATUS_INTERNAL_ERR;
|
|
4824
|
+
}
|
|
4825
|
+
}
|
|
4826
|
+
|
|
4827
|
+
uint32_t * matrix_row_counts = (uint32_t *) mapping_buf;
|
|
4828
|
+
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) ((uint8_t *) mapping_buf + matrix_row_counts_size);
|
|
4829
|
+
|
|
4830
|
+
mmctx->matrix_row_counts = matrix_row_counts;
|
|
4831
|
+
mmctx->matrix_rows = matrix_rows;
|
|
2602
4832
|
|
|
2603
4833
|
if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
|
|
4834
|
+
if (must_free_mapping) free(mapping_buf);
|
|
2604
4835
|
return HTP_STATUS_NO_SUPPORT;
|
|
2605
4836
|
}
|
|
2606
4837
|
|
|
2607
|
-
|
|
2608
|
-
|
|
4838
|
+
if (src0->type == HTP_TYPE_Q4_1) {
|
|
4839
|
+
quant_job_func = quantize_f32_q8_1x4x2;
|
|
4840
|
+
src1_row_size = q8_1x4x2_row_size(ne10);
|
|
4841
|
+
} else {
|
|
4842
|
+
quant_job_func = quantize_f32_q8x4x2;
|
|
4843
|
+
src1_row_size = q8x4x2_row_size(ne10);
|
|
4844
|
+
}
|
|
2609
4845
|
|
|
2610
|
-
const size_t src2_spad_size_per_thread =
|
|
4846
|
+
const size_t src2_spad_size_per_thread = 0; // We moved the mapping to DDR!
|
|
2611
4847
|
htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread);
|
|
2612
4848
|
|
|
2613
4849
|
size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
|
|
@@ -2623,22 +4859,26 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
|
|
2623
4859
|
// Make sure the reserved vtcm size is sufficient
|
|
2624
4860
|
if (octx->ctx->vtcm_size < spad_size) {
|
|
2625
4861
|
FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size);
|
|
4862
|
+
if (must_free_mapping) free(mapping_buf);
|
|
2626
4863
|
return HTP_STATUS_VTCM_TOO_SMALL;
|
|
2627
4864
|
}
|
|
2628
4865
|
|
|
2629
|
-
|
|
2630
|
-
octx->src1_spad.data = octx->
|
|
2631
|
-
octx->
|
|
4866
|
+
// Place src1 spad first. We use it for dyn.quant and may reuse in subseq ops.
|
|
4867
|
+
octx->src1_spad.data = octx->ctx->vtcm_base;
|
|
4868
|
+
octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
|
4869
|
+
octx->src2_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
|
2632
4870
|
octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size;
|
|
2633
4871
|
|
|
4872
|
+
octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL;
|
|
4873
|
+
octx->src0_spad.src = NULL;
|
|
4874
|
+
octx->src2_spad.src = NULL;
|
|
4875
|
+
octx->dst_spad.src = NULL;
|
|
4876
|
+
|
|
2634
4877
|
octx->src0_spad.stride = src0_row_size_padded;
|
|
2635
4878
|
octx->src1_spad.stride = src1_row_size;
|
|
2636
4879
|
|
|
2637
4880
|
if (src1_nrows > 1) {
|
|
2638
4881
|
// initialize matrix_row_counts and map
|
|
2639
|
-
uint32_t * matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0;
|
|
2640
|
-
struct mmid_row_mapping * matrix_rows = (void *) octx->src2_spad.data + matrix_row_counts_size;
|
|
2641
|
-
|
|
2642
4882
|
memset(matrix_row_counts, 0, n_as * sizeof(uint32_t));
|
|
2643
4883
|
|
|
2644
4884
|
// group rows by src0 matrix
|
|
@@ -2648,23 +4888,71 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
|
|
2648
4888
|
|
|
2649
4889
|
assert(i02 >= 0 && i02 < n_as);
|
|
2650
4890
|
|
|
2651
|
-
|
|
4891
|
+
matrix_rows[i02 * n_ids * ids->ne[1] + matrix_row_counts[i02]] = (struct mmid_row_mapping) { id, iid1 };
|
|
2652
4892
|
matrix_row_counts[i02] += 1;
|
|
2653
4893
|
}
|
|
2654
4894
|
}
|
|
2655
4895
|
}
|
|
2656
4896
|
|
|
2657
|
-
|
|
2658
|
-
|
|
4897
|
+
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
|
4898
|
+
if (must_free_mapping) free(mapping_buf);
|
|
4899
|
+
return HTP_STATUS_OK;
|
|
4900
|
+
}
|
|
4901
|
+
|
|
4902
|
+
bool hmx_eligible = false;
|
|
4903
|
+
#ifdef HTP_HAS_HMX
|
|
4904
|
+
if (octx->ctx->hmx_enabled && src1_nrows > 1) {
|
|
4905
|
+
uint32_t wtype = src0->type;
|
|
4906
|
+
if (ne01 % 32 == 0 &&
|
|
4907
|
+
(wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32 || wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || wtype == HTP_TYPE_MXFP4)) {
|
|
4908
|
+
if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && ne00 % 32 == 0) {
|
|
4909
|
+
hmx_eligible = true;
|
|
4910
|
+
} else if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && ne00 % 256 == 0) {
|
|
4911
|
+
hmx_eligible = true;
|
|
4912
|
+
}
|
|
4913
|
+
}
|
|
4914
|
+
}
|
|
4915
|
+
#endif
|
|
4916
|
+
|
|
4917
|
+
mmctx->hmx_eligible = hmx_eligible;
|
|
4918
|
+
|
|
4919
|
+
if (hmx_eligible) {
|
|
4920
|
+
for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) {
|
|
4921
|
+
const int32_t cne1 = matrix_row_counts[cur_a];
|
|
4922
|
+
if (cne1 == 0) continue;
|
|
4923
|
+
|
|
4924
|
+
int ret = hmx_matmul_id_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data,
|
|
4925
|
+
(const uint8_t *) src0->data + cur_a * nb02,
|
|
4926
|
+
cne1, ne00, ne01,
|
|
4927
|
+
ne11,
|
|
4928
|
+
nb11, nb12,
|
|
4929
|
+
nb1, nb2,
|
|
4930
|
+
(int) src0->nb[1], (int) src0->type,
|
|
4931
|
+
matrix_rows, cur_a, n_ids * ids->ne[1]);
|
|
4932
|
+
if (ret != 0) {
|
|
4933
|
+
FARF(ERROR, "HMX matmul failed for expert %u, error %d\n", cur_a, ret);
|
|
4934
|
+
if (must_free_mapping) free(mapping_buf);
|
|
4935
|
+
return HTP_STATUS_NO_SUPPORT;
|
|
4936
|
+
}
|
|
4937
|
+
}
|
|
4938
|
+
|
|
4939
|
+
// HMX has overwritten VTCM, so force dynamic quantization cache to clear
|
|
4940
|
+
octx->src1_spad.src = NULL;
|
|
4941
|
+
|
|
4942
|
+
if (must_free_mapping) free(mapping_buf);
|
|
4943
|
+
return HTP_STATUS_OK;
|
|
4944
|
+
}
|
|
4945
|
+
|
|
4946
|
+
if (octx->src1_spad.src != src1) {
|
|
2659
4947
|
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
|
|
2660
4948
|
mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
|
|
2661
4949
|
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
|
|
4950
|
+
octx->src1_spad.src = src1;
|
|
2662
4951
|
}
|
|
2663
4952
|
|
|
2664
|
-
|
|
2665
|
-
|
|
2666
|
-
worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
|
|
2667
|
-
}
|
|
4953
|
+
const uint32_t n_matmul_jobs = octx->n_threads;
|
|
4954
|
+
worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
|
|
2668
4955
|
|
|
4956
|
+
if (must_free_mapping) free(mapping_buf);
|
|
2669
4957
|
return HTP_STATUS_OK;
|
|
2670
4958
|
}
|