whispercpp 1.3.5 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/LICENSE +1 -1
- data/README.md +133 -3
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -7
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +56 -46
- data/ext/ruby_whisper.h +165 -2
- data/ext/ruby_whisper_context.c +297 -126
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -66
- data/ext/ruby_whisper_segment.c +6 -7
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +46 -16
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +24 -19
- data/ext/sources/examples/cli/cli.cpp +51 -9
- data/ext/sources/examples/common-ggml.cpp +4 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +213 -163
- data/ext/sources/ggml/CMakeLists.txt +29 -15
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +73 -11
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +8 -3
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +155 -16
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +25 -5
- data/ext/sources/ggml/src/ggml-alloc.c +9 -10
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
- data/ext/sources/ggml/src/ggml-common.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
- data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
- data/ext/sources/ggml/src/ggml-impl.h +68 -1
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +385 -119
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
- data/ext/sources/ggml/src/ggml.c +268 -52
- data/ext/sources/ggml/src/gguf.cpp +377 -47
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +62 -40
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +445 -55
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_context_params.rb +82 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +44 -6
- data/whispercpp.gemspec +2 -2
- metadata +426 -280
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
- data/ext/sources/examples/talk-llama/llama-context.h +0 -360
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
- data/ext/sources/examples/talk-llama/llama-model.h +0 -544
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
- data/ext/sources/examples/talk-llama/llama.h +0 -1540
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -569
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
|
@@ -121,7 +121,8 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
|
|
|
121
121
|
#endif
|
|
122
122
|
|
|
123
123
|
#if defined(__MMA__)
|
|
124
|
-
|
|
124
|
+
typedef vector unsigned char vec_t;
|
|
125
|
+
typedef __vector_quad acc_t;
|
|
125
126
|
#endif
|
|
126
127
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
127
128
|
// VECTORIZED FUSED MULTIPLY ADD
|
|
@@ -179,44 +180,49 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
|
|
|
179
180
|
}
|
|
180
181
|
#endif
|
|
181
182
|
|
|
183
|
+
#if defined(__riscv_v_intrinsic)
|
|
184
|
+
template <> inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
|
|
185
|
+
return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
|
186
|
+
}
|
|
187
|
+
template <> inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
|
|
188
|
+
return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
|
189
|
+
}
|
|
190
|
+
template <> inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
|
|
191
|
+
return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
|
192
|
+
}
|
|
193
|
+
template <> inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
|
|
194
|
+
return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
|
195
|
+
}
|
|
196
|
+
#endif
|
|
197
|
+
|
|
182
198
|
#if defined(__riscv_zvfh)
|
|
183
|
-
template <>
|
|
184
|
-
inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
|
|
199
|
+
template <> inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
|
|
185
200
|
return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
|
186
201
|
}
|
|
187
|
-
inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
|
|
202
|
+
template <> inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
|
|
188
203
|
return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
|
189
204
|
}
|
|
190
|
-
inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
|
|
205
|
+
template <> inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
|
|
191
206
|
return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
|
192
207
|
}
|
|
193
|
-
inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
|
|
208
|
+
template <> inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
|
|
194
209
|
return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
|
195
210
|
}
|
|
196
|
-
inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
|
|
197
|
-
return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
|
198
|
-
}
|
|
199
|
-
inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
|
|
200
|
-
return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
|
201
|
-
}
|
|
202
|
-
inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
|
|
203
|
-
return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
|
204
|
-
}
|
|
205
|
-
inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
|
|
206
|
-
return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
|
207
|
-
}
|
|
208
211
|
#endif
|
|
209
212
|
|
|
210
213
|
#if defined(__riscv_zvfbfwma)
|
|
211
|
-
inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
|
|
214
|
+
template <> inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
|
|
212
215
|
return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
|
213
216
|
}
|
|
214
|
-
inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
|
|
217
|
+
template <> inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
|
|
215
218
|
return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
|
216
219
|
}
|
|
217
|
-
inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
|
|
220
|
+
template <> inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
|
|
218
221
|
return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
|
219
222
|
}
|
|
223
|
+
template <> inline vfloat32m8_t madd(vbfloat16m4_t a, vbfloat16m4_t b, vfloat32m8_t c) {
|
|
224
|
+
return __riscv_vfwmaccbf16_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
|
225
|
+
}
|
|
220
226
|
#endif
|
|
221
227
|
|
|
222
228
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
@@ -271,7 +277,7 @@ inline float hsum(__m512 x) {
|
|
|
271
277
|
}
|
|
272
278
|
#endif // __AVX512F__
|
|
273
279
|
|
|
274
|
-
#if defined(
|
|
280
|
+
#if defined(__riscv_v_intrinsic)
|
|
275
281
|
inline float hsum(vfloat32m1_t x) {
|
|
276
282
|
return __riscv_vfmv_f_s_f32m1_f32(
|
|
277
283
|
__riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
|
|
@@ -378,6 +384,21 @@ template <> inline __m256bh load(const float *p) {
|
|
|
378
384
|
}
|
|
379
385
|
#endif
|
|
380
386
|
|
|
387
|
+
#if defined(__riscv_v_intrinsic)
|
|
388
|
+
template <> inline vfloat32m1_t load(const float *p) {
|
|
389
|
+
return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
|
|
390
|
+
}
|
|
391
|
+
template <> inline vfloat32m2_t load(const float *p) {
|
|
392
|
+
return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2());
|
|
393
|
+
}
|
|
394
|
+
template <> inline vfloat32m4_t load(const float *p) {
|
|
395
|
+
return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4());
|
|
396
|
+
}
|
|
397
|
+
template <> inline vfloat32m8_t load(const float *p) {
|
|
398
|
+
return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8());
|
|
399
|
+
}
|
|
400
|
+
#endif
|
|
401
|
+
|
|
381
402
|
#if defined(__riscv_zvfh)
|
|
382
403
|
template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
|
|
383
404
|
return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
|
|
@@ -391,18 +412,6 @@ template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
|
|
|
391
412
|
template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
|
|
392
413
|
return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
|
|
393
414
|
}
|
|
394
|
-
template <> inline vfloat32m1_t load(const float *p) {
|
|
395
|
-
return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
|
|
396
|
-
}
|
|
397
|
-
template <> inline vfloat32m2_t load(const float *p) {
|
|
398
|
-
return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2());
|
|
399
|
-
}
|
|
400
|
-
template <> inline vfloat32m4_t load(const float *p) {
|
|
401
|
-
return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4());
|
|
402
|
-
}
|
|
403
|
-
template <> inline vfloat32m8_t load(const float *p) {
|
|
404
|
-
return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8());
|
|
405
|
-
}
|
|
406
415
|
#endif
|
|
407
416
|
|
|
408
417
|
#if defined(__riscv_zvfbfwma)
|
|
@@ -415,23 +424,14 @@ template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
|
|
|
415
424
|
template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
|
|
416
425
|
return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());
|
|
417
426
|
}
|
|
427
|
+
template <> inline vbfloat16m4_t load(const ggml_bf16_t *p) {
|
|
428
|
+
return __riscv_vle16_v_bf16m4(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m4());
|
|
429
|
+
}
|
|
418
430
|
#endif
|
|
419
431
|
|
|
420
|
-
#if defined(
|
|
432
|
+
#if defined(__riscv_v_intrinsic)
|
|
421
433
|
template <typename T> T set_zero();
|
|
422
434
|
|
|
423
|
-
template <> inline vfloat16mf2_t set_zero() {
|
|
424
|
-
return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());
|
|
425
|
-
}
|
|
426
|
-
template <> inline vfloat16m1_t set_zero() {
|
|
427
|
-
return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());
|
|
428
|
-
}
|
|
429
|
-
template <> inline vfloat16m2_t set_zero() {
|
|
430
|
-
return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());
|
|
431
|
-
}
|
|
432
|
-
template <> inline vfloat16m4_t set_zero() {
|
|
433
|
-
return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
|
|
434
|
-
}
|
|
435
435
|
template <> inline vfloat32m1_t set_zero() {
|
|
436
436
|
return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
|
|
437
437
|
}
|
|
@@ -448,14 +448,22 @@ template <> inline vfloat32m8_t set_zero() {
|
|
|
448
448
|
|
|
449
449
|
#if defined(__riscv_v_intrinsic)
|
|
450
450
|
template <typename T> size_t vlmax() {
|
|
451
|
-
if constexpr (std::is_same_v<T,
|
|
452
|
-
else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
|
|
453
|
-
else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
|
|
454
|
-
else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
|
|
455
|
-
else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
|
|
451
|
+
if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
|
|
456
452
|
else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return __riscv_vsetvlmax_e32m2(); }
|
|
457
453
|
else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return __riscv_vsetvlmax_e32m4(); }
|
|
458
454
|
else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return __riscv_vsetvlmax_e32m8(); }
|
|
455
|
+
#if defined (__riscv_zvfh)
|
|
456
|
+
else if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
|
|
457
|
+
else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
|
|
458
|
+
else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
|
|
459
|
+
else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
|
|
460
|
+
#endif
|
|
461
|
+
#if defined (__riscv_zvfbfwma)
|
|
462
|
+
else if constexpr (std::is_same_v<T, vbfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
|
|
463
|
+
else if constexpr (std::is_same_v<T, vbfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
|
|
464
|
+
else if constexpr (std::is_same_v<T, vbfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
|
|
465
|
+
else if constexpr (std::is_same_v<T, vbfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
|
|
466
|
+
#endif
|
|
459
467
|
return 0;
|
|
460
468
|
}
|
|
461
469
|
#endif
|
|
@@ -532,7 +540,7 @@ class tinyBLAS {
|
|
|
532
540
|
if constexpr (RN > 1) {
|
|
533
541
|
return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
|
|
534
542
|
} else {
|
|
535
|
-
GGML_LOG_ERROR("mnpack<%d, %d>
|
|
543
|
+
GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N);
|
|
536
544
|
GGML_ASSERT(false); // we have miss something.
|
|
537
545
|
}
|
|
538
546
|
}
|
|
@@ -710,7 +718,7 @@ class tinyBLAS_RVV {
|
|
|
710
718
|
if constexpr (RN > 1) {
|
|
711
719
|
return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
|
|
712
720
|
} else {
|
|
713
|
-
GGML_LOG_ERROR("mnpack<%d, %d>
|
|
721
|
+
GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N);
|
|
714
722
|
GGML_ASSERT(false); // we have miss something.
|
|
715
723
|
}
|
|
716
724
|
}
|
|
@@ -1797,10 +1805,27 @@ class tinyBLAS_Q0_AVX {
|
|
|
1797
1805
|
} \
|
|
1798
1806
|
} \
|
|
1799
1807
|
|
|
1808
|
+
template<typename T>
|
|
1809
|
+
struct mma_instr;
|
|
1810
|
+
|
|
1811
|
+
template<>
|
|
1812
|
+
struct mma_instr<ggml_bf16_t> {
|
|
1813
|
+
static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
|
|
1814
|
+
__builtin_mma_xvbf16ger2pp(acc, a, b);
|
|
1815
|
+
}
|
|
1816
|
+
};
|
|
1817
|
+
|
|
1818
|
+
template<>
|
|
1819
|
+
struct mma_instr<ggml_fp16_t> {
|
|
1820
|
+
static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
|
|
1821
|
+
__builtin_mma_xvf16ger2pp(acc, a, b);
|
|
1822
|
+
}
|
|
1823
|
+
};
|
|
1824
|
+
|
|
1800
1825
|
template <typename TA, typename TB, typename TC>
|
|
1801
|
-
class
|
|
1826
|
+
class tinyBLAS_HP16_PPC {
|
|
1802
1827
|
public:
|
|
1803
|
-
|
|
1828
|
+
tinyBLAS_HP16_PPC(int64_t k,
|
|
1804
1829
|
const TA *A, int64_t lda,
|
|
1805
1830
|
const TB *B, int64_t ldb,
|
|
1806
1831
|
TC *C, int64_t ldc,
|
|
@@ -2118,8 +2143,8 @@ class tinyBLAS_BF16_PPC {
|
|
|
2118
2143
|
packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
|
|
2119
2144
|
packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
|
|
2120
2145
|
for (int x = 0; x < 4; x++) {
|
|
2121
|
-
|
|
2122
|
-
|
|
2146
|
+
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
|
2147
|
+
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
|
2123
2148
|
}
|
|
2124
2149
|
}
|
|
2125
2150
|
SAVE_ACC(&acc_0, ii, jj);
|
|
@@ -2135,8 +2160,8 @@ class tinyBLAS_BF16_PPC {
|
|
|
2135
2160
|
packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
|
|
2136
2161
|
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
|
|
2137
2162
|
for (int x = 0; x < 4; x++) {
|
|
2138
|
-
|
|
2139
|
-
|
|
2163
|
+
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
|
2164
|
+
mma_instr<TA>::outer_product(&acc_1, vec_A[x+4], vec_B[x]);
|
|
2140
2165
|
}
|
|
2141
2166
|
}
|
|
2142
2167
|
SAVE_ACC(&acc_0, ii, jj);
|
|
@@ -2155,10 +2180,10 @@ class tinyBLAS_BF16_PPC {
|
|
|
2155
2180
|
packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
|
|
2156
2181
|
packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
|
|
2157
2182
|
for (int x = 0; x < 4; x++) {
|
|
2158
|
-
|
|
2159
|
-
|
|
2160
|
-
|
|
2161
|
-
|
|
2183
|
+
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
|
2184
|
+
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
|
2185
|
+
mma_instr<TA>::outer_product(&acc_2, vec_A[x+4], vec_B[x]);
|
|
2186
|
+
mma_instr<TA>::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]);
|
|
2162
2187
|
}
|
|
2163
2188
|
}
|
|
2164
2189
|
|
|
@@ -2189,7 +2214,7 @@ class tinyBLAS_BF16_PPC {
|
|
|
2189
2214
|
packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
|
|
2190
2215
|
packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
|
|
2191
2216
|
for (int x = 0; x<2; x++) {
|
|
2192
|
-
|
|
2217
|
+
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
|
2193
2218
|
}
|
|
2194
2219
|
}
|
|
2195
2220
|
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
|
@@ -2224,8 +2249,8 @@ class tinyBLAS_BF16_PPC {
|
|
|
2224
2249
|
packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
|
|
2225
2250
|
packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
|
|
2226
2251
|
for (int x = 0; x<4; x++) {
|
|
2227
|
-
|
|
2228
|
-
|
|
2252
|
+
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
|
2253
|
+
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
|
2229
2254
|
}
|
|
2230
2255
|
}
|
|
2231
2256
|
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
|
@@ -2284,43 +2309,302 @@ class tinyBLAS_BF16_PPC {
|
|
|
2284
2309
|
const int nth;
|
|
2285
2310
|
};
|
|
2286
2311
|
|
|
2287
|
-
|
|
2288
|
-
|
|
2289
|
-
|
|
2290
|
-
|
|
2291
|
-
|
|
2292
|
-
|
|
2312
|
+
template <typename TA>
|
|
2313
|
+
class tinyBLAS_Q0_PPC {
|
|
2314
|
+
public:
|
|
2315
|
+
tinyBLAS_Q0_PPC(int64_t k,
|
|
2316
|
+
const TA * A, int64_t lda,
|
|
2317
|
+
const block_q8_0 * B, int64_t ldb,
|
|
2318
|
+
float * C, int64_t ldc,
|
|
2319
|
+
int ith, int nth)
|
|
2293
2320
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
|
2294
|
-
kc = 64;
|
|
2295
2321
|
}
|
|
2296
2322
|
|
|
2297
|
-
|
|
2298
|
-
|
|
2299
|
-
|
|
2300
|
-
|
|
2301
|
-
|
|
2302
|
-
|
|
2303
|
-
|
|
2304
|
-
|
|
2305
|
-
|
|
2306
|
-
|
|
2307
|
-
|
|
2323
|
+
void matmul(int64_t m, int64_t n) {
|
|
2324
|
+
#if defined(_AIX) || defined(__BIG_ENDIAN__)
|
|
2325
|
+
mnpack(0, m, 0, n);
|
|
2326
|
+
#else
|
|
2327
|
+
const int64_t mc = 64;
|
|
2328
|
+
const int64_t kc = 64;
|
|
2329
|
+
int64_t nc = 64;
|
|
2330
|
+
int64_t n_aligned = 0;
|
|
2331
|
+
if (n % 64 == 0) {
|
|
2332
|
+
n_aligned = n;
|
|
2333
|
+
} else if (n == 4) {
|
|
2334
|
+
n_aligned = 4;
|
|
2335
|
+
} else if (n < 64) {
|
|
2336
|
+
n_aligned = (n / 8) * 8;
|
|
2337
|
+
} else {
|
|
2338
|
+
n_aligned = (n / 64) * 64;
|
|
2339
|
+
}
|
|
2340
|
+
if (n_aligned > 0) {
|
|
2341
|
+
if (n_aligned % 64 == 0) nc = 64;
|
|
2342
|
+
else if (n_aligned == n) nc = n;
|
|
2343
|
+
else if (n_aligned % 32 == 0) nc = 32;
|
|
2344
|
+
else if (n_aligned % 24 == 0) nc = 24;
|
|
2345
|
+
else if (n_aligned % 16 == 0) nc = 16;
|
|
2346
|
+
else nc = 8;
|
|
2347
|
+
}
|
|
2348
|
+
bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);
|
|
2349
|
+
if (can_use_tiled) {
|
|
2350
|
+
matmul_tiled(m, n_aligned, mc, nc, kc);
|
|
2351
|
+
if (n > n_aligned) {
|
|
2352
|
+
mnpack(0, m, n_aligned, n);
|
|
2353
|
+
}
|
|
2308
2354
|
} else {
|
|
2309
2355
|
mnpack(0, m, 0, n);
|
|
2310
2356
|
}
|
|
2357
|
+
#endif
|
|
2358
|
+
}
|
|
2359
|
+
|
|
2360
|
+
private:
|
|
2361
|
+
inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) {
|
|
2362
|
+
for (int I = 0; I < RM; I++) {
|
|
2363
|
+
for (int J = 0; J < RN; J++) {
|
|
2364
|
+
*((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J);
|
|
2365
|
+
}
|
|
2366
|
+
}
|
|
2311
2367
|
}
|
|
2312
2368
|
|
|
2313
|
-
|
|
2314
|
-
|
|
2315
|
-
|
|
2369
|
+
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
|
2370
|
+
vec_t vec_C[4];
|
|
2371
|
+
__builtin_mma_disassemble_acc(vec_C, ACC);
|
|
2372
|
+
for (int I = 0; I < 4; I++) {
|
|
2373
|
+
for (int J = 0; J < 4; J++) {
|
|
2374
|
+
*((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J);
|
|
2375
|
+
}
|
|
2376
|
+
}
|
|
2377
|
+
}
|
|
2378
|
+
|
|
2379
|
+
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
|
2380
|
+
vec_t vec_C[4];
|
|
2381
|
+
__builtin_mma_disassemble_acc(vec_C, ACC);
|
|
2382
|
+
for (int I = 0; I < 4; I++) {
|
|
2383
|
+
for (int J = 0; J < 4; J++) {
|
|
2384
|
+
float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I);
|
|
2385
|
+
*c_ptr += *((float *)&vec_C[I] + J);
|
|
2386
|
+
}
|
|
2387
|
+
}
|
|
2388
|
+
}
|
|
2389
|
+
|
|
2390
|
+
template<typename ArrayType>
|
|
2391
|
+
inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) {
|
|
2392
|
+
vector signed int vec_C[4];
|
|
2393
|
+
vector float CA[4] = {0};
|
|
2394
|
+
vector float res[4] = {0};
|
|
2395
|
+
__builtin_mma_disassemble_acc(vec_C, ACC);
|
|
2396
|
+
for (int i = 0; i < 4; i++) {
|
|
2397
|
+
CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0));
|
|
2398
|
+
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
|
2399
|
+
fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]);
|
|
2400
|
+
}
|
|
2401
|
+
}
|
|
2402
|
+
|
|
2403
|
+
inline void process_q4_elements(vector signed char (&c)[2], int * ca) {
|
|
2404
|
+
const vector signed char lowMask = vec_splats((signed char)0xF);
|
|
2405
|
+
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
|
|
2406
|
+
const vector signed char v8 = vec_splats((signed char)0x8);
|
|
2407
|
+
vector signed int vsum = {0};
|
|
2408
|
+
vector signed int vsum2 = {0};
|
|
2409
|
+
c[0] = vec_and(c[1], lowMask);
|
|
2410
|
+
c[1] = vec_sr(c[1], v4);
|
|
2411
|
+
c[0] = vec_sub(c[0], v8);
|
|
2412
|
+
c[1] = vec_sub(c[1], v8);
|
|
2413
|
+
vsum = vec_sum4s(c[0], vsum);
|
|
2414
|
+
vsum2 = vec_sum4s(c[1], vsum2);
|
|
2415
|
+
vsum = vec_add(vsum, vsum2);
|
|
2416
|
+
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
2417
|
+
}
|
|
2418
|
+
|
|
2419
|
+
template <typename V1, typename V2>
|
|
2420
|
+
inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) {
|
|
2421
|
+
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
|
2422
|
+
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
|
2423
|
+
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
|
2424
|
+
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
|
2425
|
+
V2 t1, t2, t3, t4, t5, t6, t7, t8;
|
|
2426
|
+
vector unsigned char xor_vector;
|
|
2427
|
+
uint8_t flip_vec = 0x80;
|
|
2428
|
+
xor_vector = vec_splats(flip_vec);
|
|
2429
|
+
t1 = vec_perm(s1, s2, swiz1);
|
|
2430
|
+
t2 = vec_perm(s1, s2, swiz2);
|
|
2431
|
+
t3 = vec_perm(s3, s4, swiz1);
|
|
2432
|
+
t4 = vec_perm(s3, s4, swiz2);
|
|
2433
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
2434
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
2435
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
2436
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
2437
|
+
if (flip == true) {
|
|
2438
|
+
t5 = vec_xor(t5, xor_vector);
|
|
2439
|
+
t6 = vec_xor(t6, xor_vector);
|
|
2440
|
+
t7 = vec_xor(t7, xor_vector);
|
|
2441
|
+
t8 = vec_xor(t8, xor_vector);
|
|
2442
|
+
}
|
|
2443
|
+
vec_xst(t5, 0, vecOffset);
|
|
2444
|
+
vec_xst(t6, 0, vecOffset + 16);
|
|
2445
|
+
vec_xst(t7, 0, vecOffset + 32);
|
|
2446
|
+
vec_xst(t8, 0, vecOffset + 48);
|
|
2447
|
+
}
|
|
2448
|
+
|
|
2449
|
+
inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) {
|
|
2450
|
+
const vector signed char lowMask = vec_splats((signed char)0x0F);
|
|
2451
|
+
const vector signed char v8 = vec_splats((signed char)0x08);
|
|
2452
|
+
const vector unsigned char v4 = vec_splats((unsigned char)4);
|
|
2453
|
+
lo = vec_and(packed, lowMask);
|
|
2454
|
+
hi = vec_sr(packed, v4);
|
|
2455
|
+
lo = vec_sub(lo, v8);
|
|
2456
|
+
hi = vec_sub(hi, v8);
|
|
2457
|
+
}
|
|
2458
|
+
|
|
2459
|
+
inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) {
|
|
2460
|
+
vec_t t[8], s[8];
|
|
2461
|
+
vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
|
|
2462
|
+
vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
|
|
2463
|
+
vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
|
2464
|
+
vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
|
2465
|
+
for (int i = 0; i < 4; i += 2) {
|
|
2466
|
+
t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
|
|
2467
|
+
t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
|
|
2468
|
+
}
|
|
2469
|
+
for (int i = 4; i < 8; i += 2) {
|
|
2470
|
+
t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
|
|
2471
|
+
t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
|
|
2472
|
+
}
|
|
2473
|
+
s[0] = vec_perm(t[0], t[2], swiz3);
|
|
2474
|
+
s[1] = vec_perm(t[0], t[2], swiz4);
|
|
2475
|
+
s[2] = vec_perm(t[1], t[3], swiz3);
|
|
2476
|
+
s[3] = vec_perm(t[1], t[3], swiz4);
|
|
2477
|
+
s[4] = vec_perm(t[4], t[6], swiz3);
|
|
2478
|
+
s[5] = vec_perm(t[4], t[6], swiz4);
|
|
2479
|
+
s[6] = vec_perm(t[5], t[7], swiz3);
|
|
2480
|
+
s[7] = vec_perm(t[5], t[7], swiz4);
|
|
2481
|
+
for (int i = 0; i < 8; ++i) {
|
|
2482
|
+
vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16));
|
|
2483
|
+
}
|
|
2484
|
+
}
|
|
2485
|
+
|
|
2486
|
+
static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) {
|
|
2487
|
+
vector signed short i16_hi = vec_unpackh(raw);
|
|
2488
|
+
vector signed short i16_lo = vec_unpackl(raw);
|
|
2489
|
+
|
|
2490
|
+
vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0);
|
|
2491
|
+
vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0);
|
|
2492
|
+
vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0);
|
|
2493
|
+
vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0);
|
|
2494
|
+
out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale));
|
|
2495
|
+
out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale));
|
|
2496
|
+
}
|
|
2497
|
+
|
|
2498
|
+
void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
|
2499
|
+
unsigned char * vecOffset = vec;
|
|
2500
|
+
for (int i = 0; i < rows; i += 8) {
|
|
2501
|
+
const block_q4_0 * rows_base[8];
|
|
2502
|
+
for (int r = 0; r < 8; r++) {
|
|
2503
|
+
rows_base[r] = a + (i + r) * lda;
|
|
2504
|
+
}
|
|
2505
|
+
for (int blk = 0; blk < blocks; blk++) {
|
|
2506
|
+
vector unsigned short hp_res[8][4];
|
|
2507
|
+
for (int r = 0; r < 8; r++) {
|
|
2508
|
+
const block_q4_0 * current_blk = rows_base[r] + blk;
|
|
2509
|
+
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d));
|
|
2510
|
+
vector signed char v_qs = vec_xl(0, (const vector signed char *)current_blk->qs);
|
|
2511
|
+
vector signed char c1, c2;
|
|
2512
|
+
unpack_q4_to_q8(v_qs, c1, c2);
|
|
2513
|
+
convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]);
|
|
2514
|
+
convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]);
|
|
2515
|
+
}
|
|
2516
|
+
for (int c = 0; c < 4; c++) {
|
|
2517
|
+
vector unsigned char c_arr[8];
|
|
2518
|
+
for (int r = 0; r < 8; r++) {
|
|
2519
|
+
c_arr[r] = (vector unsigned char)hp_res[r][c];
|
|
2520
|
+
}
|
|
2521
|
+
vector_permute_store_fp16((vec_t *)c_arr, vecOffset);
|
|
2522
|
+
vecOffset += 128;
|
|
2523
|
+
}
|
|
2524
|
+
}
|
|
2525
|
+
}
|
|
2526
|
+
}
|
|
2527
|
+
|
|
2528
|
+
template <int chunk_size>
|
|
2529
|
+
static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
|
2530
|
+
unsigned char * vecOffset = vec;
|
|
2531
|
+
const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
|
|
2532
|
+
const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
|
|
2533
|
+
const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
|
2534
|
+
const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
|
2535
|
+
|
|
2536
|
+
for (int i = 0; i < rows; i += chunk_size) {
|
|
2537
|
+
const block_q8_0 * rows_base[chunk_size];
|
|
2538
|
+
for (int r = 0; r < chunk_size; r++) {
|
|
2539
|
+
rows_base[r] = a + (i + r) * lda;
|
|
2540
|
+
}
|
|
2541
|
+
for (int blk = 0; blk < blocks; blk++) {
|
|
2542
|
+
vector unsigned short hp_res[chunk_size][4];
|
|
2543
|
+
for (int r = 0; r < chunk_size; r++) {
|
|
2544
|
+
const block_q8_0 * b = rows_base[r] + blk;
|
|
2545
|
+
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d));
|
|
2546
|
+
vector signed char c[2];
|
|
2547
|
+
__vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs);
|
|
2548
|
+
__builtin_vsx_disassemble_pair(c, & pair);
|
|
2549
|
+
convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]);
|
|
2550
|
+
convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]);
|
|
2551
|
+
}
|
|
2552
|
+
for (int col = 0; col < 4; col++) {
|
|
2553
|
+
if constexpr (chunk_size == 8) {
|
|
2554
|
+
vec_t t[8];
|
|
2555
|
+
t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
|
|
2556
|
+
t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
|
|
2557
|
+
t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
|
|
2558
|
+
t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
|
|
2559
|
+
t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1);
|
|
2560
|
+
t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2);
|
|
2561
|
+
t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1);
|
|
2562
|
+
t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2);
|
|
2563
|
+
|
|
2564
|
+
vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0));
|
|
2565
|
+
vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16));
|
|
2566
|
+
vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32));
|
|
2567
|
+
vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48));
|
|
2568
|
+
vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64));
|
|
2569
|
+
vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80));
|
|
2570
|
+
vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96));
|
|
2571
|
+
vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112));
|
|
2572
|
+
vecOffset += 128;
|
|
2573
|
+
} else {
|
|
2574
|
+
vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
|
|
2575
|
+
vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
|
|
2576
|
+
vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
|
|
2577
|
+
vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
|
|
2578
|
+
|
|
2579
|
+
vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0));
|
|
2580
|
+
vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16));
|
|
2581
|
+
vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32));
|
|
2582
|
+
vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48));
|
|
2583
|
+
vecOffset += 64;
|
|
2584
|
+
}
|
|
2585
|
+
}
|
|
2586
|
+
}
|
|
2587
|
+
}
|
|
2588
|
+
}
|
|
2589
|
+
|
|
2590
|
+
void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
|
2591
|
+
if (rows == 4) {
|
|
2592
|
+
pack_q8_block<4>(a, lda, rows, blocks, vec);
|
|
2593
|
+
} else {
|
|
2594
|
+
pack_q8_block<8>(a, lda, rows, blocks, vec);
|
|
2595
|
+
}
|
|
2596
|
+
}
|
|
2597
|
+
|
|
2598
|
+
template<int size>
|
|
2599
|
+
void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array<int, size> & comparray) {
|
|
2316
2600
|
int64_t i, j;
|
|
2317
|
-
TA *aoffset = NULL;
|
|
2318
|
-
int8_t *vecOffset = NULL;
|
|
2319
|
-
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
|
2320
|
-
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
|
2601
|
+
TA * aoffset = NULL;
|
|
2602
|
+
int8_t * vecOffset = NULL;
|
|
2603
|
+
TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL;
|
|
2604
|
+
TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL;
|
|
2321
2605
|
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
|
|
2322
2606
|
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
|
|
2323
|
-
aoffset = const_cast<TA*>(a);
|
|
2607
|
+
aoffset = const_cast<TA *>(a);
|
|
2324
2608
|
vecOffset = vec;
|
|
2325
2609
|
j = (rows >> 3);
|
|
2326
2610
|
if (j > 0) {
|
|
@@ -2337,27 +2621,27 @@ class tinyBLAS_BF16_PPC {
|
|
|
2337
2621
|
i = (cols >> 2);
|
|
2338
2622
|
if (i > 0) {
|
|
2339
2623
|
do {
|
|
2340
|
-
c1[1] =
|
|
2341
|
-
c2[1] =
|
|
2342
|
-
c3[1] =
|
|
2343
|
-
c4[1] =
|
|
2344
|
-
c5[1] =
|
|
2345
|
-
c6[1] =
|
|
2346
|
-
c7[1] =
|
|
2347
|
-
c8[1] =
|
|
2348
|
-
|
|
2349
|
-
process_q4_elements(c1, &comparray[0]);
|
|
2350
|
-
process_q4_elements(c2, &comparray[1]);
|
|
2351
|
-
process_q4_elements(c3, &comparray[2]);
|
|
2352
|
-
process_q4_elements(c4, &comparray[3]);
|
|
2353
|
-
process_q4_elements(c5, &comparray[4]);
|
|
2354
|
-
process_q4_elements(c6, &comparray[5]);
|
|
2355
|
-
process_q4_elements(c7, &comparray[6]);
|
|
2356
|
-
process_q4_elements(c8, &comparray[7]);
|
|
2624
|
+
c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
|
|
2625
|
+
c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
|
|
2626
|
+
c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
|
|
2627
|
+
c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs);
|
|
2628
|
+
c5[1] = vec_xl(0, (const vector signed char *)aoffset5->qs);
|
|
2629
|
+
c6[1] = vec_xl(0, (const vector signed char *)aoffset6->qs);
|
|
2630
|
+
c7[1] = vec_xl(0, (const vector signed char *)aoffset7->qs);
|
|
2631
|
+
c8[1] = vec_xl(0, (const vector signed char *)aoffset8->qs);
|
|
2632
|
+
|
|
2633
|
+
process_q4_elements(c1, & comparray[0]);
|
|
2634
|
+
process_q4_elements(c2, & comparray[1]);
|
|
2635
|
+
process_q4_elements(c3, & comparray[2]);
|
|
2636
|
+
process_q4_elements(c4, & comparray[3]);
|
|
2637
|
+
process_q4_elements(c5, & comparray[4]);
|
|
2638
|
+
process_q4_elements(c6, & comparray[5]);
|
|
2639
|
+
process_q4_elements(c7, & comparray[6]);
|
|
2640
|
+
process_q4_elements(c8, & comparray[7]);
|
|
2357
2641
|
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
|
2358
|
-
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
|
2359
|
-
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
|
|
2360
|
-
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
|
|
2642
|
+
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
|
2643
|
+
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false);
|
|
2644
|
+
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false);
|
|
2361
2645
|
aoffset1 += lda;
|
|
2362
2646
|
aoffset2 += lda;
|
|
2363
2647
|
aoffset3 += lda;
|
|
@@ -2383,17 +2667,17 @@ class tinyBLAS_BF16_PPC {
|
|
|
2383
2667
|
i = (cols >> 2);
|
|
2384
2668
|
if (i > 0) {
|
|
2385
2669
|
do {
|
|
2386
|
-
c1[1] =
|
|
2387
|
-
c2[1] =
|
|
2388
|
-
c3[1] =
|
|
2389
|
-
c4[1] =
|
|
2390
|
-
|
|
2391
|
-
process_q4_elements(c1, &comparray[0]);
|
|
2392
|
-
process_q4_elements(c2, &comparray[1]);
|
|
2393
|
-
process_q4_elements(c3, &comparray[2]);
|
|
2394
|
-
process_q4_elements(c4, &comparray[3]);
|
|
2670
|
+
c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
|
|
2671
|
+
c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
|
|
2672
|
+
c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
|
|
2673
|
+
c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs);
|
|
2674
|
+
|
|
2675
|
+
process_q4_elements(c1, & comparray[0]);
|
|
2676
|
+
process_q4_elements(c2, & comparray[1]);
|
|
2677
|
+
process_q4_elements(c3, & comparray[2]);
|
|
2678
|
+
process_q4_elements(c4, & comparray[3]);
|
|
2395
2679
|
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
|
2396
|
-
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
|
2680
|
+
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
|
2397
2681
|
aoffset1 += lda;
|
|
2398
2682
|
aoffset2 += lda;
|
|
2399
2683
|
aoffset3 += lda;
|
|
@@ -2412,17 +2696,17 @@ class tinyBLAS_BF16_PPC {
|
|
|
2412
2696
|
if (i > 0) {
|
|
2413
2697
|
do {
|
|
2414
2698
|
switch(rows) {
|
|
2415
|
-
case 3: c3[1] =
|
|
2416
|
-
case 2: c2[1] =
|
|
2417
|
-
case 1: c1[1] =
|
|
2699
|
+
case 3: c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
|
|
2700
|
+
case 2: c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
|
|
2701
|
+
case 1: c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
|
|
2418
2702
|
break;
|
|
2419
2703
|
}
|
|
2420
|
-
process_q4_elements(c1, &comparray[0]);
|
|
2421
|
-
process_q4_elements(c2, &comparray[1]);
|
|
2422
|
-
process_q4_elements(c3, &comparray[2]);
|
|
2423
|
-
process_q4_elements(c4, &comparray[3]);
|
|
2704
|
+
process_q4_elements(c1, & comparray[0]);
|
|
2705
|
+
process_q4_elements(c2, & comparray[1]);
|
|
2706
|
+
process_q4_elements(c3, & comparray[2]);
|
|
2707
|
+
process_q4_elements(c4, & comparray[3]);
|
|
2424
2708
|
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
|
2425
|
-
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
|
2709
|
+
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
|
2426
2710
|
aoffset1 += lda;
|
|
2427
2711
|
aoffset2 += lda;
|
|
2428
2712
|
aoffset3 += lda;
|
|
@@ -2433,39 +2717,38 @@ class tinyBLAS_BF16_PPC {
|
|
|
2433
2717
|
}
|
|
2434
2718
|
}
|
|
2435
2719
|
|
|
2436
|
-
template<typename TA>
|
|
2437
2720
|
template<typename VA, typename VB>
|
|
2438
|
-
void
|
|
2721
|
+
void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) {
|
|
2439
2722
|
int64_t i, j;
|
|
2440
|
-
block_q8_0 *aoffset = NULL;
|
|
2441
|
-
VA *vecOffset = NULL;
|
|
2442
|
-
block_q8_0* aoffsets[8];
|
|
2723
|
+
block_q8_0 * aoffset = NULL;
|
|
2724
|
+
VA * vecOffset = NULL;
|
|
2725
|
+
block_q8_0 * aoffsets[8];
|
|
2443
2726
|
__vector_pair arr[8];
|
|
2444
2727
|
VB c[8][2] = {0};
|
|
2445
2728
|
VB c1[8] = {0}; VB c2[8] = {0};
|
|
2446
|
-
aoffset = const_cast<block_q8_0*>(a);
|
|
2729
|
+
aoffset = const_cast<block_q8_0 *>(a);
|
|
2447
2730
|
vecOffset = vec;
|
|
2448
2731
|
j = (rows >> 3);
|
|
2449
2732
|
if (j > 0) {
|
|
2450
2733
|
do {
|
|
2451
2734
|
aoffsets[0] = aoffset;
|
|
2452
2735
|
for (int it = 1; it < 8; it++)
|
|
2453
|
-
aoffsets[it] = aoffsets[it-1] + lda;
|
|
2736
|
+
aoffsets[it] = aoffsets[it - 1] + lda;
|
|
2454
2737
|
aoffset += 8 * lda;
|
|
2455
2738
|
|
|
2456
2739
|
i = (cols >> 3);
|
|
2457
2740
|
if (i > 0) {
|
|
2458
2741
|
do {
|
|
2459
2742
|
for (int it = 0; it < 8; it++) {
|
|
2460
|
-
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
|
|
2461
|
-
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
|
2743
|
+
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
|
|
2744
|
+
__builtin_vsx_disassemble_pair(c[it], & arr[it]);
|
|
2462
2745
|
c1[it] = c[it][0];
|
|
2463
2746
|
c2[it] = c[it][1];
|
|
2464
2747
|
}
|
|
2465
2748
|
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
|
2466
|
-
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
|
2467
|
-
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
|
|
2468
|
-
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
|
|
2749
|
+
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
|
2750
|
+
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip);
|
|
2751
|
+
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip);
|
|
2469
2752
|
for (int it = 0; it < 8; it++)
|
|
2470
2753
|
aoffsets[it] += lda;
|
|
2471
2754
|
vecOffset += 256;
|
|
@@ -2484,13 +2767,13 @@ class tinyBLAS_BF16_PPC {
|
|
|
2484
2767
|
if (i > 0) {
|
|
2485
2768
|
do {
|
|
2486
2769
|
for (int it = 0; it < 4; it++) {
|
|
2487
|
-
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
|
|
2488
|
-
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
|
2770
|
+
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
|
|
2771
|
+
__builtin_vsx_disassemble_pair(c[it], & arr[it]);
|
|
2489
2772
|
c1[it] = c[it][0];
|
|
2490
2773
|
c2[it] = c[it][1];
|
|
2491
2774
|
}
|
|
2492
2775
|
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
|
2493
|
-
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
|
2776
|
+
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
|
2494
2777
|
for (int it = 0; it < 4; it++) {
|
|
2495
2778
|
aoffsets[it] += lda;
|
|
2496
2779
|
}
|
|
@@ -2503,24 +2786,24 @@ class tinyBLAS_BF16_PPC {
|
|
|
2503
2786
|
if (rows & 3) {
|
|
2504
2787
|
aoffsets[0] = aoffset;
|
|
2505
2788
|
for (int it = 1; it < 3; it++ )
|
|
2506
|
-
aoffsets[it] = aoffsets[it-1] + lda;
|
|
2789
|
+
aoffsets[it] = aoffsets[it - 1] + lda;
|
|
2507
2790
|
i = (cols >> 3);
|
|
2508
2791
|
if (i > 0) {
|
|
2509
2792
|
do {
|
|
2510
2793
|
switch(rows) {
|
|
2511
|
-
case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
|
|
2512
|
-
__builtin_vsx_disassemble_pair(c[2], &arr[2]);
|
|
2794
|
+
case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs);
|
|
2795
|
+
__builtin_vsx_disassemble_pair(c[2], & arr[2]);
|
|
2513
2796
|
c1[2] = c[2][0]; c2[2] = c[2][1];
|
|
2514
|
-
case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
|
|
2515
|
-
__builtin_vsx_disassemble_pair(c[1], &arr[1]);
|
|
2797
|
+
case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs);
|
|
2798
|
+
__builtin_vsx_disassemble_pair(c[1], & arr[1]);
|
|
2516
2799
|
c1[1] = c[1][0]; c2[1] = c[1][1];
|
|
2517
|
-
case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
|
|
2518
|
-
__builtin_vsx_disassemble_pair(c[0], &arr[0]);
|
|
2800
|
+
case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs);
|
|
2801
|
+
__builtin_vsx_disassemble_pair(c[0], & arr[0]);
|
|
2519
2802
|
c1[0] = c[0][0]; c2[0] = c[0][1];
|
|
2520
2803
|
break;
|
|
2521
2804
|
}
|
|
2522
2805
|
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
|
2523
|
-
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
|
2806
|
+
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
|
2524
2807
|
for (int it = 0; it < 3; it++)
|
|
2525
2808
|
aoffsets[it] += lda;
|
|
2526
2809
|
vecOffset += 128;
|
|
@@ -2530,8 +2813,7 @@ class tinyBLAS_BF16_PPC {
|
|
|
2530
2813
|
}
|
|
2531
2814
|
}
|
|
2532
2815
|
|
|
2533
|
-
|
|
2534
|
-
void tinyBLAS_Q0_PPC<TA>::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
2816
|
+
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
2535
2817
|
int m_rem = MIN(m - m0, 16);
|
|
2536
2818
|
int n_rem = MIN(n - n0, 16);
|
|
2537
2819
|
|
|
@@ -2568,8 +2850,7 @@ class tinyBLAS_BF16_PPC {
|
|
|
2568
2850
|
}
|
|
2569
2851
|
|
|
2570
2852
|
|
|
2571
|
-
|
|
2572
|
-
void tinyBLAS_Q0_PPC<TA>::KERNEL_4x8(int64_t ii, int64_t jj) {
|
|
2853
|
+
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
|
2573
2854
|
vec_t vec_A[8], vec_B[16] = {0};
|
|
2574
2855
|
acc_t acc_0, acc_1;
|
|
2575
2856
|
std::array<int, 4> comparray {};
|
|
@@ -2577,26 +2858,26 @@ class tinyBLAS_BF16_PPC {
|
|
|
2577
2858
|
vector float vs[8] = {0};
|
|
2578
2859
|
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
|
2579
2860
|
for (int l = 0; l < k; l++) {
|
|
2580
|
-
__builtin_mma_xxsetaccz(&acc_0);
|
|
2581
|
-
__builtin_mma_xxsetaccz(&acc_1);
|
|
2861
|
+
__builtin_mma_xxsetaccz(& acc_0);
|
|
2862
|
+
__builtin_mma_xxsetaccz(& acc_1);
|
|
2582
2863
|
if (std::is_same_v<TA, block_q4_0>) {
|
|
2583
|
-
packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
|
|
2864
|
+
packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray);
|
|
2584
2865
|
} else {
|
|
2585
|
-
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
|
|
2866
|
+
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false);
|
|
2586
2867
|
}
|
|
2587
|
-
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
|
2868
|
+
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
|
|
2588
2869
|
for(int x = 0; x < 8; x++) {
|
|
2589
|
-
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
|
2590
|
-
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
|
|
2870
|
+
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
|
2871
|
+
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]);
|
|
2591
2872
|
}
|
|
2592
2873
|
for (int I = 0; I<4; I++) {
|
|
2593
2874
|
for (int J = 0; J<4; J++) {
|
|
2594
|
-
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
|
2595
|
-
*((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
|
2875
|
+
*((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
|
2876
|
+
*((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
|
|
2596
2877
|
}
|
|
2597
2878
|
}
|
|
2598
2879
|
if (!isAblock_q4) {
|
|
2599
|
-
auto aoffset = A+(ii*lda)+l;
|
|
2880
|
+
auto aoffset = A + (ii * lda) + l;
|
|
2600
2881
|
for (int i = 0; i < 4; i++) {
|
|
2601
2882
|
comparray[i] = 0;
|
|
2602
2883
|
int ca = 0;
|
|
@@ -2607,15 +2888,14 @@ class tinyBLAS_BF16_PPC {
|
|
|
2607
2888
|
aoffset += lda;
|
|
2608
2889
|
}
|
|
2609
2890
|
}
|
|
2610
|
-
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
|
2611
|
-
compute(&acc_1, 0, 4, comparray, vs, fin_res);
|
|
2891
|
+
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
|
2892
|
+
compute(& acc_1, 0, 4, comparray, vs, fin_res);
|
|
2612
2893
|
}
|
|
2613
2894
|
save_res(ii, jj, 0, fin_res);
|
|
2614
|
-
save_res(ii, jj+4, 4, fin_res);
|
|
2895
|
+
save_res(ii, jj + 4, 4, fin_res);
|
|
2615
2896
|
}
|
|
2616
2897
|
|
|
2617
|
-
|
|
2618
|
-
void tinyBLAS_Q0_PPC<TA>::KERNEL_8x4(int64_t ii, int64_t jj) {
|
|
2898
|
+
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
|
2619
2899
|
vec_t vec_A[16], vec_B[8] = {0};
|
|
2620
2900
|
acc_t acc_0, acc_1;
|
|
2621
2901
|
std::array<int, 8> comparray {};
|
|
@@ -2623,25 +2903,25 @@ class tinyBLAS_BF16_PPC {
|
|
|
2623
2903
|
vector float vs[8] = {0};
|
|
2624
2904
|
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
|
2625
2905
|
for (int l = 0; l < k; l++) {
|
|
2626
|
-
__builtin_mma_xxsetaccz(&acc_0);
|
|
2627
|
-
__builtin_mma_xxsetaccz(&acc_1);
|
|
2906
|
+
__builtin_mma_xxsetaccz(& acc_0);
|
|
2907
|
+
__builtin_mma_xxsetaccz(& acc_1);
|
|
2628
2908
|
if (std::is_same_v<TA, block_q4_0>) {
|
|
2629
|
-
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
|
2909
|
+
packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
|
|
2630
2910
|
} else {
|
|
2631
|
-
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
|
2911
|
+
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
|
|
2632
2912
|
}
|
|
2633
|
-
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
|
|
2913
|
+
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true);
|
|
2634
2914
|
for(int x = 0; x < 8; x++) {
|
|
2635
|
-
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
|
2636
|
-
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
|
|
2915
|
+
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
|
2916
|
+
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
|
|
2637
2917
|
}
|
|
2638
|
-
for (int I = 0; I<8; I++) {
|
|
2639
|
-
for (int J = 0; J<4; J++) {
|
|
2640
|
-
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
|
2918
|
+
for (int I = 0; I < 8; I++) {
|
|
2919
|
+
for (int J = 0; J < 4; J++) {
|
|
2920
|
+
*((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
|
2641
2921
|
}
|
|
2642
2922
|
}
|
|
2643
2923
|
if (!isAblock_q4) {
|
|
2644
|
-
auto aoffset = A+(ii*lda)+l;
|
|
2924
|
+
auto aoffset = A + (ii * lda) + l;
|
|
2645
2925
|
for (int i = 0; i < 8; i++) {
|
|
2646
2926
|
comparray[i] = 0;
|
|
2647
2927
|
int ca = 0;
|
|
@@ -2652,15 +2932,14 @@ class tinyBLAS_BF16_PPC {
|
|
|
2652
2932
|
aoffset += lda;
|
|
2653
2933
|
}
|
|
2654
2934
|
}
|
|
2655
|
-
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
|
2656
|
-
compute(&acc_1, 4, 4, comparray, vs, fin_res);
|
|
2935
|
+
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
|
2936
|
+
compute(& acc_1, 4, 4, comparray, vs, fin_res);
|
|
2657
2937
|
}
|
|
2658
2938
|
save_res(ii, jj, 0, fin_res);
|
|
2659
|
-
save_res(ii+4, jj, 4, fin_res);
|
|
2939
|
+
save_res(ii + 4, jj, 4, fin_res);
|
|
2660
2940
|
}
|
|
2661
2941
|
|
|
2662
|
-
|
|
2663
|
-
void tinyBLAS_Q0_PPC<TA>::KERNEL_8x8(int64_t ii, int64_t jj) {
|
|
2942
|
+
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
|
2664
2943
|
vec_t vec_A[16], vec_B[16] = {0};
|
|
2665
2944
|
acc_t acc_0, acc_1, acc_2, acc_3;
|
|
2666
2945
|
acc_t acc_4, acc_5, acc_6, acc_7;
|
|
@@ -2669,30 +2948,30 @@ class tinyBLAS_BF16_PPC {
|
|
|
2669
2948
|
vector float vs[16] = {0};
|
|
2670
2949
|
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
|
2671
2950
|
for (int l = 0; l < k; l++) {
|
|
2672
|
-
__builtin_mma_xxsetaccz(&acc_0);
|
|
2673
|
-
__builtin_mma_xxsetaccz(&acc_1);
|
|
2674
|
-
__builtin_mma_xxsetaccz(&acc_2);
|
|
2675
|
-
__builtin_mma_xxsetaccz(&acc_3);
|
|
2951
|
+
__builtin_mma_xxsetaccz(& acc_0);
|
|
2952
|
+
__builtin_mma_xxsetaccz(& acc_1);
|
|
2953
|
+
__builtin_mma_xxsetaccz(& acc_2);
|
|
2954
|
+
__builtin_mma_xxsetaccz(& acc_3);
|
|
2676
2955
|
if (std::is_same_v<TA, block_q4_0>) {
|
|
2677
|
-
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
|
2956
|
+
packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
|
|
2678
2957
|
} else {
|
|
2679
|
-
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
|
2958
|
+
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
|
|
2680
2959
|
}
|
|
2681
|
-
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
|
2960
|
+
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
|
|
2682
2961
|
for(int x = 0; x < 8; x++) {
|
|
2683
|
-
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
|
2684
|
-
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
|
|
2685
|
-
__builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
|
|
2686
|
-
__builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
|
|
2962
|
+
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
|
2963
|
+
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
|
|
2964
|
+
__builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]);
|
|
2965
|
+
__builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]);
|
|
2687
2966
|
}
|
|
2688
|
-
for (int I = 0; I<8; I++) {
|
|
2689
|
-
for (int J = 0; J<4; J++) {
|
|
2690
|
-
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
|
2691
|
-
*((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
|
2967
|
+
for (int I = 0; I < 8 ; I++) {
|
|
2968
|
+
for (int J = 0; J < 4; J++) {
|
|
2969
|
+
*((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
|
2970
|
+
*((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
|
|
2692
2971
|
}
|
|
2693
2972
|
}
|
|
2694
2973
|
if (!isAblock_q4) {
|
|
2695
|
-
auto aoffset = A+(ii*lda)+l;
|
|
2974
|
+
auto aoffset = A + (ii * lda) + l;
|
|
2696
2975
|
for (int i = 0; i < 8; i++) {
|
|
2697
2976
|
comparray[i] = 0;
|
|
2698
2977
|
int ca = 0;
|
|
@@ -2703,19 +2982,99 @@ class tinyBLAS_BF16_PPC {
|
|
|
2703
2982
|
aoffset += lda;
|
|
2704
2983
|
}
|
|
2705
2984
|
}
|
|
2706
|
-
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
|
2707
|
-
compute(&acc_1, 4, 4, comparray, vs, fin_res);
|
|
2708
|
-
compute(&acc_2, 0, 8, comparray, vs, fin_res);
|
|
2709
|
-
compute(&acc_3, 4, 12, comparray, vs, fin_res);
|
|
2985
|
+
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
|
2986
|
+
compute(& acc_1, 4, 4, comparray, vs, fin_res);
|
|
2987
|
+
compute(& acc_2, 0, 8, comparray, vs, fin_res);
|
|
2988
|
+
compute(& acc_3, 4, 12, comparray, vs, fin_res);
|
|
2710
2989
|
}
|
|
2711
2990
|
save_res(ii, jj, 0, fin_res);
|
|
2712
|
-
save_res(ii+4, jj, 4, fin_res);
|
|
2713
|
-
save_res(ii, jj+4, 8, fin_res);
|
|
2714
|
-
save_res(ii+4, jj+4, 12, fin_res);
|
|
2991
|
+
save_res(ii + 4, jj, 4, fin_res);
|
|
2992
|
+
save_res(ii, jj + 4, 8, fin_res);
|
|
2993
|
+
save_res(ii + 4, jj + 4, 12, fin_res);
|
|
2994
|
+
}
|
|
2995
|
+
|
|
2996
|
+
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) {
|
|
2997
|
+
acc_t acc[8];
|
|
2998
|
+
for (int i = 0; i < mc ; i += 16) {
|
|
2999
|
+
for (int j = 0; j < nc; j += 8) {
|
|
3000
|
+
int A0_base = (i / 16) * (2 * 32 * kc);
|
|
3001
|
+
int B0_base = (j / 8) * (32 * kc);
|
|
3002
|
+
for (int x = 0; x < 8; x++) {
|
|
3003
|
+
__builtin_mma_xxsetaccz(&acc[x]);
|
|
3004
|
+
}
|
|
3005
|
+
for (int64_t kk = 0; kk < kc; kk++) {
|
|
3006
|
+
int A0_block_idx = A0_base + kk * 32;
|
|
3007
|
+
int B0_block_idx = B0_base + kk * 32;
|
|
3008
|
+
int A1_block_idx = A0_block_idx + 32 * kc;
|
|
3009
|
+
int B1_block_idx = B0_block_idx + 32 * kc;
|
|
3010
|
+
vec_t * A0_block = & vec_A[A0_block_idx];
|
|
3011
|
+
vec_t * B0_block = & vec_B[B0_block_idx];
|
|
3012
|
+
vec_t * A1_block = & vec_A[A1_block_idx];
|
|
3013
|
+
for (int it = 0; it < 4; it++) {
|
|
3014
|
+
for (int x = 0; x < 4; x++) {
|
|
3015
|
+
__builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]);
|
|
3016
|
+
__builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]);
|
|
3017
|
+
__builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]);
|
|
3018
|
+
__builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
|
|
3019
|
+
__builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]);
|
|
3020
|
+
__builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]);
|
|
3021
|
+
__builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]);
|
|
3022
|
+
__builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
|
|
3023
|
+
}
|
|
3024
|
+
}
|
|
3025
|
+
}
|
|
3026
|
+
if (l == 0) {
|
|
3027
|
+
save_acc(& acc[0], ii + i, jj + j);
|
|
3028
|
+
save_acc(& acc[1], ii + i, jj + j + 4);
|
|
3029
|
+
save_acc(& acc[2], ii + i + 4, jj + j);
|
|
3030
|
+
save_acc(& acc[3], ii + i + 4, jj + j + 4);
|
|
3031
|
+
save_acc(& acc[4], ii + i + 8, jj + j);
|
|
3032
|
+
save_acc(& acc[5], ii + i + 8, jj + j + 4);
|
|
3033
|
+
save_acc(& acc[6], ii + i + 12, jj + j);
|
|
3034
|
+
save_acc(& acc[7], ii + i + 12, jj + j + 4);
|
|
3035
|
+
} else {
|
|
3036
|
+
add_save_acc(& acc[0], ii + i, jj + j);
|
|
3037
|
+
add_save_acc(& acc[1], ii + i, jj + j + 4);
|
|
3038
|
+
add_save_acc(& acc[2], ii + i + 4, jj + j);
|
|
3039
|
+
add_save_acc(& acc[3], ii + i + 4, jj + j + 4);
|
|
3040
|
+
add_save_acc(& acc[4], ii + i + 8, jj + j);
|
|
3041
|
+
add_save_acc(& acc[5], ii + i + 8, jj + j + 4);
|
|
3042
|
+
add_save_acc(& acc[6], ii + i + 12, jj + j);
|
|
3043
|
+
add_save_acc(& acc[7], ii + i + 12, jj + j + 4);
|
|
3044
|
+
}
|
|
3045
|
+
}
|
|
3046
|
+
}
|
|
3047
|
+
}
|
|
3048
|
+
|
|
3049
|
+
void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
|
|
3050
|
+
vec_t A_pack[mc * kc * 4];
|
|
3051
|
+
vec_t B_pack[nc * kc * 4];
|
|
3052
|
+
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
|
|
3053
|
+
int64_t ytiles = m / mc;
|
|
3054
|
+
int64_t xtiles = n / nc;
|
|
3055
|
+
int64_t tiles = xtiles * ytiles;
|
|
3056
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
|
3057
|
+
int64_t start = duty * ith;
|
|
3058
|
+
int64_t end = start + duty;
|
|
3059
|
+
if (end > tiles) {
|
|
3060
|
+
end = tiles;
|
|
3061
|
+
}
|
|
3062
|
+
for (int64_t job = start; job < end; ++job) {
|
|
3063
|
+
int64_t ii = (job / xtiles) * mc;
|
|
3064
|
+
int64_t jj = (job % xtiles) * nc;
|
|
3065
|
+
for (int64_t kk = 0; kk < k; kk += kc) {
|
|
3066
|
+
if constexpr(is_Ablock_q4) {
|
|
3067
|
+
packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
|
|
3068
|
+
} else {
|
|
3069
|
+
packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
|
|
3070
|
+
}
|
|
3071
|
+
packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);
|
|
3072
|
+
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);
|
|
3073
|
+
}
|
|
3074
|
+
}
|
|
2715
3075
|
}
|
|
2716
3076
|
|
|
2717
|
-
|
|
2718
|
-
void tinyBLAS_Q0_PPC<TA>::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
|
3077
|
+
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
|
2719
3078
|
int64_t ytiles = (m - m0) / RM;
|
|
2720
3079
|
int64_t xtiles = (n - n0) / RN;
|
|
2721
3080
|
int64_t tiles = xtiles * ytiles;
|
|
@@ -2737,32 +3096,32 @@ class tinyBLAS_BF16_PPC {
|
|
|
2737
3096
|
vector float fin_res[4] = {0};
|
|
2738
3097
|
vector float vs[4] = {0};
|
|
2739
3098
|
vector float CA[4] = {0};
|
|
2740
|
-
__builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
|
|
2741
|
-
__builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
|
|
3099
|
+
__builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value
|
|
3100
|
+
__builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value
|
|
2742
3101
|
for (int l = 0; l < k; l++) {
|
|
2743
|
-
__builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
|
2744
|
-
__builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
|
2745
|
-
__builtin_mma_xxsetaccz(&acc_0);
|
|
3102
|
+
__builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
|
|
3103
|
+
__builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
|
|
3104
|
+
__builtin_mma_xxsetaccz(& acc_0);
|
|
2746
3105
|
if (isAblock_q4) {
|
|
2747
|
-
|
|
3106
|
+
packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray);
|
|
2748
3107
|
} else {
|
|
2749
|
-
|
|
3108
|
+
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false);
|
|
2750
3109
|
}
|
|
2751
|
-
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
|
|
2752
|
-
for(int x = 0; x < 8; x+=4) {
|
|
2753
|
-
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
|
2754
|
-
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
|
|
2755
|
-
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
|
|
2756
|
-
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
|
|
3110
|
+
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true);
|
|
3111
|
+
for (int x = 0; x < 8; x += 4) {
|
|
3112
|
+
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
|
3113
|
+
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]);
|
|
3114
|
+
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]);
|
|
3115
|
+
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]);
|
|
2757
3116
|
}
|
|
2758
|
-
for (int I = 0; I<RM; I++) {
|
|
2759
|
-
for (int J = 0; J<RN; J++) {
|
|
2760
|
-
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
|
3117
|
+
for (int I = 0; I < RM; I++) {
|
|
3118
|
+
for (int J = 0; J < RN; J++) {
|
|
3119
|
+
*((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
|
2761
3120
|
}
|
|
2762
3121
|
}
|
|
2763
|
-
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
|
3122
|
+
__builtin_mma_disassemble_acc(vec_C, & acc_0);
|
|
2764
3123
|
if (!isAblock_q4) {
|
|
2765
|
-
auto aoffset = A+(ii*lda)+l;
|
|
3124
|
+
auto aoffset = A + (ii * lda) + l;
|
|
2766
3125
|
for (int i = 0; i < RM; i++) {
|
|
2767
3126
|
comparray[i] = 0;
|
|
2768
3127
|
int ca = 0;
|
|
@@ -2783,9 +3142,21 @@ class tinyBLAS_BF16_PPC {
|
|
|
2783
3142
|
}
|
|
2784
3143
|
}
|
|
2785
3144
|
|
|
2786
|
-
template<
|
|
3145
|
+
template<int RM, int RN>
|
|
3146
|
+
inline void kernel(int64_t ii, int64_t jj) {
|
|
3147
|
+
if constexpr(RM == 4 && RN == 8) {
|
|
3148
|
+
KERNEL_4x8(ii,jj);
|
|
3149
|
+
} else if constexpr(RM == 8 && RN == 4) {
|
|
3150
|
+
KERNEL_8x4(ii,jj);
|
|
3151
|
+
} else if constexpr(RM == 8 && RN == 8) {
|
|
3152
|
+
KERNEL_8x8(ii,jj);
|
|
3153
|
+
} else {
|
|
3154
|
+
assert(false && "RN/RM values not supported");
|
|
3155
|
+
}
|
|
3156
|
+
}
|
|
3157
|
+
|
|
2787
3158
|
template <int RM, int RN>
|
|
2788
|
-
NOINLINE void
|
|
3159
|
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
2789
3160
|
int64_t ytiles = (m - m0) / RM;
|
|
2790
3161
|
int64_t xtiles = (n - n0) / RN;
|
|
2791
3162
|
int64_t tiles = xtiles * ytiles;
|
|
@@ -2797,12 +3168,20 @@ class tinyBLAS_BF16_PPC {
|
|
|
2797
3168
|
for (int64_t job = start; job < end; ++job) {
|
|
2798
3169
|
int64_t ii = m0 + job / xtiles * RM;
|
|
2799
3170
|
int64_t jj = n0 + job % xtiles * RN;
|
|
2800
|
-
|
|
3171
|
+
kernel<RM, RN>(ii, jj);
|
|
2801
3172
|
}
|
|
2802
3173
|
}
|
|
2803
|
-
|
|
2804
|
-
|
|
2805
|
-
|
|
3174
|
+
const TA * const A;
|
|
3175
|
+
const block_q8_0 * const B;
|
|
3176
|
+
float * C;
|
|
3177
|
+
const int64_t k;
|
|
3178
|
+
int64_t kc;
|
|
3179
|
+
const int64_t lda;
|
|
3180
|
+
const int64_t ldb;
|
|
3181
|
+
const int64_t ldc;
|
|
3182
|
+
const int ith;
|
|
3183
|
+
const int nth;
|
|
3184
|
+
};
|
|
2806
3185
|
|
|
2807
3186
|
class tinyBLAS_PPC {
|
|
2808
3187
|
public:
|
|
@@ -2815,16 +3194,21 @@ class tinyBLAS_PPC {
|
|
|
2815
3194
|
}
|
|
2816
3195
|
|
|
2817
3196
|
void matmul(int64_t m, int64_t n) {
|
|
3197
|
+
#if defined(_AIX) || defined(__BIG_ENDIAN__)
|
|
3198
|
+
mnpack(0, m, 0, n);
|
|
3199
|
+
#else
|
|
2818
3200
|
int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
|
|
2819
3201
|
if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
|
|
2820
3202
|
matmul_tiled(m, n, mc, nc, kc);
|
|
2821
3203
|
} else {
|
|
2822
3204
|
mnpack(0, m, 0, n);
|
|
2823
3205
|
}
|
|
3206
|
+
#endif
|
|
2824
3207
|
}
|
|
2825
3208
|
|
|
2826
3209
|
private:
|
|
2827
3210
|
|
|
3211
|
+
__attribute__((always_inline))
|
|
2828
3212
|
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
|
2829
3213
|
vec_t vec_C[4];
|
|
2830
3214
|
__builtin_mma_disassemble_acc(vec_C, ACC);
|
|
@@ -2835,6 +3219,7 @@ class tinyBLAS_PPC {
|
|
|
2835
3219
|
}
|
|
2836
3220
|
}
|
|
2837
3221
|
|
|
3222
|
+
__attribute__((always_inline))
|
|
2838
3223
|
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
|
2839
3224
|
vec_t vec_C[4];
|
|
2840
3225
|
__builtin_mma_disassemble_acc(vec_C, ACC);
|
|
@@ -3369,7 +3754,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
3369
3754
|
params->ith, params->nth};
|
|
3370
3755
|
tb.matmul(m, n);
|
|
3371
3756
|
return true;
|
|
3372
|
-
#elif defined(
|
|
3757
|
+
#elif defined(__riscv_v_intrinsic)
|
|
3373
3758
|
#if LMUL == 1
|
|
3374
3759
|
tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,
|
|
3375
3760
|
k, (const float *)A, lda,
|
|
@@ -3418,35 +3803,40 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
3418
3803
|
return tb.matmul(m, n);
|
|
3419
3804
|
}
|
|
3420
3805
|
#elif defined(__MMA__)
|
|
3421
|
-
if (
|
|
3422
|
-
|
|
3423
|
-
if(Btype == GGML_TYPE_BF16) {
|
|
3424
|
-
tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
|
|
3425
|
-
(const ggml_bf16_t *)A, lda,
|
|
3426
|
-
(const ggml_bf16_t *)B, ldb,
|
|
3427
|
-
(float *)C, ldc,
|
|
3428
|
-
params->ith, params->nth};
|
|
3429
|
-
tb.matmul(m, n);
|
|
3430
|
-
return true;
|
|
3806
|
+
if (k % 8) {
|
|
3807
|
+
return false;
|
|
3431
3808
|
}
|
|
3432
|
-
|
|
3433
|
-
|
|
3434
|
-
|
|
3435
|
-
|
|
3436
|
-
(const ggml_bf16_t *)B, ldb,
|
|
3437
|
-
(float *)C, ldc};
|
|
3438
|
-
#elif LMUL == 2
|
|
3439
|
-
tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
|
3440
|
-
k, (const ggml_bf16_t *)A, lda,
|
|
3441
|
-
(const ggml_bf16_t *)B, ldb,
|
|
3442
|
-
(float *)C, ldc};
|
|
3443
|
-
#else // LMUL = 4
|
|
3444
|
-
tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
|
3445
|
-
k, (const ggml_bf16_t *)A, lda,
|
|
3809
|
+
|
|
3810
|
+
if (Btype == GGML_TYPE_BF16) {
|
|
3811
|
+
tinyBLAS_HP16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
|
|
3812
|
+
(const ggml_bf16_t *)A, lda,
|
|
3446
3813
|
(const ggml_bf16_t *)B, ldb,
|
|
3447
|
-
(float *)C, ldc
|
|
3448
|
-
|
|
3449
|
-
|
|
3814
|
+
(float *)C, ldc,
|
|
3815
|
+
params->ith, params->nth };
|
|
3816
|
+
|
|
3817
|
+
tb.matmul(m, n);
|
|
3818
|
+
return true;
|
|
3819
|
+
}
|
|
3820
|
+
#elif defined(__riscv_zvfbfwma)
|
|
3821
|
+
if (Btype == GGML_TYPE_BF16) {
|
|
3822
|
+
#if LMUL == 1
|
|
3823
|
+
tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
|
3824
|
+
k, (const ggml_bf16_t *)A, lda,
|
|
3825
|
+
(const ggml_bf16_t *)B, ldb,
|
|
3826
|
+
(float *)C, ldc};
|
|
3827
|
+
#elif LMUL == 2
|
|
3828
|
+
tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
|
3829
|
+
k, (const ggml_bf16_t *)A, lda,
|
|
3830
|
+
(const ggml_bf16_t *)B, ldb,
|
|
3831
|
+
(float *)C, ldc};
|
|
3832
|
+
#else // LMUL = 4
|
|
3833
|
+
tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
|
3834
|
+
k, (const ggml_bf16_t *)A, lda,
|
|
3835
|
+
(const ggml_bf16_t *)B, ldb,
|
|
3836
|
+
(float *)C, ldc};
|
|
3837
|
+
#endif
|
|
3838
|
+
return tb.matmul(m, n);
|
|
3839
|
+
}
|
|
3450
3840
|
#endif
|
|
3451
3841
|
return false;
|
|
3452
3842
|
}
|
|
@@ -3516,6 +3906,21 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
3516
3906
|
#endif
|
|
3517
3907
|
return tb.matmul(m, n);
|
|
3518
3908
|
}
|
|
3909
|
+
#elif defined(__MMA__)
|
|
3910
|
+
if (k % 8) {
|
|
3911
|
+
return false;
|
|
3912
|
+
}
|
|
3913
|
+
|
|
3914
|
+
if (Btype == GGML_TYPE_F16) {
|
|
3915
|
+
tinyBLAS_HP16_PPC<ggml_fp16_t, ggml_fp16_t, float> tb{ k,
|
|
3916
|
+
(const ggml_fp16_t *)A, lda,
|
|
3917
|
+
(const ggml_fp16_t *)B, ldb,
|
|
3918
|
+
(float *)C, ldc,
|
|
3919
|
+
params->ith, params->nth };
|
|
3920
|
+
|
|
3921
|
+
tb.matmul(m, n);
|
|
3922
|
+
return true;
|
|
3923
|
+
}
|
|
3519
3924
|
#endif
|
|
3520
3925
|
return false;
|
|
3521
3926
|
}
|