whispercpp 1.3.6 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/README.md +38 -5
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -8
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +36 -42
- data/ext/ruby_whisper.h +135 -0
- data/ext/ruby_whisper_context.c +107 -28
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -65
- data/ext/ruby_whisper_segment.c +6 -6
- data/ext/ruby_whisper_transcribe.cpp +42 -15
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +1 -1
- data/ext/sources/examples/cli/cli.cpp +43 -9
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +199 -163
- data/ext/sources/ggml/CMakeLists.txt +21 -13
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +72 -10
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-rpc.h +3 -3
- data/ext/sources/ggml/include/ggml.h +101 -9
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +22 -5
- data/ext/sources/ggml/src/ggml-alloc.c +5 -1
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
- data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
- data/ext/sources/ggml/src/ggml-impl.h +6 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
- data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +289 -114
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
- data/ext/sources/ggml/src/ggml.c +110 -28
- data/ext/sources/ggml/src/gguf.cpp +173 -28
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +56 -12
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +411 -62
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +24 -6
- data/whispercpp.gemspec +2 -2
- metadata +215 -281
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
- data/ext/sources/examples/talk-llama/llama-context.h +0 -359
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
- data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
- data/ext/sources/examples/talk-llama/llama-model.h +0 -597
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
- data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
- data/ext/sources/examples/talk-llama/llama.h +0 -1573
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -704
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
- /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
|
@@ -1,7 +1,12 @@
|
|
|
1
1
|
#include "common.cuh"
|
|
2
2
|
|
|
3
3
|
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
|
|
4
|
-
|
|
4
|
+
|
|
5
|
+
bool ggml_cuda_should_use_mmvq(enum ggml_type type, int cc, int64_t ne11);
|
|
6
|
+
|
|
7
|
+
// Returns the maximum batch size for which MMVQ should be used for MUL_MAT_ID,
|
|
8
|
+
// based on the quantization type and GPU architecture (compute capability).
|
|
9
|
+
int get_mmvq_mmid_max_batch(ggml_type type, int cc);
|
|
5
10
|
|
|
6
11
|
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
|
|
7
12
|
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
|
|
@@ -18,6 +18,7 @@ static __global__ void norm_f32(
|
|
|
18
18
|
|
|
19
19
|
float2 mean_var = make_float2(0.0f, 0.0f);
|
|
20
20
|
|
|
21
|
+
ggml_cuda_pdl_sync();
|
|
21
22
|
for (int col = tid; col < ncols; col += block_size) {
|
|
22
23
|
const float xi = x[col];
|
|
23
24
|
mean_var.x += xi;
|
|
@@ -46,6 +47,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
|
|
46
47
|
|
|
47
48
|
float tmp = 0.0f; // partial sum for thread in warp
|
|
48
49
|
|
|
50
|
+
ggml_cuda_pdl_sync();
|
|
49
51
|
for (int j = start; j < end; j += block_size) {
|
|
50
52
|
tmp += x[j];
|
|
51
53
|
}
|
|
@@ -95,6 +97,7 @@ static __global__ void rms_norm_f32(const float * x,
|
|
|
95
97
|
const uint3 add_nrows_packed = make_uint3(0, 0, 0),
|
|
96
98
|
const uint3 add_nchannels_packed = make_uint3(0, 0, 0),
|
|
97
99
|
const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) {
|
|
100
|
+
ggml_cuda_pdl_lc();
|
|
98
101
|
const int nrows = gridDim.x;
|
|
99
102
|
const int nchannels = gridDim.y;
|
|
100
103
|
|
|
@@ -124,6 +127,7 @@ static __global__ void rms_norm_f32(const float * x,
|
|
|
124
127
|
|
|
125
128
|
float tmp = 0.0f; // partial sum for thread in warp
|
|
126
129
|
|
|
130
|
+
ggml_cuda_pdl_sync();
|
|
127
131
|
for (int col = tid; col < ncols; col += block_size) {
|
|
128
132
|
const float xi = x[col];
|
|
129
133
|
tmp += xi * xi;
|
|
@@ -163,6 +167,7 @@ static __global__ void rms_norm_back_f32(
|
|
|
163
167
|
float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
|
|
164
168
|
float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
|
|
165
169
|
|
|
170
|
+
ggml_cuda_pdl_sync();
|
|
166
171
|
for (int col = tid; col < ncols; col += block_size) {
|
|
167
172
|
const float xfi = xf[col];
|
|
168
173
|
sum_xx += xfi * xfi;
|
|
@@ -253,6 +258,7 @@ static __global__ void l2_norm_f32(
|
|
|
253
258
|
|
|
254
259
|
float tmp = 0.0f; // partial sum for thread in warp
|
|
255
260
|
|
|
261
|
+
ggml_cuda_pdl_sync();
|
|
256
262
|
for (int col = tid; col < ncols; col += block_size) {
|
|
257
263
|
const float xi = x[col];
|
|
258
264
|
tmp += xi * xi;
|
|
@@ -261,6 +267,7 @@ static __global__ void l2_norm_f32(
|
|
|
261
267
|
// sum up partial sums
|
|
262
268
|
extern __shared__ float s_sum[];
|
|
263
269
|
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
|
|
270
|
+
ggml_cuda_pdl_lc();
|
|
264
271
|
|
|
265
272
|
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
|
266
273
|
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
|
|
@@ -300,10 +307,19 @@ static void rms_norm_f32_cuda(
|
|
|
300
307
|
const dim3 blocks_num(nrows, nchannels, nsamples);
|
|
301
308
|
if (ncols < 1024) {
|
|
302
309
|
const dim3 block_dims(256, 1, 1);
|
|
303
|
-
|
|
310
|
+
const ggml_cuda_kernel_launch_params launch_params = {blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
|
|
311
|
+
ggml_cuda_kernel_launch(rms_norm_f32<256, false>, launch_params,
|
|
312
|
+
x, dst, ncols, stride_row, stride_channel, stride_sample, eps,
|
|
313
|
+
// underlying cudaLaunchKernelEx does not support default params
|
|
314
|
+
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0),
|
|
315
|
+
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0));
|
|
304
316
|
} else {
|
|
305
317
|
const dim3 block_dims(1024, 1, 1);
|
|
306
|
-
|
|
318
|
+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
|
|
319
|
+
ggml_cuda_kernel_launch(rms_norm_f32<1024, false>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps,
|
|
320
|
+
// underlying cudaLaunchKernelEx does not support default params
|
|
321
|
+
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0),
|
|
322
|
+
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0));
|
|
307
323
|
}
|
|
308
324
|
}
|
|
309
325
|
|
|
@@ -346,14 +362,20 @@ static void rms_norm_mul_f32_cuda(const float * x,
|
|
|
346
362
|
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
|
|
347
363
|
if (ncols < 1024) {
|
|
348
364
|
const dim3 block_dims(256, 1, 1);
|
|
349
|
-
|
|
365
|
+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
|
|
366
|
+
ggml_cuda_kernel_launch(rms_norm_f32<256, true>, launch_params,
|
|
350
367
|
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
|
351
|
-
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed
|
|
368
|
+
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed,
|
|
369
|
+
// underlying cudaLaunchKernelEx does not support default params
|
|
370
|
+
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0));
|
|
352
371
|
} else {
|
|
353
372
|
const dim3 block_dims(1024, 1, 1);
|
|
354
|
-
|
|
373
|
+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
|
|
374
|
+
ggml_cuda_kernel_launch(rms_norm_f32<1024, true>, launch_params,
|
|
355
375
|
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
|
356
|
-
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed
|
|
376
|
+
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed,
|
|
377
|
+
// underlying cudaLaunchKernelEx does not support default params
|
|
378
|
+
nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0));
|
|
357
379
|
}
|
|
358
380
|
} else {
|
|
359
381
|
const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
|
|
@@ -367,14 +389,16 @@ static void rms_norm_mul_f32_cuda(const float * x,
|
|
|
367
389
|
const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
|
|
368
390
|
if (ncols < 1024) {
|
|
369
391
|
const dim3 block_dims(256, 1, 1);
|
|
370
|
-
|
|
392
|
+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims,block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
|
|
393
|
+
ggml_cuda_kernel_launch(rms_norm_f32<256, true, true>, launch_params,
|
|
371
394
|
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
|
372
395
|
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
|
|
373
396
|
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
|
|
374
397
|
add_nchannels_packed, add_nsamples_packed);
|
|
375
398
|
} else {
|
|
376
399
|
const dim3 block_dims(1024, 1, 1);
|
|
377
|
-
|
|
400
|
+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
|
|
401
|
+
ggml_cuda_kernel_launch(rms_norm_f32<1024, true, true>, launch_params,
|
|
378
402
|
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
|
379
403
|
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
|
|
380
404
|
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
|
|
@@ -399,10 +423,12 @@ static void l2_norm_f32_cuda(
|
|
|
399
423
|
const dim3 blocks_num(nrows, nchannels, nsamples);
|
|
400
424
|
if (ncols < 1024) {
|
|
401
425
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
|
402
|
-
|
|
426
|
+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, 0, stream};
|
|
427
|
+
ggml_cuda_kernel_launch(l2_norm_f32<WARP_SIZE>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
403
428
|
} else {
|
|
404
429
|
const dim3 block_dims(1024, 1, 1);
|
|
405
|
-
|
|
430
|
+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream};
|
|
431
|
+
ggml_cuda_kernel_launch(l2_norm_f32<1024>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
|
406
432
|
}
|
|
407
433
|
}
|
|
408
434
|
|
|
@@ -54,15 +54,31 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
54
54
|
const int64_t dps2 = ne2 / ne02;
|
|
55
55
|
const int64_t dps3 = ne3 / ne03;
|
|
56
56
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
57
|
+
if (dps2 == 1 && ne2 > 1) {
|
|
58
|
+
// src0 has uniform stride s02 along dim 2; batch the inner loop with a strided GEMM
|
|
59
|
+
GGML_ASSERT(ne2 <= std::numeric_limits<int>::max());
|
|
60
|
+
const int batch_count = (int) ne2;
|
|
61
|
+
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
|
60
62
|
CUBLAS_CHECK(
|
|
61
|
-
|
|
63
|
+
cublasSgemmStridedBatched(handle, CUBLAS_OP_N, src1_cublas_op,
|
|
62
64
|
ne0, ne1, ne01,
|
|
63
|
-
&alpha, src0_d + (i3/dps3)*s03
|
|
64
|
-
src1_d + i3
|
|
65
|
-
&beta, dst_d + i3
|
|
65
|
+
&alpha, src0_d + (i3/dps3)*s03, lda, s02,
|
|
66
|
+
src1_d + i3 *s13, ldb, s12,
|
|
67
|
+
&beta, dst_d + i3 *s3, ldc, s2,
|
|
68
|
+
batch_count));
|
|
69
|
+
}
|
|
70
|
+
} else {
|
|
71
|
+
// Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2
|
|
72
|
+
// with non-uniform stride; would need cublasSgemmBatched with pointer arrays).
|
|
73
|
+
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
|
74
|
+
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
|
75
|
+
CUBLAS_CHECK(
|
|
76
|
+
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
|
77
|
+
ne0, ne1, ne01,
|
|
78
|
+
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
|
|
79
|
+
src1_d + i3 *s13 + i2 *s12, ldb,
|
|
80
|
+
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
|
|
81
|
+
}
|
|
66
82
|
}
|
|
67
83
|
}
|
|
68
84
|
}
|
|
@@ -3,9 +3,12 @@
|
|
|
3
3
|
|
|
4
4
|
__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1)
|
|
5
5
|
static __global__ void quantize_q8_1(
|
|
6
|
-
const float *
|
|
6
|
+
const float * x_ptr, void * vy_ptr,
|
|
7
7
|
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
|
8
8
|
const int64_t ne0, const uint32_t ne1, const uint3 ne2) {
|
|
9
|
+
ggml_cuda_pdl_lc();
|
|
10
|
+
const float * GGML_CUDA_RESTRICT x = x_ptr;
|
|
11
|
+
void * GGML_CUDA_RESTRICT vy = vy_ptr;
|
|
9
12
|
const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
|
10
13
|
|
|
11
14
|
if (i0 >= ne0) {
|
|
@@ -28,6 +31,7 @@ static __global__ void quantize_q8_1(
|
|
|
28
31
|
const int64_t ib = i_cont / QK8_1; // block index
|
|
29
32
|
const int64_t iqs = i_cont % QK8_1; // quant index
|
|
30
33
|
|
|
34
|
+
ggml_cuda_pdl_sync();
|
|
31
35
|
const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f;
|
|
32
36
|
float amax = fabsf(xi);
|
|
33
37
|
float sum = xi;
|
|
@@ -70,6 +74,102 @@ __device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
|
|
|
70
74
|
return static_cast<uint8_t>(biased);
|
|
71
75
|
}
|
|
72
76
|
|
|
77
|
+
|
|
78
|
+
static __global__ void quantize_mmq_nvfp4(
|
|
79
|
+
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
|
|
80
|
+
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
|
81
|
+
const int64_t ne0, const int64_t ne1, const int64_t ne2) {
|
|
82
|
+
#if defined(BLACKWELL_MMA_AVAILABLE)
|
|
83
|
+
|
|
84
|
+
const int64_t i0_base = ((int64_t) blockDim.x * blockIdx.y + threadIdx.x) * QK_NVFP4_SUB;
|
|
85
|
+
if (i0_base >= ne0) {
|
|
86
|
+
return;
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
const int64_t i1 = blockIdx.x;
|
|
90
|
+
const int64_t i2 = blockIdx.z % ne2;
|
|
91
|
+
const int64_t i3 = blockIdx.z / ne2;
|
|
92
|
+
const int64_t i01 = ids ? ids[i1] : i1;
|
|
93
|
+
const int64_t k_block = i0_base / QK_K;
|
|
94
|
+
const int64_t blocks_per_col = (ne0 + QK_K - 1) / QK_K;
|
|
95
|
+
if (k_block >= blocks_per_col) {
|
|
96
|
+
return;
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
const int64_t ib = blockIdx.z * ((int64_t) blocks_per_col * ne1) + k_block * ne1 + blockIdx.x;
|
|
100
|
+
block_fp4_mmq * y = (block_fp4_mmq *) vy;
|
|
101
|
+
block_fp4_mmq * yb = y + ib;
|
|
102
|
+
|
|
103
|
+
const int sub = (i0_base % QK_K) / QK_NVFP4_SUB;
|
|
104
|
+
|
|
105
|
+
float vals_raw[QK_NVFP4_SUB];
|
|
106
|
+
float amax_raw = 0.0f;
|
|
107
|
+
const int64_t base_idx = i3 * s03 + i2 * s02 + i01 * s01;
|
|
108
|
+
#pragma unroll
|
|
109
|
+
for (int k = 0; k < QK_NVFP4_SUB; k++) {
|
|
110
|
+
const int64_t i00 = i0_base + k;
|
|
111
|
+
if (i00 < ne00) {
|
|
112
|
+
const float v = x[base_idx + i00];
|
|
113
|
+
vals_raw[k] = v;
|
|
114
|
+
amax_raw = fmaxf(amax_raw, fabsf(v));
|
|
115
|
+
} else {
|
|
116
|
+
vals_raw[k] = 0.0f;
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
static constexpr int test_offsets[5] = { 0, -1, 1, -2, 2};
|
|
121
|
+
const int first_fp8_code = (int) ggml_cuda_fp32_to_ue4m3(amax_raw / 6.0f);
|
|
122
|
+
|
|
123
|
+
float best_err = FLT_MAX;
|
|
124
|
+
uint8_t fp8_code = 0;
|
|
125
|
+
float subblock_scale = 0.0f;
|
|
126
|
+
|
|
127
|
+
#pragma unroll // Check +/- 2 to find best code to reduce NVFP4 activation loss. Negligible overhead on Blackwell.
|
|
128
|
+
for (int i = 0; i < 5; i++) {
|
|
129
|
+
const int test_code = first_fp8_code + test_offsets[i];
|
|
130
|
+
if (test_code < 0 || test_code > 0x7e) {
|
|
131
|
+
continue;
|
|
132
|
+
}
|
|
133
|
+
const uint8_t code = (uint8_t) test_code;
|
|
134
|
+
const float test_scale = ggml_cuda_ue4m3_to_fp32(code);
|
|
135
|
+
const float test_inv_scale = test_scale > 0.0f ? 0.5f / test_scale : 0.0f;
|
|
136
|
+
float cur_err = 0.0f;
|
|
137
|
+
#pragma unroll
|
|
138
|
+
for (int k = 0; k < QK_NVFP4_SUB; ++k) {
|
|
139
|
+
const float v = vals_raw[k];
|
|
140
|
+
const uint8_t q = ggml_cuda_float_to_fp4_e2m1(v, test_inv_scale);
|
|
141
|
+
const float err_diff = fabsf(v) - fabsf(kvalues_mxfp4[q & 0x7]) * test_scale;
|
|
142
|
+
cur_err = fmaf(err_diff, err_diff, cur_err);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
if (cur_err < best_err) {
|
|
146
|
+
best_err = cur_err;
|
|
147
|
+
fp8_code = test_code;
|
|
148
|
+
subblock_scale = test_scale;
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
const float inv_scale = subblock_scale > 0.0f ? 0.5f / subblock_scale : 0.0f;
|
|
153
|
+
uint32_t q0 = 0;
|
|
154
|
+
uint32_t q1 = 0;
|
|
155
|
+
#pragma unroll // this is faster than the previous __nv_fp4x4_e2m1
|
|
156
|
+
for (int k = 0; k < QK_NVFP4_SUB / 4; ++k) {
|
|
157
|
+
q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 0], inv_scale) << (8 * k);
|
|
158
|
+
q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 8], inv_scale) << (8 * k + 4);
|
|
159
|
+
q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 4], inv_scale) << (8 * k);
|
|
160
|
+
q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 12], inv_scale) << (8 * k + 4);
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
uint32_t * yqs = reinterpret_cast<uint32_t *>(yb->qs);
|
|
164
|
+
yqs[2 * sub + 0] = q0;
|
|
165
|
+
yqs[2 * sub + 1] = q1;
|
|
166
|
+
reinterpret_cast<uint8_t *>(yb->d4)[sub] = fp8_code;
|
|
167
|
+
#else
|
|
168
|
+
NO_DEVICE_CODE; // This is for Blackwell NVFP4 activations only.
|
|
169
|
+
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
|
170
|
+
|
|
171
|
+
}
|
|
172
|
+
|
|
73
173
|
// quantize values in the format mxfp4 is stored which is interleaved nibbles
|
|
74
174
|
// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31
|
|
75
175
|
static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
|
|
@@ -100,6 +200,7 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
|
|
|
100
200
|
const int64_t i2 = blockIdx.z % ne2;
|
|
101
201
|
const int64_t i3 = blockIdx.z / ne2;
|
|
102
202
|
|
|
203
|
+
ggml_cuda_pdl_sync();
|
|
103
204
|
const int64_t i01 = ids ? ids[i1] : i1;
|
|
104
205
|
const int64_t i02 = i2;
|
|
105
206
|
const int64_t i03 = i3;
|
|
@@ -192,6 +293,7 @@ static __global__ void quantize_mmq_q8_1(
|
|
|
192
293
|
const int64_t i3 = blockIdx.z / ne2;
|
|
193
294
|
|
|
194
295
|
const int64_t i00 = i0;
|
|
296
|
+
ggml_cuda_pdl_sync();
|
|
195
297
|
const int64_t i01 = ids ? ids[i1] : i1;
|
|
196
298
|
const int64_t i02 = i2;
|
|
197
299
|
const int64_t i03 = i3;
|
|
@@ -282,7 +384,8 @@ void quantize_row_q8_1_cuda(
|
|
|
282
384
|
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
|
283
385
|
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
|
|
284
386
|
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
|
|
285
|
-
|
|
387
|
+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, block_size, 0, stream);
|
|
388
|
+
ggml_cuda_kernel_launch(quantize_q8_1, launch_params, x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);
|
|
286
389
|
GGML_UNUSED(type_src0);
|
|
287
390
|
}
|
|
288
391
|
|
|
@@ -316,28 +419,32 @@ void quantize_mmq_q8_1_cuda(
|
|
|
316
419
|
}
|
|
317
420
|
}
|
|
318
421
|
|
|
319
|
-
void
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
422
|
+
void quantize_mmq_fp4_cuda(
|
|
423
|
+
const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
|
|
424
|
+
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
|
425
|
+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
|
|
426
|
+
GGML_ASSERT(type_src0 == GGML_TYPE_MXFP4 || type_src0 == GGML_TYPE_NVFP4);
|
|
427
|
+
GGML_ASSERT(ne0 > 0);
|
|
428
|
+
|
|
429
|
+
if (type_src0 == GGML_TYPE_NVFP4) {
|
|
430
|
+
GGML_ASSERT(ne00 % QK_NVFP4 == 0);
|
|
431
|
+
constexpr int nvfp4_block_size = 128;
|
|
432
|
+
const int64_t block_num_y = (ne0 + QK_NVFP4_SUB * nvfp4_block_size - 1) / (QK_NVFP4_SUB * nvfp4_block_size);
|
|
433
|
+
const dim3 block_size(nvfp4_block_size, 1, 1);
|
|
434
|
+
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
|
|
435
|
+
quantize_mmq_nvfp4<<<num_blocks, block_size, 0, stream>>>(
|
|
436
|
+
x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
|
437
|
+
} else {
|
|
438
|
+
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
|
|
439
|
+
|
|
440
|
+
constexpr int nwarps = 8;
|
|
441
|
+
constexpr int vals_per_warp = 2 * QK_MXFP4;
|
|
442
|
+
constexpr int vals_per_block = nwarps * vals_per_warp;
|
|
443
|
+
|
|
444
|
+
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
|
|
445
|
+
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
|
|
446
|
+
const dim3 block_size(WARP_SIZE, nwarps, 1);
|
|
447
|
+
|
|
448
|
+
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
|
449
|
+
}
|
|
343
450
|
}
|
|
@@ -26,7 +26,7 @@ void quantize_mmq_q8_1_cuda(
|
|
|
26
26
|
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
|
|
27
27
|
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
|
|
28
28
|
|
|
29
|
-
void
|
|
29
|
+
void quantize_mmq_fp4_cuda(const float * x,
|
|
30
30
|
const int32_t * ids,
|
|
31
31
|
void * vy,
|
|
32
32
|
ggml_type type_src0,
|
|
@@ -2,7 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
|
|
4
4
|
template <bool norm>
|
|
5
|
-
static __global__ void reduce_rows_f32(const float *
|
|
5
|
+
static __global__ void reduce_rows_f32(const float * x_ptr, float * dst_ptr, const int ncols) {
|
|
6
|
+
const float * GGML_CUDA_RESTRICT x = x_ptr;
|
|
7
|
+
float * GGML_CUDA_RESTRICT dst = dst_ptr;
|
|
6
8
|
const int row = blockIdx.x;
|
|
7
9
|
const int col = threadIdx.x;
|
|
8
10
|
|
|
@@ -10,6 +12,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r
|
|
|
10
12
|
const int num_unroll = 8;
|
|
11
13
|
float temp[num_unroll];
|
|
12
14
|
float sum_temp[num_unroll] = { 0.0f };
|
|
15
|
+
|
|
16
|
+
ggml_cuda_pdl_sync();
|
|
13
17
|
for (int i = col; i < ncols;) {
|
|
14
18
|
for (int j = 0; j < num_unroll; ++j) {
|
|
15
19
|
if (i < ncols) {
|
|
@@ -134,6 +134,7 @@ static __global__ void rope_neox(const T * x,
|
|
|
134
134
|
const float * freq_factors,
|
|
135
135
|
const int64_t * row_indices,
|
|
136
136
|
const int set_rows_stride) {
|
|
137
|
+
ggml_cuda_pdl_lc();
|
|
137
138
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
138
139
|
|
|
139
140
|
if (i0 >= ne00) {
|
|
@@ -148,6 +149,7 @@ static __global__ void rope_neox(const T * x,
|
|
|
148
149
|
|
|
149
150
|
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
|
150
151
|
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
|
152
|
+
ggml_cuda_pdl_sync();
|
|
151
153
|
|
|
152
154
|
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
|
153
155
|
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
|
|
@@ -216,6 +218,7 @@ static __global__ void rope_multi(const T * x,
|
|
|
216
218
|
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
|
217
219
|
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
|
218
220
|
|
|
221
|
+
ggml_cuda_pdl_sync();
|
|
219
222
|
if (i0 >= n_dims) {
|
|
220
223
|
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
|
221
224
|
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
|
@@ -300,6 +303,7 @@ static __global__ void rope_vision(const T * x,
|
|
|
300
303
|
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
|
301
304
|
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
|
302
305
|
|
|
306
|
+
ggml_cuda_pdl_sync();
|
|
303
307
|
const int sect_dims = sections.v[0] + sections.v[1];
|
|
304
308
|
const int sec_w = sections.v[1] + sections.v[0];
|
|
305
309
|
const int sector = (i0 / 2) % sect_dims;
|
|
@@ -399,13 +403,14 @@ static void rope_neox_cuda(const T * x,
|
|
|
399
403
|
const dim3 block_nums(nr, n_blocks_x, 1);
|
|
400
404
|
|
|
401
405
|
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
|
406
|
+
const ggml_cuda_kernel_launch_params launch_params = {block_nums, block_dims, 0, stream};
|
|
402
407
|
|
|
403
408
|
if (freq_factors == nullptr) {
|
|
404
|
-
rope_neox<forward, false
|
|
409
|
+
ggml_cuda_kernel_launch(rope_neox<forward, false, T, D>, launch_params,
|
|
405
410
|
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
|
406
411
|
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
|
|
407
412
|
} else {
|
|
408
|
-
rope_neox<forward, true
|
|
413
|
+
ggml_cuda_kernel_launch(rope_neox<forward, true, T, D>, launch_params,
|
|
409
414
|
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
|
410
415
|
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
|
|
411
416
|
}
|
|
@@ -443,11 +448,13 @@ static void rope_multi_cuda(const T * x,
|
|
|
443
448
|
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
|
444
449
|
|
|
445
450
|
if (freq_factors == nullptr) {
|
|
446
|
-
|
|
451
|
+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
|
|
452
|
+
ggml_cuda_kernel_launch(rope_multi<forward, false, T>, launch_params,
|
|
447
453
|
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
|
448
454
|
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
|
449
455
|
} else {
|
|
450
|
-
|
|
456
|
+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
|
|
457
|
+
ggml_cuda_kernel_launch(rope_multi<forward, true, T>, launch_params,
|
|
451
458
|
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
|
452
459
|
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
|
453
460
|
}
|
|
@@ -3,9 +3,11 @@
|
|
|
3
3
|
#define MAX_GRIDDIM_X 0x7FFFFFFF
|
|
4
4
|
|
|
5
5
|
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) {
|
|
6
|
+
ggml_cuda_pdl_lc();
|
|
6
7
|
int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
|
|
7
8
|
int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x;
|
|
8
9
|
|
|
10
|
+
ggml_cuda_pdl_sync();
|
|
9
11
|
for (int64_t i = tid; i < nelements; i += stride) {
|
|
10
12
|
dst[i] = scale * x[i] + bias;
|
|
11
13
|
}
|
|
@@ -13,7 +15,8 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
|
|
|
13
15
|
|
|
14
16
|
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) {
|
|
15
17
|
const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
|
16
|
-
|
|
18
|
+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream);
|
|
19
|
+
ggml_cuda_kernel_launch(scale_f32, launch_params, x, dst, scale, bias, nelements);
|
|
17
20
|
}
|
|
18
21
|
|
|
19
22
|
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
@@ -53,6 +53,7 @@ static __global__ void k_set_rows_quant(const float * __restrict__ src0,
|
|
|
53
53
|
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
|
|
54
54
|
const int64_t i10 = i01;
|
|
55
55
|
|
|
56
|
+
ggml_cuda_pdl_sync();
|
|
56
57
|
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
|
57
58
|
|
|
58
59
|
const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
|
|
@@ -110,9 +111,9 @@ static void set_rows_cuda_quant(
|
|
|
110
111
|
}
|
|
111
112
|
|
|
112
113
|
template <typename src_t, typename idx_t, typename dst_t>
|
|
113
|
-
static __global__ void k_set_rows(const src_t *
|
|
114
|
-
const idx_t *
|
|
115
|
-
dst_t *
|
|
114
|
+
static __global__ void k_set_rows(const src_t * src0_ptr,
|
|
115
|
+
const idx_t * src1_ptr,
|
|
116
|
+
dst_t * dst_ptr,
|
|
116
117
|
const int64_t ne_total,
|
|
117
118
|
const int64_t ne10,
|
|
118
119
|
const int64_t ne11,
|
|
@@ -132,6 +133,9 @@ static __global__ void k_set_rows(const src_t * __restrict__ src0,
|
|
|
132
133
|
const uint3 ne02,
|
|
133
134
|
const uint3 ne11_fd,
|
|
134
135
|
const uint3 ne12_fd) {
|
|
136
|
+
const src_t * GGML_CUDA_RESTRICT src0 = src0_ptr;
|
|
137
|
+
const idx_t * GGML_CUDA_RESTRICT src1 = src1_ptr;
|
|
138
|
+
dst_t * GGML_CUDA_RESTRICT dst = dst_ptr;
|
|
135
139
|
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
|
|
136
140
|
|
|
137
141
|
if (i >= ne_total) {
|
|
@@ -157,7 +161,9 @@ static __global__ void k_set_rows(const src_t * __restrict__ src0,
|
|
|
157
161
|
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
|
|
158
162
|
const int64_t i10 = i01;
|
|
159
163
|
|
|
164
|
+
ggml_cuda_pdl_sync();
|
|
160
165
|
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
|
166
|
+
ggml_cuda_pdl_lc();
|
|
161
167
|
|
|
162
168
|
const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
|
|
163
169
|
dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
|
|
@@ -203,9 +209,11 @@ static void set_rows_cuda(
|
|
|
203
209
|
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
|
|
204
210
|
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
|
|
205
211
|
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
212
|
+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_size, block_size, 0, stream);
|
|
213
|
+
ggml_cuda_kernel_launch(k_set_rows<src_t, idx_t, dst_t>, launch_params,
|
|
214
|
+
src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
|
|
215
|
+
s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
|
|
216
|
+
ne11_fd, ne12_fd);
|
|
209
217
|
}
|
|
210
218
|
}
|
|
211
219
|
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
#include "snake.cuh"
|
|
2
|
+
#include "convert.cuh"
|
|
3
|
+
|
|
4
|
+
// Fused Snake activation: y = x + sin^2(a * x) * inv_b
|
|
5
|
+
// x: [T, C] (T contiguous), a: [1, C], inv_b: [1, C]
|
|
6
|
+
// Supports F32, F16, BF16 data with F32 compute.
|
|
7
|
+
|
|
8
|
+
template <typename T>
|
|
9
|
+
static __global__ void snake_kernel(
|
|
10
|
+
const T * __restrict__ x,
|
|
11
|
+
const float * __restrict__ a,
|
|
12
|
+
const float * __restrict__ inv_b,
|
|
13
|
+
T * __restrict__ dst,
|
|
14
|
+
const int total,
|
|
15
|
+
const uint3 T_len_fastdiv) {
|
|
16
|
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
17
|
+
if (idx >= total) return;
|
|
18
|
+
|
|
19
|
+
const int c = (int) fastdiv((uint32_t) idx, T_len_fastdiv);
|
|
20
|
+
|
|
21
|
+
const float xi = ggml_cuda_cast<float>(x[idx]);
|
|
22
|
+
const float s = sinf(a[c] * xi);
|
|
23
|
+
dst[idx] = ggml_cuda_cast<T>(xi + s * s * inv_b[c]);
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
// Internal launcher with explicit x/a/inv_b/dst tensors.
|
|
27
|
+
// Shared by the public op (reads dst->src) and the fusion path (explicit args).
|
|
28
|
+
static void launch_snake(ggml_backend_cuda_context & ctx,
|
|
29
|
+
const ggml_tensor * x,
|
|
30
|
+
const ggml_tensor * a,
|
|
31
|
+
const ggml_tensor * inv_b,
|
|
32
|
+
ggml_tensor * dst) {
|
|
33
|
+
const float * a_d = (const float *)a->data;
|
|
34
|
+
const float * inv_b_d = (const float *)inv_b->data;
|
|
35
|
+
|
|
36
|
+
const int T = (int)x->ne[0];
|
|
37
|
+
const int C = (int)x->ne[1];
|
|
38
|
+
const int total = T * C;
|
|
39
|
+
const uint3 T_len_fastdiv = init_fastdiv_values((uint64_t) T);
|
|
40
|
+
|
|
41
|
+
const int block_size = 256;
|
|
42
|
+
const int grid_size = (total + block_size - 1) / block_size;
|
|
43
|
+
|
|
44
|
+
cudaStream_t stream = ctx.stream();
|
|
45
|
+
|
|
46
|
+
switch (x->type) {
|
|
47
|
+
case GGML_TYPE_F32: {
|
|
48
|
+
snake_kernel<<<grid_size, block_size, 0, stream>>>(
|
|
49
|
+
(const float *)x->data, a_d, inv_b_d, (float *)dst->data, total, T_len_fastdiv);
|
|
50
|
+
} break;
|
|
51
|
+
case GGML_TYPE_F16: {
|
|
52
|
+
snake_kernel<<<grid_size, block_size, 0, stream>>>(
|
|
53
|
+
(const half *)x->data, a_d, inv_b_d, (half *)dst->data, total, T_len_fastdiv);
|
|
54
|
+
} break;
|
|
55
|
+
case GGML_TYPE_BF16: {
|
|
56
|
+
snake_kernel<<<grid_size, block_size, 0, stream>>>(
|
|
57
|
+
(const nv_bfloat16 *)x->data, a_d, inv_b_d, (nv_bfloat16 *)dst->data, total, T_len_fastdiv);
|
|
58
|
+
} break;
|
|
59
|
+
default:
|
|
60
|
+
GGML_ABORT("snake: unsupported type");
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
// Fusion entry: caller supplies x/a/inv_b explicitly from the matched
|
|
65
|
+
// mul -> sin -> sqr -> mul -> add pattern. The dst is the trailing add output.
|
|
66
|
+
void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx,
|
|
67
|
+
const ggml_tensor * x,
|
|
68
|
+
const ggml_tensor * a,
|
|
69
|
+
const ggml_tensor * inv_b,
|
|
70
|
+
ggml_tensor * dst) {
|
|
71
|
+
launch_snake(ctx, x, a, inv_b, dst);
|
|
72
|
+
}
|