whispercpp 1.3.5 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/LICENSE +1 -1
- data/README.md +133 -3
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -7
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +56 -46
- data/ext/ruby_whisper.h +165 -2
- data/ext/ruby_whisper_context.c +297 -126
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -66
- data/ext/ruby_whisper_segment.c +6 -7
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +46 -16
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +24 -19
- data/ext/sources/examples/cli/cli.cpp +51 -9
- data/ext/sources/examples/common-ggml.cpp +4 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +213 -163
- data/ext/sources/ggml/CMakeLists.txt +29 -15
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +73 -11
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +8 -3
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +155 -16
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +25 -5
- data/ext/sources/ggml/src/ggml-alloc.c +9 -10
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
- data/ext/sources/ggml/src/ggml-common.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
- data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
- data/ext/sources/ggml/src/ggml-impl.h +68 -1
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +385 -119
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
- data/ext/sources/ggml/src/ggml.c +268 -52
- data/ext/sources/ggml/src/gguf.cpp +377 -47
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +62 -40
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +445 -55
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_context_params.rb +82 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +44 -6
- data/whispercpp.gemspec +2 -2
- metadata +426 -280
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
- data/ext/sources/examples/talk-llama/llama-context.h +0 -360
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
- data/ext/sources/examples/talk-llama/llama-model.h +0 -544
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
- data/ext/sources/examples/talk-llama/llama.h +0 -1540
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -569
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
|
@@ -3,14 +3,14 @@
|
|
|
3
3
|
#include "ggml-cpu.h"
|
|
4
4
|
#include "ggml-impl.h"
|
|
5
5
|
#include "binary-ops.h"
|
|
6
|
+
#include "simd-gemm.h"
|
|
6
7
|
#include "ggml.h"
|
|
7
8
|
#include "unary-ops.h"
|
|
8
9
|
#include "vec.h"
|
|
9
10
|
|
|
10
|
-
#include <cfloat>
|
|
11
11
|
#include <algorithm>
|
|
12
|
+
#include <cfloat>
|
|
12
13
|
#include <cmath>
|
|
13
|
-
#include <functional>
|
|
14
14
|
|
|
15
15
|
// ggml_compute_forward_dup
|
|
16
16
|
|
|
@@ -375,7 +375,7 @@ static void ggml_compute_forward_dup_bytes(
|
|
|
375
375
|
const size_t rs = ne00 * type_size;
|
|
376
376
|
|
|
377
377
|
if (nb00 == type_size) {
|
|
378
|
-
// src0 is
|
|
378
|
+
// src0 is contiguous on first dimension, copy by rows
|
|
379
379
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
380
380
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
381
381
|
id += rs * ir0;
|
|
@@ -664,12 +664,14 @@ void ggml_compute_forward_add(
|
|
|
664
664
|
{
|
|
665
665
|
ggml_compute_forward_add_non_quantized(params, dst);
|
|
666
666
|
} break;
|
|
667
|
+
case GGML_TYPE_Q1_0:
|
|
667
668
|
case GGML_TYPE_Q4_0:
|
|
668
669
|
case GGML_TYPE_Q4_1:
|
|
669
670
|
case GGML_TYPE_Q5_0:
|
|
670
671
|
case GGML_TYPE_Q5_1:
|
|
671
672
|
case GGML_TYPE_Q8_0:
|
|
672
673
|
case GGML_TYPE_MXFP4:
|
|
674
|
+
case GGML_TYPE_NVFP4:
|
|
673
675
|
case GGML_TYPE_Q2_K:
|
|
674
676
|
case GGML_TYPE_Q3_K:
|
|
675
677
|
case GGML_TYPE_Q4_K:
|
|
@@ -1112,6 +1114,7 @@ void ggml_compute_forward_add1(
|
|
|
1112
1114
|
GGML_ABORT("fatal error");
|
|
1113
1115
|
}
|
|
1114
1116
|
} break;
|
|
1117
|
+
case GGML_TYPE_Q1_0:
|
|
1115
1118
|
case GGML_TYPE_Q4_0:
|
|
1116
1119
|
case GGML_TYPE_Q4_1:
|
|
1117
1120
|
case GGML_TYPE_Q5_0:
|
|
@@ -1119,6 +1122,7 @@ void ggml_compute_forward_add1(
|
|
|
1119
1122
|
case GGML_TYPE_Q8_0:
|
|
1120
1123
|
case GGML_TYPE_Q8_1:
|
|
1121
1124
|
case GGML_TYPE_MXFP4:
|
|
1125
|
+
case GGML_TYPE_NVFP4:
|
|
1122
1126
|
case GGML_TYPE_Q2_K:
|
|
1123
1127
|
case GGML_TYPE_Q3_K:
|
|
1124
1128
|
case GGML_TYPE_Q4_K:
|
|
@@ -1240,6 +1244,7 @@ void ggml_compute_forward_acc(
|
|
|
1240
1244
|
} break;
|
|
1241
1245
|
case GGML_TYPE_F16:
|
|
1242
1246
|
case GGML_TYPE_BF16:
|
|
1247
|
+
case GGML_TYPE_Q1_0:
|
|
1243
1248
|
case GGML_TYPE_Q4_0:
|
|
1244
1249
|
case GGML_TYPE_Q4_1:
|
|
1245
1250
|
case GGML_TYPE_Q5_0:
|
|
@@ -1247,6 +1252,7 @@ void ggml_compute_forward_acc(
|
|
|
1247
1252
|
case GGML_TYPE_Q8_0:
|
|
1248
1253
|
case GGML_TYPE_Q8_1:
|
|
1249
1254
|
case GGML_TYPE_MXFP4:
|
|
1255
|
+
case GGML_TYPE_NVFP4:
|
|
1250
1256
|
case GGML_TYPE_Q2_K:
|
|
1251
1257
|
case GGML_TYPE_Q3_K:
|
|
1252
1258
|
case GGML_TYPE_Q4_K:
|
|
@@ -1795,7 +1801,7 @@ void ggml_compute_forward_repeat(
|
|
|
1795
1801
|
{
|
|
1796
1802
|
ggml_compute_forward_repeat_f32(params, dst);
|
|
1797
1803
|
} break;
|
|
1798
|
-
// TODO: templateify the
|
|
1804
|
+
// TODO: templateify the implementation and support for I64
|
|
1799
1805
|
// ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
|
|
1800
1806
|
//case GGML_TYPE_I64:
|
|
1801
1807
|
// {
|
|
@@ -2097,10 +2103,14 @@ static void ggml_compute_forward_gelu_f32(
|
|
|
2097
2103
|
|
|
2098
2104
|
const ggml_tensor * src0 = dst->src[0];
|
|
2099
2105
|
|
|
2100
|
-
assert(
|
|
2101
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2106
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2102
2107
|
assert(ggml_are_same_shape(src0, dst));
|
|
2103
2108
|
|
|
2109
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2110
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2111
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2112
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2113
|
+
|
|
2104
2114
|
const int ith = params->ith;
|
|
2105
2115
|
const int nth = params->nth;
|
|
2106
2116
|
|
|
@@ -2114,19 +2124,23 @@ static void ggml_compute_forward_gelu_f32(
|
|
|
2114
2124
|
const int ir0 = dr*ith;
|
|
2115
2125
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2116
2126
|
|
|
2117
|
-
for (int
|
|
2127
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2128
|
+
const int i3 = ir/(ne02*ne01);
|
|
2129
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2130
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2131
|
+
|
|
2118
2132
|
ggml_vec_gelu_f32(nc,
|
|
2119
|
-
(float *) ((char *) dst->data + i1*
|
|
2120
|
-
(float *) ((char *) src0->data + i1*
|
|
2133
|
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2134
|
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2121
2135
|
|
|
2122
2136
|
#ifndef NDEBUG
|
|
2123
2137
|
for (int k = 0; k < nc; k++) {
|
|
2124
|
-
const float x = ((float *) ((char *) dst->data + i1*(
|
|
2138
|
+
const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
|
|
2125
2139
|
GGML_UNUSED(x);
|
|
2126
2140
|
assert(!isnan(x));
|
|
2127
2141
|
assert(!isinf(x));
|
|
2128
2142
|
}
|
|
2129
|
-
#endif
|
|
2143
|
+
#endif // NDEBUG
|
|
2130
2144
|
}
|
|
2131
2145
|
}
|
|
2132
2146
|
|
|
@@ -2136,10 +2150,14 @@ static void ggml_compute_forward_gelu_f16(
|
|
|
2136
2150
|
|
|
2137
2151
|
const ggml_tensor * src0 = dst->src[0];
|
|
2138
2152
|
|
|
2139
|
-
assert(
|
|
2140
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2153
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2141
2154
|
assert(ggml_are_same_shape(src0, dst));
|
|
2142
2155
|
|
|
2156
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2157
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2158
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2159
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2160
|
+
|
|
2143
2161
|
const int ith = params->ith;
|
|
2144
2162
|
const int nth = params->nth;
|
|
2145
2163
|
|
|
@@ -2153,20 +2171,24 @@ static void ggml_compute_forward_gelu_f16(
|
|
|
2153
2171
|
const int ir0 = dr*ith;
|
|
2154
2172
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2155
2173
|
|
|
2156
|
-
for (int
|
|
2174
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2175
|
+
const int i3 = ir/(ne02*ne01);
|
|
2176
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2177
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2178
|
+
|
|
2157
2179
|
ggml_vec_gelu_f16(nc,
|
|
2158
|
-
(ggml_fp16_t *) ((char *) dst->data + i1*
|
|
2159
|
-
(ggml_fp16_t *) ((char *) src0->data + i1*
|
|
2180
|
+
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2181
|
+
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2160
2182
|
|
|
2161
2183
|
#ifndef NDEBUG
|
|
2162
2184
|
for (int k = 0; k < nc; k++) {
|
|
2163
|
-
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2185
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
|
|
2164
2186
|
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2165
2187
|
GGML_UNUSED(v);
|
|
2166
2188
|
assert(!isnan(v));
|
|
2167
2189
|
assert(!isinf(v));
|
|
2168
2190
|
}
|
|
2169
|
-
#endif
|
|
2191
|
+
#endif // NDEBUG
|
|
2170
2192
|
}
|
|
2171
2193
|
}
|
|
2172
2194
|
|
|
@@ -2213,8 +2235,42 @@ static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, gg
|
|
|
2213
2235
|
}
|
|
2214
2236
|
}
|
|
2215
2237
|
|
|
2238
|
+
static void ggml_compute_forward_fill_f16(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2239
|
+
const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0));
|
|
2240
|
+
|
|
2241
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
|
|
2242
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
|
|
2243
|
+
|
|
2244
|
+
const auto [ir0, ir1] = get_thread_range(params, dst);
|
|
2245
|
+
|
|
2246
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
2247
|
+
const int64_t i03 = ir/(ne2*ne1);
|
|
2248
|
+
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
|
|
2249
|
+
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
|
|
2250
|
+
|
|
2251
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
|
2252
|
+
|
|
2253
|
+
ggml_vec_set_f16(ne0, dst_ptr, c);
|
|
2254
|
+
}
|
|
2255
|
+
}
|
|
2256
|
+
|
|
2216
2257
|
void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2217
|
-
|
|
2258
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
2259
|
+
|
|
2260
|
+
switch (src0->type) {
|
|
2261
|
+
case GGML_TYPE_F32:
|
|
2262
|
+
{
|
|
2263
|
+
ggml_compute_forward_fill_f32(params, dst);
|
|
2264
|
+
} break;
|
|
2265
|
+
case GGML_TYPE_F16:
|
|
2266
|
+
{
|
|
2267
|
+
ggml_compute_forward_fill_f16(params, dst);
|
|
2268
|
+
} break;
|
|
2269
|
+
default:
|
|
2270
|
+
{
|
|
2271
|
+
GGML_ABORT("unsupported type for ggml_compute_forward_fill: %s", ggml_type_name(src0->type));
|
|
2272
|
+
}
|
|
2273
|
+
}
|
|
2218
2274
|
}
|
|
2219
2275
|
|
|
2220
2276
|
// ggml_compute_tri
|
|
@@ -2277,10 +2333,14 @@ static void ggml_compute_forward_gelu_erf_f32(
|
|
|
2277
2333
|
|
|
2278
2334
|
const ggml_tensor * src0 = dst->src[0];
|
|
2279
2335
|
|
|
2280
|
-
assert(
|
|
2281
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2336
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2282
2337
|
assert(ggml_are_same_shape(src0, dst));
|
|
2283
2338
|
|
|
2339
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2340
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2341
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2342
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2343
|
+
|
|
2284
2344
|
const int ith = params->ith;
|
|
2285
2345
|
const int nth = params->nth;
|
|
2286
2346
|
|
|
@@ -2294,19 +2354,23 @@ static void ggml_compute_forward_gelu_erf_f32(
|
|
|
2294
2354
|
const int ir0 = dr*ith;
|
|
2295
2355
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2296
2356
|
|
|
2297
|
-
for (int
|
|
2357
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2358
|
+
const int i3 = ir/(ne02*ne01);
|
|
2359
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2360
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2361
|
+
|
|
2298
2362
|
ggml_vec_gelu_erf_f32(nc,
|
|
2299
|
-
(float *) ((char *) dst->data + i1*
|
|
2300
|
-
(float *) ((char *) src0->data + i1*
|
|
2363
|
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2364
|
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2301
2365
|
|
|
2302
2366
|
#ifndef NDEBUG
|
|
2303
2367
|
for (int k = 0; k < nc; k++) {
|
|
2304
|
-
const float x = ((float *) ((char *) dst->data + i1*(
|
|
2368
|
+
const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
|
|
2305
2369
|
GGML_UNUSED(x);
|
|
2306
2370
|
assert(!isnan(x));
|
|
2307
2371
|
assert(!isinf(x));
|
|
2308
2372
|
}
|
|
2309
|
-
#endif
|
|
2373
|
+
#endif // NDEBUG
|
|
2310
2374
|
}
|
|
2311
2375
|
}
|
|
2312
2376
|
|
|
@@ -2316,10 +2380,14 @@ static void ggml_compute_forward_gelu_erf_f16(
|
|
|
2316
2380
|
|
|
2317
2381
|
const ggml_tensor * src0 = dst->src[0];
|
|
2318
2382
|
|
|
2319
|
-
assert(
|
|
2320
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2383
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2321
2384
|
assert(ggml_are_same_shape(src0, dst));
|
|
2322
2385
|
|
|
2386
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2387
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2388
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2389
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2390
|
+
|
|
2323
2391
|
const int ith = params->ith;
|
|
2324
2392
|
const int nth = params->nth;
|
|
2325
2393
|
|
|
@@ -2333,20 +2401,24 @@ static void ggml_compute_forward_gelu_erf_f16(
|
|
|
2333
2401
|
const int ir0 = dr*ith;
|
|
2334
2402
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2335
2403
|
|
|
2336
|
-
for (int
|
|
2404
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2405
|
+
const int i3 = ir/(ne02*ne01);
|
|
2406
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2407
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2408
|
+
|
|
2337
2409
|
ggml_vec_gelu_erf_f16(nc,
|
|
2338
|
-
(ggml_fp16_t *) ((char *) dst->data + i1*
|
|
2339
|
-
(ggml_fp16_t *) ((char *) src0->data + i1*
|
|
2410
|
+
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2411
|
+
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2340
2412
|
|
|
2341
2413
|
#ifndef NDEBUG
|
|
2342
2414
|
for (int k = 0; k < nc; k++) {
|
|
2343
|
-
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2415
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
|
|
2344
2416
|
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2345
2417
|
GGML_UNUSED(v);
|
|
2346
2418
|
assert(!isnan(v));
|
|
2347
2419
|
assert(!isinf(v));
|
|
2348
2420
|
}
|
|
2349
|
-
#endif
|
|
2421
|
+
#endif // NDEBUG
|
|
2350
2422
|
}
|
|
2351
2423
|
}
|
|
2352
2424
|
|
|
@@ -2380,10 +2452,14 @@ static void ggml_compute_forward_gelu_quick_f32(
|
|
|
2380
2452
|
|
|
2381
2453
|
const ggml_tensor * src0 = dst->src[0];
|
|
2382
2454
|
|
|
2383
|
-
assert(
|
|
2384
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2455
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2385
2456
|
assert(ggml_are_same_shape(src0, dst));
|
|
2386
2457
|
|
|
2458
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2459
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2460
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2461
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2462
|
+
|
|
2387
2463
|
const int ith = params->ith;
|
|
2388
2464
|
const int nth = params->nth;
|
|
2389
2465
|
|
|
@@ -2397,19 +2473,23 @@ static void ggml_compute_forward_gelu_quick_f32(
|
|
|
2397
2473
|
const int ir0 = dr*ith;
|
|
2398
2474
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2399
2475
|
|
|
2400
|
-
for (int
|
|
2476
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2477
|
+
const int i3 = ir/(ne02*ne01);
|
|
2478
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2479
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2480
|
+
|
|
2401
2481
|
ggml_vec_gelu_quick_f32(nc,
|
|
2402
|
-
(float *) ((char *) dst->data + i1*
|
|
2403
|
-
(float *) ((char *) src0->data + i1*
|
|
2482
|
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2483
|
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2404
2484
|
|
|
2405
2485
|
#ifndef NDEBUG
|
|
2406
2486
|
for (int k = 0; k < nc; k++) {
|
|
2407
|
-
const float x = ((float *) ((char *) dst->data + i1*(
|
|
2487
|
+
const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
|
|
2408
2488
|
GGML_UNUSED(x);
|
|
2409
2489
|
assert(!isnan(x));
|
|
2410
2490
|
assert(!isinf(x));
|
|
2411
2491
|
}
|
|
2412
|
-
#endif
|
|
2492
|
+
#endif // NDEBUG
|
|
2413
2493
|
}
|
|
2414
2494
|
}
|
|
2415
2495
|
|
|
@@ -2419,10 +2499,14 @@ static void ggml_compute_forward_gelu_quick_f16(
|
|
|
2419
2499
|
|
|
2420
2500
|
const ggml_tensor * src0 = dst->src[0];
|
|
2421
2501
|
|
|
2422
|
-
assert(
|
|
2423
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2502
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2424
2503
|
assert(ggml_are_same_shape(src0, dst));
|
|
2425
2504
|
|
|
2505
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2506
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2507
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2508
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2509
|
+
|
|
2426
2510
|
const int ith = params->ith;
|
|
2427
2511
|
const int nth = params->nth;
|
|
2428
2512
|
|
|
@@ -2436,20 +2520,24 @@ static void ggml_compute_forward_gelu_quick_f16(
|
|
|
2436
2520
|
const int ir0 = dr*ith;
|
|
2437
2521
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2438
2522
|
|
|
2439
|
-
for (int
|
|
2523
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2524
|
+
const int i3 = ir/(ne02*ne01);
|
|
2525
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2526
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2527
|
+
|
|
2440
2528
|
ggml_vec_gelu_quick_f16(nc,
|
|
2441
|
-
(ggml_fp16_t *) ((char *) dst->data + i1*
|
|
2442
|
-
(ggml_fp16_t *) ((char *) src0->data + i1*
|
|
2529
|
+
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2530
|
+
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2443
2531
|
|
|
2444
2532
|
#ifndef NDEBUG
|
|
2445
2533
|
for (int k = 0; k < nc; k++) {
|
|
2446
|
-
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2534
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
|
|
2447
2535
|
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2448
2536
|
GGML_UNUSED(v);
|
|
2449
2537
|
assert(!isnan(v));
|
|
2450
2538
|
assert(!isinf(v));
|
|
2451
2539
|
}
|
|
2452
|
-
#endif
|
|
2540
|
+
#endif // NDEBUG
|
|
2453
2541
|
}
|
|
2454
2542
|
}
|
|
2455
2543
|
|
|
@@ -2483,10 +2571,14 @@ static void ggml_compute_forward_silu_f32(
|
|
|
2483
2571
|
|
|
2484
2572
|
const ggml_tensor * src0 = dst->src[0];
|
|
2485
2573
|
|
|
2486
|
-
assert(
|
|
2487
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2574
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2488
2575
|
assert(ggml_are_same_shape(src0, dst));
|
|
2489
2576
|
|
|
2577
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2578
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2579
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2580
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2581
|
+
|
|
2490
2582
|
const int ith = params->ith;
|
|
2491
2583
|
const int nth = params->nth;
|
|
2492
2584
|
|
|
@@ -2500,19 +2592,23 @@ static void ggml_compute_forward_silu_f32(
|
|
|
2500
2592
|
const int ir0 = dr*ith;
|
|
2501
2593
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2502
2594
|
|
|
2503
|
-
for (int
|
|
2595
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2596
|
+
const int i3 = ir/(ne02*ne01);
|
|
2597
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2598
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2599
|
+
|
|
2504
2600
|
ggml_vec_silu_f32(nc,
|
|
2505
|
-
(float *) ((char *) dst->data + i1*
|
|
2506
|
-
(float *) ((char *) src0->data + i1*
|
|
2601
|
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2602
|
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2507
2603
|
|
|
2508
2604
|
#ifndef NDEBUG
|
|
2509
2605
|
for (int k = 0; k < nc; k++) {
|
|
2510
|
-
const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
|
|
2606
|
+
const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
|
|
2511
2607
|
GGML_UNUSED(x);
|
|
2512
2608
|
assert(!isnan(x));
|
|
2513
2609
|
assert(!isinf(x));
|
|
2514
2610
|
}
|
|
2515
|
-
#endif
|
|
2611
|
+
#endif // NDEBUG
|
|
2516
2612
|
}
|
|
2517
2613
|
}
|
|
2518
2614
|
|
|
@@ -2522,10 +2618,14 @@ static void ggml_compute_forward_silu_f16(
|
|
|
2522
2618
|
|
|
2523
2619
|
const ggml_tensor * src0 = dst->src[0];
|
|
2524
2620
|
|
|
2525
|
-
assert(
|
|
2526
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2621
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2527
2622
|
assert(ggml_are_same_shape(src0, dst));
|
|
2528
2623
|
|
|
2624
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2625
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2626
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2627
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2628
|
+
|
|
2529
2629
|
const int ith = params->ith;
|
|
2530
2630
|
const int nth = params->nth;
|
|
2531
2631
|
|
|
@@ -2539,20 +2639,24 @@ static void ggml_compute_forward_silu_f16(
|
|
|
2539
2639
|
const int ir0 = dr*ith;
|
|
2540
2640
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2541
2641
|
|
|
2542
|
-
for (int
|
|
2642
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2643
|
+
const int i3 = ir/(ne02*ne01);
|
|
2644
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2645
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2646
|
+
|
|
2543
2647
|
ggml_vec_silu_f16(nc,
|
|
2544
|
-
(ggml_fp16_t *) ((char *) dst->data + i1*
|
|
2545
|
-
(ggml_fp16_t *) ((char *) src0->data + i1*
|
|
2648
|
+
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2649
|
+
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2546
2650
|
|
|
2547
2651
|
#ifndef NDEBUG
|
|
2548
2652
|
for (int k = 0; k < nc; k++) {
|
|
2549
|
-
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
|
|
2653
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
|
|
2550
2654
|
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2551
2655
|
GGML_UNUSED(v);
|
|
2552
2656
|
assert(!isnan(v));
|
|
2553
2657
|
assert(!isinf(v));
|
|
2554
2658
|
}
|
|
2555
|
-
#endif
|
|
2659
|
+
#endif // NDEBUG
|
|
2556
2660
|
}
|
|
2557
2661
|
}
|
|
2558
2662
|
|
|
@@ -2702,7 +2806,7 @@ static void ggml_compute_forward_silu_back_f32(
|
|
|
2702
2806
|
assert(!isnan(x));
|
|
2703
2807
|
assert(!isinf(x));
|
|
2704
2808
|
}
|
|
2705
|
-
#endif
|
|
2809
|
+
#endif // NDEBUG
|
|
2706
2810
|
}
|
|
2707
2811
|
}
|
|
2708
2812
|
|
|
@@ -2738,7 +2842,7 @@ static void ggml_compute_forward_silu_back_f16(
|
|
|
2738
2842
|
(ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
|
|
2739
2843
|
(ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
|
|
2740
2844
|
|
|
2741
|
-
|
|
2845
|
+
#ifndef NDEBUG
|
|
2742
2846
|
for (int k = 0; k < nc; k++) {
|
|
2743
2847
|
const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2744
2848
|
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
@@ -2746,7 +2850,7 @@ static void ggml_compute_forward_silu_back_f16(
|
|
|
2746
2850
|
assert(!isnan(v));
|
|
2747
2851
|
assert(!isinf(v));
|
|
2748
2852
|
}
|
|
2749
|
-
|
|
2853
|
+
#endif // NDEBUG
|
|
2750
2854
|
}
|
|
2751
2855
|
}
|
|
2752
2856
|
|
|
@@ -2829,7 +2933,7 @@ static void ggml_compute_forward_reglu_f32(
|
|
|
2829
2933
|
assert(!isnan(x));
|
|
2830
2934
|
assert(!isinf(x));
|
|
2831
2935
|
}
|
|
2832
|
-
#endif
|
|
2936
|
+
#endif // NDEBUG
|
|
2833
2937
|
}
|
|
2834
2938
|
}
|
|
2835
2939
|
|
|
@@ -2889,7 +2993,7 @@ static void ggml_compute_forward_reglu_f16(
|
|
|
2889
2993
|
assert(!isnan(v));
|
|
2890
2994
|
assert(!isinf(v));
|
|
2891
2995
|
}
|
|
2892
|
-
#endif
|
|
2996
|
+
#endif // NDEBUG
|
|
2893
2997
|
}
|
|
2894
2998
|
}
|
|
2895
2999
|
|
|
@@ -2972,7 +3076,7 @@ static void ggml_compute_forward_geglu_f32(
|
|
|
2972
3076
|
assert(!isnan(x));
|
|
2973
3077
|
assert(!isinf(x));
|
|
2974
3078
|
}
|
|
2975
|
-
#endif
|
|
3079
|
+
#endif // NDEBUG
|
|
2976
3080
|
}
|
|
2977
3081
|
}
|
|
2978
3082
|
|
|
@@ -3032,7 +3136,7 @@ static void ggml_compute_forward_geglu_f16(
|
|
|
3032
3136
|
assert(!isnan(v));
|
|
3033
3137
|
assert(!isinf(v));
|
|
3034
3138
|
}
|
|
3035
|
-
#endif
|
|
3139
|
+
#endif // NDEBUG
|
|
3036
3140
|
}
|
|
3037
3141
|
}
|
|
3038
3142
|
|
|
@@ -3115,7 +3219,7 @@ static void ggml_compute_forward_swiglu_f32(
|
|
|
3115
3219
|
assert(!isnan(x));
|
|
3116
3220
|
assert(!isinf(x));
|
|
3117
3221
|
}
|
|
3118
|
-
#endif
|
|
3222
|
+
#endif // NDEBUG
|
|
3119
3223
|
}
|
|
3120
3224
|
}
|
|
3121
3225
|
|
|
@@ -3175,7 +3279,7 @@ static void ggml_compute_forward_swiglu_f16(
|
|
|
3175
3279
|
assert(!isnan(v));
|
|
3176
3280
|
assert(!isinf(v));
|
|
3177
3281
|
}
|
|
3178
|
-
#endif
|
|
3282
|
+
#endif // NDEBUG
|
|
3179
3283
|
}
|
|
3180
3284
|
}
|
|
3181
3285
|
|
|
@@ -3266,7 +3370,7 @@ static void ggml_compute_forward_swiglu_oai_f32(
|
|
|
3266
3370
|
assert(!isnan(x));
|
|
3267
3371
|
assert(!isinf(x));
|
|
3268
3372
|
}
|
|
3269
|
-
#endif
|
|
3373
|
+
#endif // NDEBUG
|
|
3270
3374
|
}
|
|
3271
3375
|
}
|
|
3272
3376
|
|
|
@@ -3345,7 +3449,7 @@ static void ggml_compute_forward_geglu_erf_f32(
|
|
|
3345
3449
|
assert(!isnan(x));
|
|
3346
3450
|
assert(!isinf(x));
|
|
3347
3451
|
}
|
|
3348
|
-
#endif
|
|
3452
|
+
#endif // NDEBUG
|
|
3349
3453
|
}
|
|
3350
3454
|
}
|
|
3351
3455
|
|
|
@@ -3405,7 +3509,7 @@ static void ggml_compute_forward_geglu_erf_f16(
|
|
|
3405
3509
|
assert(!isnan(v));
|
|
3406
3510
|
assert(!isinf(v));
|
|
3407
3511
|
}
|
|
3408
|
-
#endif
|
|
3512
|
+
#endif // NDEBUG
|
|
3409
3513
|
}
|
|
3410
3514
|
}
|
|
3411
3515
|
|
|
@@ -3488,7 +3592,7 @@ static void ggml_compute_forward_geglu_quick_f32(
|
|
|
3488
3592
|
assert(!isnan(x));
|
|
3489
3593
|
assert(!isinf(x));
|
|
3490
3594
|
}
|
|
3491
|
-
#endif
|
|
3595
|
+
#endif // NDEBUG
|
|
3492
3596
|
}
|
|
3493
3597
|
}
|
|
3494
3598
|
|
|
@@ -3548,7 +3652,7 @@ static void ggml_compute_forward_geglu_quick_f16(
|
|
|
3548
3652
|
assert(!isnan(v));
|
|
3549
3653
|
assert(!isinf(v));
|
|
3550
3654
|
}
|
|
3551
|
-
#endif
|
|
3655
|
+
#endif // NDEBUG
|
|
3552
3656
|
}
|
|
3553
3657
|
}
|
|
3554
3658
|
|
|
@@ -3643,11 +3747,27 @@ void ggml_compute_forward_norm(
|
|
|
3643
3747
|
|
|
3644
3748
|
// ggml_compute_forward_group_rms_norm
|
|
3645
3749
|
|
|
3750
|
+
// fusion kinds that can be combined with the rms_norm computation in a single pass.
|
|
3751
|
+
// extend this enum when adding new fused variants (e.g. FUSE_ADD, FUSE_MUL_ADD, ...).
|
|
3752
|
+
enum ggml_rms_norm_fuse_op {
|
|
3753
|
+
GGML_RMS_NORM_FUSE_OP_NONE,
|
|
3754
|
+
GGML_RMS_NORM_FUSE_OP_MUL,
|
|
3755
|
+
};
|
|
3756
|
+
|
|
3757
|
+
template <ggml_rms_norm_fuse_op FUSE_OP>
|
|
3646
3758
|
static void ggml_compute_forward_rms_norm_f32(
|
|
3647
3759
|
const ggml_compute_params * params,
|
|
3648
|
-
ggml_tensor *
|
|
3760
|
+
ggml_tensor * dst_rms_norm,
|
|
3761
|
+
ggml_tensor * dst_fused = nullptr) {
|
|
3649
3762
|
|
|
3650
|
-
const ggml_tensor * src0 =
|
|
3763
|
+
const ggml_tensor * src0 = dst_rms_norm->src[0];
|
|
3764
|
+
const ggml_tensor * src1 = nullptr;
|
|
3765
|
+
ggml_tensor * dst = dst_rms_norm;
|
|
3766
|
+
|
|
3767
|
+
if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) {
|
|
3768
|
+
src1 = (dst_fused->src[0] == dst_rms_norm) ? dst_fused->src[1] : dst_fused->src[0];
|
|
3769
|
+
dst = dst_fused;
|
|
3770
|
+
}
|
|
3651
3771
|
|
|
3652
3772
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
|
3653
3773
|
|
|
@@ -3656,11 +3776,10 @@ static void ggml_compute_forward_rms_norm_f32(
|
|
|
3656
3776
|
const int ith = params->ith;
|
|
3657
3777
|
const int nth = params->nth;
|
|
3658
3778
|
|
|
3659
|
-
|
|
3779
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
3660
3780
|
|
|
3661
3781
|
float eps;
|
|
3662
|
-
memcpy(&eps,
|
|
3663
|
-
|
|
3782
|
+
memcpy(&eps, dst_rms_norm->op_params, sizeof(float));
|
|
3664
3783
|
GGML_ASSERT(eps >= 0.0f);
|
|
3665
3784
|
|
|
3666
3785
|
// TODO: optimize
|
|
@@ -3670,25 +3789,32 @@ static void ggml_compute_forward_rms_norm_f32(
|
|
|
3670
3789
|
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
3671
3790
|
|
|
3672
3791
|
ggml_float sum = 0.0;
|
|
3792
|
+
// worth switching to explicit SIMD?
|
|
3673
3793
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
3674
3794
|
sum += (ggml_float)(x[i00] * x[i00]);
|
|
3675
3795
|
}
|
|
3676
3796
|
|
|
3677
|
-
const float mean
|
|
3678
|
-
|
|
3679
|
-
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
3680
|
-
|
|
3681
|
-
memcpy(y, x, ne00 * sizeof(float));
|
|
3682
|
-
// for (int i00 = 0; i00 < ne00; i00++) {
|
|
3683
|
-
// y[i00] = x[i00];
|
|
3684
|
-
// }
|
|
3685
|
-
|
|
3797
|
+
const float mean = sum/ne00;
|
|
3686
3798
|
const float scale = 1.0f/sqrtf(mean + eps);
|
|
3687
3799
|
|
|
3688
3800
|
// if you hit this, likely you got an inf somewhere earlier
|
|
3689
3801
|
assert(scale > 0.0f);
|
|
3690
3802
|
|
|
3691
|
-
|
|
3803
|
+
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
3804
|
+
|
|
3805
|
+
if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) {
|
|
3806
|
+
const int64_t i11 = i01 % ne11;
|
|
3807
|
+
const int64_t i12 = i02 % ne12;
|
|
3808
|
+
const int64_t i13 = i03 % ne13;
|
|
3809
|
+
const float * w = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
|
|
3810
|
+
|
|
3811
|
+
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
3812
|
+
y[i00] = x[i00] * scale * w[i00];
|
|
3813
|
+
}
|
|
3814
|
+
} else {
|
|
3815
|
+
memcpy(y, x, ne00 * sizeof(float));
|
|
3816
|
+
ggml_vec_scale_f32(ne00, y, scale);
|
|
3817
|
+
}
|
|
3692
3818
|
}
|
|
3693
3819
|
}
|
|
3694
3820
|
}
|
|
@@ -3703,7 +3829,31 @@ void ggml_compute_forward_rms_norm(
|
|
|
3703
3829
|
switch (src0->type) {
|
|
3704
3830
|
case GGML_TYPE_F32:
|
|
3705
3831
|
{
|
|
3706
|
-
ggml_compute_forward_rms_norm_f32(params, dst);
|
|
3832
|
+
ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_NONE>(params, dst);
|
|
3833
|
+
} break;
|
|
3834
|
+
default:
|
|
3835
|
+
{
|
|
3836
|
+
GGML_ABORT("fatal error");
|
|
3837
|
+
}
|
|
3838
|
+
}
|
|
3839
|
+
}
|
|
3840
|
+
|
|
3841
|
+
// Fused RMS_NORM + MUL: computes dst = rms_norm(src0) * src1 in a single pass.
|
|
3842
|
+
// This avoids materializing the intermediate rms_norm result in memory.
|
|
3843
|
+
void ggml_compute_forward_rms_norm_mul_fused(
|
|
3844
|
+
const ggml_compute_params * params,
|
|
3845
|
+
ggml_tensor * dst_rms_norm,
|
|
3846
|
+
ggml_tensor * dst_mul) {
|
|
3847
|
+
|
|
3848
|
+
GGML_ASSERT(dst_mul != nullptr);
|
|
3849
|
+
GGML_ASSERT(dst_mul->src[0] == dst_rms_norm || dst_mul->src[1] == dst_rms_norm);
|
|
3850
|
+
|
|
3851
|
+
const ggml_tensor * src0 = dst_rms_norm->src[0];
|
|
3852
|
+
|
|
3853
|
+
switch (src0->type) {
|
|
3854
|
+
case GGML_TYPE_F32:
|
|
3855
|
+
{
|
|
3856
|
+
ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_MUL>(params, dst_rms_norm, dst_mul);
|
|
3707
3857
|
} break;
|
|
3708
3858
|
default:
|
|
3709
3859
|
{
|
|
@@ -3858,12 +4008,12 @@ static void ggml_compute_forward_rms_norm_back_f32(
|
|
|
3858
4008
|
// dx := scale(dx, rrms)
|
|
3859
4009
|
float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
3860
4010
|
|
|
3861
|
-
// dx[i00] = (x*(-sum_xdz/sum_eps)
|
|
3862
|
-
|
|
3863
|
-
|
|
3864
|
-
|
|
3865
|
-
|
|
3866
|
-
|
|
4011
|
+
// dx[i00] = (dz + x*(-sum_xdz/sum_eps)) * rrms
|
|
4012
|
+
// note: https://github.com/ggml-org/ggml/issues/1491
|
|
4013
|
+
const float scale_x = (float) (-sum_xdz) / sum_eps;
|
|
4014
|
+
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
4015
|
+
dx[i00] = (dz[i00] + x[i00] * scale_x) * rrms;
|
|
4016
|
+
}
|
|
3867
4017
|
}
|
|
3868
4018
|
}
|
|
3869
4019
|
}
|
|
@@ -4264,12 +4414,14 @@ void ggml_compute_forward_out_prod(
|
|
|
4264
4414
|
const ggml_tensor * src0 = dst->src[0];
|
|
4265
4415
|
|
|
4266
4416
|
switch (src0->type) {
|
|
4417
|
+
case GGML_TYPE_Q1_0:
|
|
4267
4418
|
case GGML_TYPE_Q4_0:
|
|
4268
4419
|
case GGML_TYPE_Q4_1:
|
|
4269
4420
|
case GGML_TYPE_Q5_0:
|
|
4270
4421
|
case GGML_TYPE_Q5_1:
|
|
4271
4422
|
case GGML_TYPE_Q8_0:
|
|
4272
4423
|
case GGML_TYPE_MXFP4:
|
|
4424
|
+
case GGML_TYPE_NVFP4:
|
|
4273
4425
|
case GGML_TYPE_Q2_K:
|
|
4274
4426
|
case GGML_TYPE_Q3_K:
|
|
4275
4427
|
case GGML_TYPE_Q4_K:
|
|
@@ -4538,6 +4690,7 @@ void ggml_compute_forward_set(
|
|
|
4538
4690
|
} break;
|
|
4539
4691
|
case GGML_TYPE_F16:
|
|
4540
4692
|
case GGML_TYPE_BF16:
|
|
4693
|
+
case GGML_TYPE_Q1_0:
|
|
4541
4694
|
case GGML_TYPE_Q4_0:
|
|
4542
4695
|
case GGML_TYPE_Q4_1:
|
|
4543
4696
|
case GGML_TYPE_Q5_0:
|
|
@@ -4545,6 +4698,7 @@ void ggml_compute_forward_set(
|
|
|
4545
4698
|
case GGML_TYPE_Q8_0:
|
|
4546
4699
|
case GGML_TYPE_Q8_1:
|
|
4547
4700
|
case GGML_TYPE_MXFP4:
|
|
4701
|
+
case GGML_TYPE_NVFP4:
|
|
4548
4702
|
case GGML_TYPE_Q2_K:
|
|
4549
4703
|
case GGML_TYPE_Q3_K:
|
|
4550
4704
|
case GGML_TYPE_Q4_K:
|
|
@@ -4760,6 +4914,7 @@ void ggml_compute_forward_get_rows(
|
|
|
4760
4914
|
const ggml_tensor * src0 = dst->src[0];
|
|
4761
4915
|
|
|
4762
4916
|
switch (src0->type) {
|
|
4917
|
+
case GGML_TYPE_Q1_0:
|
|
4763
4918
|
case GGML_TYPE_Q4_0:
|
|
4764
4919
|
case GGML_TYPE_Q4_1:
|
|
4765
4920
|
case GGML_TYPE_Q5_0:
|
|
@@ -4767,6 +4922,7 @@ void ggml_compute_forward_get_rows(
|
|
|
4767
4922
|
case GGML_TYPE_Q8_0:
|
|
4768
4923
|
case GGML_TYPE_Q8_1:
|
|
4769
4924
|
case GGML_TYPE_MXFP4:
|
|
4925
|
+
case GGML_TYPE_NVFP4:
|
|
4770
4926
|
case GGML_TYPE_Q2_K:
|
|
4771
4927
|
case GGML_TYPE_Q3_K:
|
|
4772
4928
|
case GGML_TYPE_Q4_K:
|
|
@@ -5239,7 +5395,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
5239
5395
|
//printf("p[%d] = %f\n", i, p[i]);
|
|
5240
5396
|
assert(!isnan(wp[i]));
|
|
5241
5397
|
}
|
|
5242
|
-
#endif
|
|
5398
|
+
#endif // NDEBUG
|
|
5243
5399
|
|
|
5244
5400
|
float max = -INFINITY;
|
|
5245
5401
|
ggml_vec_max_f32(ne00, &max, wp);
|
|
@@ -5264,7 +5420,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
5264
5420
|
assert(!isnan(dp[i]));
|
|
5265
5421
|
assert(!isinf(dp[i]));
|
|
5266
5422
|
}
|
|
5267
|
-
#endif
|
|
5423
|
+
#endif // NDEBUG
|
|
5268
5424
|
}
|
|
5269
5425
|
}
|
|
5270
5426
|
}
|
|
@@ -5338,7 +5494,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32(
|
|
|
5338
5494
|
assert(!isnan(dy[i]));
|
|
5339
5495
|
assert(!isnan(y[i]));
|
|
5340
5496
|
}
|
|
5341
|
-
#endif
|
|
5497
|
+
#endif // NDEBUG
|
|
5342
5498
|
// Jii = yi - yi*yi
|
|
5343
5499
|
// Jij = -yi*yj
|
|
5344
5500
|
// J = diag(y)-y.T*y
|
|
@@ -5371,7 +5527,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32(
|
|
|
5371
5527
|
assert(!isnan(dx[i]));
|
|
5372
5528
|
assert(!isinf(dx[i]));
|
|
5373
5529
|
}
|
|
5374
|
-
#endif
|
|
5530
|
+
#endif // NDEBUG
|
|
5375
5531
|
}
|
|
5376
5532
|
}
|
|
5377
5533
|
|
|
@@ -5484,6 +5640,7 @@ void ggml_compute_forward_clamp(
|
|
|
5484
5640
|
ggml_compute_forward_clamp_f16(params, dst);
|
|
5485
5641
|
} break;
|
|
5486
5642
|
case GGML_TYPE_BF16:
|
|
5643
|
+
case GGML_TYPE_Q1_0:
|
|
5487
5644
|
case GGML_TYPE_Q4_0:
|
|
5488
5645
|
case GGML_TYPE_Q4_1:
|
|
5489
5646
|
case GGML_TYPE_Q5_0:
|
|
@@ -5491,6 +5648,7 @@ void ggml_compute_forward_clamp(
|
|
|
5491
5648
|
case GGML_TYPE_Q8_0:
|
|
5492
5649
|
case GGML_TYPE_Q8_1:
|
|
5493
5650
|
case GGML_TYPE_MXFP4:
|
|
5651
|
+
case GGML_TYPE_NVFP4:
|
|
5494
5652
|
case GGML_TYPE_Q2_K:
|
|
5495
5653
|
case GGML_TYPE_Q3_K:
|
|
5496
5654
|
case GGML_TYPE_Q4_K:
|
|
@@ -5739,28 +5897,33 @@ static void ggml_compute_forward_rope_flt(
|
|
|
5739
5897
|
|
|
5740
5898
|
const int32_t * pos = (const int32_t *) src1->data;
|
|
5741
5899
|
|
|
5900
|
+
int64_t last_i2 = -1;
|
|
5901
|
+
|
|
5742
5902
|
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
|
|
5743
5903
|
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
|
5744
|
-
|
|
5745
|
-
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
5746
|
-
if (!mrope_used) {
|
|
5747
|
-
const int64_t p = pos[i2];
|
|
5748
|
-
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5749
|
-
}
|
|
5750
|
-
else {
|
|
5751
|
-
const int64_t p_t = pos[i2];
|
|
5752
|
-
const int64_t p_h = pos[i2 + ne2];
|
|
5753
|
-
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
5754
|
-
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
5755
|
-
ggml_mrope_cache_init(
|
|
5756
|
-
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
|
5757
|
-
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5758
|
-
}
|
|
5759
|
-
|
|
5760
5904
|
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
|
5761
|
-
if (ir++ < ir0) continue;
|
|
5905
|
+
if (ir++ < ir0) continue; // skip rows mapped to other threads
|
|
5762
5906
|
if (ir > ir1) break;
|
|
5763
5907
|
|
|
5908
|
+
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
5909
|
+
if (last_i2 != i2) {
|
|
5910
|
+
if (!mrope_used) {
|
|
5911
|
+
const int64_t p = pos[i2];
|
|
5912
|
+
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5913
|
+
}
|
|
5914
|
+
else {
|
|
5915
|
+
const int64_t p_t = pos[i2];
|
|
5916
|
+
const int64_t p_h = pos[i2 + ne2];
|
|
5917
|
+
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
5918
|
+
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
5919
|
+
ggml_mrope_cache_init(
|
|
5920
|
+
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
|
5921
|
+
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5922
|
+
}
|
|
5923
|
+
|
|
5924
|
+
last_i2 = i2;
|
|
5925
|
+
}
|
|
5926
|
+
|
|
5764
5927
|
T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
|
5765
5928
|
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
|
5766
5929
|
|
|
@@ -6129,7 +6292,7 @@ static void ggml_compute_forward_im2col_f16(
|
|
|
6129
6292
|
const ggml_tensor * src1 = dst->src[1];
|
|
6130
6293
|
|
|
6131
6294
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
6132
|
-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
6295
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
|
6133
6296
|
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
|
6134
6297
|
|
|
6135
6298
|
GGML_TENSOR_BINARY_OP_LOCALS;
|
|
@@ -6160,7 +6323,7 @@ static void ggml_compute_forward_im2col_f16(
|
|
|
6160
6323
|
int ofs1 = is_2D ? nb12 : nb11;
|
|
6161
6324
|
|
|
6162
6325
|
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
|
6163
|
-
GGML_ASSERT(nb10 ==
|
|
6326
|
+
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
|
6164
6327
|
|
|
6165
6328
|
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
|
6166
6329
|
{
|
|
@@ -6173,7 +6336,12 @@ static void ggml_compute_forward_im2col_f16(
|
|
|
6173
6336
|
|
|
6174
6337
|
// micro kernel
|
|
6175
6338
|
ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
|
6176
|
-
const float * const
|
|
6339
|
+
const float * const src_data_f32 = src1->type == GGML_TYPE_F32
|
|
6340
|
+
? (const float *)((const char *) src1->data + in*ofs0 + iic*ofs1)
|
|
6341
|
+
: nullptr; // [IH, IW]
|
|
6342
|
+
const ggml_fp16_t * const src_data_f16 = src1->type == GGML_TYPE_F16
|
|
6343
|
+
? (const ggml_fp16_t *)((const char *) src1->data + in*ofs0 + iic*ofs1)
|
|
6344
|
+
: nullptr; // [IH, IW]
|
|
6177
6345
|
|
|
6178
6346
|
for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
|
|
6179
6347
|
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
|
@@ -6183,7 +6351,11 @@ static void ggml_compute_forward_im2col_f16(
|
|
|
6183
6351
|
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
|
6184
6352
|
dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
|
|
6185
6353
|
} else {
|
|
6186
|
-
|
|
6354
|
+
if (src_data_f32 != nullptr) {
|
|
6355
|
+
dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data_f32[iih*IW + iiw]);
|
|
6356
|
+
} else {
|
|
6357
|
+
dst_data[iic*(KH*KW) + ikh*KW + ikw] = src_data_f16[iih*IW + iiw];
|
|
6358
|
+
}
|
|
6187
6359
|
}
|
|
6188
6360
|
}
|
|
6189
6361
|
}
|
|
@@ -6558,6 +6730,78 @@ static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
|
|
|
6558
6730
|
return (coord + size) % size; // adding size avoids negative number weirdness
|
|
6559
6731
|
}
|
|
6560
6732
|
|
|
6733
|
+
// ggml_compute_forward_col2im_1d
|
|
6734
|
+
//
|
|
6735
|
+
// Scatter-add columns [K*OC, T_in] -> signal [T_out, OC]
|
|
6736
|
+
// where T_out = (T_in - 1)*s + K - 2*p. Gather approach: each output reads ceil(K/s) inputs.
|
|
6737
|
+
// Parallelized over the time axis so the split stays balanced whatever OC is.
|
|
6738
|
+
// Supports F32, F16, BF16 input/output (same type), F32 accumulator.
|
|
6739
|
+
|
|
6740
|
+
template <typename elem_t>
|
|
6741
|
+
static void ggml_compute_forward_col2im_1d_impl(
|
|
6742
|
+
const ggml_compute_params * params,
|
|
6743
|
+
ggml_tensor * dst) {
|
|
6744
|
+
|
|
6745
|
+
const ggml_tensor * src = dst->src[0]; // [K*OC, T_in]
|
|
6746
|
+
|
|
6747
|
+
GGML_ASSERT(ggml_is_contiguous(src));
|
|
6748
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
6749
|
+
|
|
6750
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
6751
|
+
const int32_t OC = ((const int32_t *)(dst->op_params))[1];
|
|
6752
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
|
6753
|
+
|
|
6754
|
+
const int64_t K_OC = src->ne[0];
|
|
6755
|
+
const int64_t T_in = src->ne[1];
|
|
6756
|
+
const int64_t K = K_OC / OC;
|
|
6757
|
+
const int64_t T_out = dst->ne[0];
|
|
6758
|
+
|
|
6759
|
+
const elem_t * col_data = (const elem_t *) src->data;
|
|
6760
|
+
elem_t * dst_data = (elem_t *) dst->data;
|
|
6761
|
+
|
|
6762
|
+
const int ith = params->ith;
|
|
6763
|
+
const int nth = params->nth;
|
|
6764
|
+
|
|
6765
|
+
// Parallelize over the time axis: the split stays balanced whatever OC is,
|
|
6766
|
+
// down to OC = 1 for mono audio, and threads read disjoint column bands
|
|
6767
|
+
const int64_t dr = (T_out + nth - 1) / nth;
|
|
6768
|
+
const int64_t it0 = dr * ith;
|
|
6769
|
+
const int64_t it1 = it0 + dr < T_out ? it0 + dr : T_out;
|
|
6770
|
+
|
|
6771
|
+
for (int64_t oc = 0; oc < OC; oc++) {
|
|
6772
|
+
for (int64_t t_out = it0; t_out < it1; t_out++) {
|
|
6773
|
+
const int64_t t_abs = t_out + p0; // absolute position in uncropped signal
|
|
6774
|
+
// Gather: find all (t_in, k) where t_in * s + k == t_abs, 0 <= k < K
|
|
6775
|
+
int64_t t_in_min = (t_abs - K + 1 + s0 - 1) / s0; // ceil((t_abs-K+1)/s)
|
|
6776
|
+
if (t_in_min < 0) t_in_min = 0;
|
|
6777
|
+
int64_t t_in_max = t_abs / s0;
|
|
6778
|
+
if (t_in_max >= T_in) t_in_max = T_in - 1;
|
|
6779
|
+
|
|
6780
|
+
float sum = 0.0f;
|
|
6781
|
+
for (int64_t t_in = t_in_min; t_in <= t_in_max; t_in++) {
|
|
6782
|
+
int64_t k = t_abs - t_in * s0;
|
|
6783
|
+
if (k >= 0 && k < K) {
|
|
6784
|
+
// col layout: [K*OC, T_in], element (oc*K+k, t_in)
|
|
6785
|
+
sum += type_conversion_table<elem_t>::to_f32(col_data[(oc * K + k) + t_in * K_OC]);
|
|
6786
|
+
}
|
|
6787
|
+
}
|
|
6788
|
+
// dst layout: [T_out, OC], element (t_out, oc)
|
|
6789
|
+
dst_data[t_out + oc * T_out] = type_conversion_table<elem_t>::from_f32(sum);
|
|
6790
|
+
}
|
|
6791
|
+
}
|
|
6792
|
+
}
|
|
6793
|
+
|
|
6794
|
+
void ggml_compute_forward_col2im_1d(
|
|
6795
|
+
const ggml_compute_params * params,
|
|
6796
|
+
ggml_tensor * dst) {
|
|
6797
|
+
switch (dst->src[0]->type) {
|
|
6798
|
+
case GGML_TYPE_F32: ggml_compute_forward_col2im_1d_impl<float> (params, dst); break;
|
|
6799
|
+
case GGML_TYPE_F16: ggml_compute_forward_col2im_1d_impl<ggml_fp16_t>(params, dst); break;
|
|
6800
|
+
case GGML_TYPE_BF16: ggml_compute_forward_col2im_1d_impl<ggml_bf16_t>(params, dst); break;
|
|
6801
|
+
default: GGML_ABORT("col2im_1d: unsupported type %d", dst->src[0]->type);
|
|
6802
|
+
}
|
|
6803
|
+
}
|
|
6804
|
+
|
|
6561
6805
|
// ggml_compute_forward_conv_2d
|
|
6562
6806
|
|
|
6563
6807
|
|
|
@@ -6838,16 +7082,15 @@ void ggml_compute_forward_conv_3d(
|
|
|
6838
7082
|
ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
|
|
6839
7083
|
}
|
|
6840
7084
|
|
|
6841
|
-
|
|
6842
|
-
|
|
6843
|
-
|
|
6844
|
-
|
|
6845
|
-
ggml_tensor * dst) {
|
|
7085
|
+
template <typename kernel_t>
|
|
7086
|
+
static void ggml_compute_forward_conv_transpose_2d_impl(
|
|
7087
|
+
const ggml_compute_params * params,
|
|
7088
|
+
ggml_tensor * dst) {
|
|
6846
7089
|
|
|
6847
7090
|
const ggml_tensor * src0 = dst->src[0];
|
|
6848
7091
|
const ggml_tensor * src1 = dst->src[1];
|
|
6849
7092
|
|
|
6850
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
7093
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
|
|
6851
7094
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
6852
7095
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
6853
7096
|
|
|
@@ -6858,7 +7101,7 @@ void ggml_compute_forward_conv_transpose_2d(
|
|
|
6858
7101
|
|
|
6859
7102
|
const int nk = ne00*ne01*ne02*ne03;
|
|
6860
7103
|
|
|
6861
|
-
GGML_ASSERT(nb00 ==
|
|
7104
|
+
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
|
6862
7105
|
GGML_ASSERT(nb10 == sizeof(float));
|
|
6863
7106
|
|
|
6864
7107
|
if (ith == 0) {
|
|
@@ -6866,12 +7109,12 @@ void ggml_compute_forward_conv_transpose_2d(
|
|
|
6866
7109
|
|
|
6867
7110
|
// permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
|
|
6868
7111
|
{
|
|
6869
|
-
|
|
7112
|
+
kernel_t * const wdata = (kernel_t *) params->wdata + 0;
|
|
6870
7113
|
|
|
6871
7114
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
6872
7115
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
6873
|
-
const
|
|
6874
|
-
|
|
7116
|
+
const kernel_t * const src = (kernel_t *)((char *) src0->data + i03*nb03 + i02*nb02);
|
|
7117
|
+
kernel_t * dst_data = wdata + i02*ne01*ne00*ne03;
|
|
6875
7118
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
|
6876
7119
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
6877
7120
|
dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
|
|
@@ -6883,13 +7126,17 @@ void ggml_compute_forward_conv_transpose_2d(
|
|
|
6883
7126
|
|
|
6884
7127
|
// permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
|
|
6885
7128
|
{
|
|
6886
|
-
|
|
7129
|
+
kernel_t * const wdata = (kernel_t *) params->wdata + nk;
|
|
6887
7130
|
for (int i12 = 0; i12 < ne12; i12++) {
|
|
6888
7131
|
for (int i11 = 0; i11 < ne11; i11++) {
|
|
6889
7132
|
const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
|
|
6890
|
-
|
|
7133
|
+
kernel_t * dst_data = wdata + i11*ne10*ne12;
|
|
6891
7134
|
for (int i10 = 0; i10 < ne10; i10++) {
|
|
6892
|
-
|
|
7135
|
+
if constexpr (std::is_same_v<kernel_t, ggml_fp16_t>) {
|
|
7136
|
+
dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
|
|
7137
|
+
} else {
|
|
7138
|
+
dst_data[i10*ne12 + i12] = src[i10];
|
|
7139
|
+
}
|
|
6893
7140
|
}
|
|
6894
7141
|
}
|
|
6895
7142
|
}
|
|
@@ -6911,21 +7158,27 @@ void ggml_compute_forward_conv_transpose_2d(
|
|
|
6911
7158
|
const int ip0 = dp*ith;
|
|
6912
7159
|
const int ip1 = MIN(ip0 + dp, np);
|
|
6913
7160
|
|
|
6914
|
-
|
|
6915
|
-
|
|
7161
|
+
kernel_t * const wdata = (kernel_t *) params->wdata + 0;
|
|
7162
|
+
kernel_t * const wdata_src = wdata + nk;
|
|
6916
7163
|
|
|
6917
7164
|
for (int i2 = ip0; i2 < ip1; i2++) { // Cout
|
|
6918
7165
|
float * dst_data = (float *)((char *) dst->data + i2*nb2);
|
|
6919
|
-
|
|
7166
|
+
kernel_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
|
|
6920
7167
|
for (int i11 = 0; i11 < ne11; i11++) {
|
|
6921
7168
|
for (int i10 = 0; i10 < ne10; i10++) {
|
|
6922
7169
|
const int i1n = i11*ne10*ne12 + i10*ne12;
|
|
6923
7170
|
for (int i01 = 0; i01 < ne01; i01++) {
|
|
6924
7171
|
for (int i00 = 0; i00 < ne00; i00++) {
|
|
6925
7172
|
float v = 0;
|
|
6926
|
-
|
|
6927
|
-
|
|
6928
|
-
|
|
7173
|
+
if constexpr (std::is_same_v<kernel_t, ggml_fp16_t>) {
|
|
7174
|
+
ggml_vec_dot_f16(ne03, &v, 0,
|
|
7175
|
+
wdata_src + i1n, 0,
|
|
7176
|
+
wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
|
|
7177
|
+
} else {
|
|
7178
|
+
ggml_vec_dot_f32(ne03, &v, 0,
|
|
7179
|
+
wdata_src + i1n, 0,
|
|
7180
|
+
wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
|
|
7181
|
+
}
|
|
6929
7182
|
dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
|
|
6930
7183
|
}
|
|
6931
7184
|
}
|
|
@@ -6934,19 +7187,41 @@ void ggml_compute_forward_conv_transpose_2d(
|
|
|
6934
7187
|
}
|
|
6935
7188
|
}
|
|
6936
7189
|
|
|
6937
|
-
|
|
7190
|
+
void ggml_compute_forward_conv_transpose_2d(
|
|
7191
|
+
const ggml_compute_params * params,
|
|
7192
|
+
ggml_tensor * dst) {
|
|
6938
7193
|
|
|
6939
|
-
|
|
6940
|
-
|
|
6941
|
-
|
|
6942
|
-
|
|
6943
|
-
|
|
6944
|
-
|
|
6945
|
-
|
|
6946
|
-
|
|
6947
|
-
|
|
6948
|
-
|
|
6949
|
-
|
|
7194
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
7195
|
+
|
|
7196
|
+
switch (src0->type) {
|
|
7197
|
+
case GGML_TYPE_F16:
|
|
7198
|
+
{
|
|
7199
|
+
ggml_compute_forward_conv_transpose_2d_impl<ggml_fp16_t>(params, dst);
|
|
7200
|
+
} break;
|
|
7201
|
+
case GGML_TYPE_F32:
|
|
7202
|
+
{
|
|
7203
|
+
ggml_compute_forward_conv_transpose_2d_impl<float>(params, dst);
|
|
7204
|
+
} break;
|
|
7205
|
+
default:
|
|
7206
|
+
{
|
|
7207
|
+
GGML_ABORT("fatal error");
|
|
7208
|
+
}
|
|
7209
|
+
}
|
|
7210
|
+
}
|
|
7211
|
+
|
|
7212
|
+
// ggml_compute_forward_conv_2d_dw
|
|
7213
|
+
|
|
7214
|
+
struct ggml_conv_2d_dw_params {
|
|
7215
|
+
int64_t channels;
|
|
7216
|
+
int64_t batch;
|
|
7217
|
+
int64_t src_w;
|
|
7218
|
+
int64_t src_h;
|
|
7219
|
+
int64_t dst_w;
|
|
7220
|
+
int64_t dst_h;
|
|
7221
|
+
int64_t knl_w;
|
|
7222
|
+
int64_t knl_h;
|
|
7223
|
+
int stride_x;
|
|
7224
|
+
int stride_y;
|
|
6950
7225
|
int pad_x;
|
|
6951
7226
|
int pad_y;
|
|
6952
7227
|
int dilation_x;
|
|
@@ -7110,12 +7385,13 @@ void ggml_compute_forward_conv_2d_dw(
|
|
|
7110
7385
|
}
|
|
7111
7386
|
}
|
|
7112
7387
|
|
|
7113
|
-
//
|
|
7114
|
-
|
|
7115
|
-
static void ggml_compute_forward_pool_1d_sk_p0(
|
|
7388
|
+
// ggml_compute_forward_pool_1d_ksp
|
|
7389
|
+
static void ggml_compute_forward_pool_1d_ksp(
|
|
7116
7390
|
const ggml_compute_params * params,
|
|
7117
7391
|
const ggml_op_pool op,
|
|
7118
7392
|
const int k,
|
|
7393
|
+
const int s,
|
|
7394
|
+
const int p,
|
|
7119
7395
|
ggml_tensor * dst) {
|
|
7120
7396
|
|
|
7121
7397
|
const ggml_tensor * src = dst->src[0];
|
|
@@ -7126,39 +7402,56 @@ static void ggml_compute_forward_pool_1d_sk_p0(
|
|
|
7126
7402
|
return;
|
|
7127
7403
|
}
|
|
7128
7404
|
|
|
7129
|
-
const
|
|
7130
|
-
const
|
|
7131
|
-
float * drow = (float *)dst->data;
|
|
7405
|
+
const int64_t IW = src->ne[0];
|
|
7406
|
+
const int64_t OW = dst->ne[0];
|
|
7132
7407
|
|
|
7133
|
-
const int64_t
|
|
7408
|
+
const int64_t nr = ggml_nrows(src);
|
|
7134
7409
|
|
|
7135
|
-
|
|
7136
|
-
const
|
|
7137
|
-
|
|
7138
|
-
|
|
7410
|
+
for (int64_t ir = 0; ir < nr; ++ir) {
|
|
7411
|
+
const char * srow_bytes = (const char *) src->data + ir * src->nb[1];
|
|
7412
|
+
float * drow = (float *) (( char *) dst->data + ir * dst->nb[1]);
|
|
7413
|
+
|
|
7414
|
+
for (int64_t ow = 0; ow < OW; ++ow) {
|
|
7415
|
+
float res = 0;
|
|
7139
7416
|
switch (op) {
|
|
7140
|
-
case GGML_OP_POOL_AVG:
|
|
7141
|
-
case GGML_OP_POOL_MAX:
|
|
7417
|
+
case GGML_OP_POOL_AVG: res = 0.0f; break;
|
|
7418
|
+
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
|
7142
7419
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7143
7420
|
}
|
|
7421
|
+
|
|
7422
|
+
int count = 0;
|
|
7423
|
+
const int base = (int) ow * s - p;
|
|
7424
|
+
|
|
7144
7425
|
for (int ki = 0; ki < k; ++ki) {
|
|
7145
|
-
const
|
|
7426
|
+
const int j = base + ki;
|
|
7427
|
+
if (j < 0 || j >= (int) IW) {
|
|
7428
|
+
continue;
|
|
7429
|
+
}
|
|
7430
|
+
|
|
7431
|
+
float v;
|
|
7432
|
+
if (src->type == GGML_TYPE_F32) {
|
|
7433
|
+
v = ((const float *) srow_bytes)[j];
|
|
7434
|
+
} else {
|
|
7435
|
+
v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
|
|
7436
|
+
}
|
|
7437
|
+
|
|
7146
7438
|
switch (op) {
|
|
7147
|
-
case GGML_OP_POOL_AVG:
|
|
7148
|
-
case GGML_OP_POOL_MAX:
|
|
7149
|
-
case GGML_OP_POOL_COUNT:
|
|
7439
|
+
case GGML_OP_POOL_AVG: res += v; break;
|
|
7440
|
+
case GGML_OP_POOL_MAX: res = std::max(v, res); break;
|
|
7441
|
+
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7150
7442
|
}
|
|
7151
|
-
|
|
7443
|
+
|
|
7444
|
+
++count;
|
|
7152
7445
|
}
|
|
7446
|
+
|
|
7153
7447
|
switch (op) {
|
|
7154
|
-
case GGML_OP_POOL_AVG:
|
|
7155
|
-
case GGML_OP_POOL_MAX:
|
|
7448
|
+
case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
|
|
7449
|
+
case GGML_OP_POOL_MAX: break;
|
|
7156
7450
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7157
7451
|
}
|
|
7158
|
-
}
|
|
7159
7452
|
|
|
7160
|
-
|
|
7161
|
-
|
|
7453
|
+
drow[ow] = res;
|
|
7454
|
+
}
|
|
7162
7455
|
}
|
|
7163
7456
|
}
|
|
7164
7457
|
|
|
@@ -7173,10 +7466,8 @@ void ggml_compute_forward_pool_1d(
|
|
|
7173
7466
|
const int k0 = opts[1];
|
|
7174
7467
|
const int s0 = opts[2];
|
|
7175
7468
|
const int p0 = opts[3];
|
|
7176
|
-
GGML_ASSERT(p0 == 0); // padding not supported
|
|
7177
|
-
GGML_ASSERT(k0 == s0); // only s = k supported
|
|
7178
7469
|
|
|
7179
|
-
|
|
7470
|
+
ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
|
|
7180
7471
|
}
|
|
7181
7472
|
|
|
7182
7473
|
// ggml_compute_forward_pool_2d
|
|
@@ -7194,6 +7485,7 @@ void ggml_compute_forward_pool_2d(
|
|
|
7194
7485
|
}
|
|
7195
7486
|
|
|
7196
7487
|
const int32_t * opts = (const int32_t *)dst->op_params;
|
|
7488
|
+
|
|
7197
7489
|
ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
|
|
7198
7490
|
const int k0 = opts[1];
|
|
7199
7491
|
const int k1 = opts[2];
|
|
@@ -7217,11 +7509,13 @@ void ggml_compute_forward_pool_2d(
|
|
|
7217
7509
|
while (cdata < data_end) {
|
|
7218
7510
|
for (int oy = 0; oy < py; ++oy) {
|
|
7219
7511
|
float * const drow = dplane + oy * px;
|
|
7512
|
+
float * const out = drow;
|
|
7513
|
+
|
|
7220
7514
|
for (int ox = 0; ox < px; ++ox) {
|
|
7221
|
-
float
|
|
7515
|
+
float res = 0;
|
|
7222
7516
|
switch (op) {
|
|
7223
|
-
case GGML_OP_POOL_AVG:
|
|
7224
|
-
case GGML_OP_POOL_MAX:
|
|
7517
|
+
case GGML_OP_POOL_AVG: res = 0; break;
|
|
7518
|
+
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
|
7225
7519
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7226
7520
|
}
|
|
7227
7521
|
|
|
@@ -7229,24 +7523,32 @@ void ggml_compute_forward_pool_2d(
|
|
|
7229
7523
|
const int iy = offset1 + oy * s1;
|
|
7230
7524
|
|
|
7231
7525
|
for (int ky = 0; ky < k1; ++ky) {
|
|
7232
|
-
if (iy + ky < 0 || iy + ky >= src->ne[1])
|
|
7526
|
+
if (iy + ky < 0 || iy + ky >= src->ne[1]) {
|
|
7527
|
+
continue;
|
|
7528
|
+
}
|
|
7529
|
+
|
|
7233
7530
|
const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
|
|
7234
7531
|
for (int kx = 0; kx < k0; ++kx) {
|
|
7235
7532
|
int j = ix + kx;
|
|
7236
|
-
if (j < 0 || j >= src->ne[0])
|
|
7533
|
+
if (j < 0 || j >= src->ne[0]) {
|
|
7534
|
+
continue;
|
|
7535
|
+
}
|
|
7536
|
+
|
|
7237
7537
|
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
|
|
7238
7538
|
switch (op) {
|
|
7239
|
-
case GGML_OP_POOL_AVG:
|
|
7240
|
-
case GGML_OP_POOL_MAX:
|
|
7539
|
+
case GGML_OP_POOL_AVG: res += srow_j; break;
|
|
7540
|
+
case GGML_OP_POOL_MAX: res = std::max(srow_j, res); break;
|
|
7241
7541
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7242
7542
|
}
|
|
7243
7543
|
}
|
|
7244
7544
|
}
|
|
7245
7545
|
switch (op) {
|
|
7246
|
-
case GGML_OP_POOL_AVG:
|
|
7247
|
-
case GGML_OP_POOL_MAX:
|
|
7546
|
+
case GGML_OP_POOL_AVG: res /= ka; break;
|
|
7547
|
+
case GGML_OP_POOL_MAX: break;
|
|
7248
7548
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7249
7549
|
}
|
|
7550
|
+
|
|
7551
|
+
out[ox] = res;
|
|
7250
7552
|
}
|
|
7251
7553
|
}
|
|
7252
7554
|
|
|
@@ -7603,8 +7905,7 @@ static void ggml_compute_forward_pad_f32(
|
|
|
7603
7905
|
|
|
7604
7906
|
const ggml_tensor * src0 = dst->src[0];
|
|
7605
7907
|
|
|
7606
|
-
|
|
7607
|
-
GGML_ASSERT( dst->nb[0] == sizeof(float));
|
|
7908
|
+
assert(dst->nb[0] == sizeof(float));
|
|
7608
7909
|
|
|
7609
7910
|
const int ith = params->ith;
|
|
7610
7911
|
const int nth = params->nth;
|
|
@@ -8016,12 +8317,14 @@ void ggml_compute_forward_top_k(
|
|
|
8016
8317
|
}
|
|
8017
8318
|
}
|
|
8018
8319
|
|
|
8019
|
-
// ggml_compute_forward_flash_attn_ext
|
|
8020
|
-
|
|
8021
8320
|
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
8022
8321
|
const ggml_compute_params * params,
|
|
8023
8322
|
ggml_tensor * dst,
|
|
8024
|
-
int ir0, int ir1
|
|
8323
|
+
int ir0, int ir1,
|
|
8324
|
+
int64_t ic_start, int64_t ic_end,
|
|
8325
|
+
float * partials, int64_t partial_stride) {
|
|
8326
|
+
|
|
8327
|
+
const bool write_partials = (partials != nullptr);
|
|
8025
8328
|
const ggml_tensor * q = dst->src[0];
|
|
8026
8329
|
const ggml_tensor * k = dst->src[1];
|
|
8027
8330
|
const ggml_tensor * v = dst->src[2];
|
|
@@ -8098,7 +8401,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
|
8098
8401
|
|
|
8099
8402
|
int ith = params->ith;
|
|
8100
8403
|
|
|
8101
|
-
// loop over n_batch and n_head
|
|
8102
8404
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
8103
8405
|
// q indices
|
|
8104
8406
|
const int iq3 = ir/(neq2*neq1);
|
|
@@ -8138,7 +8440,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
|
8138
8440
|
// online softmax / attention
|
|
8139
8441
|
// loop over n_kv and n_head_kv
|
|
8140
8442
|
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
|
8141
|
-
|
|
8443
|
+
|
|
8444
|
+
for (int64_t ic = ic_start; ic < ic_end; ++ic) {
|
|
8142
8445
|
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
|
|
8143
8446
|
if (mv == -INFINITY) {
|
|
8144
8447
|
continue;
|
|
@@ -8211,8 +8514,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
|
8211
8514
|
}
|
|
8212
8515
|
}
|
|
8213
8516
|
|
|
8214
|
-
// sinks
|
|
8215
|
-
if (sinks) {
|
|
8517
|
+
// sinks - apply only on the first kv-chunk
|
|
8518
|
+
if (sinks && ic_start == 0) {
|
|
8216
8519
|
const float s = ((float *)((char *) sinks->data))[h];
|
|
8217
8520
|
|
|
8218
8521
|
float ms = 1.0f;
|
|
@@ -8220,6 +8523,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
|
8220
8523
|
|
|
8221
8524
|
if (s > M) {
|
|
8222
8525
|
ms = expf(M - s);
|
|
8526
|
+
M = s;
|
|
8223
8527
|
ggml_vec_scale_f32(DV, VKQ32, ms);
|
|
8224
8528
|
} else {
|
|
8225
8529
|
vs = expf(s - M);
|
|
@@ -8228,20 +8532,386 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
|
8228
8532
|
S = S*ms + vs;
|
|
8229
8533
|
}
|
|
8230
8534
|
|
|
8231
|
-
|
|
8232
|
-
|
|
8233
|
-
|
|
8535
|
+
if (write_partials) {
|
|
8536
|
+
// Write M, S, VKQ to partials for later reduction
|
|
8537
|
+
// partials layout: [M, S, VKQ[DV]] per query head
|
|
8538
|
+
float * partial = partials + ir * partial_stride;
|
|
8539
|
+
partial[0] = M;
|
|
8540
|
+
partial[1] = S;
|
|
8541
|
+
memcpy(partial + 2, VKQ32, DV * sizeof(float));
|
|
8542
|
+
} else {
|
|
8543
|
+
// V /= S
|
|
8544
|
+
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
|
8545
|
+
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
|
8234
8546
|
|
|
8235
|
-
|
|
8236
|
-
|
|
8237
|
-
|
|
8238
|
-
|
|
8547
|
+
// dst indices
|
|
8548
|
+
const int i1 = iq1;
|
|
8549
|
+
const int i2 = iq2;
|
|
8550
|
+
const int i3 = iq3;
|
|
8551
|
+
|
|
8552
|
+
// permute(0, 2, 1, 3)
|
|
8553
|
+
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
|
|
8554
|
+
}
|
|
8555
|
+
}
|
|
8556
|
+
}
|
|
8557
|
+
|
|
8558
|
+
static void ggml_compute_forward_flash_attn_ext_tiled(
|
|
8559
|
+
const ggml_compute_params * params,
|
|
8560
|
+
ggml_tensor * dst,
|
|
8561
|
+
int ir0, int ir1) {
|
|
8562
|
+
const ggml_tensor * q = dst->src[0];
|
|
8563
|
+
const ggml_tensor * k = dst->src[1];
|
|
8564
|
+
const ggml_tensor * v = dst->src[2];
|
|
8565
|
+
const ggml_tensor * mask = dst->src[3];
|
|
8566
|
+
const ggml_tensor * sinks = dst->src[4];
|
|
8567
|
+
|
|
8568
|
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
8569
|
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
8570
|
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
8571
|
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
|
8572
|
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
|
8573
|
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
|
8574
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
8575
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
8576
|
+
|
|
8577
|
+
const int64_t DK = nek0;
|
|
8578
|
+
const int64_t DV = nev0;
|
|
8579
|
+
const int64_t N = neq1;
|
|
8580
|
+
|
|
8581
|
+
GGML_ASSERT(ne0 == DV);
|
|
8582
|
+
GGML_ASSERT(ne2 == N);
|
|
8583
|
+
|
|
8584
|
+
// input tensor rows must be contiguous
|
|
8585
|
+
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
|
8586
|
+
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
|
8587
|
+
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
|
8588
|
+
|
|
8589
|
+
GGML_ASSERT(neq0 == DK);
|
|
8590
|
+
GGML_ASSERT(nek0 == DK);
|
|
8591
|
+
GGML_ASSERT(nev0 == DV);
|
|
8592
|
+
|
|
8593
|
+
GGML_ASSERT(neq1 == N);
|
|
8594
|
+
|
|
8595
|
+
// dst cannot be transposed or permuted
|
|
8596
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
|
8597
|
+
GGML_ASSERT(nb0 <= nb1);
|
|
8598
|
+
GGML_ASSERT(nb1 <= nb2);
|
|
8599
|
+
GGML_ASSERT(nb2 <= nb3);
|
|
8600
|
+
|
|
8601
|
+
GGML_ASSERT(k->type == v->type);
|
|
8602
|
+
const ggml_type kv_type = k->type;
|
|
8603
|
+
|
|
8604
|
+
|
|
8605
|
+
// broadcast factors
|
|
8606
|
+
const int64_t rk2 = neq2/nek2;
|
|
8607
|
+
const int64_t rk3 = neq3/nek3;
|
|
8608
|
+
|
|
8609
|
+
const int64_t rv2 = neq2/nev2;
|
|
8610
|
+
const int64_t rv3 = neq3/nev3;
|
|
8611
|
+
|
|
8612
|
+
float scale = 1.0f;
|
|
8613
|
+
float max_bias = 0.0f;
|
|
8614
|
+
float logit_softcap = 0.0f;
|
|
8615
|
+
|
|
8616
|
+
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
|
8617
|
+
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
|
8618
|
+
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
|
|
8619
|
+
|
|
8620
|
+
if (logit_softcap != 0) {
|
|
8621
|
+
scale /= logit_softcap;
|
|
8622
|
+
}
|
|
8623
|
+
|
|
8624
|
+
const uint32_t n_head = neq2;
|
|
8625
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
|
8626
|
+
|
|
8627
|
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
8628
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
8629
|
+
|
|
8630
|
+
int ith = params->ith;
|
|
8631
|
+
|
|
8632
|
+
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
|
|
8633
|
+
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
|
|
8634
|
+
|
|
8635
|
+
int ir = ir0;
|
|
8636
|
+
while (ir < ir1) {
|
|
8637
|
+
// q indices for the start of this tile
|
|
8638
|
+
const int iq3 = ir/(neq2*neq1);
|
|
8639
|
+
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
|
8640
|
+
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
|
8641
|
+
|
|
8642
|
+
// Number of valid rows in this tile:
|
|
8643
|
+
// - limited by tile size (Q_TILE_SZ)
|
|
8644
|
+
// - limited by chunk boundary (ir1 - ir)
|
|
8645
|
+
// - limited by head boundary (neq1 - iq1) to avoid crossing into next head
|
|
8646
|
+
const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
|
|
8647
|
+
GGML_ASSERT(tile_rows > 0);
|
|
8648
|
+
|
|
8649
|
+
const uint32_t h = iq2; // head index
|
|
8650
|
+
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
|
8651
|
+
|
|
8652
|
+
float S[Q_TILE_SZ];
|
|
8653
|
+
float M[Q_TILE_SZ];
|
|
8654
|
+
|
|
8655
|
+
for (int i = 0 ; i < Q_TILE_SZ; ++i) {
|
|
8656
|
+
S[i] = 0.;
|
|
8657
|
+
M[i] = -INFINITY;
|
|
8658
|
+
}
|
|
8659
|
+
|
|
8660
|
+
// Per-thread scratch layout:
|
|
8661
|
+
// Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
|
|
8662
|
+
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
|
|
8663
|
+
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
|
|
8664
|
+
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
|
|
8665
|
+
// V32: KV_TILE_SZ * DV (F32 buffer for V tile)
|
|
8666
|
+
// K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
|
|
8667
|
+
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
|
|
8668
|
+
|
|
8669
|
+
void * Q_q = base;
|
|
8670
|
+
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
|
|
8671
|
+
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
|
|
8672
|
+
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
|
|
8673
|
+
float * V32 = VKQ32 + Q_TILE_SZ * DV;
|
|
8674
|
+
float * K_f32 = V32 + KV_TILE_SZ * DV;
|
|
8675
|
+
|
|
8676
|
+
memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
|
|
8677
|
+
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
|
8678
|
+
|
|
8679
|
+
// k indices
|
|
8680
|
+
const int ik3 = iq3 / rk3;
|
|
8681
|
+
const int ik2 = iq2 / rk2;
|
|
8682
|
+
|
|
8683
|
+
// v indices
|
|
8684
|
+
const int iv3 = iq3 / rv3;
|
|
8685
|
+
const int iv2 = iq2 / rv2;
|
|
8686
|
+
|
|
8687
|
+
{
|
|
8688
|
+
float * Q_f32 = (float *)Q_q;
|
|
8689
|
+
for (int tq = 0; tq < tile_rows; tq++) {
|
|
8690
|
+
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
|
|
8691
|
+
memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
|
|
8692
|
+
}
|
|
8693
|
+
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
|
|
8694
|
+
memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
|
|
8695
|
+
}
|
|
8696
|
+
}
|
|
8697
|
+
|
|
8698
|
+
memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
|
|
8699
|
+
memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
|
|
8700
|
+
|
|
8701
|
+
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
|
|
8702
|
+
const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
|
|
8703
|
+
|
|
8704
|
+
// skip the tile entirely if all the masks are -inf
|
|
8705
|
+
if (mask) {
|
|
8706
|
+
bool can_skip = true;
|
|
8707
|
+
for (int tq = 0; tq < tile_rows; tq++) {
|
|
8708
|
+
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
|
|
8709
|
+
for (int tk = 0; tk < kv_tile; tk++) {
|
|
8710
|
+
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
|
|
8711
|
+
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
|
|
8712
|
+
can_skip = false;
|
|
8713
|
+
}
|
|
8714
|
+
}
|
|
8715
|
+
// Pad remaining mask entries with -inf
|
|
8716
|
+
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
|
|
8717
|
+
mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
|
|
8718
|
+
}
|
|
8719
|
+
}
|
|
8720
|
+
|
|
8721
|
+
if (can_skip) {
|
|
8722
|
+
continue;
|
|
8723
|
+
}
|
|
8724
|
+
}
|
|
8725
|
+
|
|
8726
|
+
// Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
|
|
8727
|
+
// Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
|
|
8728
|
+
for (int tk = 0; tk < kv_tile; tk++) {
|
|
8729
|
+
const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
|
|
8730
|
+
if (kv_type == GGML_TYPE_F16) {
|
|
8731
|
+
const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
|
|
8732
|
+
for (int64_t dk = 0; dk < DK; dk++) {
|
|
8733
|
+
K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
|
|
8734
|
+
}
|
|
8735
|
+
} else {
|
|
8736
|
+
const float * k_f32_src = (const float *)k_data;
|
|
8737
|
+
for (int64_t dk = 0; dk < DK; dk++) {
|
|
8738
|
+
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
|
|
8739
|
+
}
|
|
8740
|
+
}
|
|
8741
|
+
}
|
|
8742
|
+
memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
|
8743
|
+
simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
|
|
8744
|
+
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
|
|
8745
|
+
|
|
8746
|
+
// Set padded KQ entries to -inf so softmax gives them zero weight
|
|
8747
|
+
if (kv_tile < KV_TILE_SZ) {
|
|
8748
|
+
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
|
8749
|
+
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
|
|
8750
|
+
KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
|
|
8751
|
+
}
|
|
8752
|
+
}
|
|
8753
|
+
}
|
|
8754
|
+
|
|
8755
|
+
if (logit_softcap != 0.0f) {
|
|
8756
|
+
ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
|
|
8757
|
+
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
|
|
8758
|
+
}
|
|
8759
|
+
|
|
8760
|
+
if (mask) {
|
|
8761
|
+
ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
|
|
8762
|
+
}
|
|
8763
|
+
|
|
8764
|
+
bool skip[Q_TILE_SZ] = {};
|
|
8765
|
+
|
|
8766
|
+
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
|
8767
|
+
float * kq_row = KQ + tq * KV_TILE_SZ;
|
|
8768
|
+
|
|
8769
|
+
float tile_max;
|
|
8770
|
+
ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
|
|
8771
|
+
|
|
8772
|
+
if (tile_max == -INFINITY) {
|
|
8773
|
+
skip[tq] = true;
|
|
8774
|
+
continue;
|
|
8775
|
+
}
|
|
8776
|
+
|
|
8777
|
+
const float Mold = M[tq];
|
|
8778
|
+
const float Mnew = fmaxf(Mold, tile_max);
|
|
8779
|
+
|
|
8780
|
+
if (Mnew > Mold) {
|
|
8781
|
+
const float ms = expf(Mold - Mnew);
|
|
8782
|
+
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
|
|
8783
|
+
S[tq] *= ms;
|
|
8784
|
+
}
|
|
8785
|
+
M[tq] = Mnew;
|
|
8786
|
+
|
|
8787
|
+
|
|
8788
|
+
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
|
|
8789
|
+
}
|
|
8790
|
+
|
|
8791
|
+
// V accumulation: VKQ32 += softmax(KQ) * V
|
|
8792
|
+
// Pack V tile to contiguous F32, zero-padded
|
|
8793
|
+
for (int tk = 0; tk < kv_tile; tk++) {
|
|
8794
|
+
const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
|
|
8795
|
+
if (kv_type == GGML_TYPE_F16) {
|
|
8796
|
+
ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
|
|
8797
|
+
} else {
|
|
8798
|
+
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
|
|
8799
|
+
}
|
|
8800
|
+
}
|
|
8801
|
+
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
|
8802
|
+
if (skip[tq]) {
|
|
8803
|
+
memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
|
|
8804
|
+
}
|
|
8805
|
+
}
|
|
8806
|
+
simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
|
|
8807
|
+
}
|
|
8808
|
+
|
|
8809
|
+
// sinks (apply only to valid rows in the tile)
|
|
8810
|
+
if (sinks) {
|
|
8811
|
+
const float s = ((float *)((char *) sinks->data))[h];
|
|
8812
|
+
|
|
8813
|
+
for (int tq = 0; tq < tile_rows; tq++) {
|
|
8814
|
+
float ms = 1.0f;
|
|
8815
|
+
float vs = 1.0f;
|
|
8816
|
+
|
|
8817
|
+
if (s > M[tq]) {
|
|
8818
|
+
ms = expf(M[tq] - s);
|
|
8819
|
+
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
|
|
8820
|
+
} else {
|
|
8821
|
+
vs = expf(s - M[tq]);
|
|
8822
|
+
}
|
|
8823
|
+
|
|
8824
|
+
S[tq] = S[tq] * ms + vs;
|
|
8825
|
+
}
|
|
8826
|
+
}
|
|
8827
|
+
|
|
8828
|
+
for (int tq = 0; tq < tile_rows; tq++) {
|
|
8829
|
+
// V /= S
|
|
8830
|
+
const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
|
|
8831
|
+
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
|
|
8832
|
+
|
|
8833
|
+
// dst indices
|
|
8834
|
+
const int i1 = iq1 + tq;
|
|
8835
|
+
const int i2 = iq2;
|
|
8836
|
+
const int i3 = iq3;
|
|
8837
|
+
|
|
8838
|
+
// permute(0, 2, 1, 3)
|
|
8839
|
+
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
|
|
8840
|
+
}
|
|
8841
|
+
|
|
8842
|
+
ir += tile_rows;
|
|
8843
|
+
}
|
|
8844
|
+
}
|
|
8845
|
+
|
|
8846
|
+
// Reduction function: combines partial results across KV chunks
|
|
8847
|
+
// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
|
|
8848
|
+
static void ggml_flash_attn_ext_reduce_partials(
|
|
8849
|
+
const ggml_compute_params * params,
|
|
8850
|
+
ggml_tensor * dst,
|
|
8851
|
+
const int64_t n_chunks,
|
|
8852
|
+
const int64_t chunk_size) {
|
|
8853
|
+
|
|
8854
|
+
const ggml_tensor * q = dst->src[0];
|
|
8855
|
+
const ggml_tensor * k = dst->src[1];
|
|
8856
|
+
const ggml_tensor * v = dst->src[2];
|
|
8857
|
+
|
|
8858
|
+
const int64_t DK = k->ne[0];
|
|
8859
|
+
const int64_t DV = v->ne[0];
|
|
8860
|
+
const int64_t nek1 = k->ne[1];
|
|
8861
|
+
const int64_t n_q_heads = q->ne[2];
|
|
8862
|
+
|
|
8863
|
+
const int ith = params->ith;
|
|
8864
|
+
const int nth = params->nth;
|
|
8865
|
+
|
|
8866
|
+
const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
|
|
8867
|
+
float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
|
|
8868
|
+
|
|
8869
|
+
const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
|
|
8870
|
+
const int64_t partial_size = 2 + DV;
|
|
8871
|
+
const float * partials_base = (const float *) params->wdata + partials_offset;
|
|
8872
|
+
|
|
8873
|
+
// Output layout
|
|
8874
|
+
const int64_t ne1 = dst->ne[1];
|
|
8875
|
+
const int64_t ne2 = dst->ne[2];
|
|
8876
|
+
const size_t nb1 = dst->nb[1];
|
|
8239
8877
|
|
|
8240
|
-
|
|
8241
|
-
|
|
8878
|
+
// Each thread reduces a subset of query heads
|
|
8879
|
+
for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
|
|
8880
|
+
float M_final = -INFINITY;
|
|
8881
|
+
float S_final = 0.0f;
|
|
8882
|
+
float * VKQ_final = thread_wdata;
|
|
8883
|
+
memset(VKQ_final, 0, DV * sizeof(float));
|
|
8242
8884
|
|
|
8243
|
-
//
|
|
8244
|
-
|
|
8885
|
+
// Combine partials from all chunks
|
|
8886
|
+
for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
|
|
8887
|
+
const int64_t ic_start = chunk_idx * chunk_size;
|
|
8888
|
+
if (ic_start >= nek1) continue;
|
|
8889
|
+
|
|
8890
|
+
const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
|
|
8891
|
+
const float M_chunk = partial[0];
|
|
8892
|
+
const float S_chunk = partial[1];
|
|
8893
|
+
const float * VKQ_chunk = partial + 2;
|
|
8894
|
+
|
|
8895
|
+
if (S_chunk == 0.0f) continue;
|
|
8896
|
+
|
|
8897
|
+
const float M_new = fmaxf(M_final, M_chunk);
|
|
8898
|
+
const float scale_old = expf(M_final - M_new);
|
|
8899
|
+
const float scale_new = expf(M_chunk - M_new);
|
|
8900
|
+
|
|
8901
|
+
for (int64_t d = 0; d < DV; ++d) {
|
|
8902
|
+
VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
|
|
8903
|
+
}
|
|
8904
|
+
S_final = S_final * scale_old + S_chunk * scale_new;
|
|
8905
|
+
M_final = M_new;
|
|
8906
|
+
}
|
|
8907
|
+
|
|
8908
|
+
// Normalize and write to output
|
|
8909
|
+
if (S_final != 0.0f) {
|
|
8910
|
+
const float S_inv = 1.0f / S_final;
|
|
8911
|
+
ggml_vec_scale_f32(DV, VKQ_final, S_inv);
|
|
8912
|
+
}
|
|
8913
|
+
// iq1=0, iq3=0 for decode
|
|
8914
|
+
memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
|
|
8245
8915
|
}
|
|
8246
8916
|
}
|
|
8247
8917
|
|
|
@@ -8266,6 +8936,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8266
8936
|
const int64_t DV = nev0;
|
|
8267
8937
|
const int64_t N = neq1;
|
|
8268
8938
|
|
|
8939
|
+
|
|
8269
8940
|
GGML_ASSERT(ne0 == DV);
|
|
8270
8941
|
GGML_ASSERT(ne2 == N);
|
|
8271
8942
|
|
|
@@ -8286,47 +8957,97 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8286
8957
|
GGML_ASSERT(nb1 <= nb2);
|
|
8287
8958
|
GGML_ASSERT(nb2 <= nb3);
|
|
8288
8959
|
|
|
8289
|
-
// parallelize by q rows using ggml_vec_dot_f32
|
|
8290
|
-
|
|
8291
|
-
// total rows in q
|
|
8292
|
-
const int64_t nr = neq1*neq2*neq3;
|
|
8293
|
-
|
|
8294
|
-
// rows per thread
|
|
8295
8960
|
const int ith = params->ith;
|
|
8296
8961
|
const int nth = params->nth;
|
|
8297
8962
|
|
|
8298
|
-
//
|
|
8299
|
-
const bool
|
|
8963
|
+
// When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
|
|
8964
|
+
const bool use_ref = params->use_ref;
|
|
8300
8965
|
|
|
8301
|
-
|
|
8302
|
-
|
|
8303
|
-
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
|
8304
|
-
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
|
8966
|
+
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
|
|
8967
|
+
const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
|
|
8305
8968
|
|
|
8306
|
-
if (
|
|
8307
|
-
|
|
8308
|
-
}
|
|
8969
|
+
if (use_split_kv_path) {
|
|
8970
|
+
const int64_t chunk_size = (nek1 + nth - 1) / nth;
|
|
8309
8971
|
|
|
8310
|
-
|
|
8311
|
-
|
|
8312
|
-
|
|
8313
|
-
}
|
|
8972
|
+
// Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
|
|
8973
|
+
const int64_t partial_size = 2 + DV;
|
|
8974
|
+
float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
|
|
8314
8975
|
|
|
8315
|
-
|
|
8976
|
+
const int64_t ic_start = ith * chunk_size;
|
|
8977
|
+
const int64_t ic_end = std::min(ic_start + chunk_size, nek1);
|
|
8316
8978
|
|
|
8317
|
-
|
|
8318
|
-
|
|
8979
|
+
const int64_t partial_stride = nth * partial_size;
|
|
8980
|
+
float * chunk_partials = partials_base + ith * partial_size;
|
|
8319
8981
|
|
|
8320
|
-
|
|
8321
|
-
|
|
8982
|
+
if (ic_start < nek1) {
|
|
8983
|
+
for (int64_t q_head = 0; q_head < neq2; q_head++) {
|
|
8984
|
+
ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
8985
|
+
params, dst, q_head, q_head + 1, ic_start, ic_end,
|
|
8986
|
+
chunk_partials, partial_stride);
|
|
8987
|
+
}
|
|
8988
|
+
} else {
|
|
8989
|
+
for (int64_t q_head = 0; q_head < neq2; q_head++) {
|
|
8990
|
+
float * q_partials = chunk_partials + q_head * partial_stride;
|
|
8991
|
+
q_partials[0] = -INFINITY; // M
|
|
8992
|
+
q_partials[1] = 0.0f; // S
|
|
8993
|
+
}
|
|
8994
|
+
}
|
|
8322
8995
|
|
|
8323
|
-
|
|
8324
|
-
|
|
8325
|
-
|
|
8996
|
+
ggml_barrier(params->threadpool);
|
|
8997
|
+
ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
|
|
8998
|
+
} else {
|
|
8326
8999
|
|
|
8327
|
-
|
|
9000
|
+
// total rows in q
|
|
9001
|
+
const int64_t nr = neq1*neq2*neq3;
|
|
8328
9002
|
|
|
8329
|
-
|
|
9003
|
+
// disable for NUMA
|
|
9004
|
+
const bool disable_chunking = ggml_is_numa();
|
|
9005
|
+
|
|
9006
|
+
// 4x chunks per thread
|
|
9007
|
+
int nth_scaled = nth * 4;
|
|
9008
|
+
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
|
9009
|
+
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
|
9010
|
+
|
|
9011
|
+
if (nth == 1 || nchunk < nth || disable_chunking) {
|
|
9012
|
+
nchunk = nth;
|
|
9013
|
+
}
|
|
9014
|
+
|
|
9015
|
+
if (ith == 0) {
|
|
9016
|
+
ggml_threadpool_chunk_set(params->threadpool, nth);
|
|
9017
|
+
}
|
|
9018
|
+
|
|
9019
|
+
ggml_barrier(params->threadpool);
|
|
9020
|
+
|
|
9021
|
+
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
|
9022
|
+
|
|
9023
|
+
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
|
9024
|
+
bool use_tiled = !use_ref &&
|
|
9025
|
+
(q->type == GGML_TYPE_F32 &&
|
|
9026
|
+
kv_is_f32_or_f16 &&
|
|
9027
|
+
k->type == v->type &&
|
|
9028
|
+
neq1 >= Q_TILE_SZ);
|
|
9029
|
+
#ifdef GGML_SIMD
|
|
9030
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
9031
|
+
const int64_t f32_epr = svcntw();
|
|
9032
|
+
#else
|
|
9033
|
+
const int64_t f32_epr = GGML_F32_EPR;
|
|
9034
|
+
#endif
|
|
9035
|
+
use_tiled &= (DV % f32_epr == 0);
|
|
9036
|
+
#endif
|
|
9037
|
+
int current_chunk = ith;
|
|
9038
|
+
|
|
9039
|
+
while (current_chunk < nchunk) {
|
|
9040
|
+
const int64_t ir0 = dr * current_chunk;
|
|
9041
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
9042
|
+
|
|
9043
|
+
if (use_tiled) {
|
|
9044
|
+
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
|
|
9045
|
+
} else {
|
|
9046
|
+
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
|
|
9047
|
+
}
|
|
9048
|
+
|
|
9049
|
+
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
9050
|
+
}
|
|
8330
9051
|
}
|
|
8331
9052
|
}
|
|
8332
9053
|
|
|
@@ -9107,7 +9828,7 @@ void ggml_compute_forward_win_unpart(
|
|
|
9107
9828
|
}
|
|
9108
9829
|
}
|
|
9109
9830
|
|
|
9110
|
-
//
|
|
9831
|
+
//ggml_compute_forward_unary
|
|
9111
9832
|
|
|
9112
9833
|
void ggml_compute_forward_unary(
|
|
9113
9834
|
const ggml_compute_params * params,
|
|
@@ -9396,13 +10117,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|
|
9396
10117
|
const int ith = params->ith;
|
|
9397
10118
|
const int nth = params->nth;
|
|
9398
10119
|
|
|
9399
|
-
|
|
9400
|
-
|
|
9401
|
-
|
|
9402
|
-
|
|
9403
|
-
const int h_start = (HEADS * ith) / nth;
|
|
9404
|
-
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
|
9405
|
-
(HEADS * (ith + 1)) / nth : HEADS;
|
|
10120
|
+
const int h_start = (HEADS * (ith )) / nth;
|
|
10121
|
+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
|
10122
|
+
(HEADS * (ith + 1)) / nth : HEADS;
|
|
9406
10123
|
|
|
9407
10124
|
float * k = (float *) dst->src[0]->data;
|
|
9408
10125
|
float * v = (float *) dst->src[1]->data;
|
|
@@ -9613,13 +10330,9 @@ static void ggml_compute_forward_gla_f32(
|
|
|
9613
10330
|
const int ith = params->ith;
|
|
9614
10331
|
const int nth = params->nth;
|
|
9615
10332
|
|
|
9616
|
-
|
|
9617
|
-
|
|
9618
|
-
|
|
9619
|
-
|
|
9620
|
-
const int h_start = (HEADS * ith) / nth;
|
|
9621
|
-
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
|
9622
|
-
(HEADS * (ith + 1)) / nth : HEADS;
|
|
10333
|
+
const int h_start = (HEADS * (ith )) / nth;
|
|
10334
|
+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
|
10335
|
+
(HEADS * (ith + 1)) / nth : HEADS;
|
|
9623
10336
|
|
|
9624
10337
|
float * k = (float *) dst->src[0]->data;
|
|
9625
10338
|
float * v = (float *) dst->src[1]->data;
|
|
@@ -9870,6 +10583,219 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s
|
|
|
9870
10583
|
}
|
|
9871
10584
|
}
|
|
9872
10585
|
|
|
10586
|
+
// ggml_compute_forward_gated_delta_net
|
|
10587
|
+
static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|
10588
|
+
const ggml_compute_params * params,
|
|
10589
|
+
ggml_tensor * dst,
|
|
10590
|
+
int64_t ir0,
|
|
10591
|
+
int64_t ir1) {
|
|
10592
|
+
|
|
10593
|
+
ggml_tensor * src_q = dst->src[0];
|
|
10594
|
+
ggml_tensor * src_k = dst->src[1];
|
|
10595
|
+
ggml_tensor * src_v = dst->src[2];
|
|
10596
|
+
ggml_tensor * src_g = dst->src[3];
|
|
10597
|
+
ggml_tensor * src_beta = dst->src[4];
|
|
10598
|
+
ggml_tensor * src_state = dst->src[5];
|
|
10599
|
+
|
|
10600
|
+
const int64_t S_v = src_v->ne[0];
|
|
10601
|
+
const int64_t H = src_v->ne[1];
|
|
10602
|
+
const int64_t n_tokens = src_v->ne[2];
|
|
10603
|
+
const int64_t n_seqs = src_v->ne[3];
|
|
10604
|
+
|
|
10605
|
+
GGML_ASSERT(ggml_is_contiguous_rows(src_q));
|
|
10606
|
+
GGML_ASSERT(ggml_is_contiguous_rows(src_k));
|
|
10607
|
+
GGML_ASSERT(ggml_is_contiguous_rows(src_v));
|
|
10608
|
+
GGML_ASSERT(ggml_is_contiguous(src_g));
|
|
10609
|
+
GGML_ASSERT(ggml_is_contiguous(src_beta));
|
|
10610
|
+
GGML_ASSERT(ggml_is_contiguous(src_state));
|
|
10611
|
+
|
|
10612
|
+
GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v);
|
|
10613
|
+
GGML_ASSERT(src_beta->ne[0] == 1);
|
|
10614
|
+
|
|
10615
|
+
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
|
|
10616
|
+
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
|
|
10617
|
+
GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
|
|
10618
|
+
GGML_TENSOR_LOCALS(size_t, nbk, src_k, nb);
|
|
10619
|
+
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
|
|
10620
|
+
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
|
|
10621
|
+
GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne);
|
|
10622
|
+
GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb);
|
|
10623
|
+
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
|
|
10624
|
+
|
|
10625
|
+
const bool kda = (neg0 == S_v);
|
|
10626
|
+
|
|
10627
|
+
// K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs].
|
|
10628
|
+
const int64_t K = ggml_get_op_params_i32(dst, 0);
|
|
10629
|
+
GGML_ASSERT(K >= 1);
|
|
10630
|
+
// per-seq stride in floats (seq s starts at state + s * seq_stride)
|
|
10631
|
+
const int64_t state_seq_stride = src_state->nb[3] / sizeof(float);
|
|
10632
|
+
|
|
10633
|
+
const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0);
|
|
10634
|
+
const int ith = params->ith;
|
|
10635
|
+
|
|
10636
|
+
float * delta = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32;
|
|
10637
|
+
float * state_work = K > 1 ? (delta + S_v) : nullptr;
|
|
10638
|
+
|
|
10639
|
+
// output layout: [attn_scores | new_states]
|
|
10640
|
+
// attn_scores: S_v * H * n_tokens * n_seqs floats
|
|
10641
|
+
// new_states: S_v * S_v * H * n_seqs * K floats (K snapshot slots; last min(n_tokens, K))
|
|
10642
|
+
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
|
|
10643
|
+
const int64_t state_size_per_snap = S_v * S_v * H * n_seqs;
|
|
10644
|
+
float * attn_out_base = (float *)dst->data;
|
|
10645
|
+
float * state_out_base = (float *)dst->data + attn_score_elems;
|
|
10646
|
+
|
|
10647
|
+
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
|
|
10648
|
+
// When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
|
|
10649
|
+
|
|
10650
|
+
const float * state_in_base = (const float *)src_state->data;
|
|
10651
|
+
|
|
10652
|
+
//const int64_t rq1 = nev1 / neq1;
|
|
10653
|
+
//const int64_t rk1 = nev1 / nek1;
|
|
10654
|
+
const int64_t rq3 = nev3 / neq3;
|
|
10655
|
+
const int64_t rk3 = nev3 / nek3;
|
|
10656
|
+
|
|
10657
|
+
const float scale = 1.0f / sqrtf((float) S_v);
|
|
10658
|
+
|
|
10659
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
10660
|
+
const int64_t iv1 = ir % H; // head_index
|
|
10661
|
+
const int64_t iv3 = ir / H; // sequence
|
|
10662
|
+
|
|
10663
|
+
const int64_t iq1 = iv1 % neq1;
|
|
10664
|
+
const int64_t ik1 = iv1 % nek1;
|
|
10665
|
+
|
|
10666
|
+
const int64_t iq3 = iv3 / rq3;
|
|
10667
|
+
const int64_t ik3 = iv3 / rk3;
|
|
10668
|
+
|
|
10669
|
+
// For K=1, write directly to the single output slot to avoid an extra memcpy at the end.
|
|
10670
|
+
// For K>1, work in scratch and copy out per-token when the slot is in range.
|
|
10671
|
+
float * s_out = (K > 1)
|
|
10672
|
+
? state_work
|
|
10673
|
+
: state_out_base + (iv3 * H + iv1) * S_v * S_v;
|
|
10674
|
+
|
|
10675
|
+
// copy input state into the working buffer and operate in-place
|
|
10676
|
+
// state layout [S_v, S_v, H, n_seqs]: seq iv3 starts at iv3 * state_seq_stride.
|
|
10677
|
+
const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v;
|
|
10678
|
+
memcpy(s_out, s_in, S_v * S_v * sizeof(float));
|
|
10679
|
+
|
|
10680
|
+
// attn output pointer for first token of this (head, seq)
|
|
10681
|
+
float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v;
|
|
10682
|
+
|
|
10683
|
+
for (int64_t t = 0; t < n_tokens; t++) {
|
|
10684
|
+
const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1);
|
|
10685
|
+
const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1);
|
|
10686
|
+
const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
|
|
10687
|
+
|
|
10688
|
+
const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
|
|
10689
|
+
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
|
|
10690
|
+
|
|
10691
|
+
// state is stored transposed: s_out[j*S_v + i] = S[i][j]
|
|
10692
|
+
// so row j of s_out = column j of S (contiguous access)
|
|
10693
|
+
|
|
10694
|
+
if (kda) {
|
|
10695
|
+
// precompute exp(g) into delta scratch (reused below)
|
|
10696
|
+
for (int64_t i = 0; i < S_v; ++i) {
|
|
10697
|
+
delta[i] = expf(g_d[i]);
|
|
10698
|
+
}
|
|
10699
|
+
// S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i])
|
|
10700
|
+
for (int64_t j = 0; j < S_v; ++j) {
|
|
10701
|
+
ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta);
|
|
10702
|
+
}
|
|
10703
|
+
} else {
|
|
10704
|
+
ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
|
|
10705
|
+
}
|
|
10706
|
+
|
|
10707
|
+
// delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k)
|
|
10708
|
+
for (int64_t j = 0; j < S_v; ++j) {
|
|
10709
|
+
float sum = 0.0f;
|
|
10710
|
+
ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1);
|
|
10711
|
+
delta[j] = (v_d[j] - sum) * beta_val;
|
|
10712
|
+
}
|
|
10713
|
+
|
|
10714
|
+
// outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i]
|
|
10715
|
+
for (int64_t j = 0; j < S_v; ++j) {
|
|
10716
|
+
ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]);
|
|
10717
|
+
}
|
|
10718
|
+
|
|
10719
|
+
// attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q)
|
|
10720
|
+
for (int64_t j = 0; j < S_v; ++j) {
|
|
10721
|
+
float sum = 0.0f;
|
|
10722
|
+
ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1);
|
|
10723
|
+
attn_data[j] = sum * scale;
|
|
10724
|
+
}
|
|
10725
|
+
|
|
10726
|
+
attn_data += S_v * H; // advance to next token
|
|
10727
|
+
|
|
10728
|
+
if (K > 1) {
|
|
10729
|
+
const int64_t target_slot = n_tokens - 1 - t;
|
|
10730
|
+
if (target_slot >= 0 && target_slot < K) {
|
|
10731
|
+
float * curr_state_o = state_out_base + target_slot * state_size_per_snap +
|
|
10732
|
+
(iv3 * H + iv1) * S_v * S_v;
|
|
10733
|
+
memcpy(curr_state_o, s_out, S_v * S_v * sizeof(float));
|
|
10734
|
+
}
|
|
10735
|
+
}
|
|
10736
|
+
}
|
|
10737
|
+
}
|
|
10738
|
+
}
|
|
10739
|
+
|
|
10740
|
+
|
|
10741
|
+
static void ggml_compute_forward_gated_delta_net_f32(
|
|
10742
|
+
const ggml_compute_params * params,
|
|
10743
|
+
ggml_tensor * dst) {
|
|
10744
|
+
|
|
10745
|
+
ggml_tensor * V = dst->src[2];
|
|
10746
|
+
int64_t nr = V->ne[1] * V->ne[3];
|
|
10747
|
+
|
|
10748
|
+
// disable for NUMA
|
|
10749
|
+
const bool disable_chunking = ggml_is_numa();
|
|
10750
|
+
|
|
10751
|
+
int nth = params->nth;
|
|
10752
|
+
int ith = params->ith;
|
|
10753
|
+
|
|
10754
|
+
// 4x chunks per thread
|
|
10755
|
+
int nth_scaled = nth * 4;
|
|
10756
|
+
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
|
10757
|
+
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
|
10758
|
+
|
|
10759
|
+
if (nth == 1 || nchunk < nth || disable_chunking) {
|
|
10760
|
+
nchunk = nth;
|
|
10761
|
+
}
|
|
10762
|
+
|
|
10763
|
+
if (ith == 0) {
|
|
10764
|
+
ggml_threadpool_chunk_set(params->threadpool, nth);
|
|
10765
|
+
}
|
|
10766
|
+
|
|
10767
|
+
ggml_barrier(params->threadpool);
|
|
10768
|
+
|
|
10769
|
+
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
|
10770
|
+
|
|
10771
|
+
int current_chunk = ith;
|
|
10772
|
+
|
|
10773
|
+
while (current_chunk < nchunk) {
|
|
10774
|
+
const int64_t ir0 = dr * current_chunk;
|
|
10775
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
10776
|
+
|
|
10777
|
+
ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1);
|
|
10778
|
+
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
10779
|
+
}
|
|
10780
|
+
}
|
|
10781
|
+
|
|
10782
|
+
void ggml_compute_forward_gated_delta_net(
|
|
10783
|
+
const ggml_compute_params * params,
|
|
10784
|
+
ggml_tensor * dst) {
|
|
10785
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
10786
|
+
|
|
10787
|
+
switch (src0->type) {
|
|
10788
|
+
case GGML_TYPE_F32:
|
|
10789
|
+
{
|
|
10790
|
+
ggml_compute_forward_gated_delta_net_f32(params, dst);
|
|
10791
|
+
} break;
|
|
10792
|
+
default:
|
|
10793
|
+
{
|
|
10794
|
+
GGML_ABORT("fatal error");
|
|
10795
|
+
}
|
|
10796
|
+
}
|
|
10797
|
+
}
|
|
10798
|
+
|
|
9873
10799
|
// ggml_compute_forward_rwkv_wkv7
|
|
9874
10800
|
|
|
9875
10801
|
static void ggml_compute_forward_rwkv_wkv7_f32(
|
|
@@ -9887,13 +10813,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
|
|
|
9887
10813
|
const int ith = params->ith;
|
|
9888
10814
|
const int nth = params->nth;
|
|
9889
10815
|
|
|
9890
|
-
|
|
9891
|
-
|
|
9892
|
-
|
|
9893
|
-
|
|
9894
|
-
const int h_start = (HEADS * ith) / nth;
|
|
9895
|
-
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
|
9896
|
-
(HEADS * (ith + 1)) / nth : HEADS;
|
|
10816
|
+
const int h_start = (HEADS * (ith )) / nth;
|
|
10817
|
+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
|
10818
|
+
(HEADS * (ith + 1)) / nth : HEADS;
|
|
9897
10819
|
|
|
9898
10820
|
float * r = (float *) dst->src[0]->data;
|
|
9899
10821
|
float * w = (float *) dst->src[1]->data;
|
|
@@ -10195,7 +11117,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
|
|
|
10195
11117
|
assert(!isnan(s0[i]));
|
|
10196
11118
|
assert(!isnan(s1[i]));
|
|
10197
11119
|
}
|
|
10198
|
-
#endif
|
|
11120
|
+
#endif // NDEBUG
|
|
10199
11121
|
|
|
10200
11122
|
float max = -INFINITY;
|
|
10201
11123
|
ggml_vec_max_f32(nc, &max, s0);
|
|
@@ -10214,7 +11136,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
|
|
|
10214
11136
|
assert(!isnan(st[i]));
|
|
10215
11137
|
assert(!isinf(st[i]));
|
|
10216
11138
|
}
|
|
10217
|
-
#endif
|
|
11139
|
+
#endif // NDEBUG
|
|
10218
11140
|
}
|
|
10219
11141
|
sums[ith] = sum_thread;
|
|
10220
11142
|
ggml_barrier(params->threadpool);
|
|
@@ -10287,7 +11209,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
|
10287
11209
|
assert(!isnan(s0[i]));
|
|
10288
11210
|
assert(!isnan(s1[i]));
|
|
10289
11211
|
}
|
|
10290
|
-
#endif
|
|
11212
|
+
#endif // NDEBUG
|
|
10291
11213
|
|
|
10292
11214
|
// soft_max
|
|
10293
11215
|
float max = -INFINITY;
|
|
@@ -10305,7 +11227,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
|
10305
11227
|
assert(!isnan(ds0[i]));
|
|
10306
11228
|
assert(!isinf(ds0[i]));
|
|
10307
11229
|
}
|
|
10308
|
-
#endif
|
|
11230
|
+
#endif // NDEBUG
|
|
10309
11231
|
}
|
|
10310
11232
|
}
|
|
10311
11233
|
|
|
@@ -10471,3 +11393,95 @@ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_
|
|
|
10471
11393
|
}
|
|
10472
11394
|
}
|
|
10473
11395
|
}
|
|
11396
|
+
|
|
11397
|
+
static void ggml_compute_forward_fwht_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
11398
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
11399
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
11400
|
+
|
|
11401
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
11402
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
11403
|
+
|
|
11404
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
11405
|
+
|
|
11406
|
+
const int ith = params->ith;
|
|
11407
|
+
const int nth = params->nth;
|
|
11408
|
+
|
|
11409
|
+
const int64_t n = ne10;
|
|
11410
|
+
GGML_ASSERT((n & (n - 1)) == 0); // must be power of 2
|
|
11411
|
+
|
|
11412
|
+
const int64_t nr = ne11 * ne12 * ne13;
|
|
11413
|
+
const int64_t rows_per_thread = (nr + nth - 1) / nth;
|
|
11414
|
+
const int64_t start_row = ith * rows_per_thread;
|
|
11415
|
+
const int64_t end_row = MIN(start_row + rows_per_thread, nr);
|
|
11416
|
+
|
|
11417
|
+
const float scale = 1.0f / sqrtf((float)n);
|
|
11418
|
+
|
|
11419
|
+
#if defined(GGML_SIMD)
|
|
11420
|
+
const GGML_F32_VEC v_minus_one = GGML_F32_VEC_SET1(-1.0f);
|
|
11421
|
+
#endif
|
|
11422
|
+
|
|
11423
|
+
for (int64_t r = start_row; r < end_row; r++) {
|
|
11424
|
+
const int64_t i13 = r / (ne11 * ne12);
|
|
11425
|
+
const int64_t i12 = (r - i13 * ne11 * ne12) / ne11;
|
|
11426
|
+
const int64_t i11 = r - i13 * ne11 * ne12 - i12 * ne11;
|
|
11427
|
+
|
|
11428
|
+
const float * src_row = (const float *) ((const char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13);
|
|
11429
|
+
float * dst_row = (float *) ((char *) dst->data + i11 * nb1 + i12 * nb2 + i13 * nb3);
|
|
11430
|
+
|
|
11431
|
+
for (int64_t j = 0; j < n; j++) {
|
|
11432
|
+
dst_row[j] = src_row[j] * scale;
|
|
11433
|
+
}
|
|
11434
|
+
|
|
11435
|
+
// Scalar passes
|
|
11436
|
+
#if defined(GGML_SIMD)
|
|
11437
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
11438
|
+
const int step = svcntw();
|
|
11439
|
+
#else
|
|
11440
|
+
const int step = GGML_F32_EPR;
|
|
11441
|
+
#endif
|
|
11442
|
+
#else
|
|
11443
|
+
const int step = n;
|
|
11444
|
+
#endif
|
|
11445
|
+
for (int64_t len = 1; len < step && len < n; len <<= 1) {
|
|
11446
|
+
for (int64_t i = 0; i < n; i += 2 * len) {
|
|
11447
|
+
for (int64_t j = 0; j < len; j++) {
|
|
11448
|
+
float u = dst_row[i + j];
|
|
11449
|
+
float v = dst_row[i + len + j];
|
|
11450
|
+
dst_row[i + j] = u + v;
|
|
11451
|
+
dst_row[i + len + j] = u - v;
|
|
11452
|
+
}
|
|
11453
|
+
}
|
|
11454
|
+
}
|
|
11455
|
+
|
|
11456
|
+
// SIMD passes using GGML_F32_VEC_* macros for multi-architecture support
|
|
11457
|
+
#if defined(GGML_SIMD)
|
|
11458
|
+
for (int64_t len = step; len < n; len <<= 1) {
|
|
11459
|
+
for (int64_t i = 0; i < n; i += 2 * len) {
|
|
11460
|
+
for (int64_t j = 0; j < len; j += step) {
|
|
11461
|
+
GGML_F32_VEC u = GGML_F32_VEC_LOAD(dst_row + i + j);
|
|
11462
|
+
GGML_F32_VEC v = GGML_F32_VEC_LOAD(dst_row + i + len + j);
|
|
11463
|
+
|
|
11464
|
+
GGML_F32_VEC_STORE(dst_row + i + j, GGML_F32_VEC_ADD(u, v));
|
|
11465
|
+
GGML_F32_VEC_STORE(dst_row + i + len + j, GGML_F32_VEC_FMA(u, v, v_minus_one));
|
|
11466
|
+
}
|
|
11467
|
+
}
|
|
11468
|
+
}
|
|
11469
|
+
#endif
|
|
11470
|
+
}
|
|
11471
|
+
}
|
|
11472
|
+
|
|
11473
|
+
void ggml_compute_forward_fwht(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
11474
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
11475
|
+
|
|
11476
|
+
switch (src1->type) {
|
|
11477
|
+
case GGML_TYPE_F32:
|
|
11478
|
+
{
|
|
11479
|
+
ggml_compute_forward_fwht_f32(params, dst);
|
|
11480
|
+
}
|
|
11481
|
+
break;
|
|
11482
|
+
default:
|
|
11483
|
+
{
|
|
11484
|
+
GGML_ABORT("fatal error - fwht is F32 only");
|
|
11485
|
+
}
|
|
11486
|
+
}
|
|
11487
|
+
}
|