whispercpp 1.3.5 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/LICENSE +1 -1
- data/README.md +133 -3
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -7
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +56 -46
- data/ext/ruby_whisper.h +165 -2
- data/ext/ruby_whisper_context.c +297 -126
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -66
- data/ext/ruby_whisper_segment.c +6 -7
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +46 -16
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +24 -19
- data/ext/sources/examples/cli/cli.cpp +51 -9
- data/ext/sources/examples/common-ggml.cpp +4 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +213 -163
- data/ext/sources/ggml/CMakeLists.txt +29 -15
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +73 -11
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +8 -3
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +155 -16
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +25 -5
- data/ext/sources/ggml/src/ggml-alloc.c +9 -10
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
- data/ext/sources/ggml/src/ggml-common.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
- data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
- data/ext/sources/ggml/src/ggml-impl.h +68 -1
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +385 -119
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
- data/ext/sources/ggml/src/ggml.c +268 -52
- data/ext/sources/ggml/src/gguf.cpp +377 -47
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +62 -40
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +445 -55
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_context_params.rb +82 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +44 -6
- data/whispercpp.gemspec +2 -2
- metadata +426 -280
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
- data/ext/sources/examples/talk-llama/llama-context.h +0 -360
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
- data/ext/sources/examples/talk-llama/llama-model.h +0 -544
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
- data/ext/sources/examples/talk-llama/llama.h +0 -1540
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -569
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
|
@@ -203,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
203
203
|
GGML_ABORT("unsupported op");
|
|
204
204
|
}
|
|
205
205
|
|
|
206
|
+
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
|
207
|
+
return 1;
|
|
208
|
+
}
|
|
209
|
+
|
|
206
210
|
int n_fuse = 1;
|
|
207
211
|
|
|
208
212
|
// check if the current node can run concurrently with other nodes before it
|
|
@@ -283,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
283
287
|
n_fuse = ggml_metal_op_acc(ctx, idx);
|
|
284
288
|
} break;
|
|
285
289
|
case GGML_OP_SCALE:
|
|
286
|
-
{
|
|
287
|
-
n_fuse = ggml_metal_op_scale(ctx, idx);
|
|
288
|
-
} break;
|
|
289
290
|
case GGML_OP_FILL:
|
|
290
|
-
{
|
|
291
|
-
n_fuse = ggml_metal_op_fill(ctx, idx);
|
|
292
|
-
} break;
|
|
293
291
|
case GGML_OP_CLAMP:
|
|
294
|
-
|
|
295
|
-
n_fuse = ggml_metal_op_clamp(ctx, idx);
|
|
296
|
-
} break;
|
|
292
|
+
case GGML_OP_LEAKY_RELU:
|
|
297
293
|
case GGML_OP_SQR:
|
|
298
294
|
case GGML_OP_SQRT:
|
|
299
295
|
case GGML_OP_SIN:
|
|
@@ -337,6 +333,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
337
333
|
{
|
|
338
334
|
n_fuse = ggml_metal_op_rwkv(ctx, idx);
|
|
339
335
|
} break;
|
|
336
|
+
case GGML_OP_GATED_DELTA_NET:
|
|
337
|
+
{
|
|
338
|
+
n_fuse = ggml_metal_op_gated_delta_net(ctx, idx);
|
|
339
|
+
} break;
|
|
340
|
+
case GGML_OP_SOLVE_TRI:
|
|
341
|
+
{
|
|
342
|
+
n_fuse = ggml_metal_op_solve_tri(ctx, idx);
|
|
343
|
+
} break;
|
|
340
344
|
case GGML_OP_MUL_MAT:
|
|
341
345
|
{
|
|
342
346
|
n_fuse = ggml_metal_op_mul_mat(ctx, idx);
|
|
@@ -353,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
353
357
|
{
|
|
354
358
|
n_fuse = ggml_metal_op_set_rows(ctx, idx);
|
|
355
359
|
} break;
|
|
360
|
+
case GGML_OP_DIAG:
|
|
361
|
+
{
|
|
362
|
+
n_fuse = ggml_metal_op_diag(ctx, idx);
|
|
363
|
+
} break;
|
|
356
364
|
case GGML_OP_L2_NORM:
|
|
357
365
|
{
|
|
358
366
|
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
|
@@ -386,6 +394,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
386
394
|
{
|
|
387
395
|
n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
|
|
388
396
|
} break;
|
|
397
|
+
case GGML_OP_CONV_3D:
|
|
398
|
+
{
|
|
399
|
+
n_fuse = ggml_metal_op_conv_3d(ctx, idx);
|
|
400
|
+
} break;
|
|
389
401
|
case GGML_OP_UPSCALE:
|
|
390
402
|
{
|
|
391
403
|
n_fuse = ggml_metal_op_upscale(ctx, idx);
|
|
@@ -398,6 +410,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
398
410
|
{
|
|
399
411
|
n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx);
|
|
400
412
|
} break;
|
|
413
|
+
case GGML_OP_ROLL:
|
|
414
|
+
{
|
|
415
|
+
n_fuse = ggml_metal_op_roll(ctx, idx);
|
|
416
|
+
} break;
|
|
401
417
|
case GGML_OP_ARANGE:
|
|
402
418
|
{
|
|
403
419
|
n_fuse = ggml_metal_op_arange(ctx, idx);
|
|
@@ -414,10 +430,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
414
430
|
{
|
|
415
431
|
n_fuse = ggml_metal_op_top_k(ctx, idx);
|
|
416
432
|
} break;
|
|
417
|
-
case GGML_OP_LEAKY_RELU:
|
|
418
|
-
{
|
|
419
|
-
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
|
|
420
|
-
} break;
|
|
421
433
|
case GGML_OP_TRI:
|
|
422
434
|
{
|
|
423
435
|
n_fuse = ggml_metal_op_tri(ctx, idx);
|
|
@@ -426,12 +438,20 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
426
438
|
{
|
|
427
439
|
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
|
|
428
440
|
} break;
|
|
441
|
+
case GGML_OP_SET:
|
|
442
|
+
{
|
|
443
|
+
n_fuse = ggml_metal_op_set(ctx, idx);
|
|
444
|
+
} break;
|
|
429
445
|
case GGML_OP_DUP:
|
|
430
446
|
case GGML_OP_CPY:
|
|
431
447
|
case GGML_OP_CONT:
|
|
432
448
|
{
|
|
433
449
|
n_fuse = ggml_metal_op_cpy(ctx, idx);
|
|
434
450
|
} break;
|
|
451
|
+
case GGML_OP_POOL_1D:
|
|
452
|
+
{
|
|
453
|
+
n_fuse = ggml_metal_op_pool_1d(ctx, idx);
|
|
454
|
+
} break;
|
|
435
455
|
case GGML_OP_POOL_2D:
|
|
436
456
|
{
|
|
437
457
|
n_fuse = ggml_metal_op_pool_2d(ctx, idx);
|
|
@@ -544,9 +564,20 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
|
|
|
544
564
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
545
565
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
546
566
|
|
|
547
|
-
|
|
567
|
+
int nth = std::min(256, ne0);
|
|
548
568
|
|
|
549
|
-
|
|
569
|
+
// when rows are small, we can batch them together in a single threadgroup
|
|
570
|
+
int nrptg = 1;
|
|
571
|
+
if (nth < 256) {
|
|
572
|
+
nrptg = std::min((256 + nth - 1) / nth, ne1);
|
|
573
|
+
if (nrptg * nth > 256) {
|
|
574
|
+
nrptg = 256 / nth;
|
|
575
|
+
}
|
|
576
|
+
}
|
|
577
|
+
|
|
578
|
+
const int nw0 = (ne1 + nrptg - 1) / nrptg;
|
|
579
|
+
|
|
580
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nw0, ne2, ne3, nth, nrptg, 1);
|
|
550
581
|
|
|
551
582
|
return 1;
|
|
552
583
|
}
|
|
@@ -612,8 +643,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
612
643
|
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
|
613
644
|
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
|
614
645
|
|
|
615
|
-
GGML_ASSERT(
|
|
616
|
-
GGML_ASSERT(
|
|
646
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
647
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
|
|
617
648
|
|
|
618
649
|
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
|
|
619
650
|
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
|
|
@@ -623,7 +654,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
623
654
|
const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
|
|
624
655
|
|
|
625
656
|
if (!inplace) {
|
|
626
|
-
// run a
|
|
657
|
+
// run a separate kernel to cpy src->dst
|
|
627
658
|
// not sure how to avoid this
|
|
628
659
|
// TODO: make a simpler cpy_bytes kernel
|
|
629
660
|
|
|
@@ -663,10 +694,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
663
694
|
}
|
|
664
695
|
|
|
665
696
|
ggml_metal_kargs_bin args = {
|
|
666
|
-
/*.ne00 =*/
|
|
667
|
-
/*.ne01 =*/
|
|
668
|
-
/*.ne02 =*/
|
|
669
|
-
/*.ne03 =*/
|
|
697
|
+
/*.ne00 =*/ ne10,
|
|
698
|
+
/*.ne01 =*/ ne11,
|
|
699
|
+
/*.ne02 =*/ ne12,
|
|
700
|
+
/*.ne03 =*/ ne13,
|
|
670
701
|
/*.nb00 =*/ nb00,
|
|
671
702
|
/*.nb01 =*/ pnb1,
|
|
672
703
|
/*.nb02 =*/ pnb2,
|
|
@@ -679,10 +710,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
679
710
|
/*.nb11 =*/ nb11,
|
|
680
711
|
/*.nb12 =*/ nb12,
|
|
681
712
|
/*.nb13 =*/ nb13,
|
|
682
|
-
/*.ne0 =*/
|
|
683
|
-
/*.ne1 =*/
|
|
684
|
-
/*.ne2 =*/
|
|
685
|
-
/*.ne3 =*/
|
|
713
|
+
/*.ne0 =*/ ne10,
|
|
714
|
+
/*.ne1 =*/ ne11,
|
|
715
|
+
/*.ne2 =*/ ne12,
|
|
716
|
+
/*.ne3 =*/ ne13,
|
|
686
717
|
/*.nb0 =*/ nb0,
|
|
687
718
|
/*.nb1 =*/ pnb1,
|
|
688
719
|
/*.nb2 =*/ pnb2,
|
|
@@ -691,7 +722,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
691
722
|
/*.o1 =*/ { 0 },
|
|
692
723
|
};
|
|
693
724
|
|
|
694
|
-
auto pipeline =
|
|
725
|
+
auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
|
|
695
726
|
|
|
696
727
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
697
728
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -699,53 +730,20 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
699
730
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
700
731
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
701
732
|
|
|
702
|
-
const int
|
|
703
|
-
|
|
704
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
|
|
733
|
+
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
705
734
|
|
|
706
|
-
|
|
707
|
-
}
|
|
708
|
-
|
|
709
|
-
int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
|
710
|
-
ggml_tensor * op = ctx->node(idx);
|
|
711
|
-
|
|
712
|
-
ggml_metal_library_t lib = ctx->lib;
|
|
713
|
-
ggml_metal_encoder_t enc = ctx->enc;
|
|
714
|
-
|
|
715
|
-
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
716
|
-
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
717
|
-
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
718
|
-
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
719
|
-
|
|
720
|
-
float scale;
|
|
721
|
-
float bias;
|
|
722
|
-
memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
|
|
723
|
-
memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float));
|
|
724
|
-
|
|
725
|
-
ggml_metal_kargs_scale args = {
|
|
726
|
-
/*.scale =*/ scale,
|
|
727
|
-
/*.bias =*/ bias,
|
|
728
|
-
};
|
|
729
|
-
|
|
730
|
-
int64_t n = ggml_nelements(op);
|
|
735
|
+
int nth = 1;
|
|
731
736
|
|
|
732
|
-
|
|
733
|
-
|
|
737
|
+
while (2*nth < args.ne0 && nth < nth_max) {
|
|
738
|
+
nth *= 2;
|
|
734
739
|
}
|
|
735
740
|
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
739
|
-
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
740
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
741
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
742
|
-
|
|
743
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
741
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
|
|
744
742
|
|
|
745
743
|
return 1;
|
|
746
744
|
}
|
|
747
745
|
|
|
748
|
-
int
|
|
746
|
+
int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
|
749
747
|
ggml_tensor * op = ctx->node(idx);
|
|
750
748
|
|
|
751
749
|
ggml_metal_library_t lib = ctx->lib;
|
|
@@ -756,94 +754,85 @@ int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
|
|
|
756
754
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
757
755
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
758
756
|
|
|
759
|
-
|
|
757
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
760
758
|
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
};
|
|
759
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
760
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
764
761
|
|
|
765
|
-
|
|
762
|
+
ggml_metal_kargs_unary args = {
|
|
763
|
+
/*.ne00 =*/ ne00,
|
|
764
|
+
/*.ne01 =*/ ne01,
|
|
765
|
+
/*.ne02 =*/ ne02,
|
|
766
|
+
/*.ne03 =*/ ne03,
|
|
767
|
+
/*.nb00 =*/ nb00,
|
|
768
|
+
/*.nb01 =*/ nb01,
|
|
769
|
+
/*.nb02 =*/ nb02,
|
|
770
|
+
/*.nb03 =*/ nb03,
|
|
771
|
+
/*.ne0 =*/ ne0,
|
|
772
|
+
/*.ne1 =*/ ne1,
|
|
773
|
+
/*.ne2 =*/ ne2,
|
|
774
|
+
/*.ne3 =*/ ne3,
|
|
775
|
+
/*.nb0 =*/ nb0,
|
|
776
|
+
/*.nb1 =*/ nb1,
|
|
777
|
+
/*.nb2 =*/ nb2,
|
|
778
|
+
/*.nb3 =*/ nb3,
|
|
779
|
+
/*.slope =*/ 0.0,
|
|
780
|
+
/*.scale =*/ 0.0,
|
|
781
|
+
/*.bias =*/ 0.0,
|
|
782
|
+
/*.val =*/ 0.0,
|
|
783
|
+
/*.min =*/ 0.0,
|
|
784
|
+
/*.max =*/ 0.0,
|
|
785
|
+
};
|
|
766
786
|
|
|
767
|
-
if (
|
|
768
|
-
|
|
787
|
+
if (op->op == GGML_OP_LEAKY_RELU) {
|
|
788
|
+
args.slope = ggml_get_op_params_f32(op, 0);
|
|
769
789
|
}
|
|
770
790
|
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
776
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
777
|
-
|
|
778
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
779
|
-
|
|
780
|
-
return 1;
|
|
781
|
-
}
|
|
782
|
-
|
|
783
|
-
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
|
784
|
-
ggml_tensor * op = ctx->node(idx);
|
|
785
|
-
|
|
786
|
-
ggml_metal_library_t lib = ctx->lib;
|
|
787
|
-
ggml_metal_encoder_t enc = ctx->enc;
|
|
788
|
-
|
|
789
|
-
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
790
|
-
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
791
|
-
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
792
|
-
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
793
|
-
|
|
794
|
-
float min;
|
|
795
|
-
float max;
|
|
796
|
-
memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
|
|
797
|
-
memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
|
|
791
|
+
if (op->op == GGML_OP_SCALE) {
|
|
792
|
+
args.scale = ggml_get_op_params_f32(op, 0);
|
|
793
|
+
args.bias = ggml_get_op_params_f32(op, 1);
|
|
794
|
+
}
|
|
798
795
|
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
};
|
|
796
|
+
if (op->op == GGML_OP_FILL) {
|
|
797
|
+
args.val = ggml_get_op_params_f32(op, 0);
|
|
798
|
+
}
|
|
803
799
|
|
|
804
|
-
|
|
800
|
+
if (op->op == GGML_OP_CLAMP) {
|
|
801
|
+
args.min = ggml_get_op_params_f32(op, 0);
|
|
802
|
+
args.max = ggml_get_op_params_f32(op, 1);
|
|
803
|
+
}
|
|
805
804
|
|
|
806
|
-
if (
|
|
807
|
-
|
|
805
|
+
if (op->op == GGML_OP_UNARY && ggml_get_unary_op(op) == GGML_UNARY_OP_XIELU) {
|
|
806
|
+
args.slope = ggml_get_op_params_f32(op, 1); // alpha_n
|
|
807
|
+
args.scale = ggml_get_op_params_f32(op, 2); // alpha_p
|
|
808
|
+
args.bias = ggml_get_op_params_f32(op, 3); // beta
|
|
809
|
+
args.val = ggml_get_op_params_f32(op, 4); // eps
|
|
808
810
|
}
|
|
809
811
|
|
|
810
812
|
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
|
811
813
|
|
|
814
|
+
if (pipeline.c4) {
|
|
815
|
+
args.ne00 = ne00/4;
|
|
816
|
+
args.ne0 = ne0/4;
|
|
817
|
+
}
|
|
818
|
+
|
|
812
819
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
813
820
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
814
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
815
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
816
|
-
|
|
817
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
818
|
-
|
|
819
|
-
return 1;
|
|
820
|
-
}
|
|
821
|
-
|
|
822
|
-
int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
|
823
|
-
ggml_tensor * op = ctx->node(idx);
|
|
824
|
-
|
|
825
|
-
ggml_metal_library_t lib = ctx->lib;
|
|
826
|
-
ggml_metal_encoder_t enc = ctx->enc;
|
|
821
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
822
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
827
823
|
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
831
|
-
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
824
|
+
if (pipeline.cnt) {
|
|
825
|
+
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
|
|
832
826
|
|
|
833
|
-
|
|
827
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
828
|
+
} else {
|
|
829
|
+
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
830
|
+
const int nth = MIN(args.ne00, nth_max);
|
|
831
|
+
const int nk0 = (args.ne00 + nth - 1)/nth;
|
|
834
832
|
|
|
835
|
-
|
|
836
|
-
n /= 4;
|
|
833
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
|
|
837
834
|
}
|
|
838
835
|
|
|
839
|
-
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
|
840
|
-
|
|
841
|
-
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
842
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
|
|
843
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
|
|
844
|
-
|
|
845
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
846
|
-
|
|
847
836
|
return 1;
|
|
848
837
|
}
|
|
849
838
|
|
|
@@ -953,6 +942,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
953
942
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
954
943
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
955
944
|
|
|
945
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
946
|
+
|
|
947
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
948
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
949
|
+
|
|
956
950
|
ggml_metal_kargs_sum_rows args = {
|
|
957
951
|
/*.ne00 =*/ ne00,
|
|
958
952
|
/*.ne01 =*/ ne01,
|
|
@@ -974,21 +968,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
974
968
|
|
|
975
969
|
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
|
|
976
970
|
|
|
971
|
+
if (pipeline.c4) {
|
|
972
|
+
args.ne00 = ne00/4;
|
|
973
|
+
args.ne0 = ne0/4;
|
|
974
|
+
}
|
|
975
|
+
|
|
977
976
|
int nth = 32; // SIMD width
|
|
978
977
|
|
|
979
|
-
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
978
|
+
while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
980
979
|
nth *= 2;
|
|
981
980
|
}
|
|
982
981
|
|
|
983
982
|
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
984
|
-
nth = std::min(nth, ne00);
|
|
983
|
+
nth = std::min(nth, (int) args.ne00);
|
|
985
984
|
|
|
986
985
|
const size_t smem = pipeline.smem;
|
|
987
986
|
|
|
988
987
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
989
988
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
990
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
991
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
989
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
990
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
992
991
|
|
|
993
992
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
994
993
|
|
|
@@ -1247,6 +1246,48 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
1247
1246
|
return 1;
|
|
1248
1247
|
}
|
|
1249
1248
|
|
|
1249
|
+
int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {
|
|
1250
|
+
ggml_tensor * op = ctx->node(idx);
|
|
1251
|
+
|
|
1252
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
1253
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
1254
|
+
|
|
1255
|
+
GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
|
|
1256
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1257
|
+
GGML_TENSOR_LOCALS(int32_t, ne, op, ne);
|
|
1258
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1259
|
+
|
|
1260
|
+
ggml_metal_kargs_diag args = {
|
|
1261
|
+
/*.ne00 =*/ne00,
|
|
1262
|
+
/*.ne01 =*/ne01,
|
|
1263
|
+
/*.ne02 =*/ne02,
|
|
1264
|
+
/*.ne03 =*/ne03,
|
|
1265
|
+
/*.nb00 =*/nb00,
|
|
1266
|
+
/*.nb01 =*/nb01,
|
|
1267
|
+
/*.nb02 =*/nb02,
|
|
1268
|
+
/*.nb03 =*/nb03,
|
|
1269
|
+
/*.ne0 =*/ne0,
|
|
1270
|
+
/*.ne1 =*/ne1,
|
|
1271
|
+
/*.ne2 =*/ne2,
|
|
1272
|
+
/*.ne3 =*/ne3,
|
|
1273
|
+
/*.nb0 =*/nb0,
|
|
1274
|
+
/*.nb1 =*/nb1,
|
|
1275
|
+
/*.nb2 =*/nb2,
|
|
1276
|
+
/*.nb3 =*/nb3,
|
|
1277
|
+
};
|
|
1278
|
+
|
|
1279
|
+
auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);
|
|
1280
|
+
|
|
1281
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1282
|
+
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
1283
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1284
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2);
|
|
1285
|
+
|
|
1286
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);
|
|
1287
|
+
|
|
1288
|
+
return 1;
|
|
1289
|
+
}
|
|
1290
|
+
|
|
1250
1291
|
int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
|
1251
1292
|
ggml_tensor * op = ctx->node(idx);
|
|
1252
1293
|
|
|
@@ -1524,27 +1565,287 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
|
|
1524
1565
|
const int64_t C = op->ne[0];
|
|
1525
1566
|
const int64_t H = op->src[0]->ne[1];
|
|
1526
1567
|
|
|
1527
|
-
auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
|
|
1568
|
+
auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
|
|
1569
|
+
|
|
1570
|
+
int ida = 0;
|
|
1571
|
+
|
|
1572
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1573
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
|
|
1574
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
|
|
1575
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
|
|
1576
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
|
|
1577
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
|
|
1578
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++);
|
|
1579
|
+
if (op->op == GGML_OP_RWKV_WKV7) {
|
|
1580
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++);
|
|
1581
|
+
}
|
|
1582
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++);
|
|
1583
|
+
ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++);
|
|
1584
|
+
ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++);
|
|
1585
|
+
ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++);
|
|
1586
|
+
ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++);
|
|
1587
|
+
|
|
1588
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);
|
|
1589
|
+
|
|
1590
|
+
return 1;
|
|
1591
|
+
}
|
|
1592
|
+
|
|
1593
|
+
int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {
|
|
1594
|
+
ggml_tensor * op = ctx->node(idx);
|
|
1595
|
+
|
|
1596
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
1597
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
1598
|
+
|
|
1599
|
+
|
|
1600
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1601
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1602
|
+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1603
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1604
|
+
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
1605
|
+
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
1606
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1607
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1608
|
+
|
|
1609
|
+
auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op);
|
|
1610
|
+
|
|
1611
|
+
int ida = 0;
|
|
1612
|
+
|
|
1613
|
+
ggml_metal_kargs_gated_delta_net args = {
|
|
1614
|
+
/*.ne00 =*/ ne00,
|
|
1615
|
+
/*.ne01 =*/ ne01,
|
|
1616
|
+
/*.ne02 =*/ ne02,
|
|
1617
|
+
/*.ne03 =*/ ne03,
|
|
1618
|
+
/*.nb00 =*/ nb00,
|
|
1619
|
+
/*.nb01 =*/ nb01,
|
|
1620
|
+
/*.nb02 =*/ nb02,
|
|
1621
|
+
/*.nb03 =*/ nb03,
|
|
1622
|
+
/*.ne10 =*/ ne10,
|
|
1623
|
+
/*.ne11 =*/ ne11,
|
|
1624
|
+
/*.ne12 =*/ ne12,
|
|
1625
|
+
/*.ne13 =*/ ne13,
|
|
1626
|
+
/*.nb10 =*/ nb10,
|
|
1627
|
+
/*.nb11 =*/ nb11,
|
|
1628
|
+
/*.nb12 =*/ nb12,
|
|
1629
|
+
/*.nb13 =*/ nb13,
|
|
1630
|
+
/*.ne20 =*/ ne20,
|
|
1631
|
+
/*.ne21 =*/ ne21,
|
|
1632
|
+
/*.ne22 =*/ ne22,
|
|
1633
|
+
/*.ne23 =*/ ne23,
|
|
1634
|
+
/*.nb20 =*/ nb20,
|
|
1635
|
+
/*.nb21 =*/ nb21,
|
|
1636
|
+
/*.nb22 =*/ nb22,
|
|
1637
|
+
/*.nb23 =*/ nb23,
|
|
1638
|
+
/*.ns02 =*/ (int32_t) (nb02/sizeof(float)),
|
|
1639
|
+
/*.ns12 =*/ (int32_t) (nb12/sizeof(float)),
|
|
1640
|
+
/*.ns22 =*/ (int32_t) (nb22/sizeof(float)),
|
|
1641
|
+
/*.ne0 =*/ ne0,
|
|
1642
|
+
/*.ne1 =*/ ne1,
|
|
1643
|
+
/*.ne2 =*/ ne2,
|
|
1644
|
+
/*.ne3 =*/ ne3,
|
|
1645
|
+
/*.nb0 =*/ nb0,
|
|
1646
|
+
/*.nb1 =*/ nb1,
|
|
1647
|
+
/*.nb2 =*/ nb2,
|
|
1648
|
+
/*.nb3 =*/ nb3,
|
|
1649
|
+
};
|
|
1650
|
+
|
|
1651
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1652
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
|
1653
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q
|
|
1654
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k
|
|
1655
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v
|
|
1656
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate
|
|
1657
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta
|
|
1658
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state
|
|
1659
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst
|
|
1660
|
+
|
|
1661
|
+
const int nsg = pipeline.nsg;
|
|
1662
|
+
|
|
1663
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1);
|
|
1664
|
+
|
|
1665
|
+
return 1;
|
|
1666
|
+
}
|
|
1667
|
+
|
|
1668
|
+
int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
|
1669
|
+
ggml_tensor * op = ctx->node(idx);
|
|
1670
|
+
|
|
1671
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
1672
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
1673
|
+
|
|
1674
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1675
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1676
|
+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1677
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1678
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1679
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1680
|
+
|
|
1681
|
+
ggml_metal_kargs_solve_tri args = {
|
|
1682
|
+
/*.ne00 =*/ ne00,
|
|
1683
|
+
/*.ne01 =*/ ne01,
|
|
1684
|
+
/*.ne02 =*/ ne02,
|
|
1685
|
+
/*.ne03 =*/ ne03,
|
|
1686
|
+
/*.nb00 =*/ nb00,
|
|
1687
|
+
/*.nb01 =*/ nb01,
|
|
1688
|
+
/*.nb02 =*/ nb02,
|
|
1689
|
+
/*.nb03 =*/ nb03,
|
|
1690
|
+
/*.ne10 =*/ ne10,
|
|
1691
|
+
/*.ne11 =*/ ne11,
|
|
1692
|
+
/*.ne12 =*/ ne12,
|
|
1693
|
+
/*.ne13 =*/ ne13,
|
|
1694
|
+
/*.nb10 =*/ nb10,
|
|
1695
|
+
/*.nb11 =*/ nb11,
|
|
1696
|
+
/*.nb12 =*/ nb12,
|
|
1697
|
+
/*.nb13 =*/ nb13,
|
|
1698
|
+
/*.ne0 =*/ ne0,
|
|
1699
|
+
/*.ne1 =*/ ne1,
|
|
1700
|
+
/*.ne2 =*/ ne2,
|
|
1701
|
+
/*.ne3 =*/ ne3,
|
|
1702
|
+
/*.nb0 =*/ nb0,
|
|
1703
|
+
/*.nb1 =*/ nb1,
|
|
1704
|
+
/*.nb2 =*/ nb2,
|
|
1705
|
+
/*.nb3 =*/ nb3,
|
|
1706
|
+
};
|
|
1707
|
+
|
|
1708
|
+
auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
|
|
1709
|
+
|
|
1710
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1711
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1712
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1713
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
1714
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
1715
|
+
|
|
1716
|
+
const int nsg = pipeline.nsg;
|
|
1717
|
+
|
|
1718
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0);
|
|
1719
|
+
|
|
1720
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1);
|
|
1721
|
+
|
|
1722
|
+
return 1;
|
|
1723
|
+
}
|
|
1724
|
+
|
|
1725
|
+
int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
|
|
1726
|
+
ggml_tensor * op = ctx->node(idx);
|
|
1727
|
+
|
|
1728
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
1729
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
1730
|
+
|
|
1731
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1732
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1733
|
+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1734
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1735
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1736
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1737
|
+
|
|
1738
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
1739
|
+
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
|
1740
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
1741
|
+
|
|
1742
|
+
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
|
|
1743
|
+
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
|
|
1744
|
+
const size_t pnb3 = ((const int32_t *) op->op_params)[2];
|
|
1745
|
+
const size_t offs = ((const int32_t *) op->op_params)[3];
|
|
1746
|
+
|
|
1747
|
+
const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
|
|
1748
|
+
|
|
1749
|
+
if (!inplace) {
|
|
1750
|
+
// run a separate kernel to cpy src->dst
|
|
1751
|
+
// not sure how to avoid this
|
|
1752
|
+
// TODO: make a simpler cpy_bytes kernel
|
|
1753
|
+
|
|
1754
|
+
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
|
|
1755
|
+
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
|
1756
|
+
|
|
1757
|
+
ggml_metal_kargs_cpy args = {
|
|
1758
|
+
/*.nk0 =*/ ne00,
|
|
1759
|
+
/*.ne00 =*/ ne00,
|
|
1760
|
+
/*.ne01 =*/ ne01,
|
|
1761
|
+
/*.ne02 =*/ ne02,
|
|
1762
|
+
/*.ne03 =*/ ne03,
|
|
1763
|
+
/*.nb00 =*/ nb00,
|
|
1764
|
+
/*.nb01 =*/ nb01,
|
|
1765
|
+
/*.nb02 =*/ nb02,
|
|
1766
|
+
/*.nb03 =*/ nb03,
|
|
1767
|
+
/*.ne0 =*/ ne0,
|
|
1768
|
+
/*.ne1 =*/ ne1,
|
|
1769
|
+
/*.ne2 =*/ ne2,
|
|
1770
|
+
/*.ne3 =*/ ne3,
|
|
1771
|
+
/*.nb0 =*/ nb0,
|
|
1772
|
+
/*.nb1 =*/ nb1,
|
|
1773
|
+
/*.nb2 =*/ nb2,
|
|
1774
|
+
/*.nb3 =*/ nb3,
|
|
1775
|
+
};
|
|
1776
|
+
|
|
1777
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1778
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1779
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
1780
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
1781
|
+
|
|
1782
|
+
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
|
|
1783
|
+
|
|
1784
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
|
1785
|
+
|
|
1786
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
1787
|
+
}
|
|
1788
|
+
|
|
1789
|
+
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
|
|
1790
|
+
|
|
1791
|
+
GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
|
|
1792
|
+
|
|
1793
|
+
int64_t nk0 = ne10;
|
|
1794
|
+
if (ggml_is_quantized(op->src[1]->type)) {
|
|
1795
|
+
nk0 = ne10/16;
|
|
1796
|
+
} else if (ggml_is_quantized(op->type)) {
|
|
1797
|
+
nk0 = ne10/ggml_blck_size(op->type);
|
|
1798
|
+
}
|
|
1799
|
+
|
|
1800
|
+
int nth = std::min<int>(nk0*ne11, 256);
|
|
1801
|
+
|
|
1802
|
+
// when rows are small, we can batch them together in a single threadgroup
|
|
1803
|
+
int nrptg = 1;
|
|
1804
|
+
|
|
1805
|
+
// TODO: relax this constraint in the future
|
|
1806
|
+
if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
|
|
1807
|
+
if (nth > nk0) {
|
|
1808
|
+
nrptg = (nth + nk0 - 1)/nk0;
|
|
1809
|
+
nth = nk0;
|
|
1810
|
+
|
|
1811
|
+
if (nrptg*nth > 256) {
|
|
1812
|
+
nrptg--;
|
|
1813
|
+
}
|
|
1814
|
+
}
|
|
1815
|
+
}
|
|
1816
|
+
|
|
1817
|
+
nth = std::min<int>(nth, nk0);
|
|
1818
|
+
|
|
1819
|
+
ggml_metal_kargs_cpy args = {
|
|
1820
|
+
/*.nk0 =*/ nk0,
|
|
1821
|
+
/*.ne00 =*/ ne10,
|
|
1822
|
+
/*.ne01 =*/ ne11,
|
|
1823
|
+
/*.ne02 =*/ ne12,
|
|
1824
|
+
/*.ne03 =*/ ne13,
|
|
1825
|
+
/*.nb00 =*/ nb10,
|
|
1826
|
+
/*.nb01 =*/ nb11,
|
|
1827
|
+
/*.nb02 =*/ nb12,
|
|
1828
|
+
/*.nb03 =*/ nb13,
|
|
1829
|
+
/*.ne0 =*/ ne10,
|
|
1830
|
+
/*.ne1 =*/ ne11,
|
|
1831
|
+
/*.ne2 =*/ ne12,
|
|
1832
|
+
/*.ne3 =*/ ne13,
|
|
1833
|
+
/*.nb0 =*/ ggml_element_size(op),
|
|
1834
|
+
/*.nb1 =*/ pnb1,
|
|
1835
|
+
/*.nb2 =*/ pnb2,
|
|
1836
|
+
/*.nb3 =*/ pnb3,
|
|
1837
|
+
};
|
|
1838
|
+
|
|
1839
|
+
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
|
|
1528
1840
|
|
|
1529
|
-
|
|
1841
|
+
bid_dst.offs += offs;
|
|
1530
1842
|
|
|
1531
1843
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1532
|
-
|
|
1533
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
1534
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
1535
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
|
|
1536
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
|
|
1537
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++);
|
|
1538
|
-
if (op->op == GGML_OP_RWKV_WKV7) {
|
|
1539
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++);
|
|
1540
|
-
}
|
|
1541
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++);
|
|
1542
|
-
ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++);
|
|
1543
|
-
ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++);
|
|
1544
|
-
ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++);
|
|
1545
|
-
ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++);
|
|
1844
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1845
|
+
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
|
1846
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
1546
1847
|
|
|
1547
|
-
ggml_metal_encoder_dispatch_threadgroups(enc,
|
|
1848
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
|
|
1548
1849
|
|
|
1549
1850
|
return 1;
|
|
1550
1851
|
}
|
|
@@ -1571,7 +1872,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|
|
1571
1872
|
nk0 = ne00/ggml_blck_size(op->type);
|
|
1572
1873
|
}
|
|
1573
1874
|
|
|
1574
|
-
int nth = std::min<int>(nk0,
|
|
1875
|
+
int nth = std::min<int>(nk0*ne01, 256);
|
|
1575
1876
|
|
|
1576
1877
|
// when rows are small, we can batch them together in a single threadgroup
|
|
1577
1878
|
int nrptg = 1;
|
|
@@ -1582,7 +1883,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|
|
1582
1883
|
nrptg = (nth + nk0 - 1)/nk0;
|
|
1583
1884
|
nth = nk0;
|
|
1584
1885
|
|
|
1585
|
-
if (nrptg*nth >
|
|
1886
|
+
if (nrptg*nth > 256) {
|
|
1586
1887
|
nrptg--;
|
|
1587
1888
|
}
|
|
1588
1889
|
}
|
|
@@ -1622,6 +1923,54 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|
|
1622
1923
|
return 1;
|
|
1623
1924
|
}
|
|
1624
1925
|
|
|
1926
|
+
int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
|
|
1927
|
+
ggml_tensor * op = ctx->node(idx);
|
|
1928
|
+
|
|
1929
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
1930
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
1931
|
+
|
|
1932
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1933
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1934
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1935
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1936
|
+
|
|
1937
|
+
const int32_t * opts = op->op_params;
|
|
1938
|
+
ggml_op_pool op_pool = (ggml_op_pool) opts[0];
|
|
1939
|
+
|
|
1940
|
+
const int32_t k0 = opts[1];
|
|
1941
|
+
const int32_t s0 = opts[2];
|
|
1942
|
+
const int32_t p0 = opts[3];
|
|
1943
|
+
|
|
1944
|
+
const int64_t IW = op->src[0]->ne[0];
|
|
1945
|
+
const int64_t OW = op->ne[0];
|
|
1946
|
+
|
|
1947
|
+
const int64_t np = ggml_nelements(op);
|
|
1948
|
+
|
|
1949
|
+
ggml_metal_kargs_pool_1d args_pool_1d = {
|
|
1950
|
+
/* .k0 = */ k0,
|
|
1951
|
+
/* .s0 = */ s0,
|
|
1952
|
+
/* .p0 = */ p0,
|
|
1953
|
+
/* .IW = */ IW,
|
|
1954
|
+
/* .OW = */ OW,
|
|
1955
|
+
/* .np = */ np
|
|
1956
|
+
};
|
|
1957
|
+
|
|
1958
|
+
auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
|
|
1959
|
+
|
|
1960
|
+
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
|
|
1961
|
+
const int ntg = (np + nth - 1) / nth;
|
|
1962
|
+
|
|
1963
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1964
|
+
ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0);
|
|
1965
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1966
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
1967
|
+
|
|
1968
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
|
|
1969
|
+
|
|
1970
|
+
return 1;
|
|
1971
|
+
}
|
|
1972
|
+
|
|
1973
|
+
|
|
1625
1974
|
int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
|
1626
1975
|
ggml_tensor * op = ctx->node(idx);
|
|
1627
1976
|
|
|
@@ -1717,6 +2066,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1717
2066
|
(
|
|
1718
2067
|
op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
|
|
1719
2068
|
op->src[0]->type == GGML_TYPE_F16 ||
|
|
2069
|
+
op->src[0]->type == GGML_TYPE_BF16 ||
|
|
2070
|
+
op->src[0]->type == GGML_TYPE_Q1_0 ||
|
|
1720
2071
|
op->src[0]->type == GGML_TYPE_Q4_0 ||
|
|
1721
2072
|
op->src[0]->type == GGML_TYPE_Q4_1 ||
|
|
1722
2073
|
op->src[0]->type == GGML_TYPE_Q5_0 ||
|
|
@@ -1731,6 +2082,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1731
2082
|
op->src[0]->type == GGML_TYPE_Q4_K ||
|
|
1732
2083
|
op->src[0]->type == GGML_TYPE_Q5_K ||
|
|
1733
2084
|
op->src[0]->type == GGML_TYPE_Q6_K ||
|
|
2085
|
+
op->src[0]->type == GGML_TYPE_Q2_K ||
|
|
2086
|
+
op->src[0]->type == GGML_TYPE_Q3_K ||
|
|
1734
2087
|
false) && (ne11 >= 4 && ne11 <= 8)
|
|
1735
2088
|
)
|
|
1736
2089
|
)
|
|
@@ -1759,7 +2112,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1759
2112
|
const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup
|
|
1760
2113
|
int16_t r1ptg = 4; // num src1 rows per threadgroup
|
|
1761
2114
|
|
|
1762
|
-
// note: not sure how optimal are those across all different hardware. there might be
|
|
2115
|
+
// note: not sure how optimal are those across all different hardware. there might be something cleverer
|
|
1763
2116
|
switch (ne11) {
|
|
1764
2117
|
case 2:
|
|
1765
2118
|
r1ptg = 2; break;
|
|
@@ -1776,7 +2129,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1776
2129
|
GGML_ABORT("unsupported ne11");
|
|
1777
2130
|
};
|
|
1778
2131
|
|
|
1779
|
-
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op
|
|
2132
|
+
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op, nsg, nxpsg, r1ptg);
|
|
1780
2133
|
|
|
1781
2134
|
ggml_metal_kargs_mul_mv_ext args = {
|
|
1782
2135
|
/*.ne00 =*/ ne00,
|
|
@@ -1851,7 +2204,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1851
2204
|
const size_t smem = pipeline.smem;
|
|
1852
2205
|
|
|
1853
2206
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1854
|
-
|
|
2207
|
+
|
|
2208
|
+
const int nr0 = pipeline.nr0;
|
|
2209
|
+
const int nr1 = pipeline.nr1;
|
|
2210
|
+
const int nsg = pipeline.nsg;
|
|
2211
|
+
|
|
2212
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + nr1 - 1) / nr1), ((ne01 + nr0 - 1) / nr0), ne12 * ne13, 32, nsg, 1);
|
|
1855
2213
|
} else {
|
|
1856
2214
|
auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
|
|
1857
2215
|
|
|
@@ -2239,7 +2597,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
|
|
|
2239
2597
|
// return res;
|
|
2240
2598
|
//}
|
|
2241
2599
|
|
|
2242
|
-
const int nqptg = is_vec ?
|
|
2600
|
+
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;
|
|
2243
2601
|
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
|
|
2244
2602
|
|
|
2245
2603
|
const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
|
|
@@ -2355,7 +2713,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2355
2713
|
|
|
2356
2714
|
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
2357
2715
|
// half8x8 kernel
|
|
2358
|
-
const int nqptg =
|
|
2716
|
+
const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup
|
|
2359
2717
|
const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
|
|
2360
2718
|
|
|
2361
2719
|
GGML_ASSERT(nqptg <= 32);
|
|
@@ -2464,7 +2822,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2464
2822
|
|
|
2465
2823
|
// simdgroups per threadgroup (a.k.a. warps)
|
|
2466
2824
|
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
|
2467
|
-
int32_t nsg = 4;
|
|
2825
|
+
int32_t nsg = ne00 >= 512 ? 8 : 4;
|
|
2468
2826
|
|
|
2469
2827
|
const size_t smem = FATTN_SMEM(nsg);
|
|
2470
2828
|
|
|
@@ -2522,9 +2880,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2522
2880
|
#undef FATTN_SMEM
|
|
2523
2881
|
} else {
|
|
2524
2882
|
// half4x4 kernel
|
|
2525
|
-
const int nqptg =
|
|
2883
|
+
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
|
|
2526
2884
|
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
|
|
2527
|
-
const int
|
|
2885
|
+
const int nhptg = 1; // heads per threadgroup
|
|
2528
2886
|
|
|
2529
2887
|
GGML_ASSERT(nqptg <= 32);
|
|
2530
2888
|
GGML_ASSERT(nqptg % 1 == 0);
|
|
@@ -2576,6 +2934,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2576
2934
|
ggml_metal_op_concurrency_reset(ctx);
|
|
2577
2935
|
}
|
|
2578
2936
|
|
|
2937
|
+
// note: for simplicity assume the K is larger or equal than V
|
|
2938
|
+
GGML_ASSERT(ne10 >= ne20);
|
|
2939
|
+
|
|
2579
2940
|
// ne00 + 2*ncpsg*(nsg)
|
|
2580
2941
|
// for each query, we load it as f16 in shared memory (ne00)
|
|
2581
2942
|
// and store the soft_max values and the mask
|
|
@@ -2583,28 +2944,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2583
2944
|
// ne20*(nsg)
|
|
2584
2945
|
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
|
2585
2946
|
//
|
|
2586
|
-
#define FATTN_SMEM(nsg) (GGML_PAD((
|
|
2587
|
-
|
|
2588
|
-
int64_t nsgmax = 2;
|
|
2589
|
-
while (true) {
|
|
2590
|
-
const size_t smem = FATTN_SMEM(nsgmax);
|
|
2591
|
-
// avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
|
|
2592
|
-
if (smem > props_dev->max_theadgroup_memory_size/2) {
|
|
2593
|
-
break;
|
|
2594
|
-
}
|
|
2595
|
-
nsgmax *= 2;
|
|
2596
|
-
}
|
|
2597
|
-
nsgmax /= 2;
|
|
2598
|
-
|
|
2599
|
-
// simdgroups per threadgroup (a.k.a. warps)
|
|
2600
|
-
//const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
|
2601
|
-
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
|
|
2947
|
+
#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))
|
|
2602
2948
|
|
|
2603
2949
|
int64_t nsg = 1;
|
|
2604
|
-
while (nsg <= nsgt) {
|
|
2605
|
-
nsg *= 2;
|
|
2606
|
-
}
|
|
2607
|
-
nsg /= 2;
|
|
2608
2950
|
|
|
2609
2951
|
// workgroups
|
|
2610
2952
|
// each workgroup handles nsg*nkpsg cache values
|
|
@@ -2617,7 +2959,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2617
2959
|
} else {
|
|
2618
2960
|
nwg = 32;
|
|
2619
2961
|
nsg = 1;
|
|
2620
|
-
while (2*nwg*nsg*
|
|
2962
|
+
while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {
|
|
2621
2963
|
nsg *= 2;
|
|
2622
2964
|
}
|
|
2623
2965
|
}
|
|
@@ -2683,7 +3025,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2683
3025
|
|
|
2684
3026
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2685
3027
|
|
|
2686
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
|
3028
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
|
|
2687
3029
|
} else {
|
|
2688
3030
|
// sanity checks
|
|
2689
3031
|
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
|
|
@@ -2696,7 +3038,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2696
3038
|
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
|
|
2697
3039
|
|
|
2698
3040
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2699
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
|
3041
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
|
|
2700
3042
|
|
|
2701
3043
|
// sync the 2 kernels
|
|
2702
3044
|
ggml_metal_op_concurrency_reset(ctx);
|
|
@@ -2748,8 +3090,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|
|
2748
3090
|
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
2749
3091
|
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
|
|
2750
3092
|
|
|
2751
|
-
bool bcast_row = false;
|
|
2752
|
-
|
|
2753
3093
|
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
2754
3094
|
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
|
2755
3095
|
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
@@ -2843,18 +3183,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|
|
2843
3183
|
|
|
2844
3184
|
struct ggml_metal_pipeline_with_params pipeline;
|
|
2845
3185
|
|
|
2846
|
-
|
|
2847
|
-
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
2848
|
-
|
|
2849
|
-
// src1 is a row
|
|
2850
|
-
GGML_ASSERT(ne11 == 1);
|
|
2851
|
-
|
|
2852
|
-
pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true);
|
|
2853
|
-
|
|
2854
|
-
bcast_row = true;
|
|
2855
|
-
} else {
|
|
2856
|
-
pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false);
|
|
2857
|
-
}
|
|
3186
|
+
pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
|
|
2858
3187
|
|
|
2859
3188
|
if (n_fuse > 1) {
|
|
2860
3189
|
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
|
|
@@ -2868,20 +3197,26 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|
|
2868
3197
|
}
|
|
2869
3198
|
}
|
|
2870
3199
|
|
|
3200
|
+
if (pipeline.c4) {
|
|
3201
|
+
args.ne00 = ne00/4;
|
|
3202
|
+
args.ne10 = ne10/4;
|
|
3203
|
+
args.ne0 = ne0/4;
|
|
3204
|
+
}
|
|
3205
|
+
|
|
2871
3206
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2872
3207
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
2873
3208
|
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
2874
3209
|
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
|
2875
3210
|
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
|
|
2876
3211
|
|
|
2877
|
-
if (
|
|
2878
|
-
|
|
2879
|
-
|
|
2880
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
3212
|
+
if (pipeline.cnt) {
|
|
3213
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, args.ne0, ggml_nrows(op), 1, 1, 1, 1);
|
|
2881
3214
|
} else {
|
|
2882
|
-
int
|
|
3215
|
+
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2883
3216
|
|
|
2884
|
-
|
|
3217
|
+
int nth = 1;
|
|
3218
|
+
|
|
3219
|
+
while (2*nth < args.ne0 && nth < nth_max) {
|
|
2885
3220
|
nth *= 2;
|
|
2886
3221
|
}
|
|
2887
3222
|
|
|
@@ -2902,39 +3237,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2902
3237
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2903
3238
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2904
3239
|
|
|
3240
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
3241
|
+
|
|
3242
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
3243
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
3244
|
+
|
|
2905
3245
|
float eps;
|
|
2906
3246
|
memcpy(&eps, op->op_params, sizeof(float));
|
|
2907
3247
|
|
|
2908
|
-
int nth = 32; // SIMD width
|
|
2909
|
-
|
|
2910
3248
|
ggml_metal_kargs_l2_norm args = {
|
|
2911
|
-
/*.ne00
|
|
2912
|
-
/*.
|
|
2913
|
-
/*.
|
|
2914
|
-
/*.
|
|
3249
|
+
/*.ne00 =*/ ne00,
|
|
3250
|
+
/*.ne01 =*/ ne01,
|
|
3251
|
+
/*.ne02 =*/ ne02,
|
|
3252
|
+
/*.ne03 =*/ ne03,
|
|
3253
|
+
/*.nb00 =*/ nb00,
|
|
3254
|
+
/*.nb01 =*/ nb01,
|
|
3255
|
+
/*.nb02 =*/ nb02,
|
|
3256
|
+
/*.nb03 =*/ nb03,
|
|
3257
|
+
/*.ne0 =*/ ne0,
|
|
3258
|
+
/*.ne1 =*/ ne1,
|
|
3259
|
+
/*.ne2 =*/ ne2,
|
|
3260
|
+
/*.ne3 =*/ ne3,
|
|
3261
|
+
/*.nb0 =*/ nb0,
|
|
3262
|
+
/*.nb1 =*/ nb1,
|
|
3263
|
+
/*.nb2 =*/ nb2,
|
|
3264
|
+
/*.nb3 =*/ nb3,
|
|
3265
|
+
/*.eps =*/ eps,
|
|
2915
3266
|
};
|
|
2916
3267
|
|
|
2917
3268
|
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
|
|
2918
3269
|
|
|
2919
|
-
|
|
3270
|
+
if (pipeline.c4) {
|
|
3271
|
+
args.ne00 = ne00/4;
|
|
3272
|
+
args.ne0 = ne0/4;
|
|
3273
|
+
}
|
|
3274
|
+
|
|
3275
|
+
int nth = 32; // SIMD width
|
|
3276
|
+
|
|
3277
|
+
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
2920
3278
|
nth *= 2;
|
|
2921
3279
|
}
|
|
2922
3280
|
|
|
2923
3281
|
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2924
|
-
nth = std::min(nth, ne00/4);
|
|
2925
3282
|
|
|
2926
3283
|
const size_t smem = pipeline.smem;
|
|
2927
3284
|
|
|
2928
|
-
const int64_t nrows = ggml_nrows(op->src[0]);
|
|
2929
|
-
|
|
2930
3285
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2931
3286
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
2932
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
2933
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
3287
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
3288
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
2934
3289
|
|
|
2935
3290
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2936
3291
|
|
|
2937
|
-
ggml_metal_encoder_dispatch_threadgroups(enc,
|
|
3292
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
|
2938
3293
|
|
|
2939
3294
|
return 1;
|
|
2940
3295
|
}
|
|
@@ -3280,16 +3635,26 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
|
|
3280
3635
|
|
|
3281
3636
|
auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
|
|
3282
3637
|
|
|
3283
|
-
|
|
3638
|
+
if (KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
3639
|
+
const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);
|
|
3284
3640
|
|
|
3285
|
-
|
|
3641
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3642
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3643
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
|
|
3644
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
3286
3645
|
|
|
3287
|
-
|
|
3288
|
-
|
|
3289
|
-
|
|
3290
|
-
|
|
3646
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
|
|
3647
|
+
} else {
|
|
3648
|
+
const uint64_t n_threads = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), N);
|
|
3649
|
+
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
|
|
3650
|
+
|
|
3651
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3652
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3653
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
|
|
3654
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
3291
3655
|
|
|
3292
|
-
|
|
3656
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, quotient * CHW, OH, OW, n_threads, 1, 1);
|
|
3657
|
+
}
|
|
3293
3658
|
|
|
3294
3659
|
return 1;
|
|
3295
3660
|
}
|
|
@@ -3372,6 +3737,77 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
|
|
|
3372
3737
|
return 1;
|
|
3373
3738
|
}
|
|
3374
3739
|
|
|
3740
|
+
int ggml_metal_op_conv_3d(ggml_metal_op_t ctx, int idx) {
|
|
3741
|
+
ggml_tensor * op = ctx->node(idx);
|
|
3742
|
+
|
|
3743
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
3744
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
3745
|
+
|
|
3746
|
+
// 1. Extract standard dimensions and byte strides
|
|
3747
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3748
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
3749
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3750
|
+
|
|
3751
|
+
// 2. Extract hyperparams from op_params
|
|
3752
|
+
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
|
3753
|
+
const int32_t s1 = ((const int32_t *)(op->op_params))[1];
|
|
3754
|
+
const int32_t s2 = ((const int32_t *)(op->op_params))[2];
|
|
3755
|
+
const int32_t p0 = ((const int32_t *)(op->op_params))[3];
|
|
3756
|
+
const int32_t p1 = ((const int32_t *)(op->op_params))[4];
|
|
3757
|
+
const int32_t p2 = ((const int32_t *)(op->op_params))[5];
|
|
3758
|
+
const int32_t d0 = ((const int32_t *)(op->op_params))[6];
|
|
3759
|
+
const int32_t d1 = ((const int32_t *)(op->op_params))[7];
|
|
3760
|
+
const int32_t d2 = ((const int32_t *)(op->op_params))[8];
|
|
3761
|
+
const int32_t IC = ((const int32_t *)(op->op_params))[9];
|
|
3762
|
+
const int32_t N = ((const int32_t *)(op->op_params))[10];
|
|
3763
|
+
const int32_t OC = ((const int32_t *)(op->op_params))[11];
|
|
3764
|
+
|
|
3765
|
+
// 3. Build the parameter struct using the macro-generated variables
|
|
3766
|
+
ggml_metal_kargs_conv_3d args = {
|
|
3767
|
+
/*.IW =*/ (int32_t)op->src[1]->ne[0],
|
|
3768
|
+
/*.IH =*/ (int32_t)op->src[1]->ne[1],
|
|
3769
|
+
/*.ID =*/ (int32_t)op->src[1]->ne[2],
|
|
3770
|
+
/*.OW =*/ (int32_t)op->ne[0],
|
|
3771
|
+
/*.OH =*/ (int32_t)op->ne[1],
|
|
3772
|
+
/*.OD =*/ (int32_t)op->ne[2],
|
|
3773
|
+
/*.KW =*/ (int32_t)op->src[0]->ne[0],
|
|
3774
|
+
/*.KH =*/ (int32_t)op->src[0]->ne[1],
|
|
3775
|
+
/*.KD =*/ (int32_t)op->src[0]->ne[2],
|
|
3776
|
+
s0, s1, s2,
|
|
3777
|
+
p0, p1, p2,
|
|
3778
|
+
d0, d1, d2,
|
|
3779
|
+
IC, N, OC,
|
|
3780
|
+
nb00, nb01, nb02, nb03, // Weight strides
|
|
3781
|
+
nb10, nb11, nb12, nb13, // Input strides
|
|
3782
|
+
nb0, nb1, nb2, nb3 // Output strides
|
|
3783
|
+
};
|
|
3784
|
+
|
|
3785
|
+
// 4. Fetch the JIT pipeline
|
|
3786
|
+
auto pipeline = ggml_metal_library_get_pipeline_conv_3d(lib, op);
|
|
3787
|
+
|
|
3788
|
+
// 5. Grid mapping
|
|
3789
|
+
int nth0 = 32; // Standard SIMD width for Apple Silicon
|
|
3790
|
+
int nth1 = 1;
|
|
3791
|
+
int nth2 = 1;
|
|
3792
|
+
|
|
3793
|
+
int64_t spatial_volume = args.OW * args.OH * args.OD;
|
|
3794
|
+
|
|
3795
|
+
int ntg0 = (spatial_volume + nth0 - 1) / nth0;
|
|
3796
|
+
int ntg1 = args.OC;
|
|
3797
|
+
int ntg2 = args.N;
|
|
3798
|
+
|
|
3799
|
+
// 6. Bind and Dispatch via the ggml C wrapper
|
|
3800
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3801
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3802
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
3803
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
3804
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
3805
|
+
|
|
3806
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ntg0, ntg1, ntg2, nth0, nth1, nth2);
|
|
3807
|
+
|
|
3808
|
+
return 1;
|
|
3809
|
+
}
|
|
3810
|
+
|
|
3375
3811
|
int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
|
3376
3812
|
ggml_tensor * op = ctx->node(idx);
|
|
3377
3813
|
|
|
@@ -3484,12 +3920,76 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
|
|
3484
3920
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3485
3921
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3486
3922
|
|
|
3487
|
-
|
|
3488
|
-
|
|
3489
|
-
|
|
3490
|
-
|
|
3923
|
+
float sf0 = (float)ne0/op->src[0]->ne[0];
|
|
3924
|
+
float sf1 = (float)ne1/op->src[0]->ne[1];
|
|
3925
|
+
float sf2 = (float)ne2/op->src[0]->ne[2];
|
|
3926
|
+
float sf3 = (float)ne3/op->src[0]->ne[3];
|
|
3927
|
+
|
|
3928
|
+
const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
|
|
3929
|
+
|
|
3930
|
+
float poffs = 0.5f;
|
|
3931
|
+
|
|
3932
|
+
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
3933
|
+
poffs = 0.0f;
|
|
3934
|
+
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
|
|
3935
|
+
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
|
|
3936
|
+
}
|
|
3491
3937
|
|
|
3492
3938
|
ggml_metal_kargs_upscale args = {
|
|
3939
|
+
/*.ne00 =*/ ne00,
|
|
3940
|
+
/*.ne01 =*/ ne01,
|
|
3941
|
+
/*.ne02 =*/ ne02,
|
|
3942
|
+
/*.ne03 =*/ ne03,
|
|
3943
|
+
/*.nb00 =*/ nb00,
|
|
3944
|
+
/*.nb01 =*/ nb01,
|
|
3945
|
+
/*.nb02 =*/ nb02,
|
|
3946
|
+
/*.nb03 =*/ nb03,
|
|
3947
|
+
/*.ne0 =*/ ne0,
|
|
3948
|
+
/*.ne1 =*/ ne1,
|
|
3949
|
+
/*.ne2 =*/ ne2,
|
|
3950
|
+
/*.ne3 =*/ ne3,
|
|
3951
|
+
/*.nb0 =*/ nb0,
|
|
3952
|
+
/*.nb1 =*/ nb1,
|
|
3953
|
+
/*.nb2 =*/ nb2,
|
|
3954
|
+
/*.nb3 =*/ nb3,
|
|
3955
|
+
/*.sf0 =*/ sf0,
|
|
3956
|
+
/*.sf1 =*/ sf1,
|
|
3957
|
+
/*.sf2 =*/ sf2,
|
|
3958
|
+
/*.sf3 =*/ sf3,
|
|
3959
|
+
/*.poffs =*/ poffs,
|
|
3960
|
+
};
|
|
3961
|
+
|
|
3962
|
+
auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
|
|
3963
|
+
|
|
3964
|
+
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
|
3965
|
+
|
|
3966
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3967
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3968
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
3969
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
3970
|
+
|
|
3971
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
|
|
3972
|
+
|
|
3973
|
+
return 1;
|
|
3974
|
+
}
|
|
3975
|
+
|
|
3976
|
+
int ggml_metal_op_roll(ggml_metal_op_t ctx, int idx) {
|
|
3977
|
+
ggml_tensor * op = ctx->node(idx);
|
|
3978
|
+
|
|
3979
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
3980
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
3981
|
+
|
|
3982
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3983
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3984
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3985
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3986
|
+
|
|
3987
|
+
const int32_t s0 = ggml_get_op_params_i32(op, 0);
|
|
3988
|
+
const int32_t s1 = ggml_get_op_params_i32(op, 1);
|
|
3989
|
+
const int32_t s2 = ggml_get_op_params_i32(op, 2);
|
|
3990
|
+
const int32_t s3 = ggml_get_op_params_i32(op, 3);
|
|
3991
|
+
|
|
3992
|
+
ggml_metal_kargs_roll args = {
|
|
3493
3993
|
/*.ne00 =*/ ne00,
|
|
3494
3994
|
/*.ne01 =*/ ne01,
|
|
3495
3995
|
/*.ne02 =*/ ne02,
|
|
@@ -3498,23 +3998,23 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
|
|
3498
3998
|
/*.nb01 =*/ nb01,
|
|
3499
3999
|
/*.nb02 =*/ nb02,
|
|
3500
4000
|
/*.nb03 =*/ nb03,
|
|
3501
|
-
/*.ne0
|
|
3502
|
-
/*.ne1
|
|
3503
|
-
/*.ne2
|
|
3504
|
-
/*.ne3
|
|
3505
|
-
/*.nb0
|
|
3506
|
-
/*.nb1
|
|
3507
|
-
/*.nb2
|
|
3508
|
-
/*.nb3
|
|
3509
|
-
/*.
|
|
3510
|
-
/*.
|
|
3511
|
-
/*.
|
|
3512
|
-
/*.
|
|
4001
|
+
/*.ne0 =*/ ne0,
|
|
4002
|
+
/*.ne1 =*/ ne1,
|
|
4003
|
+
/*.ne2 =*/ ne2,
|
|
4004
|
+
/*.ne3 =*/ ne3,
|
|
4005
|
+
/*.nb0 =*/ nb0,
|
|
4006
|
+
/*.nb1 =*/ nb1,
|
|
4007
|
+
/*.nb2 =*/ nb2,
|
|
4008
|
+
/*.nb3 =*/ nb3,
|
|
4009
|
+
/*.s0 =*/ s0,
|
|
4010
|
+
/*.s1 =*/ s1,
|
|
4011
|
+
/*.s2 =*/ s2,
|
|
4012
|
+
/*.s3 =*/ s3
|
|
3513
4013
|
};
|
|
3514
4014
|
|
|
3515
|
-
auto pipeline =
|
|
4015
|
+
auto pipeline = ggml_metal_library_get_pipeline_roll(lib, op);
|
|
3516
4016
|
|
|
3517
|
-
const int nth = std::min(
|
|
4017
|
+
const int nth = std::min(1024, ne0);
|
|
3518
4018
|
|
|
3519
4019
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3520
4020
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -3558,14 +4058,21 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
|
|
|
3558
4058
|
|
|
3559
4059
|
auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
|
|
3560
4060
|
|
|
3561
|
-
|
|
4061
|
+
if (pipeline.c4) {
|
|
4062
|
+
args.ne00 = ne00/4;
|
|
4063
|
+
args.ne0 = ne0/4;
|
|
4064
|
+
}
|
|
4065
|
+
|
|
4066
|
+
const int nth_max = MIN(64, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
4067
|
+
const int nth = MIN(args.ne0, nth_max);
|
|
4068
|
+
const int nk0 = (args.ne0 + 1024 - 1)/1024; // note: 1024 is hardcoded in the kernel!
|
|
3562
4069
|
|
|
3563
4070
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3564
4071
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3565
4072
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
3566
4073
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
3567
4074
|
|
|
3568
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
|
|
4075
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne1, ne2, ne3, nth, 1, 1);
|
|
3569
4076
|
|
|
3570
4077
|
return 1;
|
|
3571
4078
|
}
|
|
@@ -3942,42 +4449,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
|
|
|
3942
4449
|
return 1;
|
|
3943
4450
|
}
|
|
3944
4451
|
|
|
3945
|
-
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
|
3946
|
-
ggml_tensor * op = ctx->node(idx);
|
|
3947
|
-
|
|
3948
|
-
ggml_metal_library_t lib = ctx->lib;
|
|
3949
|
-
ggml_metal_encoder_t enc = ctx->enc;
|
|
3950
|
-
|
|
3951
|
-
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3952
|
-
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3953
|
-
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3954
|
-
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3955
|
-
|
|
3956
|
-
float slope;
|
|
3957
|
-
memcpy(&slope, op->op_params, sizeof(float));
|
|
3958
|
-
|
|
3959
|
-
ggml_metal_kargs_leaky_relu args = {
|
|
3960
|
-
/*.slope =*/ slope
|
|
3961
|
-
};
|
|
3962
|
-
|
|
3963
|
-
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
|
3964
|
-
|
|
3965
|
-
int64_t n = ggml_nelements(op);
|
|
3966
|
-
|
|
3967
|
-
if (n % 4 == 0) {
|
|
3968
|
-
n /= 4;
|
|
3969
|
-
}
|
|
3970
|
-
|
|
3971
|
-
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3972
|
-
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3973
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
3974
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
3975
|
-
|
|
3976
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
3977
|
-
|
|
3978
|
-
return 1;
|
|
3979
|
-
}
|
|
3980
|
-
|
|
3981
4452
|
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
|
|
3982
4453
|
ggml_tensor * op = ctx->node(idx);
|
|
3983
4454
|
|