whispercpp 1.3.6 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/README.md +38 -5
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -8
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +36 -42
- data/ext/ruby_whisper.h +135 -0
- data/ext/ruby_whisper_context.c +107 -28
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -65
- data/ext/ruby_whisper_segment.c +6 -6
- data/ext/ruby_whisper_transcribe.cpp +42 -15
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +1 -1
- data/ext/sources/examples/cli/cli.cpp +43 -9
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +199 -163
- data/ext/sources/ggml/CMakeLists.txt +21 -13
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +72 -10
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-rpc.h +3 -3
- data/ext/sources/ggml/include/ggml.h +101 -9
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +22 -5
- data/ext/sources/ggml/src/ggml-alloc.c +5 -1
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
- data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
- data/ext/sources/ggml/src/ggml-impl.h +6 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
- data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +289 -114
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
- data/ext/sources/ggml/src/ggml.c +110 -28
- data/ext/sources/ggml/src/gguf.cpp +173 -28
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +56 -12
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +411 -62
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +24 -6
- data/whispercpp.gemspec +2 -2
- metadata +215 -281
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
- data/ext/sources/examples/talk-llama/llama-context.h +0 -359
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
- data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
- data/ext/sources/examples/talk-llama/llama-model.h +0 -597
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
- data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
- data/ext/sources/examples/talk-llama/llama.h +0 -1573
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -704
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
- /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
|
@@ -0,0 +1,2263 @@
|
|
|
1
|
+
#include "ggml.h"
|
|
2
|
+
#include "ggml-impl.h"
|
|
3
|
+
#include "ggml-backend.h"
|
|
4
|
+
#include "ggml-backend-impl.h"
|
|
5
|
+
#include "ggml-alloc.h"
|
|
6
|
+
#include "ggml-cpp.h"
|
|
7
|
+
|
|
8
|
+
#include <algorithm>
|
|
9
|
+
#include <cassert>
|
|
10
|
+
#include <cmath>
|
|
11
|
+
#include <cstddef>
|
|
12
|
+
#include <cstdint>
|
|
13
|
+
#include <cstring>
|
|
14
|
+
#include <map>
|
|
15
|
+
#include <memory>
|
|
16
|
+
#include <set>
|
|
17
|
+
#include <string>
|
|
18
|
+
#include <tuple>
|
|
19
|
+
#include <utility>
|
|
20
|
+
#include <vector>
|
|
21
|
+
|
|
22
|
+
struct ggml_backend_meta_device;
|
|
23
|
+
struct ggml_backend_meta_buffer_type;
|
|
24
|
+
struct ggml_backend_meta_buffer;
|
|
25
|
+
struct ggml_backend_meta;
|
|
26
|
+
|
|
27
|
+
const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis) {
|
|
28
|
+
switch (split_axis) {
|
|
29
|
+
case GGML_BACKEND_SPLIT_AXIS_0:
|
|
30
|
+
return "0";
|
|
31
|
+
case GGML_BACKEND_SPLIT_AXIS_1:
|
|
32
|
+
return "1";
|
|
33
|
+
case GGML_BACKEND_SPLIT_AXIS_2:
|
|
34
|
+
return "2";
|
|
35
|
+
case GGML_BACKEND_SPLIT_AXIS_3:
|
|
36
|
+
return "3";
|
|
37
|
+
case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
|
|
38
|
+
return "MIRRORED";
|
|
39
|
+
case GGML_BACKEND_SPLIT_AXIS_PARTIAL:
|
|
40
|
+
return "PARTIAL";
|
|
41
|
+
case GGML_BACKEND_SPLIT_AXIS_NONE:
|
|
42
|
+
return "NONE";
|
|
43
|
+
case GGML_BACKEND_SPLIT_AXIS_UNKNOWN:
|
|
44
|
+
return "UNKNOWN";
|
|
45
|
+
default:
|
|
46
|
+
GGML_ABORT("fatal error");
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
//
|
|
51
|
+
// meta backend device
|
|
52
|
+
//
|
|
53
|
+
|
|
54
|
+
struct ggml_backend_meta_device_context {
|
|
55
|
+
std::vector<ggml_backend_dev_t> simple_devs;
|
|
56
|
+
ggml_backend_meta_get_split_state_t get_split_state;
|
|
57
|
+
void * get_split_state_ud;
|
|
58
|
+
|
|
59
|
+
std::string name;
|
|
60
|
+
std::string description;
|
|
61
|
+
|
|
62
|
+
ggml_backend_meta_device_context(
|
|
63
|
+
std::vector<ggml_backend_dev_t> simple_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) :
|
|
64
|
+
simple_devs(std::move(simple_devs)), get_split_state(get_split_state), get_split_state_ud(get_split_state_ud) {
|
|
65
|
+
name = std::string("Meta(");
|
|
66
|
+
description = std::string("Meta(");
|
|
67
|
+
for (size_t i = 0; i < simple_devs.size(); i++) {
|
|
68
|
+
if (i > 0) {
|
|
69
|
+
name += ",";
|
|
70
|
+
description += ",";
|
|
71
|
+
}
|
|
72
|
+
name += ggml_backend_dev_name (simple_devs[i]);
|
|
73
|
+
description += ggml_backend_dev_description(simple_devs[i]);
|
|
74
|
+
}
|
|
75
|
+
name += ")";
|
|
76
|
+
description += ")";
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
bool operator<(const ggml_backend_meta_device_context & other) const {
|
|
80
|
+
return std::tie(simple_devs, get_split_state, get_split_state_ud)
|
|
81
|
+
< std::tie(other.simple_devs, other.get_split_state, other.get_split_state_ud);
|
|
82
|
+
}
|
|
83
|
+
};
|
|
84
|
+
|
|
85
|
+
static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev);
|
|
86
|
+
|
|
87
|
+
static const char * ggml_backend_meta_device_get_name(ggml_backend_dev_t dev) {
|
|
88
|
+
GGML_ASSERT(ggml_backend_dev_is_meta(dev));
|
|
89
|
+
const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
|
|
90
|
+
return meta_dev_ctx->name.c_str();
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
static const char * ggml_backend_meta_device_get_description(ggml_backend_dev_t dev) {
|
|
94
|
+
GGML_ASSERT(ggml_backend_dev_is_meta(dev));
|
|
95
|
+
const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
|
|
96
|
+
return meta_dev_ctx->description.c_str();
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
static void ggml_backend_meta_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
|
100
|
+
GGML_ASSERT(ggml_backend_dev_is_meta(dev));
|
|
101
|
+
const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
|
|
102
|
+
*free = 0;
|
|
103
|
+
*total = 0;
|
|
104
|
+
for (ggml_backend_dev_t dev : meta_dev_ctx->simple_devs) {
|
|
105
|
+
size_t tmp_free, tmp_total;
|
|
106
|
+
ggml_backend_dev_memory(dev, &tmp_free, &tmp_total);
|
|
107
|
+
*free += tmp_free;
|
|
108
|
+
*total += tmp_total;
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
static enum ggml_backend_dev_type ggml_backend_meta_device_get_type(ggml_backend_dev_t dev) {
|
|
113
|
+
return GGML_BACKEND_DEVICE_TYPE_META;
|
|
114
|
+
|
|
115
|
+
GGML_UNUSED(dev);
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
static void ggml_backend_meta_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
|
119
|
+
GGML_ASSERT(ggml_backend_dev_is_meta(dev));
|
|
120
|
+
const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
|
|
121
|
+
|
|
122
|
+
// TODO replace placeholders
|
|
123
|
+
props->name = ggml_backend_meta_device_get_name(dev);
|
|
124
|
+
props->description = ggml_backend_meta_device_get_description(dev);
|
|
125
|
+
props->type = ggml_backend_meta_device_get_type(dev);
|
|
126
|
+
props->device_id = 0;
|
|
127
|
+
|
|
128
|
+
ggml_backend_meta_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
|
129
|
+
|
|
130
|
+
props->caps = {
|
|
131
|
+
/* .async = */ true,
|
|
132
|
+
/* .host_buffer = */ false, // Not implemented.
|
|
133
|
+
/* .buffer_from_host_ptr = */ false, // Not implemented.
|
|
134
|
+
/* .events = */ false, // Not implemented.
|
|
135
|
+
};
|
|
136
|
+
for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) {
|
|
137
|
+
ggml_backend_dev_props tmp_props;
|
|
138
|
+
ggml_backend_dev_get_props(simple_dev, &tmp_props);
|
|
139
|
+
props->caps.async = props->caps.async && tmp_props.caps.async;
|
|
140
|
+
props->caps.host_buffer = props->caps.host_buffer && tmp_props.caps.host_buffer;
|
|
141
|
+
props->caps.buffer_from_host_ptr = props->caps.buffer_from_host_ptr && tmp_props.caps.buffer_from_host_ptr;
|
|
142
|
+
props->caps.events = props->caps.events && tmp_props.caps.events;
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params);
|
|
147
|
+
|
|
148
|
+
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev);
|
|
149
|
+
|
|
150
|
+
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev);
|
|
151
|
+
|
|
152
|
+
static bool ggml_backend_meta_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
|
153
|
+
GGML_ASSERT(ggml_backend_dev_is_meta(dev));
|
|
154
|
+
const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
|
|
155
|
+
return std::all_of(meta_dev_ctx->simple_devs.begin(), meta_dev_ctx->simple_devs.end(),
|
|
156
|
+
[op](ggml_backend_dev_t simple_dev) { return ggml_backend_dev_supports_op(simple_dev, op); });
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
static bool ggml_backend_meta_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
|
160
|
+
GGML_ASSERT(ggml_backend_dev_is_meta(dev));
|
|
161
|
+
ggml_backend_dev_t dev_buft = ggml_backend_buft_get_device(buft);
|
|
162
|
+
if (!ggml_backend_dev_is_meta(dev_buft)) {
|
|
163
|
+
return false;
|
|
164
|
+
}
|
|
165
|
+
const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
|
|
166
|
+
const ggml_backend_meta_device_context * meta_buft_dev_ctx = (const ggml_backend_meta_device_context *) dev_buft->context;
|
|
167
|
+
if (meta_dev_ctx->simple_devs.size() != meta_buft_dev_ctx->simple_devs.size()) {
|
|
168
|
+
return false;
|
|
169
|
+
}
|
|
170
|
+
for (size_t i = 0; i < meta_dev_ctx->simple_devs.size(); i++) {
|
|
171
|
+
if (meta_dev_ctx->simple_devs[i] != meta_buft_dev_ctx->simple_devs[i]) {
|
|
172
|
+
return false;
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
return true;
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
static const ggml_backend_device_i ggml_backend_meta_device_iface = {
|
|
179
|
+
/* .get_name = */ ggml_backend_meta_device_get_name,
|
|
180
|
+
/* .get_description = */ ggml_backend_meta_device_get_description,
|
|
181
|
+
/* .get_memory = */ ggml_backend_meta_device_get_memory,
|
|
182
|
+
/* .get_type = */ ggml_backend_meta_device_get_type,
|
|
183
|
+
/* .get_props = */ ggml_backend_meta_device_get_props,
|
|
184
|
+
/* .init_backend = */ ggml_backend_meta_device_init_backend,
|
|
185
|
+
/* .get_buffer_type = */ ggml_backend_meta_device_get_buffer_type,
|
|
186
|
+
/* .get_host_buffer_type = */ ggml_backend_meta_device_get_host_buffer_type,
|
|
187
|
+
/* .buffer_from_host_ptr = */ nullptr,
|
|
188
|
+
/* .supports_op = */ ggml_backend_meta_device_supports_op,
|
|
189
|
+
/* .supports_buft = */ ggml_backend_meta_device_supports_buft,
|
|
190
|
+
/* .offload_op = */ nullptr,
|
|
191
|
+
/* .event_new = */ nullptr,
|
|
192
|
+
/* .event_free = */ nullptr,
|
|
193
|
+
/* .event_synchronize = */ nullptr,
|
|
194
|
+
};
|
|
195
|
+
|
|
196
|
+
static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev) {
|
|
197
|
+
return dev != nullptr && dev->iface.get_name == ggml_backend_meta_device_iface.get_name;
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
static size_t ggml_backend_meta_dev_n_devs(ggml_backend_dev_t meta_dev) {
|
|
201
|
+
GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev));
|
|
202
|
+
const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context;
|
|
203
|
+
return meta_dev_ctx->simple_devs.size();
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
static ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev, size_t index) {
|
|
207
|
+
GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev));
|
|
208
|
+
const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context;
|
|
209
|
+
GGML_ASSERT(index < meta_dev_ctx->simple_devs.size());
|
|
210
|
+
return meta_dev_ctx->simple_devs[index];
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
ggml_backend_dev_t ggml_backend_meta_device(
|
|
214
|
+
ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) {
|
|
215
|
+
GGML_ASSERT(n_devs <= GGML_BACKEND_META_MAX_DEVICES);
|
|
216
|
+
// TODO: this is not thread-safe - needs to be fixed
|
|
217
|
+
static std::vector<std::unique_ptr<ggml_backend_meta_device_context>> ctxs;
|
|
218
|
+
static std::map<ggml_backend_meta_device_context, struct ggml_backend_device> meta_devs;
|
|
219
|
+
|
|
220
|
+
std::vector<ggml_backend_dev_t> simple_devs;
|
|
221
|
+
simple_devs.reserve(n_devs);
|
|
222
|
+
for (size_t i = 0; i < n_devs; i++) {
|
|
223
|
+
simple_devs.push_back(devs[i]);
|
|
224
|
+
}
|
|
225
|
+
ggml_backend_meta_device_context ctx(simple_devs, get_split_state, get_split_state_ud);
|
|
226
|
+
|
|
227
|
+
{
|
|
228
|
+
auto it = meta_devs.find(ctx);
|
|
229
|
+
if (it != meta_devs.end()) {
|
|
230
|
+
return &it->second;
|
|
231
|
+
}
|
|
232
|
+
}
|
|
233
|
+
ctxs.push_back(std::make_unique<ggml_backend_meta_device_context>(ctx));
|
|
234
|
+
|
|
235
|
+
struct ggml_backend_device meta_dev = {
|
|
236
|
+
/*iface =*/ ggml_backend_meta_device_iface,
|
|
237
|
+
/*reg =*/ nullptr,
|
|
238
|
+
/*ctx =*/ ctxs.back().get(),
|
|
239
|
+
};
|
|
240
|
+
|
|
241
|
+
auto result = meta_devs.emplace(*ctxs.back(), meta_dev);
|
|
242
|
+
return &result.first->second;
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
//
|
|
246
|
+
// meta backend buffer type
|
|
247
|
+
//
|
|
248
|
+
|
|
249
|
+
struct ggml_backend_meta_buffer_type_context {
|
|
250
|
+
std::vector<ggml_backend_buffer_type_t> simple_bufts;
|
|
251
|
+
|
|
252
|
+
std::string name;
|
|
253
|
+
|
|
254
|
+
ggml_backend_meta_buffer_type_context(std::vector<ggml_backend_buffer_type_t> simple_bufts) : simple_bufts(std::move(simple_bufts)) {
|
|
255
|
+
name = "Meta(";
|
|
256
|
+
for (size_t i = 0; i < simple_bufts.size(); i++) {
|
|
257
|
+
if (i > 0) {
|
|
258
|
+
name += ",";
|
|
259
|
+
}
|
|
260
|
+
name += ggml_backend_buft_name(simple_bufts[i]);
|
|
261
|
+
}
|
|
262
|
+
name += ")";
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
bool operator<(const ggml_backend_meta_buffer_type_context & other) const {
|
|
266
|
+
return simple_bufts < other.simple_bufts;
|
|
267
|
+
}
|
|
268
|
+
};
|
|
269
|
+
|
|
270
|
+
static size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft) {
|
|
271
|
+
GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft));
|
|
272
|
+
const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context;
|
|
273
|
+
return meta_buft_ctx->simple_bufts.size();
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
static const char * ggml_backend_meta_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
|
277
|
+
GGML_ASSERT(ggml_backend_buft_is_meta(buft));
|
|
278
|
+
const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) buft->context;
|
|
279
|
+
return meta_buft_ctx->name.c_str();
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
static ggml_backend_buffer_type_t ggml_backend_meta_buft_simple_buft(ggml_backend_buffer_type_t meta_buft, size_t index) {
|
|
283
|
+
GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft));
|
|
284
|
+
const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context;
|
|
285
|
+
GGML_ASSERT(index < meta_buft_ctx->simple_bufts.size());
|
|
286
|
+
return meta_buft_ctx->simple_bufts[index];
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
|
|
290
|
+
|
|
291
|
+
static size_t ggml_backend_meta_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
292
|
+
const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
|
|
293
|
+
size_t max_alignment = 1;
|
|
294
|
+
for (size_t i = 0; i < n_simple_bufts; i++) {
|
|
295
|
+
const size_t alignment = ggml_backend_buft_get_alignment(ggml_backend_meta_buft_simple_buft(buft, i));
|
|
296
|
+
max_alignment = std::max(max_alignment, alignment);
|
|
297
|
+
GGML_ASSERT(max_alignment % alignment == 0);
|
|
298
|
+
}
|
|
299
|
+
return max_alignment;
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
static size_t ggml_backend_meta_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
|
303
|
+
const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
|
|
304
|
+
size_t max_size = SIZE_MAX;
|
|
305
|
+
for (size_t i = 0; i < n_simple_bufts; i++) {
|
|
306
|
+
max_size = std::min(max_size, ggml_backend_buft_get_max_size(ggml_backend_meta_buft_simple_buft(buft, i)));
|
|
307
|
+
}
|
|
308
|
+
return max_size;
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
static size_t ggml_backend_meta_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
|
312
|
+
const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
|
|
313
|
+
size_t max_alloc_size = 0;
|
|
314
|
+
for (size_t i = 0; i < n_simple_bufts; i++) {
|
|
315
|
+
const size_t alloc_size = ggml_backend_buft_get_alloc_size(ggml_backend_meta_buft_simple_buft(buft, i), tensor);
|
|
316
|
+
max_alloc_size = std::max(max_alloc_size, alloc_size);
|
|
317
|
+
}
|
|
318
|
+
return max_alloc_size;
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
static bool ggml_backend_meta_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
|
322
|
+
const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
|
|
323
|
+
for (size_t i = 0; i < n_simple_bufts; i++) {
|
|
324
|
+
if (!ggml_backend_buft_is_host(ggml_backend_meta_buft_simple_buft(buft, i))) {
|
|
325
|
+
return false;
|
|
326
|
+
}
|
|
327
|
+
}
|
|
328
|
+
return true;
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
static const struct ggml_backend_buffer_type_i ggml_backend_meta_buffer_type_iface = {
|
|
332
|
+
/* .get_name = */ ggml_backend_meta_buffer_type_get_name,
|
|
333
|
+
/* .alloc_buffer = */ ggml_backend_meta_buffer_type_alloc_buffer,
|
|
334
|
+
/* .get_alignment = */ ggml_backend_meta_buffer_type_get_alignment,
|
|
335
|
+
/* .get_max_size = */ ggml_backend_meta_buffer_type_get_max_size,
|
|
336
|
+
/* .get_alloc_size = */ ggml_backend_meta_buffer_type_get_alloc_size,
|
|
337
|
+
/* .is_host = */ ggml_backend_meta_buffer_type_is_host,
|
|
338
|
+
};
|
|
339
|
+
|
|
340
|
+
bool ggml_backend_buft_is_meta(ggml_backend_buffer_type_t buft) {
|
|
341
|
+
return buft != nullptr && buft->iface.get_name == ggml_backend_meta_buffer_type_iface.get_name;
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev) {
|
|
345
|
+
static std::map<ggml_backend_dev_t, struct ggml_backend_buffer_type> meta_bufts;
|
|
346
|
+
GGML_ASSERT(ggml_backend_dev_is_meta(dev));
|
|
347
|
+
{
|
|
348
|
+
auto it = meta_bufts.find(dev);
|
|
349
|
+
if (it != meta_bufts.end()) {
|
|
350
|
+
return &it->second;
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
const size_t n_devs = ggml_backend_meta_dev_n_devs(dev);
|
|
355
|
+
std::vector<ggml_backend_buffer_type_t> simple_bufts;
|
|
356
|
+
simple_bufts.reserve(n_devs);
|
|
357
|
+
for (size_t i = 0; i < n_devs; i++) {
|
|
358
|
+
simple_bufts.push_back(ggml_backend_dev_buffer_type(ggml_backend_meta_dev_simple_dev(dev, i)));
|
|
359
|
+
}
|
|
360
|
+
ggml_backend_meta_buffer_type_context * buft_ctx = new ggml_backend_meta_buffer_type_context(simple_bufts);
|
|
361
|
+
|
|
362
|
+
struct ggml_backend_buffer_type meta_buft = {
|
|
363
|
+
/*iface =*/ ggml_backend_meta_buffer_type_iface,
|
|
364
|
+
/*device =*/ dev,
|
|
365
|
+
/*ctx =*/ buft_ctx,
|
|
366
|
+
};
|
|
367
|
+
auto result = meta_bufts.emplace(dev, meta_buft);
|
|
368
|
+
return &result.first->second;
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev) {
|
|
372
|
+
GGML_ASSERT(ggml_backend_dev_is_meta(dev));
|
|
373
|
+
const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
|
|
374
|
+
|
|
375
|
+
ggml_backend_buffer_type_t host_buft = nullptr;
|
|
376
|
+
for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) {
|
|
377
|
+
ggml_backend_buffer_type_t simple_host_buft = ggml_backend_dev_host_buffer_type(simple_dev);
|
|
378
|
+
if (simple_host_buft == nullptr) {
|
|
379
|
+
return nullptr;
|
|
380
|
+
}
|
|
381
|
+
if (host_buft == nullptr) {
|
|
382
|
+
host_buft = simple_host_buft;
|
|
383
|
+
} else if (host_buft != simple_host_buft) {
|
|
384
|
+
// if different simple devices have different host buffer types,
|
|
385
|
+
// we cannot provide a single host buffer type for the meta device
|
|
386
|
+
return nullptr;
|
|
387
|
+
}
|
|
388
|
+
}
|
|
389
|
+
return host_buft;
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
//
|
|
393
|
+
// meta backend buffer
|
|
394
|
+
//
|
|
395
|
+
|
|
396
|
+
// Container to hold the tensor slices per simple ggml backend buffer.
|
|
397
|
+
struct ggml_backend_meta_simple_tensor_container {
|
|
398
|
+
std::vector<ggml_context_ptr> ctxs;
|
|
399
|
+
std::map<const ggml_tensor *, std::vector<ggml_tensor *>> simple_tensors;
|
|
400
|
+
|
|
401
|
+
ggml_backend_meta_simple_tensor_container(const ggml_init_params & params, const int n_simple) {
|
|
402
|
+
ctxs.reserve(n_simple);
|
|
403
|
+
for (int i = 0; i < n_simple; i++) {
|
|
404
|
+
ctxs.emplace_back(ggml_init(params));
|
|
405
|
+
}
|
|
406
|
+
}
|
|
407
|
+
ggml_backend_meta_simple_tensor_container() {}
|
|
408
|
+
};
|
|
409
|
+
|
|
410
|
+
struct ggml_backend_meta_buffer_context {
|
|
411
|
+
// FIXME
|
|
412
|
+
// Most tensors can simply be stored statically in their own buffer.
|
|
413
|
+
// Externally created views however also need a mapping to simple tensors but they use the buffer of the view source.
|
|
414
|
+
// If external views are simply using that buffer they will slowly deplete its memory.
|
|
415
|
+
// Current solution: rotating set of 2 "compute" containers to hold external views, works correctly for llama.cpp.
|
|
416
|
+
// Long-term: tie the lifetime of external views to the meta backend executing the graph instead,
|
|
417
|
+
// currently not possible due to graph-external operations in the backend scheduler.
|
|
418
|
+
ggml_backend_meta_simple_tensor_container stc_static;
|
|
419
|
+
ggml_backend_meta_simple_tensor_container stc_compute[2];
|
|
420
|
+
int stc_compute_index = 0;
|
|
421
|
+
int stc_compute_index_next = 0;
|
|
422
|
+
std::vector<ggml_backend_buffer_ptr> bufs;
|
|
423
|
+
|
|
424
|
+
// FIXME
|
|
425
|
+
// The size of the split state cache is unbounded and can theoretically grow infinitely large.
|
|
426
|
+
// However, it is also expensive to build and clearing it on every rebuild in ggml_backend_meta_graph_compute is too expensive.
|
|
427
|
+
static constexpr size_t nbtc = GGML_TENSOR_SIZE - sizeof(ggml_tensor::padding);
|
|
428
|
+
std::map<std::pair<const ggml_tensor *, bool>, std::pair<ggml_backend_meta_split_state, char[nbtc]>> split_state_cache;
|
|
429
|
+
|
|
430
|
+
int debug;
|
|
431
|
+
|
|
432
|
+
ggml_backend_meta_buffer_context(
|
|
433
|
+
ggml_backend_meta_simple_tensor_container & stc_static,
|
|
434
|
+
ggml_backend_meta_simple_tensor_container & stc_compute_0,
|
|
435
|
+
ggml_backend_meta_simple_tensor_container & stc_compute_1,
|
|
436
|
+
const std::vector<ggml_backend_buffer_t> & bufs)
|
|
437
|
+
: stc_static(std::move(stc_static)), stc_compute{std::move(stc_compute_0), std::move(stc_compute_1)} {
|
|
438
|
+
this->bufs.reserve(bufs.size());
|
|
439
|
+
for (ggml_backend_buffer_t buf : bufs) {
|
|
440
|
+
this->bufs.emplace_back(buf);
|
|
441
|
+
}
|
|
442
|
+
const char * GGML_META_DEBUG = getenv("GGML_META_DEBUG");
|
|
443
|
+
debug = GGML_META_DEBUG ? atoi(GGML_META_DEBUG) : 0;
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
ggml_backend_meta_simple_tensor_container & get_simple_tensor_container(const ggml_tensor * tensor) {
|
|
447
|
+
if (stc_static.simple_tensors.find(tensor) != stc_static.simple_tensors.end()) {
|
|
448
|
+
return stc_static;
|
|
449
|
+
}
|
|
450
|
+
return stc_compute[stc_compute_index];
|
|
451
|
+
}
|
|
452
|
+
};
|
|
453
|
+
|
|
454
|
+
static void ggml_backend_meta_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
455
|
+
GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
|
|
456
|
+
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
|
|
457
|
+
delete buf_ctx;
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
static size_t ggml_backend_meta_buffer_n_bufs(ggml_backend_buffer_t meta_buf) {
|
|
461
|
+
GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf));
|
|
462
|
+
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context;
|
|
463
|
+
return buf_ctx->bufs.size();
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
static ggml_backend_buffer_t ggml_backend_meta_buffer_simple_buffer(ggml_backend_buffer_t meta_buf, size_t index) {
|
|
467
|
+
GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf));
|
|
468
|
+
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context;
|
|
469
|
+
GGML_ASSERT(index < buf_ctx->bufs.size());
|
|
470
|
+
return buf_ctx->bufs[index].get();
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
static struct ggml_tensor * ggml_backend_meta_buffer_simple_tensor(const struct ggml_tensor * tensor, size_t index) {
|
|
474
|
+
GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
|
|
475
|
+
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
|
|
476
|
+
GGML_ASSERT(index < buf_ctx->bufs.size());
|
|
477
|
+
|
|
478
|
+
ggml_backend_meta_simple_tensor_container & stc = buf_ctx->get_simple_tensor_container(tensor);
|
|
479
|
+
auto it = stc.simple_tensors.find(tensor);
|
|
480
|
+
if (it == stc.simple_tensors.end()) {
|
|
481
|
+
return nullptr;
|
|
482
|
+
}
|
|
483
|
+
return it->second[index];
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync);
|
|
487
|
+
|
|
488
|
+
static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(
|
|
489
|
+
ggml_backend_meta_simple_tensor_container & stc, const struct ggml_tensor * tensor, bool assume_sync) {
|
|
490
|
+
// FIXME Currently this function preserves/erases the information in n_segments and nr in an inconsistent way.
|
|
491
|
+
// Since the operations in question are developed specifically for llama.cpp this currently does not manifest as a bug there.
|
|
492
|
+
// However, in a broader ggml context with arbitrary ggml graphs this can lead to unexpected results.
|
|
493
|
+
const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
|
|
494
|
+
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
|
|
495
|
+
|
|
496
|
+
auto split_states_equal = [&](const ggml_backend_meta_split_state & a, const ggml_backend_meta_split_state & b) -> bool {
|
|
497
|
+
if (a.axis != b.axis) {
|
|
498
|
+
return false;
|
|
499
|
+
}
|
|
500
|
+
for (size_t j = 0; j < n_bufs; j++) {
|
|
501
|
+
int64_t sum_a = 0;
|
|
502
|
+
for (size_t s = 0; s < a.n_segments; s++) {
|
|
503
|
+
sum_a += a.ne[s*n_bufs + j] * a.nr[s];
|
|
504
|
+
}
|
|
505
|
+
int64_t sum_b = 0;
|
|
506
|
+
for (size_t s = 0; s < b.n_segments; s++) {
|
|
507
|
+
sum_b += b.ne[s*n_bufs + j] * b.nr[s];
|
|
508
|
+
}
|
|
509
|
+
if (sum_a != sum_b) {
|
|
510
|
+
return false;
|
|
511
|
+
}
|
|
512
|
+
}
|
|
513
|
+
return true;
|
|
514
|
+
};
|
|
515
|
+
|
|
516
|
+
auto handle_generic = [&](const std::vector<ggml_backend_meta_split_state> & src_ss, bool scalar_only) -> ggml_backend_meta_split_state {
|
|
517
|
+
ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1};
|
|
518
|
+
for (size_t i = 0; i < GGML_MAX_SRC; i++) {
|
|
519
|
+
if (tensor->src[i] == nullptr || tensor->src[i] == tensor) {
|
|
520
|
+
continue;
|
|
521
|
+
}
|
|
522
|
+
if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) {
|
|
523
|
+
ret = src_ss[i];
|
|
524
|
+
} else if (!split_states_equal(src_ss[i], ret)) {
|
|
525
|
+
ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
|
|
526
|
+
break;
|
|
527
|
+
}
|
|
528
|
+
}
|
|
529
|
+
if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) {
|
|
530
|
+
ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
|
|
531
|
+
}
|
|
532
|
+
if (scalar_only && ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) {
|
|
533
|
+
ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
|
|
534
|
+
}
|
|
535
|
+
GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
|
|
536
|
+
return ret;
|
|
537
|
+
};
|
|
538
|
+
|
|
539
|
+
// Some ops process data on a per-row bases:
|
|
540
|
+
auto handle_per_row = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
|
|
541
|
+
GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_0);
|
|
542
|
+
return src_ss[0];
|
|
543
|
+
};
|
|
544
|
+
|
|
545
|
+
// Some ops broadcast the src1 data across src0:
|
|
546
|
+
auto handle_bin_bcast = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
|
|
547
|
+
if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS &&
|
|
548
|
+
tensor->src[1]->ne[src_ss[0].axis] == 1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
|
549
|
+
return src_ss[0];
|
|
550
|
+
}
|
|
551
|
+
if (src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[0].axis == src_ss[1].axis ||
|
|
552
|
+
(src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL)))) {
|
|
553
|
+
return src_ss[0]; // GGML_OP_ADD_ID
|
|
554
|
+
}
|
|
555
|
+
GGML_ASSERT(tensor->src[2] == nullptr || src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
|
|
556
|
+
return handle_generic(src_ss, /*scalar_only =*/ false);
|
|
557
|
+
};
|
|
558
|
+
|
|
559
|
+
auto handle_concat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
|
|
560
|
+
const ggml_backend_meta_split_axis concat_axis = ggml_backend_meta_split_axis(ggml_get_op_params_i32(tensor, 0));
|
|
561
|
+
if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis >= 0 && src_ss[1].axis < GGML_MAX_DIMS) {
|
|
562
|
+
GGML_ASSERT(concat_axis != src_ss[1].axis);
|
|
563
|
+
return src_ss[1];
|
|
564
|
+
}
|
|
565
|
+
if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) {
|
|
566
|
+
GGML_ASSERT(concat_axis != src_ss[0].axis);
|
|
567
|
+
return src_ss[0];
|
|
568
|
+
}
|
|
569
|
+
if (src_ss[0].axis == src_ss[1].axis && src_ss[0].axis != concat_axis) {
|
|
570
|
+
return src_ss[0];
|
|
571
|
+
}
|
|
572
|
+
return handle_generic(src_ss, /*scalar_only =*/ true);
|
|
573
|
+
};
|
|
574
|
+
|
|
575
|
+
auto handle_mul_mat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
|
|
576
|
+
if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
|
577
|
+
return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1};
|
|
578
|
+
}
|
|
579
|
+
if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
|
580
|
+
ggml_backend_meta_split_state ret = src_ss[0];
|
|
581
|
+
ret.axis = GGML_BACKEND_SPLIT_AXIS_0;
|
|
582
|
+
ret.nr[0] = 1;
|
|
583
|
+
ret.n_segments = 1;
|
|
584
|
+
return ret;
|
|
585
|
+
}
|
|
586
|
+
if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
|
587
|
+
return src_ss[1];
|
|
588
|
+
}
|
|
589
|
+
if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_0) {
|
|
590
|
+
GGML_ASSERT(split_states_equal(src_ss[0], src_ss[1]));
|
|
591
|
+
return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, {1}, 1};
|
|
592
|
+
}
|
|
593
|
+
GGML_ABORT("fatal error");
|
|
594
|
+
//return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
|
|
595
|
+
};
|
|
596
|
+
|
|
597
|
+
auto handle_reshape = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
|
|
598
|
+
switch (src_ss[0].axis) {
|
|
599
|
+
case GGML_BACKEND_SPLIT_AXIS_0:
|
|
600
|
+
case GGML_BACKEND_SPLIT_AXIS_1:
|
|
601
|
+
case GGML_BACKEND_SPLIT_AXIS_2:
|
|
602
|
+
case GGML_BACKEND_SPLIT_AXIS_3: {
|
|
603
|
+
GGML_ASSERT(src_ss[0].n_segments == 1);
|
|
604
|
+
if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1 && src_ss[0].nr[0] == 1) {
|
|
605
|
+
return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, {1}, 1};
|
|
606
|
+
}
|
|
607
|
+
int64_t base_ne_in = tensor->src[0]->ne[0];
|
|
608
|
+
for (int dim = 1; dim <= src_ss[0].axis; dim++) {
|
|
609
|
+
base_ne_in *= tensor->src[0]->ne[dim];
|
|
610
|
+
}
|
|
611
|
+
base_ne_in /= src_ss[0].nr[0];
|
|
612
|
+
int64_t base_ne_out = 1;
|
|
613
|
+
for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
|
|
614
|
+
const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim];
|
|
615
|
+
if (base_ne_out_next % base_ne_in == 0) {
|
|
616
|
+
return {ggml_backend_meta_split_axis(dim), {0}, {uint32_t(base_ne_out_next/base_ne_in)}, 1};
|
|
617
|
+
}
|
|
618
|
+
if (base_ne_out_next > base_ne_in) {
|
|
619
|
+
GGML_ASSERT(src_ss[0].n_segments == 1);
|
|
620
|
+
GGML_ASSERT(src_ss[0].nr[0] == 1);
|
|
621
|
+
return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1};
|
|
622
|
+
}
|
|
623
|
+
base_ne_out = base_ne_out_next;
|
|
624
|
+
}
|
|
625
|
+
GGML_ABORT("shape mismatch for %s", ggml_op_name(tensor->op));
|
|
626
|
+
}
|
|
627
|
+
case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
|
|
628
|
+
case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
|
|
629
|
+
return src_ss[0];
|
|
630
|
+
}
|
|
631
|
+
default: {
|
|
632
|
+
GGML_ABORT("fatal error");
|
|
633
|
+
//return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
|
|
634
|
+
}
|
|
635
|
+
}
|
|
636
|
+
};
|
|
637
|
+
|
|
638
|
+
auto handle_cpy = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
|
|
639
|
+
if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) {
|
|
640
|
+
return handle_reshape(src_ss);
|
|
641
|
+
}
|
|
642
|
+
return handle_generic(src_ss, /*scalar_only =*/ false);
|
|
643
|
+
};
|
|
644
|
+
|
|
645
|
+
auto handle_view = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
|
|
646
|
+
if (ggml_is_contiguous(tensor) && ggml_is_contiguous(tensor->src[0])) {
|
|
647
|
+
return handle_reshape(src_ss);
|
|
648
|
+
}
|
|
649
|
+
const int axis = src_ss[0].axis;
|
|
650
|
+
{
|
|
651
|
+
bool all_strides_the_same = true;
|
|
652
|
+
for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
|
|
653
|
+
if (tensor->ne[dim] == 1 && tensor->src[0]->ne[dim] == 1) {
|
|
654
|
+
continue;
|
|
655
|
+
}
|
|
656
|
+
if (tensor->nb[dim] != tensor->src[0]->nb[dim]) {
|
|
657
|
+
all_strides_the_same = false;
|
|
658
|
+
break;
|
|
659
|
+
}
|
|
660
|
+
}
|
|
661
|
+
if (all_strides_the_same) {
|
|
662
|
+
return src_ss[0];
|
|
663
|
+
}
|
|
664
|
+
}
|
|
665
|
+
if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]) && axis >= 0 && axis < GGML_MAX_DIMS-1) {
|
|
666
|
+
for (int dim = 0; dim < GGML_MAX_DIMS-1; dim++) {
|
|
667
|
+
if (tensor->nb[dim+1] == tensor->src[0]->nb[axis+1]) {
|
|
668
|
+
return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1};
|
|
669
|
+
}
|
|
670
|
+
}
|
|
671
|
+
GGML_ABORT("fatal error");
|
|
672
|
+
}
|
|
673
|
+
if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED || src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) {
|
|
674
|
+
return src_ss[0];
|
|
675
|
+
}
|
|
676
|
+
GGML_ABORT("view of permuted tensor not implemented");
|
|
677
|
+
//return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
|
|
678
|
+
};
|
|
679
|
+
|
|
680
|
+
auto handle_permute = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
|
|
681
|
+
switch (src_ss[0].axis) {
|
|
682
|
+
case GGML_BACKEND_SPLIT_AXIS_0:
|
|
683
|
+
case GGML_BACKEND_SPLIT_AXIS_1:
|
|
684
|
+
case GGML_BACKEND_SPLIT_AXIS_2:
|
|
685
|
+
case GGML_BACKEND_SPLIT_AXIS_3: {
|
|
686
|
+
GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1);
|
|
687
|
+
return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, {src_ss[0].nr[0]}, 1};
|
|
688
|
+
}
|
|
689
|
+
case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
|
|
690
|
+
case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
|
|
691
|
+
return src_ss[0];
|
|
692
|
+
}
|
|
693
|
+
default: {
|
|
694
|
+
GGML_ABORT("fatal error");
|
|
695
|
+
//return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
|
|
696
|
+
}
|
|
697
|
+
}
|
|
698
|
+
};
|
|
699
|
+
|
|
700
|
+
auto handle_transpose = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
|
|
701
|
+
switch (src_ss[0].axis) {
|
|
702
|
+
case GGML_BACKEND_SPLIT_AXIS_0:
|
|
703
|
+
case GGML_BACKEND_SPLIT_AXIS_1: {
|
|
704
|
+
GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1);
|
|
705
|
+
return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, {src_ss[0].nr[0]}, 1};
|
|
706
|
+
}
|
|
707
|
+
case GGML_BACKEND_SPLIT_AXIS_2:
|
|
708
|
+
case GGML_BACKEND_SPLIT_AXIS_3:
|
|
709
|
+
case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
|
|
710
|
+
case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
|
|
711
|
+
return src_ss[0];
|
|
712
|
+
}
|
|
713
|
+
default: {
|
|
714
|
+
GGML_ABORT("fatal error");
|
|
715
|
+
//return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
|
|
716
|
+
}
|
|
717
|
+
}
|
|
718
|
+
};
|
|
719
|
+
|
|
720
|
+
auto handle_get_rows = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
|
|
721
|
+
if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
|
722
|
+
return src_ss[0];
|
|
723
|
+
}
|
|
724
|
+
return handle_generic(src_ss, /*scalar_only =*/ true);
|
|
725
|
+
};
|
|
726
|
+
|
|
727
|
+
auto handle_set_rows = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
|
|
728
|
+
GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_1);
|
|
729
|
+
GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
|
|
730
|
+
GGML_ASSERT(split_states_equal(src_ss[0], src_ss[2]));
|
|
731
|
+
return src_ss[0];
|
|
732
|
+
};
|
|
733
|
+
|
|
734
|
+
auto handle_rope = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
|
|
735
|
+
GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
|
|
736
|
+
return src_ss[0];
|
|
737
|
+
};
|
|
738
|
+
|
|
739
|
+
auto handle_pad = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
|
|
740
|
+
if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) {
|
|
741
|
+
GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 0] == 0);
|
|
742
|
+
GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 1] == 0);
|
|
743
|
+
}
|
|
744
|
+
return src_ss[0];
|
|
745
|
+
};
|
|
746
|
+
|
|
747
|
+
auto handle_flash_attn_ext = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
|
|
748
|
+
GGML_ASSERT( src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_2);
|
|
749
|
+
GGML_ASSERT( src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_2);
|
|
750
|
+
GGML_ASSERT( src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_2);
|
|
751
|
+
GGML_ASSERT(tensor->src[4] == nullptr || src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
|
|
752
|
+
GGML_ASSERT(tensor->src[4] == nullptr || src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_0);
|
|
753
|
+
return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1};
|
|
754
|
+
};
|
|
755
|
+
|
|
756
|
+
auto handle_ssm_conv = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
|
|
757
|
+
if (src_ss[0].axis == src_ss[1].axis) {
|
|
758
|
+
if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0) {
|
|
759
|
+
return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1};
|
|
760
|
+
}
|
|
761
|
+
if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1) {
|
|
762
|
+
return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1};
|
|
763
|
+
}
|
|
764
|
+
}
|
|
765
|
+
return handle_generic(src_ss, /*scalar_only =*/ false);
|
|
766
|
+
};
|
|
767
|
+
|
|
768
|
+
auto handle_gated_delta_net = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
|
|
769
|
+
if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED &&
|
|
770
|
+
src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED &&
|
|
771
|
+
src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
|
772
|
+
return src_ss[0];
|
|
773
|
+
}
|
|
774
|
+
GGML_ASSERT(src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1);
|
|
775
|
+
GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1);
|
|
776
|
+
GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1);
|
|
777
|
+
GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1);
|
|
778
|
+
GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1);
|
|
779
|
+
// state shape is [S_v, S_v, H_v, n_seqs] (s0 only); the heads dim is its own axis 2,
|
|
780
|
+
// so a head-aligned split on the input cache lands on axis 2 here.
|
|
781
|
+
GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0);
|
|
782
|
+
return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1};
|
|
783
|
+
};
|
|
784
|
+
|
|
785
|
+
auto calculate_split_state = [&]() -> ggml_backend_meta_split_state {
|
|
786
|
+
if (ggml_nelements(tensor) == 0) {
|
|
787
|
+
return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
|
|
788
|
+
}
|
|
789
|
+
if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) {
|
|
790
|
+
ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer));
|
|
791
|
+
const ggml_backend_meta_device_context * dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
|
|
792
|
+
ggml_backend_meta_split_state ret = dev_ctx->get_split_state(tensor, dev_ctx->get_split_state_ud);
|
|
793
|
+
if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) {
|
|
794
|
+
const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1;
|
|
795
|
+
int64_t ne_sum = 0;
|
|
796
|
+
for (size_t s = 0; s < ret.n_segments; s++) {
|
|
797
|
+
for (size_t j = 0; j < n_bufs; j++) {
|
|
798
|
+
GGML_ASSERT(ret.ne[s*n_bufs + j] % granularity == 0);
|
|
799
|
+
ne_sum += ret.ne[s*n_bufs + j] * ret.nr[s];
|
|
800
|
+
}
|
|
801
|
+
}
|
|
802
|
+
GGML_ASSERT(ne_sum == tensor->ne[ret.axis]);
|
|
803
|
+
}
|
|
804
|
+
return ret;
|
|
805
|
+
}
|
|
806
|
+
|
|
807
|
+
std::vector<ggml_backend_meta_split_state> src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1});
|
|
808
|
+
for (size_t i = 0; i < GGML_MAX_SRC; i++) {
|
|
809
|
+
if (tensor->src[i] == nullptr || tensor->src[i] == tensor) {
|
|
810
|
+
src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
|
|
811
|
+
continue;
|
|
812
|
+
}
|
|
813
|
+
src_ss[i] = ggml_backend_meta_get_split_state(stc, tensor->src[i], /*assume_sync =*/ true);
|
|
814
|
+
GGML_ASSERT(src_ss[i].axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
|
|
815
|
+
}
|
|
816
|
+
|
|
817
|
+
ggml_backend_meta_split_state split_state;
|
|
818
|
+
switch (tensor->op) {
|
|
819
|
+
case GGML_OP_NONE: {
|
|
820
|
+
split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1};
|
|
821
|
+
} break;
|
|
822
|
+
case GGML_OP_DUP: {
|
|
823
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ true);
|
|
824
|
+
} break;
|
|
825
|
+
case GGML_OP_ADD:
|
|
826
|
+
case GGML_OP_ADD_ID: {
|
|
827
|
+
split_state = handle_bin_bcast(src_ss);
|
|
828
|
+
} break;
|
|
829
|
+
case GGML_OP_ADD1:
|
|
830
|
+
case GGML_OP_ACC: {
|
|
831
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ true);
|
|
832
|
+
} break;
|
|
833
|
+
case GGML_OP_SUB:
|
|
834
|
+
case GGML_OP_MUL:
|
|
835
|
+
case GGML_OP_DIV: {
|
|
836
|
+
split_state = handle_bin_bcast(src_ss);
|
|
837
|
+
} break;
|
|
838
|
+
case GGML_OP_SQR:
|
|
839
|
+
case GGML_OP_SQRT:
|
|
840
|
+
case GGML_OP_LOG:
|
|
841
|
+
case GGML_OP_SIN:
|
|
842
|
+
case GGML_OP_COS: {
|
|
843
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ false);
|
|
844
|
+
} break;
|
|
845
|
+
case GGML_OP_SUM: {
|
|
846
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ true);
|
|
847
|
+
} break;
|
|
848
|
+
case GGML_OP_SUM_ROWS:
|
|
849
|
+
case GGML_OP_CUMSUM:
|
|
850
|
+
case GGML_OP_MEAN:
|
|
851
|
+
case GGML_OP_ARGMAX:
|
|
852
|
+
case GGML_OP_COUNT_EQUAL: {
|
|
853
|
+
split_state = handle_per_row(src_ss);
|
|
854
|
+
} break;
|
|
855
|
+
case GGML_OP_REPEAT:
|
|
856
|
+
case GGML_OP_REPEAT_BACK: {
|
|
857
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ false);
|
|
858
|
+
} break;
|
|
859
|
+
case GGML_OP_CONCAT: {
|
|
860
|
+
split_state = handle_concat(src_ss);
|
|
861
|
+
} break;
|
|
862
|
+
case GGML_OP_SILU_BACK: {
|
|
863
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ false);
|
|
864
|
+
} break;
|
|
865
|
+
case GGML_OP_NORM:
|
|
866
|
+
case GGML_OP_RMS_NORM:
|
|
867
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
868
|
+
case GGML_OP_GROUP_NORM:
|
|
869
|
+
case GGML_OP_L2_NORM: {
|
|
870
|
+
split_state = handle_per_row(src_ss);
|
|
871
|
+
} break;
|
|
872
|
+
case GGML_OP_MUL_MAT:
|
|
873
|
+
case GGML_OP_MUL_MAT_ID: {
|
|
874
|
+
split_state = handle_mul_mat(src_ss);
|
|
875
|
+
} break;
|
|
876
|
+
case GGML_OP_OUT_PROD: {
|
|
877
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ true);
|
|
878
|
+
} break;
|
|
879
|
+
case GGML_OP_SCALE: {
|
|
880
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ false);
|
|
881
|
+
} break;
|
|
882
|
+
case GGML_OP_SET: {
|
|
883
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ true);
|
|
884
|
+
} break;
|
|
885
|
+
case GGML_OP_CPY: {
|
|
886
|
+
split_state = handle_cpy(src_ss);
|
|
887
|
+
} break;
|
|
888
|
+
case GGML_OP_CONT:
|
|
889
|
+
case GGML_OP_RESHAPE: {
|
|
890
|
+
split_state = handle_reshape(src_ss);
|
|
891
|
+
} break;
|
|
892
|
+
case GGML_OP_VIEW: {
|
|
893
|
+
split_state = handle_view(src_ss);
|
|
894
|
+
} break;
|
|
895
|
+
case GGML_OP_PERMUTE: {
|
|
896
|
+
split_state = handle_permute(src_ss);
|
|
897
|
+
} break;
|
|
898
|
+
case GGML_OP_TRANSPOSE: {
|
|
899
|
+
split_state = handle_transpose(src_ss);
|
|
900
|
+
} break;
|
|
901
|
+
case GGML_OP_GET_ROWS: {
|
|
902
|
+
split_state = handle_get_rows(src_ss);
|
|
903
|
+
} break;
|
|
904
|
+
case GGML_OP_GET_ROWS_BACK: {
|
|
905
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ true);
|
|
906
|
+
} break;
|
|
907
|
+
case GGML_OP_SET_ROWS: {
|
|
908
|
+
split_state = handle_set_rows(src_ss);
|
|
909
|
+
} break;
|
|
910
|
+
case GGML_OP_DIAG:
|
|
911
|
+
case GGML_OP_DIAG_MASK_INF:
|
|
912
|
+
case GGML_OP_DIAG_MASK_ZERO: {
|
|
913
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ true);
|
|
914
|
+
} break;
|
|
915
|
+
case GGML_OP_SOFT_MAX:
|
|
916
|
+
case GGML_OP_SOFT_MAX_BACK: {
|
|
917
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ false);
|
|
918
|
+
} break;
|
|
919
|
+
case GGML_OP_ROPE: {
|
|
920
|
+
split_state = handle_rope(src_ss);
|
|
921
|
+
} break;
|
|
922
|
+
case GGML_OP_ROPE_BACK: {
|
|
923
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ true);
|
|
924
|
+
} break;
|
|
925
|
+
case GGML_OP_CLAMP: {
|
|
926
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ false);
|
|
927
|
+
} break;
|
|
928
|
+
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
929
|
+
case GGML_OP_IM2COL:
|
|
930
|
+
case GGML_OP_IM2COL_BACK:
|
|
931
|
+
case GGML_OP_IM2COL_3D:
|
|
932
|
+
case GGML_OP_CONV_2D:
|
|
933
|
+
case GGML_OP_CONV_3D:
|
|
934
|
+
case GGML_OP_CONV_2D_DW:
|
|
935
|
+
case GGML_OP_CONV_TRANSPOSE_2D:
|
|
936
|
+
case GGML_OP_POOL_1D:
|
|
937
|
+
case GGML_OP_POOL_2D:
|
|
938
|
+
case GGML_OP_POOL_2D_BACK:
|
|
939
|
+
case GGML_OP_UPSCALE: {
|
|
940
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ true);
|
|
941
|
+
} break;
|
|
942
|
+
case GGML_OP_PAD: {
|
|
943
|
+
split_state = handle_pad(src_ss);
|
|
944
|
+
} break;
|
|
945
|
+
case GGML_OP_PAD_REFLECT_1D:
|
|
946
|
+
case GGML_OP_ROLL:
|
|
947
|
+
case GGML_OP_ARANGE:
|
|
948
|
+
case GGML_OP_TIMESTEP_EMBEDDING: {
|
|
949
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ true);
|
|
950
|
+
} break;
|
|
951
|
+
case GGML_OP_ARGSORT:
|
|
952
|
+
case GGML_OP_TOP_K: {
|
|
953
|
+
split_state = handle_per_row(src_ss);
|
|
954
|
+
} break;
|
|
955
|
+
case GGML_OP_LEAKY_RELU: {
|
|
956
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ false);
|
|
957
|
+
} break;
|
|
958
|
+
case GGML_OP_TRI: {
|
|
959
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ true);
|
|
960
|
+
} break;
|
|
961
|
+
case GGML_OP_FILL: {
|
|
962
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ false);
|
|
963
|
+
} break;
|
|
964
|
+
case GGML_OP_FLASH_ATTN_EXT: {
|
|
965
|
+
split_state = handle_flash_attn_ext(src_ss);
|
|
966
|
+
} break;
|
|
967
|
+
case GGML_OP_FLASH_ATTN_BACK: {
|
|
968
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ true);
|
|
969
|
+
} break;
|
|
970
|
+
case GGML_OP_SSM_CONV: {
|
|
971
|
+
split_state = handle_ssm_conv(src_ss);
|
|
972
|
+
} break;
|
|
973
|
+
case GGML_OP_SSM_SCAN:
|
|
974
|
+
case GGML_OP_WIN_PART:
|
|
975
|
+
case GGML_OP_WIN_UNPART:
|
|
976
|
+
case GGML_OP_GET_REL_POS:
|
|
977
|
+
case GGML_OP_ADD_REL_POS:
|
|
978
|
+
case GGML_OP_RWKV_WKV6:
|
|
979
|
+
case GGML_OP_GATED_LINEAR_ATTN:
|
|
980
|
+
case GGML_OP_RWKV_WKV7:
|
|
981
|
+
case GGML_OP_SOLVE_TRI: {
|
|
982
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ true);
|
|
983
|
+
} break;
|
|
984
|
+
case GGML_OP_GATED_DELTA_NET: {
|
|
985
|
+
split_state = handle_gated_delta_net(src_ss);
|
|
986
|
+
} break;
|
|
987
|
+
case GGML_OP_UNARY: {
|
|
988
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ false);
|
|
989
|
+
} break;
|
|
990
|
+
case GGML_OP_MAP_CUSTOM1:
|
|
991
|
+
case GGML_OP_MAP_CUSTOM2:
|
|
992
|
+
case GGML_OP_MAP_CUSTOM3:
|
|
993
|
+
case GGML_OP_CUSTOM: {
|
|
994
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ true);
|
|
995
|
+
} break;
|
|
996
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
|
997
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK: {
|
|
998
|
+
split_state = handle_per_row(src_ss);
|
|
999
|
+
} break;
|
|
1000
|
+
case GGML_OP_OPT_STEP_ADAMW:
|
|
1001
|
+
case GGML_OP_OPT_STEP_SGD:
|
|
1002
|
+
case GGML_OP_GLU: {
|
|
1003
|
+
split_state = handle_generic(src_ss, /*scalar_only =*/ false);
|
|
1004
|
+
} break;
|
|
1005
|
+
default: {
|
|
1006
|
+
GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op));
|
|
1007
|
+
split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
|
|
1008
|
+
} break;
|
|
1009
|
+
}
|
|
1010
|
+
if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
|
|
1011
|
+
bool first_src_split_by_axis = true;
|
|
1012
|
+
const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
|
|
1013
|
+
|
|
1014
|
+
for (size_t i = 0; i < GGML_MAX_SRC; i++) {
|
|
1015
|
+
if (tensor->src[i] == nullptr || src_ss[i].axis < 0 || src_ss[i].axis >= GGML_MAX_DIMS) {
|
|
1016
|
+
continue;
|
|
1017
|
+
}
|
|
1018
|
+
if (first_src_split_by_axis) {
|
|
1019
|
+
for (size_t j = 0; j < n_bufs; j++) {
|
|
1020
|
+
// Take over ratio from src:
|
|
1021
|
+
for (size_t s = 0; s < src_ss[i].n_segments; s++) {
|
|
1022
|
+
split_state.ne[s*n_bufs + j] = 0;
|
|
1023
|
+
}
|
|
1024
|
+
for (size_t s = 0; s < src_ss[i].n_segments; s++) {
|
|
1025
|
+
split_state.ne[j] += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s];
|
|
1026
|
+
}
|
|
1027
|
+
split_state.ne[j] *= tensor->ne[split_state.axis];
|
|
1028
|
+
if (split_state.ne[j] != 0 || tensor->src[i]->ne[src_ss[i].axis] != 0) {
|
|
1029
|
+
const int64_t div = tensor->src[i]->ne[src_ss[i].axis] * split_state.nr[0];
|
|
1030
|
+
GGML_ASSERT(split_state.ne[j] % div == 0);
|
|
1031
|
+
split_state.ne[j] /= div;
|
|
1032
|
+
}
|
|
1033
|
+
}
|
|
1034
|
+
} else {
|
|
1035
|
+
GGML_ASSERT(split_state.n_segments == 1);
|
|
1036
|
+
for (size_t j = 0; j < n_bufs; j++) {
|
|
1037
|
+
// Assert that ratio is consistent:
|
|
1038
|
+
int64_t sum = 0;
|
|
1039
|
+
for (size_t s = 0; s < src_ss[i].n_segments; s++) {
|
|
1040
|
+
sum += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s];
|
|
1041
|
+
}
|
|
1042
|
+
GGML_ASSERT(split_state.ne[j]*split_state.nr[0] * tensor->src[i]->ne[src_ss[i].axis]
|
|
1043
|
+
== sum * tensor->ne[split_state.axis]);
|
|
1044
|
+
}
|
|
1045
|
+
}
|
|
1046
|
+
first_src_split_by_axis = false;
|
|
1047
|
+
}
|
|
1048
|
+
GGML_ASSERT(!first_src_split_by_axis);
|
|
1049
|
+
}
|
|
1050
|
+
return split_state;
|
|
1051
|
+
};
|
|
1052
|
+
|
|
1053
|
+
const std::pair key = std::make_pair(tensor, assume_sync);
|
|
1054
|
+
auto it = buf_ctx->split_state_cache.find(key);
|
|
1055
|
+
if (it != buf_ctx->split_state_cache.end() && memcmp(it->second.second, (const char *) tensor, sizeof(it->second.second)) != 0) {
|
|
1056
|
+
buf_ctx->split_state_cache.clear();
|
|
1057
|
+
it = buf_ctx->split_state_cache.end();
|
|
1058
|
+
}
|
|
1059
|
+
|
|
1060
|
+
if (it == buf_ctx->split_state_cache.end()) {
|
|
1061
|
+
buf_ctx->split_state_cache[key].first = calculate_split_state();
|
|
1062
|
+
memcpy(buf_ctx->split_state_cache[key].second, tensor, sizeof(buf_ctx->split_state_cache[key].second));
|
|
1063
|
+
if (buf_ctx->debug > 0) {
|
|
1064
|
+
std::string srcs_info;
|
|
1065
|
+
for (size_t i = 0; i < GGML_MAX_SRC; i++) {
|
|
1066
|
+
if (tensor->src[i] == nullptr) {
|
|
1067
|
+
continue;
|
|
1068
|
+
}
|
|
1069
|
+
if (!srcs_info.empty()) {
|
|
1070
|
+
srcs_info += ", ";
|
|
1071
|
+
}
|
|
1072
|
+
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true);
|
|
1073
|
+
GGML_ASSERT(split_state.n_segments == 1);
|
|
1074
|
+
const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis);
|
|
1075
|
+
std::string ne_info;
|
|
1076
|
+
for (size_t j = 0; j < n_bufs; j++) {
|
|
1077
|
+
if (!ne_info.empty()) {
|
|
1078
|
+
ne_info += ", ";
|
|
1079
|
+
}
|
|
1080
|
+
ne_info += std::to_string(split_state.ne[j]) + "x" + std::to_string(split_state.nr[0]);
|
|
1081
|
+
}
|
|
1082
|
+
srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]";
|
|
1083
|
+
}
|
|
1084
|
+
std::string ne_info;
|
|
1085
|
+
for (size_t j = 0; j < n_bufs; j++) {
|
|
1086
|
+
if (!ne_info.empty()) {
|
|
1087
|
+
ne_info += ", ";
|
|
1088
|
+
}
|
|
1089
|
+
const ggml_backend_meta_split_state & ss = buf_ctx->split_state_cache[key].first;
|
|
1090
|
+
ne_info += std::to_string(ss.ne[j]) + "x" + std::to_string(ss.nr[0]);
|
|
1091
|
+
}
|
|
1092
|
+
GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op),
|
|
1093
|
+
ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].first.axis), ne_info.c_str());
|
|
1094
|
+
}
|
|
1095
|
+
}
|
|
1096
|
+
|
|
1097
|
+
ggml_backend_meta_split_state ret = buf_ctx->split_state_cache[key].first;
|
|
1098
|
+
GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_NONE);
|
|
1099
|
+
#ifndef NDEBUG
|
|
1100
|
+
if (ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) {
|
|
1101
|
+
int64_t ne_ret = 0;
|
|
1102
|
+
for (size_t s = 0; s < ret.n_segments; s++) {
|
|
1103
|
+
for (size_t j = 0; j < n_bufs; j++) {
|
|
1104
|
+
ne_ret += ret.ne[s*n_bufs + j] * ret.nr[s];
|
|
1105
|
+
}
|
|
1106
|
+
}
|
|
1107
|
+
assert(ne_ret == tensor->ne[int(ret.axis)]);
|
|
1108
|
+
}
|
|
1109
|
+
#endif // NDEBUG
|
|
1110
|
+
return ret;
|
|
1111
|
+
}
|
|
1112
|
+
|
|
1113
|
+
static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) {
|
|
1114
|
+
GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
|
|
1115
|
+
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
|
|
1116
|
+
return ggml_backend_meta_get_split_state(buf_ctx->get_simple_tensor_container(tensor), tensor, assume_sync);
|
|
1117
|
+
}
|
|
1118
|
+
|
|
1119
|
+
static void * ggml_backend_meta_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
1120
|
+
GGML_UNUSED(buffer);
|
|
1121
|
+
return (void *) 0x1000000000000000; // FIXME
|
|
1122
|
+
}
|
|
1123
|
+
|
|
1124
|
+
static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_meta_simple_tensor_container & stc, ggml_tensor * tensor) {
|
|
1125
|
+
GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
|
|
1126
|
+
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
|
|
1127
|
+
const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
|
|
1128
|
+
|
|
1129
|
+
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(stc, tensor, /*assume_sync =*/ true);
|
|
1130
|
+
GGML_ASSERT(ggml_nelements(tensor) == 0 || split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
|
|
1131
|
+
GGML_ASSERT(split_state.n_segments <= 16);
|
|
1132
|
+
|
|
1133
|
+
int split_dim = split_state.axis;
|
|
1134
|
+
int64_t ne[GGML_MAX_DIMS];
|
|
1135
|
+
size_t nb[GGML_MAX_DIMS];
|
|
1136
|
+
for (size_t k = 0; k < GGML_MAX_DIMS; k++) {
|
|
1137
|
+
ne[k] = tensor->ne[k];
|
|
1138
|
+
nb[k] = tensor->nb[k];
|
|
1139
|
+
}
|
|
1140
|
+
|
|
1141
|
+
std::vector<ggml_tensor *> simple_tensors;
|
|
1142
|
+
simple_tensors.reserve(n_simple_bufs);
|
|
1143
|
+
for (size_t j = 0; j < n_simple_bufs; j++) {
|
|
1144
|
+
ggml_context * simple_ctx = stc.ctxs[j].get();
|
|
1145
|
+
ggml_backend_buffer_t simple_buf = buf_ctx->bufs[j].get();
|
|
1146
|
+
|
|
1147
|
+
if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
|
|
1148
|
+
// TODO: the following assert fails for llama-parallel even though the results are correct:
|
|
1149
|
+
// GGML_ASSERT(ggml_is_contiguously_allocated(tensor));
|
|
1150
|
+
ne[split_dim] = 0;
|
|
1151
|
+
for (size_t s = 0; s < split_state.n_segments; s++) {
|
|
1152
|
+
ne[split_dim] += split_state.ne[s*n_simple_bufs + j] * split_state.nr[s];
|
|
1153
|
+
}
|
|
1154
|
+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
|
1155
|
+
if (tensor->nb[i] > tensor->nb[split_dim]) {
|
|
1156
|
+
nb[i] = tensor->nb[i] * ne[split_dim]/tensor->ne[split_dim];
|
|
1157
|
+
}
|
|
1158
|
+
}
|
|
1159
|
+
}
|
|
1160
|
+
|
|
1161
|
+
ggml_tensor * t_ij = ggml_new_tensor(simple_ctx, tensor->type, GGML_MAX_DIMS, ne);
|
|
1162
|
+
t_ij->op = tensor->op;
|
|
1163
|
+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
|
1164
|
+
t_ij->nb[i] = nb[i];
|
|
1165
|
+
}
|
|
1166
|
+
t_ij->flags = tensor->flags;
|
|
1167
|
+
memcpy(t_ij->op_params, tensor->op_params, sizeof(tensor->op_params));
|
|
1168
|
+
ggml_set_name(t_ij, tensor->name);
|
|
1169
|
+
t_ij->buffer = simple_buf;
|
|
1170
|
+
t_ij->view_src = tensor->view_src;
|
|
1171
|
+
t_ij->view_offs = tensor->view_offs;
|
|
1172
|
+
if (t_ij->view_src != nullptr && ggml_backend_buffer_is_meta(t_ij->view_src->buffer)) {
|
|
1173
|
+
t_ij->view_src = ggml_backend_meta_buffer_simple_tensor(tensor->view_src, j);
|
|
1174
|
+
if (t_ij->view_offs > 0 && split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
|
|
1175
|
+
GGML_ASSERT(tensor->ne[split_dim] != 0);
|
|
1176
|
+
const int split_dim_view_src = ggml_backend_meta_get_split_state(tensor->view_src, /*assume_sync =*/ true).axis;
|
|
1177
|
+
GGML_ASSERT(split_dim_view_src >= 0 && split_dim_view_src < GGML_MAX_DIMS);
|
|
1178
|
+
|
|
1179
|
+
// The offset can be internal to the data split, in those cases the view offset should not be scaled.
|
|
1180
|
+
// If however, the offset is larger than the data split then it needs to be scaled proportionally.
|
|
1181
|
+
bool split_internal_offset = t_ij->view_offs <= tensor->view_src->nb[split_dim_view_src];
|
|
1182
|
+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
|
1183
|
+
const size_t dim_size = tensor->ne[i] * tensor->nb[i];
|
|
1184
|
+
if (tensor->view_offs <= dim_size && dim_size < tensor->nb[split_dim]) {
|
|
1185
|
+
split_internal_offset = true;
|
|
1186
|
+
break;
|
|
1187
|
+
}
|
|
1188
|
+
}
|
|
1189
|
+
if (!split_internal_offset) {
|
|
1190
|
+
t_ij->view_offs = t_ij->view_offs * ne[split_dim]/tensor->ne[split_dim];
|
|
1191
|
+
}
|
|
1192
|
+
}
|
|
1193
|
+
}
|
|
1194
|
+
if (t_ij->view_src != nullptr) {
|
|
1195
|
+
t_ij->data = (char *) t_ij->view_src->data + t_ij->view_offs;
|
|
1196
|
+
} else if (simple_buf != nullptr) {
|
|
1197
|
+
t_ij->data = (char *) ggml_backend_buffer_get_base(simple_buf)
|
|
1198
|
+
+ size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(tensor->buffer));
|
|
1199
|
+
}
|
|
1200
|
+
t_ij->extra = tensor->extra;
|
|
1201
|
+
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
|
1202
|
+
t_ij->src[i] = tensor->src[i];
|
|
1203
|
+
if (tensor->src[i] == tensor) {
|
|
1204
|
+
t_ij->src[i] = t_ij;
|
|
1205
|
+
} else if (t_ij->src[i] != nullptr && ggml_backend_buffer_is_meta(t_ij->src[i]->buffer)) {
|
|
1206
|
+
t_ij->src[i] = ggml_backend_meta_buffer_simple_tensor(tensor->src[i], j);
|
|
1207
|
+
}
|
|
1208
|
+
}
|
|
1209
|
+
|
|
1210
|
+
simple_tensors.push_back(t_ij);
|
|
1211
|
+
}
|
|
1212
|
+
|
|
1213
|
+
// If one of the sources has a zero-sized slice, disable the computation:
|
|
1214
|
+
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
|
1215
|
+
if (tensor->src[i] == nullptr || !ggml_backend_buffer_is_meta(tensor->src[i]->buffer)) {
|
|
1216
|
+
continue;
|
|
1217
|
+
}
|
|
1218
|
+
|
|
1219
|
+
const ggml_backend_meta_split_state split_state_src = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true);
|
|
1220
|
+
if (split_state_src.axis < 0 || split_state_src.axis >= GGML_MAX_DIMS) {
|
|
1221
|
+
continue;
|
|
1222
|
+
}
|
|
1223
|
+
for (size_t j = 0; j < n_simple_bufs; j++) {
|
|
1224
|
+
int64_t ne_sum = 0;
|
|
1225
|
+
for (size_t s = 0; s < split_state_src.n_segments; s++) {
|
|
1226
|
+
ne_sum += split_state_src.ne[s*n_simple_bufs + j] * split_state_src.nr[s];
|
|
1227
|
+
}
|
|
1228
|
+
if (ne_sum == 0) {
|
|
1229
|
+
simple_tensors[j]->flags &= ~GGML_TENSOR_FLAG_COMPUTE;
|
|
1230
|
+
}
|
|
1231
|
+
}
|
|
1232
|
+
}
|
|
1233
|
+
|
|
1234
|
+
stc.simple_tensors[tensor] = simple_tensors;
|
|
1235
|
+
|
|
1236
|
+
return GGML_STATUS_SUCCESS;
|
|
1237
|
+
}
|
|
1238
|
+
|
|
1239
|
+
static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
|
1240
|
+
GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
|
|
1241
|
+
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
|
|
1242
|
+
buf_ctx->stc_compute_index = buf_ctx->stc_compute_index_next;
|
|
1243
|
+
return ggml_backend_meta_buffer_init_tensor_impl(buf_ctx->get_simple_tensor_container(tensor), tensor);
|
|
1244
|
+
}
|
|
1245
|
+
|
|
1246
|
+
static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
1247
|
+
const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer);
|
|
1248
|
+
GGML_ASSERT(ggml_is_contiguous(tensor));
|
|
1249
|
+
|
|
1250
|
+
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
|
|
1251
|
+
|
|
1252
|
+
if (split_state.n_segments != 1 || split_state.nr[0] != 1) {
|
|
1253
|
+
GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
|
|
1254
|
+
GGML_ASSERT(split_state.nr[0] != 0);
|
|
1255
|
+
GGML_ASSERT(tensor->ne[3] == 1);
|
|
1256
|
+
|
|
1257
|
+
size_t offset_data = 0;
|
|
1258
|
+
std::vector<size_t> simple_offsets(n_bufs, 0);
|
|
1259
|
+
if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) {
|
|
1260
|
+
GGML_ASSERT(tensor->ne[2] == 1);
|
|
1261
|
+
|
|
1262
|
+
const size_t row_stride = tensor->nb[1];
|
|
1263
|
+
GGML_ASSERT(offset % row_stride == 0);
|
|
1264
|
+
GGML_ASSERT(size % row_stride == 0);
|
|
1265
|
+
const int64_t row_start = offset / row_stride;
|
|
1266
|
+
const int64_t row_count = size / row_stride;
|
|
1267
|
+
GGML_ASSERT(row_start + row_count <= tensor->ne[1]);
|
|
1268
|
+
|
|
1269
|
+
const int64_t blck_size = ggml_blck_size(tensor->type);
|
|
1270
|
+
for (size_t s = 0; s < split_state.n_segments; s++) {
|
|
1271
|
+
for (size_t r = 0; r < split_state.nr[s]; r++) {
|
|
1272
|
+
for (size_t j = 0; j < n_bufs; j++) {
|
|
1273
|
+
ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
|
|
1274
|
+
GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0);
|
|
1275
|
+
const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0];
|
|
1276
|
+
ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data,
|
|
1277
|
+
simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes,
|
|
1278
|
+
row_count, simple_tensor->nb[1], tensor->nb[1]);
|
|
1279
|
+
offset_data += nbytes;
|
|
1280
|
+
simple_offsets[j] += nbytes;
|
|
1281
|
+
}
|
|
1282
|
+
}
|
|
1283
|
+
}
|
|
1284
|
+
GGML_ASSERT(offset_data*row_count == size);
|
|
1285
|
+
return;
|
|
1286
|
+
}
|
|
1287
|
+
GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1);
|
|
1288
|
+
|
|
1289
|
+
const size_t row_stride = tensor->nb[2];
|
|
1290
|
+
GGML_ASSERT(offset % row_stride == 0);
|
|
1291
|
+
GGML_ASSERT(size % row_stride == 0);
|
|
1292
|
+
const int64_t row_start = offset / row_stride;
|
|
1293
|
+
const int64_t row_count = size / row_stride;
|
|
1294
|
+
GGML_ASSERT(row_start + row_count <= tensor->ne[2]);
|
|
1295
|
+
|
|
1296
|
+
for (size_t s = 0; s < split_state.n_segments; s++) {
|
|
1297
|
+
for (size_t r = 0; r < split_state.nr[s]; r++) {
|
|
1298
|
+
for (size_t j = 0; j < n_bufs; j++) {
|
|
1299
|
+
ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
|
|
1300
|
+
const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1];
|
|
1301
|
+
ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data,
|
|
1302
|
+
simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes,
|
|
1303
|
+
row_count, simple_tensor->nb[2], tensor->nb[2]);
|
|
1304
|
+
offset_data += nbytes;
|
|
1305
|
+
simple_offsets[j] += nbytes;
|
|
1306
|
+
}
|
|
1307
|
+
}
|
|
1308
|
+
}
|
|
1309
|
+
GGML_ASSERT(offset_data*row_count == size);
|
|
1310
|
+
return;
|
|
1311
|
+
}
|
|
1312
|
+
|
|
1313
|
+
switch (split_state.axis) {
|
|
1314
|
+
case GGML_BACKEND_SPLIT_AXIS_0:
|
|
1315
|
+
case GGML_BACKEND_SPLIT_AXIS_1:
|
|
1316
|
+
case GGML_BACKEND_SPLIT_AXIS_2: {
|
|
1317
|
+
// Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
|
|
1318
|
+
const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
|
|
1319
|
+
GGML_ASSERT(offset % chunk_size_full == 0);
|
|
1320
|
+
GGML_ASSERT(size % chunk_size_full == 0);
|
|
1321
|
+
const int64_t i_start = offset /chunk_size_full;
|
|
1322
|
+
const int64_t i_stop = (offset + size)/chunk_size_full;
|
|
1323
|
+
size_t offset_j = 0;
|
|
1324
|
+
for (size_t j = 0; j < n_bufs; j++) {
|
|
1325
|
+
ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
|
|
1326
|
+
const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
|
|
1327
|
+
if (chunk_size_j == 0) {
|
|
1328
|
+
continue;
|
|
1329
|
+
}
|
|
1330
|
+
const size_t simple_offset = i_start * chunk_size_j;
|
|
1331
|
+
ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full);
|
|
1332
|
+
offset_j += chunk_size_j;
|
|
1333
|
+
}
|
|
1334
|
+
GGML_ASSERT(offset_j == chunk_size_full);
|
|
1335
|
+
} break;
|
|
1336
|
+
case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
|
|
1337
|
+
for (size_t j = 0; j < n_bufs; j++) {
|
|
1338
|
+
ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
|
|
1339
|
+
ggml_backend_tensor_set(simple_tensor, data, offset, size);
|
|
1340
|
+
}
|
|
1341
|
+
} break;
|
|
1342
|
+
case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
|
|
1343
|
+
GGML_ASSERT(tensor->type == GGML_TYPE_F32);
|
|
1344
|
+
const int64_t ne = ggml_nelements(tensor);
|
|
1345
|
+
std::vector<float> tmp;
|
|
1346
|
+
tmp.reserve(ne);
|
|
1347
|
+
for (int64_t i = 0; i < ne; i++) {
|
|
1348
|
+
tmp.push_back(((const float *) data)[i] / n_bufs);
|
|
1349
|
+
}
|
|
1350
|
+
for (size_t j = 0; j < n_bufs; j++) {
|
|
1351
|
+
ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
|
|
1352
|
+
ggml_backend_tensor_set(simple_tensor, tmp.data(), offset, size);
|
|
1353
|
+
}
|
|
1354
|
+
} break;
|
|
1355
|
+
default: {
|
|
1356
|
+
GGML_ABORT("fatal error");
|
|
1357
|
+
}
|
|
1358
|
+
}
|
|
1359
|
+
}
|
|
1360
|
+
|
|
1361
|
+
static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
1362
|
+
const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer);
|
|
1363
|
+
GGML_ASSERT(ggml_is_contiguous(tensor));
|
|
1364
|
+
|
|
1365
|
+
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
|
|
1366
|
+
|
|
1367
|
+
if (split_state.n_segments != 1 || split_state.nr[0] != 1) {
|
|
1368
|
+
GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
|
|
1369
|
+
GGML_ASSERT(split_state.nr[0] != 0);
|
|
1370
|
+
GGML_ASSERT(tensor->ne[3] == 1);
|
|
1371
|
+
|
|
1372
|
+
size_t offset_data = 0;
|
|
1373
|
+
std::vector<size_t> simple_offsets(n_bufs, 0);
|
|
1374
|
+
if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) {
|
|
1375
|
+
GGML_ASSERT(tensor->ne[2] == 1);
|
|
1376
|
+
|
|
1377
|
+
const size_t row_stride = tensor->nb[1];
|
|
1378
|
+
GGML_ASSERT(offset % row_stride == 0);
|
|
1379
|
+
GGML_ASSERT(size % row_stride == 0);
|
|
1380
|
+
const int64_t row_start = offset / row_stride;
|
|
1381
|
+
const int64_t row_count = size / row_stride;
|
|
1382
|
+
GGML_ASSERT(row_start + row_count <= tensor->ne[1]);
|
|
1383
|
+
|
|
1384
|
+
const int64_t blck_size = ggml_blck_size(tensor->type);
|
|
1385
|
+
for (size_t s = 0; s < split_state.n_segments; s++) {
|
|
1386
|
+
for (size_t r = 0; r < split_state.nr[s]; r++) {
|
|
1387
|
+
for (size_t j = 0; j < n_bufs; j++) {
|
|
1388
|
+
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
|
|
1389
|
+
GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0);
|
|
1390
|
+
const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0];
|
|
1391
|
+
ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data,
|
|
1392
|
+
simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes,
|
|
1393
|
+
row_count, simple_tensor->nb[1], tensor->nb[1]);
|
|
1394
|
+
offset_data += nbytes;
|
|
1395
|
+
simple_offsets[j] += nbytes;
|
|
1396
|
+
}
|
|
1397
|
+
}
|
|
1398
|
+
}
|
|
1399
|
+
GGML_ASSERT(offset_data*row_count == size);
|
|
1400
|
+
return;
|
|
1401
|
+
}
|
|
1402
|
+
GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1);
|
|
1403
|
+
|
|
1404
|
+
const size_t row_stride = tensor->nb[2];
|
|
1405
|
+
GGML_ASSERT(offset % row_stride == 0);
|
|
1406
|
+
GGML_ASSERT(size % row_stride == 0);
|
|
1407
|
+
const int64_t row_start = offset / row_stride;
|
|
1408
|
+
const int64_t row_count = size / row_stride;
|
|
1409
|
+
GGML_ASSERT(row_start + row_count <= tensor->ne[2]);
|
|
1410
|
+
|
|
1411
|
+
for (size_t s = 0; s < split_state.n_segments; s++) {
|
|
1412
|
+
for (size_t r = 0; r < split_state.nr[s]; r++) {
|
|
1413
|
+
for (size_t j = 0; j < n_bufs; j++) {
|
|
1414
|
+
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
|
|
1415
|
+
const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1];
|
|
1416
|
+
ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data,
|
|
1417
|
+
simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes,
|
|
1418
|
+
row_count, simple_tensor->nb[2], tensor->nb[2]);
|
|
1419
|
+
offset_data += nbytes;
|
|
1420
|
+
simple_offsets[j] += nbytes;
|
|
1421
|
+
}
|
|
1422
|
+
}
|
|
1423
|
+
}
|
|
1424
|
+
GGML_ASSERT(offset_data*row_count == size);
|
|
1425
|
+
return;
|
|
1426
|
+
}
|
|
1427
|
+
|
|
1428
|
+
switch (split_state.axis) {
|
|
1429
|
+
case GGML_BACKEND_SPLIT_AXIS_0:
|
|
1430
|
+
case GGML_BACKEND_SPLIT_AXIS_1:
|
|
1431
|
+
case GGML_BACKEND_SPLIT_AXIS_2: {
|
|
1432
|
+
// Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
|
|
1433
|
+
const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
|
|
1434
|
+
GGML_ASSERT(offset % chunk_size_full == 0);
|
|
1435
|
+
GGML_ASSERT(size % chunk_size_full == 0);
|
|
1436
|
+
const int64_t i_start = offset /chunk_size_full;
|
|
1437
|
+
const int64_t i_stop = (offset + size)/chunk_size_full;
|
|
1438
|
+
size_t offset_j = 0;
|
|
1439
|
+
for (size_t j = 0; j < n_bufs; j++){
|
|
1440
|
+
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
|
|
1441
|
+
const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
|
|
1442
|
+
if (chunk_size_j == 0) {
|
|
1443
|
+
continue;
|
|
1444
|
+
}
|
|
1445
|
+
const size_t simple_offset = i_start * chunk_size_j;
|
|
1446
|
+
ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full);
|
|
1447
|
+
offset_j += chunk_size_j;
|
|
1448
|
+
}
|
|
1449
|
+
GGML_ASSERT(offset_j == chunk_size_full);
|
|
1450
|
+
} break;
|
|
1451
|
+
case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
|
|
1452
|
+
// TODO other simple backend may be better
|
|
1453
|
+
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0);
|
|
1454
|
+
ggml_backend_tensor_get(simple_tensor, data, offset, size);
|
|
1455
|
+
} break;
|
|
1456
|
+
default: {
|
|
1457
|
+
GGML_ABORT("fatal error");
|
|
1458
|
+
}
|
|
1459
|
+
}
|
|
1460
|
+
}
|
|
1461
|
+
|
|
1462
|
+
static void ggml_backend_meta_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
|
1463
|
+
const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer);
|
|
1464
|
+
for (size_t i = 0; i < n_buffers; i++) {
|
|
1465
|
+
ggml_backend_buffer_clear(ggml_backend_meta_buffer_simple_buffer(buffer, i), value);
|
|
1466
|
+
}
|
|
1467
|
+
}
|
|
1468
|
+
|
|
1469
|
+
static void ggml_backend_meta_buffer_reset(ggml_backend_buffer_t buffer) {
|
|
1470
|
+
GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
|
|
1471
|
+
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
|
|
1472
|
+
for (size_t i = 0; i < buf_ctx->bufs.size(); i++) {
|
|
1473
|
+
ggml_backend_buffer_reset(ggml_backend_meta_buffer_simple_buffer(buffer, i));
|
|
1474
|
+
}
|
|
1475
|
+
}
|
|
1476
|
+
|
|
1477
|
+
static const ggml_backend_buffer_i ggml_backend_meta_buffer_iface = {
|
|
1478
|
+
/* .free_buffer = */ ggml_backend_meta_buffer_free_buffer,
|
|
1479
|
+
/* .get_base = */ ggml_backend_meta_buffer_get_base,
|
|
1480
|
+
/* .init_tensor = */ ggml_backend_meta_buffer_init_tensor,
|
|
1481
|
+
/* .memset_tensor = */ nullptr, // TODO implement
|
|
1482
|
+
/* .set_tensor = */ ggml_backend_meta_buffer_set_tensor,
|
|
1483
|
+
/* .get_tensor = */ ggml_backend_meta_buffer_get_tensor,
|
|
1484
|
+
/* .set_tensor_2d = */ nullptr,
|
|
1485
|
+
/* .get_tensor_2d = */ nullptr,
|
|
1486
|
+
/* .cpy_tensor = */ nullptr,
|
|
1487
|
+
/* .clear = */ ggml_backend_meta_buffer_clear,
|
|
1488
|
+
/* .reset = */ ggml_backend_meta_buffer_reset,
|
|
1489
|
+
};
|
|
1490
|
+
|
|
1491
|
+
bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf) {
|
|
1492
|
+
return buf != nullptr && buf->iface.free_buffer == ggml_backend_meta_buffer_iface.free_buffer;
|
|
1493
|
+
}
|
|
1494
|
+
|
|
1495
|
+
static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
1496
|
+
const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
|
|
1497
|
+
|
|
1498
|
+
const ggml_init_params params = {
|
|
1499
|
+
/*.mem_size =*/ 1024*1024*ggml_tensor_overhead(), // FIXME
|
|
1500
|
+
/*.mem_buffer =*/ nullptr,
|
|
1501
|
+
/*.no_alloc =*/ true,
|
|
1502
|
+
};
|
|
1503
|
+
ggml_backend_meta_simple_tensor_container stc_static;
|
|
1504
|
+
ggml_backend_meta_simple_tensor_container stc_compute_0(params, n_simple_bufts);
|
|
1505
|
+
ggml_backend_meta_simple_tensor_container stc_compute_1(params, n_simple_bufts);
|
|
1506
|
+
|
|
1507
|
+
size_t max_size = 0;
|
|
1508
|
+
std::vector<ggml_backend_buffer_t> bufs;
|
|
1509
|
+
bufs.reserve(n_simple_bufts);
|
|
1510
|
+
for (size_t i = 0; i < n_simple_bufts; i++) {
|
|
1511
|
+
bufs.push_back(ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size));
|
|
1512
|
+
GGML_ASSERT(bufs.back() != nullptr);
|
|
1513
|
+
max_size = std::max(max_size, ggml_backend_buffer_get_size(bufs.back()));
|
|
1514
|
+
}
|
|
1515
|
+
ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs);
|
|
1516
|
+
|
|
1517
|
+
return ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, buf_ctx, max_size);
|
|
1518
|
+
}
|
|
1519
|
+
|
|
1520
|
+
struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
|
|
1521
|
+
const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
|
|
1522
|
+
|
|
1523
|
+
constexpr size_t compute_headroom = 16; // Maximum number of views per statically allocated tensor that can be created between evals.
|
|
1524
|
+
const ggml_init_params params_static = {
|
|
1525
|
+
/*.mem_size =*/ ggml_get_mem_size(ctx),
|
|
1526
|
+
/*.mem_buffer =*/ nullptr,
|
|
1527
|
+
/*.no_alloc =*/ true,
|
|
1528
|
+
};
|
|
1529
|
+
const ggml_init_params params_compute = {
|
|
1530
|
+
/*.mem_size =*/ compute_headroom*ggml_get_mem_size(ctx),
|
|
1531
|
+
/*.mem_buffer =*/ nullptr,
|
|
1532
|
+
/*.no_alloc =*/ true,
|
|
1533
|
+
};
|
|
1534
|
+
ggml_backend_meta_simple_tensor_container stc_static (params_static, n_simple_bufts);
|
|
1535
|
+
ggml_backend_meta_simple_tensor_container stc_compute_0(params_compute, n_simple_bufts);
|
|
1536
|
+
ggml_backend_meta_simple_tensor_container stc_compute_1(params_compute, n_simple_bufts);
|
|
1537
|
+
|
|
1538
|
+
std::vector<ggml_backend_buffer_t> bufs(n_simple_bufts, nullptr);
|
|
1539
|
+
ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs);
|
|
1540
|
+
|
|
1541
|
+
ggml_backend_buffer_t meta_buf = ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, meta_buf_ctx, 0);
|
|
1542
|
+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
|
|
1543
|
+
t->buffer = meta_buf;
|
|
1544
|
+
ggml_backend_meta_buffer_init_tensor_impl(meta_buf_ctx->stc_static, t);
|
|
1545
|
+
t->data = (void *) 0x2000000000000000; // FIXME
|
|
1546
|
+
}
|
|
1547
|
+
for (size_t i = 0; i < n_simple_bufts; i++) {
|
|
1548
|
+
ggml_context * ctx = meta_buf_ctx->stc_static.ctxs[i].get();
|
|
1549
|
+
ggml_backend_buffer_type_t simple_buft = ggml_backend_meta_buft_simple_buft(buft, i);
|
|
1550
|
+
|
|
1551
|
+
// If a ggml_context only has zero-sized tensors, ggml_backend_alloc_ctx_tensors_from_buft returns NULL.
|
|
1552
|
+
// For those edge cases, allocate a dummy buffer instead.
|
|
1553
|
+
bool any_nonzero_slice = false;
|
|
1554
|
+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
|
|
1555
|
+
if (ggml_nelements(t) != 0) {
|
|
1556
|
+
any_nonzero_slice = true;
|
|
1557
|
+
break;
|
|
1558
|
+
}
|
|
1559
|
+
}
|
|
1560
|
+
if (any_nonzero_slice) {
|
|
1561
|
+
meta_buf_ctx->bufs[i].reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx, simple_buft));
|
|
1562
|
+
} else {
|
|
1563
|
+
meta_buf_ctx->bufs[i].reset(ggml_backend_buft_alloc_buffer(simple_buft, 0));
|
|
1564
|
+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
|
|
1565
|
+
t->buffer = meta_buf_ctx->bufs[i].get();
|
|
1566
|
+
}
|
|
1567
|
+
}
|
|
1568
|
+
GGML_ASSERT(meta_buf_ctx->bufs[i]);
|
|
1569
|
+
meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->bufs[i].get()));
|
|
1570
|
+
}
|
|
1571
|
+
return meta_buf;
|
|
1572
|
+
}
|
|
1573
|
+
|
|
1574
|
+
//
|
|
1575
|
+
// meta backend
|
|
1576
|
+
//
|
|
1577
|
+
|
|
1578
|
+
static ggml_guid_t ggml_backend_meta_guid() {
|
|
1579
|
+
static ggml_guid guid = {0xf1, 0x0e, 0x34, 0xcf, 0x9c, 0x6f, 0x43, 0xcb, 0x96, 0x92, 0xbe, 0x8e, 0xbb, 0x71, 0x3f, 0xda};
|
|
1580
|
+
return &guid;
|
|
1581
|
+
}
|
|
1582
|
+
|
|
1583
|
+
struct ggml_backend_meta_context {
|
|
1584
|
+
struct cgraph_config {
|
|
1585
|
+
ggml_cgraph * cgraph_main = nullptr;
|
|
1586
|
+
int offset = 0; // Node offset vs. original graph
|
|
1587
|
+
|
|
1588
|
+
std::vector<ggml_cgraph *> cgraphs_aux;
|
|
1589
|
+
};
|
|
1590
|
+
struct backend_config {
|
|
1591
|
+
ggml_backend_t backend;
|
|
1592
|
+
|
|
1593
|
+
std::vector<cgraph_config> cgraphs;
|
|
1594
|
+
std::vector<ggml_tensor *> nodes;
|
|
1595
|
+
std::vector<ggml_backend_buffer_ptr> bufs;
|
|
1596
|
+
|
|
1597
|
+
backend_config(ggml_backend_t backend, const size_t n_reduce_steps) : backend(backend) {
|
|
1598
|
+
bufs.resize(n_reduce_steps);
|
|
1599
|
+
}
|
|
1600
|
+
};
|
|
1601
|
+
std::string name;
|
|
1602
|
+
std::vector<backend_config> backend_configs;
|
|
1603
|
+
ggml_context_ptr ctx;
|
|
1604
|
+
std::vector<ggml_cgraph *> cgraphs_aux;
|
|
1605
|
+
std::vector<ggml_tensor *> nodes_aux;
|
|
1606
|
+
size_t n_reduce_steps;
|
|
1607
|
+
int max_nnodes = 0;
|
|
1608
|
+
size_t max_tmp_size = 0;
|
|
1609
|
+
size_t max_subgraphs = 0;
|
|
1610
|
+
size_t n_subgraphs = 0;
|
|
1611
|
+
uint64_t uid = 0;
|
|
1612
|
+
|
|
1613
|
+
void * comm_ctx = nullptr;
|
|
1614
|
+
ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr;
|
|
1615
|
+
|
|
1616
|
+
ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) {
|
|
1617
|
+
const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev);
|
|
1618
|
+
n_reduce_steps = std::ceil(std::log2(n_devs));
|
|
1619
|
+
name = "Meta(";
|
|
1620
|
+
std::vector<ggml_backend_t> simple_backends;
|
|
1621
|
+
backend_configs.reserve(n_devs);
|
|
1622
|
+
simple_backends.reserve(n_devs);
|
|
1623
|
+
for (size_t i = 0; i < n_devs; i++) {
|
|
1624
|
+
ggml_backend_dev_t simple_dev = ggml_backend_meta_dev_simple_dev(meta_dev, i);
|
|
1625
|
+
if (i > 0) {
|
|
1626
|
+
name += ",";
|
|
1627
|
+
}
|
|
1628
|
+
name += ggml_backend_dev_name(simple_dev);
|
|
1629
|
+
simple_backends.push_back(ggml_backend_dev_init(simple_dev, params));
|
|
1630
|
+
backend_configs.emplace_back(simple_backends.back(), n_reduce_steps);
|
|
1631
|
+
}
|
|
1632
|
+
name += ")";
|
|
1633
|
+
|
|
1634
|
+
if (n_devs > 1) {
|
|
1635
|
+
ggml_backend_comm_init_t comm_init = (ggml_backend_comm_init_t) ggml_backend_reg_get_proc_address(
|
|
1636
|
+
ggml_backend_dev_backend_reg(ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_init");
|
|
1637
|
+
if (comm_init != nullptr) {
|
|
1638
|
+
comm_ctx = comm_init(simple_backends.data(), simple_backends.size());
|
|
1639
|
+
}
|
|
1640
|
+
}
|
|
1641
|
+
if (comm_ctx != nullptr) {
|
|
1642
|
+
comm_allreduce = (ggml_backend_comm_allreduce_tensor_t)
|
|
1643
|
+
ggml_backend_reg_get_proc_address(ggml_backend_dev_backend_reg(
|
|
1644
|
+
ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_allreduce_tensor");
|
|
1645
|
+
GGML_ASSERT(comm_allreduce != nullptr);
|
|
1646
|
+
}
|
|
1647
|
+
}
|
|
1648
|
+
|
|
1649
|
+
~ggml_backend_meta_context() {
|
|
1650
|
+
if (comm_ctx != nullptr) {
|
|
1651
|
+
ggml_backend_comm_free_t comm_free = (ggml_backend_comm_free_t) ggml_backend_reg_get_proc_address(
|
|
1652
|
+
ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_configs[0].backend)), "ggml_backend_comm_free");
|
|
1653
|
+
GGML_ASSERT(comm_free != nullptr);
|
|
1654
|
+
comm_free(comm_ctx);
|
|
1655
|
+
}
|
|
1656
|
+
for (auto & bc : backend_configs) {
|
|
1657
|
+
ggml_backend_free(bc.backend);
|
|
1658
|
+
}
|
|
1659
|
+
}
|
|
1660
|
+
};
|
|
1661
|
+
|
|
1662
|
+
static const char * ggml_backend_meta_get_name(ggml_backend_t backend) {
|
|
1663
|
+
GGML_ASSERT(ggml_backend_is_meta(backend));
|
|
1664
|
+
const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) backend->context;
|
|
1665
|
+
return backend_ctx->name.c_str();
|
|
1666
|
+
}
|
|
1667
|
+
|
|
1668
|
+
static void ggml_backend_meta_free(ggml_backend_t backend) {
|
|
1669
|
+
GGML_ASSERT(ggml_backend_is_meta(backend));
|
|
1670
|
+
ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context;
|
|
1671
|
+
delete backend_ctx;
|
|
1672
|
+
delete backend;
|
|
1673
|
+
}
|
|
1674
|
+
|
|
1675
|
+
static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
1676
|
+
const size_t n_backends = ggml_backend_meta_n_backends(backend);
|
|
1677
|
+
GGML_ASSERT(offset == 0);
|
|
1678
|
+
GGML_ASSERT(ggml_is_contiguous(tensor));
|
|
1679
|
+
|
|
1680
|
+
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
|
|
1681
|
+
GGML_ASSERT(split_state.n_segments == 1);
|
|
1682
|
+
GGML_ASSERT(split_state.nr[0] == 1);
|
|
1683
|
+
|
|
1684
|
+
switch (split_state.axis) {
|
|
1685
|
+
case GGML_BACKEND_SPLIT_AXIS_0:
|
|
1686
|
+
case GGML_BACKEND_SPLIT_AXIS_1:
|
|
1687
|
+
case GGML_BACKEND_SPLIT_AXIS_2: {
|
|
1688
|
+
// Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
|
|
1689
|
+
const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
|
|
1690
|
+
GGML_ASSERT(offset % chunk_size_full == 0);
|
|
1691
|
+
GGML_ASSERT(size % chunk_size_full == 0);
|
|
1692
|
+
const int64_t i_start = offset /chunk_size_full;
|
|
1693
|
+
const int64_t i_stop = (offset + size)/chunk_size_full;
|
|
1694
|
+
size_t offset_j = 0;
|
|
1695
|
+
for (size_t j = 0; j < n_backends; j++){
|
|
1696
|
+
ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j);
|
|
1697
|
+
ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
|
|
1698
|
+
const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
|
|
1699
|
+
if (chunk_size_j == 0) {
|
|
1700
|
+
continue;
|
|
1701
|
+
}
|
|
1702
|
+
ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j,
|
|
1703
|
+
i_stop - i_start, chunk_size_j, chunk_size_full);
|
|
1704
|
+
offset_j += chunk_size_j;
|
|
1705
|
+
}
|
|
1706
|
+
GGML_ASSERT(offset_j == chunk_size_full);
|
|
1707
|
+
} break;
|
|
1708
|
+
case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
|
|
1709
|
+
for (size_t j = 0; j < n_backends; j++) {
|
|
1710
|
+
ggml_backend_tensor_set_async(
|
|
1711
|
+
ggml_backend_meta_simple_backend(backend, j), ggml_backend_meta_buffer_simple_tensor(tensor, j), data, offset, size);
|
|
1712
|
+
}
|
|
1713
|
+
} break;
|
|
1714
|
+
default: {
|
|
1715
|
+
GGML_ABORT("fatal error");
|
|
1716
|
+
}
|
|
1717
|
+
}
|
|
1718
|
+
}
|
|
1719
|
+
|
|
1720
|
+
static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
1721
|
+
const size_t n_backends = ggml_backend_meta_n_backends(backend);
|
|
1722
|
+
GGML_ASSERT(offset == 0);
|
|
1723
|
+
GGML_ASSERT(ggml_is_contiguous(tensor));
|
|
1724
|
+
|
|
1725
|
+
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
|
|
1726
|
+
GGML_ASSERT(split_state.n_segments == 1);
|
|
1727
|
+
GGML_ASSERT(split_state.nr[0] == 1);
|
|
1728
|
+
|
|
1729
|
+
switch (split_state.axis) {
|
|
1730
|
+
case GGML_BACKEND_SPLIT_AXIS_0:
|
|
1731
|
+
case GGML_BACKEND_SPLIT_AXIS_1:
|
|
1732
|
+
case GGML_BACKEND_SPLIT_AXIS_2: {
|
|
1733
|
+
// Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
|
|
1734
|
+
const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
|
|
1735
|
+
GGML_ASSERT(offset % chunk_size_full == 0);
|
|
1736
|
+
GGML_ASSERT(size % chunk_size_full == 0);
|
|
1737
|
+
const int64_t i_start = offset /chunk_size_full;
|
|
1738
|
+
const int64_t i_stop = (offset + size)/chunk_size_full;
|
|
1739
|
+
size_t offset_j = 0;
|
|
1740
|
+
for (size_t j = 0; j < n_backends; j++){
|
|
1741
|
+
ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j);
|
|
1742
|
+
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
|
|
1743
|
+
const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
|
|
1744
|
+
if (chunk_size_j == 0) {
|
|
1745
|
+
continue;
|
|
1746
|
+
}
|
|
1747
|
+
ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j,
|
|
1748
|
+
i_stop - i_start, chunk_size_j, chunk_size_full);
|
|
1749
|
+
offset_j += chunk_size_j;
|
|
1750
|
+
}
|
|
1751
|
+
GGML_ASSERT(offset_j == chunk_size_full);
|
|
1752
|
+
} break;
|
|
1753
|
+
case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
|
|
1754
|
+
// TODO other simple backend may be better
|
|
1755
|
+
ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, 0);
|
|
1756
|
+
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0);
|
|
1757
|
+
ggml_backend_tensor_get_async(simple_backend, simple_tensor, data, offset, size);
|
|
1758
|
+
} break;
|
|
1759
|
+
default: {
|
|
1760
|
+
GGML_ABORT("fatal error");
|
|
1761
|
+
}
|
|
1762
|
+
}
|
|
1763
|
+
}
|
|
1764
|
+
|
|
1765
|
+
static void ggml_backend_meta_synchronize(ggml_backend_t backend) {
|
|
1766
|
+
const size_t n_backends = ggml_backend_meta_n_backends(backend);
|
|
1767
|
+
for (size_t i = 0; i < n_backends; i++) {
|
|
1768
|
+
ggml_backend_synchronize(ggml_backend_meta_simple_backend(backend, i));
|
|
1769
|
+
}
|
|
1770
|
+
}
|
|
1771
|
+
|
|
1772
|
+
static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
|
1773
|
+
GGML_ASSERT(cgraph->grads == nullptr);
|
|
1774
|
+
const size_t n_backends = ggml_backend_meta_n_backends(backend);
|
|
1775
|
+
ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context;
|
|
1776
|
+
|
|
1777
|
+
// If the previous cgraph had a defined UID it can be used to skip rebuilding the subgraphs per simple backend.
|
|
1778
|
+
const bool needs_rebuild = (cgraph->uid == 0) || (cgraph->uid != backend_ctx->uid);
|
|
1779
|
+
|
|
1780
|
+
bool max_nnodes_raised = false;
|
|
1781
|
+
if (cgraph->n_nodes > backend_ctx->max_nnodes) {
|
|
1782
|
+
for (size_t j = 0; j < n_backends; j++) {
|
|
1783
|
+
auto & bcj = backend_ctx->backend_configs[j];
|
|
1784
|
+
bcj.nodes.resize(cgraph->n_nodes);
|
|
1785
|
+
bcj.cgraphs.resize(cgraph->n_nodes);
|
|
1786
|
+
}
|
|
1787
|
+
backend_ctx->max_nnodes = cgraph->n_nodes;
|
|
1788
|
+
max_nnodes_raised = true;
|
|
1789
|
+
assert(needs_rebuild);
|
|
1790
|
+
}
|
|
1791
|
+
|
|
1792
|
+
if (needs_rebuild) {
|
|
1793
|
+
std::set<ggml_backend_buffer_t> used_buffers;
|
|
1794
|
+
for (int i = 0; i < cgraph->n_leafs; i++) {
|
|
1795
|
+
if (ggml_backend_buffer_is_meta(cgraph->leafs[i]->buffer)) {
|
|
1796
|
+
used_buffers.emplace(cgraph->leafs[i]->buffer);
|
|
1797
|
+
}
|
|
1798
|
+
}
|
|
1799
|
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
1800
|
+
if (ggml_backend_buffer_is_meta(cgraph->nodes[i]->buffer)) {
|
|
1801
|
+
used_buffers.emplace(cgraph->nodes[i]->buffer);
|
|
1802
|
+
}
|
|
1803
|
+
}
|
|
1804
|
+
for (ggml_backend_buffer_t buf : used_buffers) {
|
|
1805
|
+
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buf->context;
|
|
1806
|
+
buf_ctx->stc_compute_index_next = buf_ctx->stc_compute_index ^ 1;
|
|
1807
|
+
ggml_backend_meta_simple_tensor_container & stc = buf_ctx->stc_compute[buf_ctx->stc_compute_index_next];
|
|
1808
|
+
for (ggml_context_ptr & ctx : stc.ctxs) {
|
|
1809
|
+
ggml_reset(ctx.get());
|
|
1810
|
+
}
|
|
1811
|
+
stc.simple_tensors.clear();
|
|
1812
|
+
}
|
|
1813
|
+
size_t n_subgraphs = 0;
|
|
1814
|
+
size_t max_tmp_size = 0;
|
|
1815
|
+
|
|
1816
|
+
for (size_t j = 0; j < n_backends; j++) {
|
|
1817
|
+
auto & bcj = backend_ctx->backend_configs[j];
|
|
1818
|
+
|
|
1819
|
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
1820
|
+
ggml_tensor * node = cgraph->nodes[i];
|
|
1821
|
+
if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) {
|
|
1822
|
+
// FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes.
|
|
1823
|
+
// For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash.
|
|
1824
|
+
bcj.nodes[i] = node;
|
|
1825
|
+
continue;
|
|
1826
|
+
}
|
|
1827
|
+
bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j);
|
|
1828
|
+
GGML_ASSERT(bcj.nodes[i]);
|
|
1829
|
+
}
|
|
1830
|
+
}
|
|
1831
|
+
|
|
1832
|
+
{
|
|
1833
|
+
// For MoE models it may make sense to delay the AllReduce in order to reduce I/O:
|
|
1834
|
+
auto get_i_delayed = [&](const int i) -> int {
|
|
1835
|
+
int id = i; // i_delayed
|
|
1836
|
+
int idr = i; // i_delayed return, last safe return value
|
|
1837
|
+
|
|
1838
|
+
ggml_tensor * node = cgraph->nodes[id];
|
|
1839
|
+
int32_t n_used = ggml_node_get_use_count(cgraph, id);
|
|
1840
|
+
|
|
1841
|
+
// Skip MIRRORED nodes that don't consume node
|
|
1842
|
+
auto skip_unrelated = [&]() {
|
|
1843
|
+
while (id + 1 < cgraph->n_nodes) {
|
|
1844
|
+
ggml_tensor * next = cgraph->nodes[id+1];
|
|
1845
|
+
if (ggml_backend_meta_get_split_state(next, false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
|
1846
|
+
break;
|
|
1847
|
+
}
|
|
1848
|
+
bool safe = true;
|
|
1849
|
+
for (int s = 0; s < GGML_MAX_SRC; s++) {
|
|
1850
|
+
if (next->src[s] == nullptr) {
|
|
1851
|
+
continue;
|
|
1852
|
+
}
|
|
1853
|
+
if (next->src[s] == node) {
|
|
1854
|
+
safe = false;
|
|
1855
|
+
break;
|
|
1856
|
+
}
|
|
1857
|
+
if (ggml_backend_meta_get_split_state(next->src[s], false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
|
1858
|
+
safe = false;
|
|
1859
|
+
break;
|
|
1860
|
+
}
|
|
1861
|
+
}
|
|
1862
|
+
if (!safe) {
|
|
1863
|
+
break;
|
|
1864
|
+
}
|
|
1865
|
+
id++;
|
|
1866
|
+
}
|
|
1867
|
+
};
|
|
1868
|
+
|
|
1869
|
+
skip_unrelated();
|
|
1870
|
+
if (id + 1 >= cgraph->n_nodes) {
|
|
1871
|
+
return idr;
|
|
1872
|
+
}
|
|
1873
|
+
{
|
|
1874
|
+
ggml_tensor * next = cgraph->nodes[id+1];
|
|
1875
|
+
if (next->op == GGML_OP_ADD_ID && next->src[0] == node &&
|
|
1876
|
+
ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL &&
|
|
1877
|
+
ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
|
1878
|
+
node = next;
|
|
1879
|
+
id++;
|
|
1880
|
+
idr = id;
|
|
1881
|
+
n_used = ggml_node_get_use_count(cgraph, id);
|
|
1882
|
+
}
|
|
1883
|
+
}
|
|
1884
|
+
// Chain of MULs with MIRRORED src[1]
|
|
1885
|
+
while (true) {
|
|
1886
|
+
skip_unrelated();
|
|
1887
|
+
if (id + 1 >= cgraph->n_nodes) {
|
|
1888
|
+
return idr;
|
|
1889
|
+
}
|
|
1890
|
+
ggml_tensor * next = cgraph->nodes[id+1];
|
|
1891
|
+
if (next->op == GGML_OP_MUL && next->src[0] == node &&
|
|
1892
|
+
ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
|
1893
|
+
node = next;
|
|
1894
|
+
id++;
|
|
1895
|
+
idr = id;
|
|
1896
|
+
n_used = ggml_node_get_use_count(cgraph, id);
|
|
1897
|
+
} else {
|
|
1898
|
+
break;
|
|
1899
|
+
}
|
|
1900
|
+
}
|
|
1901
|
+
|
|
1902
|
+
if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) {
|
|
1903
|
+
return idr;
|
|
1904
|
+
}
|
|
1905
|
+
for (int32_t k = 0; k < n_used; k++) {
|
|
1906
|
+
ggml_tensor * next = cgraph->nodes[id+1];
|
|
1907
|
+
if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] ||
|
|
1908
|
+
next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] ||
|
|
1909
|
+
ggml_node_get_use_count(cgraph, id+1) != 1) {
|
|
1910
|
+
return idr;
|
|
1911
|
+
}
|
|
1912
|
+
id++;
|
|
1913
|
+
}
|
|
1914
|
+
{
|
|
1915
|
+
ggml_tensor * next = cgraph->nodes[id+1];
|
|
1916
|
+
if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] ||
|
|
1917
|
+
next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) {
|
|
1918
|
+
return idr;
|
|
1919
|
+
}
|
|
1920
|
+
id++;
|
|
1921
|
+
}
|
|
1922
|
+
for (int32_t k = 0; k < n_used - 2; k++) {
|
|
1923
|
+
ggml_tensor * next = cgraph->nodes[id+1];
|
|
1924
|
+
if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] ||
|
|
1925
|
+
next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) {
|
|
1926
|
+
return idr;
|
|
1927
|
+
}
|
|
1928
|
+
id++;
|
|
1929
|
+
}
|
|
1930
|
+
idr = id;
|
|
1931
|
+
return idr;
|
|
1932
|
+
};
|
|
1933
|
+
|
|
1934
|
+
int i_start = 0;
|
|
1935
|
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
1936
|
+
ggml_tensor * node = cgraph->nodes[i];
|
|
1937
|
+
if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) {
|
|
1938
|
+
continue;
|
|
1939
|
+
}
|
|
1940
|
+
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false);
|
|
1941
|
+
if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) {
|
|
1942
|
+
max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node));
|
|
1943
|
+
}
|
|
1944
|
+
const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL;
|
|
1945
|
+
if (!new_subgraph) {
|
|
1946
|
+
continue;
|
|
1947
|
+
}
|
|
1948
|
+
|
|
1949
|
+
const int i_delayed = get_i_delayed(i);
|
|
1950
|
+
|
|
1951
|
+
// If we can delay the AllReduce we need to consider the interaction with zero-sized tensor slices.
|
|
1952
|
+
// A backend with such a slice would normally have valid data after participating in the AllReduce with a node that has
|
|
1953
|
+
// its compute flag disabled and thus gets its data zeroed out.
|
|
1954
|
+
// If the AllReduce is delayed then the nodes until that point also need to have their compute flag disabled.
|
|
1955
|
+
if (i_delayed > i) {
|
|
1956
|
+
for (size_t j = 0; j < n_backends; j++) {
|
|
1957
|
+
auto & bcj = backend_ctx->backend_configs[j];
|
|
1958
|
+
if ((bcj.nodes[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
|
1959
|
+
for (int ii = i + 1; ii <= i_delayed; ii++) {
|
|
1960
|
+
bcj.nodes[ii]->flags &= ~GGML_TENSOR_FLAG_COMPUTE;
|
|
1961
|
+
}
|
|
1962
|
+
}
|
|
1963
|
+
}
|
|
1964
|
+
}
|
|
1965
|
+
|
|
1966
|
+
i = i_delayed;
|
|
1967
|
+
|
|
1968
|
+
for (size_t j = 0; j < n_backends; j++) {
|
|
1969
|
+
auto & bcj = backend_ctx->backend_configs[j];
|
|
1970
|
+
bcj.cgraphs[n_subgraphs].offset = i_start;
|
|
1971
|
+
}
|
|
1972
|
+
n_subgraphs++;
|
|
1973
|
+
i_start = i + 1;
|
|
1974
|
+
}
|
|
1975
|
+
GGML_ASSERT(i_start == cgraph->n_nodes);
|
|
1976
|
+
}
|
|
1977
|
+
|
|
1978
|
+
backend_ctx->uid = cgraph->uid;
|
|
1979
|
+
backend_ctx->n_subgraphs = n_subgraphs;
|
|
1980
|
+
|
|
1981
|
+
if (max_tmp_size > backend_ctx->max_tmp_size) {
|
|
1982
|
+
for (size_t j = 0; j < n_backends; j++) {
|
|
1983
|
+
auto & bcj = backend_ctx->backend_configs[j];
|
|
1984
|
+
for (size_t i = 0; i < backend_ctx->n_reduce_steps; i++) {
|
|
1985
|
+
bcj.bufs[i].reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size));
|
|
1986
|
+
}
|
|
1987
|
+
}
|
|
1988
|
+
backend_ctx->max_tmp_size = max_tmp_size;
|
|
1989
|
+
}
|
|
1990
|
+
|
|
1991
|
+
if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) {
|
|
1992
|
+
backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs);
|
|
1993
|
+
const size_t n_nodes_per_device = 3 * backend_ctx->n_reduce_steps; // tmp + ADD (+zeroing) graph per step and device
|
|
1994
|
+
const size_t n_cgraphs_per_device = 2 * backend_ctx->n_reduce_steps; // ADD ( + zeroing) graph per step and device
|
|
1995
|
+
const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads);
|
|
1996
|
+
const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads);
|
|
1997
|
+
const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead();
|
|
1998
|
+
const ggml_init_params params = {
|
|
1999
|
+
/*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux),
|
|
2000
|
+
/*.mem_buffer =*/ nullptr,
|
|
2001
|
+
/*.no_alloc =*/ true,
|
|
2002
|
+
};
|
|
2003
|
+
backend_ctx->ctx.reset(ggml_init(params));
|
|
2004
|
+
for (size_t j = 0; j < n_backends; j++) {
|
|
2005
|
+
auto & bcj = backend_ctx->backend_configs[j];
|
|
2006
|
+
for (size_t i = 0; i < n_subgraphs; i++) {
|
|
2007
|
+
bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false);
|
|
2008
|
+
}
|
|
2009
|
+
}
|
|
2010
|
+
backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs);
|
|
2011
|
+
for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) {
|
|
2012
|
+
backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads);
|
|
2013
|
+
}
|
|
2014
|
+
backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs);
|
|
2015
|
+
for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) {
|
|
2016
|
+
backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1);
|
|
2017
|
+
}
|
|
2018
|
+
}
|
|
2019
|
+
|
|
2020
|
+
for (size_t j = 0; j < n_backends; j++) {
|
|
2021
|
+
auto & bcj = backend_ctx->backend_configs[j];
|
|
2022
|
+
for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) {
|
|
2023
|
+
ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main;
|
|
2024
|
+
const size_t i_node_start = bcj.cgraphs[i_graph].offset;
|
|
2025
|
+
const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes;
|
|
2026
|
+
cgraph_ij->n_nodes = i_node_stop - i_node_start;
|
|
2027
|
+
ggml_hash_set_reset(&cgraph_ij->visited_hash_set);
|
|
2028
|
+
for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) {
|
|
2029
|
+
ggml_tensor * node_ij = bcj.nodes[i_node];
|
|
2030
|
+
cgraph_ij->nodes[i_node - i_node_start] = node_ij;
|
|
2031
|
+
const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]);
|
|
2032
|
+
const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij);
|
|
2033
|
+
cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig];
|
|
2034
|
+
}
|
|
2035
|
+
cgraph_ij->uid = ggml_graph_next_uid();
|
|
2036
|
+
}
|
|
2037
|
+
}
|
|
2038
|
+
}
|
|
2039
|
+
|
|
2040
|
+
size_t iga = 0; // i graph aux
|
|
2041
|
+
size_t ina = 0; // i node aux
|
|
2042
|
+
|
|
2043
|
+
auto get_node_aux = [&](ggml_tensor * t) -> ggml_tensor * {
|
|
2044
|
+
ggml_tensor * ret = backend_ctx->nodes_aux[ina++];
|
|
2045
|
+
memset(ret, 0, sizeof(ggml_tensor));
|
|
2046
|
+
ret->op = GGML_OP_NONE;
|
|
2047
|
+
ret->type = t->type;
|
|
2048
|
+
for (size_t k = 0; k < GGML_MAX_DIMS; k++) {
|
|
2049
|
+
ret->ne[k] = t->ne[k];
|
|
2050
|
+
ret->nb[k] = t->nb[k];
|
|
2051
|
+
}
|
|
2052
|
+
return ret;
|
|
2053
|
+
};
|
|
2054
|
+
auto set_tmp_data = [&](ggml_tensor * tensor, const size_t j, const size_t i_buf) {
|
|
2055
|
+
auto & bcj = backend_ctx->backend_configs[j];
|
|
2056
|
+
ggml_backend_buffer_ptr & buf_ptr = bcj.bufs[i_buf];
|
|
2057
|
+
if (!buf_ptr || ggml_backend_buffer_get_size(buf_ptr.get()) < backend_ctx->max_tmp_size) {
|
|
2058
|
+
buf_ptr.reset(ggml_backend_alloc_buffer(bcj.backend, backend_ctx->max_tmp_size));
|
|
2059
|
+
}
|
|
2060
|
+
tensor->buffer = buf_ptr.get();
|
|
2061
|
+
tensor->data = ggml_backend_buffer_get_base(buf_ptr.get());
|
|
2062
|
+
};
|
|
2063
|
+
// FIXME usage_counts
|
|
2064
|
+
auto get_cgraph_aux = [&]() -> ggml_cgraph * {
|
|
2065
|
+
ggml_cgraph * ret = backend_ctx->cgraphs_aux[iga++];
|
|
2066
|
+
return ret;
|
|
2067
|
+
};
|
|
2068
|
+
|
|
2069
|
+
// Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable:
|
|
2070
|
+
auto allreduce_fallback = [&](size_t i) -> ggml_status {
|
|
2071
|
+
std::vector<ggml_cgraph *> step_cgraphs(n_backends, nullptr);
|
|
2072
|
+
|
|
2073
|
+
// Zero out nodes that were disabled due to having a zero-sized slice:
|
|
2074
|
+
for (size_t j = 0; j < n_backends; j++) {
|
|
2075
|
+
auto & bcj = backend_ctx->backend_configs[j];
|
|
2076
|
+
ggml_tensor * node = bcj.cgraphs[i].cgraph_main->nodes[bcj.cgraphs[i].cgraph_main->n_nodes - 1];
|
|
2077
|
+
if (node->flags & GGML_TENSOR_FLAG_COMPUTE) {
|
|
2078
|
+
continue;
|
|
2079
|
+
}
|
|
2080
|
+
ggml_tensor * node_zero = get_node_aux(node);
|
|
2081
|
+
node_zero->op = GGML_OP_SCALE; // FIXME 0.0f * NaN == NaN
|
|
2082
|
+
node_zero->src[0] = node;
|
|
2083
|
+
ggml_set_op_params_f32(node_zero, 0, 0.0f);
|
|
2084
|
+
node_zero->data = node->data;
|
|
2085
|
+
node_zero->buffer = node->buffer;
|
|
2086
|
+
node_zero->flags |= GGML_TENSOR_FLAG_COMPUTE;
|
|
2087
|
+
|
|
2088
|
+
step_cgraphs[j] = get_cgraph_aux();
|
|
2089
|
+
step_cgraphs[j]->nodes[0] = node_zero;
|
|
2090
|
+
step_cgraphs[j]->n_nodes = 1;
|
|
2091
|
+
const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]);
|
|
2092
|
+
if (status != GGML_STATUS_SUCCESS) {
|
|
2093
|
+
return status;
|
|
2094
|
+
}
|
|
2095
|
+
}
|
|
2096
|
+
std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr);
|
|
2097
|
+
|
|
2098
|
+
auto push_data = [&](const size_t j_src, const size_t j_dst, const size_t i_buf) {
|
|
2099
|
+
assert(step_cgraphs[j_dst] == nullptr);
|
|
2100
|
+
auto & bcj_src = backend_ctx->backend_configs[j_src];
|
|
2101
|
+
auto & bcj_dst = backend_ctx->backend_configs[j_dst];
|
|
2102
|
+
|
|
2103
|
+
ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1];
|
|
2104
|
+
ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1];
|
|
2105
|
+
GGML_ASSERT(ggml_is_contiguous(node_src));
|
|
2106
|
+
GGML_ASSERT(ggml_is_contiguous(node_dst));
|
|
2107
|
+
|
|
2108
|
+
ggml_tensor * node_tmp = get_node_aux(node_dst);
|
|
2109
|
+
set_tmp_data(node_tmp, j_dst, i_buf);
|
|
2110
|
+
|
|
2111
|
+
ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_tmp);
|
|
2112
|
+
|
|
2113
|
+
ggml_tensor * node_red = get_node_aux(node_dst);
|
|
2114
|
+
node_red->view_src = node_dst->view_src == nullptr ? node_dst : node_dst->view_src;
|
|
2115
|
+
node_red->view_offs = node_dst->view_offs;
|
|
2116
|
+
node_red->op = GGML_OP_ADD;
|
|
2117
|
+
node_red->src[0] = node_dst;
|
|
2118
|
+
node_red->src[1] = node_tmp;
|
|
2119
|
+
node_red->flags |= GGML_TENSOR_FLAG_COMPUTE;
|
|
2120
|
+
ggml_backend_view_init(node_red);
|
|
2121
|
+
|
|
2122
|
+
ggml_cgraph * cgraph_aux = get_cgraph_aux();
|
|
2123
|
+
cgraph_aux->nodes[0] = node_red;
|
|
2124
|
+
cgraph_aux->n_nodes = 1;
|
|
2125
|
+
step_cgraphs[j_dst] = cgraph_aux;
|
|
2126
|
+
};
|
|
2127
|
+
|
|
2128
|
+
size_t offset_j = n_backends/2;
|
|
2129
|
+
while ((offset_j & (offset_j - 1)) != 0) {
|
|
2130
|
+
offset_j--;
|
|
2131
|
+
}
|
|
2132
|
+
const size_t offset_j_max = offset_j;
|
|
2133
|
+
size_t i_buf = 0;
|
|
2134
|
+
|
|
2135
|
+
// If n_backends is not a power of 2, fold in the excess prior to butterfly reduction:
|
|
2136
|
+
for (size_t j_src = 2*offset_j_max; j_src < n_backends; j_src++) {
|
|
2137
|
+
const size_t j_dst = j_src - 2*offset_j_max;
|
|
2138
|
+
push_data(j_src, j_dst, i_buf);
|
|
2139
|
+
const ggml_status status = ggml_backend_graph_compute_async(backend_ctx->backend_configs[j_dst].backend, step_cgraphs[j_dst]);
|
|
2140
|
+
if (status != GGML_STATUS_SUCCESS) {
|
|
2141
|
+
return status;
|
|
2142
|
+
}
|
|
2143
|
+
i_buf = 1;
|
|
2144
|
+
}
|
|
2145
|
+
|
|
2146
|
+
// Butterfly reduction:
|
|
2147
|
+
for (; offset_j >= 1; offset_j /= 2) {
|
|
2148
|
+
std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr);
|
|
2149
|
+
|
|
2150
|
+
for (size_t j = 0; j < 2*offset_j_max; j++) {
|
|
2151
|
+
const size_t j_other = j ^ offset_j;
|
|
2152
|
+
if (j_other >= n_backends) {
|
|
2153
|
+
continue;
|
|
2154
|
+
}
|
|
2155
|
+
push_data(j, j_other, i_buf);
|
|
2156
|
+
}
|
|
2157
|
+
|
|
2158
|
+
for (size_t j = 0; j < 2*offset_j_max; j++) {
|
|
2159
|
+
if (step_cgraphs[j] == nullptr) {
|
|
2160
|
+
continue;
|
|
2161
|
+
}
|
|
2162
|
+
auto & bcj = backend_ctx->backend_configs[j];
|
|
2163
|
+
const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]);
|
|
2164
|
+
if (status != GGML_STATUS_SUCCESS) {
|
|
2165
|
+
return status;
|
|
2166
|
+
}
|
|
2167
|
+
}
|
|
2168
|
+
i_buf++;
|
|
2169
|
+
}
|
|
2170
|
+
assert(i_buf == backend_ctx->n_reduce_steps);
|
|
2171
|
+
|
|
2172
|
+
// If n_backends is not a power of 2, copy back the reduced tensors to the excess:
|
|
2173
|
+
for (size_t j = 2*offset_j_max; j < n_backends; j++) {
|
|
2174
|
+
auto & bcj_src = backend_ctx->backend_configs[j - 2*offset_j_max];
|
|
2175
|
+
auto & bcj_dst = backend_ctx->backend_configs[j];
|
|
2176
|
+
|
|
2177
|
+
ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1];
|
|
2178
|
+
ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1];
|
|
2179
|
+
ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_dst);
|
|
2180
|
+
}
|
|
2181
|
+
|
|
2182
|
+
return GGML_STATUS_SUCCESS;
|
|
2183
|
+
};
|
|
2184
|
+
|
|
2185
|
+
|
|
2186
|
+
for (size_t i = 0; i < backend_ctx->n_subgraphs; i++) {
|
|
2187
|
+
for (size_t j = 0; j < n_backends; j++) {
|
|
2188
|
+
auto & bcj = backend_ctx->backend_configs[j];
|
|
2189
|
+
const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, bcj.cgraphs[i].cgraph_main);
|
|
2190
|
+
if (status != GGML_STATUS_SUCCESS) {
|
|
2191
|
+
return status;
|
|
2192
|
+
}
|
|
2193
|
+
}
|
|
2194
|
+
|
|
2195
|
+
if (n_backends > 1 && i < backend_ctx->n_subgraphs - 1) {
|
|
2196
|
+
bool backend_allreduce_success = false;
|
|
2197
|
+
if (backend_ctx->comm_ctx) {
|
|
2198
|
+
std::vector<ggml_tensor *> nodes;
|
|
2199
|
+
nodes.reserve(n_backends);
|
|
2200
|
+
for (size_t j = 0; j < n_backends; j++) {
|
|
2201
|
+
auto & bcj = backend_ctx->backend_configs[j];
|
|
2202
|
+
ggml_cgraph * cgraph_ij = bcj.cgraphs[i].cgraph_main;
|
|
2203
|
+
nodes.push_back(cgraph_ij->nodes[cgraph_ij->n_nodes-1]);
|
|
2204
|
+
}
|
|
2205
|
+
backend_allreduce_success = backend_ctx->comm_allreduce(backend_ctx->comm_ctx, nodes.data());
|
|
2206
|
+
}
|
|
2207
|
+
|
|
2208
|
+
if (!backend_allreduce_success) {
|
|
2209
|
+
const ggml_status status = allreduce_fallback(i);
|
|
2210
|
+
if (status != GGML_STATUS_SUCCESS) {
|
|
2211
|
+
return status;
|
|
2212
|
+
}
|
|
2213
|
+
}
|
|
2214
|
+
}
|
|
2215
|
+
}
|
|
2216
|
+
return GGML_STATUS_SUCCESS;
|
|
2217
|
+
}
|
|
2218
|
+
|
|
2219
|
+
static const ggml_backend_i ggml_backend_meta_i = {
|
|
2220
|
+
/* .get_name = */ ggml_backend_meta_get_name,
|
|
2221
|
+
/* .free = */ ggml_backend_meta_free,
|
|
2222
|
+
/* .set_tensor_async = */ ggml_backend_meta_set_tensor_async,
|
|
2223
|
+
/* .get_tensor_async = */ ggml_backend_meta_get_tensor_async,
|
|
2224
|
+
/* .set_tensor_2d_async = */ nullptr,
|
|
2225
|
+
/* .get_tensor_2d_async = */ nullptr,
|
|
2226
|
+
/* .cpy_tensor_async = */ nullptr,
|
|
2227
|
+
/* .synchronize = */ ggml_backend_meta_synchronize,
|
|
2228
|
+
/* .graph_plan_create = */ nullptr,
|
|
2229
|
+
/* .graph_plan_free = */ nullptr,
|
|
2230
|
+
/* .graph_plan_update = */ nullptr,
|
|
2231
|
+
/* .graph_plan_compute = */ nullptr,
|
|
2232
|
+
/* .graph_compute = */ ggml_backend_meta_graph_compute,
|
|
2233
|
+
/* .event_record = */ nullptr,
|
|
2234
|
+
/* .event_wait = */ nullptr,
|
|
2235
|
+
/* .graph_optimize = */ nullptr,
|
|
2236
|
+
};
|
|
2237
|
+
|
|
2238
|
+
bool ggml_backend_is_meta(ggml_backend_t backend) {
|
|
2239
|
+
return backend != nullptr && backend->iface.get_name == ggml_backend_meta_i.get_name;
|
|
2240
|
+
}
|
|
2241
|
+
|
|
2242
|
+
static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params) {
|
|
2243
|
+
ggml_backend_meta_context * backend_ctx = new ggml_backend_meta_context(dev, params);
|
|
2244
|
+
|
|
2245
|
+
ggml_backend_t backend = new struct ggml_backend;
|
|
2246
|
+
backend->guid = ggml_backend_meta_guid();
|
|
2247
|
+
backend->iface = ggml_backend_meta_i;
|
|
2248
|
+
backend->device = dev;
|
|
2249
|
+
backend->context = backend_ctx;
|
|
2250
|
+
return backend;
|
|
2251
|
+
}
|
|
2252
|
+
|
|
2253
|
+
size_t ggml_backend_meta_n_backends(ggml_backend_t meta_backend) {
|
|
2254
|
+
GGML_ASSERT(ggml_backend_is_meta(meta_backend));
|
|
2255
|
+
const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context;
|
|
2256
|
+
return backend_ctx->backend_configs.size();
|
|
2257
|
+
}
|
|
2258
|
+
|
|
2259
|
+
ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index) {
|
|
2260
|
+
GGML_ASSERT(ggml_backend_is_meta(meta_backend));
|
|
2261
|
+
const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context;
|
|
2262
|
+
return backend_ctx->backend_configs[index].backend;
|
|
2263
|
+
}
|