whispercpp 1.3.5 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/LICENSE +1 -1
- data/README.md +133 -3
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -7
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +56 -46
- data/ext/ruby_whisper.h +165 -2
- data/ext/ruby_whisper_context.c +297 -126
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -66
- data/ext/ruby_whisper_segment.c +6 -7
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +46 -16
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +24 -19
- data/ext/sources/examples/cli/cli.cpp +51 -9
- data/ext/sources/examples/common-ggml.cpp +4 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +213 -163
- data/ext/sources/ggml/CMakeLists.txt +29 -15
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +73 -11
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +8 -3
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +155 -16
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +25 -5
- data/ext/sources/ggml/src/ggml-alloc.c +9 -10
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
- data/ext/sources/ggml/src/ggml-common.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
- data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
- data/ext/sources/ggml/src/ggml-impl.h +68 -1
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +385 -119
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
- data/ext/sources/ggml/src/ggml.c +268 -52
- data/ext/sources/ggml/src/gguf.cpp +377 -47
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +62 -40
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +445 -55
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_context_params.rb +82 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +44 -6
- data/whispercpp.gemspec +2 -2
- metadata +426 -280
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
- data/ext/sources/examples/talk-llama/llama-context.h +0 -360
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
- data/ext/sources/examples/talk-llama/llama-model.h +0 -544
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
- data/ext/sources/examples/talk-llama/llama.h +0 -1540
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -569
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
|
@@ -3,19 +3,32 @@
|
|
|
3
3
|
|
|
4
4
|
#include "ime.h"
|
|
5
5
|
|
|
6
|
+
#include "binary-ops.h"
|
|
7
|
+
#include "common.h"
|
|
6
8
|
#include "ggml-backend-impl.h"
|
|
7
9
|
#include "ggml-common.h"
|
|
8
10
|
#include "ggml-cpu.h"
|
|
11
|
+
#include "ime_env.h"
|
|
9
12
|
#include "ime_kernels.h"
|
|
13
|
+
#include "ops.h"
|
|
14
|
+
#include "repack.h"
|
|
15
|
+
#include "rvv_kernels.h"
|
|
16
|
+
#include "spine_mem_pool.h"
|
|
10
17
|
#include "traits.h"
|
|
18
|
+
#include "vec.h"
|
|
19
|
+
|
|
20
|
+
#include <fcntl.h>
|
|
21
|
+
#include <sys/mman.h>
|
|
22
|
+
#include <unistd.h>
|
|
11
23
|
|
|
12
24
|
#include <algorithm>
|
|
25
|
+
#include <atomic>
|
|
13
26
|
#include <cassert>
|
|
27
|
+
#include <cerrno>
|
|
14
28
|
#include <cmath>
|
|
15
29
|
#include <cstdio> // for GGML_ASSERT
|
|
16
30
|
#include <stdexcept>
|
|
17
31
|
#include <thread>
|
|
18
|
-
|
|
19
32
|
// clang-format off
|
|
20
33
|
#if defined(__riscv)
|
|
21
34
|
|
|
@@ -25,13 +38,17 @@
|
|
|
25
38
|
#include <riscv_vector.h>
|
|
26
39
|
#endif
|
|
27
40
|
|
|
28
|
-
#if !defined(__riscv_zfh)
|
|
29
|
-
#error "riscv zfh extension not enabled"
|
|
41
|
+
#if !defined(__riscv_zfh) || !defined(__riscv_zvfh)
|
|
42
|
+
#error "riscv zfh extension not enabled, GGML_RV_ZFH and GGML_RV_ZVFH must be defined to 1"
|
|
30
43
|
#endif
|
|
31
44
|
|
|
32
|
-
#if defined(
|
|
45
|
+
#if !defined(__riscv_zba)
|
|
46
|
+
#error "riscv zba extension not enabled, GGML_RV_ZBA must be defined to 1"
|
|
47
|
+
#endif
|
|
48
|
+
|
|
49
|
+
#if defined(RISCV64_SPACEMIT_IME1) || defined(RISCV64_SPACEMIT_IME2)
|
|
33
50
|
#else
|
|
34
|
-
#error "RISCV64_SPACEMIT_IME1 not defined"
|
|
51
|
+
#error "RISCV64_SPACEMIT_IME1 or RISCV64_SPACEMIT_IME2 not defined"
|
|
35
52
|
#endif
|
|
36
53
|
|
|
37
54
|
#else
|
|
@@ -46,382 +63,490 @@
|
|
|
46
63
|
#pragma GCC diagnostic ignored "-Wunused-parameter"
|
|
47
64
|
#endif
|
|
48
65
|
|
|
49
|
-
#if defined(RISCV64_SPACEMIT_IME1)
|
|
50
|
-
#define QGEMM_STRIDEN_THREAD_ALIGN 16
|
|
51
|
-
#else
|
|
52
|
-
#define QGEMM_STRIDEN_THREAD_ALIGN 32
|
|
53
|
-
#endif
|
|
54
|
-
|
|
55
66
|
// clang-format on
|
|
56
67
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
const std::byte * packed_quant_b_data = nullptr;
|
|
61
|
-
const float * quant_b_scale = nullptr;
|
|
62
|
-
const void * quant_b_zp = nullptr;
|
|
63
|
-
const float * quant_b_blksum = nullptr;
|
|
64
|
-
const float * bias = nullptr;
|
|
65
|
-
float * c_ptr = nullptr;
|
|
66
|
-
size_t ldc = 0;
|
|
67
|
-
};
|
|
68
|
-
|
|
69
|
-
constexpr size_t div_round_up(size_t up, size_t down) {
|
|
70
|
-
return (up + down - 1) / down;
|
|
71
|
-
}
|
|
72
|
-
|
|
73
|
-
constexpr size_t q8_blk_size(size_t blk_len) {
|
|
74
|
-
const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t);
|
|
75
|
-
// Currently, the strictest alignment requirement of a block is for a float.
|
|
76
|
-
// Ensure contiguous blocks are suitably aligned.
|
|
77
|
-
assert(blk_size % alignof(float) == 0);
|
|
78
|
-
return blk_size;
|
|
68
|
+
extern "C" {
|
|
69
|
+
extern void ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value);
|
|
70
|
+
extern int ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value);
|
|
79
71
|
}
|
|
80
72
|
|
|
81
73
|
namespace ggml::cpu::riscv64_spacemit {
|
|
82
74
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
75
|
+
struct TLSContext {
|
|
76
|
+
int cpu_id{ -1 };
|
|
77
|
+
cpu_set_t cpuset;
|
|
78
|
+
void * tcm_buffer{ nullptr };
|
|
79
|
+
size_t tcm_buffer_size{ 0 };
|
|
80
|
+
};
|
|
86
81
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
constexpr
|
|
82
|
+
thread_local TLSContext tls_context;
|
|
83
|
+
|
|
84
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> constexpr size_t get_repacked_block_type_size() {
|
|
85
|
+
if constexpr (std::is_same_v<BLOC_TYPE, block_q6_K> || std::is_same_v<BLOC_TYPE, block_q8_0>) {
|
|
86
|
+
return sizeof(block_q8_0);
|
|
87
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0>) {
|
|
88
|
+
return sizeof(block_q4_0) * INTER_SIZE / QK4_0;
|
|
89
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_1> || std::is_same_v<BLOC_TYPE, block_q4_K>) {
|
|
90
|
+
return (sizeof(block_q4_0) + sizeof(uint8_t)) * INTER_SIZE / QK4_1;
|
|
91
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K>) {
|
|
92
|
+
return sizeof(spacemit_kernels::nrow_block_q2_k<1>);
|
|
93
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_q3_K>) {
|
|
94
|
+
return sizeof(spacemit_kernels::nrow_block_q3_k<1>);
|
|
95
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_mxfp4>) {
|
|
96
|
+
return sizeof(spacemit_kernels::nrow_block_mxfp4<1>);
|
|
97
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_q5_1> || std::is_same_v<BLOC_TYPE, block_q5_K>) {
|
|
98
|
+
return sizeof(spacemit_kernels::nrow_block_q5_1<1>);
|
|
99
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_q5_0>) {
|
|
100
|
+
return sizeof(spacemit_kernels::nrow_block_q5_0<1>);
|
|
101
|
+
} else {
|
|
102
|
+
assert(false);
|
|
103
|
+
return 0;
|
|
104
|
+
}
|
|
105
|
+
}
|
|
97
106
|
|
|
98
|
-
|
|
107
|
+
template <typename BLOC_TYPE> constexpr bool block_type_has_zp() {
|
|
108
|
+
if constexpr (std::is_same_v<BLOC_TYPE, block_q6_K> || std::is_same_v<BLOC_TYPE, block_q8_0> ||
|
|
109
|
+
std::is_same_v<BLOC_TYPE, block_q3_K> || std::is_same_v<BLOC_TYPE, block_q4_0> ||
|
|
110
|
+
std::is_same_v<BLOC_TYPE, block_mxfp4> || std::is_same_v<BLOC_TYPE, block_q5_0>) {
|
|
111
|
+
return false;
|
|
112
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_1> || std::is_same_v<BLOC_TYPE, block_q4_K> ||
|
|
113
|
+
std::is_same_v<BLOC_TYPE, block_q2_K> || std::is_same_v<BLOC_TYPE, block_q5_1> ||
|
|
114
|
+
std::is_same_v<BLOC_TYPE, block_q5_K>) {
|
|
115
|
+
return true;
|
|
116
|
+
} else {
|
|
117
|
+
assert(false);
|
|
118
|
+
return false;
|
|
119
|
+
}
|
|
120
|
+
}
|
|
99
121
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
122
|
+
class tensor_traits_base : public ggml::cpu::tensor_traits {
|
|
123
|
+
public:
|
|
124
|
+
virtual int repack(ggml_tensor * t, const void * data, size_t data_size) = 0;
|
|
125
|
+
};
|
|
104
126
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
127
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
|
|
128
|
+
bool work_size(int /* n_threads */, const ggml_tensor * op, size_t & size) override {
|
|
129
|
+
switch (op->op) {
|
|
130
|
+
case GGML_OP_MUL_MAT:
|
|
131
|
+
{
|
|
132
|
+
int64_t src1_nelements = ggml_nelements(op->src[1]);
|
|
133
|
+
|
|
134
|
+
if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K> || std::is_same_v<BLOC_TYPE, block_q3_K>) {
|
|
135
|
+
size =
|
|
136
|
+
spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K);
|
|
137
|
+
} else if constexpr (INTER_SIZE == QK4_0) {
|
|
138
|
+
size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) *
|
|
139
|
+
spacemit_kernels::q8_blk_size(QK4_0, true);
|
|
140
|
+
} else if constexpr (INTER_SIZE == 256) {
|
|
141
|
+
size = spacemit_kernels::div_round_up(src1_nelements, 256) *
|
|
142
|
+
spacemit_kernels::q8_hp_blk_size(256, true, true);
|
|
143
|
+
} else {
|
|
144
|
+
GGML_ABORT("unsupported block type");
|
|
145
|
+
}
|
|
108
146
|
|
|
109
|
-
|
|
147
|
+
size = GGML_PAD(size, sizeof(int64_t));
|
|
110
148
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
149
|
+
return true;
|
|
150
|
+
}
|
|
151
|
+
case GGML_OP_MUL_MAT_ID:
|
|
152
|
+
{
|
|
153
|
+
int64_t src1_nelements = ggml_nelements(op->src[1]);
|
|
154
|
+
|
|
155
|
+
if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K> || std::is_same_v<BLOC_TYPE, block_q3_K>) {
|
|
156
|
+
size =
|
|
157
|
+
spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K);
|
|
158
|
+
} else if constexpr (INTER_SIZE == QK4_0) {
|
|
159
|
+
size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) *
|
|
160
|
+
spacemit_kernels::q8_blk_size(QK4_0, true);
|
|
161
|
+
} else if constexpr (INTER_SIZE == 256) {
|
|
162
|
+
size = spacemit_kernels::div_round_up(src1_nelements, 256) *
|
|
163
|
+
spacemit_kernels::q8_hp_blk_size(256, true, true);
|
|
164
|
+
} else {
|
|
165
|
+
GGML_ABORT("unsupported block type");
|
|
166
|
+
}
|
|
115
167
|
|
|
116
|
-
|
|
117
|
-
const std::byte * b_col = packed_quant_b_data + n * packed_b_stride;
|
|
118
|
-
const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr;
|
|
119
|
-
float * c_blk = c_ptr + n;
|
|
168
|
+
size = GGML_PAD(size, sizeof(int64_t));
|
|
120
169
|
|
|
121
|
-
|
|
170
|
+
const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert
|
|
171
|
+
const int64_t ne12 = op->src[1]->ne[2]; // n_tokens
|
|
122
172
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr,
|
|
126
|
-
scale_stride);
|
|
173
|
+
const size_t sizeof_mmid_row_mapping = sizeof(int64_t);
|
|
174
|
+
size += sizeof_mmid_row_mapping * ne02 * (ne12 + 1) + (ne02 + 1) * sizeof(int64_t);
|
|
127
175
|
|
|
128
|
-
|
|
129
|
-
a_row += rows_handled * lda;
|
|
176
|
+
size = GGML_PAD(size, sizeof(int64_t));
|
|
130
177
|
|
|
131
|
-
|
|
178
|
+
return true;
|
|
179
|
+
}
|
|
180
|
+
default:
|
|
181
|
+
// GGML_ABORT("fatal error");
|
|
182
|
+
break;
|
|
132
183
|
}
|
|
184
|
+
return false;
|
|
133
185
|
}
|
|
134
|
-
}
|
|
135
186
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
187
|
+
bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override {
|
|
188
|
+
switch (op->op) {
|
|
189
|
+
case GGML_OP_MUL_MAT:
|
|
190
|
+
switch (op->src[0]->type) {
|
|
191
|
+
case GGML_TYPE_Q2_K:
|
|
192
|
+
case GGML_TYPE_Q3_K:
|
|
193
|
+
case GGML_TYPE_Q4_0:
|
|
194
|
+
case GGML_TYPE_Q4_1:
|
|
195
|
+
case GGML_TYPE_Q4_K:
|
|
196
|
+
case GGML_TYPE_Q6_K:
|
|
197
|
+
case GGML_TYPE_Q8_0:
|
|
198
|
+
case GGML_TYPE_Q5_1:
|
|
199
|
+
case GGML_TYPE_Q5_K:
|
|
200
|
+
//case GGML_TYPE_MXFP4:
|
|
201
|
+
forward_mul_mat(params, op);
|
|
202
|
+
return true;
|
|
203
|
+
default:
|
|
204
|
+
// GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT");
|
|
205
|
+
return false;
|
|
206
|
+
}
|
|
207
|
+
break;
|
|
208
|
+
case GGML_OP_MUL_MAT_ID:
|
|
209
|
+
switch (op->src[0]->type) {
|
|
210
|
+
case GGML_TYPE_Q2_K:
|
|
211
|
+
case GGML_TYPE_Q3_K:
|
|
212
|
+
case GGML_TYPE_Q4_0:
|
|
213
|
+
case GGML_TYPE_Q4_1:
|
|
214
|
+
case GGML_TYPE_Q4_K:
|
|
215
|
+
case GGML_TYPE_Q6_K:
|
|
216
|
+
case GGML_TYPE_Q8_0:
|
|
217
|
+
case GGML_TYPE_Q5_1:
|
|
218
|
+
case GGML_TYPE_Q5_K:
|
|
219
|
+
//case GGML_TYPE_MXFP4:
|
|
220
|
+
forward_mul_mat_id(params, op);
|
|
221
|
+
return true;
|
|
222
|
+
default:
|
|
223
|
+
// GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT_ID");
|
|
224
|
+
return false;
|
|
225
|
+
}
|
|
226
|
+
break;
|
|
227
|
+
default:
|
|
228
|
+
// GGML_ABORT("fatal error");
|
|
229
|
+
break;
|
|
230
|
+
}
|
|
231
|
+
return false;
|
|
142
232
|
}
|
|
143
|
-
return -1;
|
|
144
|
-
}
|
|
145
233
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
};
|
|
234
|
+
void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
|
|
235
|
+
constexpr size_t a_blk_len = INTER_SIZE;
|
|
236
|
+
constexpr size_t b_blk_len = INTER_SIZE;
|
|
150
237
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
uint8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_1 blocks
|
|
155
|
-
};
|
|
238
|
+
const ggml_tensor * src0 = op->src[0];
|
|
239
|
+
const ggml_tensor * src1 = op->src[1];
|
|
240
|
+
ggml_tensor * dst = op;
|
|
156
241
|
|
|
157
|
-
|
|
158
|
-
static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding");
|
|
159
|
-
static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t),
|
|
160
|
-
"wrong block_with_zp<4,16> size/padding");
|
|
161
|
-
static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding");
|
|
242
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
162
243
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
using block_q8_0x16 = block<8, 16>;
|
|
244
|
+
int ith = params->ith;
|
|
245
|
+
int nth = params->nth;
|
|
166
246
|
|
|
167
|
-
|
|
168
|
-
block_q4_0x16 out;
|
|
169
|
-
GGML_ASSERT(QK4_0 / blck_size_interleave == 2);
|
|
247
|
+
[[maybe_unused]] const enum ggml_type type = src0->type;
|
|
170
248
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
249
|
+
void * w_data = (void *) src0->data;
|
|
250
|
+
const float * feature = (const float *) src1->data;
|
|
251
|
+
float * output = (float *) dst->data;
|
|
174
252
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
253
|
+
const int64_t gemm_m = ne11 * ne12 * ne13;
|
|
254
|
+
const int64_t gemm_k = ne10;
|
|
255
|
+
const int64_t gemm_n = ne01;
|
|
256
|
+
|
|
257
|
+
spacemit_kernels::quantize_a_row_def quantize_a_row_i8;
|
|
258
|
+
spacemit_kernels::quantize_a_row_def quantize_a_4row_i8;
|
|
259
|
+
spacemit_kernels::gemm_kernel_quantize_def gemm_kernel;
|
|
260
|
+
bool set_kernel_impl = false;
|
|
261
|
+
|
|
262
|
+
int64_t block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len);
|
|
263
|
+
|
|
264
|
+
#if defined(RISCV64_SPACEMIT_IME2)
|
|
265
|
+
if (!set_kernel_impl && (global_spine_env_info.use_ime2)) {
|
|
266
|
+
quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8;
|
|
267
|
+
quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8;
|
|
268
|
+
block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true);
|
|
269
|
+
|
|
270
|
+
if constexpr (std::is_same_v<BLOC_TYPE, block_q6_K> || std::is_same_v<BLOC_TYPE, block_q8_0>) {
|
|
271
|
+
gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8;
|
|
272
|
+
set_kernel_impl = true;
|
|
273
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0> || std::is_same_v<BLOC_TYPE, block_q4_1> ||
|
|
274
|
+
std::is_same_v<BLOC_TYPE, block_q4_K>) {
|
|
275
|
+
if constexpr (INTER_SIZE == 256) {
|
|
276
|
+
gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp;
|
|
277
|
+
quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp;
|
|
278
|
+
quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8_hp;
|
|
279
|
+
block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true);
|
|
280
|
+
set_kernel_impl = true;
|
|
281
|
+
} else {
|
|
282
|
+
gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4;
|
|
283
|
+
quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8;
|
|
284
|
+
quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8;
|
|
285
|
+
block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true);
|
|
286
|
+
set_kernel_impl = true;
|
|
287
|
+
}
|
|
288
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K>) {
|
|
289
|
+
quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k;
|
|
290
|
+
quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k;
|
|
291
|
+
block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len);
|
|
292
|
+
|
|
293
|
+
gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k;
|
|
294
|
+
set_kernel_impl = true;
|
|
295
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_q3_K>) {
|
|
296
|
+
quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k;
|
|
297
|
+
quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k;
|
|
298
|
+
block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len);
|
|
299
|
+
|
|
300
|
+
gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k;
|
|
301
|
+
set_kernel_impl = true;
|
|
302
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_mxfp4>) {
|
|
303
|
+
gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4;
|
|
304
|
+
set_kernel_impl = true;
|
|
305
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_q5_1> || std::is_same_v<BLOC_TYPE, block_q5_K> ||
|
|
306
|
+
std::is_same_v<BLOC_TYPE, block_q5_0>) {
|
|
307
|
+
gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5;
|
|
308
|
+
set_kernel_impl = true;
|
|
309
|
+
}
|
|
181
310
|
}
|
|
182
|
-
|
|
311
|
+
#endif
|
|
183
312
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
313
|
+
#if defined(RISCV64_SPACEMIT_IME1)
|
|
314
|
+
if (!set_kernel_impl && (global_spine_env_info.use_ime1)) {
|
|
315
|
+
quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8;
|
|
316
|
+
quantize_a_4row_i8 = spacemit_kernels::ime1::quantize_a_4row_i8;
|
|
317
|
+
|
|
318
|
+
if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0> || std::is_same_v<BLOC_TYPE, block_q4_1> ||
|
|
319
|
+
std::is_same_v<BLOC_TYPE, block_q4_K>) {
|
|
320
|
+
gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4;
|
|
321
|
+
set_kernel_impl = true;
|
|
322
|
+
}
|
|
323
|
+
}
|
|
324
|
+
#endif
|
|
325
|
+
if (!set_kernel_impl) {
|
|
326
|
+
GGML_ABORT("no kernel implementation found for the block type");
|
|
190
327
|
}
|
|
191
|
-
}
|
|
192
328
|
|
|
193
|
-
|
|
194
|
-
|
|
329
|
+
const int64_t a_k_blks = spacemit_kernels::div_round_up(gemm_k, a_blk_len);
|
|
330
|
+
const int64_t b_k_blks = spacemit_kernels::div_round_up(gemm_k, b_blk_len);
|
|
195
331
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
GGML_ASSERT(QK4_1 / blck_size_interleave == 2);
|
|
199
|
-
|
|
200
|
-
for (int i = 0; i < 16; i++) {
|
|
201
|
-
float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d);
|
|
202
|
-
float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m);
|
|
203
|
-
float mid = -std::nearbyintf(m / d);
|
|
204
|
-
mid = std::min(15.0f, std::max(0.0f, mid));
|
|
205
|
-
out.d[i] = GGML_FP32_TO_FP16(d);
|
|
206
|
-
out.zp[i] = static_cast<uint8_t>(mid);
|
|
207
|
-
}
|
|
332
|
+
const int64_t row_stride_a = a_k_blks * block_stride_a;
|
|
333
|
+
const int64_t gemm_workspace_size = GGML_PAD(gemm_m * row_stride_a, alignof(int64_t));
|
|
208
334
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
for (int j = 0; j < QK4_1 / 4; j++) {
|
|
212
|
-
//src [b0 b16] ......... [b8 b24] ......... [b15 b31]
|
|
213
|
-
//dst [b0 b8] ......... [b7 b15]
|
|
214
|
-
out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4);
|
|
335
|
+
if (ith == 0 && params->wsize < gemm_workspace_size) {
|
|
336
|
+
GGML_ABORT("wsize less than gemm_workspace_size");
|
|
215
337
|
}
|
|
216
|
-
}
|
|
217
338
|
|
|
218
|
-
|
|
219
|
-
// [16, 31], in.d & 0xF0
|
|
220
|
-
for (int j = 0; j < QK4_1 / 4; j++) {
|
|
221
|
-
//src [b0 b16] ......... [b8 b24] ......... [b15 b31]
|
|
222
|
-
//dst [b16 b24] ......... [b23 b31]
|
|
223
|
-
out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0);
|
|
224
|
-
}
|
|
225
|
-
}
|
|
339
|
+
uintptr_t ws_ptr = reinterpret_cast<uintptr_t>(params->wdata);
|
|
226
340
|
|
|
227
|
-
|
|
228
|
-
|
|
341
|
+
void * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer;
|
|
342
|
+
const int64_t tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size;
|
|
229
343
|
|
|
230
|
-
|
|
231
|
-
int interleave_block,
|
|
232
|
-
const void * GGML_RESTRICT data,
|
|
233
|
-
size_t data_size) {
|
|
234
|
-
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
|
235
|
-
GGML_ASSERT(interleave_block == 16);
|
|
344
|
+
auto * quant_a_buffer = reinterpret_cast<uint8_t *>(ws_ptr);
|
|
236
345
|
|
|
237
|
-
|
|
346
|
+
constexpr int64_t row_align = 4;
|
|
347
|
+
const int64_t row_blks = spacemit_kernels::div_round_up(gemm_m, row_align);
|
|
238
348
|
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
int nrow = ggml_nrows(t);
|
|
243
|
-
int nblocks = t->ne[0] / QK4_0;
|
|
349
|
+
const int64_t row_stride_b = b_k_blks * get_repacked_block_type_size<BLOC_TYPE, INTER_SIZE, NB_COLS>();
|
|
350
|
+
const int64_t per_mb_rows_wsize = row_align * row_stride_a;
|
|
351
|
+
const int64_t per_nb_cols_wsize = NB_COLS * row_stride_b;
|
|
244
352
|
|
|
245
|
-
|
|
353
|
+
const int64_t barrier_idx = static_cast<int64_t>(ith / 2);
|
|
246
354
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
355
|
+
GGML_ASSERT(global_spine_env_info.init_barrier != nullptr);
|
|
356
|
+
GGML_ASSERT(barrier_idx < spine_init_barrier_count);
|
|
357
|
+
spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx];
|
|
250
358
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
359
|
+
if (gemm_m == 1) {
|
|
360
|
+
int task_per_thread = spacemit_kernels::div_round_up(a_k_blks, nth);
|
|
361
|
+
int a_blk_start = ith * task_per_thread;
|
|
362
|
+
int a_blk_end = std::min(a_blk_start + task_per_thread, (int) a_k_blks);
|
|
363
|
+
if (a_blk_start < a_blk_end) {
|
|
364
|
+
quantize_a_row_i8(a_blk_len, feature + a_blk_start * a_blk_len, (a_blk_end - a_blk_start) * a_blk_len,
|
|
365
|
+
quant_a_buffer + a_blk_start * block_stride_a);
|
|
366
|
+
}
|
|
367
|
+
} else {
|
|
368
|
+
int task_per_thread = spacemit_kernels::div_round_up(row_blks, nth);
|
|
369
|
+
int m_row_blk_start = ith * task_per_thread;
|
|
370
|
+
int m_row_blk_end = std::min(m_row_blk_start + task_per_thread, (int) row_blks);
|
|
371
|
+
for (int m_row_blk = m_row_blk_start; m_row_blk < m_row_blk_end; m_row_blk++) {
|
|
372
|
+
int m_idx = m_row_blk * row_align;
|
|
373
|
+
int rows_tobe_handled = (gemm_m - m_idx) > row_align ? row_align : (gemm_m - m_idx);
|
|
374
|
+
|
|
375
|
+
if (rows_tobe_handled == row_align && quantize_a_4row_i8 != nullptr) {
|
|
376
|
+
const float * a_row_ptr = feature + m_idx * gemm_k;
|
|
377
|
+
auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a;
|
|
378
|
+
quantize_a_4row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr);
|
|
379
|
+
} else {
|
|
380
|
+
while (rows_tobe_handled) {
|
|
381
|
+
const float * a_row_ptr = feature + m_idx * gemm_k;
|
|
382
|
+
auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a;
|
|
383
|
+
quantize_a_row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr);
|
|
384
|
+
rows_tobe_handled -= 1;
|
|
385
|
+
m_idx += 1;
|
|
386
|
+
}
|
|
387
|
+
}
|
|
255
388
|
}
|
|
256
|
-
*dst++ = make_block_q4_0x16(dst_tmp, interleave_block);
|
|
257
389
|
}
|
|
258
|
-
src += nrows_interleaved * nblocks;
|
|
259
|
-
}
|
|
260
|
-
return 0;
|
|
261
390
|
|
|
262
|
-
|
|
263
|
-
}
|
|
391
|
+
ggml_barrier(params->threadpool);
|
|
264
392
|
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
size_t data_size) {
|
|
269
|
-
GGML_ASSERT(t->type == GGML_TYPE_Q4_1);
|
|
270
|
-
GGML_ASSERT(interleave_block == 16);
|
|
393
|
+
const int64_t gemm_m_stride = gemm_n / gemm_m > 64 ? gemm_m : 16;
|
|
394
|
+
const int64_t gemm_m_blocked = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride);
|
|
395
|
+
const int64_t max_gemm_n_stride = spacemit_kernels::div_round_up(gemm_n * gemm_m_blocked, nth);
|
|
271
396
|
|
|
272
|
-
|
|
397
|
+
int64_t gemm_n_stride = gemm_n;
|
|
398
|
+
if (max_gemm_n_stride < gemm_n) {
|
|
399
|
+
gemm_n_stride =
|
|
400
|
+
std::min(gemm_n_stride, spacemit_kernels::div_round_up(max_gemm_n_stride, NB_COLS) * NB_COLS);
|
|
401
|
+
}
|
|
273
402
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
int nblocks = t->ne[0] / QK4_1;
|
|
403
|
+
if (gemm_n_stride == gemm_n && tcm_buffer != nullptr && per_mb_rows_wsize <= tcm_buffer_size) {
|
|
404
|
+
for (int64_t m_start = ith * row_align; m_start < gemm_m; m_start += row_align * nth) {
|
|
405
|
+
uint8_t * b_col = reinterpret_cast<uint8_t *>(w_data);
|
|
406
|
+
uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? b_col : nullptr;
|
|
279
407
|
|
|
280
|
-
|
|
408
|
+
int64_t m_row_real = std::min(gemm_m - m_start, row_align);
|
|
281
409
|
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
}
|
|
410
|
+
spacemit_kernels::rvv::memcpy1d(tcm_buffer, quant_a_buffer + m_start * row_stride_a,
|
|
411
|
+
m_row_real * row_stride_a);
|
|
285
412
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
413
|
+
int64_t n_blk_real = 0;
|
|
414
|
+
for (int64_t ni = 0; ni < gemm_n; ni += n_blk_real, b_col += n_blk_real * row_stride_b) {
|
|
415
|
+
n_blk_real = std::min(gemm_n - ni, (int64_t) NB_COLS);
|
|
416
|
+
|
|
417
|
+
uint8_t * a_row_ptr = (uint8_t *) tcm_buffer;
|
|
418
|
+
float * c_blk = output + m_start * gemm_n + ni;
|
|
419
|
+
|
|
420
|
+
int32_t rows_remaining = m_row_real;
|
|
421
|
+
|
|
422
|
+
while (rows_remaining > 0) {
|
|
423
|
+
auto rows_handled = gemm_kernel(b_blk_len, a_row_ptr, b_col, b_col_zp, c_blk, rows_remaining,
|
|
424
|
+
n_blk_real, b_k_blks, gemm_n);
|
|
425
|
+
|
|
426
|
+
c_blk += rows_handled * gemm_n;
|
|
427
|
+
a_row_ptr += rows_handled * row_stride_a;
|
|
428
|
+
|
|
429
|
+
rows_remaining -= rows_handled;
|
|
430
|
+
}
|
|
431
|
+
}
|
|
290
432
|
}
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
433
|
+
} else if (tcm_buffer != nullptr && per_nb_cols_wsize <= tcm_buffer_size) {
|
|
434
|
+
uint8_t * a_row = quant_a_buffer;
|
|
435
|
+
uint8_t * b_col = reinterpret_cast<uint8_t *>(tcm_buffer);
|
|
436
|
+
if ((gemm_workspace_size + per_nb_cols_wsize) <= tcm_buffer_size) {
|
|
437
|
+
a_row = (uint8_t *) tcm_buffer;
|
|
438
|
+
b_col = reinterpret_cast<uint8_t *>(tcm_buffer) + gemm_workspace_size;
|
|
439
|
+
}
|
|
440
|
+
uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? b_col : nullptr;
|
|
296
441
|
|
|
297
|
-
|
|
298
|
-
|
|
442
|
+
int64_t ni = ith * NB_COLS;
|
|
443
|
+
int64_t nb_real = std::min(gemm_n - ni, NB_COLS);
|
|
299
444
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
} else {
|
|
308
|
-
*d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
|
|
309
|
-
*m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
|
|
310
|
-
}
|
|
311
|
-
}
|
|
445
|
+
if (ith % 2 == 0 && nb_real > 0) {
|
|
446
|
+
spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(w_data) + ni * row_stride_b,
|
|
447
|
+
nb_real * row_stride_b);
|
|
448
|
+
if (a_row != quant_a_buffer) {
|
|
449
|
+
spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size);
|
|
450
|
+
}
|
|
451
|
+
}
|
|
312
452
|
|
|
313
|
-
|
|
314
|
-
int interleave_block,
|
|
315
|
-
const void * GGML_RESTRICT data,
|
|
316
|
-
size_t data_size) {
|
|
317
|
-
GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
|
|
318
|
-
GGML_ASSERT(interleave_block == 16);
|
|
319
|
-
GGML_ASSERT(QK_K / QK4_1 == 8);
|
|
453
|
+
spine_barrier_wait(cur_barrier);
|
|
320
454
|
|
|
321
|
-
|
|
455
|
+
if (ith % 2 != 0 && nb_real > 0) {
|
|
456
|
+
if (a_row != quant_a_buffer) {
|
|
457
|
+
spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size);
|
|
458
|
+
}
|
|
459
|
+
spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(w_data) + ni * row_stride_b,
|
|
460
|
+
nb_real * row_stride_b);
|
|
461
|
+
}
|
|
322
462
|
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
int nblocks = t->ne[0] / QK_K;
|
|
463
|
+
for (; ni < gemm_n; ni += NB_COLS * nth) {
|
|
464
|
+
int64_t rows_remaining = gemm_m;
|
|
465
|
+
float * c_blk = output + ni;
|
|
466
|
+
auto * a_row_cur = a_row;
|
|
328
467
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
468
|
+
if (ith % 2 != 0) {
|
|
469
|
+
spine_barrier_wait(cur_barrier);
|
|
470
|
+
}
|
|
332
471
|
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4);
|
|
353
|
-
}
|
|
354
|
-
} else {
|
|
355
|
-
for (int ii = 0; ii < 16; ii++) {
|
|
356
|
-
dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0);
|
|
357
|
-
}
|
|
358
|
-
}
|
|
472
|
+
while (rows_remaining > 0) {
|
|
473
|
+
auto rows_handled = gemm_kernel(b_blk_len, a_row_cur, b_col, b_col_zp, c_blk, rows_remaining,
|
|
474
|
+
nb_real, b_k_blks, gemm_n);
|
|
475
|
+
|
|
476
|
+
c_blk += rows_handled * gemm_n;
|
|
477
|
+
a_row_cur += rows_handled * row_stride_a;
|
|
478
|
+
|
|
479
|
+
rows_remaining -= rows_handled;
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
if (ith % 2 == 0) {
|
|
483
|
+
spine_barrier_wait(cur_barrier);
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
const int64_t next_ni = ni + NB_COLS * nth;
|
|
487
|
+
if (next_ni < gemm_n) {
|
|
488
|
+
nb_real = std::min(gemm_n - next_ni, NB_COLS);
|
|
489
|
+
spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(w_data) + next_ni * row_stride_b,
|
|
490
|
+
nb_real * row_stride_b);
|
|
359
491
|
}
|
|
360
|
-
*dst++ = make_block_q4_1x16(dst_tmp, interleave_block);
|
|
361
492
|
}
|
|
362
|
-
}
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
return 0;
|
|
493
|
+
} else {
|
|
494
|
+
const int64_t task_count_m = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride);
|
|
495
|
+
const int64_t task_count_n = spacemit_kernels::div_round_up(gemm_n, gemm_n_stride);
|
|
366
496
|
|
|
367
|
-
|
|
368
|
-
|
|
497
|
+
int64_t task_count = task_count_m * task_count_n;
|
|
498
|
+
int64_t task_per_thread = (task_count + nth - 1) / nth;
|
|
499
|
+
int64_t start = ith * task_per_thread;
|
|
500
|
+
int64_t end = std::min((ith + 1) * task_per_thread, task_count);
|
|
501
|
+
for (int64_t compute_idx = start; compute_idx < end; compute_idx++) {
|
|
502
|
+
const auto tid_n = compute_idx / task_count_m;
|
|
503
|
+
const auto tid_m = compute_idx % task_count_m;
|
|
369
504
|
|
|
370
|
-
|
|
505
|
+
const int64_t m_start = tid_m * gemm_m_stride;
|
|
506
|
+
const int64_t m_count = std::min(gemm_m - m_start, (int64_t) gemm_m_stride);
|
|
371
507
|
|
|
372
|
-
|
|
373
|
-
|
|
508
|
+
const int64_t n_start = tid_n * gemm_n_stride;
|
|
509
|
+
const int64_t n_count = std::min(gemm_n - n_start, (int64_t) gemm_n_stride);
|
|
374
510
|
|
|
375
|
-
|
|
376
|
-
return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size);
|
|
377
|
-
}
|
|
511
|
+
const int64_t n_blk = m_count == 1 ? n_count : NB_COLS;
|
|
378
512
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
}
|
|
513
|
+
uint8_t * b_col = reinterpret_cast<uint8_t *>(w_data) + n_start * row_stride_b;
|
|
514
|
+
uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? b_col : nullptr;
|
|
382
515
|
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
516
|
+
int64_t n_blk_real = 0;
|
|
517
|
+
for (int64_t ni = 0; ni < n_count; ni += n_blk_real, b_col += n_blk_real * row_stride_b) {
|
|
518
|
+
n_blk_real = std::min(n_count - ni, n_blk);
|
|
386
519
|
|
|
387
|
-
|
|
388
|
-
public:
|
|
389
|
-
virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
|
|
390
|
-
};
|
|
520
|
+
uint8_t * a_row = quant_a_buffer + m_start * row_stride_a;
|
|
391
521
|
|
|
392
|
-
|
|
393
|
-
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
|
394
|
-
switch (op->op) {
|
|
395
|
-
case GGML_OP_MUL_MAT:
|
|
396
|
-
size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4;
|
|
397
|
-
size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float));
|
|
398
|
-
return true;
|
|
399
|
-
default:
|
|
400
|
-
// GGML_ABORT("fatal error");
|
|
401
|
-
break;
|
|
402
|
-
}
|
|
403
|
-
return false;
|
|
404
|
-
}
|
|
522
|
+
float * c_blk = output + m_start * gemm_n + n_start + ni;
|
|
405
523
|
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
524
|
+
int64_t rows_remaining = m_count;
|
|
525
|
+
|
|
526
|
+
uint8_t * b_col_cur = b_col;
|
|
527
|
+
uint8_t * b_col_zp_cur = b_col_zp;
|
|
528
|
+
|
|
529
|
+
while (rows_remaining > 0) {
|
|
530
|
+
auto rows_handled = gemm_kernel(b_blk_len, a_row, b_col_cur, b_col_zp_cur, c_blk,
|
|
531
|
+
rows_remaining, n_blk_real, b_k_blks, gemm_n);
|
|
532
|
+
|
|
533
|
+
c_blk += rows_handled * gemm_n;
|
|
534
|
+
a_row += rows_handled * row_stride_a;
|
|
535
|
+
|
|
536
|
+
rows_remaining -= rows_handled;
|
|
537
|
+
}
|
|
414
538
|
}
|
|
415
|
-
|
|
416
|
-
// GGML_ABORT("fatal error");
|
|
417
|
-
break;
|
|
539
|
+
}
|
|
418
540
|
}
|
|
419
|
-
return false;
|
|
420
541
|
}
|
|
421
542
|
|
|
422
|
-
void
|
|
543
|
+
void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) {
|
|
544
|
+
constexpr size_t a_blk_len = INTER_SIZE;
|
|
545
|
+
constexpr size_t b_blk_len = INTER_SIZE;
|
|
546
|
+
|
|
423
547
|
const ggml_tensor * src0 = op->src[0];
|
|
424
548
|
const ggml_tensor * src1 = op->src[1];
|
|
549
|
+
const ggml_tensor * ids = op->src[2];
|
|
425
550
|
ggml_tensor * dst = op;
|
|
426
551
|
|
|
427
552
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
@@ -429,133 +554,381 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|
|
429
554
|
int ith = params->ith;
|
|
430
555
|
int nth = params->nth;
|
|
431
556
|
|
|
432
|
-
|
|
557
|
+
// row groups
|
|
558
|
+
const int n_ids = ids->ne[0]; // n_expert_used
|
|
559
|
+
const int n_as = ne02; // n_expert
|
|
560
|
+
|
|
561
|
+
struct mmid_row_mapping {
|
|
562
|
+
int32_t i1;
|
|
563
|
+
int32_t i2;
|
|
564
|
+
};
|
|
565
|
+
|
|
566
|
+
spacemit_kernels::quantize_a_row_def quantize_a_row_i8;
|
|
567
|
+
spacemit_kernels::gemm_kernel_quantize_def gemm_kernel;
|
|
568
|
+
spacemit_kernels::moe_gemm_kernel_quantize_def moe_gemm_kernel_m2;
|
|
569
|
+
bool set_kernel_impl = false;
|
|
570
|
+
size_t block_stride_a = spacemit_kernels::q8_blk_size(QK4_0);
|
|
571
|
+
|
|
572
|
+
#if defined(RISCV64_SPACEMIT_IME2)
|
|
573
|
+
if (!set_kernel_impl && (global_spine_env_info.use_ime2)) {
|
|
574
|
+
quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8;
|
|
575
|
+
block_stride_a = spacemit_kernels::q8_blk_size(QK4_0, true);
|
|
576
|
+
|
|
577
|
+
if constexpr (std::is_same_v<BLOC_TYPE, block_q6_K> || std::is_same_v<BLOC_TYPE, block_q8_0>) {
|
|
578
|
+
gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8;
|
|
579
|
+
set_kernel_impl = true;
|
|
580
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0> || std::is_same_v<BLOC_TYPE, block_q4_1> ||
|
|
581
|
+
std::is_same_v<BLOC_TYPE, block_q4_K>) {
|
|
582
|
+
if constexpr (INTER_SIZE == 256) {
|
|
583
|
+
gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp;
|
|
584
|
+
quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp;
|
|
585
|
+
block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true);
|
|
586
|
+
set_kernel_impl = true;
|
|
587
|
+
} else {
|
|
588
|
+
gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4;
|
|
589
|
+
moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i4;
|
|
590
|
+
quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8;
|
|
591
|
+
block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true);
|
|
592
|
+
set_kernel_impl = true;
|
|
593
|
+
}
|
|
594
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K>) {
|
|
595
|
+
quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k;
|
|
596
|
+
block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len);
|
|
597
|
+
gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k;
|
|
598
|
+
set_kernel_impl = true;
|
|
599
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_q3_K>) {
|
|
600
|
+
quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k;
|
|
601
|
+
block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len);
|
|
602
|
+
gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k;
|
|
603
|
+
set_kernel_impl = true;
|
|
604
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_mxfp4>) {
|
|
605
|
+
gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4;
|
|
606
|
+
moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8mxfp4;
|
|
607
|
+
set_kernel_impl = true;
|
|
608
|
+
} else if constexpr (std::is_same_v<BLOC_TYPE, block_q5_1> || std::is_same_v<BLOC_TYPE, block_q5_K> ||
|
|
609
|
+
std::is_same_v<BLOC_TYPE, block_q5_0>) {
|
|
610
|
+
gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5;
|
|
611
|
+
moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i5;
|
|
612
|
+
set_kernel_impl = true;
|
|
613
|
+
}
|
|
614
|
+
}
|
|
615
|
+
#endif
|
|
433
616
|
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
617
|
+
#if defined(RISCV64_SPACEMIT_IME1)
|
|
618
|
+
if (!set_kernel_impl && (global_spine_env_info.use_ime1)) {
|
|
619
|
+
quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8;
|
|
620
|
+
|
|
621
|
+
if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0> || std::is_same_v<BLOC_TYPE, block_q4_1> ||
|
|
622
|
+
std::is_same_v<BLOC_TYPE, block_q4_K>) {
|
|
623
|
+
gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4;
|
|
624
|
+
set_kernel_impl = true;
|
|
625
|
+
}
|
|
626
|
+
}
|
|
627
|
+
#endif
|
|
628
|
+
if (!set_kernel_impl) {
|
|
629
|
+
GGML_ABORT("no kernel implementation found for the block type");
|
|
630
|
+
}
|
|
437
631
|
|
|
438
|
-
const size_t
|
|
439
|
-
|
|
440
|
-
const size_t gemm_m = ne11;
|
|
441
|
-
const size_t gemm_k = ne10;
|
|
442
|
-
const size_t gemm_n = ne01;
|
|
632
|
+
const size_t a_k_blks = spacemit_kernels::div_round_up(ne10, a_blk_len);
|
|
633
|
+
const size_t b_k_blks = spacemit_kernels::div_round_up(ne10, b_blk_len);
|
|
443
634
|
|
|
444
|
-
|
|
635
|
+
const size_t nbw1 = a_k_blks * block_stride_a;
|
|
636
|
+
const size_t nbw2 = ne11 * nbw1;
|
|
637
|
+
const size_t nbw3 = nbw2 * ne12;
|
|
638
|
+
const size_t gemm_workspace_size = GGML_PAD(nbw3, alignof(int64_t));
|
|
445
639
|
|
|
446
|
-
const
|
|
447
|
-
|
|
448
|
-
const size_t per_gemm_workspace_stride =
|
|
449
|
-
div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t);
|
|
450
|
-
const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride;
|
|
451
|
-
const size_t desired_wsize = gemm_workspace_size + alignof(uint64_t) - 1;
|
|
640
|
+
const uintptr_t ws_ptr = reinterpret_cast<uintptr_t>(params->wdata);
|
|
641
|
+
auto * quant_a_buffer = reinterpret_cast<uint8_t *>(ws_ptr);
|
|
452
642
|
|
|
453
|
-
if (
|
|
454
|
-
|
|
643
|
+
if (ne11 == 1) {
|
|
644
|
+
for (int64_t ii = ith; ii < ne12 * a_k_blks; ii += nth) {
|
|
645
|
+
int64_t i12 = ii / a_k_blks;
|
|
646
|
+
int64_t ak_blk_id = ii % a_k_blks;
|
|
647
|
+
quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12) + ak_blk_id * a_blk_len,
|
|
648
|
+
a_blk_len, quant_a_buffer + i12 * nbw2 + ak_blk_id * block_stride_a);
|
|
649
|
+
}
|
|
650
|
+
} else {
|
|
651
|
+
for (int64_t ii = ith; ii < ne12 * ne11; ii += nth) {
|
|
652
|
+
int64_t i12 = ii / ne11;
|
|
653
|
+
int64_t i11 = ii % ne11;
|
|
654
|
+
quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12 + i11 * nb11), ne10,
|
|
655
|
+
quant_a_buffer + i12 * nbw2 + i11 * nbw1);
|
|
656
|
+
}
|
|
455
657
|
}
|
|
456
658
|
|
|
457
|
-
|
|
659
|
+
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) *ne12 + (i1)]
|
|
458
660
|
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
661
|
+
int64_t * matrix_row_counts = (int64_t *) (ws_ptr + gemm_workspace_size);
|
|
662
|
+
int32_t * valid_ep_count = (int32_t *) (matrix_row_counts + n_as);
|
|
663
|
+
int32_t * valid_act_count = (int32_t *) (valid_ep_count + 1);
|
|
664
|
+
int64_t * valid_matrix_row_counts = (int64_t *) (valid_act_count + 1);
|
|
665
|
+
mmid_row_mapping * matrix_rows = (mmid_row_mapping *) (valid_matrix_row_counts + n_as);
|
|
464
666
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
667
|
+
if (ith == 0) {
|
|
668
|
+
// initialize matrix_row_counts
|
|
669
|
+
memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
|
|
670
|
+
|
|
671
|
+
// group rows by src0 matrix
|
|
672
|
+
for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
|
|
673
|
+
for (int32_t id = 0; id < n_ids; ++id) {
|
|
674
|
+
const int32_t i02 =
|
|
675
|
+
*(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
|
|
676
|
+
|
|
677
|
+
GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
|
678
|
+
|
|
679
|
+
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
|
|
680
|
+
matrix_row_counts[i02] += 1;
|
|
681
|
+
}
|
|
469
682
|
}
|
|
470
683
|
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
684
|
+
int32_t valid_ep_count_t = 0;
|
|
685
|
+
int32_t valid_act_count_t = 0;
|
|
686
|
+
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
|
|
687
|
+
const int64_t cne1 = matrix_row_counts[cur_a];
|
|
688
|
+
if (cne1 == 0) {
|
|
689
|
+
continue;
|
|
690
|
+
}
|
|
691
|
+
valid_matrix_row_counts[valid_ep_count_t] = cur_a;
|
|
692
|
+
valid_act_count_t += cne1;
|
|
693
|
+
valid_ep_count_t += 1;
|
|
694
|
+
}
|
|
695
|
+
valid_ep_count[0] = valid_ep_count_t;
|
|
696
|
+
valid_act_count[0] = valid_act_count_t;
|
|
474
697
|
}
|
|
475
698
|
|
|
476
|
-
const
|
|
477
|
-
void * ws = reinterpret_cast<void *>((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1)));
|
|
478
|
-
const size_t quant_a_stride = block_count_k * q8_blk_size(QK4_0);
|
|
699
|
+
const int64_t barrier_idx = static_cast<int64_t>(ith / 2);
|
|
479
700
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
701
|
+
GGML_ASSERT(global_spine_env_info.init_barrier != nullptr);
|
|
702
|
+
GGML_ASSERT(barrier_idx < spine_init_barrier_count);
|
|
703
|
+
spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx];
|
|
704
|
+
|
|
705
|
+
ggml_barrier(params->threadpool);
|
|
706
|
+
|
|
707
|
+
const size_t row_stride_b = b_k_blks * get_repacked_block_type_size<BLOC_TYPE, INTER_SIZE, NB_COLS>();
|
|
708
|
+
const size_t expert_b_stride = ne01 * row_stride_b;
|
|
709
|
+
const size_t per_nb_cols_wsize = NB_COLS * row_stride_b;
|
|
710
|
+
|
|
711
|
+
std::array<const uint8_t *, 2> src_workspaces;
|
|
712
|
+
std::array<float *, 2> dst_workspaces;
|
|
713
|
+
|
|
714
|
+
auto * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer;
|
|
715
|
+
const auto tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size;
|
|
716
|
+
|
|
717
|
+
const auto valid_ep_count_t = valid_ep_count[0];
|
|
718
|
+
const auto valid_act_count_t = valid_act_count[0];
|
|
719
|
+
|
|
720
|
+
int nth_es = 1;
|
|
721
|
+
int nth_n = nth;
|
|
722
|
+
|
|
723
|
+
int ith_es = ith % nth_es;
|
|
724
|
+
int ith_n = (ith / nth_es) % nth_n;
|
|
725
|
+
|
|
726
|
+
if (valid_ep_count_t % nth == 0 && tcm_buffer != nullptr && valid_ep_count_t == n_as &&
|
|
727
|
+
valid_act_count_t == n_as && per_nb_cols_wsize <= tcm_buffer_size) {
|
|
728
|
+
for (int64_t valid_id = ith; valid_id < valid_ep_count_t; valid_id += nth) {
|
|
729
|
+
const int64_t cur_a = valid_matrix_row_counts[valid_id];
|
|
730
|
+
|
|
731
|
+
auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride;
|
|
732
|
+
|
|
733
|
+
mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, 0);
|
|
734
|
+
const int id = row_mapping.i1;
|
|
735
|
+
const int64_t i11 = id % ne11;
|
|
736
|
+
const int64_t i12 = row_mapping.i2;
|
|
737
|
+
const int64_t i1 = id;
|
|
738
|
+
const int64_t i2 = i12;
|
|
739
|
+
|
|
740
|
+
auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2);
|
|
741
|
+
float * c_blk = (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2));
|
|
742
|
+
|
|
743
|
+
uint8_t * a_row = src1_col;
|
|
744
|
+
uint8_t * b_col = reinterpret_cast<uint8_t *>(tcm_buffer);
|
|
745
|
+
if ((nbw1 + per_nb_cols_wsize) <= tcm_buffer_size) {
|
|
746
|
+
a_row = (uint8_t *) tcm_buffer;
|
|
747
|
+
b_col = reinterpret_cast<uint8_t *>(tcm_buffer) + nbw1;
|
|
748
|
+
}
|
|
749
|
+
uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? b_col : nullptr;
|
|
750
|
+
|
|
751
|
+
if (ith % 2 == 0) {
|
|
752
|
+
spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(src0_cur), per_nb_cols_wsize);
|
|
753
|
+
|
|
754
|
+
if (a_row != src1_col) {
|
|
755
|
+
spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1);
|
|
756
|
+
}
|
|
757
|
+
}
|
|
758
|
+
|
|
759
|
+
spine_barrier_wait(cur_barrier);
|
|
760
|
+
|
|
761
|
+
if (ith % 2 != 0) {
|
|
762
|
+
if (a_row != src1_col) {
|
|
763
|
+
spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1);
|
|
764
|
+
}
|
|
765
|
+
|
|
766
|
+
spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(src0_cur), per_nb_cols_wsize);
|
|
767
|
+
}
|
|
768
|
+
|
|
769
|
+
int64_t nb_real = std::min(ne01, NB_COLS);
|
|
770
|
+
for (int64_t ni = 0; ni < ne01; ni += NB_COLS) {
|
|
771
|
+
if (ith % 2 != 0) {
|
|
772
|
+
spine_barrier_wait(cur_barrier);
|
|
773
|
+
}
|
|
774
|
+
|
|
775
|
+
gemm_kernel(b_blk_len, a_row, b_col, b_col_zp, c_blk + ni, 1, nb_real, b_k_blks, ne01);
|
|
776
|
+
|
|
777
|
+
if (ith % 2 == 0) {
|
|
778
|
+
spine_barrier_wait(cur_barrier);
|
|
779
|
+
}
|
|
780
|
+
|
|
781
|
+
const int64_t next_ni = ni + NB_COLS;
|
|
782
|
+
if (next_ni < ne01) {
|
|
783
|
+
nb_real = std::min(ne01 - next_ni, NB_COLS);
|
|
784
|
+
spacemit_kernels::rvv::memcpy1d(
|
|
785
|
+
b_col, reinterpret_cast<uint8_t *>(src0_cur) + next_ni * row_stride_b, per_nb_cols_wsize);
|
|
507
786
|
}
|
|
508
787
|
}
|
|
509
788
|
}
|
|
510
|
-
}
|
|
789
|
+
} else {
|
|
790
|
+
for (int64_t valid_id = ith_es; valid_id < valid_ep_count_t; valid_id += nth_es) {
|
|
791
|
+
const int64_t cur_a = valid_matrix_row_counts[valid_id];
|
|
792
|
+
const int64_t cne1 = matrix_row_counts[cur_a];
|
|
511
793
|
|
|
512
|
-
|
|
794
|
+
int64_t src1_cur_start = 0;
|
|
795
|
+
int64_t src1_cur_end = cne1;
|
|
513
796
|
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
}
|
|
517
|
-
nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores });
|
|
518
|
-
|
|
519
|
-
size_t threads_per_gemm = nth / batch_feature;
|
|
520
|
-
constexpr size_t gemm_m_stride = 128;
|
|
521
|
-
size_t nc = gemm_n;
|
|
522
|
-
const size_t gemm_m_blocked = div_round_up(gemm_m, gemm_m_stride);
|
|
523
|
-
const size_t max_nc = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm);
|
|
524
|
-
if (max_nc < nc) {
|
|
525
|
-
nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN);
|
|
526
|
-
}
|
|
527
|
-
const size_t gemm_n_stride = nc;
|
|
528
|
-
const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride);
|
|
529
|
-
const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride);
|
|
530
|
-
threads_per_gemm = thread_count_m * thread_count_n;
|
|
797
|
+
int64_t src0_cur_start = (ith_n * ne01) / nth_n;
|
|
798
|
+
int64_t src0_cur_end = MIN(((ith_n + 1) * ne01) / nth_n, ne01);
|
|
531
799
|
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
800
|
+
if (src1_cur_start >= src1_cur_end || src0_cur_start >= src0_cur_end) {
|
|
801
|
+
continue;
|
|
802
|
+
}
|
|
803
|
+
|
|
804
|
+
src0_cur_start =
|
|
805
|
+
(src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
|
|
806
|
+
src0_cur_end =
|
|
807
|
+
(src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
|
|
808
|
+
|
|
809
|
+
auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride + src0_cur_start * row_stride_b;
|
|
810
|
+
uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? src0_cur : nullptr;
|
|
811
|
+
|
|
812
|
+
size_t extra_tcm_buffer_size = tcm_buffer_size;
|
|
813
|
+
void * extra_tcm_buffer = tcm_buffer;
|
|
814
|
+
if (tcm_buffer != nullptr && (src1_cur_end - src1_cur_start) >= 4 &&
|
|
815
|
+
(src0_cur_end - src0_cur_start) * row_stride_b <= tcm_buffer_size) {
|
|
816
|
+
spacemit_kernels::rvv::memcpy1d(tcm_buffer, src0_cur,
|
|
817
|
+
(src0_cur_end - src0_cur_start) * row_stride_b);
|
|
818
|
+
src0_cur = reinterpret_cast<uint8_t *>(tcm_buffer);
|
|
819
|
+
b_col_zp = block_type_has_zp<BLOC_TYPE>() ? src0_cur : nullptr;
|
|
820
|
+
extra_tcm_buffer_size -= (src0_cur_end - src0_cur_start) * row_stride_b;
|
|
821
|
+
extra_tcm_buffer = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(tcm_buffer) +
|
|
822
|
+
(src0_cur_end - src0_cur_start) * row_stride_b);
|
|
823
|
+
}
|
|
541
824
|
|
|
542
|
-
|
|
543
|
-
const auto tid_m = blk_i % thread_count_m;
|
|
825
|
+
int ir1 = src1_cur_start;
|
|
544
826
|
|
|
545
|
-
|
|
546
|
-
|
|
827
|
+
if (extra_tcm_buffer_size >= nbw1 && extra_tcm_buffer != nullptr) {
|
|
828
|
+
int64_t quant_a_tile_size = extra_tcm_buffer_size / nbw1;
|
|
829
|
+
do {
|
|
830
|
+
quant_a_tile_size = MIN(quant_a_tile_size, src1_cur_end - ir1);
|
|
547
831
|
|
|
548
|
-
|
|
549
|
-
const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride);
|
|
832
|
+
uint8_t * quant_a_tile_buffer = reinterpret_cast<uint8_t *>(extra_tcm_buffer);
|
|
550
833
|
|
|
551
|
-
|
|
834
|
+
int iir1 = ir1;
|
|
835
|
+
for (; iir1 < (ir1 + quant_a_tile_size); ++iir1) {
|
|
836
|
+
mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, iir1);
|
|
552
837
|
|
|
553
|
-
|
|
838
|
+
const int id = row_mapping.i1; // selected expert index
|
|
839
|
+
|
|
840
|
+
const int64_t i11 = id % ne11;
|
|
841
|
+
const int64_t i12 = row_mapping.i2; // row index in src1
|
|
842
|
+
|
|
843
|
+
auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2);
|
|
844
|
+
spacemit_kernels::rvv::memcpy1d(quant_a_tile_buffer, src1_col, nbw1);
|
|
845
|
+
quant_a_tile_buffer = quant_a_tile_buffer + nbw1;
|
|
846
|
+
}
|
|
847
|
+
|
|
848
|
+
quant_a_tile_buffer = reinterpret_cast<uint8_t *>(extra_tcm_buffer);
|
|
849
|
+
iir1 = ir1;
|
|
850
|
+
|
|
851
|
+
if (moe_gemm_kernel_m2 != nullptr) {
|
|
852
|
+
for (; iir1 < (ir1 + quant_a_tile_size - 1); iir1 += 2, quant_a_tile_buffer += 2 * nbw1) {
|
|
853
|
+
mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1);
|
|
854
|
+
mmid_row_mapping row_mapping_1 = MMID_MATRIX_ROW(cur_a, iir1 + 1);
|
|
855
|
+
|
|
856
|
+
src_workspaces[0] = quant_a_tile_buffer;
|
|
857
|
+
src_workspaces[1] = quant_a_tile_buffer + nbw1;
|
|
858
|
+
|
|
859
|
+
dst_workspaces[0] =
|
|
860
|
+
(float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) +
|
|
861
|
+
src0_cur_start;
|
|
862
|
+
dst_workspaces[1] = (float *) ((char *) dst->data +
|
|
863
|
+
((row_mapping_1.i1) * nb1 + (row_mapping_1.i2) * nb2)) +
|
|
864
|
+
src0_cur_start;
|
|
865
|
+
moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp,
|
|
866
|
+
dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks,
|
|
867
|
+
ne01);
|
|
868
|
+
}
|
|
869
|
+
}
|
|
870
|
+
|
|
871
|
+
for (; iir1 < (ir1 + quant_a_tile_size); iir1++, quant_a_tile_buffer += nbw1) {
|
|
872
|
+
mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1);
|
|
873
|
+
|
|
874
|
+
gemm_kernel(
|
|
875
|
+
b_blk_len, quant_a_tile_buffer, src0_cur, b_col_zp,
|
|
876
|
+
(float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) +
|
|
877
|
+
src0_cur_start,
|
|
878
|
+
1, src0_cur_end - src0_cur_start, b_k_blks, ne01);
|
|
879
|
+
}
|
|
880
|
+
|
|
881
|
+
ir1 += quant_a_tile_size;
|
|
882
|
+
} while (ir1 < src1_cur_end);
|
|
883
|
+
} else {
|
|
884
|
+
if (moe_gemm_kernel_m2 != nullptr) {
|
|
885
|
+
for (; ir1 < src1_cur_end - 1; ir1 += 2) {
|
|
886
|
+
for (int iir1 = 0; iir1 < 2; ++iir1) {
|
|
887
|
+
mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1 + iir1);
|
|
888
|
+
|
|
889
|
+
const int id = row_mapping.i1; // selected expert index
|
|
890
|
+
|
|
891
|
+
const int64_t i11 = id % ne11;
|
|
892
|
+
const int64_t i12 = row_mapping.i2; // row index in src1
|
|
893
|
+
|
|
894
|
+
const int64_t i1 = id; // selected expert index
|
|
895
|
+
const int64_t i2 = i12; // row
|
|
896
|
+
|
|
897
|
+
src_workspaces[iir1] = quant_a_buffer + (i11 * nbw1 + i12 * nbw2);
|
|
898
|
+
|
|
899
|
+
dst_workspaces[iir1] =
|
|
900
|
+
(float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start;
|
|
901
|
+
}
|
|
902
|
+
|
|
903
|
+
moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp,
|
|
904
|
+
dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks, ne01);
|
|
905
|
+
}
|
|
906
|
+
}
|
|
907
|
+
|
|
908
|
+
for (; ir1 < src1_cur_end; ir1++) {
|
|
909
|
+
mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
|
|
910
|
+
|
|
911
|
+
const int id = row_mapping.i1; // selected expert index
|
|
912
|
+
|
|
913
|
+
const int64_t i11 = id % ne11;
|
|
914
|
+
const int64_t i12 = row_mapping.i2; // row index in src1
|
|
915
|
+
|
|
916
|
+
const int64_t i1 = id; // selected expert index
|
|
917
|
+
const int64_t i2 = i12; // row
|
|
918
|
+
|
|
919
|
+
auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2);
|
|
920
|
+
|
|
921
|
+
gemm_kernel(b_blk_len, src1_col, src0_cur, b_col_zp,
|
|
922
|
+
(float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, 1,
|
|
923
|
+
src0_cur_end - src0_cur_start, b_k_blks, ne01);
|
|
924
|
+
}
|
|
925
|
+
}
|
|
554
926
|
}
|
|
555
927
|
}
|
|
928
|
+
#undef MMID_MATRIX_ROW
|
|
556
929
|
}
|
|
557
930
|
|
|
558
|
-
int repack(
|
|
931
|
+
int repack(ggml_tensor * t, const void * data, size_t data_size) override {
|
|
559
932
|
GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
|
|
560
933
|
(int) NB_COLS, (int) INTER_SIZE);
|
|
561
934
|
return ggml::cpu::riscv64_spacemit::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
|
|
@@ -563,309 +936,464 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|
|
563
936
|
};
|
|
564
937
|
|
|
565
938
|
class tensor_traits_common : public tensor_traits_base {
|
|
566
|
-
bool work_size(int
|
|
939
|
+
bool work_size(int n_threads, const ggml_tensor * op, size_t & size) override {
|
|
567
940
|
switch (op->op) {
|
|
568
|
-
case
|
|
569
|
-
|
|
570
|
-
|
|
941
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
|
942
|
+
{
|
|
943
|
+
const int n_tasks = n_threads;
|
|
944
|
+
const int64_t neq2 = op->src[0]->ne[2]; // number of query heads
|
|
945
|
+
const int64_t DK = op->src[1]->ne[0];
|
|
946
|
+
const int64_t DV = op->src[2]->ne[0]; // DV
|
|
947
|
+
|
|
948
|
+
// Tiled flash attention scratch (tile sizes defined in common.h)
|
|
949
|
+
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding
|
|
950
|
+
size_t prefill = sizeof(float) *
|
|
951
|
+
(GGML_FA_TILE_Q * DK + 2 * GGML_FA_TILE_Q * GGML_FA_TILE_KV + GGML_FA_TILE_Q * DV +
|
|
952
|
+
GGML_FA_TILE_KV * DV + GGML_FA_TILE_KV * DK) *
|
|
953
|
+
n_tasks;
|
|
954
|
+
|
|
955
|
+
// Decode path: n_kv_chunks = n_tasks (one chunk per thread)
|
|
956
|
+
// Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
|
|
957
|
+
size_t n_chunks = n_tasks;
|
|
958
|
+
size_t decode = sizeof(float) * (neq2 * n_chunks * (2 + DV) + n_tasks * (DK + 2 * DV));
|
|
959
|
+
|
|
960
|
+
size = MAX(prefill, decode);
|
|
961
|
+
}
|
|
571
962
|
return true;
|
|
572
963
|
default:
|
|
573
|
-
// GGML_ABORT("fatal error");
|
|
574
964
|
break;
|
|
575
965
|
}
|
|
576
966
|
return false;
|
|
577
967
|
}
|
|
578
968
|
|
|
579
|
-
bool compute_forward(
|
|
969
|
+
bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override {
|
|
580
970
|
switch (op->op) {
|
|
581
971
|
case GGML_OP_NORM:
|
|
582
|
-
|
|
583
|
-
|
|
972
|
+
switch (op->src[0]->type) {
|
|
973
|
+
case GGML_TYPE_F32:
|
|
974
|
+
spacemit_kernels::rvv::forward_norm_f32(params, op);
|
|
975
|
+
return true;
|
|
976
|
+
default:
|
|
977
|
+
GGML_ABORT("fatal error");
|
|
978
|
+
}
|
|
584
979
|
case GGML_OP_RMS_NORM:
|
|
585
|
-
|
|
980
|
+
switch (op->src[0]->type) {
|
|
981
|
+
case GGML_TYPE_F32:
|
|
982
|
+
spacemit_kernels::rvv::forward_rms_norm_f32(params, op);
|
|
983
|
+
return true;
|
|
984
|
+
default:
|
|
985
|
+
GGML_ABORT("fatal error");
|
|
986
|
+
}
|
|
987
|
+
case GGML_OP_ADD:
|
|
988
|
+
switch (op->src[0]->type) {
|
|
989
|
+
case GGML_TYPE_F32:
|
|
990
|
+
spacemit_kernels::rvv::forward_binary<GGML_OP_ADD, float>(params, op);
|
|
991
|
+
return true;
|
|
992
|
+
case GGML_TYPE_F16:
|
|
993
|
+
spacemit_kernels::rvv::forward_binary<GGML_OP_ADD, _Float16>(params, op);
|
|
994
|
+
return true;
|
|
995
|
+
default:
|
|
996
|
+
ggml_compute_forward_add(params, op);
|
|
997
|
+
return true;
|
|
998
|
+
}
|
|
999
|
+
case GGML_OP_SUB:
|
|
1000
|
+
switch (op->src[0]->type) {
|
|
1001
|
+
case GGML_TYPE_F32:
|
|
1002
|
+
spacemit_kernels::rvv::forward_binary<GGML_OP_SUB, float>(params, op);
|
|
1003
|
+
return true;
|
|
1004
|
+
case GGML_TYPE_F16:
|
|
1005
|
+
spacemit_kernels::rvv::forward_binary<GGML_OP_SUB, _Float16>(params, op);
|
|
1006
|
+
return true;
|
|
1007
|
+
default:
|
|
1008
|
+
ggml_compute_forward_sub(params, op);
|
|
1009
|
+
return true;
|
|
1010
|
+
}
|
|
1011
|
+
case GGML_OP_MUL:
|
|
1012
|
+
switch (op->src[0]->type) {
|
|
1013
|
+
case GGML_TYPE_F32:
|
|
1014
|
+
spacemit_kernels::rvv::forward_binary<GGML_OP_MUL, float>(params, op);
|
|
1015
|
+
return true;
|
|
1016
|
+
case GGML_TYPE_F16:
|
|
1017
|
+
spacemit_kernels::rvv::forward_binary<GGML_OP_MUL, _Float16>(params, op);
|
|
1018
|
+
return true;
|
|
1019
|
+
default:
|
|
1020
|
+
ggml_compute_forward_mul(params, op);
|
|
1021
|
+
return true;
|
|
1022
|
+
}
|
|
1023
|
+
case GGML_OP_DIV:
|
|
1024
|
+
switch (op->src[0]->type) {
|
|
1025
|
+
case GGML_TYPE_F32:
|
|
1026
|
+
spacemit_kernels::rvv::forward_binary<GGML_OP_DIV, float>(params, op);
|
|
1027
|
+
return true;
|
|
1028
|
+
case GGML_TYPE_F16:
|
|
1029
|
+
spacemit_kernels::rvv::forward_binary<GGML_OP_DIV, _Float16>(params, op);
|
|
1030
|
+
return true;
|
|
1031
|
+
default:
|
|
1032
|
+
ggml_compute_forward_div(params, op);
|
|
1033
|
+
return true;
|
|
1034
|
+
}
|
|
1035
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
|
1036
|
+
forward_flash_attn_ext_f16(params, op);
|
|
1037
|
+
return true;
|
|
1038
|
+
case GGML_OP_CONT:
|
|
1039
|
+
{
|
|
1040
|
+
const ggml_tensor * src0 = op->src[0];
|
|
1041
|
+
if (op->type == src0->type && op->nb[0] != src0->nb[0] && op->nb[0] == src0->nb[1] &&
|
|
1042
|
+
op->ne[3] * op->ne[2] * op->nb[2] == src0->ne[3] * src0->ne[2] * src0->nb[2]) {
|
|
1043
|
+
spacemit_kernels::rvv::forward_cont_with_permute(params, op);
|
|
1044
|
+
} else {
|
|
1045
|
+
ggml_compute_forward_cont(params, op);
|
|
1046
|
+
}
|
|
1047
|
+
return true;
|
|
1048
|
+
}
|
|
1049
|
+
case GGML_OP_CPY:
|
|
1050
|
+
{
|
|
1051
|
+
const ggml_tensor * src0 = op->src[0];
|
|
1052
|
+
if (op->type == src0->type && op->nb[0] == src0->nb[1] && src0->nb[0] != src0->nb[1] &&
|
|
1053
|
+
ggml_nelements(src0) == ggml_nelements(op)) {
|
|
1054
|
+
spacemit_kernels::rvv::forward_cpy_with_permute(params, op);
|
|
1055
|
+
} else {
|
|
1056
|
+
ggml_compute_forward_cpy(params, op);
|
|
1057
|
+
}
|
|
1058
|
+
return true;
|
|
1059
|
+
}
|
|
1060
|
+
case GGML_OP_REPEAT:
|
|
1061
|
+
{
|
|
1062
|
+
const bool rows_equal = ggml_nrows(op->src[0]) == ggml_nrows(op);
|
|
1063
|
+
const bool broadcast_or_equal = op->src[0]->ne[0] == 1 || op->src[0]->ne[0] == op->ne[0];
|
|
1064
|
+
|
|
1065
|
+
if (rows_equal && broadcast_or_equal) {
|
|
1066
|
+
switch (op->src[0]->type) {
|
|
1067
|
+
case GGML_TYPE_F32:
|
|
1068
|
+
spacemit_kernels::rvv::forward_repeat_nrows<int32_t>(params, op);
|
|
1069
|
+
return true;
|
|
1070
|
+
case GGML_TYPE_F16:
|
|
1071
|
+
spacemit_kernels::rvv::forward_repeat_nrows<int16_t>(params, op);
|
|
1072
|
+
return true;
|
|
1073
|
+
default:
|
|
1074
|
+
break;
|
|
1075
|
+
}
|
|
1076
|
+
}
|
|
1077
|
+
|
|
1078
|
+
if (op->src[0]->ne[1] == 1 && op->src[0]->ne[0] == op->ne[0]) {
|
|
1079
|
+
switch (op->src[0]->type) {
|
|
1080
|
+
case GGML_TYPE_F32:
|
|
1081
|
+
spacemit_kernels::rvv::forward_repeat_dim1<int32_t>(params, op);
|
|
1082
|
+
return true;
|
|
1083
|
+
case GGML_TYPE_F16:
|
|
1084
|
+
spacemit_kernels::rvv::forward_repeat_dim1<int16_t>(params, op);
|
|
1085
|
+
return true;
|
|
1086
|
+
default:
|
|
1087
|
+
break;
|
|
1088
|
+
}
|
|
1089
|
+
}
|
|
1090
|
+
|
|
1091
|
+
ggml_compute_forward_repeat(params, op);
|
|
1092
|
+
}
|
|
1093
|
+
return true;
|
|
1094
|
+
case GGML_OP_SUM_ROWS:
|
|
1095
|
+
{
|
|
1096
|
+
if (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) {
|
|
1097
|
+
spacemit_kernels::rvv::forward_sum_rows<float>(params, op);
|
|
1098
|
+
} else {
|
|
1099
|
+
ggml_compute_forward_sum_rows(params, op);
|
|
1100
|
+
}
|
|
1101
|
+
}
|
|
1102
|
+
return true;
|
|
1103
|
+
case GGML_OP_GET_ROWS:
|
|
1104
|
+
{
|
|
1105
|
+
if (op->src[0]->type == op->type) {
|
|
1106
|
+
switch (op->src[0]->type) {
|
|
1107
|
+
case GGML_TYPE_F32:
|
|
1108
|
+
spacemit_kernels::rvv::forward_get_rows<int32_t>(params, op);
|
|
1109
|
+
return true;
|
|
1110
|
+
case GGML_TYPE_F16:
|
|
1111
|
+
spacemit_kernels::rvv::forward_get_rows<int16_t>(params, op);
|
|
1112
|
+
return true;
|
|
1113
|
+
default:
|
|
1114
|
+
break;
|
|
1115
|
+
}
|
|
1116
|
+
}
|
|
1117
|
+
|
|
1118
|
+
ggml_compute_forward_get_rows(params, op);
|
|
1119
|
+
}
|
|
586
1120
|
return true;
|
|
1121
|
+
case GGML_OP_CONCAT:
|
|
1122
|
+
{
|
|
1123
|
+
const int32_t dim = ggml_get_op_params_i32(op, 0);
|
|
1124
|
+
if (dim == 0 && op->type == op->src[0]->type) {
|
|
1125
|
+
switch (op->src[0]->type) {
|
|
1126
|
+
case GGML_TYPE_F32:
|
|
1127
|
+
spacemit_kernels::rvv::forward_concat<int32_t>(params, op);
|
|
1128
|
+
return true;
|
|
1129
|
+
case GGML_TYPE_F16:
|
|
1130
|
+
spacemit_kernels::rvv::forward_concat<int16_t>(params, op);
|
|
1131
|
+
return true;
|
|
1132
|
+
default:
|
|
1133
|
+
break;
|
|
1134
|
+
}
|
|
1135
|
+
}
|
|
1136
|
+
|
|
1137
|
+
ggml_compute_forward_concat(params, op);
|
|
1138
|
+
}
|
|
1139
|
+
return true;
|
|
1140
|
+
// TODO For GGML_OP_GATED_DELTA_NET
|
|
1141
|
+
// case GGML_OP_GATED_DELTA_NET:
|
|
1142
|
+
// return true;
|
|
587
1143
|
default:
|
|
588
|
-
// GGML_ABORT("fatal error");
|
|
589
1144
|
break;
|
|
590
1145
|
}
|
|
591
1146
|
return false;
|
|
592
1147
|
}
|
|
593
1148
|
|
|
594
|
-
void
|
|
595
|
-
const ggml_tensor *
|
|
596
|
-
ggml_tensor *
|
|
597
|
-
|
|
598
|
-
|
|
1149
|
+
void forward_flash_attn_ext_f16(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
1150
|
+
const ggml_tensor * q = dst->src[0];
|
|
1151
|
+
const ggml_tensor * k = dst->src[1];
|
|
1152
|
+
const ggml_tensor * v = dst->src[2];
|
|
1153
|
+
|
|
1154
|
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
1155
|
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
1156
|
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
1157
|
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
|
1158
|
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
|
1159
|
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
|
1160
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
1161
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
1162
|
+
|
|
1163
|
+
const int64_t DK = nek0;
|
|
1164
|
+
const int64_t DV = nev0;
|
|
1165
|
+
|
|
1166
|
+
const bool supported_prec = (dst->op_params[3] == GGML_PREC_F32 || dst->op_params[3] == GGML_PREC_DEFAULT);
|
|
1167
|
+
const bool supported_types = (q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16);
|
|
1168
|
+
const bool supported_shape = (DK > 0 && DK <= 128 && DV > 0 && DV <= 128);
|
|
1169
|
+
const bool supported_vlen = (__riscv_vlenb() == 128);
|
|
1170
|
+
|
|
1171
|
+
if (!(supported_prec && supported_types && supported_shape && supported_vlen)) {
|
|
1172
|
+
ggml_compute_forward_flash_attn_ext(params, dst);
|
|
1173
|
+
return;
|
|
1174
|
+
}
|
|
1175
|
+
|
|
1176
|
+
// total rows in q
|
|
1177
|
+
const int64_t nr = neq1 * neq2 * neq3;
|
|
599
1178
|
|
|
1179
|
+
// rows per thread
|
|
600
1180
|
const int ith = params->ith;
|
|
601
1181
|
const int nth = params->nth;
|
|
602
1182
|
|
|
603
|
-
|
|
1183
|
+
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
|
1184
|
+
const bool use_tiled = !params->use_ref && (neq1 >= Q_TILE_SZ);
|
|
604
1185
|
|
|
605
|
-
|
|
606
|
-
|
|
1186
|
+
// 4x chunks per thread
|
|
1187
|
+
// int nth_scaled = nth * 4;
|
|
1188
|
+
// int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
|
1189
|
+
// int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
|
607
1190
|
|
|
608
|
-
|
|
1191
|
+
// if (nth == 1 || nchunk < nth) {
|
|
1192
|
+
// nchunk = nth;
|
|
1193
|
+
// }
|
|
609
1194
|
|
|
610
|
-
|
|
611
|
-
auto * output = (float *) dst->data;
|
|
1195
|
+
int64_t nchunk = nth;
|
|
612
1196
|
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
const auto task_begin = ith * task_per_thread;
|
|
618
|
-
const auto task_end = std::min((ith + 1) * task_per_thread, task_count);
|
|
1197
|
+
if (ith == 0) {
|
|
1198
|
+
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
|
1199
|
+
ggml_threadpool_chunk_set(params->threadpool, nth);
|
|
1200
|
+
}
|
|
619
1201
|
|
|
620
|
-
|
|
621
|
-
auto offset = task_idx * hidden_size;
|
|
622
|
-
auto * p_input = const_cast<float *>(input + offset);
|
|
1202
|
+
ggml_barrier(params->threadpool);
|
|
623
1203
|
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
auto * p_gamma_data = (const float *) nullptr;
|
|
627
|
-
auto * p_beta_data = (const float *) nullptr;
|
|
628
|
-
size_t gvl = __riscv_vsetvlmax_e32m4();
|
|
629
|
-
vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl);
|
|
630
|
-
vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl);
|
|
631
|
-
int64_t length = hidden_size;
|
|
632
|
-
while (length > 0) {
|
|
633
|
-
gvl = __riscv_vsetvl_e32m4(length);
|
|
634
|
-
// load data
|
|
635
|
-
vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl);
|
|
1204
|
+
// The number of elements in each chunk
|
|
1205
|
+
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
|
636
1206
|
|
|
637
|
-
|
|
638
|
-
|
|
1207
|
+
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
|
1208
|
+
int current_chunk = ith;
|
|
639
1209
|
|
|
640
|
-
|
|
1210
|
+
while (current_chunk < nchunk) {
|
|
1211
|
+
const int64_t ir0 = dr * current_chunk;
|
|
1212
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
641
1213
|
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
1214
|
+
if (use_tiled) {
|
|
1215
|
+
spacemit_kernels::rvv::forward_flash_attn_ext_f16_tiled_vlen1024_vf16(
|
|
1216
|
+
params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer,
|
|
1217
|
+
ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size);
|
|
1218
|
+
} else {
|
|
1219
|
+
spacemit_kernels::rvv::forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16(
|
|
1220
|
+
params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer,
|
|
1221
|
+
ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size);
|
|
645
1222
|
}
|
|
646
1223
|
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
float mean = 0.f;
|
|
650
|
-
vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl);
|
|
651
|
-
vfloat32m1_t mean_v =
|
|
652
|
-
__riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl);
|
|
653
|
-
mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl);
|
|
654
|
-
mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl);
|
|
655
|
-
mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl);
|
|
656
|
-
mean = __riscv_vfmv_f_s_f32m1_f32(mean_v);
|
|
657
|
-
mean /= hidden_size;
|
|
658
|
-
|
|
659
|
-
vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0),
|
|
660
|
-
__riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl);
|
|
661
|
-
mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl);
|
|
662
|
-
mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl);
|
|
663
|
-
mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl);
|
|
664
|
-
|
|
665
|
-
float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v);
|
|
666
|
-
mean_square /= hidden_size;
|
|
667
|
-
mean_square = sqrt(mean_square - mean * mean + epsilon);
|
|
668
|
-
|
|
669
|
-
mean_square = 1.0f / mean_square;
|
|
670
|
-
length = hidden_size;
|
|
671
|
-
p_temp_output = p_output;
|
|
672
|
-
|
|
673
|
-
if (p_gamma_data == nullptr && p_beta_data == nullptr) {
|
|
674
|
-
while (length > 0) {
|
|
675
|
-
gvl = __riscv_vsetvl_e32m4(length);
|
|
676
|
-
vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
|
|
677
|
-
src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
|
|
678
|
-
src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
|
|
679
|
-
__riscv_vse32_v_f32m4(p_output, src_data, gvl);
|
|
680
|
-
p_temp_output += gvl;
|
|
681
|
-
p_output += gvl;
|
|
682
|
-
length -= gvl;
|
|
683
|
-
}
|
|
684
|
-
} else if (p_beta_data == nullptr) {
|
|
685
|
-
while (length > 0) {
|
|
686
|
-
gvl = __riscv_vsetvl_e32m4(length);
|
|
687
|
-
vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
|
|
688
|
-
vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
|
|
689
|
-
src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
|
|
690
|
-
src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
|
|
691
|
-
src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
|
|
692
|
-
__riscv_vse32_v_f32m4(p_output, src_data, gvl);
|
|
693
|
-
p_temp_output += gvl;
|
|
694
|
-
p_output += gvl;
|
|
695
|
-
p_gamma_data += gvl;
|
|
696
|
-
length -= gvl;
|
|
697
|
-
}
|
|
698
|
-
} else if (p_gamma_data != nullptr) {
|
|
699
|
-
while (length > 0) {
|
|
700
|
-
gvl = __riscv_vsetvl_e32m4(length);
|
|
701
|
-
vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
|
|
702
|
-
vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
|
|
703
|
-
src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
|
|
704
|
-
src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
|
|
705
|
-
src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
|
|
706
|
-
vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl);
|
|
707
|
-
src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl);
|
|
708
|
-
p_beta_data += gvl;
|
|
709
|
-
__riscv_vse32_v_f32m4(p_output, src_data, gvl);
|
|
710
|
-
p_temp_output += gvl;
|
|
711
|
-
p_output += gvl;
|
|
712
|
-
p_gamma_data += gvl;
|
|
713
|
-
length -= gvl;
|
|
714
|
-
}
|
|
715
|
-
}
|
|
1224
|
+
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
716
1225
|
}
|
|
717
1226
|
}
|
|
718
1227
|
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
const int ith = params->ith;
|
|
726
|
-
const int nth = params->nth;
|
|
727
|
-
|
|
728
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
|
729
|
-
|
|
730
|
-
float epsilon;
|
|
731
|
-
memcpy(&epsilon, dst->op_params, sizeof(float));
|
|
732
|
-
|
|
733
|
-
GGML_ASSERT(epsilon > 0.0f);
|
|
734
|
-
|
|
735
|
-
auto * input = (float *) src0->data;
|
|
736
|
-
auto * output = (float *) dst->data;
|
|
737
|
-
|
|
738
|
-
const auto hidden_size = ne00;
|
|
739
|
-
const auto task_count = ne01 * ne02 * ne03;
|
|
740
|
-
const auto task_per_thread = (task_count + nth - 1) / nth;
|
|
741
|
-
|
|
742
|
-
const auto task_begin = ith * task_per_thread;
|
|
743
|
-
const auto task_end = std::min((ith + 1) * task_per_thread, task_count);
|
|
744
|
-
|
|
745
|
-
for (auto task_idx = task_begin; task_idx < task_end; task_idx++) {
|
|
746
|
-
auto offset = task_idx * hidden_size;
|
|
747
|
-
auto * p_input = const_cast<float *>(input + offset);
|
|
748
|
-
auto * p_output = output + offset;
|
|
749
|
-
auto * p_temp_output = p_output;
|
|
750
|
-
auto * p_gamma_data = (const float *) nullptr;
|
|
751
|
-
auto * p_beta_data = (const float *) nullptr;
|
|
752
|
-
|
|
753
|
-
size_t gvl = __riscv_vsetvlmax_e32m4();
|
|
754
|
-
// vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl);
|
|
755
|
-
vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl);
|
|
756
|
-
int64_t length = hidden_size;
|
|
757
|
-
while (length > 0) {
|
|
758
|
-
gvl = __riscv_vsetvl_e32m4(length);
|
|
759
|
-
// load data
|
|
760
|
-
vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl);
|
|
1228
|
+
int repack(ggml_tensor * t, const void * data, size_t data_size) override {
|
|
1229
|
+
memcpy(t->data, data, data_size);
|
|
1230
|
+
return 0;
|
|
1231
|
+
}
|
|
1232
|
+
};
|
|
761
1233
|
|
|
762
|
-
|
|
1234
|
+
// Impl By IME1
|
|
1235
|
+
static const tensor_traits<block_q4_0, 32, 16> q4_0_16x32_q8_0;
|
|
1236
|
+
static const tensor_traits<block_q4_1, 32, 16> q4_1_16x32_q8_0;
|
|
1237
|
+
static const tensor_traits<block_q4_K, 32, 16> q4_k_16x32_q8_0;
|
|
1238
|
+
// Impl By IME2
|
|
1239
|
+
static const tensor_traits<block_q2_K, 256, 32> q2_k_32x256_q8_0;
|
|
1240
|
+
static const tensor_traits<block_q3_K, 256, 32> q3_k_32x256_q8_0;
|
|
1241
|
+
static const tensor_traits<block_q4_0, 32, 32> q4_0_32x32_q8_0;
|
|
1242
|
+
static const tensor_traits<block_q4_1, 32, 32> q4_1_32x32_q8_0;
|
|
1243
|
+
static const tensor_traits<block_q4_0, 256, 32> q4_0_32x256_q8_0;
|
|
1244
|
+
static const tensor_traits<block_q4_1, 256, 32> q4_1_32x256_q8_0;
|
|
1245
|
+
static const tensor_traits<block_q4_K, 32, 32> q4_k_32x32_q8_0;
|
|
1246
|
+
static const tensor_traits<block_q6_K, 32, 32> q6_k_32x32_q8_0;
|
|
1247
|
+
static const tensor_traits<block_q8_0, 32, 32> q8_0_32x32_q8_0;
|
|
1248
|
+
static const tensor_traits<block_mxfp4, 32, 32> mxfp4_32x32_q8_0;
|
|
1249
|
+
static const tensor_traits<block_q5_K, 32, 32> q5_k_32x32_q8_0;
|
|
1250
|
+
static const tensor_traits<block_q5_1, 32, 32> q5_1_32x32_q8_0;
|
|
1251
|
+
static const tensor_traits<block_q5_0, 32, 32> q5_0_32x32_q8_0;
|
|
1252
|
+
// Impl By RVV
|
|
1253
|
+
static const tensor_traits_common rvv_impl;
|
|
763
1254
|
|
|
764
|
-
|
|
1255
|
+
} // namespace ggml::cpu::riscv64_spacemit
|
|
765
1256
|
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
1257
|
+
static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const ggml_tensor * cur) {
|
|
1258
|
+
switch (cur->type) {
|
|
1259
|
+
case GGML_TYPE_Q2_K:
|
|
1260
|
+
{
|
|
1261
|
+
#if defined(RISCV64_SPACEMIT_IME2)
|
|
1262
|
+
if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
|
|
1263
|
+
return &ggml::cpu::riscv64_spacemit::q2_k_32x256_q8_0;
|
|
1264
|
+
}
|
|
1265
|
+
#endif
|
|
769
1266
|
}
|
|
1267
|
+
break;
|
|
1268
|
+
case GGML_TYPE_Q3_K:
|
|
1269
|
+
{
|
|
1270
|
+
#if defined(RISCV64_SPACEMIT_IME2)
|
|
1271
|
+
if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
|
|
1272
|
+
return &ggml::cpu::riscv64_spacemit::q3_k_32x256_q8_0;
|
|
1273
|
+
}
|
|
1274
|
+
#endif
|
|
1275
|
+
}
|
|
1276
|
+
break;
|
|
1277
|
+
case GGML_TYPE_Q4_0:
|
|
1278
|
+
{
|
|
1279
|
+
#if defined(RISCV64_SPACEMIT_IME2)
|
|
1280
|
+
if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 &&
|
|
1281
|
+
(ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
|
|
1282
|
+
return &ggml::cpu::riscv64_spacemit::q4_0_32x256_q8_0;
|
|
1283
|
+
}
|
|
770
1284
|
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0),
|
|
777
|
-
__riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl);
|
|
778
|
-
mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl);
|
|
779
|
-
mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl);
|
|
780
|
-
mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl);
|
|
781
|
-
|
|
782
|
-
float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v);
|
|
783
|
-
mean_square /= hidden_size;
|
|
1285
|
+
if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
|
|
1286
|
+
return &ggml::cpu::riscv64_spacemit::q4_0_32x32_q8_0;
|
|
1287
|
+
}
|
|
1288
|
+
#endif
|
|
784
1289
|
|
|
785
|
-
|
|
1290
|
+
#if defined(RISCV64_SPACEMIT_IME1)
|
|
1291
|
+
if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) {
|
|
1292
|
+
return &ggml::cpu::riscv64_spacemit::q4_0_16x32_q8_0;
|
|
1293
|
+
}
|
|
1294
|
+
#endif
|
|
1295
|
+
}
|
|
1296
|
+
break;
|
|
1297
|
+
case GGML_TYPE_Q4_1:
|
|
1298
|
+
{
|
|
1299
|
+
#if defined(RISCV64_SPACEMIT_IME2)
|
|
1300
|
+
// TODO
|
|
1301
|
+
// if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 &&
|
|
1302
|
+
// (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
|
|
1303
|
+
// return &ggml::cpu::riscv64_spacemit::q4_1_32x256_q8_0;
|
|
1304
|
+
// }
|
|
1305
|
+
|
|
1306
|
+
if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
|
|
1307
|
+
return &ggml::cpu::riscv64_spacemit::q4_1_32x32_q8_0;
|
|
1308
|
+
}
|
|
1309
|
+
#endif
|
|
786
1310
|
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
1311
|
+
#if defined(RISCV64_SPACEMIT_IME1)
|
|
1312
|
+
if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) {
|
|
1313
|
+
return &ggml::cpu::riscv64_spacemit::q4_1_16x32_q8_0;
|
|
1314
|
+
}
|
|
1315
|
+
#endif
|
|
1316
|
+
}
|
|
1317
|
+
break;
|
|
1318
|
+
case GGML_TYPE_Q4_K:
|
|
1319
|
+
{
|
|
1320
|
+
#if defined(RISCV64_SPACEMIT_IME2)
|
|
1321
|
+
if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
|
|
1322
|
+
return &ggml::cpu::riscv64_spacemit::q4_k_32x32_q8_0;
|
|
1323
|
+
}
|
|
1324
|
+
#endif
|
|
790
1325
|
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
|
|
795
|
-
src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
|
|
796
|
-
__riscv_vse32_v_f32m4(p_output, src_data, gvl);
|
|
797
|
-
p_temp_output += gvl;
|
|
798
|
-
p_output += gvl;
|
|
799
|
-
length -= gvl;
|
|
1326
|
+
#if defined(RISCV64_SPACEMIT_IME1)
|
|
1327
|
+
if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) {
|
|
1328
|
+
return &ggml::cpu::riscv64_spacemit::q4_k_16x32_q8_0;
|
|
800
1329
|
}
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
p_temp_output += gvl;
|
|
810
|
-
p_output += gvl;
|
|
811
|
-
p_gamma_data += gvl;
|
|
812
|
-
length -= gvl;
|
|
1330
|
+
#endif
|
|
1331
|
+
}
|
|
1332
|
+
break;
|
|
1333
|
+
case GGML_TYPE_Q6_K:
|
|
1334
|
+
{
|
|
1335
|
+
#if defined(RISCV64_SPACEMIT_IME2)
|
|
1336
|
+
if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
|
|
1337
|
+
return &ggml::cpu::riscv64_spacemit::q6_k_32x32_q8_0;
|
|
813
1338
|
}
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl);
|
|
823
|
-
p_beta_data += gvl;
|
|
824
|
-
__riscv_vse32_v_f32m4(p_output, src_data, gvl);
|
|
825
|
-
p_temp_output += gvl;
|
|
826
|
-
p_output += gvl;
|
|
827
|
-
p_gamma_data += gvl;
|
|
828
|
-
length -= gvl;
|
|
1339
|
+
#endif
|
|
1340
|
+
}
|
|
1341
|
+
break;
|
|
1342
|
+
case GGML_TYPE_Q8_0:
|
|
1343
|
+
{
|
|
1344
|
+
#if defined(RISCV64_SPACEMIT_IME2)
|
|
1345
|
+
if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
|
|
1346
|
+
return &ggml::cpu::riscv64_spacemit::q8_0_32x32_q8_0;
|
|
829
1347
|
}
|
|
1348
|
+
#endif
|
|
830
1349
|
}
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
}
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
1350
|
+
break;
|
|
1351
|
+
case GGML_TYPE_MXFP4:
|
|
1352
|
+
{
|
|
1353
|
+
#if defined(RISCV64_SPACEMIT_IME2)
|
|
1354
|
+
// TODO
|
|
1355
|
+
// if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
|
|
1356
|
+
// return &ggml::cpu::riscv64_spacemit::mxfp4_32x32_q8_0;
|
|
1357
|
+
// }
|
|
1358
|
+
#endif
|
|
1359
|
+
}
|
|
1360
|
+
break;
|
|
1361
|
+
case GGML_TYPE_Q5_K:
|
|
1362
|
+
{
|
|
1363
|
+
#if defined(RISCV64_SPACEMIT_IME2)
|
|
1364
|
+
if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
|
|
1365
|
+
return &ggml::cpu::riscv64_spacemit::q5_k_32x32_q8_0;
|
|
1366
|
+
}
|
|
1367
|
+
#endif
|
|
1368
|
+
}
|
|
1369
|
+
break;
|
|
1370
|
+
case GGML_TYPE_Q5_1:
|
|
1371
|
+
{
|
|
1372
|
+
#if defined(RISCV64_SPACEMIT_IME2)
|
|
1373
|
+
if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
|
|
1374
|
+
return &ggml::cpu::riscv64_spacemit::q5_1_32x32_q8_0;
|
|
1375
|
+
}
|
|
1376
|
+
#endif
|
|
1377
|
+
}
|
|
1378
|
+
break;
|
|
1379
|
+
case GGML_TYPE_Q5_0:
|
|
1380
|
+
{
|
|
1381
|
+
#if defined(RISCV64_SPACEMIT_IME2)
|
|
1382
|
+
if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
|
|
1383
|
+
return &ggml::cpu::riscv64_spacemit::q5_0_32x32_q8_0;
|
|
1384
|
+
}
|
|
1385
|
+
#endif
|
|
1386
|
+
}
|
|
1387
|
+
break;
|
|
1388
|
+
default:
|
|
1389
|
+
break;
|
|
862
1390
|
}
|
|
863
1391
|
|
|
864
1392
|
return nullptr;
|
|
865
1393
|
}
|
|
866
1394
|
|
|
867
1395
|
static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
|
868
|
-
|
|
1396
|
+
ggml_tensor * tensor) {
|
|
869
1397
|
tensor->extra =
|
|
870
1398
|
(void *) const_cast<ggml::cpu::tensor_traits *>(ggml_riscv64_spacemit_get_optimal_repack_type(tensor));
|
|
871
1399
|
|
|
@@ -874,8 +1402,46 @@ static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_ba
|
|
|
874
1402
|
return GGML_STATUS_SUCCESS;
|
|
875
1403
|
}
|
|
876
1404
|
|
|
1405
|
+
static void ggml_backend_riscv64_spacemit_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
1406
|
+
GGML_ASSERT(buffer);
|
|
1407
|
+
|
|
1408
|
+
void * base = buffer->context;
|
|
1409
|
+
if (base == nullptr) {
|
|
1410
|
+
return;
|
|
1411
|
+
}
|
|
1412
|
+
|
|
1413
|
+
ggml::cpu::riscv64_spacemit::spine_mem_pool_free(base);
|
|
1414
|
+
}
|
|
1415
|
+
|
|
1416
|
+
static void * ggml_backend_riscv64_spacemit_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
1417
|
+
GGML_ASSERT(buffer);
|
|
1418
|
+
|
|
1419
|
+
void * base = buffer->context;
|
|
1420
|
+
GGML_ASSERT(base != nullptr);
|
|
1421
|
+
return base;
|
|
1422
|
+
}
|
|
1423
|
+
|
|
1424
|
+
static void ggml_backend_riscv64_spacemit_buffer_memset_tensor(ggml_backend_buffer_t buffer,
|
|
1425
|
+
ggml_tensor * tensor,
|
|
1426
|
+
uint8_t value,
|
|
1427
|
+
size_t offset,
|
|
1428
|
+
size_t size) {
|
|
1429
|
+
GGML_ASSERT(tensor);
|
|
1430
|
+
memset((char *) tensor->data + offset, value, size);
|
|
1431
|
+
|
|
1432
|
+
GGML_UNUSED(buffer);
|
|
1433
|
+
}
|
|
1434
|
+
|
|
1435
|
+
static void ggml_backend_riscv64_spacemit_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
|
1436
|
+
GGML_ASSERT(buffer);
|
|
1437
|
+
|
|
1438
|
+
void * base = buffer->context;
|
|
1439
|
+
GGML_ASSERT(base != nullptr);
|
|
1440
|
+
memset(base, value, buffer->size);
|
|
1441
|
+
}
|
|
1442
|
+
|
|
877
1443
|
static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|
878
|
-
|
|
1444
|
+
ggml_tensor * tensor,
|
|
879
1445
|
const void * data,
|
|
880
1446
|
size_t offset,
|
|
881
1447
|
size_t size) {
|
|
@@ -891,6 +1457,20 @@ static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_
|
|
|
891
1457
|
GGML_UNUSED(buffer);
|
|
892
1458
|
}
|
|
893
1459
|
|
|
1460
|
+
static const ggml_backend_buffer_i ggml_backend_riscv64_spacemit_buffer_i = {
|
|
1461
|
+
/* .free_buffer = */ ggml_backend_riscv64_spacemit_buffer_free_buffer,
|
|
1462
|
+
/* .get_base = */ ggml_backend_riscv64_spacemit_buffer_get_base,
|
|
1463
|
+
/* .init_tensor = */ ggml_backend_riscv64_spacemit_buffer_init_tensor,
|
|
1464
|
+
/* .memset_tensor = */ ggml_backend_riscv64_spacemit_buffer_memset_tensor,
|
|
1465
|
+
/* .set_tensor = */ ggml_backend_riscv64_spacemit_buffer_set_tensor,
|
|
1466
|
+
/* .get_tensor = */ nullptr,
|
|
1467
|
+
/* .set_tensor_2d = */ nullptr,
|
|
1468
|
+
/* .get_tensor_2d = */ nullptr,
|
|
1469
|
+
/* .cpy_tensor = */ nullptr,
|
|
1470
|
+
/* .clear = */ ggml_backend_riscv64_spacemit_buffer_clear,
|
|
1471
|
+
/* .reset = */ nullptr,
|
|
1472
|
+
};
|
|
1473
|
+
|
|
894
1474
|
static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
|
895
1475
|
return "CPU_RISCV64_SPACEMIT";
|
|
896
1476
|
|
|
@@ -899,18 +1479,12 @@ static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_
|
|
|
899
1479
|
|
|
900
1480
|
static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
|
|
901
1481
|
size_t size) {
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
if (buffer == nullptr) {
|
|
1482
|
+
void * base = ggml::cpu::riscv64_spacemit::spine_mem_pool_alloc(size, 64);
|
|
1483
|
+
if (base == nullptr) {
|
|
905
1484
|
return nullptr;
|
|
906
1485
|
}
|
|
907
1486
|
|
|
908
|
-
|
|
909
|
-
buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor;
|
|
910
|
-
buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor;
|
|
911
|
-
buffer->iface.get_tensor = nullptr;
|
|
912
|
-
buffer->iface.cpy_tensor = nullptr;
|
|
913
|
-
return buffer;
|
|
1487
|
+
return ggml_backend_buffer_init(buft, ggml_backend_riscv64_spacemit_buffer_i, base, size);
|
|
914
1488
|
}
|
|
915
1489
|
|
|
916
1490
|
static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
@@ -919,44 +1493,91 @@ static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_b
|
|
|
919
1493
|
GGML_UNUSED(buft);
|
|
920
1494
|
}
|
|
921
1495
|
|
|
922
|
-
static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft,
|
|
923
|
-
const struct ggml_tensor * tensor) {
|
|
1496
|
+
static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
|
924
1497
|
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
|
925
1498
|
if (tensor->ne[i] <= 0) {
|
|
926
1499
|
return 0;
|
|
927
1500
|
}
|
|
928
1501
|
}
|
|
929
1502
|
|
|
930
|
-
|
|
1503
|
+
GGML_UNUSED(buft);
|
|
1504
|
+
|
|
1505
|
+
const auto plain_nbytes = [&]() {
|
|
1506
|
+
size_t total = ggml_type_size(tensor->type);
|
|
1507
|
+
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
|
1508
|
+
total += (tensor->ne[i] - 1) * tensor->nb[i];
|
|
1509
|
+
}
|
|
1510
|
+
return total;
|
|
1511
|
+
};
|
|
1512
|
+
|
|
931
1513
|
const size_t blck_size = ggml_blck_size(tensor->type);
|
|
932
1514
|
if (blck_size == 1) {
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
1515
|
+
return plain_nbytes();
|
|
1516
|
+
}
|
|
1517
|
+
|
|
1518
|
+
const size_t row_nbytes = tensor->ne[0] * tensor->nb[0] / blck_size;
|
|
1519
|
+
|
|
1520
|
+
const auto add_strided_nbytes = [&](size_t total, size_t src_block_size, size_t dst_block_size) {
|
|
1521
|
+
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
|
|
1522
|
+
total += (tensor->ne[i] - 1) * (tensor->nb[i] / src_block_size) * dst_block_size;
|
|
936
1523
|
}
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
}
|
|
1524
|
+
return total;
|
|
1525
|
+
};
|
|
1526
|
+
|
|
1527
|
+
const auto remap_block_nbytes = [&](size_t src_block_size, size_t dst_block_size, int64_t padded_rows = 0) {
|
|
1528
|
+
GGML_ASSERT(row_nbytes % src_block_size == 0);
|
|
1529
|
+
|
|
1530
|
+
size_t total =
|
|
1531
|
+
add_strided_nbytes((row_nbytes / src_block_size) * dst_block_size, src_block_size, dst_block_size);
|
|
1532
|
+
|
|
1533
|
+
if (padded_rows > 0 && tensor->ne[1] % padded_rows != 0) {
|
|
1534
|
+
total += (padded_rows - tensor->ne[1] % padded_rows) * (tensor->nb[1] / src_block_size) * dst_block_size;
|
|
949
1535
|
}
|
|
1536
|
+
|
|
1537
|
+
return total;
|
|
1538
|
+
};
|
|
1539
|
+
|
|
1540
|
+
size_t nbytes = row_nbytes;
|
|
1541
|
+
switch (tensor->type) {
|
|
1542
|
+
case GGML_TYPE_Q4_K:
|
|
1543
|
+
nbytes = remap_block_nbytes(sizeof(block_q4_K), sizeof(block_q4_1) * 8);
|
|
1544
|
+
break;
|
|
1545
|
+
case GGML_TYPE_Q6_K:
|
|
1546
|
+
nbytes = remap_block_nbytes(sizeof(block_q6_K), sizeof(block_q8_0) * 8, 32);
|
|
1547
|
+
break;
|
|
1548
|
+
case GGML_TYPE_Q8_0:
|
|
1549
|
+
nbytes = remap_block_nbytes(sizeof(block_q8_0), sizeof(block_q8_0), 32);
|
|
1550
|
+
break;
|
|
1551
|
+
case GGML_TYPE_Q2_K:
|
|
1552
|
+
nbytes = remap_block_nbytes(sizeof(block_q2_K), sizeof(spacemit_kernels::nrow_block_q2_k<1>));
|
|
1553
|
+
break;
|
|
1554
|
+
case GGML_TYPE_Q3_K:
|
|
1555
|
+
nbytes = remap_block_nbytes(sizeof(block_q3_K), sizeof(spacemit_kernels::nrow_block_q3_k<1>));
|
|
1556
|
+
break;
|
|
1557
|
+
case GGML_TYPE_MXFP4:
|
|
1558
|
+
nbytes = remap_block_nbytes(sizeof(block_mxfp4), sizeof(spacemit_kernels::nrow_block_mxfp4<1>));
|
|
1559
|
+
break;
|
|
1560
|
+
case GGML_TYPE_Q5_K:
|
|
1561
|
+
nbytes = remap_block_nbytes(sizeof(block_q5_K), sizeof(spacemit_kernels::nrow_block_q5_1<1>) * 8);
|
|
1562
|
+
break;
|
|
1563
|
+
case GGML_TYPE_Q5_1:
|
|
1564
|
+
nbytes = remap_block_nbytes(sizeof(block_q5_1), sizeof(spacemit_kernels::nrow_block_q5_1<1>));
|
|
1565
|
+
break;
|
|
1566
|
+
case GGML_TYPE_Q5_0:
|
|
1567
|
+
nbytes = remap_block_nbytes(sizeof(block_q5_0), sizeof(spacemit_kernels::nrow_block_q5_0<1>));
|
|
1568
|
+
break;
|
|
1569
|
+
default:
|
|
1570
|
+
nbytes = add_strided_nbytes(row_nbytes, 1, 1);
|
|
1571
|
+
break;
|
|
950
1572
|
}
|
|
951
1573
|
|
|
952
|
-
GGML_UNUSED(buft);
|
|
953
1574
|
return nbytes;
|
|
954
1575
|
}
|
|
955
1576
|
|
|
956
1577
|
namespace ggml::cpu::riscv64_spacemit {
|
|
957
1578
|
|
|
958
1579
|
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
959
|
-
bool supports_op(ggml_backend_dev_t, const
|
|
1580
|
+
bool supports_op(ggml_backend_dev_t, const ggml_tensor * op) override {
|
|
960
1581
|
switch (op->op) {
|
|
961
1582
|
case GGML_OP_MUL_MAT:
|
|
962
1583
|
if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) &&
|
|
@@ -970,10 +1591,16 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
|
970
1591
|
}
|
|
971
1592
|
}
|
|
972
1593
|
break;
|
|
973
|
-
case
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
1594
|
+
case GGML_OP_MUL_MAT_ID:
|
|
1595
|
+
if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 3) &&
|
|
1596
|
+
op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() &&
|
|
1597
|
+
ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) {
|
|
1598
|
+
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
|
1599
|
+
return false;
|
|
1600
|
+
}
|
|
1601
|
+
if (op->src[1]->type == GGML_TYPE_F32) {
|
|
1602
|
+
return true;
|
|
1603
|
+
}
|
|
977
1604
|
}
|
|
978
1605
|
break;
|
|
979
1606
|
default:
|
|
@@ -983,15 +1610,28 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
|
983
1610
|
return false;
|
|
984
1611
|
}
|
|
985
1612
|
|
|
986
|
-
ggml::cpu::tensor_traits * get_tensor_traits(const
|
|
1613
|
+
ggml::cpu::tensor_traits * get_tensor_traits(const ggml_tensor * op) override {
|
|
987
1614
|
switch (op->op) {
|
|
988
1615
|
case GGML_OP_MUL_MAT:
|
|
1616
|
+
case GGML_OP_MUL_MAT_ID:
|
|
989
1617
|
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) {
|
|
990
1618
|
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
|
991
1619
|
}
|
|
992
1620
|
break;
|
|
993
1621
|
case GGML_OP_NORM:
|
|
994
1622
|
case GGML_OP_RMS_NORM:
|
|
1623
|
+
case GGML_OP_ADD:
|
|
1624
|
+
case GGML_OP_SUB:
|
|
1625
|
+
case GGML_OP_MUL:
|
|
1626
|
+
case GGML_OP_DIV:
|
|
1627
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
|
1628
|
+
case GGML_OP_CONT:
|
|
1629
|
+
case GGML_OP_CPY:
|
|
1630
|
+
case GGML_OP_REPEAT:
|
|
1631
|
+
case GGML_OP_SUM_ROWS:
|
|
1632
|
+
case GGML_OP_GET_ROWS:
|
|
1633
|
+
case GGML_OP_CONCAT:
|
|
1634
|
+
// case GGML_OP_GATED_DELTA_NET:
|
|
995
1635
|
return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl);
|
|
996
1636
|
default:
|
|
997
1637
|
// GGML_ABORT("fatal error");
|
|
@@ -1005,7 +1645,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
|
1005
1645
|
} // namespace ggml::cpu::riscv64_spacemit
|
|
1006
1646
|
|
|
1007
1647
|
ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) {
|
|
1008
|
-
static
|
|
1648
|
+
static ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = {
|
|
1009
1649
|
/* .iface = */
|
|
1010
1650
|
{
|
|
1011
1651
|
/* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name,
|
|
@@ -1023,3 +1663,78 @@ ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) {
|
|
|
1023
1663
|
|
|
1024
1664
|
return &ggml_backend_cpu_buffer_type_riscv64_spacemit;
|
|
1025
1665
|
}
|
|
1666
|
+
|
|
1667
|
+
extern "C" {
|
|
1668
|
+
static int bind_ai_thread() {
|
|
1669
|
+
int fd, bytes;
|
|
1670
|
+
char str[32];
|
|
1671
|
+
|
|
1672
|
+
fd = open("/proc/set_ai_thread", O_WRONLY);
|
|
1673
|
+
if (fd < 0) {
|
|
1674
|
+
GGML_LOG_ERROR("try open /proc/set_ai_thread failed\n");
|
|
1675
|
+
return -1;
|
|
1676
|
+
}
|
|
1677
|
+
|
|
1678
|
+
snprintf(str, 16, "%d", 0);
|
|
1679
|
+
bytes = write(fd, str, strlen(str));
|
|
1680
|
+
if (bytes < 0) {
|
|
1681
|
+
GGML_LOG_ERROR("try write /proc/set_ai_thread failed\n");
|
|
1682
|
+
close(fd);
|
|
1683
|
+
return -1;
|
|
1684
|
+
}
|
|
1685
|
+
|
|
1686
|
+
close(fd);
|
|
1687
|
+
return 0;
|
|
1688
|
+
}
|
|
1689
|
+
|
|
1690
|
+
void ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(int thread_n) {
|
|
1691
|
+
int cpu_id = sched_getcpu();
|
|
1692
|
+
if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2 &&
|
|
1693
|
+
!((1 << cpu_id) & ggml::cpu::riscv64_spacemit::global_spine_env_info.cpu_mask)) {
|
|
1694
|
+
GGML_PRINT_DEBUG("bind_ai_thread for thread %d, pid %d\n", thread_n, getpid());
|
|
1695
|
+
bind_ai_thread();
|
|
1696
|
+
}
|
|
1697
|
+
|
|
1698
|
+
if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_tcm &&
|
|
1699
|
+
ggml::cpu::riscv64_spacemit::tls_context.cpu_id == -1) {
|
|
1700
|
+
CPU_ZERO(&(ggml::cpu::riscv64_spacemit::tls_context.cpuset));
|
|
1701
|
+
pthread_t main_thread = pthread_self();
|
|
1702
|
+
const auto & perfer_core_ids = ggml::cpu::riscv64_spacemit::global_spine_env_info.perfer_core_ids;
|
|
1703
|
+
if (thread_n < 0 || static_cast<size_t>(thread_n) >= perfer_core_ids.size()) {
|
|
1704
|
+
GGML_ABORT("thread_n %d exceeds perfer_core_ids size %zu\n", thread_n, perfer_core_ids.size());
|
|
1705
|
+
}
|
|
1706
|
+
auto perfer_cpu_id = perfer_core_ids[static_cast<size_t>(thread_n)];
|
|
1707
|
+
CPU_SET(perfer_cpu_id, &(ggml::cpu::riscv64_spacemit::tls_context.cpuset));
|
|
1708
|
+
int s =
|
|
1709
|
+
pthread_setaffinity_np(main_thread, sizeof(cpu_set_t), &(ggml::cpu::riscv64_spacemit::tls_context.cpuset));
|
|
1710
|
+
if (s != 0) {
|
|
1711
|
+
GGML_ABORT("set thread affinity error for thread_n %d, cpu_id %d\n", thread_n, perfer_cpu_id);
|
|
1712
|
+
}
|
|
1713
|
+
|
|
1714
|
+
int ai_cpu_id = perfer_cpu_id - ggml::cpu::riscv64_spacemit::global_spine_env_info.aicpu_id_offset;
|
|
1715
|
+
ggml::cpu::riscv64_spacemit::tls_context.cpu_id = ai_cpu_id;
|
|
1716
|
+
ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer =
|
|
1717
|
+
ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_get(ai_cpu_id);
|
|
1718
|
+
ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size =
|
|
1719
|
+
ggml::cpu::riscv64_spacemit::global_spine_env_info.tcm_blk_size;
|
|
1720
|
+
}
|
|
1721
|
+
|
|
1722
|
+
if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) {
|
|
1723
|
+
void * rt =
|
|
1724
|
+
ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_wait(ggml::cpu::riscv64_spacemit::tls_context.cpu_id);
|
|
1725
|
+
if (rt == nullptr) {
|
|
1726
|
+
GGML_ABORT("wait tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id);
|
|
1727
|
+
}
|
|
1728
|
+
}
|
|
1729
|
+
}
|
|
1730
|
+
|
|
1731
|
+
void ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(int thread_n) {
|
|
1732
|
+
if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) {
|
|
1733
|
+
auto rt = ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_release(
|
|
1734
|
+
ggml::cpu::riscv64_spacemit::tls_context.cpu_id);
|
|
1735
|
+
if (rt != 0) {
|
|
1736
|
+
GGML_ABORT("release tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id);
|
|
1737
|
+
}
|
|
1738
|
+
}
|
|
1739
|
+
}
|
|
1740
|
+
}
|