whispercpp 1.3.6 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/README.md +38 -5
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -8
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +36 -42
- data/ext/ruby_whisper.h +135 -0
- data/ext/ruby_whisper_context.c +107 -28
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -65
- data/ext/ruby_whisper_segment.c +6 -6
- data/ext/ruby_whisper_transcribe.cpp +42 -15
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +1 -1
- data/ext/sources/examples/cli/cli.cpp +43 -9
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +199 -163
- data/ext/sources/ggml/CMakeLists.txt +21 -13
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +72 -10
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-rpc.h +3 -3
- data/ext/sources/ggml/include/ggml.h +101 -9
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +22 -5
- data/ext/sources/ggml/src/ggml-alloc.c +5 -1
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
- data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
- data/ext/sources/ggml/src/ggml-impl.h +6 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
- data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +289 -114
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
- data/ext/sources/ggml/src/ggml.c +110 -28
- data/ext/sources/ggml/src/gguf.cpp +173 -28
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +56 -12
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +411 -62
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +24 -6
- data/whispercpp.gemspec +2 -2
- metadata +215 -281
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
- data/ext/sources/examples/talk-llama/llama-context.h +0 -359
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
- data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
- data/ext/sources/examples/talk-llama/llama-model.h +0 -597
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
- data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
- data/ext/sources/examples/talk-llama/llama.h +0 -1573
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -704
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
- /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
|
@@ -118,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg
|
|
|
118
118
|
}
|
|
119
119
|
#endif
|
|
120
120
|
|
|
121
|
+
template <typename type4x4>
|
|
122
|
+
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
|
|
123
|
+
device const uint8_t * qs = xb->qs;
|
|
124
|
+
const float d = xb->d;
|
|
125
|
+
const float neg_d = -d;
|
|
126
|
+
|
|
127
|
+
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
|
|
128
|
+
const uint8_t b0 = qs[byte_offset];
|
|
129
|
+
const uint8_t b1 = qs[byte_offset + 1];
|
|
130
|
+
|
|
131
|
+
float4x4 reg_f;
|
|
132
|
+
|
|
133
|
+
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
|
|
134
|
+
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
|
|
135
|
+
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
|
|
136
|
+
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
|
|
137
|
+
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
|
|
138
|
+
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
|
|
139
|
+
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
|
|
140
|
+
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
|
|
141
|
+
|
|
142
|
+
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
|
|
143
|
+
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
|
|
144
|
+
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
|
|
145
|
+
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
|
|
146
|
+
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
|
|
147
|
+
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
|
|
148
|
+
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
|
|
149
|
+
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
|
|
150
|
+
|
|
151
|
+
reg = (type4x4) reg_f;
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
template <typename type4>
|
|
155
|
+
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
|
|
156
|
+
const float d = xb->d;
|
|
157
|
+
const float neg_d = -d;
|
|
158
|
+
const int base = il * 4;
|
|
159
|
+
const uint8_t byte = xb->qs[base / 8];
|
|
160
|
+
const int s = base % 8;
|
|
161
|
+
|
|
162
|
+
float4 reg_f;
|
|
163
|
+
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
|
|
164
|
+
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
|
|
165
|
+
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
|
|
166
|
+
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
|
|
167
|
+
|
|
168
|
+
reg = (type4) reg_f;
|
|
169
|
+
}
|
|
170
|
+
|
|
121
171
|
template <typename type4x4>
|
|
122
172
|
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
|
|
123
173
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
|
@@ -152,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
|
|
152
202
|
}
|
|
153
203
|
}
|
|
154
204
|
|
|
205
|
+
void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
|
|
206
|
+
float sum_abs = 0.0f;
|
|
207
|
+
for (int j = 0; j < QK1_0; j++) {
|
|
208
|
+
sum_abs += fabs(src[j]);
|
|
209
|
+
}
|
|
210
|
+
dst.d = sum_abs / QK1_0;
|
|
211
|
+
|
|
212
|
+
for (int j = 0; j < QK1_0 / 8; j++) {
|
|
213
|
+
dst.qs[j] = 0;
|
|
214
|
+
}
|
|
215
|
+
for (int j = 0; j < QK1_0; j++) {
|
|
216
|
+
if (src[j] >= 0.0f) {
|
|
217
|
+
dst.qs[j / 8] |= (1 << (j % 8));
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
}
|
|
221
|
+
|
|
155
222
|
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
|
156
223
|
#pragma METAL fp math_mode(safe)
|
|
157
224
|
float amax = 0.0f; // absolute max
|
|
@@ -1094,6 +1161,31 @@ kernel void kernel_unary_impl(
|
|
|
1094
1161
|
// TODO: precise implementation
|
|
1095
1162
|
dst_ptr[i0] = (T) (exp(x) - 1);
|
|
1096
1163
|
}
|
|
1164
|
+
|
|
1165
|
+
if (FC_OP == OP_UNARY_NUM_FLOOR) {
|
|
1166
|
+
dst_ptr[i0] = (T) floor(x);
|
|
1167
|
+
}
|
|
1168
|
+
|
|
1169
|
+
if (FC_OP == OP_UNARY_NUM_CEIL) {
|
|
1170
|
+
dst_ptr[i0] = (T) ceil(x);
|
|
1171
|
+
}
|
|
1172
|
+
|
|
1173
|
+
if (FC_OP == OP_UNARY_NUM_ROUND) {
|
|
1174
|
+
dst_ptr[i0] = (T) round(x);
|
|
1175
|
+
}
|
|
1176
|
+
|
|
1177
|
+
if (FC_OP == OP_UNARY_NUM_TRUNC) {
|
|
1178
|
+
dst_ptr[i0] = (T) trunc(x);
|
|
1179
|
+
}
|
|
1180
|
+
|
|
1181
|
+
if (FC_OP == OP_UNARY_NUM_XIELU) {
|
|
1182
|
+
const TC xi = x;
|
|
1183
|
+
const TC gate = TC(xi > TC(0.0f));
|
|
1184
|
+
const TC clamped = fmin(xi, TC(args.val));
|
|
1185
|
+
const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi;
|
|
1186
|
+
const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi;
|
|
1187
|
+
dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg);
|
|
1188
|
+
}
|
|
1097
1189
|
}
|
|
1098
1190
|
|
|
1099
1191
|
#undef FC_OP
|
|
@@ -1329,7 +1421,8 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat
|
|
|
1329
1421
|
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
|
|
1330
1422
|
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
|
1331
1423
|
|
|
1332
|
-
|
|
1424
|
+
template<typename T>
|
|
1425
|
+
kernel void kernel_reglu(
|
|
1333
1426
|
constant ggml_metal_kargs_glu & args,
|
|
1334
1427
|
device const char * src0,
|
|
1335
1428
|
device const char * src1,
|
|
@@ -1337,19 +1430,25 @@ kernel void kernel_reglu_f32(
|
|
|
1337
1430
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
1338
1431
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
1339
1432
|
uint ntg[[threads_per_threadgroup]]) {
|
|
1340
|
-
device const
|
|
1341
|
-
device const
|
|
1342
|
-
device
|
|
1433
|
+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1434
|
+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1435
|
+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
|
1343
1436
|
|
|
1344
1437
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1345
1438
|
const float x0 = src0_row[i0];
|
|
1346
1439
|
const float x1 = src1_row[i0];
|
|
1347
1440
|
|
|
1348
|
-
dst_row[i0] = x0*x1*(x0 > 0.0f);
|
|
1441
|
+
dst_row[i0] = (T)(x0*x1*(x0 > 0.0f));
|
|
1349
1442
|
}
|
|
1350
1443
|
}
|
|
1351
1444
|
|
|
1352
|
-
|
|
1445
|
+
typedef decltype(kernel_reglu<float>) kernel_reglu_t;
|
|
1446
|
+
|
|
1447
|
+
template [[host_name("kernel_reglu_f32")]] kernel kernel_reglu_t kernel_reglu<float>;
|
|
1448
|
+
template [[host_name("kernel_reglu_f16")]] kernel kernel_reglu_t kernel_reglu<half>;
|
|
1449
|
+
|
|
1450
|
+
template<typename T>
|
|
1451
|
+
kernel void kernel_geglu(
|
|
1353
1452
|
constant ggml_metal_kargs_glu & args,
|
|
1354
1453
|
device const char * src0,
|
|
1355
1454
|
device const char * src1,
|
|
@@ -1357,9 +1456,9 @@ kernel void kernel_geglu_f32(
|
|
|
1357
1456
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
1358
1457
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
1359
1458
|
uint ntg[[threads_per_threadgroup]]) {
|
|
1360
|
-
device const
|
|
1361
|
-
device const
|
|
1362
|
-
device
|
|
1459
|
+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1460
|
+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1461
|
+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
|
1363
1462
|
|
|
1364
1463
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1365
1464
|
const float x0 = src0_row[i0];
|
|
@@ -1367,11 +1466,17 @@ kernel void kernel_geglu_f32(
|
|
|
1367
1466
|
|
|
1368
1467
|
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
|
1369
1468
|
|
|
1370
|
-
dst_row[i0] = gelu*x1;
|
|
1469
|
+
dst_row[i0] = (T)(gelu*x1);
|
|
1371
1470
|
}
|
|
1372
1471
|
}
|
|
1373
1472
|
|
|
1374
|
-
|
|
1473
|
+
typedef decltype(kernel_geglu<float>) kernel_geglu_t;
|
|
1474
|
+
|
|
1475
|
+
template [[host_name("kernel_geglu_f32")]] kernel kernel_geglu_t kernel_geglu<float>;
|
|
1476
|
+
template [[host_name("kernel_geglu_f16")]] kernel kernel_geglu_t kernel_geglu<half>;
|
|
1477
|
+
|
|
1478
|
+
template<typename T>
|
|
1479
|
+
kernel void kernel_swiglu(
|
|
1375
1480
|
constant ggml_metal_kargs_glu & args,
|
|
1376
1481
|
device const char * src0,
|
|
1377
1482
|
device const char * src1,
|
|
@@ -1379,9 +1484,9 @@ kernel void kernel_swiglu_f32(
|
|
|
1379
1484
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
1380
1485
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
1381
1486
|
uint ntg[[threads_per_threadgroup]]) {
|
|
1382
|
-
device const
|
|
1383
|
-
device const
|
|
1384
|
-
device
|
|
1487
|
+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1488
|
+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1489
|
+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
|
1385
1490
|
|
|
1386
1491
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1387
1492
|
const float x0 = src0_row[i0];
|
|
@@ -1389,11 +1494,17 @@ kernel void kernel_swiglu_f32(
|
|
|
1389
1494
|
|
|
1390
1495
|
const float silu = x0 / (1.0f + exp(-x0));
|
|
1391
1496
|
|
|
1392
|
-
dst_row[i0] = silu*x1;
|
|
1497
|
+
dst_row[i0] = (T)(silu*x1);
|
|
1393
1498
|
}
|
|
1394
1499
|
}
|
|
1395
1500
|
|
|
1396
|
-
|
|
1501
|
+
typedef decltype(kernel_swiglu<float>) kernel_swiglu_t;
|
|
1502
|
+
|
|
1503
|
+
template [[host_name("kernel_swiglu_f32")]] kernel kernel_swiglu_t kernel_swiglu<float>;
|
|
1504
|
+
template [[host_name("kernel_swiglu_f16")]] kernel kernel_swiglu_t kernel_swiglu<half>;
|
|
1505
|
+
|
|
1506
|
+
template<typename T>
|
|
1507
|
+
kernel void kernel_swiglu_oai(
|
|
1397
1508
|
constant ggml_metal_kargs_glu & args,
|
|
1398
1509
|
device const char * src0,
|
|
1399
1510
|
device const char * src1,
|
|
@@ -1401,9 +1512,9 @@ kernel void kernel_swiglu_oai_f32(
|
|
|
1401
1512
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
1402
1513
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
1403
1514
|
uint ntg[[threads_per_threadgroup]]) {
|
|
1404
|
-
device const
|
|
1405
|
-
device const
|
|
1406
|
-
device
|
|
1515
|
+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1516
|
+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1517
|
+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
|
1407
1518
|
|
|
1408
1519
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1409
1520
|
float x0 = src0_row[i0];
|
|
@@ -1415,11 +1526,17 @@ kernel void kernel_swiglu_oai_f32(
|
|
|
1415
1526
|
float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
|
|
1416
1527
|
out_glu = out_glu * (1.0f + x1);
|
|
1417
1528
|
|
|
1418
|
-
dst_row[i0] = out_glu;
|
|
1529
|
+
dst_row[i0] = (T)out_glu;
|
|
1419
1530
|
}
|
|
1420
1531
|
}
|
|
1421
1532
|
|
|
1422
|
-
|
|
1533
|
+
typedef decltype(kernel_swiglu_oai<float>) kernel_swiglu_oai_t;
|
|
1534
|
+
|
|
1535
|
+
template [[host_name("kernel_swiglu_oai_f32")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<float>;
|
|
1536
|
+
template [[host_name("kernel_swiglu_oai_f16")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<half>;
|
|
1537
|
+
|
|
1538
|
+
template<typename T>
|
|
1539
|
+
kernel void kernel_geglu_erf(
|
|
1423
1540
|
constant ggml_metal_kargs_glu & args,
|
|
1424
1541
|
device const char * src0,
|
|
1425
1542
|
device const char * src1,
|
|
@@ -1427,9 +1544,9 @@ kernel void kernel_geglu_erf_f32(
|
|
|
1427
1544
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
1428
1545
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
1429
1546
|
uint ntg[[threads_per_threadgroup]]) {
|
|
1430
|
-
device const
|
|
1431
|
-
device const
|
|
1432
|
-
device
|
|
1547
|
+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1548
|
+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1549
|
+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
|
1433
1550
|
|
|
1434
1551
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1435
1552
|
const float x0 = src0_row[i0];
|
|
@@ -1437,11 +1554,17 @@ kernel void kernel_geglu_erf_f32(
|
|
|
1437
1554
|
|
|
1438
1555
|
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
|
|
1439
1556
|
|
|
1440
|
-
dst_row[i0] = gelu_erf*x1;
|
|
1557
|
+
dst_row[i0] = (T)(gelu_erf*x1);
|
|
1441
1558
|
}
|
|
1442
1559
|
}
|
|
1443
1560
|
|
|
1444
|
-
|
|
1561
|
+
typedef decltype(kernel_geglu_erf<float>) kernel_geglu_erf_t;
|
|
1562
|
+
|
|
1563
|
+
template [[host_name("kernel_geglu_erf_f32")]] kernel kernel_geglu_erf_t kernel_geglu_erf<float>;
|
|
1564
|
+
template [[host_name("kernel_geglu_erf_f16")]] kernel kernel_geglu_erf_t kernel_geglu_erf<half>;
|
|
1565
|
+
|
|
1566
|
+
template<typename T>
|
|
1567
|
+
kernel void kernel_geglu_quick(
|
|
1445
1568
|
constant ggml_metal_kargs_glu & args,
|
|
1446
1569
|
device const char * src0,
|
|
1447
1570
|
device const char * src1,
|
|
@@ -1449,9 +1572,9 @@ kernel void kernel_geglu_quick_f32(
|
|
|
1449
1572
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
1450
1573
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
1451
1574
|
uint ntg[[threads_per_threadgroup]]) {
|
|
1452
|
-
device const
|
|
1453
|
-
device const
|
|
1454
|
-
device
|
|
1575
|
+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1576
|
+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1577
|
+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
|
1455
1578
|
|
|
1456
1579
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1457
1580
|
const float x0 = src0_row[i0];
|
|
@@ -1459,10 +1582,15 @@ kernel void kernel_geglu_quick_f32(
|
|
|
1459
1582
|
|
|
1460
1583
|
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
|
|
1461
1584
|
|
|
1462
|
-
dst_row[i0] = gelu_quick*x1;
|
|
1585
|
+
dst_row[i0] = (T)(gelu_quick*x1);
|
|
1463
1586
|
}
|
|
1464
1587
|
}
|
|
1465
1588
|
|
|
1589
|
+
typedef decltype(kernel_geglu_quick<float>) kernel_geglu_quick_t;
|
|
1590
|
+
|
|
1591
|
+
template [[host_name("kernel_geglu_quick_f32")]] kernel kernel_geglu_quick_t kernel_geglu_quick<float>;
|
|
1592
|
+
template [[host_name("kernel_geglu_quick_f16")]] kernel kernel_geglu_quick_t kernel_geglu_quick<half>;
|
|
1593
|
+
|
|
1466
1594
|
kernel void kernel_op_sum_f32(
|
|
1467
1595
|
constant ggml_metal_kargs_sum & args,
|
|
1468
1596
|
device const float * src0,
|
|
@@ -2439,6 +2567,7 @@ kernel void kernel_rwkv_wkv7_f32(
|
|
|
2439
2567
|
|
|
2440
2568
|
constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
|
|
2441
2569
|
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
|
|
2570
|
+
constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];
|
|
2442
2571
|
|
|
2443
2572
|
#if 1
|
|
2444
2573
|
template<short NSG>
|
|
@@ -2456,21 +2585,24 @@ kernel void kernel_gated_delta_net_impl(
|
|
|
2456
2585
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
2457
2586
|
#define S_v FC_gated_delta_net_ne20
|
|
2458
2587
|
#define G FC_gated_delta_net_ne30
|
|
2588
|
+
#define K FC_gated_delta_net_K
|
|
2459
2589
|
|
|
2460
2590
|
const uint tx = tpitg.x;
|
|
2461
2591
|
const uint ty = tpitg.y;
|
|
2462
2592
|
|
|
2463
|
-
const uint i23 = tgpig.z; // B
|
|
2464
|
-
const uint i21 = tgpig.y; // H
|
|
2465
|
-
const uint i20 = tgpig.x*NSG + ty;
|
|
2593
|
+
const uint i23 = tgpig.z; // B (n_seqs)
|
|
2594
|
+
const uint i21 = tgpig.y; // H (head)
|
|
2595
|
+
const uint i20 = tgpig.x*NSG + ty; // row within S_v
|
|
2466
2596
|
|
|
2467
2597
|
const uint i01 = i21 % args.ne01;
|
|
2468
2598
|
const uint i11 = i21 % args.ne11;
|
|
2469
2599
|
|
|
2470
2600
|
const float scale = 1.0f / sqrt((float)S_v);
|
|
2471
2601
|
|
|
2602
|
+
// input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D.
|
|
2472
2603
|
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
|
|
2473
|
-
|
|
2604
|
+
const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
|
2605
|
+
device const float * s_ptr = (device const float *) (s) + state_in_base;
|
|
2474
2606
|
|
|
2475
2607
|
float ls[NSG];
|
|
2476
2608
|
|
|
@@ -2488,6 +2620,16 @@ kernel void kernel_gated_delta_net_impl(
|
|
|
2488
2620
|
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
|
|
2489
2621
|
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
|
|
2490
2622
|
|
|
2623
|
+
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
|
|
2624
|
+
// When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned.
|
|
2625
|
+
|
|
2626
|
+
// output state base offset: after attention scores
|
|
2627
|
+
const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
|
|
2628
|
+
// output state per-slot size: S_v * S_v * H * n_seqs
|
|
2629
|
+
const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23;
|
|
2630
|
+
// per-(seq,head) offset within a slot
|
|
2631
|
+
const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
|
2632
|
+
|
|
2491
2633
|
for (short t = 0; t < args.ne22; t++) {
|
|
2492
2634
|
float s_k = 0.0f;
|
|
2493
2635
|
|
|
@@ -2535,17 +2677,30 @@ kernel void kernel_gated_delta_net_impl(
|
|
|
2535
2677
|
|
|
2536
2678
|
b_ptr += args.ne21;
|
|
2537
2679
|
g_ptr += args.ne21*G;
|
|
2538
|
-
}
|
|
2539
2680
|
|
|
2540
|
-
|
|
2681
|
+
if (K > 1) {
|
|
2682
|
+
const int target_slot = (int)args.ne22 - 1 - (int)t;
|
|
2683
|
+
if (target_slot >= 0 && target_slot < (int)K) {
|
|
2684
|
+
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
|
|
2685
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2686
|
+
const short is = tx*NSG + j;
|
|
2687
|
+
dst_state[is] = ls[j];
|
|
2688
|
+
}
|
|
2689
|
+
}
|
|
2690
|
+
}
|
|
2691
|
+
}
|
|
2541
2692
|
|
|
2542
|
-
|
|
2543
|
-
|
|
2544
|
-
|
|
2693
|
+
if (K == 1) {
|
|
2694
|
+
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
|
|
2695
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2696
|
+
const short is = tx*NSG + j;
|
|
2697
|
+
dst_state[is] = ls[j];
|
|
2698
|
+
}
|
|
2545
2699
|
}
|
|
2546
2700
|
|
|
2547
2701
|
#undef S_v
|
|
2548
2702
|
#undef G
|
|
2703
|
+
#undef K
|
|
2549
2704
|
}
|
|
2550
2705
|
|
|
2551
2706
|
typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
|
|
@@ -3100,6 +3255,35 @@ kernel void kernel_group_norm_f32(
|
|
|
3100
3255
|
}
|
|
3101
3256
|
}
|
|
3102
3257
|
|
|
3258
|
+
// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
|
|
3259
|
+
inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
|
|
3260
|
+
device const uint8_t * qs = qb_curr->qs + il / 8;
|
|
3261
|
+
const uint8_t b0 = qs[0];
|
|
3262
|
+
const uint8_t b1 = qs[1];
|
|
3263
|
+
|
|
3264
|
+
float acc = 0.0f;
|
|
3265
|
+
|
|
3266
|
+
acc += select(0.0f, yl[ 0], bool(b0 & 0x01));
|
|
3267
|
+
acc += select(0.0f, yl[ 1], bool(b0 & 0x02));
|
|
3268
|
+
acc += select(0.0f, yl[ 2], bool(b0 & 0x04));
|
|
3269
|
+
acc += select(0.0f, yl[ 3], bool(b0 & 0x08));
|
|
3270
|
+
acc += select(0.0f, yl[ 4], bool(b0 & 0x10));
|
|
3271
|
+
acc += select(0.0f, yl[ 5], bool(b0 & 0x20));
|
|
3272
|
+
acc += select(0.0f, yl[ 6], bool(b0 & 0x40));
|
|
3273
|
+
acc += select(0.0f, yl[ 7], bool(b0 & 0x80));
|
|
3274
|
+
|
|
3275
|
+
acc += select(0.0f, yl[ 8], bool(b1 & 0x01));
|
|
3276
|
+
acc += select(0.0f, yl[ 9], bool(b1 & 0x02));
|
|
3277
|
+
acc += select(0.0f, yl[10], bool(b1 & 0x04));
|
|
3278
|
+
acc += select(0.0f, yl[11], bool(b1 & 0x08));
|
|
3279
|
+
acc += select(0.0f, yl[12], bool(b1 & 0x10));
|
|
3280
|
+
acc += select(0.0f, yl[13], bool(b1 & 0x20));
|
|
3281
|
+
acc += select(0.0f, yl[14], bool(b1 & 0x40));
|
|
3282
|
+
acc += select(0.0f, yl[15], bool(b1 & 0x80));
|
|
3283
|
+
|
|
3284
|
+
return qb_curr->d * (2.0f * acc - sumy);
|
|
3285
|
+
}
|
|
3286
|
+
|
|
3103
3287
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
3104
3288
|
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
|
3105
3289
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
|
@@ -3232,6 +3416,9 @@ static inline void helper_mv_reduce_and_write(
|
|
|
3232
3416
|
|
|
3233
3417
|
constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
|
|
3234
3418
|
constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
|
|
3419
|
+
constant short FC_mul_mv_ne12 [[function_constant(FC_MUL_MV + 2)]];
|
|
3420
|
+
constant short FC_mul_mv_r2 [[function_constant(FC_MUL_MV + 3)]];
|
|
3421
|
+
constant short FC_mul_mv_r3 [[function_constant(FC_MUL_MV + 4)]];
|
|
3235
3422
|
|
|
3236
3423
|
template<typename block_q_type, short NR0, typename args_t>
|
|
3237
3424
|
void mul_vec_q_n_f32_impl(
|
|
@@ -3255,10 +3442,10 @@ void mul_vec_q_n_f32_impl(
|
|
|
3255
3442
|
const int r1 = tgpig.y;
|
|
3256
3443
|
const int im = tgpig.z;
|
|
3257
3444
|
|
|
3258
|
-
const uint i12 = im%
|
|
3259
|
-
const uint i13 = im/
|
|
3445
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
3446
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
3260
3447
|
|
|
3261
|
-
//const uint64_t offset0 = r0*args.nb01 + (i12/
|
|
3448
|
+
//const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3262
3449
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
3263
3450
|
|
|
3264
3451
|
//device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
|
|
@@ -3267,7 +3454,7 @@ void mul_vec_q_n_f32_impl(
|
|
|
3267
3454
|
// pointers to src0 rows
|
|
3268
3455
|
device const block_q_type * ax[NR0];
|
|
3269
3456
|
FOR_UNROLL (int row = 0; row < NR0; ++row) {
|
|
3270
|
-
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/
|
|
3457
|
+
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3271
3458
|
|
|
3272
3459
|
ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
|
|
3273
3460
|
}
|
|
@@ -3321,6 +3508,85 @@ void mul_vec_q_n_f32_impl(
|
|
|
3321
3508
|
}
|
|
3322
3509
|
}
|
|
3323
3510
|
|
|
3511
|
+
template<int nr0, typename args_t>
|
|
3512
|
+
void kernel_mul_mv_q1_0_f32_impl(
|
|
3513
|
+
args_t args,
|
|
3514
|
+
device const char * src0,
|
|
3515
|
+
device const char * src1,
|
|
3516
|
+
device char * dst,
|
|
3517
|
+
threadgroup char * shmem,
|
|
3518
|
+
uint3 tgpig,
|
|
3519
|
+
ushort tiisg,
|
|
3520
|
+
ushort sgitg) {
|
|
3521
|
+
const short NSG = FC_mul_mv_nsg;
|
|
3522
|
+
|
|
3523
|
+
const int nb = args.ne00/QK1_0;
|
|
3524
|
+
|
|
3525
|
+
const int r0 = tgpig.x;
|
|
3526
|
+
const int r1 = tgpig.y;
|
|
3527
|
+
const int im = tgpig.z;
|
|
3528
|
+
|
|
3529
|
+
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
3530
|
+
|
|
3531
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
3532
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
3533
|
+
|
|
3534
|
+
const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13;
|
|
3535
|
+
|
|
3536
|
+
device const float * y = (device const float *) (src1 + offset1);
|
|
3537
|
+
|
|
3538
|
+
device const block_q1_0 * ax[nr0];
|
|
3539
|
+
for (int row = 0; row < nr0; ++row) {
|
|
3540
|
+
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3541
|
+
ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
|
|
3542
|
+
}
|
|
3543
|
+
|
|
3544
|
+
float yl[16];
|
|
3545
|
+
float sumf[nr0] = {0.f};
|
|
3546
|
+
|
|
3547
|
+
const short ix = (tiisg/8);
|
|
3548
|
+
const short il = (tiisg%8)*16;
|
|
3549
|
+
|
|
3550
|
+
device const float * yb = y + ix*QK1_0 + il;
|
|
3551
|
+
|
|
3552
|
+
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
|
|
3553
|
+
float sumy = 0.f;
|
|
3554
|
+
|
|
3555
|
+
FOR_UNROLL (short i = 0; i < 16; i++) {
|
|
3556
|
+
yl[i] = yb[i];
|
|
3557
|
+
sumy += yb[i];
|
|
3558
|
+
}
|
|
3559
|
+
|
|
3560
|
+
FOR_UNROLL (short row = 0; row < nr0; row++) {
|
|
3561
|
+
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il);
|
|
3562
|
+
}
|
|
3563
|
+
|
|
3564
|
+
yb += QK1_0 * (N_SIMDWIDTH/8);
|
|
3565
|
+
}
|
|
3566
|
+
|
|
3567
|
+
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
|
3568
|
+
|
|
3569
|
+
for (int row = 0; row < nr0; ++row) {
|
|
3570
|
+
const float tot = simd_sum(sumf[row]);
|
|
3571
|
+
|
|
3572
|
+
if (tiisg == 0 && first_row + row < args.ne01) {
|
|
3573
|
+
dst_f32[first_row + row] = tot;
|
|
3574
|
+
}
|
|
3575
|
+
}
|
|
3576
|
+
}
|
|
3577
|
+
|
|
3578
|
+
[[host_name("kernel_mul_mv_q1_0_f32")]]
|
|
3579
|
+
kernel void kernel_mul_mv_q1_0_f32(
|
|
3580
|
+
constant ggml_metal_kargs_mul_mv & args,
|
|
3581
|
+
device const char * src0,
|
|
3582
|
+
device const char * src1,
|
|
3583
|
+
device char * dst,
|
|
3584
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3585
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
3586
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3587
|
+
kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
3588
|
+
}
|
|
3589
|
+
|
|
3324
3590
|
kernel void kernel_mul_mv_q4_0_f32(
|
|
3325
3591
|
constant ggml_metal_kargs_mul_mv & args,
|
|
3326
3592
|
device const char * src0,
|
|
@@ -3390,10 +3656,10 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
|
3390
3656
|
const int r1 = tgpig.y;
|
|
3391
3657
|
const int im = tgpig.z;
|
|
3392
3658
|
|
|
3393
|
-
const uint i12 = im%
|
|
3394
|
-
const uint i13 = im/
|
|
3659
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
3660
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
3395
3661
|
|
|
3396
|
-
//const uint64_t offset0 = r0*args.nb01 + (i12/
|
|
3662
|
+
//const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3397
3663
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
3398
3664
|
|
|
3399
3665
|
//device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
|
|
@@ -3402,7 +3668,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
|
3402
3668
|
// pointers to src0 rows
|
|
3403
3669
|
device const block_q8_0 * ax[NR0];
|
|
3404
3670
|
FOR_UNROLL (short row = 0; row < NR0; ++row) {
|
|
3405
|
-
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/
|
|
3671
|
+
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3406
3672
|
|
|
3407
3673
|
ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
|
|
3408
3674
|
}
|
|
@@ -3482,10 +3748,10 @@ void kernel_mul_mv_ext_q4_f32_impl(
|
|
|
3482
3748
|
const int i11 = tgpig.y*r1ptg;
|
|
3483
3749
|
const int i1m = tgpig.z;
|
|
3484
3750
|
|
|
3485
|
-
const int i12 = i1m%
|
|
3486
|
-
const int i13 = i1m/
|
|
3751
|
+
const int i12 = i1m%FC_mul_mv_ne12;
|
|
3752
|
+
const int i13 = i1m/FC_mul_mv_ne12;
|
|
3487
3753
|
|
|
3488
|
-
const uint64_t offset0 = i01*args.nb01 + (i12/
|
|
3754
|
+
const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3489
3755
|
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
3490
3756
|
|
|
3491
3757
|
device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
|
|
@@ -3585,10 +3851,10 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
|
|
|
3585
3851
|
const int i11 = tgpig.y*r1ptg;
|
|
3586
3852
|
const int i1m = tgpig.z;
|
|
3587
3853
|
|
|
3588
|
-
const int i12 = i1m%
|
|
3589
|
-
const int i13 = i1m/
|
|
3854
|
+
const int i12 = i1m%FC_mul_mv_ne12;
|
|
3855
|
+
const int i13 = i1m/FC_mul_mv_ne12;
|
|
3590
3856
|
|
|
3591
|
-
const uint64_t offset0 = i01*args.nb01 + (i12/
|
|
3857
|
+
const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3592
3858
|
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
3593
3859
|
|
|
3594
3860
|
device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
|
|
@@ -3713,6 +3979,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4
|
|
|
3713
3979
|
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
|
|
3714
3980
|
#endif
|
|
3715
3981
|
|
|
3982
|
+
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>;
|
|
3983
|
+
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>;
|
|
3984
|
+
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>;
|
|
3985
|
+
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>;
|
|
3986
|
+
|
|
3716
3987
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
|
|
3717
3988
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
|
|
3718
3989
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
|
|
@@ -3795,10 +4066,10 @@ void kernel_mul_mv_t_t_impl(
|
|
|
3795
4066
|
const int r1 = tgpig.y;
|
|
3796
4067
|
const int im = tgpig.z;
|
|
3797
4068
|
|
|
3798
|
-
const uint i12 = im%
|
|
3799
|
-
const uint i13 = im/
|
|
4069
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
4070
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
3800
4071
|
|
|
3801
|
-
//const uint64_t offset0 = r0*args.nb01 + (i12/
|
|
4072
|
+
//const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3802
4073
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
3803
4074
|
|
|
3804
4075
|
//device const T0 * x = (device const T0 *) (src0 + offset0);
|
|
@@ -3807,7 +4078,7 @@ void kernel_mul_mv_t_t_impl(
|
|
|
3807
4078
|
// pointers to src0 rows
|
|
3808
4079
|
device const T0 * ax [NR0];
|
|
3809
4080
|
FOR_UNROLL (short row = 0; row < NR0; ++row) {
|
|
3810
|
-
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/
|
|
4081
|
+
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3811
4082
|
|
|
3812
4083
|
ax[row] = (device const T0 *) ((device char *) src0 + offset0);
|
|
3813
4084
|
}
|
|
@@ -3917,10 +4188,10 @@ void kernel_mul_mv_t_t_4_impl(
|
|
|
3917
4188
|
const int r1 = tgpig.y;
|
|
3918
4189
|
const int im = tgpig.z;
|
|
3919
4190
|
|
|
3920
|
-
const uint i12 = im%
|
|
3921
|
-
const uint i13 = im/
|
|
4191
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
4192
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
3922
4193
|
|
|
3923
|
-
//const uint64_t offset0 = r0*args.nb01 + (i12/
|
|
4194
|
+
//const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3924
4195
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
3925
4196
|
|
|
3926
4197
|
device const T1 * y = (device const T1 *) (src1 + offset1);
|
|
@@ -3930,7 +4201,7 @@ void kernel_mul_mv_t_t_4_impl(
|
|
|
3930
4201
|
device const T0 * ax [NR0];
|
|
3931
4202
|
device const T04 * ax4[NR0];
|
|
3932
4203
|
FOR_UNROLL (short row = 0; row < NR0; ++row) {
|
|
3933
|
-
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/
|
|
4204
|
+
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3934
4205
|
|
|
3935
4206
|
ax [row] = (device const T0 *) ((device char *) src0 + offset0);
|
|
3936
4207
|
ax4[row] = (device const T04 *) ((device char *) src0 + offset0);
|
|
@@ -4034,10 +4305,10 @@ void kernel_mul_mv_t_t_short_impl(
|
|
|
4034
4305
|
return;
|
|
4035
4306
|
}
|
|
4036
4307
|
|
|
4037
|
-
const uint i12 = im%
|
|
4038
|
-
const uint i13 = im/
|
|
4308
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
4309
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
4039
4310
|
|
|
4040
|
-
const uint64_t offset0 = r0*args.nb01 + (i12/
|
|
4311
|
+
const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
4041
4312
|
|
|
4042
4313
|
device const T0 * x = (device const T0 *) (src0 + offset0);
|
|
4043
4314
|
|
|
@@ -4460,59 +4731,59 @@ kernel void kernel_im2col(
|
|
|
4460
4731
|
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
|
|
4461
4732
|
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
|
4462
4733
|
|
|
4463
|
-
// TODO:
|
|
4464
|
-
|
|
4465
|
-
|
|
4466
|
-
|
|
4467
|
-
|
|
4468
|
-
|
|
4469
|
-
|
|
4470
|
-
|
|
4471
|
-
|
|
4472
|
-
|
|
4473
|
-
|
|
4474
|
-
|
|
4475
|
-
|
|
4476
|
-
|
|
4477
|
-
|
|
4478
|
-
|
|
4479
|
-
|
|
4480
|
-
|
|
4481
|
-
|
|
4482
|
-
|
|
4483
|
-
|
|
4484
|
-
|
|
4485
|
-
|
|
4486
|
-
|
|
4487
|
-
|
|
4488
|
-
|
|
4489
|
-
|
|
4490
|
-
|
|
4491
|
-
|
|
4492
|
-
|
|
4493
|
-
|
|
4494
|
-
|
|
4495
|
-
|
|
4496
|
-
|
|
4497
|
-
|
|
4498
|
-
|
|
4499
|
-
|
|
4500
|
-
|
|
4501
|
-
|
|
4502
|
-
|
|
4503
|
-
|
|
4504
|
-
|
|
4505
|
-
|
|
4506
|
-
|
|
4507
|
-
|
|
4508
|
-
|
|
4509
|
-
|
|
4510
|
-
|
|
4511
|
-
|
|
4512
|
-
|
|
4513
|
-
|
|
4514
|
-
|
|
4515
|
-
|
|
4734
|
+
// TODO: optimize
|
|
4735
|
+
typedef void (im2col_ext_t)(
|
|
4736
|
+
constant ggml_metal_kargs_im2col & args,
|
|
4737
|
+
device const float * x,
|
|
4738
|
+
device char * dst,
|
|
4739
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4740
|
+
uint3 tgpg[[threadgroups_per_grid]],
|
|
4741
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
4742
|
+
uint3 ntg[[threads_per_threadgroup]]);
|
|
4743
|
+
|
|
4744
|
+
template <typename T>
|
|
4745
|
+
kernel void kernel_im2col_ext(
|
|
4746
|
+
constant ggml_metal_kargs_im2col & args,
|
|
4747
|
+
device const float * x,
|
|
4748
|
+
device char * dst,
|
|
4749
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4750
|
+
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
|
|
4751
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
4752
|
+
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
|
|
4753
|
+
const int64_t KHW = (int64_t)args.KHW;
|
|
4754
|
+
|
|
4755
|
+
const int64_t d = tgpig[0] / args.CHW;
|
|
4756
|
+
const int64_t chw = tgpig[0] % args.CHW;
|
|
4757
|
+
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
|
|
4758
|
+
const int64_t HW = tgpig[0] % KHW;
|
|
4759
|
+
|
|
4760
|
+
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
|
|
4761
|
+
if (tpitg_0 >= args.N) {
|
|
4762
|
+
return;
|
|
4763
|
+
}
|
|
4764
|
+
|
|
4765
|
+
const int64_t tpitg_1 = HW / args.KW;
|
|
4766
|
+
const int64_t tpitg_2 = HW % args.KW;
|
|
4767
|
+
|
|
4768
|
+
const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
|
|
4769
|
+
const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
|
|
4770
|
+
|
|
4771
|
+
const int64_t offset_dst =
|
|
4772
|
+
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
|
|
4773
|
+
(tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
|
|
4774
|
+
|
|
4775
|
+
device T * pdst = (device T *) (dst);
|
|
4776
|
+
|
|
4777
|
+
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
|
|
4778
|
+
pdst[offset_dst] = 0.0f;
|
|
4779
|
+
} else {
|
|
4780
|
+
const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
|
|
4781
|
+
pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
|
|
4782
|
+
}
|
|
4783
|
+
}
|
|
4784
|
+
|
|
4785
|
+
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
|
|
4786
|
+
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
|
|
4516
4787
|
|
|
4517
4788
|
template <typename TK>
|
|
4518
4789
|
kernel void kernel_conv_2d(
|
|
@@ -4645,15 +4916,32 @@ kernel void kernel_conv_transpose_1d(
|
|
|
4645
4916
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4646
4917
|
uint3 tgpg[[threadgroups_per_grid]]) {
|
|
4647
4918
|
|
|
4648
|
-
|
|
4919
|
+
// For output position j on the time axis, only input positions
|
|
4920
|
+
// i such that i*s0 <= j < i*s0 + K
|
|
4921
|
+
// contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)]
|
|
4922
|
+
// intersected with [0, IL-1]. That's at most ceil(K/s0) values
|
|
4923
|
+
// (typically 2 for stride==K/2 transposed convs).
|
|
4924
|
+
const int32_t j = tgpig[0];
|
|
4925
|
+
const int32_t s0 = args.s0;
|
|
4926
|
+
const int32_t K = args.K;
|
|
4927
|
+
const int32_t IL = args.IL;
|
|
4928
|
+
|
|
4929
|
+
int32_t i_min;
|
|
4930
|
+
{
|
|
4931
|
+
int32_t a = j - K + 1;
|
|
4932
|
+
i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0
|
|
4933
|
+
}
|
|
4934
|
+
int32_t i_max = j / s0;
|
|
4935
|
+
if (i_max > IL - 1) i_max = IL - 1;
|
|
4649
4936
|
|
|
4650
|
-
|
|
4651
|
-
|
|
4652
|
-
|
|
4937
|
+
float v = 0.0f;
|
|
4938
|
+
if (i_min <= i_max) {
|
|
4939
|
+
for (int64_t c = 0; c < args.IC; c++) {
|
|
4940
|
+
const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
|
|
4941
|
+
const int32_t input_offset = c * IL;
|
|
4653
4942
|
|
|
4654
|
-
|
|
4655
|
-
|
|
4656
|
-
v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
|
|
4943
|
+
for (int32_t i = i_min; i <= i_max; i++) {
|
|
4944
|
+
v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i];
|
|
4657
4945
|
}
|
|
4658
4946
|
}
|
|
4659
4947
|
}
|
|
@@ -4851,7 +5139,7 @@ kernel void kernel_upscale_bilinear_f32(
|
|
|
4851
5139
|
for (int64_t sx = x_min; sx < x_max; ++sx) {
|
|
4852
5140
|
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
|
|
4853
5141
|
const float w = wx * wy;
|
|
4854
|
-
|
|
5142
|
+
device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
|
|
4855
5143
|
sum += (*src_ptr) * w;
|
|
4856
5144
|
wsum += w;
|
|
4857
5145
|
}
|
|
@@ -4883,6 +5171,98 @@ kernel void kernel_upscale_bilinear_f32(
|
|
|
4883
5171
|
}
|
|
4884
5172
|
}
|
|
4885
5173
|
|
|
5174
|
+
template <typename T>
|
|
5175
|
+
kernel void kernel_conv_3d(
|
|
5176
|
+
constant ggml_metal_kargs_conv_3d & args,
|
|
5177
|
+
device const char * src0, // Weights [IC * OC, KD, KH, KW]
|
|
5178
|
+
device const char * src1, // Inputs [IC * N, ID, IH, IW]
|
|
5179
|
+
device char * dst, // Outputs [OC * N, OD, OH, OW]
|
|
5180
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5181
|
+
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
|
5182
|
+
|
|
5183
|
+
// 1. Un-flatten the spatial dimension from Grid X
|
|
5184
|
+
int64_t spatial_idx = tgpig.x * 32 + tpitg.x;
|
|
5185
|
+
|
|
5186
|
+
if (spatial_idx >= args.OW * args.OH * args.OD) {
|
|
5187
|
+
return; // Thread falls outside the spatial volume
|
|
5188
|
+
}
|
|
5189
|
+
|
|
5190
|
+
int64_t od = spatial_idx / (args.OW * args.OH);
|
|
5191
|
+
int64_t oh = (spatial_idx / args.OW) % args.OH;
|
|
5192
|
+
int64_t ow = spatial_idx % args.OW;
|
|
5193
|
+
|
|
5194
|
+
// 2. Map Y to Channels, Z to Batch
|
|
5195
|
+
int64_t oc = tgpig.y;
|
|
5196
|
+
int64_t batch_idx = tgpig.z;
|
|
5197
|
+
|
|
5198
|
+
// 3. Calculate anchor coordinates in the Input volume
|
|
5199
|
+
int64_t i_w_base = ow * args.s0 - args.p0;
|
|
5200
|
+
int64_t i_h_base = oh * args.s1 - args.p1;
|
|
5201
|
+
int64_t i_d_base = od * args.s2 - args.p2;
|
|
5202
|
+
|
|
5203
|
+
float sum = 0.0f;
|
|
5204
|
+
|
|
5205
|
+
// 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width)
|
|
5206
|
+
for (int64_t ic = 0; ic < args.IC; ++ic) {
|
|
5207
|
+
|
|
5208
|
+
// ggml packs batch and channel together in the 4th dimension
|
|
5209
|
+
int64_t src_cn_idx = batch_idx * args.IC + ic;
|
|
5210
|
+
int64_t w_cn_idx = oc * args.IC + ic;
|
|
5211
|
+
|
|
5212
|
+
for (int64_t kz = 0; kz < args.KD; ++kz) {
|
|
5213
|
+
int64_t id = i_d_base + kz * args.d2;
|
|
5214
|
+
if (id < 0 || id >= args.ID) continue; // Boundary check (Padding)
|
|
5215
|
+
|
|
5216
|
+
for (int64_t ky = 0; ky < args.KH; ++ky) {
|
|
5217
|
+
int64_t ih = i_h_base + ky * args.d1;
|
|
5218
|
+
if (ih < 0 || ih >= args.IH) continue;
|
|
5219
|
+
|
|
5220
|
+
for (int64_t kx = 0; kx < args.KW; ++kx) {
|
|
5221
|
+
int64_t iw = i_w_base + kx * args.d0;
|
|
5222
|
+
if (iw < 0 || iw >= args.IW) continue;
|
|
5223
|
+
|
|
5224
|
+
// Convert multi-dimensional coordinates to flat byte offsets
|
|
5225
|
+
int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03;
|
|
5226
|
+
int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13;
|
|
5227
|
+
|
|
5228
|
+
// Dereference memory and cast weights to f32 if they were f16
|
|
5229
|
+
float w_val = (float)*(device const T*)((device const char*)src0 + w_idx);
|
|
5230
|
+
float i_val = *(device const float*)((device const char*)src1 + i_idx);
|
|
5231
|
+
|
|
5232
|
+
sum += w_val * i_val;
|
|
5233
|
+
}
|
|
5234
|
+
}
|
|
5235
|
+
}
|
|
5236
|
+
}
|
|
5237
|
+
|
|
5238
|
+
// 5. Write the accumulated value out to RAM
|
|
5239
|
+
int64_t dst_cn_idx = batch_idx * args.OC + oc;
|
|
5240
|
+
int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3;
|
|
5241
|
+
|
|
5242
|
+
*(device float*)(dst + d_idx) = sum;
|
|
5243
|
+
}
|
|
5244
|
+
|
|
5245
|
+
// Explicit instantiations so the JIT compiler can find them by name
|
|
5246
|
+
template [[host_name("kernel_conv_3d_f32_f32")]]
|
|
5247
|
+
kernel void kernel_conv_3d<float>(
|
|
5248
|
+
constant ggml_metal_kargs_conv_3d & args,
|
|
5249
|
+
device const char * src0,
|
|
5250
|
+
device const char * src1,
|
|
5251
|
+
device char * dst,
|
|
5252
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5253
|
+
uint3 tpitg[[thread_position_in_threadgroup]]);
|
|
5254
|
+
|
|
5255
|
+
// Explicit instantiation for f16 weights
|
|
5256
|
+
template [[host_name("kernel_conv_3d_f16_f32")]]
|
|
5257
|
+
kernel void kernel_conv_3d<half>(
|
|
5258
|
+
constant ggml_metal_kargs_conv_3d & args,
|
|
5259
|
+
device const char * src0,
|
|
5260
|
+
device const char * src1,
|
|
5261
|
+
device char * dst,
|
|
5262
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5263
|
+
uint3 tpitg[[thread_position_in_threadgroup]]);
|
|
5264
|
+
|
|
5265
|
+
|
|
4886
5266
|
static inline float bicubic_weight1(float x) {
|
|
4887
5267
|
const float a = -0.75f;
|
|
4888
5268
|
return ((a + 2) * x - (a + 3)) * x * x + 1;
|
|
@@ -4941,7 +5321,7 @@ kernel void kernel_upscale_bicubic_f32(
|
|
|
4941
5321
|
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
|
|
4942
5322
|
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
|
|
4943
5323
|
|
|
4944
|
-
|
|
5324
|
+
device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
|
|
4945
5325
|
sum += (*src_ptr) * wx * wy;
|
|
4946
5326
|
}
|
|
4947
5327
|
}
|
|
@@ -4950,8 +5330,8 @@ kernel void kernel_upscale_bicubic_f32(
|
|
|
4950
5330
|
}
|
|
4951
5331
|
}
|
|
4952
5332
|
|
|
4953
|
-
kernel void
|
|
4954
|
-
constant
|
|
5333
|
+
kernel void kernel_roll_f32(
|
|
5334
|
+
constant ggml_metal_kargs_roll & args,
|
|
4955
5335
|
device const char * src0,
|
|
4956
5336
|
device char * dst,
|
|
4957
5337
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
@@ -4962,30 +5342,68 @@ kernel void kernel_pad_f32(
|
|
|
4962
5342
|
const int64_t i2 = tgpig.y;
|
|
4963
5343
|
const int64_t i1 = tgpig.x;
|
|
4964
5344
|
|
|
4965
|
-
const
|
|
4966
|
-
|
|
4967
|
-
const int64_t i01 = i1;
|
|
5345
|
+
device const float * src0_ptr = (device const float *) src0;
|
|
5346
|
+
device float * dst_ptr = (device float *) dst;
|
|
4968
5347
|
|
|
4969
|
-
|
|
4970
|
-
|
|
5348
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
5349
|
+
// apply shifts and wrap around
|
|
5350
|
+
int64_t i00 = i0 - args.s0;
|
|
5351
|
+
int64_t i01 = i1 - args.s1;
|
|
5352
|
+
int64_t i02 = i2 - args.s2;
|
|
5353
|
+
int64_t i03 = i3 - args.s3;
|
|
4971
5354
|
|
|
4972
|
-
|
|
4973
|
-
|
|
4974
|
-
|
|
4975
|
-
|
|
4976
|
-
} else {
|
|
4977
|
-
dst_ptr[i0] = 0.0f;
|
|
4978
|
-
}
|
|
4979
|
-
}
|
|
5355
|
+
if (i00 < 0) { i00 += args.ne00; } else if (i00 >= args.ne00) { i00 -= args.ne00; }
|
|
5356
|
+
if (i01 < 0) { i01 += args.ne01; } else if (i01 >= args.ne01) { i01 -= args.ne01; }
|
|
5357
|
+
if (i02 < 0) { i02 += args.ne02; } else if (i02 >= args.ne02) { i02 -= args.ne02; }
|
|
5358
|
+
if (i03 < 0) { i03 += args.ne03; } else if (i03 >= args.ne03) { i03 -= args.ne03; }
|
|
4980
5359
|
|
|
4981
|
-
|
|
5360
|
+
int64_t src_idx = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00 + i00;
|
|
5361
|
+
int64_t dst_idx = i3 *args.ne2 *args.ne1 *args.ne0 + i2 *args.ne1 *args.ne0 + i1 *args.ne0 + i0;
|
|
5362
|
+
|
|
5363
|
+
dst_ptr[dst_idx] = src0_ptr[src_idx];
|
|
4982
5364
|
}
|
|
5365
|
+
}
|
|
4983
5366
|
|
|
4984
|
-
|
|
4985
|
-
|
|
5367
|
+
template <typename T>
|
|
5368
|
+
kernel void kernel_pad_impl(
|
|
5369
|
+
constant ggml_metal_kargs_pad & args,
|
|
5370
|
+
device const char * src0,
|
|
5371
|
+
device char * dst,
|
|
5372
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5373
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
5374
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
5375
|
+
const int32_t i3 = tgpig.z;
|
|
5376
|
+
const int32_t i2 = tgpig.y;
|
|
5377
|
+
const int32_t k0 = tgpig.x/args.ne1;
|
|
5378
|
+
const int32_t i1 = tgpig.x - k0*args.ne1;
|
|
5379
|
+
|
|
5380
|
+
const int32_t i03 = i3;
|
|
5381
|
+
const int32_t i02 = i2;
|
|
5382
|
+
const int32_t i01 = i1;
|
|
5383
|
+
|
|
5384
|
+
device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
|
5385
|
+
device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
|
5386
|
+
|
|
5387
|
+
for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) {
|
|
5388
|
+
const int32_t i0 = k0*1024 + tpitg.x + l0;
|
|
5389
|
+
if (i0 >= args.ne0) {
|
|
5390
|
+
break;
|
|
5391
|
+
}
|
|
5392
|
+
|
|
5393
|
+
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
|
5394
|
+
dst_ptr[i0] = src0_ptr[i0];
|
|
5395
|
+
} else {
|
|
5396
|
+
dst_ptr[i0] = 0.0f;
|
|
5397
|
+
}
|
|
4986
5398
|
}
|
|
4987
5399
|
}
|
|
4988
5400
|
|
|
5401
|
+
typedef decltype(kernel_pad_impl<float>) kernel_pad_t;
|
|
5402
|
+
|
|
5403
|
+
template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>;
|
|
5404
|
+
template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;
|
|
5405
|
+
|
|
5406
|
+
// TODO: this is slow - optimize
|
|
4989
5407
|
kernel void kernel_pad_reflect_1d_f32(
|
|
4990
5408
|
constant ggml_metal_kargs_pad_reflect_1d & args,
|
|
4991
5409
|
device const char * src0,
|
|
@@ -6177,6 +6595,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_at
|
|
|
6177
6595
|
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
|
|
6178
6596
|
template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
|
|
6179
6597
|
template [[host_name("kernel_flash_attn_ext_f32_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 320, 256>;
|
|
6598
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 512, 512>;
|
|
6180
6599
|
template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
|
|
6181
6600
|
|
|
6182
6601
|
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
|
|
@@ -6192,6 +6611,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_at
|
|
|
6192
6611
|
template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
|
|
6193
6612
|
template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
|
|
6194
6613
|
template [[host_name("kernel_flash_attn_ext_f16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 320, 256>;
|
|
6614
|
+
template [[host_name("kernel_flash_attn_ext_f16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 512, 512>;
|
|
6195
6615
|
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
|
6196
6616
|
|
|
6197
6617
|
#if defined(GGML_METAL_HAS_BF16)
|
|
@@ -6208,6 +6628,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_at
|
|
|
6208
6628
|
template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
|
6209
6629
|
template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
|
6210
6630
|
template [[host_name("kernel_flash_attn_ext_bf16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 320, 256>;
|
|
6631
|
+
template [[host_name("kernel_flash_attn_ext_bf16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 512, 512>;
|
|
6211
6632
|
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
|
6212
6633
|
#endif
|
|
6213
6634
|
|
|
@@ -6224,6 +6645,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_at
|
|
|
6224
6645
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
|
|
6225
6646
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
|
6226
6647
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 320, 256>;
|
|
6648
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 512, 512>;
|
|
6227
6649
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
|
6228
6650
|
|
|
6229
6651
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
|
|
@@ -6239,6 +6661,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_at
|
|
|
6239
6661
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
|
|
6240
6662
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
|
6241
6663
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 320, 256>;
|
|
6664
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 512, 512>;
|
|
6242
6665
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
|
6243
6666
|
|
|
6244
6667
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
|
|
@@ -6254,6 +6677,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_at
|
|
|
6254
6677
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
|
|
6255
6678
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
|
6256
6679
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 320, 256>;
|
|
6680
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 512, 512>;
|
|
6257
6681
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
|
6258
6682
|
|
|
6259
6683
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
|
|
@@ -6269,6 +6693,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_at
|
|
|
6269
6693
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
|
|
6270
6694
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
|
6271
6695
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 320, 256>;
|
|
6696
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 512, 512>;
|
|
6272
6697
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
|
6273
6698
|
|
|
6274
6699
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
|
|
@@ -6284,6 +6709,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_at
|
|
|
6284
6709
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
|
|
6285
6710
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
|
|
6286
6711
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 320, 256>;
|
|
6712
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 512, 512>;
|
|
6287
6713
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
|
|
6288
6714
|
|
|
6289
6715
|
#undef FA_TYPES
|
|
@@ -6865,6 +7291,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flas
|
|
|
6865
7291
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 320, 256, 2>;
|
|
6866
7292
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 320, 256, 2>;
|
|
6867
7293
|
|
|
7294
|
+
template [[host_name("kernel_flash_attn_ext_vec_f32_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 512, 512, 1>;
|
|
7295
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 512, 512, 1>;
|
|
7296
|
+
#if defined(GGML_METAL_HAS_BF16)
|
|
7297
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 512, 512, 1>;
|
|
7298
|
+
#endif
|
|
7299
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 512, 512, 1>;
|
|
7300
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 512, 512, 1>;
|
|
7301
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 512, 512, 1>;
|
|
7302
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 512, 512, 1>;
|
|
7303
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 512, 512, 1>;
|
|
7304
|
+
|
|
6868
7305
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
|
|
6869
7306
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
|
|
6870
7307
|
#if defined(GGML_METAL_HAS_BF16)
|
|
@@ -6930,23 +7367,27 @@ kernel void kernel_cpy_t_t(
|
|
|
6930
7367
|
device const char * src0,
|
|
6931
7368
|
device char * dst,
|
|
6932
7369
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6933
|
-
|
|
7370
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
6934
7371
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
6935
|
-
const
|
|
6936
|
-
const
|
|
6937
|
-
const
|
|
6938
|
-
const
|
|
7372
|
+
const int32_t i03 = tgpig[2];
|
|
7373
|
+
const int32_t i02 = tgpig[1];
|
|
7374
|
+
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
|
7375
|
+
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
|
7376
|
+
|
|
7377
|
+
if (i01 >= args.ne01) {
|
|
7378
|
+
return;
|
|
7379
|
+
}
|
|
6939
7380
|
|
|
6940
7381
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
6941
7382
|
|
|
6942
|
-
const
|
|
6943
|
-
const
|
|
6944
|
-
const
|
|
6945
|
-
const
|
|
7383
|
+
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
|
7384
|
+
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
|
7385
|
+
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
|
7386
|
+
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
|
6946
7387
|
|
|
6947
7388
|
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
6948
7389
|
|
|
6949
|
-
for (
|
|
7390
|
+
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) {
|
|
6950
7391
|
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
6951
7392
|
dst_data[i00] = (T1) src[0];
|
|
6952
7393
|
break;
|
|
@@ -6978,23 +7419,27 @@ kernel void kernel_cpy_f32_q(
|
|
|
6978
7419
|
device const char * src0,
|
|
6979
7420
|
device char * dst,
|
|
6980
7421
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6981
|
-
|
|
7422
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
6982
7423
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
6983
|
-
const
|
|
6984
|
-
const
|
|
6985
|
-
const
|
|
6986
|
-
const
|
|
7424
|
+
const int32_t i03 = tgpig[2];
|
|
7425
|
+
const int32_t i02 = tgpig[1];
|
|
7426
|
+
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
|
7427
|
+
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
|
7428
|
+
|
|
7429
|
+
if (i01 >= args.ne01) {
|
|
7430
|
+
return;
|
|
7431
|
+
}
|
|
6987
7432
|
|
|
6988
7433
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
6989
7434
|
|
|
6990
|
-
const
|
|
6991
|
-
const
|
|
6992
|
-
const
|
|
6993
|
-
const
|
|
7435
|
+
const int32_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
|
7436
|
+
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
|
7437
|
+
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
|
7438
|
+
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
|
|
6994
7439
|
|
|
6995
7440
|
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
6996
7441
|
|
|
6997
|
-
for (
|
|
7442
|
+
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
|
|
6998
7443
|
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
|
|
6999
7444
|
|
|
7000
7445
|
quantize_func(src, dst_data[i00]);
|
|
@@ -7006,6 +7451,7 @@ kernel void kernel_cpy_f32_q(
|
|
|
7006
7451
|
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
|
|
7007
7452
|
|
|
7008
7453
|
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
|
|
7454
|
+
template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
|
|
7009
7455
|
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
|
|
7010
7456
|
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
|
|
7011
7457
|
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
|
|
@@ -7018,24 +7464,28 @@ kernel void kernel_cpy_q_f32(
|
|
|
7018
7464
|
device const char * src0,
|
|
7019
7465
|
device char * dst,
|
|
7020
7466
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
7021
|
-
|
|
7467
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
7022
7468
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
7023
|
-
const
|
|
7024
|
-
const
|
|
7025
|
-
const
|
|
7026
|
-
const
|
|
7469
|
+
const int32_t i03 = tgpig[2];
|
|
7470
|
+
const int32_t i02 = tgpig[1];
|
|
7471
|
+
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
|
7472
|
+
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
|
7473
|
+
|
|
7474
|
+
if (i01 >= args.ne01) {
|
|
7475
|
+
return;
|
|
7476
|
+
}
|
|
7027
7477
|
|
|
7028
7478
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
7029
7479
|
|
|
7030
|
-
const
|
|
7031
|
-
const
|
|
7032
|
-
const
|
|
7033
|
-
const
|
|
7480
|
+
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
|
7481
|
+
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
|
7482
|
+
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
|
7483
|
+
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
|
7034
7484
|
|
|
7035
7485
|
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
|
7036
7486
|
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
7037
7487
|
|
|
7038
|
-
for (
|
|
7488
|
+
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
|
|
7039
7489
|
T4x4 temp;
|
|
7040
7490
|
dequantize_func(src_data + i00/nl, i00%nl, temp);
|
|
7041
7491
|
dst_data[i00] = temp;
|
|
@@ -7046,12 +7496,14 @@ kernel void kernel_cpy_q_f32(
|
|
|
7046
7496
|
|
|
7047
7497
|
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
|
|
7048
7498
|
|
|
7499
|
+
template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
|
|
7049
7500
|
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
|
|
7050
7501
|
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
|
|
7051
7502
|
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
|
|
7052
7503
|
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
|
|
7053
7504
|
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
|
|
7054
7505
|
|
|
7506
|
+
template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
|
|
7055
7507
|
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
|
|
7056
7508
|
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
|
|
7057
7509
|
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
|
|
@@ -7069,7 +7521,11 @@ kernel void kernel_concat(
|
|
|
7069
7521
|
|
|
7070
7522
|
const int i3 = tgpig.z;
|
|
7071
7523
|
const int i2 = tgpig.y;
|
|
7072
|
-
const int i1 = tgpig.x;
|
|
7524
|
+
const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y;
|
|
7525
|
+
|
|
7526
|
+
if (i1 >= args.ne1) {
|
|
7527
|
+
return;
|
|
7528
|
+
}
|
|
7073
7529
|
|
|
7074
7530
|
int o[4] = {0, 0, 0, 0};
|
|
7075
7531
|
o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
|
|
@@ -7109,10 +7565,10 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
|
7109
7565
|
|
|
7110
7566
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
7111
7567
|
|
|
7112
|
-
const uint i12 = im%
|
|
7113
|
-
const uint i13 = im/
|
|
7568
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
7569
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
7114
7570
|
|
|
7115
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
7571
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
7116
7572
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
7117
7573
|
|
|
7118
7574
|
device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
|
|
@@ -7214,10 +7670,10 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
|
7214
7670
|
|
|
7215
7671
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
7216
7672
|
|
|
7217
|
-
const uint i12 = im%
|
|
7218
|
-
const uint i13 = im/
|
|
7673
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
7674
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
7219
7675
|
|
|
7220
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
7676
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
7221
7677
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
7222
7678
|
|
|
7223
7679
|
device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
|
|
@@ -7388,10 +7844,10 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
|
7388
7844
|
|
|
7389
7845
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
7390
7846
|
|
|
7391
|
-
const uint i12 = im%
|
|
7392
|
-
const uint i13 = im/
|
|
7847
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
7848
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
7393
7849
|
|
|
7394
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
7850
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
7395
7851
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
7396
7852
|
|
|
7397
7853
|
device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
|
|
@@ -7500,10 +7956,10 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
|
7500
7956
|
|
|
7501
7957
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
7502
7958
|
|
|
7503
|
-
const uint i12 = im%
|
|
7504
|
-
const uint i13 = im/
|
|
7959
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
7960
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
7505
7961
|
|
|
7506
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
7962
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
7507
7963
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
7508
7964
|
|
|
7509
7965
|
device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
|
|
@@ -7636,10 +8092,10 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
|
7636
8092
|
|
|
7637
8093
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
7638
8094
|
|
|
7639
|
-
const uint i12 = im%
|
|
7640
|
-
const uint i13 = im/
|
|
8095
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8096
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
7641
8097
|
|
|
7642
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8098
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
7643
8099
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
7644
8100
|
|
|
7645
8101
|
device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
|
|
@@ -7741,10 +8197,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
|
7741
8197
|
|
|
7742
8198
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
7743
8199
|
|
|
7744
|
-
const uint i12 = im%
|
|
7745
|
-
const uint i13 = im/
|
|
8200
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8201
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
7746
8202
|
|
|
7747
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8203
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
7748
8204
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
7749
8205
|
|
|
7750
8206
|
device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
|
|
@@ -7849,10 +8305,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
|
7849
8305
|
|
|
7850
8306
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
7851
8307
|
|
|
7852
|
-
const uint i12 = im%
|
|
7853
|
-
const uint i13 = im/
|
|
8308
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8309
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
7854
8310
|
|
|
7855
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8311
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
7856
8312
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
7857
8313
|
|
|
7858
8314
|
device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
|
|
@@ -7968,10 +8424,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
|
7968
8424
|
|
|
7969
8425
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
7970
8426
|
|
|
7971
|
-
const uint i12 = im%
|
|
7972
|
-
const uint i13 = im/
|
|
8427
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8428
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
7973
8429
|
|
|
7974
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8430
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
7975
8431
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
7976
8432
|
|
|
7977
8433
|
device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
|
|
@@ -8080,10 +8536,10 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
|
8080
8536
|
|
|
8081
8537
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
8082
8538
|
|
|
8083
|
-
const uint i12 = im%
|
|
8084
|
-
const uint i13 = im/
|
|
8539
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8540
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
8085
8541
|
|
|
8086
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8542
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
8087
8543
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
8088
8544
|
|
|
8089
8545
|
device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
|
|
@@ -8192,10 +8648,10 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
|
8192
8648
|
|
|
8193
8649
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
8194
8650
|
|
|
8195
|
-
const uint i12 = im%
|
|
8196
|
-
const uint i13 = im/
|
|
8651
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8652
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
8197
8653
|
|
|
8198
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8654
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
8199
8655
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
8200
8656
|
|
|
8201
8657
|
device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
|
|
@@ -8305,10 +8761,10 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
|
8305
8761
|
|
|
8306
8762
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
8307
8763
|
|
|
8308
|
-
const uint i12 = im%
|
|
8309
|
-
const uint i13 = im/
|
|
8764
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8765
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
8310
8766
|
|
|
8311
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8767
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
8312
8768
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
8313
8769
|
|
|
8314
8770
|
device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
|
|
@@ -8404,10 +8860,10 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
|
8404
8860
|
|
|
8405
8861
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
8406
8862
|
|
|
8407
|
-
const uint i12 = im%
|
|
8408
|
-
const uint i13 = im/
|
|
8863
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8864
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
8409
8865
|
|
|
8410
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8866
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
8411
8867
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
8412
8868
|
|
|
8413
8869
|
device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
|
|
@@ -8513,10 +8969,10 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
|
8513
8969
|
|
|
8514
8970
|
const int first_row = (r0 * NSG + sgitg) * NR0;
|
|
8515
8971
|
|
|
8516
|
-
const uint i12 = im%
|
|
8517
|
-
const uint i13 = im/
|
|
8972
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8973
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
8518
8974
|
|
|
8519
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8975
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
8520
8976
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
8521
8977
|
|
|
8522
8978
|
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
|
|
@@ -8622,10 +9078,10 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
|
8622
9078
|
const int im = tgpig.z;
|
|
8623
9079
|
const int first_row = (r0 * NSG + sgitg) * NR0;
|
|
8624
9080
|
|
|
8625
|
-
const uint i12 = im%
|
|
8626
|
-
const uint i13 = im/
|
|
9081
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
9082
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
8627
9083
|
|
|
8628
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
9084
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
8629
9085
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
8630
9086
|
|
|
8631
9087
|
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
|
|
@@ -8733,10 +9189,10 @@ void kernel_mul_mv_mxfp4_f32_impl(
|
|
|
8733
9189
|
|
|
8734
9190
|
const int first_row = (r0 * NSG + sgitg) * NR0;
|
|
8735
9191
|
|
|
8736
|
-
const uint i12 = im%
|
|
8737
|
-
const uint i13 = im/
|
|
9192
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
9193
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
8738
9194
|
|
|
8739
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
9195
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
8740
9196
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
8741
9197
|
|
|
8742
9198
|
device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
|
|
@@ -8951,9 +9407,143 @@ kernel void kernel_diag_f32(
|
|
|
8951
9407
|
|
|
8952
9408
|
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
|
|
8953
9409
|
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
|
|
9410
|
+
constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]];
|
|
9411
|
+
constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]];
|
|
9412
|
+
constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]];
|
|
9413
|
+
constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]];
|
|
8954
9414
|
|
|
8955
9415
|
// each block_q contains 16*nl weights
|
|
8956
|
-
|
|
9416
|
+
#ifdef GGML_METAL_HAS_TENSOR
|
|
9417
|
+
template<
|
|
9418
|
+
typename SA, typename SA_4x4, typename SA_8x8,
|
|
9419
|
+
typename SB, typename SB_2x4, typename SB_8x8,
|
|
9420
|
+
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &),
|
|
9421
|
+
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
|
9422
|
+
kernel void kernel_mul_mm(
|
|
9423
|
+
constant ggml_metal_kargs_mul_mm & args,
|
|
9424
|
+
device const char * srcA,
|
|
9425
|
+
device const char * srcB,
|
|
9426
|
+
device char * dst,
|
|
9427
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
9428
|
+
uint3 tgpig [[threadgroup_position_in_grid]],
|
|
9429
|
+
ushort tiitg [[thread_index_in_threadgroup]],
|
|
9430
|
+
ushort sgitg [[simdgroup_index_in_threadgroup]]) {
|
|
9431
|
+
(void) sgitg;
|
|
9432
|
+
|
|
9433
|
+
// Matrix dimensions: A(M,K) x B(K,N) -> C(M,N)
|
|
9434
|
+
const int K = args.ne00;
|
|
9435
|
+
const int M = args.ne0;
|
|
9436
|
+
const int N = args.ne1;
|
|
9437
|
+
|
|
9438
|
+
// Batch dimension handling
|
|
9439
|
+
const int im = tgpig.z;
|
|
9440
|
+
const int i12 = im % FC_mul_mm_ne12;
|
|
9441
|
+
const int i13 = im / FC_mul_mm_ne12;
|
|
9442
|
+
|
|
9443
|
+
// Batch offsets for srcA and srcB
|
|
9444
|
+
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
|
|
9445
|
+
|
|
9446
|
+
// Tile dimensions
|
|
9447
|
+
constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;
|
|
9448
|
+
constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y;
|
|
9449
|
+
|
|
9450
|
+
// Tile offsets in output matrix
|
|
9451
|
+
const int ra = tgpig.y * NRA;
|
|
9452
|
+
const int rb = tgpig.x * NRB;
|
|
9453
|
+
|
|
9454
|
+
// Threadgroup memory for dequantized A tile only
|
|
9455
|
+
threadgroup SA * sa = (threadgroup SA *)(shmem);
|
|
9456
|
+
|
|
9457
|
+
// Work-item count for A loading
|
|
9458
|
+
constexpr int A_WORK_ITEMS = NRA * N_MM_NK;
|
|
9459
|
+
constexpr int NUM_THREADS = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y;
|
|
9460
|
+
|
|
9461
|
+
// tA wraps threadgroup memory
|
|
9462
|
+
auto tA = tensor(sa, dextents<int32_t, 2>(N_MM_NK_TOTAL, NRA));
|
|
9463
|
+
|
|
9464
|
+
// tB wraps device memory directly
|
|
9465
|
+
device T1 * ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13);
|
|
9466
|
+
const int strideB = args.nb11 / sizeof(T1);
|
|
9467
|
+
auto tB = tensor(ptrB, dextents<int32_t, 2>(K, N), array<int, 2>({1, strideB}));
|
|
9468
|
+
|
|
9469
|
+
// Configure matmul operation
|
|
9470
|
+
mpp::tensor_ops::matmul2d<
|
|
9471
|
+
mpp::tensor_ops::matmul2d_descriptor(
|
|
9472
|
+
NRB, NRA, N_MM_NK_TOTAL, false, true, true,
|
|
9473
|
+
mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
|
|
9474
|
+
execution_simdgroups<N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y>> mm;
|
|
9475
|
+
|
|
9476
|
+
auto cT = mm.get_destination_cooperative_tensor<decltype(tB), decltype(tA), float>();
|
|
9477
|
+
|
|
9478
|
+
// Accumulate partial results over K dimension
|
|
9479
|
+
for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) {
|
|
9480
|
+
// === PHASE 1: Dequantization of A into threadgroup memory ===
|
|
9481
|
+
for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) {
|
|
9482
|
+
const int row = work / N_MM_NK;
|
|
9483
|
+
const int k_chunk = work % N_MM_NK;
|
|
9484
|
+
const int k_pos = loop_k + k_chunk * 16;
|
|
9485
|
+
const short k_base = k_chunk * 16;
|
|
9486
|
+
|
|
9487
|
+
// Bounds check: skip device read if row is out of matrix bounds
|
|
9488
|
+
if (ra + row < M) {
|
|
9489
|
+
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
|
9490
|
+
// Element-wise reads when K is not aligned (nb01 not aligned for half4x4/float4x4).
|
|
9491
|
+
// MSL spec Table 2.5: half4x4 requires 8-byte alignment. When K is odd,
|
|
9492
|
+
// nb01 = K*2 is not 8-byte aligned, so odd-row pointers are misaligned.
|
|
9493
|
+
// Mirrors the legacy kernel's existing guard.
|
|
9494
|
+
device const T0 * row_ptr = (device const T0 *)(srcA + args.nb01 * (ra + row) + offset0);
|
|
9495
|
+
|
|
9496
|
+
FOR_UNROLL (short i = 0; i < 16; i++) {
|
|
9497
|
+
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? (SA) row_ptr[k_pos + i] : (SA)0;
|
|
9498
|
+
}
|
|
9499
|
+
} else {
|
|
9500
|
+
const int block_idx = k_pos / (16 * nl);
|
|
9501
|
+
const short il = (k_pos / 16) % nl;
|
|
9502
|
+
|
|
9503
|
+
device const block_q * row_ptr = (device const block_q *)(srcA + args.nb01 * (ra + row) + offset0);
|
|
9504
|
+
|
|
9505
|
+
SA_4x4 temp_a;
|
|
9506
|
+
dequantize_func(row_ptr + block_idx, il, temp_a);
|
|
9507
|
+
|
|
9508
|
+
FOR_UNROLL (short i = 0; i < 16; i++) {
|
|
9509
|
+
// Zero-pad A for K positions beyond valid range (handles partial K iterations)
|
|
9510
|
+
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0;
|
|
9511
|
+
}
|
|
9512
|
+
}
|
|
9513
|
+
} else {
|
|
9514
|
+
// Zero-pad rows beyond matrix bounds
|
|
9515
|
+
FOR_UNROLL (short i = 0; i < 16; i++) {
|
|
9516
|
+
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (SA)0;
|
|
9517
|
+
}
|
|
9518
|
+
}
|
|
9519
|
+
}
|
|
9520
|
+
|
|
9521
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
9522
|
+
|
|
9523
|
+
// === PHASE 2: Tensor matmul ===
|
|
9524
|
+
auto mA = tA.slice(0, 0);
|
|
9525
|
+
auto mB = tB.slice(loop_k, rb);
|
|
9526
|
+
|
|
9527
|
+
mm.run(mB, mA, cT);
|
|
9528
|
+
|
|
9529
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
9530
|
+
}
|
|
9531
|
+
|
|
9532
|
+
// Store result tile to output matrix (with batch offset)
|
|
9533
|
+
// cT.store handles bounds checking via tD's extents (M, N)
|
|
9534
|
+
device float * dstBatch = (device float *)dst + im * N * M;
|
|
9535
|
+
|
|
9536
|
+
auto tD = tensor(dstBatch, dextents<int32_t, 2>(M, N), array<int, 2>({1, M}));
|
|
9537
|
+
cT.store(tD.slice(ra, rb));
|
|
9538
|
+
}
|
|
9539
|
+
|
|
9540
|
+
#else
|
|
9541
|
+
|
|
9542
|
+
template<
|
|
9543
|
+
typename S0, typename S0_4x4, typename S0_8x8,
|
|
9544
|
+
typename S1, typename S1_2x4, typename S1_8x8,
|
|
9545
|
+
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &),
|
|
9546
|
+
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
|
8957
9547
|
kernel void kernel_mul_mm(
|
|
8958
9548
|
constant ggml_metal_kargs_mul_mm & args,
|
|
8959
9549
|
device const char * src0,
|
|
@@ -8967,10 +9557,6 @@ kernel void kernel_mul_mm(
|
|
|
8967
9557
|
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
|
8968
9558
|
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
|
8969
9559
|
|
|
8970
|
-
#ifdef GGML_METAL_HAS_TENSOR
|
|
8971
|
-
threadgroup float * sc = (threadgroup float *)(shmem);
|
|
8972
|
-
#endif
|
|
8973
|
-
|
|
8974
9560
|
constexpr int NR0 = 64;
|
|
8975
9561
|
constexpr int NR1 = 32;
|
|
8976
9562
|
|
|
@@ -8994,10 +9580,10 @@ kernel void kernel_mul_mm(
|
|
|
8994
9580
|
|
|
8995
9581
|
short il = il0;
|
|
8996
9582
|
|
|
8997
|
-
const int i12 = im%
|
|
8998
|
-
const int i13 = im/
|
|
9583
|
+
const int i12 = im % FC_mul_mm_ne12;
|
|
9584
|
+
const int i13 = im / FC_mul_mm_ne12;
|
|
8999
9585
|
|
|
9000
|
-
const uint64_t offset0 = (i12/
|
|
9586
|
+
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
|
|
9001
9587
|
const short offset1 = il0/nl;
|
|
9002
9588
|
|
|
9003
9589
|
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
|
|
@@ -9010,7 +9596,6 @@ kernel void kernel_mul_mm(
|
|
|
9010
9596
|
+ args.nb11*(r1 + lr1)
|
|
9011
9597
|
+ args.nb10*iy);
|
|
9012
9598
|
|
|
9013
|
-
#ifndef GGML_METAL_HAS_TENSOR
|
|
9014
9599
|
S0_8x8 ma[4];
|
|
9015
9600
|
S1_8x8 mb[2];
|
|
9016
9601
|
|
|
@@ -9019,19 +9604,8 @@ kernel void kernel_mul_mm(
|
|
|
9019
9604
|
for (short i = 0; i < 8; i++){
|
|
9020
9605
|
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
|
9021
9606
|
}
|
|
9022
|
-
#else
|
|
9023
|
-
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
|
|
9024
|
-
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
|
|
9025
|
-
|
|
9026
|
-
mpp::tensor_ops::matmul2d<
|
|
9027
|
-
mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
|
|
9028
|
-
execution_simdgroups<4>> mm;
|
|
9029
|
-
|
|
9030
|
-
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
|
|
9031
|
-
#endif
|
|
9032
9607
|
|
|
9033
9608
|
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
|
|
9034
|
-
#ifndef GGML_METAL_HAS_TENSOR
|
|
9035
9609
|
// load data and store to threadgroup memory
|
|
9036
9610
|
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
|
9037
9611
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
@@ -9101,66 +9675,6 @@ kernel void kernel_mul_mm(
|
|
|
9101
9675
|
|
|
9102
9676
|
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
|
|
9103
9677
|
}
|
|
9104
|
-
#else
|
|
9105
|
-
// load data and store to threadgroup memory
|
|
9106
|
-
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
|
9107
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
9108
|
-
|
|
9109
|
-
// no need for dequantization
|
|
9110
|
-
for (short i = 0; i < 16; i++) {
|
|
9111
|
-
const short sx = 2*il0 + i/8;
|
|
9112
|
-
const short sy = (tiitg/NL0)/8;
|
|
9113
|
-
|
|
9114
|
-
const short lx = i%8;
|
|
9115
|
-
const short ly = (tiitg/NL0)%8;
|
|
9116
|
-
//const short lx = (tiitg/NL0)%8;
|
|
9117
|
-
//const short ly = i%8;
|
|
9118
|
-
|
|
9119
|
-
*(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
|
|
9120
|
-
}
|
|
9121
|
-
} else {
|
|
9122
|
-
S0_4x4 temp_a;
|
|
9123
|
-
dequantize_func(x, il, temp_a);
|
|
9124
|
-
|
|
9125
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
9126
|
-
|
|
9127
|
-
FOR_UNROLL (short i = 0; i < 16; i++) {
|
|
9128
|
-
const short sx = 2*il0 + i/8;
|
|
9129
|
-
const short sy = (tiitg/NL0)/8;
|
|
9130
|
-
|
|
9131
|
-
const short lx = i%8;
|
|
9132
|
-
const short ly = (tiitg/NL0)%8;
|
|
9133
|
-
//const short lx = (tiitg/NL0)%8;
|
|
9134
|
-
//const short ly = i%8;
|
|
9135
|
-
|
|
9136
|
-
*(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
|
|
9137
|
-
}
|
|
9138
|
-
}
|
|
9139
|
-
|
|
9140
|
-
if (FC_mul_mm_bc_inp) {
|
|
9141
|
-
for (short i = 0; i < 8; ++i) {
|
|
9142
|
-
const short sx = (tiitg%NL1);
|
|
9143
|
-
const short sy = (tiitg/NL1)/8;
|
|
9144
|
-
|
|
9145
|
-
const short lx = i;
|
|
9146
|
-
const short ly = (tiitg/NL1)%8;
|
|
9147
|
-
//const short lx = (tiitg/NL1)%8;
|
|
9148
|
-
//const short ly = i;
|
|
9149
|
-
|
|
9150
|
-
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
|
|
9151
|
-
}
|
|
9152
|
-
} else {
|
|
9153
|
-
const short sx = (tiitg%NL1);
|
|
9154
|
-
const short sy = (tiitg/NL1)/8;
|
|
9155
|
-
|
|
9156
|
-
//const short lx = i;
|
|
9157
|
-
const short ly = (tiitg/NL1)%8;
|
|
9158
|
-
//const short lx = (tiitg/NL1)%8;
|
|
9159
|
-
//const short ly = i;
|
|
9160
|
-
|
|
9161
|
-
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
|
|
9162
|
-
}
|
|
9163
|
-
#endif
|
|
9164
9678
|
|
|
9165
9679
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
|
9166
9680
|
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
|
@@ -9169,7 +9683,6 @@ kernel void kernel_mul_mm(
|
|
|
9169
9683
|
|
|
9170
9684
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
9171
9685
|
|
|
9172
|
-
#ifndef GGML_METAL_HAS_TENSOR
|
|
9173
9686
|
// load matrices from threadgroup memory and conduct outer products
|
|
9174
9687
|
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
|
|
9175
9688
|
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
|
|
@@ -9196,24 +9709,10 @@ kernel void kernel_mul_mm(
|
|
|
9196
9709
|
lsma += 8*64;
|
|
9197
9710
|
lsmb += 4*64;
|
|
9198
9711
|
}
|
|
9199
|
-
#else
|
|
9200
|
-
auto sA = tA.slice(0, 0);
|
|
9201
|
-
auto sB = tB.slice(0, 0);
|
|
9202
|
-
|
|
9203
|
-
mm.run(sB, sA, cT);
|
|
9204
|
-
#endif
|
|
9205
9712
|
}
|
|
9206
9713
|
|
|
9207
9714
|
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
|
|
9208
9715
|
// if no bounds checks on the output are needed, we can directly write to device memory
|
|
9209
|
-
#ifdef GGML_METAL_HAS_TENSOR
|
|
9210
|
-
device float * C = (device float *) dst +
|
|
9211
|
-
r0 + \
|
|
9212
|
-
r1 * args.ne0 + im*args.ne1*args.ne0;
|
|
9213
|
-
|
|
9214
|
-
auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1));
|
|
9215
|
-
cT.store(tC);
|
|
9216
|
-
#else
|
|
9217
9716
|
device float * C = (device float *) dst +
|
|
9218
9717
|
(r0 + 32*(sgitg & 1)) + \
|
|
9219
9718
|
(r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
|
|
@@ -9221,21 +9720,15 @@ kernel void kernel_mul_mm(
|
|
|
9221
9720
|
for (short i = 0; i < 8; i++) {
|
|
9222
9721
|
simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
|
|
9223
9722
|
}
|
|
9224
|
-
#endif
|
|
9225
9723
|
} else {
|
|
9226
9724
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
|
9227
9725
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
9228
9726
|
|
|
9229
9727
|
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
|
|
9230
9728
|
|
|
9231
|
-
#ifdef GGML_METAL_HAS_TENSOR
|
|
9232
|
-
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
|
|
9233
|
-
cT.store(tC);
|
|
9234
|
-
#else
|
|
9235
9729
|
for (short i = 0; i < 8; i++) {
|
|
9236
9730
|
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
|
|
9237
9731
|
}
|
|
9238
|
-
#endif
|
|
9239
9732
|
|
|
9240
9733
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
9241
9734
|
|
|
@@ -9261,6 +9754,8 @@ kernel void kernel_mul_mm(
|
|
|
9261
9754
|
}
|
|
9262
9755
|
}
|
|
9263
9756
|
|
|
9757
|
+
#endif // GGML_METAL_HAS_TENSOR
|
|
9758
|
+
|
|
9264
9759
|
template<short ne20> // n_expert_used
|
|
9265
9760
|
kernel void kernel_mul_mm_id_map0(
|
|
9266
9761
|
constant ggml_metal_kargs_mul_mm_id_map0 & args,
|
|
@@ -9436,7 +9931,7 @@ kernel void kernel_mul_mm_id(
|
|
|
9436
9931
|
|
|
9437
9932
|
const short ib = 8*sx + sy;
|
|
9438
9933
|
|
|
9439
|
-
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
|
|
9934
|
+
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? (S0) *((device T0 *) x + i) : (S0) 0;
|
|
9440
9935
|
}
|
|
9441
9936
|
} else {
|
|
9442
9937
|
S0_4x4 temp_a;
|
|
@@ -9649,6 +10144,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro
|
|
|
9649
10144
|
|
|
9650
10145
|
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
|
9651
10146
|
|
|
10147
|
+
template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
|
|
9652
10148
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
|
|
9653
10149
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
|
|
9654
10150
|
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
|
@@ -9711,6 +10207,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m
|
|
|
9711
10207
|
#if defined(GGML_METAL_HAS_BF16)
|
|
9712
10208
|
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
|
9713
10209
|
#endif
|
|
10210
|
+
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
|
|
9714
10211
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
|
9715
10212
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
|
9716
10213
|
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
|
@@ -9734,6 +10231,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
|
|
|
9734
10231
|
|
|
9735
10232
|
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
|
9736
10233
|
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
|
10234
|
+
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
|
|
9737
10235
|
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
|
9738
10236
|
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
|
9739
10237
|
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
|
@@ -9766,6 +10264,7 @@ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_m
|
|
|
9766
10264
|
#if defined(GGML_METAL_HAS_BF16)
|
|
9767
10265
|
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
|
9768
10266
|
#endif
|
|
10267
|
+
template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
|
|
9769
10268
|
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
|
9770
10269
|
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
|
9771
10270
|
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
|
@@ -9789,6 +10288,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m
|
|
|
9789
10288
|
|
|
9790
10289
|
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
|
9791
10290
|
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
|
10291
|
+
template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
|
|
9792
10292
|
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
|
9793
10293
|
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
|
9794
10294
|
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
|
@@ -9943,6 +10443,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4
|
|
|
9943
10443
|
|
|
9944
10444
|
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
|
|
9945
10445
|
|
|
10446
|
+
template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0>>>;
|
|
9946
10447
|
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
|
|
9947
10448
|
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
|
|
9948
10449
|
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
|