whispercpp 1.3.5 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/LICENSE +1 -1
- data/README.md +133 -3
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -7
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +56 -46
- data/ext/ruby_whisper.h +165 -2
- data/ext/ruby_whisper_context.c +297 -126
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -66
- data/ext/ruby_whisper_segment.c +6 -7
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +46 -16
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +24 -19
- data/ext/sources/examples/cli/cli.cpp +51 -9
- data/ext/sources/examples/common-ggml.cpp +4 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +213 -163
- data/ext/sources/ggml/CMakeLists.txt +29 -15
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +73 -11
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +8 -3
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +155 -16
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +25 -5
- data/ext/sources/ggml/src/ggml-alloc.c +9 -10
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
- data/ext/sources/ggml/src/ggml-common.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
- data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
- data/ext/sources/ggml/src/ggml-impl.h +68 -1
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +385 -119
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
- data/ext/sources/ggml/src/ggml.c +268 -52
- data/ext/sources/ggml/src/gguf.cpp +377 -47
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +62 -40
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +445 -55
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_context_params.rb +82 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +44 -6
- data/whispercpp.gemspec +2 -2
- metadata +426 -280
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
- data/ext/sources/examples/talk-llama/llama-context.h +0 -360
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
- data/ext/sources/examples/talk-llama/llama-model.h +0 -544
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
- data/ext/sources/examples/talk-llama/llama.h +0 -1540
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -569
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
|
@@ -77,6 +77,14 @@ static inline float dot(float x, float y) {
|
|
|
77
77
|
return x*y;
|
|
78
78
|
}
|
|
79
79
|
|
|
80
|
+
static inline float sum(float x) {
|
|
81
|
+
return x;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
static inline float sum(float4 x) {
|
|
85
|
+
return x[0] + x[1] + x[2] + x[3];
|
|
86
|
+
}
|
|
87
|
+
|
|
80
88
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
81
89
|
template <typename type4x4>
|
|
82
90
|
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
|
@@ -110,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg
|
|
|
110
118
|
}
|
|
111
119
|
#endif
|
|
112
120
|
|
|
121
|
+
template <typename type4x4>
|
|
122
|
+
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
|
|
123
|
+
device const uint8_t * qs = xb->qs;
|
|
124
|
+
const float d = xb->d;
|
|
125
|
+
const float neg_d = -d;
|
|
126
|
+
|
|
127
|
+
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
|
|
128
|
+
const uint8_t b0 = qs[byte_offset];
|
|
129
|
+
const uint8_t b1 = qs[byte_offset + 1];
|
|
130
|
+
|
|
131
|
+
float4x4 reg_f;
|
|
132
|
+
|
|
133
|
+
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
|
|
134
|
+
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
|
|
135
|
+
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
|
|
136
|
+
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
|
|
137
|
+
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
|
|
138
|
+
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
|
|
139
|
+
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
|
|
140
|
+
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
|
|
141
|
+
|
|
142
|
+
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
|
|
143
|
+
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
|
|
144
|
+
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
|
|
145
|
+
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
|
|
146
|
+
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
|
|
147
|
+
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
|
|
148
|
+
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
|
|
149
|
+
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
|
|
150
|
+
|
|
151
|
+
reg = (type4x4) reg_f;
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
template <typename type4>
|
|
155
|
+
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
|
|
156
|
+
const float d = xb->d;
|
|
157
|
+
const float neg_d = -d;
|
|
158
|
+
const int base = il * 4;
|
|
159
|
+
const uint8_t byte = xb->qs[base / 8];
|
|
160
|
+
const int s = base % 8;
|
|
161
|
+
|
|
162
|
+
float4 reg_f;
|
|
163
|
+
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
|
|
164
|
+
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
|
|
165
|
+
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
|
|
166
|
+
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
|
|
167
|
+
|
|
168
|
+
reg = (type4) reg_f;
|
|
169
|
+
}
|
|
170
|
+
|
|
113
171
|
template <typename type4x4>
|
|
114
172
|
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
|
|
115
173
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
|
@@ -144,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
|
|
144
202
|
}
|
|
145
203
|
}
|
|
146
204
|
|
|
205
|
+
void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
|
|
206
|
+
float sum_abs = 0.0f;
|
|
207
|
+
for (int j = 0; j < QK1_0; j++) {
|
|
208
|
+
sum_abs += fabs(src[j]);
|
|
209
|
+
}
|
|
210
|
+
dst.d = sum_abs / QK1_0;
|
|
211
|
+
|
|
212
|
+
for (int j = 0; j < QK1_0 / 8; j++) {
|
|
213
|
+
dst.qs[j] = 0;
|
|
214
|
+
}
|
|
215
|
+
for (int j = 0; j < QK1_0; j++) {
|
|
216
|
+
if (src[j] >= 0.0f) {
|
|
217
|
+
dst.qs[j / 8] |= (1 << (j % 8));
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
}
|
|
221
|
+
|
|
147
222
|
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
|
148
223
|
#pragma METAL fp math_mode(safe)
|
|
149
224
|
float amax = 0.0f; // absolute max
|
|
@@ -895,753 +970,459 @@ enum ggml_sort_order {
|
|
|
895
970
|
GGML_SORT_ORDER_DESC,
|
|
896
971
|
};
|
|
897
972
|
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
kernel void kernel_add_fuse_impl(
|
|
903
|
-
constant ggml_metal_kargs_bin & args,
|
|
904
|
-
device const char * src0,
|
|
905
|
-
device const char * src1,
|
|
906
|
-
device char * dst,
|
|
907
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
908
|
-
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
909
|
-
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
910
|
-
const int i03 = tgpig.z;
|
|
911
|
-
const int i02 = tgpig.y;
|
|
912
|
-
const int i01 = tgpig.x;
|
|
973
|
+
constant float GELU_COEF_A = 0.044715f;
|
|
974
|
+
constant float GELU_QUICK_COEF = -1.702f;
|
|
975
|
+
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
976
|
+
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
|
913
977
|
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
978
|
+
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
|
979
|
+
// ref: https://www.johndcook.com/blog/python_erf/
|
|
980
|
+
constant float p_erf = 0.3275911f;
|
|
981
|
+
constant float a1_erf = 0.254829592f;
|
|
982
|
+
constant float a2_erf = -0.284496736f;
|
|
983
|
+
constant float a3_erf = 1.421413741f;
|
|
984
|
+
constant float a4_erf = -1.453152027f;
|
|
985
|
+
constant float a5_erf = 1.061405429f;
|
|
917
986
|
|
|
918
|
-
|
|
919
|
-
|
|
987
|
+
template<typename T>
|
|
988
|
+
inline T erf_approx(T x) {
|
|
989
|
+
T sign_x = sign(x);
|
|
990
|
+
x = fabs(x);
|
|
991
|
+
T t = 1.0f / (1.0f + p_erf * x);
|
|
992
|
+
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
|
993
|
+
return sign_x * y;
|
|
994
|
+
}
|
|
920
995
|
|
|
921
|
-
|
|
922
|
-
for (short j = 0; j < F; ++j) {
|
|
923
|
-
src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
|
924
|
-
}
|
|
996
|
+
template<typename T> T elu_approx(T x);
|
|
925
997
|
|
|
926
|
-
|
|
927
|
-
|
|
998
|
+
template<> inline float elu_approx<float>(float x) {
|
|
999
|
+
return (x > 0.f) ? x : (exp(x) - 1);
|
|
1000
|
+
}
|
|
928
1001
|
|
|
929
|
-
|
|
1002
|
+
template<> inline float4 elu_approx<float4>(float4 x) {
|
|
1003
|
+
float4 res;
|
|
930
1004
|
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
1005
|
+
res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
|
|
1006
|
+
res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
|
|
1007
|
+
res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
|
|
1008
|
+
res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
|
|
935
1009
|
|
|
936
|
-
|
|
937
|
-
}
|
|
1010
|
+
return res;
|
|
938
1011
|
}
|
|
939
1012
|
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
|
|
943
|
-
template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
|
|
944
|
-
template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
|
|
945
|
-
template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
|
|
946
|
-
template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
|
|
947
|
-
template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
|
|
948
|
-
template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
|
|
949
|
-
template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
|
|
1013
|
+
constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
|
|
1014
|
+
constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
|
|
950
1015
|
|
|
951
|
-
|
|
952
|
-
|
|
1016
|
+
template <typename T0, typename T, typename TC>
|
|
1017
|
+
kernel void kernel_unary_impl(
|
|
1018
|
+
constant ggml_metal_kargs_unary & args,
|
|
953
1019
|
device const char * src0,
|
|
954
|
-
device const char * src1,
|
|
955
1020
|
device char * dst,
|
|
956
1021
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
957
1022
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
958
1023
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
const int i01 = tgpig.x;
|
|
1024
|
+
#define FC_OP FC_unary_op
|
|
1025
|
+
#define FC_CNT FC_unary_cnt
|
|
962
1026
|
|
|
963
|
-
const
|
|
964
|
-
|
|
965
|
-
const int i11 = i01%args.ne11;
|
|
1027
|
+
device const T0 * src0_ptr;
|
|
1028
|
+
device T * dst_ptr;
|
|
966
1029
|
|
|
967
|
-
|
|
968
|
-
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
|
969
|
-
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
|
1030
|
+
int i0;
|
|
970
1031
|
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
|
|
974
|
-
}
|
|
975
|
-
}
|
|
1032
|
+
if (FC_CNT) {
|
|
1033
|
+
i0 = tgpig.x;
|
|
976
1034
|
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
device const char * src0,
|
|
980
|
-
device const char * src1,
|
|
981
|
-
device char * dst,
|
|
982
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
983
|
-
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
984
|
-
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
985
|
-
const int i03 = tgpig.z;
|
|
986
|
-
const int i02 = tgpig.y;
|
|
987
|
-
const int i01 = tgpig.x;
|
|
988
|
-
|
|
989
|
-
const int i13 = i03%args.ne13;
|
|
990
|
-
const int i12 = i02%args.ne12;
|
|
991
|
-
const int i11 = i01%args.ne11;
|
|
992
|
-
|
|
993
|
-
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
|
994
|
-
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
|
995
|
-
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
|
996
|
-
|
|
997
|
-
if (args.ne10 == 1) {
|
|
998
|
-
const float x = *((device float *)(src1_ptr));
|
|
999
|
-
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
1000
|
-
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
|
|
1001
|
-
}
|
|
1035
|
+
src0_ptr = (device const T0 *) (src0);
|
|
1036
|
+
dst_ptr = (device T *) (dst);
|
|
1002
1037
|
} else {
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
}
|
|
1008
|
-
}
|
|
1038
|
+
const int i03 = tgpig.z;
|
|
1039
|
+
const int i02 = tgpig.y;
|
|
1040
|
+
const int k0 = tgpig.x/args.ne01;
|
|
1041
|
+
const int i01 = tgpig.x - k0*args.ne01;
|
|
1009
1042
|
|
|
1010
|
-
|
|
1011
|
-
constant ggml_metal_kargs_bin & args,
|
|
1012
|
-
device const char * src0,
|
|
1013
|
-
device const char * src1,
|
|
1014
|
-
device char * dst,
|
|
1015
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1016
|
-
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1017
|
-
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1018
|
-
const int i03 = tgpig.z;
|
|
1019
|
-
const int i02 = tgpig.y;
|
|
1020
|
-
const int i01 = tgpig.x;
|
|
1043
|
+
i0 = k0*ntg.x + tpitg.x;
|
|
1021
1044
|
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1045
|
+
src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
|
1046
|
+
dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
|
|
1047
|
+
}
|
|
1025
1048
|
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
|
1049
|
+
{
|
|
1050
|
+
//threadgroup_barrier(mem_flags::mem_none);
|
|
1029
1051
|
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
}
|
|
1035
|
-
} else {
|
|
1036
|
-
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
1037
|
-
const int i10 = i0%args.ne10;
|
|
1038
|
-
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
|
|
1052
|
+
if (!FC_CNT) {
|
|
1053
|
+
if (i0 >= args.ne0) {
|
|
1054
|
+
return;
|
|
1055
|
+
}
|
|
1039
1056
|
}
|
|
1040
|
-
}
|
|
1041
|
-
}
|
|
1042
1057
|
|
|
1043
|
-
|
|
1044
|
-
constant ggml_metal_kargs_add_id & args,
|
|
1045
|
-
device const char * src0,
|
|
1046
|
-
device const char * src1,
|
|
1047
|
-
device const char * src2,
|
|
1048
|
-
device char * dst,
|
|
1049
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1050
|
-
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1051
|
-
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1052
|
-
const int i1 = tgpig.x;
|
|
1053
|
-
const int i2 = tgpig.y;
|
|
1058
|
+
const TC x = (TC) src0_ptr[i0];
|
|
1054
1059
|
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
const size_t nb2 = args.ne1 * nb1;
|
|
1060
|
+
if (FC_OP == OP_UNARY_NUM_SCALE) {
|
|
1061
|
+
dst_ptr[i0] = (T) (args.scale * x + args.bias);
|
|
1062
|
+
}
|
|
1059
1063
|
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1064
|
+
if (FC_OP == OP_UNARY_NUM_FILL) {
|
|
1065
|
+
dst_ptr[i0] = (T) args.val;
|
|
1066
|
+
}
|
|
1063
1067
|
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
}
|
|
1068
|
+
if (FC_OP == OP_UNARY_NUM_CLAMP) {
|
|
1069
|
+
dst_ptr[i0] = (T) clamp(x, args.min, args.max);
|
|
1070
|
+
}
|
|
1068
1071
|
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
device const char * src0,
|
|
1073
|
-
device char * dst,
|
|
1074
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1075
|
-
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1076
|
-
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1077
|
-
const int i3 = tgpig.z;
|
|
1078
|
-
const int i2 = tgpig.y;
|
|
1079
|
-
const int i1 = tgpig.x;
|
|
1072
|
+
if (FC_OP == OP_UNARY_NUM_SQR) {
|
|
1073
|
+
dst_ptr[i0] = (T) (x * x);
|
|
1074
|
+
}
|
|
1080
1075
|
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1076
|
+
if (FC_OP == OP_UNARY_NUM_SQRT) {
|
|
1077
|
+
dst_ptr[i0] = (T) sqrt(x);
|
|
1078
|
+
}
|
|
1084
1079
|
|
|
1085
|
-
|
|
1086
|
-
|
|
1080
|
+
if (FC_OP == OP_UNARY_NUM_SIN) {
|
|
1081
|
+
dst_ptr[i0] = (T) sin(x);
|
|
1082
|
+
}
|
|
1087
1083
|
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
}
|
|
1092
|
-
}
|
|
1084
|
+
if (FC_OP == OP_UNARY_NUM_COS) {
|
|
1085
|
+
dst_ptr[i0] = (T) cos(x);
|
|
1086
|
+
}
|
|
1093
1087
|
|
|
1094
|
-
|
|
1088
|
+
if (FC_OP == OP_UNARY_NUM_LOG) {
|
|
1089
|
+
dst_ptr[i0] = (T) log(x);
|
|
1090
|
+
}
|
|
1095
1091
|
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
|
1092
|
+
if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
|
|
1093
|
+
dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
|
|
1094
|
+
}
|
|
1100
1095
|
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
kernel void kernel_add_row_c4_fuse_impl(
|
|
1105
|
-
constant ggml_metal_kargs_bin & args,
|
|
1106
|
-
device const char * src0,
|
|
1107
|
-
device const char * src1,
|
|
1108
|
-
device char * dst,
|
|
1109
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1110
|
-
const uint nb = args.ne00/4;
|
|
1111
|
-
const uint i = tpig % nb;
|
|
1096
|
+
if (FC_OP == OP_UNARY_NUM_TANH) {
|
|
1097
|
+
dst_ptr[i0] = (T) precise::tanh(x);
|
|
1098
|
+
}
|
|
1112
1099
|
|
|
1113
|
-
|
|
1114
|
-
|
|
1100
|
+
if (FC_OP == OP_UNARY_NUM_RELU) {
|
|
1101
|
+
dst_ptr[i0] = (T) fmax(0, x);
|
|
1102
|
+
}
|
|
1115
1103
|
|
|
1116
|
-
|
|
1104
|
+
if (FC_OP == OP_UNARY_NUM_SIGMOID) {
|
|
1105
|
+
dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
|
|
1106
|
+
}
|
|
1117
1107
|
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
}
|
|
1108
|
+
if (FC_OP == OP_UNARY_NUM_GELU) {
|
|
1109
|
+
dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
|
|
1110
|
+
}
|
|
1122
1111
|
|
|
1123
|
-
|
|
1124
|
-
|
|
1112
|
+
if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
|
|
1113
|
+
dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
|
|
1114
|
+
}
|
|
1125
1115
|
|
|
1126
|
-
|
|
1116
|
+
if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
|
|
1117
|
+
dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
|
|
1118
|
+
}
|
|
1127
1119
|
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
|
|
1132
|
-
template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
|
|
1133
|
-
template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
|
|
1134
|
-
template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
|
|
1135
|
-
template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
|
|
1120
|
+
if (FC_OP == OP_UNARY_NUM_SILU) {
|
|
1121
|
+
dst_ptr[i0] = (T) (x / (1 + exp(-x)));
|
|
1122
|
+
}
|
|
1136
1123
|
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
device const char * src0,
|
|
1141
|
-
device const char * src1,
|
|
1142
|
-
device char * dst,
|
|
1143
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1124
|
+
if (FC_OP == OP_UNARY_NUM_ELU) {
|
|
1125
|
+
dst_ptr[i0] = (T) elu_approx(x);
|
|
1126
|
+
}
|
|
1144
1127
|
|
|
1145
|
-
|
|
1146
|
-
|
|
1128
|
+
if (FC_OP == OP_UNARY_NUM_NEG) {
|
|
1129
|
+
dst_ptr[i0] = (T) -x;
|
|
1130
|
+
}
|
|
1147
1131
|
|
|
1148
|
-
|
|
1149
|
-
|
|
1132
|
+
if (FC_OP == OP_UNARY_NUM_ABS) {
|
|
1133
|
+
dst_ptr[i0] = (T) fabs(x);
|
|
1134
|
+
}
|
|
1150
1135
|
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
}
|
|
1136
|
+
if (FC_OP == OP_UNARY_NUM_SGN) {
|
|
1137
|
+
dst_ptr[i0] = T(x > 0) - T(x < 0);
|
|
1138
|
+
}
|
|
1155
1139
|
|
|
1156
|
-
|
|
1140
|
+
if (FC_OP == OP_UNARY_NUM_STEP) {
|
|
1141
|
+
dst_ptr[i0] = T(x > 0);
|
|
1142
|
+
}
|
|
1157
1143
|
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
}
|
|
1144
|
+
if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
|
|
1145
|
+
dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
|
|
1146
|
+
}
|
|
1162
1147
|
|
|
1163
|
-
|
|
1164
|
-
|
|
1148
|
+
if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
|
|
1149
|
+
dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
|
|
1150
|
+
}
|
|
1165
1151
|
|
|
1166
|
-
|
|
1152
|
+
if (FC_OP == OP_UNARY_NUM_EXP) {
|
|
1153
|
+
dst_ptr[i0] = (T) exp(x);
|
|
1154
|
+
}
|
|
1167
1155
|
|
|
1168
|
-
|
|
1156
|
+
if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
|
|
1157
|
+
dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
|
|
1158
|
+
}
|
|
1169
1159
|
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
device const char * src1,
|
|
1175
|
-
device char * dst,
|
|
1176
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1160
|
+
if (FC_OP == OP_UNARY_NUM_EXPM1) {
|
|
1161
|
+
// TODO: precise implementation
|
|
1162
|
+
dst_ptr[i0] = (T) (exp(x) - 1);
|
|
1163
|
+
}
|
|
1177
1164
|
|
|
1178
|
-
|
|
1179
|
-
|
|
1165
|
+
if (FC_OP == OP_UNARY_NUM_FLOOR) {
|
|
1166
|
+
dst_ptr[i0] = (T) floor(x);
|
|
1167
|
+
}
|
|
1180
1168
|
|
|
1181
|
-
|
|
1182
|
-
|
|
1169
|
+
if (FC_OP == OP_UNARY_NUM_CEIL) {
|
|
1170
|
+
dst_ptr[i0] = (T) ceil(x);
|
|
1171
|
+
}
|
|
1183
1172
|
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
}
|
|
1173
|
+
if (FC_OP == OP_UNARY_NUM_ROUND) {
|
|
1174
|
+
dst_ptr[i0] = (T) round(x);
|
|
1175
|
+
}
|
|
1188
1176
|
|
|
1189
|
-
|
|
1177
|
+
if (FC_OP == OP_UNARY_NUM_TRUNC) {
|
|
1178
|
+
dst_ptr[i0] = (T) trunc(x);
|
|
1179
|
+
}
|
|
1190
1180
|
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1181
|
+
if (FC_OP == OP_UNARY_NUM_XIELU) {
|
|
1182
|
+
const TC xi = x;
|
|
1183
|
+
const TC gate = TC(xi > TC(0.0f));
|
|
1184
|
+
const TC clamped = fmin(xi, TC(args.val));
|
|
1185
|
+
const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi;
|
|
1186
|
+
const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi;
|
|
1187
|
+
dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg);
|
|
1188
|
+
}
|
|
1194
1189
|
}
|
|
1195
1190
|
|
|
1196
|
-
|
|
1191
|
+
#undef FC_OP
|
|
1192
|
+
#undef FC_CNT
|
|
1197
1193
|
}
|
|
1198
1194
|
|
|
1199
|
-
typedef decltype(
|
|
1195
|
+
typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
|
|
1200
1196
|
|
|
1201
|
-
template [[host_name("
|
|
1197
|
+
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
|
|
1198
|
+
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
|
|
1199
|
+
template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
|
|
1200
|
+
template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
|
|
1202
1201
|
|
|
1203
|
-
|
|
1204
|
-
|
|
1202
|
+
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
|
|
1203
|
+
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
|
|
1204
|
+
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
|
|
1205
|
+
constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
|
|
1206
|
+
constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]];
|
|
1207
|
+
|
|
1208
|
+
template <typename T0, typename T1, typename T>
|
|
1209
|
+
kernel void kernel_bin_fuse_impl(
|
|
1205
1210
|
constant ggml_metal_kargs_bin & args,
|
|
1206
1211
|
device const char * src0,
|
|
1207
1212
|
device const char * src1,
|
|
1208
1213
|
device char * dst,
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
device const float4 * src1_row[F];
|
|
1218
|
-
for (short j = 0; j < F; ++j) {
|
|
1219
|
-
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
|
1220
|
-
}
|
|
1221
|
-
|
|
1222
|
-
float4 res = src0_row[tpig];
|
|
1223
|
-
|
|
1224
|
-
#pragma unroll(F)
|
|
1225
|
-
for (short j = 0; j < F; ++j) {
|
|
1226
|
-
res /= src1_row[j][i];
|
|
1227
|
-
}
|
|
1228
|
-
|
|
1229
|
-
dst_row[tpig] = res;
|
|
1230
|
-
}
|
|
1231
|
-
|
|
1232
|
-
typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
|
|
1233
|
-
|
|
1234
|
-
template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
|
|
1235
|
-
|
|
1236
|
-
kernel void kernel_scale_f32(
|
|
1237
|
-
constant ggml_metal_kargs_scale & args,
|
|
1238
|
-
device const float * src0,
|
|
1239
|
-
device float * dst,
|
|
1240
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1241
|
-
dst[tpig] = src0[tpig] * args.scale + args.bias;
|
|
1242
|
-
}
|
|
1243
|
-
|
|
1244
|
-
kernel void kernel_scale_f32_4(
|
|
1245
|
-
constant ggml_metal_kargs_scale & args,
|
|
1246
|
-
device const float4 * src0,
|
|
1247
|
-
device float4 * dst,
|
|
1248
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1249
|
-
dst[tpig] = src0[tpig] * args.scale + args.bias;
|
|
1250
|
-
}
|
|
1251
|
-
|
|
1252
|
-
kernel void kernel_fill_f32(
|
|
1253
|
-
constant ggml_metal_kargs_fill & args,
|
|
1254
|
-
device const float * src0,
|
|
1255
|
-
device float * dst,
|
|
1256
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1257
|
-
dst[tpig] = args.val;
|
|
1258
|
-
}
|
|
1259
|
-
|
|
1260
|
-
kernel void kernel_fill_f32_4(
|
|
1261
|
-
constant ggml_metal_kargs_fill & args,
|
|
1262
|
-
device const float4 * src0,
|
|
1263
|
-
device float4 * dst,
|
|
1264
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1265
|
-
dst[tpig] = args.val;
|
|
1266
|
-
}
|
|
1267
|
-
|
|
1268
|
-
kernel void kernel_clamp_f32(
|
|
1269
|
-
constant ggml_metal_kargs_clamp & args,
|
|
1270
|
-
device const float * src0,
|
|
1271
|
-
device float * dst,
|
|
1272
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1273
|
-
dst[tpig] = clamp(src0[tpig], args.min, args.max);
|
|
1274
|
-
}
|
|
1275
|
-
|
|
1276
|
-
kernel void kernel_clamp_f32_4(
|
|
1277
|
-
constant ggml_metal_kargs_clamp & args,
|
|
1278
|
-
device const float4 * src0,
|
|
1279
|
-
device float4 * dst,
|
|
1280
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1281
|
-
dst[tpig] = clamp(src0[tpig], args.min, args.max);
|
|
1282
|
-
}
|
|
1283
|
-
|
|
1284
|
-
kernel void kernel_relu_f32(
|
|
1285
|
-
device const float * src0,
|
|
1286
|
-
device float * dst,
|
|
1287
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1288
|
-
dst[tpig] = max(0.0f, src0[tpig]);
|
|
1289
|
-
}
|
|
1290
|
-
|
|
1291
|
-
kernel void kernel_relu_f32_4(
|
|
1292
|
-
device const float4 * src0,
|
|
1293
|
-
device float4 * dst,
|
|
1294
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1295
|
-
dst[tpig] = max(0.0f, src0[tpig]);
|
|
1296
|
-
}
|
|
1297
|
-
|
|
1298
|
-
kernel void kernel_sigmoid_f32(
|
|
1299
|
-
device const float * src0,
|
|
1300
|
-
device float * dst,
|
|
1301
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1302
|
-
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
|
|
1303
|
-
}
|
|
1304
|
-
|
|
1305
|
-
kernel void kernel_sigmoid_f32_4(
|
|
1306
|
-
device const float4 * src0,
|
|
1307
|
-
device float4 * dst,
|
|
1308
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1309
|
-
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
|
|
1310
|
-
}
|
|
1311
|
-
|
|
1312
|
-
kernel void kernel_tanh_f32(
|
|
1313
|
-
device const float * src0,
|
|
1314
|
-
device float * dst,
|
|
1315
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1316
|
-
dst[tpig] = precise::tanh(src0[tpig]);
|
|
1317
|
-
}
|
|
1318
|
-
|
|
1319
|
-
kernel void kernel_tanh_f32_4(
|
|
1320
|
-
device const float4 * src0,
|
|
1321
|
-
device float4 * dst,
|
|
1322
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1323
|
-
dst[tpig] = precise::tanh(src0[tpig]);
|
|
1324
|
-
}
|
|
1325
|
-
|
|
1326
|
-
constant float GELU_COEF_A = 0.044715f;
|
|
1327
|
-
constant float GELU_QUICK_COEF = -1.702f;
|
|
1328
|
-
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
1329
|
-
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
|
1330
|
-
|
|
1331
|
-
kernel void kernel_gelu_f32(
|
|
1332
|
-
device const float * src0,
|
|
1333
|
-
device float * dst,
|
|
1334
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1335
|
-
device const float & x = src0[tpig];
|
|
1336
|
-
|
|
1337
|
-
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
|
1338
|
-
}
|
|
1339
|
-
|
|
1340
|
-
kernel void kernel_gelu_f32_4(
|
|
1341
|
-
device const float4 * src0,
|
|
1342
|
-
device float4 * dst,
|
|
1343
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1344
|
-
device const float4 & x = src0[tpig];
|
|
1345
|
-
|
|
1346
|
-
// BEWARE !!!
|
|
1347
|
-
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
|
1348
|
-
// This was observed with Falcon 7B and 40B models
|
|
1349
|
-
//
|
|
1350
|
-
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
|
1351
|
-
}
|
|
1214
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1215
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1216
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1217
|
+
#define FC_OP FC_bin_op
|
|
1218
|
+
#define FC_F FC_bin_f
|
|
1219
|
+
#define FC_RB FC_bin_rb
|
|
1220
|
+
#define FC_CB FC_bin_cb
|
|
1352
1221
|
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
device const float & x = src0[tpig];
|
|
1222
|
+
if (FC_RB) {
|
|
1223
|
+
// row broadcast
|
|
1224
|
+
const uint i0 = tgpig.y*args.ne00 + tgpig.x;
|
|
1225
|
+
const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
|
|
1358
1226
|
|
|
1359
|
-
|
|
1360
|
-
|
|
1227
|
+
device const T0 * src0_row = (device const T0 *) (src0);
|
|
1228
|
+
device T * dst_row = (device T *) (dst);
|
|
1361
1229
|
|
|
1362
|
-
|
|
1363
|
-
|
|
1364
|
-
device float4 * dst,
|
|
1365
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1366
|
-
device const float4 & x = src0[tpig];
|
|
1230
|
+
if (FC_F == 1) {
|
|
1231
|
+
device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
|
|
1367
1232
|
|
|
1368
|
-
|
|
1369
|
-
|
|
1233
|
+
if (FC_OP == 0) {
|
|
1234
|
+
dst_row[i0] = src0_row[i0] + src1_row[i1];
|
|
1235
|
+
}
|
|
1370
1236
|
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
constant float a1_erf = 0.254829592f;
|
|
1375
|
-
constant float a2_erf = -0.284496736f;
|
|
1376
|
-
constant float a3_erf = 1.421413741f;
|
|
1377
|
-
constant float a4_erf = -1.453152027f;
|
|
1378
|
-
constant float a5_erf = 1.061405429f;
|
|
1237
|
+
if (FC_OP == 1) {
|
|
1238
|
+
dst_row[i0] = src0_row[i0] - src1_row[i1];
|
|
1239
|
+
}
|
|
1379
1240
|
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
x = fabs(x);
|
|
1384
|
-
T t = 1.0f / (1.0f + p_erf * x);
|
|
1385
|
-
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
|
1386
|
-
return sign_x * y;
|
|
1387
|
-
}
|
|
1241
|
+
if (FC_OP == 2) {
|
|
1242
|
+
dst_row[i0] = src0_row[i0] * src1_row[i1];
|
|
1243
|
+
}
|
|
1388
1244
|
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
|
|
1245
|
+
if (FC_OP == 3) {
|
|
1246
|
+
dst_row[i0] = src0_row[i0] / src1_row[i1];
|
|
1247
|
+
}
|
|
1248
|
+
} else {
|
|
1249
|
+
T0 res = src0_row[i0];
|
|
1394
1250
|
|
|
1395
|
-
|
|
1396
|
-
|
|
1251
|
+
if (FC_OP == 0) {
|
|
1252
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1253
|
+
res += ((device const T1 *) (src1 + args.o1[j]))[i1];
|
|
1254
|
+
}
|
|
1255
|
+
}
|
|
1397
1256
|
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1257
|
+
if (FC_OP == 1) {
|
|
1258
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1259
|
+
res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
|
1260
|
+
}
|
|
1261
|
+
}
|
|
1403
1262
|
|
|
1404
|
-
|
|
1405
|
-
|
|
1263
|
+
if (FC_OP == 2) {
|
|
1264
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1265
|
+
res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
|
1266
|
+
}
|
|
1267
|
+
}
|
|
1406
1268
|
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
dst[tpig] = x / (1.0f + exp(-x));
|
|
1413
|
-
}
|
|
1269
|
+
if (FC_OP == 3) {
|
|
1270
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1271
|
+
res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
|
1272
|
+
}
|
|
1273
|
+
}
|
|
1414
1274
|
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
}
|
|
1275
|
+
dst_row[i0] = res;
|
|
1276
|
+
}
|
|
1277
|
+
} else {
|
|
1278
|
+
const int i03 = tgpig.z;
|
|
1279
|
+
const int i02 = tgpig.y;
|
|
1280
|
+
const int i01 = tgpig.x;
|
|
1422
1281
|
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1427
|
-
const float x = src0[tpig];
|
|
1428
|
-
dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
|
|
1429
|
-
}
|
|
1282
|
+
if (i01 >= args.ne01) {
|
|
1283
|
+
return;
|
|
1284
|
+
}
|
|
1430
1285
|
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1435
|
-
const float4 x = src0[tpig];
|
|
1436
|
-
dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
|
|
1437
|
-
dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
|
|
1438
|
-
dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
|
|
1439
|
-
dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
|
|
1440
|
-
}
|
|
1286
|
+
const int i13 = i03%args.ne13;
|
|
1287
|
+
const int i12 = i02%args.ne12;
|
|
1288
|
+
const int i11 = i01%args.ne11;
|
|
1441
1289
|
|
|
1442
|
-
|
|
1443
|
-
device
|
|
1444
|
-
device float * dst,
|
|
1445
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1446
|
-
dst[tpig] = src0[tpig] * src0[tpig];
|
|
1447
|
-
}
|
|
1290
|
+
device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
|
|
1291
|
+
device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
|
|
1448
1292
|
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
device float4 * dst,
|
|
1452
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1453
|
-
dst[tpig] = src0[tpig] * src0[tpig];
|
|
1454
|
-
}
|
|
1293
|
+
if (FC_F == 1) {
|
|
1294
|
+
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
|
1455
1295
|
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
device float * dst,
|
|
1459
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1460
|
-
dst[tpig] = sqrt(src0[tpig]);
|
|
1461
|
-
}
|
|
1296
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
1297
|
+
const int i10 = FC_CB ? i0%args.ne10 : i0;
|
|
1462
1298
|
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1467
|
-
dst[tpig] = sqrt(src0[tpig]);
|
|
1468
|
-
}
|
|
1299
|
+
if (FC_OP == 0) {
|
|
1300
|
+
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
|
|
1301
|
+
}
|
|
1469
1302
|
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1474
|
-
dst[tpig] = sin(src0[tpig]);
|
|
1475
|
-
}
|
|
1303
|
+
if (FC_OP == 1) {
|
|
1304
|
+
dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
|
|
1305
|
+
}
|
|
1476
1306
|
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1481
|
-
dst[tpig] = sin(src0[tpig]);
|
|
1482
|
-
}
|
|
1307
|
+
if (FC_OP == 2) {
|
|
1308
|
+
dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
|
|
1309
|
+
}
|
|
1483
1310
|
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
|
|
1311
|
+
if (FC_OP == 3) {
|
|
1312
|
+
dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
|
|
1313
|
+
}
|
|
1314
|
+
}
|
|
1315
|
+
} else {
|
|
1316
|
+
device const T1 * src1_ptr[8];
|
|
1317
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1318
|
+
src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
|
1319
|
+
}
|
|
1490
1320
|
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
device float4 * dst,
|
|
1494
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1495
|
-
dst[tpig] = cos(src0[tpig]);
|
|
1496
|
-
}
|
|
1321
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
1322
|
+
const int i10 = FC_CB ? i0%args.ne10 : i0;
|
|
1497
1323
|
|
|
1498
|
-
|
|
1499
|
-
device const float * src0,
|
|
1500
|
-
device float * dst,
|
|
1501
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1502
|
-
dst[tpig] = log(src0[tpig]);
|
|
1503
|
-
}
|
|
1324
|
+
T res = src0_ptr[i0];
|
|
1504
1325
|
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
}
|
|
1326
|
+
if (FC_OP == 0) {
|
|
1327
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1328
|
+
res += src1_ptr[j][i10];
|
|
1329
|
+
}
|
|
1330
|
+
}
|
|
1511
1331
|
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
}
|
|
1332
|
+
if (FC_OP == 1) {
|
|
1333
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1334
|
+
res -= src1_ptr[j][i10];
|
|
1335
|
+
}
|
|
1336
|
+
}
|
|
1518
1337
|
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
}
|
|
1338
|
+
if (FC_OP == 2) {
|
|
1339
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1340
|
+
res *= src1_ptr[j][i10];
|
|
1341
|
+
}
|
|
1342
|
+
}
|
|
1525
1343
|
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
}
|
|
1344
|
+
if (FC_OP == 3) {
|
|
1345
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1346
|
+
res /= src1_ptr[j][i10];
|
|
1347
|
+
}
|
|
1348
|
+
}
|
|
1532
1349
|
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
dst[tpig] = fabs(src0[tpig]);
|
|
1538
|
-
}
|
|
1350
|
+
dst_ptr[i0] = res;
|
|
1351
|
+
}
|
|
1352
|
+
}
|
|
1353
|
+
}
|
|
1539
1354
|
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
dst[tpig] = sign(src0[tpig]);
|
|
1355
|
+
#undef FC_OP
|
|
1356
|
+
#undef FC_F
|
|
1357
|
+
#undef FC_RB
|
|
1358
|
+
#undef FC_CB
|
|
1545
1359
|
}
|
|
1546
1360
|
|
|
1547
|
-
|
|
1548
|
-
device const float4 * src0,
|
|
1549
|
-
device float4 * dst,
|
|
1550
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1551
|
-
dst[tpig] = sign(src0[tpig]);
|
|
1552
|
-
}
|
|
1361
|
+
typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
|
|
1553
1362
|
|
|
1554
|
-
kernel
|
|
1555
|
-
|
|
1556
|
-
device float * dst,
|
|
1557
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1558
|
-
dst[tpig] = step(0.0f, src0[tpig]);
|
|
1559
|
-
}
|
|
1363
|
+
template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>;
|
|
1364
|
+
template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
|
|
1560
1365
|
|
|
1561
|
-
kernel void
|
|
1562
|
-
|
|
1563
|
-
device
|
|
1564
|
-
|
|
1565
|
-
|
|
1566
|
-
|
|
1366
|
+
kernel void kernel_add_id(
|
|
1367
|
+
constant ggml_metal_kargs_add_id & args,
|
|
1368
|
+
device const char * src0,
|
|
1369
|
+
device const char * src1,
|
|
1370
|
+
device const char * src2,
|
|
1371
|
+
device char * dst,
|
|
1372
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1373
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1374
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1375
|
+
const int i1 = tgpig.x;
|
|
1376
|
+
const int i2 = tgpig.y;
|
|
1567
1377
|
|
|
1568
|
-
|
|
1569
|
-
device const float * src0,
|
|
1570
|
-
device float * dst,
|
|
1571
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1572
|
-
const float x = src0[tpig];
|
|
1573
|
-
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
|
1574
|
-
}
|
|
1378
|
+
const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
|
|
1575
1379
|
|
|
1576
|
-
|
|
1577
|
-
|
|
1578
|
-
device float4 * dst,
|
|
1579
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1580
|
-
const float4 x = src0[tpig];
|
|
1581
|
-
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
|
1582
|
-
}
|
|
1380
|
+
const size_t nb1 = args.ne0 * sizeof(float);
|
|
1381
|
+
const size_t nb2 = args.ne1 * nb1;
|
|
1583
1382
|
|
|
1584
|
-
|
|
1585
|
-
|
|
1586
|
-
|
|
1587
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1588
|
-
const float x = src0[tpig];
|
|
1589
|
-
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
|
1590
|
-
}
|
|
1383
|
+
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
|
|
1384
|
+
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
|
|
1385
|
+
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
|
|
1591
1386
|
|
|
1592
|
-
|
|
1593
|
-
|
|
1594
|
-
|
|
1595
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1596
|
-
const float4 x = src0[tpig];
|
|
1597
|
-
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
|
1387
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
1388
|
+
dst_row[i0] = src0_row[i0] + src1_row[i0];
|
|
1389
|
+
}
|
|
1598
1390
|
}
|
|
1599
1391
|
|
|
1600
|
-
|
|
1601
|
-
|
|
1602
|
-
|
|
1603
|
-
|
|
1604
|
-
|
|
1605
|
-
|
|
1392
|
+
template<typename T>
|
|
1393
|
+
kernel void kernel_repeat(
|
|
1394
|
+
constant ggml_metal_kargs_repeat & args,
|
|
1395
|
+
device const char * src0,
|
|
1396
|
+
device char * dst,
|
|
1397
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1398
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1399
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1400
|
+
const int i3 = tgpig.z;
|
|
1401
|
+
const int i2 = tgpig.y;
|
|
1402
|
+
const int i1 = tgpig.x;
|
|
1606
1403
|
|
|
1607
|
-
|
|
1608
|
-
|
|
1609
|
-
|
|
1610
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1611
|
-
dst[tpig] = exp(src0[tpig]);
|
|
1612
|
-
}
|
|
1404
|
+
const int i03 = i3%args.ne03;
|
|
1405
|
+
const int i02 = i2%args.ne02;
|
|
1406
|
+
const int i01 = i1%args.ne01;
|
|
1613
1407
|
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
device float * dst,
|
|
1617
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1618
|
-
device const float & x = src0[tpig];
|
|
1619
|
-
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
|
|
1620
|
-
}
|
|
1408
|
+
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
|
|
1409
|
+
device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
|
|
1621
1410
|
|
|
1622
|
-
|
|
1623
|
-
|
|
1624
|
-
device
|
|
1625
|
-
|
|
1626
|
-
device const float4 & x = src0[tpig];
|
|
1627
|
-
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
|
|
1411
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
1412
|
+
const int i00 = i0%args.ne00;
|
|
1413
|
+
*((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
|
|
1414
|
+
}
|
|
1628
1415
|
}
|
|
1629
1416
|
|
|
1630
|
-
|
|
1631
|
-
device const float * src0,
|
|
1632
|
-
device float * dst,
|
|
1633
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1634
|
-
dst[tpig] = exp(src0[tpig]) - 1.0f;
|
|
1635
|
-
}
|
|
1417
|
+
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
|
|
1636
1418
|
|
|
1637
|
-
kernel
|
|
1638
|
-
|
|
1639
|
-
|
|
1640
|
-
|
|
1641
|
-
dst[tpig] = exp(src0[tpig]) - 1.0f;
|
|
1642
|
-
}
|
|
1419
|
+
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
|
|
1420
|
+
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
|
|
1421
|
+
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
|
|
1422
|
+
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
|
1643
1423
|
|
|
1644
|
-
|
|
1424
|
+
template<typename T>
|
|
1425
|
+
kernel void kernel_reglu(
|
|
1645
1426
|
constant ggml_metal_kargs_glu & args,
|
|
1646
1427
|
device const char * src0,
|
|
1647
1428
|
device const char * src1,
|
|
@@ -1649,19 +1430,25 @@ kernel void kernel_reglu_f32(
|
|
|
1649
1430
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
1650
1431
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
1651
1432
|
uint ntg[[threads_per_threadgroup]]) {
|
|
1652
|
-
device const
|
|
1653
|
-
device const
|
|
1654
|
-
device
|
|
1433
|
+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1434
|
+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1435
|
+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
|
1655
1436
|
|
|
1656
1437
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1657
1438
|
const float x0 = src0_row[i0];
|
|
1658
1439
|
const float x1 = src1_row[i0];
|
|
1659
1440
|
|
|
1660
|
-
dst_row[i0] = x0*x1*(x0 > 0.0f);
|
|
1441
|
+
dst_row[i0] = (T)(x0*x1*(x0 > 0.0f));
|
|
1661
1442
|
}
|
|
1662
1443
|
}
|
|
1663
1444
|
|
|
1664
|
-
|
|
1445
|
+
typedef decltype(kernel_reglu<float>) kernel_reglu_t;
|
|
1446
|
+
|
|
1447
|
+
template [[host_name("kernel_reglu_f32")]] kernel kernel_reglu_t kernel_reglu<float>;
|
|
1448
|
+
template [[host_name("kernel_reglu_f16")]] kernel kernel_reglu_t kernel_reglu<half>;
|
|
1449
|
+
|
|
1450
|
+
template<typename T>
|
|
1451
|
+
kernel void kernel_geglu(
|
|
1665
1452
|
constant ggml_metal_kargs_glu & args,
|
|
1666
1453
|
device const char * src0,
|
|
1667
1454
|
device const char * src1,
|
|
@@ -1669,9 +1456,9 @@ kernel void kernel_geglu_f32(
|
|
|
1669
1456
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
1670
1457
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
1671
1458
|
uint ntg[[threads_per_threadgroup]]) {
|
|
1672
|
-
device const
|
|
1673
|
-
device const
|
|
1674
|
-
device
|
|
1459
|
+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1460
|
+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1461
|
+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
|
1675
1462
|
|
|
1676
1463
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1677
1464
|
const float x0 = src0_row[i0];
|
|
@@ -1679,11 +1466,17 @@ kernel void kernel_geglu_f32(
|
|
|
1679
1466
|
|
|
1680
1467
|
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
|
1681
1468
|
|
|
1682
|
-
dst_row[i0] = gelu*x1;
|
|
1469
|
+
dst_row[i0] = (T)(gelu*x1);
|
|
1683
1470
|
}
|
|
1684
1471
|
}
|
|
1685
1472
|
|
|
1686
|
-
|
|
1473
|
+
typedef decltype(kernel_geglu<float>) kernel_geglu_t;
|
|
1474
|
+
|
|
1475
|
+
template [[host_name("kernel_geglu_f32")]] kernel kernel_geglu_t kernel_geglu<float>;
|
|
1476
|
+
template [[host_name("kernel_geglu_f16")]] kernel kernel_geglu_t kernel_geglu<half>;
|
|
1477
|
+
|
|
1478
|
+
template<typename T>
|
|
1479
|
+
kernel void kernel_swiglu(
|
|
1687
1480
|
constant ggml_metal_kargs_glu & args,
|
|
1688
1481
|
device const char * src0,
|
|
1689
1482
|
device const char * src1,
|
|
@@ -1691,9 +1484,9 @@ kernel void kernel_swiglu_f32(
|
|
|
1691
1484
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
1692
1485
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
1693
1486
|
uint ntg[[threads_per_threadgroup]]) {
|
|
1694
|
-
device const
|
|
1695
|
-
device const
|
|
1696
|
-
device
|
|
1487
|
+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1488
|
+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1489
|
+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
|
1697
1490
|
|
|
1698
1491
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1699
1492
|
const float x0 = src0_row[i0];
|
|
@@ -1701,11 +1494,17 @@ kernel void kernel_swiglu_f32(
|
|
|
1701
1494
|
|
|
1702
1495
|
const float silu = x0 / (1.0f + exp(-x0));
|
|
1703
1496
|
|
|
1704
|
-
dst_row[i0] = silu*x1;
|
|
1497
|
+
dst_row[i0] = (T)(silu*x1);
|
|
1705
1498
|
}
|
|
1706
1499
|
}
|
|
1707
1500
|
|
|
1708
|
-
|
|
1501
|
+
typedef decltype(kernel_swiglu<float>) kernel_swiglu_t;
|
|
1502
|
+
|
|
1503
|
+
template [[host_name("kernel_swiglu_f32")]] kernel kernel_swiglu_t kernel_swiglu<float>;
|
|
1504
|
+
template [[host_name("kernel_swiglu_f16")]] kernel kernel_swiglu_t kernel_swiglu<half>;
|
|
1505
|
+
|
|
1506
|
+
template<typename T>
|
|
1507
|
+
kernel void kernel_swiglu_oai(
|
|
1709
1508
|
constant ggml_metal_kargs_glu & args,
|
|
1710
1509
|
device const char * src0,
|
|
1711
1510
|
device const char * src1,
|
|
@@ -1713,9 +1512,9 @@ kernel void kernel_swiglu_oai_f32(
|
|
|
1713
1512
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
1714
1513
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
1715
1514
|
uint ntg[[threads_per_threadgroup]]) {
|
|
1716
|
-
device const
|
|
1717
|
-
device const
|
|
1718
|
-
device
|
|
1515
|
+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1516
|
+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1517
|
+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
|
1719
1518
|
|
|
1720
1519
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1721
1520
|
float x0 = src0_row[i0];
|
|
@@ -1727,11 +1526,17 @@ kernel void kernel_swiglu_oai_f32(
|
|
|
1727
1526
|
float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
|
|
1728
1527
|
out_glu = out_glu * (1.0f + x1);
|
|
1729
1528
|
|
|
1730
|
-
dst_row[i0] = out_glu;
|
|
1529
|
+
dst_row[i0] = (T)out_glu;
|
|
1731
1530
|
}
|
|
1732
1531
|
}
|
|
1733
1532
|
|
|
1734
|
-
|
|
1533
|
+
typedef decltype(kernel_swiglu_oai<float>) kernel_swiglu_oai_t;
|
|
1534
|
+
|
|
1535
|
+
template [[host_name("kernel_swiglu_oai_f32")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<float>;
|
|
1536
|
+
template [[host_name("kernel_swiglu_oai_f16")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<half>;
|
|
1537
|
+
|
|
1538
|
+
template<typename T>
|
|
1539
|
+
kernel void kernel_geglu_erf(
|
|
1735
1540
|
constant ggml_metal_kargs_glu & args,
|
|
1736
1541
|
device const char * src0,
|
|
1737
1542
|
device const char * src1,
|
|
@@ -1739,9 +1544,9 @@ kernel void kernel_geglu_erf_f32(
|
|
|
1739
1544
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
1740
1545
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
1741
1546
|
uint ntg[[threads_per_threadgroup]]) {
|
|
1742
|
-
device const
|
|
1743
|
-
device const
|
|
1744
|
-
device
|
|
1547
|
+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1548
|
+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1549
|
+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
|
1745
1550
|
|
|
1746
1551
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1747
1552
|
const float x0 = src0_row[i0];
|
|
@@ -1749,11 +1554,17 @@ kernel void kernel_geglu_erf_f32(
|
|
|
1749
1554
|
|
|
1750
1555
|
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
|
|
1751
1556
|
|
|
1752
|
-
dst_row[i0] = gelu_erf*x1;
|
|
1557
|
+
dst_row[i0] = (T)(gelu_erf*x1);
|
|
1753
1558
|
}
|
|
1754
1559
|
}
|
|
1755
1560
|
|
|
1756
|
-
|
|
1561
|
+
typedef decltype(kernel_geglu_erf<float>) kernel_geglu_erf_t;
|
|
1562
|
+
|
|
1563
|
+
template [[host_name("kernel_geglu_erf_f32")]] kernel kernel_geglu_erf_t kernel_geglu_erf<float>;
|
|
1564
|
+
template [[host_name("kernel_geglu_erf_f16")]] kernel kernel_geglu_erf_t kernel_geglu_erf<half>;
|
|
1565
|
+
|
|
1566
|
+
template<typename T>
|
|
1567
|
+
kernel void kernel_geglu_quick(
|
|
1757
1568
|
constant ggml_metal_kargs_glu & args,
|
|
1758
1569
|
device const char * src0,
|
|
1759
1570
|
device const char * src1,
|
|
@@ -1761,9 +1572,9 @@ kernel void kernel_geglu_quick_f32(
|
|
|
1761
1572
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
1762
1573
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
1763
1574
|
uint ntg[[threads_per_threadgroup]]) {
|
|
1764
|
-
device const
|
|
1765
|
-
device const
|
|
1766
|
-
device
|
|
1575
|
+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1576
|
+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1577
|
+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
|
1767
1578
|
|
|
1768
1579
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1769
1580
|
const float x0 = src0_row[i0];
|
|
@@ -1771,10 +1582,15 @@ kernel void kernel_geglu_quick_f32(
|
|
|
1771
1582
|
|
|
1772
1583
|
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
|
|
1773
1584
|
|
|
1774
|
-
dst_row[i0] = gelu_quick*x1;
|
|
1585
|
+
dst_row[i0] = (T)(gelu_quick*x1);
|
|
1775
1586
|
}
|
|
1776
1587
|
}
|
|
1777
1588
|
|
|
1589
|
+
typedef decltype(kernel_geglu_quick<float>) kernel_geglu_quick_t;
|
|
1590
|
+
|
|
1591
|
+
template [[host_name("kernel_geglu_quick_f32")]] kernel kernel_geglu_quick_t kernel_geglu_quick<float>;
|
|
1592
|
+
template [[host_name("kernel_geglu_quick_f16")]] kernel kernel_geglu_quick_t kernel_geglu_quick<half>;
|
|
1593
|
+
|
|
1778
1594
|
kernel void kernel_op_sum_f32(
|
|
1779
1595
|
constant ggml_metal_kargs_sum & args,
|
|
1780
1596
|
device const float * src0,
|
|
@@ -1824,33 +1640,35 @@ kernel void kernel_op_sum_f32(
|
|
|
1824
1640
|
}
|
|
1825
1641
|
}
|
|
1826
1642
|
|
|
1827
|
-
|
|
1828
|
-
|
|
1643
|
+
constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
|
|
1644
|
+
|
|
1645
|
+
template <typename T0, typename T>
|
|
1646
|
+
kernel void kernel_sum_rows_impl(
|
|
1829
1647
|
constant ggml_metal_kargs_sum_rows & args,
|
|
1830
|
-
device const
|
|
1831
|
-
device
|
|
1832
|
-
threadgroup
|
|
1648
|
+
device const char * src0,
|
|
1649
|
+
device char * dst,
|
|
1650
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
1833
1651
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1834
1652
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1835
1653
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
1836
1654
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
1837
1655
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1838
|
-
|
|
1839
|
-
int64_t i2 = tgpig.y;
|
|
1840
|
-
int64_t i1 = tgpig.x;
|
|
1656
|
+
#define FC_OP FC_sum_rows_op
|
|
1841
1657
|
|
|
1842
|
-
|
|
1843
|
-
|
|
1844
|
-
|
|
1658
|
+
const int i3 = tgpig.z;
|
|
1659
|
+
const int i2 = tgpig.y;
|
|
1660
|
+
const int i1 = tgpig.x;
|
|
1661
|
+
|
|
1662
|
+
threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
|
|
1845
1663
|
|
|
1846
1664
|
if (sgitg == 0) {
|
|
1847
|
-
|
|
1665
|
+
shmem_t[tiisg] = 0.0f;
|
|
1848
1666
|
}
|
|
1849
1667
|
|
|
1850
|
-
device const
|
|
1851
|
-
device
|
|
1668
|
+
device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
|
1669
|
+
device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
|
1852
1670
|
|
|
1853
|
-
|
|
1671
|
+
T0 sumf = T0(0.0f);
|
|
1854
1672
|
|
|
1855
1673
|
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
|
1856
1674
|
sumf += src_row[i0];
|
|
@@ -1861,23 +1679,33 @@ kernel void kernel_sum_rows(
|
|
|
1861
1679
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1862
1680
|
|
|
1863
1681
|
if (tiisg == 0) {
|
|
1864
|
-
|
|
1682
|
+
shmem_t[sgitg] = sumf;
|
|
1865
1683
|
}
|
|
1866
1684
|
|
|
1867
1685
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1868
1686
|
|
|
1869
|
-
sumf =
|
|
1687
|
+
sumf = shmem_t[tiisg];
|
|
1870
1688
|
sumf = simd_sum(sumf);
|
|
1871
1689
|
|
|
1872
1690
|
if (tpitg.x == 0) {
|
|
1873
|
-
|
|
1691
|
+
if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
|
|
1692
|
+
if (is_same<float4, T0>::value) {
|
|
1693
|
+
dst_row[0] = sum(sumf) / (4*args.ne00);
|
|
1694
|
+
} else {
|
|
1695
|
+
dst_row[0] = sum(sumf) / args.ne00;
|
|
1696
|
+
}
|
|
1697
|
+
} else {
|
|
1698
|
+
dst_row[0] = sum(sumf);
|
|
1699
|
+
}
|
|
1874
1700
|
}
|
|
1701
|
+
|
|
1702
|
+
#undef FC_OP
|
|
1875
1703
|
}
|
|
1876
1704
|
|
|
1877
|
-
typedef decltype(
|
|
1705
|
+
typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
|
|
1878
1706
|
|
|
1879
|
-
template [[host_name("
|
|
1880
|
-
template [[host_name("
|
|
1707
|
+
template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
|
|
1708
|
+
template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
|
|
1881
1709
|
|
|
1882
1710
|
template<typename T>
|
|
1883
1711
|
kernel void kernel_cumsum_blk(
|
|
@@ -2737,6 +2565,329 @@ kernel void kernel_rwkv_wkv7_f32(
|
|
|
2737
2565
|
}
|
|
2738
2566
|
}
|
|
2739
2567
|
|
|
2568
|
+
constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
|
|
2569
|
+
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
|
|
2570
|
+
constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];
|
|
2571
|
+
|
|
2572
|
+
#if 1
|
|
2573
|
+
template<short NSG>
|
|
2574
|
+
kernel void kernel_gated_delta_net_impl(
|
|
2575
|
+
constant ggml_metal_kargs_gated_delta_net & args,
|
|
2576
|
+
device const char * q,
|
|
2577
|
+
device const char * k,
|
|
2578
|
+
device const char * v,
|
|
2579
|
+
device const char * g,
|
|
2580
|
+
device const char * b,
|
|
2581
|
+
device const char * s,
|
|
2582
|
+
device char * dst,
|
|
2583
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2584
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
2585
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
2586
|
+
#define S_v FC_gated_delta_net_ne20
|
|
2587
|
+
#define G FC_gated_delta_net_ne30
|
|
2588
|
+
#define K FC_gated_delta_net_K
|
|
2589
|
+
|
|
2590
|
+
const uint tx = tpitg.x;
|
|
2591
|
+
const uint ty = tpitg.y;
|
|
2592
|
+
|
|
2593
|
+
const uint i23 = tgpig.z; // B (n_seqs)
|
|
2594
|
+
const uint i21 = tgpig.y; // H (head)
|
|
2595
|
+
const uint i20 = tgpig.x*NSG + ty; // row within S_v
|
|
2596
|
+
|
|
2597
|
+
const uint i01 = i21 % args.ne01;
|
|
2598
|
+
const uint i11 = i21 % args.ne11;
|
|
2599
|
+
|
|
2600
|
+
const float scale = 1.0f / sqrt((float)S_v);
|
|
2601
|
+
|
|
2602
|
+
// input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D.
|
|
2603
|
+
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
|
|
2604
|
+
const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
|
2605
|
+
device const float * s_ptr = (device const float *) (s) + state_in_base;
|
|
2606
|
+
|
|
2607
|
+
float ls[NSG];
|
|
2608
|
+
|
|
2609
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2610
|
+
const short is = tx*NSG + j;
|
|
2611
|
+
ls[j] = s_ptr[is];
|
|
2612
|
+
}
|
|
2613
|
+
|
|
2614
|
+
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
|
|
2615
|
+
|
|
2616
|
+
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
|
|
2617
|
+
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
|
|
2618
|
+
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
|
|
2619
|
+
|
|
2620
|
+
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
|
|
2621
|
+
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
|
|
2622
|
+
|
|
2623
|
+
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
|
|
2624
|
+
// When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned.
|
|
2625
|
+
|
|
2626
|
+
// output state base offset: after attention scores
|
|
2627
|
+
const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
|
|
2628
|
+
// output state per-slot size: S_v * S_v * H * n_seqs
|
|
2629
|
+
const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23;
|
|
2630
|
+
// per-(seq,head) offset within a slot
|
|
2631
|
+
const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
|
2632
|
+
|
|
2633
|
+
for (short t = 0; t < args.ne22; t++) {
|
|
2634
|
+
float s_k = 0.0f;
|
|
2635
|
+
|
|
2636
|
+
if (G == 1) {
|
|
2637
|
+
const float g_exp = exp(g_ptr[0]);
|
|
2638
|
+
|
|
2639
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2640
|
+
const short is = tx*NSG + j;
|
|
2641
|
+
ls[j] *= g_exp;
|
|
2642
|
+
|
|
2643
|
+
s_k += ls[j]*k_ptr[is];
|
|
2644
|
+
}
|
|
2645
|
+
} else {
|
|
2646
|
+
// KDA
|
|
2647
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2648
|
+
const short is = tx*NSG + j;
|
|
2649
|
+
ls[j] *= exp(g_ptr[is]);
|
|
2650
|
+
|
|
2651
|
+
s_k += ls[j]*k_ptr[is];
|
|
2652
|
+
}
|
|
2653
|
+
}
|
|
2654
|
+
|
|
2655
|
+
s_k = simd_sum(s_k);
|
|
2656
|
+
|
|
2657
|
+
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
|
|
2658
|
+
|
|
2659
|
+
float y = 0.0f;
|
|
2660
|
+
|
|
2661
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2662
|
+
const short is = tx*NSG + j;
|
|
2663
|
+
ls[j] += k_ptr[is]*d;
|
|
2664
|
+
|
|
2665
|
+
y += ls[j]*q_ptr[is];
|
|
2666
|
+
}
|
|
2667
|
+
|
|
2668
|
+
y = simd_sum(y);
|
|
2669
|
+
|
|
2670
|
+
if (tx == 0) {
|
|
2671
|
+
dst_attn[t*args.ne21*S_v] = y*scale;
|
|
2672
|
+
}
|
|
2673
|
+
|
|
2674
|
+
q_ptr += args.ns02;
|
|
2675
|
+
k_ptr += args.ns12;
|
|
2676
|
+
v_ptr += args.ns22;
|
|
2677
|
+
|
|
2678
|
+
b_ptr += args.ne21;
|
|
2679
|
+
g_ptr += args.ne21*G;
|
|
2680
|
+
|
|
2681
|
+
if (K > 1) {
|
|
2682
|
+
const int target_slot = (int)args.ne22 - 1 - (int)t;
|
|
2683
|
+
if (target_slot >= 0 && target_slot < (int)K) {
|
|
2684
|
+
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
|
|
2685
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2686
|
+
const short is = tx*NSG + j;
|
|
2687
|
+
dst_state[is] = ls[j];
|
|
2688
|
+
}
|
|
2689
|
+
}
|
|
2690
|
+
}
|
|
2691
|
+
}
|
|
2692
|
+
|
|
2693
|
+
if (K == 1) {
|
|
2694
|
+
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
|
|
2695
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2696
|
+
const short is = tx*NSG + j;
|
|
2697
|
+
dst_state[is] = ls[j];
|
|
2698
|
+
}
|
|
2699
|
+
}
|
|
2700
|
+
|
|
2701
|
+
#undef S_v
|
|
2702
|
+
#undef G
|
|
2703
|
+
#undef K
|
|
2704
|
+
}
|
|
2705
|
+
|
|
2706
|
+
typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
|
|
2707
|
+
|
|
2708
|
+
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
|
|
2709
|
+
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
|
|
2710
|
+
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
|
|
2711
|
+
|
|
2712
|
+
#else
|
|
2713
|
+
// a simplified version of the above
|
|
2714
|
+
// no performance improvement, so keep the above version for now
|
|
2715
|
+
|
|
2716
|
+
template<typename T, short NSG>
|
|
2717
|
+
kernel void kernel_gated_delta_net_impl(
|
|
2718
|
+
constant ggml_metal_kargs_gated_delta_net & args,
|
|
2719
|
+
device const char * q,
|
|
2720
|
+
device const char * k,
|
|
2721
|
+
device const char * v,
|
|
2722
|
+
device const char * g,
|
|
2723
|
+
device const char * b,
|
|
2724
|
+
device const char * s,
|
|
2725
|
+
device char * dst,
|
|
2726
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2727
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
2728
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
2729
|
+
#define S_v FC_gated_delta_net_ne20
|
|
2730
|
+
#define G FC_gated_delta_net_ne30
|
|
2731
|
+
|
|
2732
|
+
const uint tx = tpitg.x;
|
|
2733
|
+
const uint ty = tpitg.y;
|
|
2734
|
+
|
|
2735
|
+
const uint i23 = tgpig.z; // B
|
|
2736
|
+
const uint i21 = tgpig.y; // H
|
|
2737
|
+
const uint i20 = tgpig.x*NSG + ty;
|
|
2738
|
+
|
|
2739
|
+
const uint i01 = i21 % args.ne01;
|
|
2740
|
+
const uint i11 = i21 % args.ne11;
|
|
2741
|
+
|
|
2742
|
+
const float scale = 1.0f / sqrt((float)S_v);
|
|
2743
|
+
|
|
2744
|
+
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
|
|
2745
|
+
|
|
2746
|
+
float lsf[NSG];
|
|
2747
|
+
|
|
2748
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2749
|
+
const short is = tx*NSG + j;
|
|
2750
|
+
lsf[j] = s_ptr[is*S_v];
|
|
2751
|
+
}
|
|
2752
|
+
|
|
2753
|
+
thread T * ls = (thread T *) (lsf);
|
|
2754
|
+
|
|
2755
|
+
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
|
|
2756
|
+
|
|
2757
|
+
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
|
|
2758
|
+
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
|
|
2759
|
+
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
|
|
2760
|
+
|
|
2761
|
+
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
|
|
2762
|
+
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
|
|
2763
|
+
|
|
2764
|
+
for (short t = 0; t < args.ne22; t++) {
|
|
2765
|
+
device const T * qt_ptr = (device const T *) (q_ptr);
|
|
2766
|
+
device const T * kt_ptr = (device const T *) (k_ptr);
|
|
2767
|
+
device const T * gt_ptr = (device const T *) (g_ptr);
|
|
2768
|
+
|
|
2769
|
+
if (G == 1) {
|
|
2770
|
+
*ls *= exp(g_ptr[0]);
|
|
2771
|
+
} else {
|
|
2772
|
+
// KDA
|
|
2773
|
+
*ls *= exp(gt_ptr[tx]);
|
|
2774
|
+
}
|
|
2775
|
+
|
|
2776
|
+
const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
|
|
2777
|
+
|
|
2778
|
+
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
|
|
2779
|
+
|
|
2780
|
+
*ls += kt_ptr[tx]*d;
|
|
2781
|
+
|
|
2782
|
+
const float y = simd_sum(dot(*ls, qt_ptr[tx]));
|
|
2783
|
+
|
|
2784
|
+
if (tx == 0) {
|
|
2785
|
+
*dst_attn = y*scale;
|
|
2786
|
+
}
|
|
2787
|
+
|
|
2788
|
+
q_ptr += args.ns02;
|
|
2789
|
+
k_ptr += args.ns12;
|
|
2790
|
+
v_ptr += args.ns22;
|
|
2791
|
+
|
|
2792
|
+
b_ptr += args.ne21;
|
|
2793
|
+
g_ptr += args.ne21*G;
|
|
2794
|
+
|
|
2795
|
+
dst_attn += args.ne21*S_v;
|
|
2796
|
+
}
|
|
2797
|
+
|
|
2798
|
+
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
|
|
2799
|
+
device T * dstt_state = (device T *) (dst_state);
|
|
2800
|
+
|
|
2801
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2802
|
+
const short is = tx*NSG + j;
|
|
2803
|
+
dst_state[is*S_v] = lsf[j];
|
|
2804
|
+
}
|
|
2805
|
+
|
|
2806
|
+
#undef S_v
|
|
2807
|
+
#undef G
|
|
2808
|
+
}
|
|
2809
|
+
|
|
2810
|
+
typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;
|
|
2811
|
+
|
|
2812
|
+
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float, 1>;
|
|
2813
|
+
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;
|
|
2814
|
+
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;
|
|
2815
|
+
#endif
|
|
2816
|
+
|
|
2817
|
+
constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
|
|
2818
|
+
constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]];
|
|
2819
|
+
constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];
|
|
2820
|
+
|
|
2821
|
+
kernel void kernel_solve_tri_f32(
|
|
2822
|
+
constant ggml_metal_kargs_solve_tri & args,
|
|
2823
|
+
device const char * src0,
|
|
2824
|
+
device const char * src1,
|
|
2825
|
+
device char * dst,
|
|
2826
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
2827
|
+
ushort3 tgpig[[threadgroup_position_in_grid]],
|
|
2828
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
2829
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
2830
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
2831
|
+
constexpr short NW = N_SIMDWIDTH;
|
|
2832
|
+
|
|
2833
|
+
const short NSG = FC_solve_tri_nsg;
|
|
2834
|
+
const short N = FC_solve_tri_n;
|
|
2835
|
+
const short K = FC_solve_tri_k;
|
|
2836
|
+
const short NP = PAD2(N, NW);
|
|
2837
|
+
|
|
2838
|
+
const int32_t i03 = tgpig.z;
|
|
2839
|
+
const int32_t i02 = tgpig.y;
|
|
2840
|
+
const int32_t i01 = tgpig.x*NSG + sgitg;
|
|
2841
|
+
|
|
2842
|
+
threadgroup float * sh0 = (threadgroup float *) shmem;
|
|
2843
|
+
|
|
2844
|
+
device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
|
|
2845
|
+
device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
|
|
2846
|
+
device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01;
|
|
2847
|
+
|
|
2848
|
+
for (short rr = 0; rr < N; rr += NSG) {
|
|
2849
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2850
|
+
|
|
2851
|
+
{
|
|
2852
|
+
threadgroup float * sh0_cur = sh0 + sgitg*NP;
|
|
2853
|
+
|
|
2854
|
+
for (short t = 0; t*NW < N; ++t) {
|
|
2855
|
+
const short idx = t*NW + tiisg;
|
|
2856
|
+
sh0_cur[idx] = src0_ptr[idx];
|
|
2857
|
+
}
|
|
2858
|
+
|
|
2859
|
+
src0_ptr += NSG*N;
|
|
2860
|
+
}
|
|
2861
|
+
|
|
2862
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2863
|
+
|
|
2864
|
+
if (i01 >= args.ne10) {
|
|
2865
|
+
continue;
|
|
2866
|
+
}
|
|
2867
|
+
|
|
2868
|
+
for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
|
|
2869
|
+
const short r = rr + ir;
|
|
2870
|
+
|
|
2871
|
+
threadgroup float * sh0_cur = sh0 + ir*NP;
|
|
2872
|
+
|
|
2873
|
+
float sum = 0.0f;
|
|
2874
|
+
|
|
2875
|
+
for (short t = 0; t*NW < r; ++t) {
|
|
2876
|
+
const short idx = t*NW + tiisg;
|
|
2877
|
+
sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
|
|
2878
|
+
}
|
|
2879
|
+
|
|
2880
|
+
sum = simd_sum(sum);
|
|
2881
|
+
|
|
2882
|
+
if (tiisg == 0) {
|
|
2883
|
+
const float diag = sh0_cur[r];
|
|
2884
|
+
|
|
2885
|
+
dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
|
|
2886
|
+
}
|
|
2887
|
+
}
|
|
2888
|
+
}
|
|
2889
|
+
}
|
|
2890
|
+
|
|
2740
2891
|
kernel void kernel_argmax_f32(
|
|
2741
2892
|
constant ggml_metal_kargs_argmax & args,
|
|
2742
2893
|
device const char * src0,
|
|
@@ -2970,26 +3121,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f
|
|
|
2970
3121
|
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
|
|
2971
3122
|
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
|
|
2972
3123
|
|
|
2973
|
-
|
|
3124
|
+
template <typename T0, typename T>
|
|
3125
|
+
kernel void kernel_l2_norm_impl(
|
|
2974
3126
|
constant ggml_metal_kargs_l2_norm & args,
|
|
2975
3127
|
device const char * src0,
|
|
2976
3128
|
device char * dst,
|
|
2977
3129
|
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
|
2978
|
-
|
|
2979
|
-
|
|
2980
|
-
ushort
|
|
2981
|
-
ushort
|
|
2982
|
-
|
|
3130
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3131
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
3132
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
3133
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
3134
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
3135
|
+
const int i03 = tgpig.z;
|
|
3136
|
+
const int i02 = tgpig.y;
|
|
3137
|
+
const int i01 = tgpig.x;
|
|
3138
|
+
|
|
2983
3139
|
if (sgitg == 0) {
|
|
2984
3140
|
shmem_f32[tiisg] = 0.0f;
|
|
2985
3141
|
}
|
|
2986
3142
|
|
|
2987
|
-
device const
|
|
3143
|
+
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
|
3144
|
+
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
|
2988
3145
|
|
|
2989
3146
|
float sumf = 0.0f;
|
|
2990
3147
|
|
|
2991
3148
|
// parallel sum
|
|
2992
|
-
for (int i00 = tpitg; i00 < args.
|
|
3149
|
+
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
|
|
2993
3150
|
sumf += dot(x[i00], x[i00]);
|
|
2994
3151
|
}
|
|
2995
3152
|
sumf = simd_sum(sumf);
|
|
@@ -3005,14 +3162,18 @@ kernel void kernel_l2_norm_f32(
|
|
|
3005
3162
|
sumf = shmem_f32[tiisg];
|
|
3006
3163
|
sumf = simd_sum(sumf);
|
|
3007
3164
|
|
|
3008
|
-
const float scale = 1.0f/sqrt(
|
|
3165
|
+
const float scale = 1.0f/max(sqrt(sumf), args.eps);
|
|
3009
3166
|
|
|
3010
|
-
|
|
3011
|
-
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
|
3167
|
+
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
|
|
3012
3168
|
y[i00] = x[i00] * scale;
|
|
3013
3169
|
}
|
|
3014
3170
|
}
|
|
3015
3171
|
|
|
3172
|
+
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
|
|
3173
|
+
|
|
3174
|
+
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
|
|
3175
|
+
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
|
|
3176
|
+
|
|
3016
3177
|
kernel void kernel_group_norm_f32(
|
|
3017
3178
|
constant ggml_metal_kargs_group_norm & args,
|
|
3018
3179
|
device const float * src0,
|
|
@@ -3094,6 +3255,35 @@ kernel void kernel_group_norm_f32(
|
|
|
3094
3255
|
}
|
|
3095
3256
|
}
|
|
3096
3257
|
|
|
3258
|
+
// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
|
|
3259
|
+
inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
|
|
3260
|
+
device const uint8_t * qs = qb_curr->qs + il / 8;
|
|
3261
|
+
const uint8_t b0 = qs[0];
|
|
3262
|
+
const uint8_t b1 = qs[1];
|
|
3263
|
+
|
|
3264
|
+
float acc = 0.0f;
|
|
3265
|
+
|
|
3266
|
+
acc += select(0.0f, yl[ 0], bool(b0 & 0x01));
|
|
3267
|
+
acc += select(0.0f, yl[ 1], bool(b0 & 0x02));
|
|
3268
|
+
acc += select(0.0f, yl[ 2], bool(b0 & 0x04));
|
|
3269
|
+
acc += select(0.0f, yl[ 3], bool(b0 & 0x08));
|
|
3270
|
+
acc += select(0.0f, yl[ 4], bool(b0 & 0x10));
|
|
3271
|
+
acc += select(0.0f, yl[ 5], bool(b0 & 0x20));
|
|
3272
|
+
acc += select(0.0f, yl[ 6], bool(b0 & 0x40));
|
|
3273
|
+
acc += select(0.0f, yl[ 7], bool(b0 & 0x80));
|
|
3274
|
+
|
|
3275
|
+
acc += select(0.0f, yl[ 8], bool(b1 & 0x01));
|
|
3276
|
+
acc += select(0.0f, yl[ 9], bool(b1 & 0x02));
|
|
3277
|
+
acc += select(0.0f, yl[10], bool(b1 & 0x04));
|
|
3278
|
+
acc += select(0.0f, yl[11], bool(b1 & 0x08));
|
|
3279
|
+
acc += select(0.0f, yl[12], bool(b1 & 0x10));
|
|
3280
|
+
acc += select(0.0f, yl[13], bool(b1 & 0x20));
|
|
3281
|
+
acc += select(0.0f, yl[14], bool(b1 & 0x40));
|
|
3282
|
+
acc += select(0.0f, yl[15], bool(b1 & 0x80));
|
|
3283
|
+
|
|
3284
|
+
return qb_curr->d * (2.0f * acc - sumy);
|
|
3285
|
+
}
|
|
3286
|
+
|
|
3097
3287
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
3098
3288
|
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
|
3099
3289
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
|
@@ -3226,6 +3416,9 @@ static inline void helper_mv_reduce_and_write(
|
|
|
3226
3416
|
|
|
3227
3417
|
constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
|
|
3228
3418
|
constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
|
|
3419
|
+
constant short FC_mul_mv_ne12 [[function_constant(FC_MUL_MV + 2)]];
|
|
3420
|
+
constant short FC_mul_mv_r2 [[function_constant(FC_MUL_MV + 3)]];
|
|
3421
|
+
constant short FC_mul_mv_r3 [[function_constant(FC_MUL_MV + 4)]];
|
|
3229
3422
|
|
|
3230
3423
|
template<typename block_q_type, short NR0, typename args_t>
|
|
3231
3424
|
void mul_vec_q_n_f32_impl(
|
|
@@ -3249,72 +3442,151 @@ void mul_vec_q_n_f32_impl(
|
|
|
3249
3442
|
const int r1 = tgpig.y;
|
|
3250
3443
|
const int im = tgpig.z;
|
|
3251
3444
|
|
|
3252
|
-
const uint i12 = im%
|
|
3253
|
-
const uint i13 = im/
|
|
3445
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
3446
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
3254
3447
|
|
|
3255
|
-
//const uint64_t offset0 = r0*args.nb01 + (i12/
|
|
3448
|
+
//const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3256
3449
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
3257
3450
|
|
|
3258
3451
|
//device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
|
|
3259
3452
|
device const float * y = (device const float *) (src1 + offset1);
|
|
3260
3453
|
|
|
3261
|
-
// pointers to src0 rows
|
|
3262
|
-
device const block_q_type * ax[NR0];
|
|
3263
|
-
FOR_UNROLL (int row = 0; row < NR0; ++row) {
|
|
3264
|
-
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/
|
|
3454
|
+
// pointers to src0 rows
|
|
3455
|
+
device const block_q_type * ax[NR0];
|
|
3456
|
+
FOR_UNROLL (int row = 0; row < NR0; ++row) {
|
|
3457
|
+
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3458
|
+
|
|
3459
|
+
ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
|
|
3460
|
+
}
|
|
3461
|
+
|
|
3462
|
+
float sumf[NR0] = {0.f};
|
|
3463
|
+
|
|
3464
|
+
const short ix = (tiisg/(NW/NQ));
|
|
3465
|
+
const short il = (tiisg%(NW/NQ))*8;
|
|
3466
|
+
|
|
3467
|
+
//const int ib0 = sgitg*NQ + ix;
|
|
3468
|
+
const int ib0 = ix;
|
|
3469
|
+
|
|
3470
|
+
float yl[16]; // src1 vector cache
|
|
3471
|
+
|
|
3472
|
+
//device const float * yb = y + ix*QK4_0 + il;
|
|
3473
|
+
device const float * yb = y + ib0*QK4_0 + il;
|
|
3474
|
+
|
|
3475
|
+
// each thread in a SIMD group deals with half a block.
|
|
3476
|
+
//for (int ib = ib0; ib < nb; ib += NSG*NQ) {
|
|
3477
|
+
for (int ib = ib0; ib < nb; ib += NQ) {
|
|
3478
|
+
float sumy[2] = { 0.f, 0.f };
|
|
3479
|
+
|
|
3480
|
+
FOR_UNROLL (short i = 0; i < 8; i += 2) {
|
|
3481
|
+
sumy[0] += yb[i + 0] + yb[i + 1];
|
|
3482
|
+
yl[i + 0] = yb[i + 0];
|
|
3483
|
+
yl[i + 1] = yb[i + 1]/256.f;
|
|
3484
|
+
|
|
3485
|
+
sumy[1] += yb[i + 16] + yb[i + 17];
|
|
3486
|
+
yl[i + 8] = yb[i + 16]/16.f;
|
|
3487
|
+
yl[i + 9] = yb[i + 17]/4096.f;
|
|
3488
|
+
}
|
|
3489
|
+
|
|
3490
|
+
FOR_UNROLL (short row = 0; row < NR0; row++) {
|
|
3491
|
+
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
|
|
3492
|
+
}
|
|
3493
|
+
|
|
3494
|
+
yb += QK4_0 * 16;
|
|
3495
|
+
//yb += NSG*NQ*QK4_0;
|
|
3496
|
+
}
|
|
3497
|
+
|
|
3498
|
+
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
|
3499
|
+
|
|
3500
|
+
//helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
|
3501
|
+
|
|
3502
|
+
for (int row = 0; row < NR0; ++row) {
|
|
3503
|
+
const float tot = simd_sum(sumf[row]);
|
|
3504
|
+
|
|
3505
|
+
if (tiisg == 0 && r0 + row < args.ne01) {
|
|
3506
|
+
dst_f32[r0 + row] = tot;
|
|
3507
|
+
}
|
|
3508
|
+
}
|
|
3509
|
+
}
|
|
3510
|
+
|
|
3511
|
+
template<int nr0, typename args_t>
|
|
3512
|
+
void kernel_mul_mv_q1_0_f32_impl(
|
|
3513
|
+
args_t args,
|
|
3514
|
+
device const char * src0,
|
|
3515
|
+
device const char * src1,
|
|
3516
|
+
device char * dst,
|
|
3517
|
+
threadgroup char * shmem,
|
|
3518
|
+
uint3 tgpig,
|
|
3519
|
+
ushort tiisg,
|
|
3520
|
+
ushort sgitg) {
|
|
3521
|
+
const short NSG = FC_mul_mv_nsg;
|
|
3522
|
+
|
|
3523
|
+
const int nb = args.ne00/QK1_0;
|
|
3524
|
+
|
|
3525
|
+
const int r0 = tgpig.x;
|
|
3526
|
+
const int r1 = tgpig.y;
|
|
3527
|
+
const int im = tgpig.z;
|
|
3528
|
+
|
|
3529
|
+
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
3265
3530
|
|
|
3266
|
-
|
|
3267
|
-
|
|
3531
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
3532
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
3268
3533
|
|
|
3269
|
-
|
|
3534
|
+
const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13;
|
|
3270
3535
|
|
|
3271
|
-
const
|
|
3272
|
-
const short il = (tiisg%(NW/NQ))*8;
|
|
3536
|
+
device const float * y = (device const float *) (src1 + offset1);
|
|
3273
3537
|
|
|
3274
|
-
|
|
3275
|
-
|
|
3538
|
+
device const block_q1_0 * ax[nr0];
|
|
3539
|
+
for (int row = 0; row < nr0; ++row) {
|
|
3540
|
+
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3541
|
+
ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
|
|
3542
|
+
}
|
|
3276
3543
|
|
|
3277
|
-
float yl[16];
|
|
3544
|
+
float yl[16];
|
|
3545
|
+
float sumf[nr0] = {0.f};
|
|
3278
3546
|
|
|
3279
|
-
|
|
3280
|
-
|
|
3547
|
+
const short ix = (tiisg/8);
|
|
3548
|
+
const short il = (tiisg%8)*16;
|
|
3281
3549
|
|
|
3282
|
-
|
|
3283
|
-
//for (int ib = ib0; ib < nb; ib += NSG*NQ) {
|
|
3284
|
-
for (int ib = ib0; ib < nb; ib += NQ) {
|
|
3285
|
-
float sumy[2] = { 0.f, 0.f };
|
|
3550
|
+
device const float * yb = y + ix*QK1_0 + il;
|
|
3286
3551
|
|
|
3287
|
-
|
|
3288
|
-
|
|
3289
|
-
yl[i + 0] = yb[i + 0];
|
|
3290
|
-
yl[i + 1] = yb[i + 1]/256.f;
|
|
3552
|
+
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
|
|
3553
|
+
float sumy = 0.f;
|
|
3291
3554
|
|
|
3292
|
-
|
|
3293
|
-
yl[i
|
|
3294
|
-
|
|
3555
|
+
FOR_UNROLL (short i = 0; i < 16; i++) {
|
|
3556
|
+
yl[i] = yb[i];
|
|
3557
|
+
sumy += yb[i];
|
|
3295
3558
|
}
|
|
3296
3559
|
|
|
3297
|
-
FOR_UNROLL (short row = 0; row <
|
|
3298
|
-
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy
|
|
3560
|
+
FOR_UNROLL (short row = 0; row < nr0; row++) {
|
|
3561
|
+
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il);
|
|
3299
3562
|
}
|
|
3300
3563
|
|
|
3301
|
-
yb +=
|
|
3302
|
-
//yb += NSG*NQ*QK4_0;
|
|
3564
|
+
yb += QK1_0 * (N_SIMDWIDTH/8);
|
|
3303
3565
|
}
|
|
3304
3566
|
|
|
3305
|
-
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
|
3306
|
-
|
|
3307
|
-
//helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
|
3567
|
+
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
|
3308
3568
|
|
|
3309
|
-
for (int row = 0; row <
|
|
3569
|
+
for (int row = 0; row < nr0; ++row) {
|
|
3310
3570
|
const float tot = simd_sum(sumf[row]);
|
|
3311
3571
|
|
|
3312
|
-
if (tiisg == 0 &&
|
|
3313
|
-
dst_f32[
|
|
3572
|
+
if (tiisg == 0 && first_row + row < args.ne01) {
|
|
3573
|
+
dst_f32[first_row + row] = tot;
|
|
3314
3574
|
}
|
|
3315
3575
|
}
|
|
3316
3576
|
}
|
|
3317
3577
|
|
|
3578
|
+
[[host_name("kernel_mul_mv_q1_0_f32")]]
|
|
3579
|
+
kernel void kernel_mul_mv_q1_0_f32(
|
|
3580
|
+
constant ggml_metal_kargs_mul_mv & args,
|
|
3581
|
+
device const char * src0,
|
|
3582
|
+
device const char * src1,
|
|
3583
|
+
device char * dst,
|
|
3584
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3585
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
3586
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3587
|
+
kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
3588
|
+
}
|
|
3589
|
+
|
|
3318
3590
|
kernel void kernel_mul_mv_q4_0_f32(
|
|
3319
3591
|
constant ggml_metal_kargs_mul_mv & args,
|
|
3320
3592
|
device const char * src0,
|
|
@@ -3384,10 +3656,10 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
|
3384
3656
|
const int r1 = tgpig.y;
|
|
3385
3657
|
const int im = tgpig.z;
|
|
3386
3658
|
|
|
3387
|
-
const uint i12 = im%
|
|
3388
|
-
const uint i13 = im/
|
|
3659
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
3660
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
3389
3661
|
|
|
3390
|
-
//const uint64_t offset0 = r0*args.nb01 + (i12/
|
|
3662
|
+
//const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3391
3663
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
3392
3664
|
|
|
3393
3665
|
//device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
|
|
@@ -3396,7 +3668,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
|
3396
3668
|
// pointers to src0 rows
|
|
3397
3669
|
device const block_q8_0 * ax[NR0];
|
|
3398
3670
|
FOR_UNROLL (short row = 0; row < NR0; ++row) {
|
|
3399
|
-
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/
|
|
3671
|
+
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3400
3672
|
|
|
3401
3673
|
ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
|
|
3402
3674
|
}
|
|
@@ -3476,10 +3748,10 @@ void kernel_mul_mv_ext_q4_f32_impl(
|
|
|
3476
3748
|
const int i11 = tgpig.y*r1ptg;
|
|
3477
3749
|
const int i1m = tgpig.z;
|
|
3478
3750
|
|
|
3479
|
-
const int i12 = i1m%
|
|
3480
|
-
const int i13 = i1m/
|
|
3751
|
+
const int i12 = i1m%FC_mul_mv_ne12;
|
|
3752
|
+
const int i13 = i1m/FC_mul_mv_ne12;
|
|
3481
3753
|
|
|
3482
|
-
const uint64_t offset0 = i01*args.nb01 + (i12/
|
|
3754
|
+
const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3483
3755
|
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
3484
3756
|
|
|
3485
3757
|
device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
|
|
@@ -3579,10 +3851,10 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
|
|
|
3579
3851
|
const int i11 = tgpig.y*r1ptg;
|
|
3580
3852
|
const int i1m = tgpig.z;
|
|
3581
3853
|
|
|
3582
|
-
const int i12 = i1m%
|
|
3583
|
-
const int i13 = i1m/
|
|
3854
|
+
const int i12 = i1m%FC_mul_mv_ne12;
|
|
3855
|
+
const int i13 = i1m/FC_mul_mv_ne12;
|
|
3584
3856
|
|
|
3585
|
-
const uint64_t offset0 = i01*args.nb01 + (i12/
|
|
3857
|
+
const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3586
3858
|
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
3587
3859
|
|
|
3588
3860
|
device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
|
|
@@ -3700,6 +3972,18 @@ template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4
|
|
|
3700
3972
|
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
|
|
3701
3973
|
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
|
|
3702
3974
|
|
|
3975
|
+
#if defined(GGML_METAL_HAS_BF16)
|
|
3976
|
+
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, bfloat4, 4, dequantize_bf16_t4>;
|
|
3977
|
+
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, bfloat4, 4, dequantize_bf16_t4>;
|
|
3978
|
+
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, bfloat4, 4, dequantize_bf16_t4>;
|
|
3979
|
+
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
|
|
3980
|
+
#endif
|
|
3981
|
+
|
|
3982
|
+
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>;
|
|
3983
|
+
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>;
|
|
3984
|
+
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>;
|
|
3985
|
+
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>;
|
|
3986
|
+
|
|
3703
3987
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
|
|
3704
3988
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
|
|
3705
3989
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
|
|
@@ -3750,6 +4034,16 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
|
|
|
3750
4034
|
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
|
|
3751
4035
|
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
|
|
3752
4036
|
|
|
4037
|
+
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q2_K, 256, dequantize_q2_K>;
|
|
4038
|
+
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q2_K, 256, dequantize_q2_K>;
|
|
4039
|
+
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q2_K, 256, dequantize_q2_K>;
|
|
4040
|
+
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q2_K, 256, dequantize_q2_K>;
|
|
4041
|
+
|
|
4042
|
+
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q3_K, 256, dequantize_q3_K>;
|
|
4043
|
+
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q3_K, 256, dequantize_q3_K>;
|
|
4044
|
+
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q3_K, 256, dequantize_q3_K>;
|
|
4045
|
+
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q3_K, 256, dequantize_q3_K>;
|
|
4046
|
+
|
|
3753
4047
|
template<typename T0, typename T1, short NR0, typename args_t>
|
|
3754
4048
|
void kernel_mul_mv_t_t_impl(
|
|
3755
4049
|
args_t args,
|
|
@@ -3772,10 +4066,10 @@ void kernel_mul_mv_t_t_impl(
|
|
|
3772
4066
|
const int r1 = tgpig.y;
|
|
3773
4067
|
const int im = tgpig.z;
|
|
3774
4068
|
|
|
3775
|
-
const uint i12 = im%
|
|
3776
|
-
const uint i13 = im/
|
|
4069
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
4070
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
3777
4071
|
|
|
3778
|
-
//const uint64_t offset0 = r0*args.nb01 + (i12/
|
|
4072
|
+
//const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3779
4073
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
3780
4074
|
|
|
3781
4075
|
//device const T0 * x = (device const T0 *) (src0 + offset0);
|
|
@@ -3784,7 +4078,7 @@ void kernel_mul_mv_t_t_impl(
|
|
|
3784
4078
|
// pointers to src0 rows
|
|
3785
4079
|
device const T0 * ax [NR0];
|
|
3786
4080
|
FOR_UNROLL (short row = 0; row < NR0; ++row) {
|
|
3787
|
-
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/
|
|
4081
|
+
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3788
4082
|
|
|
3789
4083
|
ax[row] = (device const T0 *) ((device char *) src0 + offset0);
|
|
3790
4084
|
}
|
|
@@ -3894,10 +4188,10 @@ void kernel_mul_mv_t_t_4_impl(
|
|
|
3894
4188
|
const int r1 = tgpig.y;
|
|
3895
4189
|
const int im = tgpig.z;
|
|
3896
4190
|
|
|
3897
|
-
const uint i12 = im%
|
|
3898
|
-
const uint i13 = im/
|
|
4191
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
4192
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
3899
4193
|
|
|
3900
|
-
//const uint64_t offset0 = r0*args.nb01 + (i12/
|
|
4194
|
+
//const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3901
4195
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
3902
4196
|
|
|
3903
4197
|
device const T1 * y = (device const T1 *) (src1 + offset1);
|
|
@@ -3907,7 +4201,7 @@ void kernel_mul_mv_t_t_4_impl(
|
|
|
3907
4201
|
device const T0 * ax [NR0];
|
|
3908
4202
|
device const T04 * ax4[NR0];
|
|
3909
4203
|
FOR_UNROLL (short row = 0; row < NR0; ++row) {
|
|
3910
|
-
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/
|
|
4204
|
+
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
3911
4205
|
|
|
3912
4206
|
ax [row] = (device const T0 *) ((device char *) src0 + offset0);
|
|
3913
4207
|
ax4[row] = (device const T04 *) ((device char *) src0 + offset0);
|
|
@@ -4011,10 +4305,10 @@ void kernel_mul_mv_t_t_short_impl(
|
|
|
4011
4305
|
return;
|
|
4012
4306
|
}
|
|
4013
4307
|
|
|
4014
|
-
const uint i12 = im%
|
|
4015
|
-
const uint i13 = im/
|
|
4308
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
4309
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
4016
4310
|
|
|
4017
|
-
const uint64_t offset0 = r0*args.nb01 + (i12/
|
|
4311
|
+
const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
4018
4312
|
|
|
4019
4313
|
device const T0 * x = (device const T0 *) (src0 + offset0);
|
|
4020
4314
|
|
|
@@ -4437,59 +4731,59 @@ kernel void kernel_im2col(
|
|
|
4437
4731
|
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
|
|
4438
4732
|
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
|
4439
4733
|
|
|
4440
|
-
// TODO:
|
|
4441
|
-
|
|
4442
|
-
|
|
4443
|
-
|
|
4444
|
-
|
|
4445
|
-
|
|
4446
|
-
|
|
4447
|
-
|
|
4448
|
-
|
|
4449
|
-
|
|
4450
|
-
|
|
4451
|
-
|
|
4452
|
-
|
|
4453
|
-
|
|
4454
|
-
|
|
4455
|
-
|
|
4456
|
-
|
|
4457
|
-
|
|
4458
|
-
|
|
4459
|
-
|
|
4460
|
-
|
|
4461
|
-
|
|
4462
|
-
|
|
4463
|
-
|
|
4464
|
-
|
|
4465
|
-
|
|
4466
|
-
|
|
4467
|
-
|
|
4468
|
-
|
|
4469
|
-
|
|
4470
|
-
|
|
4471
|
-
|
|
4472
|
-
|
|
4473
|
-
|
|
4474
|
-
|
|
4475
|
-
|
|
4476
|
-
|
|
4477
|
-
|
|
4478
|
-
|
|
4479
|
-
|
|
4480
|
-
|
|
4481
|
-
|
|
4482
|
-
|
|
4483
|
-
|
|
4484
|
-
|
|
4485
|
-
|
|
4486
|
-
|
|
4487
|
-
|
|
4488
|
-
|
|
4489
|
-
|
|
4490
|
-
|
|
4491
|
-
|
|
4492
|
-
|
|
4734
|
+
// TODO: optimize
|
|
4735
|
+
typedef void (im2col_ext_t)(
|
|
4736
|
+
constant ggml_metal_kargs_im2col & args,
|
|
4737
|
+
device const float * x,
|
|
4738
|
+
device char * dst,
|
|
4739
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4740
|
+
uint3 tgpg[[threadgroups_per_grid]],
|
|
4741
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
4742
|
+
uint3 ntg[[threads_per_threadgroup]]);
|
|
4743
|
+
|
|
4744
|
+
template <typename T>
|
|
4745
|
+
kernel void kernel_im2col_ext(
|
|
4746
|
+
constant ggml_metal_kargs_im2col & args,
|
|
4747
|
+
device const float * x,
|
|
4748
|
+
device char * dst,
|
|
4749
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4750
|
+
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
|
|
4751
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
4752
|
+
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
|
|
4753
|
+
const int64_t KHW = (int64_t)args.KHW;
|
|
4754
|
+
|
|
4755
|
+
const int64_t d = tgpig[0] / args.CHW;
|
|
4756
|
+
const int64_t chw = tgpig[0] % args.CHW;
|
|
4757
|
+
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
|
|
4758
|
+
const int64_t HW = tgpig[0] % KHW;
|
|
4759
|
+
|
|
4760
|
+
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
|
|
4761
|
+
if (tpitg_0 >= args.N) {
|
|
4762
|
+
return;
|
|
4763
|
+
}
|
|
4764
|
+
|
|
4765
|
+
const int64_t tpitg_1 = HW / args.KW;
|
|
4766
|
+
const int64_t tpitg_2 = HW % args.KW;
|
|
4767
|
+
|
|
4768
|
+
const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
|
|
4769
|
+
const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
|
|
4770
|
+
|
|
4771
|
+
const int64_t offset_dst =
|
|
4772
|
+
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
|
|
4773
|
+
(tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
|
|
4774
|
+
|
|
4775
|
+
device T * pdst = (device T *) (dst);
|
|
4776
|
+
|
|
4777
|
+
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
|
|
4778
|
+
pdst[offset_dst] = 0.0f;
|
|
4779
|
+
} else {
|
|
4780
|
+
const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
|
|
4781
|
+
pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
|
|
4782
|
+
}
|
|
4783
|
+
}
|
|
4784
|
+
|
|
4785
|
+
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
|
|
4786
|
+
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
|
|
4493
4787
|
|
|
4494
4788
|
template <typename TK>
|
|
4495
4789
|
kernel void kernel_conv_2d(
|
|
@@ -4622,15 +4916,32 @@ kernel void kernel_conv_transpose_1d(
|
|
|
4622
4916
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4623
4917
|
uint3 tgpg[[threadgroups_per_grid]]) {
|
|
4624
4918
|
|
|
4625
|
-
|
|
4919
|
+
// For output position j on the time axis, only input positions
|
|
4920
|
+
// i such that i*s0 <= j < i*s0 + K
|
|
4921
|
+
// contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)]
|
|
4922
|
+
// intersected with [0, IL-1]. That's at most ceil(K/s0) values
|
|
4923
|
+
// (typically 2 for stride==K/2 transposed convs).
|
|
4924
|
+
const int32_t j = tgpig[0];
|
|
4925
|
+
const int32_t s0 = args.s0;
|
|
4926
|
+
const int32_t K = args.K;
|
|
4927
|
+
const int32_t IL = args.IL;
|
|
4928
|
+
|
|
4929
|
+
int32_t i_min;
|
|
4930
|
+
{
|
|
4931
|
+
int32_t a = j - K + 1;
|
|
4932
|
+
i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0
|
|
4933
|
+
}
|
|
4934
|
+
int32_t i_max = j / s0;
|
|
4935
|
+
if (i_max > IL - 1) i_max = IL - 1;
|
|
4626
4936
|
|
|
4627
|
-
|
|
4628
|
-
|
|
4629
|
-
|
|
4937
|
+
float v = 0.0f;
|
|
4938
|
+
if (i_min <= i_max) {
|
|
4939
|
+
for (int64_t c = 0; c < args.IC; c++) {
|
|
4940
|
+
const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
|
|
4941
|
+
const int32_t input_offset = c * IL;
|
|
4630
4942
|
|
|
4631
|
-
|
|
4632
|
-
|
|
4633
|
-
v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
|
|
4943
|
+
for (int32_t i = i_min; i <= i_max; i++) {
|
|
4944
|
+
v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i];
|
|
4634
4945
|
}
|
|
4635
4946
|
}
|
|
4636
4947
|
}
|
|
@@ -4749,7 +5060,9 @@ kernel void kernel_conv_transpose_2d<half>(
|
|
|
4749
5060
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
4750
5061
|
uint3 ntg[[threads_per_threadgroup]]);
|
|
4751
5062
|
|
|
4752
|
-
|
|
5063
|
+
constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
|
|
5064
|
+
|
|
5065
|
+
kernel void kernel_upscale_nearest_f32(
|
|
4753
5066
|
constant ggml_metal_kargs_upscale & args,
|
|
4754
5067
|
device const char * src0,
|
|
4755
5068
|
device char * dst,
|
|
@@ -4775,8 +5088,12 @@ kernel void kernel_upscale_f32(
|
|
|
4775
5088
|
}
|
|
4776
5089
|
}
|
|
4777
5090
|
|
|
4778
|
-
|
|
4779
|
-
|
|
5091
|
+
static inline float bilinear_tri(float x) {
|
|
5092
|
+
return MAX(0.0f, 1.0f - fabs(x));
|
|
5093
|
+
}
|
|
5094
|
+
|
|
5095
|
+
kernel void kernel_upscale_bilinear_f32(
|
|
5096
|
+
constant ggml_metal_kargs_upscale & args,
|
|
4780
5097
|
device const char * src0,
|
|
4781
5098
|
device char * dst,
|
|
4782
5099
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
@@ -4787,30 +5104,306 @@ kernel void kernel_pad_f32(
|
|
|
4787
5104
|
const int64_t i2 = tgpig.y;
|
|
4788
5105
|
const int64_t i1 = tgpig.x;
|
|
4789
5106
|
|
|
4790
|
-
const int64_t i03 = i3;
|
|
4791
|
-
const int64_t i02 = i2;
|
|
4792
|
-
const int64_t i01 = i1;
|
|
5107
|
+
const int64_t i03 = i3 / args.sf3;
|
|
5108
|
+
const int64_t i02 = i2 / args.sf2;
|
|
4793
5109
|
|
|
4794
|
-
|
|
4795
|
-
|
|
5110
|
+
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
|
|
5111
|
+
const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
|
|
5112
|
+
const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
|
|
5113
|
+
const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
|
|
5114
|
+
|
|
5115
|
+
src0 += i03*args.nb03 + i02*args.nb02;
|
|
5116
|
+
|
|
5117
|
+
device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
|
5118
|
+
|
|
5119
|
+
if (FC_upscale_aa) {
|
|
5120
|
+
const float support0 = MAX(1.0f, 1.0f / args.sf0);
|
|
5121
|
+
const float invscale0 = 1.0f / support0;
|
|
5122
|
+
const float support1 = MAX(1.0f, 1.0f / args.sf1);
|
|
5123
|
+
const float invscale1 = 1.0f / support1;
|
|
4796
5124
|
|
|
4797
|
-
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
|
4798
5125
|
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
4799
|
-
|
|
4800
|
-
|
|
4801
|
-
|
|
4802
|
-
|
|
5126
|
+
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
|
5127
|
+
|
|
5128
|
+
int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
|
|
5129
|
+
int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs));
|
|
5130
|
+
|
|
5131
|
+
int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
|
|
5132
|
+
int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs));
|
|
5133
|
+
|
|
5134
|
+
float sum = 0.0f;
|
|
5135
|
+
float wsum = 0.0f;
|
|
5136
|
+
|
|
5137
|
+
for (int64_t sy = y_min; sy < y_max; ++sy) {
|
|
5138
|
+
const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
|
|
5139
|
+
for (int64_t sx = x_min; sx < x_max; ++sx) {
|
|
5140
|
+
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
|
|
5141
|
+
const float w = wx * wy;
|
|
5142
|
+
device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
|
|
5143
|
+
sum += (*src_ptr) * w;
|
|
5144
|
+
wsum += w;
|
|
5145
|
+
}
|
|
4803
5146
|
}
|
|
5147
|
+
|
|
5148
|
+
const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
|
|
5149
|
+
dst_ptr[i0] = v;
|
|
4804
5150
|
}
|
|
5151
|
+
} else {
|
|
5152
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
5153
|
+
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
|
5154
|
+
const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
|
|
5155
|
+
const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
|
|
5156
|
+
const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
|
|
4805
5157
|
|
|
4806
|
-
|
|
5158
|
+
device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00);
|
|
5159
|
+
device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00);
|
|
5160
|
+
device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
|
|
5161
|
+
device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
|
|
5162
|
+
|
|
5163
|
+
const float v =
|
|
5164
|
+
(*src00) * (1.0f - fd0) * (1.0f - fd1) +
|
|
5165
|
+
(*src10) * fd0 * (1.0f - fd1) +
|
|
5166
|
+
(*src01) * (1.0f - fd0) * fd1 +
|
|
5167
|
+
(*src11) * fd0 * fd1;
|
|
5168
|
+
|
|
5169
|
+
dst_ptr[i0] = v;
|
|
5170
|
+
}
|
|
5171
|
+
}
|
|
5172
|
+
}
|
|
5173
|
+
|
|
5174
|
+
template <typename T>
|
|
5175
|
+
kernel void kernel_conv_3d(
|
|
5176
|
+
constant ggml_metal_kargs_conv_3d & args,
|
|
5177
|
+
device const char * src0, // Weights [IC * OC, KD, KH, KW]
|
|
5178
|
+
device const char * src1, // Inputs [IC * N, ID, IH, IW]
|
|
5179
|
+
device char * dst, // Outputs [OC * N, OD, OH, OW]
|
|
5180
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5181
|
+
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
|
5182
|
+
|
|
5183
|
+
// 1. Un-flatten the spatial dimension from Grid X
|
|
5184
|
+
int64_t spatial_idx = tgpig.x * 32 + tpitg.x;
|
|
5185
|
+
|
|
5186
|
+
if (spatial_idx >= args.OW * args.OH * args.OD) {
|
|
5187
|
+
return; // Thread falls outside the spatial volume
|
|
5188
|
+
}
|
|
5189
|
+
|
|
5190
|
+
int64_t od = spatial_idx / (args.OW * args.OH);
|
|
5191
|
+
int64_t oh = (spatial_idx / args.OW) % args.OH;
|
|
5192
|
+
int64_t ow = spatial_idx % args.OW;
|
|
5193
|
+
|
|
5194
|
+
// 2. Map Y to Channels, Z to Batch
|
|
5195
|
+
int64_t oc = tgpig.y;
|
|
5196
|
+
int64_t batch_idx = tgpig.z;
|
|
5197
|
+
|
|
5198
|
+
// 3. Calculate anchor coordinates in the Input volume
|
|
5199
|
+
int64_t i_w_base = ow * args.s0 - args.p0;
|
|
5200
|
+
int64_t i_h_base = oh * args.s1 - args.p1;
|
|
5201
|
+
int64_t i_d_base = od * args.s2 - args.p2;
|
|
5202
|
+
|
|
5203
|
+
float sum = 0.0f;
|
|
5204
|
+
|
|
5205
|
+
// 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width)
|
|
5206
|
+
for (int64_t ic = 0; ic < args.IC; ++ic) {
|
|
5207
|
+
|
|
5208
|
+
// ggml packs batch and channel together in the 4th dimension
|
|
5209
|
+
int64_t src_cn_idx = batch_idx * args.IC + ic;
|
|
5210
|
+
int64_t w_cn_idx = oc * args.IC + ic;
|
|
5211
|
+
|
|
5212
|
+
for (int64_t kz = 0; kz < args.KD; ++kz) {
|
|
5213
|
+
int64_t id = i_d_base + kz * args.d2;
|
|
5214
|
+
if (id < 0 || id >= args.ID) continue; // Boundary check (Padding)
|
|
5215
|
+
|
|
5216
|
+
for (int64_t ky = 0; ky < args.KH; ++ky) {
|
|
5217
|
+
int64_t ih = i_h_base + ky * args.d1;
|
|
5218
|
+
if (ih < 0 || ih >= args.IH) continue;
|
|
5219
|
+
|
|
5220
|
+
for (int64_t kx = 0; kx < args.KW; ++kx) {
|
|
5221
|
+
int64_t iw = i_w_base + kx * args.d0;
|
|
5222
|
+
if (iw < 0 || iw >= args.IW) continue;
|
|
5223
|
+
|
|
5224
|
+
// Convert multi-dimensional coordinates to flat byte offsets
|
|
5225
|
+
int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03;
|
|
5226
|
+
int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13;
|
|
5227
|
+
|
|
5228
|
+
// Dereference memory and cast weights to f32 if they were f16
|
|
5229
|
+
float w_val = (float)*(device const T*)((device const char*)src0 + w_idx);
|
|
5230
|
+
float i_val = *(device const float*)((device const char*)src1 + i_idx);
|
|
5231
|
+
|
|
5232
|
+
sum += w_val * i_val;
|
|
5233
|
+
}
|
|
5234
|
+
}
|
|
5235
|
+
}
|
|
5236
|
+
}
|
|
5237
|
+
|
|
5238
|
+
// 5. Write the accumulated value out to RAM
|
|
5239
|
+
int64_t dst_cn_idx = batch_idx * args.OC + oc;
|
|
5240
|
+
int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3;
|
|
5241
|
+
|
|
5242
|
+
*(device float*)(dst + d_idx) = sum;
|
|
5243
|
+
}
|
|
5244
|
+
|
|
5245
|
+
// Explicit instantiations so the JIT compiler can find them by name
|
|
5246
|
+
template [[host_name("kernel_conv_3d_f32_f32")]]
|
|
5247
|
+
kernel void kernel_conv_3d<float>(
|
|
5248
|
+
constant ggml_metal_kargs_conv_3d & args,
|
|
5249
|
+
device const char * src0,
|
|
5250
|
+
device const char * src1,
|
|
5251
|
+
device char * dst,
|
|
5252
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5253
|
+
uint3 tpitg[[thread_position_in_threadgroup]]);
|
|
5254
|
+
|
|
5255
|
+
// Explicit instantiation for f16 weights
|
|
5256
|
+
template [[host_name("kernel_conv_3d_f16_f32")]]
|
|
5257
|
+
kernel void kernel_conv_3d<half>(
|
|
5258
|
+
constant ggml_metal_kargs_conv_3d & args,
|
|
5259
|
+
device const char * src0,
|
|
5260
|
+
device const char * src1,
|
|
5261
|
+
device char * dst,
|
|
5262
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5263
|
+
uint3 tpitg[[thread_position_in_threadgroup]]);
|
|
5264
|
+
|
|
5265
|
+
|
|
5266
|
+
static inline float bicubic_weight1(float x) {
|
|
5267
|
+
const float a = -0.75f;
|
|
5268
|
+
return ((a + 2) * x - (a + 3)) * x * x + 1;
|
|
5269
|
+
}
|
|
5270
|
+
|
|
5271
|
+
static inline float bicubic_weight2(float x) {
|
|
5272
|
+
const float a = -0.75f;
|
|
5273
|
+
return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
|
|
5274
|
+
}
|
|
5275
|
+
|
|
5276
|
+
kernel void kernel_upscale_bicubic_f32(
|
|
5277
|
+
constant ggml_metal_kargs_upscale & args,
|
|
5278
|
+
device const char * src0,
|
|
5279
|
+
device char * dst,
|
|
5280
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5281
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
5282
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
5283
|
+
|
|
5284
|
+
const int64_t i3 = tgpig.z;
|
|
5285
|
+
const int64_t i2 = tgpig.y;
|
|
5286
|
+
const int64_t i1 = tgpig.x;
|
|
5287
|
+
|
|
5288
|
+
const int64_t i03 = i3 / args.sf3;
|
|
5289
|
+
const int64_t i02 = i2 / args.sf2;
|
|
5290
|
+
|
|
5291
|
+
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
|
|
5292
|
+
const int64_t i01 = (int64_t)floor(f01);
|
|
5293
|
+
const float fd1 = f01 - (float)i01;
|
|
5294
|
+
|
|
5295
|
+
const float w_y0 = bicubic_weight2(fd1 + 1.0f);
|
|
5296
|
+
const float w_y1 = bicubic_weight1(fd1);
|
|
5297
|
+
const float w_y2 = bicubic_weight1(1.0f - fd1);
|
|
5298
|
+
const float w_y3 = bicubic_weight2(2.0f - fd1);
|
|
5299
|
+
|
|
5300
|
+
const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
|
|
5301
|
+
|
|
5302
|
+
device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
|
|
5303
|
+
|
|
5304
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
5305
|
+
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
|
5306
|
+
const int64_t i00 = (int64_t)floor(f00);
|
|
5307
|
+
const float fd0 = f00 - (float)i00;
|
|
5308
|
+
|
|
5309
|
+
const float w_x0 = bicubic_weight2(fd0 + 1.0f);
|
|
5310
|
+
const float w_x1 = bicubic_weight1(fd0);
|
|
5311
|
+
const float w_x2 = bicubic_weight1(1.0f - fd0);
|
|
5312
|
+
const float w_x3 = bicubic_weight2(2.0f - fd0);
|
|
5313
|
+
|
|
5314
|
+
float sum = 0.0f;
|
|
5315
|
+
|
|
5316
|
+
for (int dy = -1; dy <= 2; ++dy) {
|
|
5317
|
+
const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
|
|
5318
|
+
const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
|
|
5319
|
+
|
|
5320
|
+
for (int dx = -1; dx <= 2; ++dx) {
|
|
5321
|
+
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
|
|
5322
|
+
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
|
|
5323
|
+
|
|
5324
|
+
device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
|
|
5325
|
+
sum += (*src_ptr) * wx * wy;
|
|
5326
|
+
}
|
|
5327
|
+
}
|
|
5328
|
+
|
|
5329
|
+
dst_ptr[i0] = sum;
|
|
4807
5330
|
}
|
|
5331
|
+
}
|
|
5332
|
+
|
|
5333
|
+
kernel void kernel_roll_f32(
|
|
5334
|
+
constant ggml_metal_kargs_roll & args,
|
|
5335
|
+
device const char * src0,
|
|
5336
|
+
device char * dst,
|
|
5337
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5338
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
5339
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
5340
|
+
|
|
5341
|
+
const int64_t i3 = tgpig.z;
|
|
5342
|
+
const int64_t i2 = tgpig.y;
|
|
5343
|
+
const int64_t i1 = tgpig.x;
|
|
5344
|
+
|
|
5345
|
+
device const float * src0_ptr = (device const float *) src0;
|
|
5346
|
+
device float * dst_ptr = (device float *) dst;
|
|
4808
5347
|
|
|
4809
5348
|
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
4810
|
-
|
|
5349
|
+
// apply shifts and wrap around
|
|
5350
|
+
int64_t i00 = i0 - args.s0;
|
|
5351
|
+
int64_t i01 = i1 - args.s1;
|
|
5352
|
+
int64_t i02 = i2 - args.s2;
|
|
5353
|
+
int64_t i03 = i3 - args.s3;
|
|
5354
|
+
|
|
5355
|
+
if (i00 < 0) { i00 += args.ne00; } else if (i00 >= args.ne00) { i00 -= args.ne00; }
|
|
5356
|
+
if (i01 < 0) { i01 += args.ne01; } else if (i01 >= args.ne01) { i01 -= args.ne01; }
|
|
5357
|
+
if (i02 < 0) { i02 += args.ne02; } else if (i02 >= args.ne02) { i02 -= args.ne02; }
|
|
5358
|
+
if (i03 < 0) { i03 += args.ne03; } else if (i03 >= args.ne03) { i03 -= args.ne03; }
|
|
5359
|
+
|
|
5360
|
+
int64_t src_idx = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00 + i00;
|
|
5361
|
+
int64_t dst_idx = i3 *args.ne2 *args.ne1 *args.ne0 + i2 *args.ne1 *args.ne0 + i1 *args.ne0 + i0;
|
|
5362
|
+
|
|
5363
|
+
dst_ptr[dst_idx] = src0_ptr[src_idx];
|
|
5364
|
+
}
|
|
5365
|
+
}
|
|
5366
|
+
|
|
5367
|
+
template <typename T>
|
|
5368
|
+
kernel void kernel_pad_impl(
|
|
5369
|
+
constant ggml_metal_kargs_pad & args,
|
|
5370
|
+
device const char * src0,
|
|
5371
|
+
device char * dst,
|
|
5372
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5373
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
5374
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
5375
|
+
const int32_t i3 = tgpig.z;
|
|
5376
|
+
const int32_t i2 = tgpig.y;
|
|
5377
|
+
const int32_t k0 = tgpig.x/args.ne1;
|
|
5378
|
+
const int32_t i1 = tgpig.x - k0*args.ne1;
|
|
5379
|
+
|
|
5380
|
+
const int32_t i03 = i3;
|
|
5381
|
+
const int32_t i02 = i2;
|
|
5382
|
+
const int32_t i01 = i1;
|
|
5383
|
+
|
|
5384
|
+
device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
|
5385
|
+
device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
|
5386
|
+
|
|
5387
|
+
for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) {
|
|
5388
|
+
const int32_t i0 = k0*1024 + tpitg.x + l0;
|
|
5389
|
+
if (i0 >= args.ne0) {
|
|
5390
|
+
break;
|
|
5391
|
+
}
|
|
5392
|
+
|
|
5393
|
+
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
|
5394
|
+
dst_ptr[i0] = src0_ptr[i0];
|
|
5395
|
+
} else {
|
|
5396
|
+
dst_ptr[i0] = 0.0f;
|
|
5397
|
+
}
|
|
4811
5398
|
}
|
|
4812
5399
|
}
|
|
4813
5400
|
|
|
5401
|
+
typedef decltype(kernel_pad_impl<float>) kernel_pad_t;
|
|
5402
|
+
|
|
5403
|
+
template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>;
|
|
5404
|
+
template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;
|
|
5405
|
+
|
|
5406
|
+
// TODO: this is slow - optimize
|
|
4814
5407
|
kernel void kernel_pad_reflect_1d_f32(
|
|
4815
5408
|
constant ggml_metal_kargs_pad_reflect_1d & args,
|
|
4816
5409
|
device const char * src0,
|
|
@@ -5114,24 +5707,6 @@ kernel void kernel_argsort_merge_f32_i32(
|
|
|
5114
5707
|
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
|
|
5115
5708
|
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
|
|
5116
5709
|
|
|
5117
|
-
kernel void kernel_leaky_relu_f32(
|
|
5118
|
-
constant ggml_metal_kargs_leaky_relu & args,
|
|
5119
|
-
device const float * src0,
|
|
5120
|
-
device float * dst,
|
|
5121
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
5122
|
-
const float x = src0[tpig];
|
|
5123
|
-
dst[tpig] = x > 0.0f ? x : x * args.slope;
|
|
5124
|
-
}
|
|
5125
|
-
|
|
5126
|
-
kernel void kernel_leaky_relu_f32_4(
|
|
5127
|
-
constant ggml_metal_kargs_leaky_relu & args,
|
|
5128
|
-
device const float4 * src0,
|
|
5129
|
-
device float4 * dst,
|
|
5130
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
5131
|
-
const float4 x = src0[tpig];
|
|
5132
|
-
dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
|
|
5133
|
-
}
|
|
5134
|
-
|
|
5135
5710
|
constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
|
|
5136
5711
|
|
|
5137
5712
|
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
|
|
@@ -5208,6 +5783,7 @@ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_E
|
|
|
5208
5783
|
// scan the blocks of the mask that are not masked
|
|
5209
5784
|
// 0 - masked (i.e. full of -INF, skip)
|
|
5210
5785
|
// 1 - not masked (i.e. at least one element of the mask is not -INF)
|
|
5786
|
+
// 2 - all zero
|
|
5211
5787
|
kernel void kernel_flash_attn_ext_blk(
|
|
5212
5788
|
constant ggml_metal_kargs_flash_attn_ext_blk & args,
|
|
5213
5789
|
device const char * mask,
|
|
@@ -5229,27 +5805,29 @@ kernel void kernel_flash_attn_ext_blk(
|
|
|
5229
5805
|
|
|
5230
5806
|
device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
|
|
5231
5807
|
|
|
5232
|
-
// fast route
|
|
5233
|
-
if (res == 0) {
|
|
5234
|
-
if (simd_max(*mask_src) > -MAXHALF/2) {
|
|
5235
|
-
res = 1;
|
|
5236
|
-
}
|
|
5237
|
-
}
|
|
5238
|
-
|
|
5239
5808
|
// detailed check of the elements of the block
|
|
5240
5809
|
if ((C > NW || Q > 1) && res == 0) {
|
|
5241
|
-
half
|
|
5810
|
+
half mmin = MAXHALF;
|
|
5811
|
+
half mmax = -MAXHALF;
|
|
5242
5812
|
|
|
5243
5813
|
FOR_UNROLL (short j = 0; j < Q; ++j) {
|
|
5244
5814
|
FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
|
|
5245
|
-
|
|
5815
|
+
mmin = min(mmin, mask_src[ii*NW]);
|
|
5816
|
+
mmax = max(mmax, mask_src[ii*NW]);
|
|
5246
5817
|
}
|
|
5247
5818
|
|
|
5248
5819
|
mask_src += args.nb31/2;
|
|
5249
5820
|
}
|
|
5250
5821
|
|
|
5251
|
-
|
|
5252
|
-
|
|
5822
|
+
mmin = simd_min(mmin);
|
|
5823
|
+
mmax = simd_max(mmax);
|
|
5824
|
+
|
|
5825
|
+
if (mmax > -MAXHALF) {
|
|
5826
|
+
if (mmin == 0.0 && mmax == 0.0) {
|
|
5827
|
+
res = 2;
|
|
5828
|
+
} else {
|
|
5829
|
+
res = 1;
|
|
5830
|
+
}
|
|
5253
5831
|
}
|
|
5254
5832
|
}
|
|
5255
5833
|
|
|
@@ -5491,9 +6069,13 @@ void kernel_flash_attn_ext_impl(
|
|
|
5491
6069
|
ic = 0;
|
|
5492
6070
|
}
|
|
5493
6071
|
|
|
6072
|
+
char blk_cur = 1;
|
|
6073
|
+
|
|
5494
6074
|
// read the mask into shared mem
|
|
5495
6075
|
if (FC_flash_attn_ext_has_mask) {
|
|
5496
|
-
|
|
6076
|
+
blk_cur = blk[ic0];
|
|
6077
|
+
|
|
6078
|
+
if (blk_cur == 0) {
|
|
5497
6079
|
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
|
5498
6080
|
pm2[jj] += NW;
|
|
5499
6081
|
}
|
|
@@ -5501,16 +6083,22 @@ void kernel_flash_attn_ext_impl(
|
|
|
5501
6083
|
continue;
|
|
5502
6084
|
}
|
|
5503
6085
|
|
|
5504
|
-
|
|
5505
|
-
|
|
6086
|
+
if (blk_cur == 1) {
|
|
6087
|
+
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
|
6088
|
+
const short j = jj*NSG + sgitg;
|
|
5506
6089
|
|
|
5507
|
-
|
|
5508
|
-
|
|
5509
|
-
|
|
5510
|
-
|
|
5511
|
-
|
|
6090
|
+
if (FC_flash_attn_ext_bc_mask) {
|
|
6091
|
+
sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
|
|
6092
|
+
} else {
|
|
6093
|
+
sm2[j*SH + tiisg] = pm2[jj][tiisg];
|
|
6094
|
+
}
|
|
5512
6095
|
|
|
5513
|
-
|
|
6096
|
+
pm2[jj] += NW;
|
|
6097
|
+
}
|
|
6098
|
+
} else if (blk_cur == 2) {
|
|
6099
|
+
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
|
6100
|
+
pm2[jj] += NW;
|
|
6101
|
+
}
|
|
5514
6102
|
}
|
|
5515
6103
|
|
|
5516
6104
|
#if 0
|
|
@@ -5552,9 +6140,7 @@ void kernel_flash_attn_ext_impl(
|
|
|
5552
6140
|
|
|
5553
6141
|
constexpr short NC = (C/8)/NSG;
|
|
5554
6142
|
|
|
5555
|
-
|
|
5556
|
-
#pragma unroll (DK <= 64 ? NC : 1)
|
|
5557
|
-
for (short cc = 0; cc < NC; ++cc) {
|
|
6143
|
+
FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
|
|
5558
6144
|
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
|
|
5559
6145
|
|
|
5560
6146
|
if (DK % 16 != 0) {
|
|
@@ -5575,7 +6161,9 @@ void kernel_flash_attn_ext_impl(
|
|
|
5575
6161
|
k8x8_t mk[2];
|
|
5576
6162
|
q8x8_t mq[2];
|
|
5577
6163
|
|
|
5578
|
-
|
|
6164
|
+
// note: too much unroll can tank the performance for large heads
|
|
6165
|
+
#pragma unroll (MIN(DK8/2, 4*NSG))
|
|
6166
|
+
for (short i = 0; i < DK8/2; ++i) {
|
|
5579
6167
|
simdgroup_barrier(mem_flags::mem_none);
|
|
5580
6168
|
|
|
5581
6169
|
simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
|
|
@@ -5675,10 +6263,12 @@ void kernel_flash_attn_ext_impl(
|
|
|
5675
6263
|
}
|
|
5676
6264
|
|
|
5677
6265
|
// mqk = mqk + slope*mask
|
|
5678
|
-
if (
|
|
5679
|
-
|
|
5680
|
-
|
|
5681
|
-
|
|
6266
|
+
if (blk_cur != 2) {
|
|
6267
|
+
if (FC_flash_attn_ext_has_bias) {
|
|
6268
|
+
s2 += s2_t(sm2[j*SH + tiisg])*slope;
|
|
6269
|
+
} else {
|
|
6270
|
+
s2 += s2_t(sm2[j*SH + tiisg]);
|
|
6271
|
+
}
|
|
5682
6272
|
}
|
|
5683
6273
|
|
|
5684
6274
|
M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
|
|
@@ -5749,7 +6339,9 @@ void kernel_flash_attn_ext_impl(
|
|
|
5749
6339
|
pv += 8*NS20;
|
|
5750
6340
|
}
|
|
5751
6341
|
} else {
|
|
5752
|
-
|
|
6342
|
+
constexpr short NC = (C/8)/2;
|
|
6343
|
+
|
|
6344
|
+
FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
|
|
5753
6345
|
s8x8_t vs[2];
|
|
5754
6346
|
|
|
5755
6347
|
simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
|
|
@@ -5929,7 +6521,7 @@ template<
|
|
|
5929
6521
|
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
|
|
5930
6522
|
short DK, // K head size
|
|
5931
6523
|
short DV, // V head size
|
|
5932
|
-
short Q =
|
|
6524
|
+
short Q = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup
|
|
5933
6525
|
short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
|
|
5934
6526
|
kernel void kernel_flash_attn_ext(
|
|
5935
6527
|
constant ggml_metal_kargs_flash_attn_ext & args,
|
|
@@ -5952,6 +6544,7 @@ kernel void kernel_flash_attn_ext(
|
|
|
5952
6544
|
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
|
5953
6545
|
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
|
5954
6546
|
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
|
6547
|
+
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
|
5955
6548
|
}
|
|
5956
6549
|
#undef FWD_TMPL
|
|
5957
6550
|
#undef FWD_ARGS
|
|
@@ -6001,6 +6594,8 @@ template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_at
|
|
|
6001
6594
|
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
|
|
6002
6595
|
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
|
|
6003
6596
|
template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
|
|
6597
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 320, 256>;
|
|
6598
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 512, 512>;
|
|
6004
6599
|
template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
|
|
6005
6600
|
|
|
6006
6601
|
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
|
|
@@ -6015,6 +6610,8 @@ template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_at
|
|
|
6015
6610
|
template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
|
|
6016
6611
|
template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
|
|
6017
6612
|
template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
|
|
6613
|
+
template [[host_name("kernel_flash_attn_ext_f16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 320, 256>;
|
|
6614
|
+
template [[host_name("kernel_flash_attn_ext_f16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 512, 512>;
|
|
6018
6615
|
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
|
6019
6616
|
|
|
6020
6617
|
#if defined(GGML_METAL_HAS_BF16)
|
|
@@ -6030,6 +6627,8 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_at
|
|
|
6030
6627
|
template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
|
6031
6628
|
template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
|
6032
6629
|
template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
|
6630
|
+
template [[host_name("kernel_flash_attn_ext_bf16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 320, 256>;
|
|
6631
|
+
template [[host_name("kernel_flash_attn_ext_bf16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 512, 512>;
|
|
6033
6632
|
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
|
6034
6633
|
#endif
|
|
6035
6634
|
|
|
@@ -6045,6 +6644,8 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_at
|
|
|
6045
6644
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
|
|
6046
6645
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
|
|
6047
6646
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
|
6647
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 320, 256>;
|
|
6648
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 512, 512>;
|
|
6048
6649
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
|
6049
6650
|
|
|
6050
6651
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
|
|
@@ -6059,6 +6660,8 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_at
|
|
|
6059
6660
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
|
|
6060
6661
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
|
|
6061
6662
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
|
6663
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 320, 256>;
|
|
6664
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 512, 512>;
|
|
6062
6665
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
|
6063
6666
|
|
|
6064
6667
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
|
|
@@ -6073,6 +6676,8 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_at
|
|
|
6073
6676
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
|
|
6074
6677
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
|
|
6075
6678
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
|
6679
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 320, 256>;
|
|
6680
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 512, 512>;
|
|
6076
6681
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
|
6077
6682
|
|
|
6078
6683
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
|
|
@@ -6087,6 +6692,8 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_at
|
|
|
6087
6692
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
|
|
6088
6693
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
|
|
6089
6694
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
|
6695
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 320, 256>;
|
|
6696
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 512, 512>;
|
|
6090
6697
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
|
6091
6698
|
|
|
6092
6699
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
|
|
@@ -6101,6 +6708,8 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_at
|
|
|
6101
6708
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
|
|
6102
6709
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
|
|
6103
6710
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
|
|
6711
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 320, 256>;
|
|
6712
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 512, 512>;
|
|
6104
6713
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
|
|
6105
6714
|
|
|
6106
6715
|
#undef FA_TYPES
|
|
@@ -6138,11 +6747,10 @@ template<
|
|
|
6138
6747
|
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
|
|
6139
6748
|
short DK, // K head size
|
|
6140
6749
|
short DV, // V head size
|
|
6141
|
-
short NE,
|
|
6142
|
-
short Q,
|
|
6143
|
-
short C
|
|
6144
|
-
|
|
6145
|
-
void kernel_flash_attn_ext_vec_impl(
|
|
6750
|
+
short NE = 4, // head elements per thread
|
|
6751
|
+
short Q = OP_FLASH_ATTN_EXT_VEC_NQPSG, // queries per threadgroup
|
|
6752
|
+
short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
|
|
6753
|
+
kernel void kernel_flash_attn_ext_vec(
|
|
6146
6754
|
constant ggml_metal_kargs_flash_attn_ext_vec & args,
|
|
6147
6755
|
device const char * q,
|
|
6148
6756
|
device const char * k,
|
|
@@ -6159,6 +6767,7 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
6159
6767
|
static_assert(DV % 32 == 0, "DV must be divisible by 32");
|
|
6160
6768
|
|
|
6161
6769
|
#define NWG (FC_flash_attn_ext_vec_nwg)
|
|
6770
|
+
#define NSG (FC_flash_attn_ext_vec_nsg)
|
|
6162
6771
|
|
|
6163
6772
|
#define NS10 (FC_flash_attn_ext_vec_ns10)
|
|
6164
6773
|
#define NS20 (FC_flash_attn_ext_vec_ns20)
|
|
@@ -6185,14 +6794,14 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
6185
6794
|
static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
|
|
6186
6795
|
static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
|
|
6187
6796
|
|
|
6188
|
-
|
|
6797
|
+
//const short T = PK + NSG*SH; // shared memory size per query in (half)
|
|
6189
6798
|
|
|
6190
|
-
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 +
|
|
6191
|
-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +
|
|
6192
|
-
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH +
|
|
6193
|
-
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH +
|
|
6194
|
-
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C +
|
|
6195
|
-
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV +
|
|
6799
|
+
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
|
|
6800
|
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
|
|
6801
|
+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + NSG*PK); // scratch buffer for attention
|
|
6802
|
+
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + NSG*PK); // same as above but in s4_t
|
|
6803
|
+
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask
|
|
6804
|
+
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + NSG*PK + NSG*SH); // scratch buffer for the results
|
|
6196
6805
|
|
|
6197
6806
|
// store the result for all queries in shared memory (the O matrix from the paper)
|
|
6198
6807
|
so4 += tiisg;
|
|
@@ -6210,11 +6819,13 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
6210
6819
|
// load heads from Q to shared memory
|
|
6211
6820
|
device const float4 * q4 = (device const float4 *) ((device const char *) q);
|
|
6212
6821
|
|
|
6213
|
-
|
|
6214
|
-
|
|
6215
|
-
|
|
6216
|
-
|
|
6217
|
-
|
|
6822
|
+
if (iq1 < args.ne01) {
|
|
6823
|
+
for (short i = tiisg; i < PK4; i += NW) {
|
|
6824
|
+
if (i < DK4) {
|
|
6825
|
+
sq4[i] = (q4_t) q4[i];
|
|
6826
|
+
} else {
|
|
6827
|
+
sq4[i] = (q4_t) 0.0f;
|
|
6828
|
+
}
|
|
6218
6829
|
}
|
|
6219
6830
|
}
|
|
6220
6831
|
|
|
@@ -6292,7 +6903,7 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
6292
6903
|
}
|
|
6293
6904
|
|
|
6294
6905
|
// skip -INF blocks
|
|
6295
|
-
if (simd_max(sm[tiisg])
|
|
6906
|
+
if (simd_max(sm[tiisg]) <= -MAXHALF) {
|
|
6296
6907
|
continue;
|
|
6297
6908
|
}
|
|
6298
6909
|
|
|
@@ -6566,57 +7177,11 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
6566
7177
|
}
|
|
6567
7178
|
|
|
6568
7179
|
#undef NWG
|
|
7180
|
+
#undef NSG
|
|
6569
7181
|
#undef NS10
|
|
6570
7182
|
#undef NS20
|
|
6571
7183
|
}
|
|
6572
7184
|
|
|
6573
|
-
template<
|
|
6574
|
-
typename q4_t, // query types in shared memory
|
|
6575
|
-
typename k4_t, // key types in shared memory
|
|
6576
|
-
typename v4_t, // value types in shared memory
|
|
6577
|
-
typename qk_t, // Q*K types
|
|
6578
|
-
typename s_t, // soft-max types
|
|
6579
|
-
typename s4_t,
|
|
6580
|
-
typename o4_t, // attention accumulation types
|
|
6581
|
-
typename kd4_t, // key type in device memory
|
|
6582
|
-
short nl_k,
|
|
6583
|
-
void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
|
|
6584
|
-
typename vd4_t, // value type in device memory
|
|
6585
|
-
short nl_v,
|
|
6586
|
-
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
|
|
6587
|
-
short DK, // K head size
|
|
6588
|
-
short DV, // V head size
|
|
6589
|
-
short NE = 4, // head elements per thread
|
|
6590
|
-
short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup
|
|
6591
|
-
short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
|
|
6592
|
-
kernel void kernel_flash_attn_ext_vec(
|
|
6593
|
-
constant ggml_metal_kargs_flash_attn_ext_vec & args,
|
|
6594
|
-
device const char * q,
|
|
6595
|
-
device const char * k,
|
|
6596
|
-
device const char * v,
|
|
6597
|
-
device const char * mask,
|
|
6598
|
-
device const char * sinks,
|
|
6599
|
-
device const char * pad,
|
|
6600
|
-
device char * dst,
|
|
6601
|
-
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
|
6602
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6603
|
-
ushort tiisg[[thread_index_in_simdgroup]],
|
|
6604
|
-
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
6605
|
-
#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
|
|
6606
|
-
#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
|
|
6607
|
-
switch (FC_flash_attn_ext_vec_nsg) {
|
|
6608
|
-
// note: disabled cases to reduce library load time
|
|
6609
|
-
case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
|
6610
|
-
case 2: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
|
6611
|
-
case 4: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
|
6612
|
-
//case 8: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
|
6613
|
-
//case 16: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 16>(FWD_ARGS); break;
|
|
6614
|
-
//case 32: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 32>(FWD_ARGS); break;
|
|
6615
|
-
}
|
|
6616
|
-
#undef FWD_TMPL
|
|
6617
|
-
#undef FWD_ARGS
|
|
6618
|
-
}
|
|
6619
|
-
|
|
6620
7185
|
// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
|
|
6621
7186
|
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
|
|
6622
7187
|
//
|
|
@@ -6715,6 +7280,28 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas
|
|
|
6715
7280
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
|
|
6716
7281
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
|
|
6717
7282
|
|
|
7283
|
+
template [[host_name("kernel_flash_attn_ext_vec_f32_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 320, 256, 2>;
|
|
7284
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 320, 256, 2>;
|
|
7285
|
+
#if defined(GGML_METAL_HAS_BF16)
|
|
7286
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 320, 256, 2>;
|
|
7287
|
+
#endif
|
|
7288
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 320, 256, 2>;
|
|
7289
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 320, 256, 2>;
|
|
7290
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 320, 256, 2>;
|
|
7291
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 320, 256, 2>;
|
|
7292
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 320, 256, 2>;
|
|
7293
|
+
|
|
7294
|
+
template [[host_name("kernel_flash_attn_ext_vec_f32_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 512, 512, 1>;
|
|
7295
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 512, 512, 1>;
|
|
7296
|
+
#if defined(GGML_METAL_HAS_BF16)
|
|
7297
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 512, 512, 1>;
|
|
7298
|
+
#endif
|
|
7299
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 512, 512, 1>;
|
|
7300
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 512, 512, 1>;
|
|
7301
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 512, 512, 1>;
|
|
7302
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 512, 512, 1>;
|
|
7303
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 512, 512, 1>;
|
|
7304
|
+
|
|
6718
7305
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
|
|
6719
7306
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
|
|
6720
7307
|
#if defined(GGML_METAL_HAS_BF16)
|
|
@@ -6780,23 +7367,27 @@ kernel void kernel_cpy_t_t(
|
|
|
6780
7367
|
device const char * src0,
|
|
6781
7368
|
device char * dst,
|
|
6782
7369
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6783
|
-
|
|
7370
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
6784
7371
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
6785
|
-
const
|
|
6786
|
-
const
|
|
6787
|
-
const
|
|
6788
|
-
const
|
|
7372
|
+
const int32_t i03 = tgpig[2];
|
|
7373
|
+
const int32_t i02 = tgpig[1];
|
|
7374
|
+
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
|
7375
|
+
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
|
7376
|
+
|
|
7377
|
+
if (i01 >= args.ne01) {
|
|
7378
|
+
return;
|
|
7379
|
+
}
|
|
6789
7380
|
|
|
6790
7381
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
6791
7382
|
|
|
6792
|
-
const
|
|
6793
|
-
const
|
|
6794
|
-
const
|
|
6795
|
-
const
|
|
7383
|
+
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
|
7384
|
+
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
|
7385
|
+
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
|
7386
|
+
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
|
6796
7387
|
|
|
6797
7388
|
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
6798
7389
|
|
|
6799
|
-
for (
|
|
7390
|
+
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) {
|
|
6800
7391
|
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
6801
7392
|
dst_data[i00] = (T1) src[0];
|
|
6802
7393
|
break;
|
|
@@ -6828,23 +7419,27 @@ kernel void kernel_cpy_f32_q(
|
|
|
6828
7419
|
device const char * src0,
|
|
6829
7420
|
device char * dst,
|
|
6830
7421
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6831
|
-
|
|
7422
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
6832
7423
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
6833
|
-
const
|
|
6834
|
-
const
|
|
6835
|
-
const
|
|
6836
|
-
const
|
|
7424
|
+
const int32_t i03 = tgpig[2];
|
|
7425
|
+
const int32_t i02 = tgpig[1];
|
|
7426
|
+
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
|
7427
|
+
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
|
7428
|
+
|
|
7429
|
+
if (i01 >= args.ne01) {
|
|
7430
|
+
return;
|
|
7431
|
+
}
|
|
6837
7432
|
|
|
6838
7433
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
6839
7434
|
|
|
6840
|
-
const
|
|
6841
|
-
const
|
|
6842
|
-
const
|
|
6843
|
-
const
|
|
7435
|
+
const int32_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
|
7436
|
+
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
|
7437
|
+
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
|
7438
|
+
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
|
|
6844
7439
|
|
|
6845
7440
|
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
6846
7441
|
|
|
6847
|
-
for (
|
|
7442
|
+
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
|
|
6848
7443
|
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
|
|
6849
7444
|
|
|
6850
7445
|
quantize_func(src, dst_data[i00]);
|
|
@@ -6856,6 +7451,7 @@ kernel void kernel_cpy_f32_q(
|
|
|
6856
7451
|
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
|
|
6857
7452
|
|
|
6858
7453
|
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
|
|
7454
|
+
template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
|
|
6859
7455
|
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
|
|
6860
7456
|
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
|
|
6861
7457
|
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
|
|
@@ -6868,24 +7464,28 @@ kernel void kernel_cpy_q_f32(
|
|
|
6868
7464
|
device const char * src0,
|
|
6869
7465
|
device char * dst,
|
|
6870
7466
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6871
|
-
|
|
7467
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
6872
7468
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
6873
|
-
const
|
|
6874
|
-
const
|
|
6875
|
-
const
|
|
6876
|
-
const
|
|
7469
|
+
const int32_t i03 = tgpig[2];
|
|
7470
|
+
const int32_t i02 = tgpig[1];
|
|
7471
|
+
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
|
7472
|
+
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
|
7473
|
+
|
|
7474
|
+
if (i01 >= args.ne01) {
|
|
7475
|
+
return;
|
|
7476
|
+
}
|
|
6877
7477
|
|
|
6878
7478
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
6879
7479
|
|
|
6880
|
-
const
|
|
6881
|
-
const
|
|
6882
|
-
const
|
|
6883
|
-
const
|
|
7480
|
+
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
|
7481
|
+
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
|
7482
|
+
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
|
7483
|
+
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
|
6884
7484
|
|
|
6885
7485
|
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
|
6886
7486
|
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
6887
7487
|
|
|
6888
|
-
for (
|
|
7488
|
+
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
|
|
6889
7489
|
T4x4 temp;
|
|
6890
7490
|
dequantize_func(src_data + i00/nl, i00%nl, temp);
|
|
6891
7491
|
dst_data[i00] = temp;
|
|
@@ -6896,12 +7496,14 @@ kernel void kernel_cpy_q_f32(
|
|
|
6896
7496
|
|
|
6897
7497
|
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
|
|
6898
7498
|
|
|
7499
|
+
template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
|
|
6899
7500
|
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
|
|
6900
7501
|
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
|
|
6901
7502
|
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
|
|
6902
7503
|
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
|
|
6903
7504
|
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
|
|
6904
7505
|
|
|
7506
|
+
template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
|
|
6905
7507
|
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
|
|
6906
7508
|
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
|
|
6907
7509
|
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
|
|
@@ -6919,7 +7521,11 @@ kernel void kernel_concat(
|
|
|
6919
7521
|
|
|
6920
7522
|
const int i3 = tgpig.z;
|
|
6921
7523
|
const int i2 = tgpig.y;
|
|
6922
|
-
const int i1 = tgpig.x;
|
|
7524
|
+
const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y;
|
|
7525
|
+
|
|
7526
|
+
if (i1 >= args.ne1) {
|
|
7527
|
+
return;
|
|
7528
|
+
}
|
|
6923
7529
|
|
|
6924
7530
|
int o[4] = {0, 0, 0, 0};
|
|
6925
7531
|
o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
|
|
@@ -6959,10 +7565,10 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
|
6959
7565
|
|
|
6960
7566
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
6961
7567
|
|
|
6962
|
-
const uint i12 = im%
|
|
6963
|
-
const uint i13 = im/
|
|
7568
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
7569
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
6964
7570
|
|
|
6965
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
7571
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
6966
7572
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
6967
7573
|
|
|
6968
7574
|
device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
|
|
@@ -7064,10 +7670,10 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
|
7064
7670
|
|
|
7065
7671
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
7066
7672
|
|
|
7067
|
-
const uint i12 = im%
|
|
7068
|
-
const uint i13 = im/
|
|
7673
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
7674
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
7069
7675
|
|
|
7070
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
7676
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
7071
7677
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
7072
7678
|
|
|
7073
7679
|
device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
|
|
@@ -7238,10 +7844,10 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
|
7238
7844
|
|
|
7239
7845
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
7240
7846
|
|
|
7241
|
-
const uint i12 = im%
|
|
7242
|
-
const uint i13 = im/
|
|
7847
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
7848
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
7243
7849
|
|
|
7244
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
7850
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
7245
7851
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
7246
7852
|
|
|
7247
7853
|
device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
|
|
@@ -7350,10 +7956,10 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
|
7350
7956
|
|
|
7351
7957
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
7352
7958
|
|
|
7353
|
-
const uint i12 = im%
|
|
7354
|
-
const uint i13 = im/
|
|
7959
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
7960
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
7355
7961
|
|
|
7356
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
7962
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
7357
7963
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
7358
7964
|
|
|
7359
7965
|
device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
|
|
@@ -7486,10 +8092,10 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
|
7486
8092
|
|
|
7487
8093
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
7488
8094
|
|
|
7489
|
-
const uint i12 = im%
|
|
7490
|
-
const uint i13 = im/
|
|
8095
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8096
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
7491
8097
|
|
|
7492
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8098
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
7493
8099
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
7494
8100
|
|
|
7495
8101
|
device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
|
|
@@ -7591,10 +8197,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
|
7591
8197
|
|
|
7592
8198
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
7593
8199
|
|
|
7594
|
-
const uint i12 = im%
|
|
7595
|
-
const uint i13 = im/
|
|
8200
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8201
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
7596
8202
|
|
|
7597
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8203
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
7598
8204
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
7599
8205
|
|
|
7600
8206
|
device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
|
|
@@ -7699,10 +8305,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
|
7699
8305
|
|
|
7700
8306
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
7701
8307
|
|
|
7702
|
-
const uint i12 = im%
|
|
7703
|
-
const uint i13 = im/
|
|
8308
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8309
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
7704
8310
|
|
|
7705
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8311
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
7706
8312
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
7707
8313
|
|
|
7708
8314
|
device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
|
|
@@ -7818,10 +8424,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
|
7818
8424
|
|
|
7819
8425
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
7820
8426
|
|
|
7821
|
-
const uint i12 = im%
|
|
7822
|
-
const uint i13 = im/
|
|
8427
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8428
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
7823
8429
|
|
|
7824
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8430
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
7825
8431
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
7826
8432
|
|
|
7827
8433
|
device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
|
|
@@ -7930,10 +8536,10 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
|
7930
8536
|
|
|
7931
8537
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
7932
8538
|
|
|
7933
|
-
const uint i12 = im%
|
|
7934
|
-
const uint i13 = im/
|
|
8539
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8540
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
7935
8541
|
|
|
7936
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8542
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
7937
8543
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
7938
8544
|
|
|
7939
8545
|
device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
|
|
@@ -8042,10 +8648,10 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
|
8042
8648
|
|
|
8043
8649
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
8044
8650
|
|
|
8045
|
-
const uint i12 = im%
|
|
8046
|
-
const uint i13 = im/
|
|
8651
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8652
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
8047
8653
|
|
|
8048
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8654
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
8049
8655
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
8050
8656
|
|
|
8051
8657
|
device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
|
|
@@ -8155,10 +8761,10 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
|
8155
8761
|
|
|
8156
8762
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
8157
8763
|
|
|
8158
|
-
const uint i12 = im%
|
|
8159
|
-
const uint i13 = im/
|
|
8764
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8765
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
8160
8766
|
|
|
8161
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8767
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
8162
8768
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
8163
8769
|
|
|
8164
8770
|
device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
|
|
@@ -8254,10 +8860,10 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
|
8254
8860
|
|
|
8255
8861
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
8256
8862
|
|
|
8257
|
-
const uint i12 = im%
|
|
8258
|
-
const uint i13 = im/
|
|
8863
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8864
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
8259
8865
|
|
|
8260
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8866
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
8261
8867
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
8262
8868
|
|
|
8263
8869
|
device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
|
|
@@ -8363,10 +8969,10 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
|
8363
8969
|
|
|
8364
8970
|
const int first_row = (r0 * NSG + sgitg) * NR0;
|
|
8365
8971
|
|
|
8366
|
-
const uint i12 = im%
|
|
8367
|
-
const uint i13 = im/
|
|
8972
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
8973
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
8368
8974
|
|
|
8369
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
8975
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
8370
8976
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
8371
8977
|
|
|
8372
8978
|
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
|
|
@@ -8472,10 +9078,10 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
|
8472
9078
|
const int im = tgpig.z;
|
|
8473
9079
|
const int first_row = (r0 * NSG + sgitg) * NR0;
|
|
8474
9080
|
|
|
8475
|
-
const uint i12 = im%
|
|
8476
|
-
const uint i13 = im/
|
|
9081
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
9082
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
8477
9083
|
|
|
8478
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
9084
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
8479
9085
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
8480
9086
|
|
|
8481
9087
|
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
|
|
@@ -8583,10 +9189,10 @@ void kernel_mul_mv_mxfp4_f32_impl(
|
|
|
8583
9189
|
|
|
8584
9190
|
const int first_row = (r0 * NSG + sgitg) * NR0;
|
|
8585
9191
|
|
|
8586
|
-
const uint i12 = im%
|
|
8587
|
-
const uint i13 = im/
|
|
9192
|
+
const uint i12 = im%FC_mul_mv_ne12;
|
|
9193
|
+
const uint i13 = im/FC_mul_mv_ne12;
|
|
8588
9194
|
|
|
8589
|
-
const uint64_t offset0 = first_row*args.nb01 + (i12/
|
|
9195
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
|
|
8590
9196
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
8591
9197
|
|
|
8592
9198
|
device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
|
|
@@ -8779,11 +9385,165 @@ kernel void kernel_set_rows_f(
|
|
|
8779
9385
|
}
|
|
8780
9386
|
}
|
|
8781
9387
|
|
|
9388
|
+
kernel void kernel_diag_f32(
|
|
9389
|
+
constant ggml_metal_kargs_diag & args,
|
|
9390
|
+
device const char * src0,
|
|
9391
|
+
device char * dst,
|
|
9392
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
9393
|
+
ushort tiitg[[thread_index_in_threadgroup]]) {
|
|
9394
|
+
constexpr short NW = N_SIMDWIDTH;
|
|
9395
|
+
|
|
9396
|
+
const int32_t i3 = tgpig.z;
|
|
9397
|
+
const int32_t i2 = tgpig.y;
|
|
9398
|
+
const int32_t i1 = tgpig.x;
|
|
9399
|
+
|
|
9400
|
+
device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03);
|
|
9401
|
+
device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3);
|
|
9402
|
+
|
|
9403
|
+
for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
|
|
9404
|
+
dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
|
|
9405
|
+
}
|
|
9406
|
+
}
|
|
9407
|
+
|
|
8782
9408
|
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
|
|
8783
9409
|
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
|
|
9410
|
+
constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]];
|
|
9411
|
+
constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]];
|
|
9412
|
+
constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]];
|
|
9413
|
+
constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]];
|
|
8784
9414
|
|
|
8785
9415
|
// each block_q contains 16*nl weights
|
|
8786
|
-
|
|
9416
|
+
#ifdef GGML_METAL_HAS_TENSOR
|
|
9417
|
+
template<
|
|
9418
|
+
typename SA, typename SA_4x4, typename SA_8x8,
|
|
9419
|
+
typename SB, typename SB_2x4, typename SB_8x8,
|
|
9420
|
+
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &),
|
|
9421
|
+
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
|
9422
|
+
kernel void kernel_mul_mm(
|
|
9423
|
+
constant ggml_metal_kargs_mul_mm & args,
|
|
9424
|
+
device const char * srcA,
|
|
9425
|
+
device const char * srcB,
|
|
9426
|
+
device char * dst,
|
|
9427
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
9428
|
+
uint3 tgpig [[threadgroup_position_in_grid]],
|
|
9429
|
+
ushort tiitg [[thread_index_in_threadgroup]],
|
|
9430
|
+
ushort sgitg [[simdgroup_index_in_threadgroup]]) {
|
|
9431
|
+
(void) sgitg;
|
|
9432
|
+
|
|
9433
|
+
// Matrix dimensions: A(M,K) x B(K,N) -> C(M,N)
|
|
9434
|
+
const int K = args.ne00;
|
|
9435
|
+
const int M = args.ne0;
|
|
9436
|
+
const int N = args.ne1;
|
|
9437
|
+
|
|
9438
|
+
// Batch dimension handling
|
|
9439
|
+
const int im = tgpig.z;
|
|
9440
|
+
const int i12 = im % FC_mul_mm_ne12;
|
|
9441
|
+
const int i13 = im / FC_mul_mm_ne12;
|
|
9442
|
+
|
|
9443
|
+
// Batch offsets for srcA and srcB
|
|
9444
|
+
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
|
|
9445
|
+
|
|
9446
|
+
// Tile dimensions
|
|
9447
|
+
constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;
|
|
9448
|
+
constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y;
|
|
9449
|
+
|
|
9450
|
+
// Tile offsets in output matrix
|
|
9451
|
+
const int ra = tgpig.y * NRA;
|
|
9452
|
+
const int rb = tgpig.x * NRB;
|
|
9453
|
+
|
|
9454
|
+
// Threadgroup memory for dequantized A tile only
|
|
9455
|
+
threadgroup SA * sa = (threadgroup SA *)(shmem);
|
|
9456
|
+
|
|
9457
|
+
// Work-item count for A loading
|
|
9458
|
+
constexpr int A_WORK_ITEMS = NRA * N_MM_NK;
|
|
9459
|
+
constexpr int NUM_THREADS = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y;
|
|
9460
|
+
|
|
9461
|
+
// tA wraps threadgroup memory
|
|
9462
|
+
auto tA = tensor(sa, dextents<int32_t, 2>(N_MM_NK_TOTAL, NRA));
|
|
9463
|
+
|
|
9464
|
+
// tB wraps device memory directly
|
|
9465
|
+
device T1 * ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13);
|
|
9466
|
+
const int strideB = args.nb11 / sizeof(T1);
|
|
9467
|
+
auto tB = tensor(ptrB, dextents<int32_t, 2>(K, N), array<int, 2>({1, strideB}));
|
|
9468
|
+
|
|
9469
|
+
// Configure matmul operation
|
|
9470
|
+
mpp::tensor_ops::matmul2d<
|
|
9471
|
+
mpp::tensor_ops::matmul2d_descriptor(
|
|
9472
|
+
NRB, NRA, N_MM_NK_TOTAL, false, true, true,
|
|
9473
|
+
mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
|
|
9474
|
+
execution_simdgroups<N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y>> mm;
|
|
9475
|
+
|
|
9476
|
+
auto cT = mm.get_destination_cooperative_tensor<decltype(tB), decltype(tA), float>();
|
|
9477
|
+
|
|
9478
|
+
// Accumulate partial results over K dimension
|
|
9479
|
+
for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) {
|
|
9480
|
+
// === PHASE 1: Dequantization of A into threadgroup memory ===
|
|
9481
|
+
for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) {
|
|
9482
|
+
const int row = work / N_MM_NK;
|
|
9483
|
+
const int k_chunk = work % N_MM_NK;
|
|
9484
|
+
const int k_pos = loop_k + k_chunk * 16;
|
|
9485
|
+
const short k_base = k_chunk * 16;
|
|
9486
|
+
|
|
9487
|
+
// Bounds check: skip device read if row is out of matrix bounds
|
|
9488
|
+
if (ra + row < M) {
|
|
9489
|
+
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
|
9490
|
+
// Element-wise reads when K is not aligned (nb01 not aligned for half4x4/float4x4).
|
|
9491
|
+
// MSL spec Table 2.5: half4x4 requires 8-byte alignment. When K is odd,
|
|
9492
|
+
// nb01 = K*2 is not 8-byte aligned, so odd-row pointers are misaligned.
|
|
9493
|
+
// Mirrors the legacy kernel's existing guard.
|
|
9494
|
+
device const T0 * row_ptr = (device const T0 *)(srcA + args.nb01 * (ra + row) + offset0);
|
|
9495
|
+
|
|
9496
|
+
FOR_UNROLL (short i = 0; i < 16; i++) {
|
|
9497
|
+
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? (SA) row_ptr[k_pos + i] : (SA)0;
|
|
9498
|
+
}
|
|
9499
|
+
} else {
|
|
9500
|
+
const int block_idx = k_pos / (16 * nl);
|
|
9501
|
+
const short il = (k_pos / 16) % nl;
|
|
9502
|
+
|
|
9503
|
+
device const block_q * row_ptr = (device const block_q *)(srcA + args.nb01 * (ra + row) + offset0);
|
|
9504
|
+
|
|
9505
|
+
SA_4x4 temp_a;
|
|
9506
|
+
dequantize_func(row_ptr + block_idx, il, temp_a);
|
|
9507
|
+
|
|
9508
|
+
FOR_UNROLL (short i = 0; i < 16; i++) {
|
|
9509
|
+
// Zero-pad A for K positions beyond valid range (handles partial K iterations)
|
|
9510
|
+
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0;
|
|
9511
|
+
}
|
|
9512
|
+
}
|
|
9513
|
+
} else {
|
|
9514
|
+
// Zero-pad rows beyond matrix bounds
|
|
9515
|
+
FOR_UNROLL (short i = 0; i < 16; i++) {
|
|
9516
|
+
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (SA)0;
|
|
9517
|
+
}
|
|
9518
|
+
}
|
|
9519
|
+
}
|
|
9520
|
+
|
|
9521
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
9522
|
+
|
|
9523
|
+
// === PHASE 2: Tensor matmul ===
|
|
9524
|
+
auto mA = tA.slice(0, 0);
|
|
9525
|
+
auto mB = tB.slice(loop_k, rb);
|
|
9526
|
+
|
|
9527
|
+
mm.run(mB, mA, cT);
|
|
9528
|
+
|
|
9529
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
9530
|
+
}
|
|
9531
|
+
|
|
9532
|
+
// Store result tile to output matrix (with batch offset)
|
|
9533
|
+
// cT.store handles bounds checking via tD's extents (M, N)
|
|
9534
|
+
device float * dstBatch = (device float *)dst + im * N * M;
|
|
9535
|
+
|
|
9536
|
+
auto tD = tensor(dstBatch, dextents<int32_t, 2>(M, N), array<int, 2>({1, M}));
|
|
9537
|
+
cT.store(tD.slice(ra, rb));
|
|
9538
|
+
}
|
|
9539
|
+
|
|
9540
|
+
#else
|
|
9541
|
+
|
|
9542
|
+
template<
|
|
9543
|
+
typename S0, typename S0_4x4, typename S0_8x8,
|
|
9544
|
+
typename S1, typename S1_2x4, typename S1_8x8,
|
|
9545
|
+
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &),
|
|
9546
|
+
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
|
8787
9547
|
kernel void kernel_mul_mm(
|
|
8788
9548
|
constant ggml_metal_kargs_mul_mm & args,
|
|
8789
9549
|
device const char * src0,
|
|
@@ -8797,8 +9557,6 @@ kernel void kernel_mul_mm(
|
|
|
8797
9557
|
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
|
8798
9558
|
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
|
8799
9559
|
|
|
8800
|
-
threadgroup float * sc = (threadgroup float *)(shmem);
|
|
8801
|
-
|
|
8802
9560
|
constexpr int NR0 = 64;
|
|
8803
9561
|
constexpr int NR1 = 32;
|
|
8804
9562
|
|
|
@@ -8822,10 +9580,10 @@ kernel void kernel_mul_mm(
|
|
|
8822
9580
|
|
|
8823
9581
|
short il = il0;
|
|
8824
9582
|
|
|
8825
|
-
const int i12 = im%
|
|
8826
|
-
const int i13 = im/
|
|
9583
|
+
const int i12 = im % FC_mul_mm_ne12;
|
|
9584
|
+
const int i13 = im / FC_mul_mm_ne12;
|
|
8827
9585
|
|
|
8828
|
-
const uint64_t offset0 = (i12/
|
|
9586
|
+
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
|
|
8829
9587
|
const short offset1 = il0/nl;
|
|
8830
9588
|
|
|
8831
9589
|
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
|
|
@@ -8838,7 +9596,6 @@ kernel void kernel_mul_mm(
|
|
|
8838
9596
|
+ args.nb11*(r1 + lr1)
|
|
8839
9597
|
+ args.nb10*iy);
|
|
8840
9598
|
|
|
8841
|
-
#ifndef GGML_METAL_HAS_TENSOR
|
|
8842
9599
|
S0_8x8 ma[4];
|
|
8843
9600
|
S1_8x8 mb[2];
|
|
8844
9601
|
|
|
@@ -8847,19 +9604,8 @@ kernel void kernel_mul_mm(
|
|
|
8847
9604
|
for (short i = 0; i < 8; i++){
|
|
8848
9605
|
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
|
8849
9606
|
}
|
|
8850
|
-
#else
|
|
8851
|
-
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
|
|
8852
|
-
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
|
|
8853
|
-
|
|
8854
|
-
mpp::tensor_ops::matmul2d<
|
|
8855
|
-
mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
|
|
8856
|
-
execution_simdgroups<4>> mm;
|
|
8857
|
-
|
|
8858
|
-
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
|
|
8859
|
-
#endif
|
|
8860
9607
|
|
|
8861
9608
|
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
|
|
8862
|
-
#ifndef GGML_METAL_HAS_TENSOR
|
|
8863
9609
|
// load data and store to threadgroup memory
|
|
8864
9610
|
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
|
8865
9611
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
@@ -8920,8 +9666,8 @@ kernel void kernel_mul_mm(
|
|
|
8920
9666
|
const short sx = (tiitg%NL1);
|
|
8921
9667
|
const short sy = (tiitg/NL1)/8;
|
|
8922
9668
|
|
|
8923
|
-
|
|
8924
|
-
|
|
9669
|
+
//const short dx = sx;
|
|
9670
|
+
//const short dy = sy;
|
|
8925
9671
|
|
|
8926
9672
|
const short ly = (tiitg/NL1)%8;
|
|
8927
9673
|
|
|
@@ -8929,66 +9675,6 @@ kernel void kernel_mul_mm(
|
|
|
8929
9675
|
|
|
8930
9676
|
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
|
|
8931
9677
|
}
|
|
8932
|
-
#else
|
|
8933
|
-
// load data and store to threadgroup memory
|
|
8934
|
-
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
|
8935
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
8936
|
-
|
|
8937
|
-
// no need for dequantization
|
|
8938
|
-
for (short i = 0; i < 16; i++) {
|
|
8939
|
-
const short sx = 2*il0 + i/8;
|
|
8940
|
-
const short sy = (tiitg/NL0)/8;
|
|
8941
|
-
|
|
8942
|
-
const short lx = i%8;
|
|
8943
|
-
const short ly = (tiitg/NL0)%8;
|
|
8944
|
-
//const short lx = (tiitg/NL0)%8;
|
|
8945
|
-
//const short ly = i%8;
|
|
8946
|
-
|
|
8947
|
-
*(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
|
|
8948
|
-
}
|
|
8949
|
-
} else {
|
|
8950
|
-
S0_4x4 temp_a;
|
|
8951
|
-
dequantize_func(x, il, temp_a);
|
|
8952
|
-
|
|
8953
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
8954
|
-
|
|
8955
|
-
FOR_UNROLL (short i = 0; i < 16; i++) {
|
|
8956
|
-
const short sx = 2*il0 + i/8;
|
|
8957
|
-
const short sy = (tiitg/NL0)/8;
|
|
8958
|
-
|
|
8959
|
-
const short lx = i%8;
|
|
8960
|
-
const short ly = (tiitg/NL0)%8;
|
|
8961
|
-
//const short lx = (tiitg/NL0)%8;
|
|
8962
|
-
//const short ly = i%8;
|
|
8963
|
-
|
|
8964
|
-
*(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
|
|
8965
|
-
}
|
|
8966
|
-
}
|
|
8967
|
-
|
|
8968
|
-
if (FC_mul_mm_bc_inp) {
|
|
8969
|
-
for (short i = 0; i < 8; ++i) {
|
|
8970
|
-
const short sx = (tiitg%NL1);
|
|
8971
|
-
const short sy = (tiitg/NL1)/8;
|
|
8972
|
-
|
|
8973
|
-
const short lx = i;
|
|
8974
|
-
const short ly = (tiitg/NL1)%8;
|
|
8975
|
-
//const short lx = (tiitg/NL1)%8;
|
|
8976
|
-
//const short ly = i;
|
|
8977
|
-
|
|
8978
|
-
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
|
|
8979
|
-
}
|
|
8980
|
-
} else {
|
|
8981
|
-
const short sx = (tiitg%NL1);
|
|
8982
|
-
const short sy = (tiitg/NL1)/8;
|
|
8983
|
-
|
|
8984
|
-
//const short lx = i;
|
|
8985
|
-
const short ly = (tiitg/NL1)%8;
|
|
8986
|
-
//const short lx = (tiitg/NL1)%8;
|
|
8987
|
-
//const short ly = i;
|
|
8988
|
-
|
|
8989
|
-
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
|
|
8990
|
-
}
|
|
8991
|
-
#endif
|
|
8992
9678
|
|
|
8993
9679
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
|
8994
9680
|
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
|
@@ -8997,7 +9683,6 @@ kernel void kernel_mul_mm(
|
|
|
8997
9683
|
|
|
8998
9684
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
8999
9685
|
|
|
9000
|
-
#ifndef GGML_METAL_HAS_TENSOR
|
|
9001
9686
|
// load matrices from threadgroup memory and conduct outer products
|
|
9002
9687
|
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
|
|
9003
9688
|
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
|
|
@@ -9024,24 +9709,10 @@ kernel void kernel_mul_mm(
|
|
|
9024
9709
|
lsma += 8*64;
|
|
9025
9710
|
lsmb += 4*64;
|
|
9026
9711
|
}
|
|
9027
|
-
#else
|
|
9028
|
-
auto sA = tA.slice(0, 0);
|
|
9029
|
-
auto sB = tB.slice(0, 0);
|
|
9030
|
-
|
|
9031
|
-
mm.run(sB, sA, cT);
|
|
9032
|
-
#endif
|
|
9033
9712
|
}
|
|
9034
9713
|
|
|
9035
9714
|
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
|
|
9036
9715
|
// if no bounds checks on the output are needed, we can directly write to device memory
|
|
9037
|
-
#ifdef GGML_METAL_HAS_TENSOR
|
|
9038
|
-
device float * C = (device float *) dst +
|
|
9039
|
-
r0 + \
|
|
9040
|
-
r1 * args.ne0 + im*args.ne1*args.ne0;
|
|
9041
|
-
|
|
9042
|
-
auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1));
|
|
9043
|
-
cT.store(tC);
|
|
9044
|
-
#else
|
|
9045
9716
|
device float * C = (device float *) dst +
|
|
9046
9717
|
(r0 + 32*(sgitg & 1)) + \
|
|
9047
9718
|
(r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
|
|
@@ -9049,21 +9720,15 @@ kernel void kernel_mul_mm(
|
|
|
9049
9720
|
for (short i = 0; i < 8; i++) {
|
|
9050
9721
|
simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
|
|
9051
9722
|
}
|
|
9052
|
-
#endif
|
|
9053
9723
|
} else {
|
|
9054
9724
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
|
9055
9725
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
9056
9726
|
|
|
9057
9727
|
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
|
|
9058
9728
|
|
|
9059
|
-
#ifdef GGML_METAL_HAS_TENSOR
|
|
9060
|
-
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
|
|
9061
|
-
cT.store(tC);
|
|
9062
|
-
#else
|
|
9063
9729
|
for (short i = 0; i < 8; i++) {
|
|
9064
9730
|
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
|
|
9065
9731
|
}
|
|
9066
|
-
#endif
|
|
9067
9732
|
|
|
9068
9733
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
9069
9734
|
|
|
@@ -9089,6 +9754,8 @@ kernel void kernel_mul_mm(
|
|
|
9089
9754
|
}
|
|
9090
9755
|
}
|
|
9091
9756
|
|
|
9757
|
+
#endif // GGML_METAL_HAS_TENSOR
|
|
9758
|
+
|
|
9092
9759
|
template<short ne20> // n_expert_used
|
|
9093
9760
|
kernel void kernel_mul_mm_id_map0(
|
|
9094
9761
|
constant ggml_metal_kargs_mul_mm_id_map0 & args,
|
|
@@ -9153,6 +9820,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
|
|
|
9153
9820
|
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
|
9154
9821
|
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
|
9155
9822
|
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
|
9823
|
+
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
|
|
9156
9824
|
|
|
9157
9825
|
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
|
9158
9826
|
kernel void kernel_mul_mm_id(
|
|
@@ -9170,7 +9838,9 @@ kernel void kernel_mul_mm_id(
|
|
|
9170
9838
|
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
|
9171
9839
|
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
|
9172
9840
|
|
|
9841
|
+
#ifdef GGML_METAL_HAS_TENSOR
|
|
9173
9842
|
threadgroup float * sc = (threadgroup float *)(shmem);
|
|
9843
|
+
#endif
|
|
9174
9844
|
|
|
9175
9845
|
constexpr int NR0 = 64;
|
|
9176
9846
|
constexpr int NR1 = 32;
|
|
@@ -9261,7 +9931,7 @@ kernel void kernel_mul_mm_id(
|
|
|
9261
9931
|
|
|
9262
9932
|
const short ib = 8*sx + sy;
|
|
9263
9933
|
|
|
9264
|
-
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
|
|
9934
|
+
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? (S0) *((device T0 *) x + i) : (S0) 0;
|
|
9265
9935
|
}
|
|
9266
9936
|
} else {
|
|
9267
9937
|
S0_4x4 temp_a;
|
|
@@ -9305,8 +9975,8 @@ kernel void kernel_mul_mm_id(
|
|
|
9305
9975
|
const short sx = (tiitg%NL1);
|
|
9306
9976
|
const short sy = (tiitg/NL1)/8;
|
|
9307
9977
|
|
|
9308
|
-
|
|
9309
|
-
|
|
9978
|
+
//const short dx = sx;
|
|
9979
|
+
//const short dy = sy;
|
|
9310
9980
|
|
|
9311
9981
|
const short ly = (tiitg/NL1)%8;
|
|
9312
9982
|
|
|
@@ -9474,6 +10144,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro
|
|
|
9474
10144
|
|
|
9475
10145
|
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
|
9476
10146
|
|
|
10147
|
+
template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
|
|
9477
10148
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
|
|
9478
10149
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
|
|
9479
10150
|
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
|
@@ -9536,6 +10207,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m
|
|
|
9536
10207
|
#if defined(GGML_METAL_HAS_BF16)
|
|
9537
10208
|
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
|
9538
10209
|
#endif
|
|
10210
|
+
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
|
|
9539
10211
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
|
9540
10212
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
|
9541
10213
|
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
|
@@ -9559,6 +10231,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
|
|
|
9559
10231
|
|
|
9560
10232
|
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
|
9561
10233
|
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
|
10234
|
+
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
|
|
9562
10235
|
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
|
9563
10236
|
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
|
9564
10237
|
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
|
@@ -9591,6 +10264,7 @@ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_m
|
|
|
9591
10264
|
#if defined(GGML_METAL_HAS_BF16)
|
|
9592
10265
|
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
|
9593
10266
|
#endif
|
|
10267
|
+
template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
|
|
9594
10268
|
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
|
9595
10269
|
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
|
9596
10270
|
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
|
@@ -9614,6 +10288,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m
|
|
|
9614
10288
|
|
|
9615
10289
|
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
|
9616
10290
|
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
|
10291
|
+
template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
|
|
9617
10292
|
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
|
9618
10293
|
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
|
9619
10294
|
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
|
@@ -9768,6 +10443,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4
|
|
|
9768
10443
|
|
|
9769
10444
|
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
|
|
9770
10445
|
|
|
10446
|
+
template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0>>>;
|
|
9771
10447
|
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
|
|
9772
10448
|
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
|
|
9773
10449
|
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
|
|
@@ -9869,6 +10545,74 @@ kernel void kernel_pool_2d_avg_f32(
|
|
|
9869
10545
|
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
|
9870
10546
|
}
|
|
9871
10547
|
|
|
10548
|
+
|
|
10549
|
+
kernel void kernel_pool_1d_max_f32(
|
|
10550
|
+
constant ggml_metal_kargs_pool_1d & args,
|
|
10551
|
+
device const float * src,
|
|
10552
|
+
device float * dst,
|
|
10553
|
+
uint gid [[thread_position_in_grid]]
|
|
10554
|
+
) {
|
|
10555
|
+
|
|
10556
|
+
if (gid >= args.np) {
|
|
10557
|
+
return;
|
|
10558
|
+
}
|
|
10559
|
+
|
|
10560
|
+
const int ow = (int)gid % args.OW;
|
|
10561
|
+
const int row = (int)gid / args.OW;
|
|
10562
|
+
|
|
10563
|
+
const int base = ow * args.s0 - args.p0;
|
|
10564
|
+
|
|
10565
|
+
float acc = -INFINITY;
|
|
10566
|
+
|
|
10567
|
+
const int src_off = row * args.IW;
|
|
10568
|
+
const int dst_off = row * args.OW;
|
|
10569
|
+
|
|
10570
|
+
for (int ki = 0; ki < args.k0; ++ki) {
|
|
10571
|
+
int j = base + ki;
|
|
10572
|
+
if (j < 0 || j >= args.IW){
|
|
10573
|
+
continue;
|
|
10574
|
+
}
|
|
10575
|
+
float v = src[src_off + j];
|
|
10576
|
+
acc = max(acc, v);
|
|
10577
|
+
}
|
|
10578
|
+
|
|
10579
|
+
dst[dst_off + ow] = acc;
|
|
10580
|
+
}
|
|
10581
|
+
|
|
10582
|
+
kernel void kernel_pool_1d_avg_f32(
|
|
10583
|
+
constant ggml_metal_kargs_pool_1d & args,
|
|
10584
|
+
device const float * src,
|
|
10585
|
+
device float * dst,
|
|
10586
|
+
uint gid [[thread_position_in_grid]]
|
|
10587
|
+
) {
|
|
10588
|
+
|
|
10589
|
+
if (gid >= args.np) {
|
|
10590
|
+
return;
|
|
10591
|
+
}
|
|
10592
|
+
|
|
10593
|
+
const int ow = (int)gid % args.OW;
|
|
10594
|
+
const int row = (int)gid / args.OW;
|
|
10595
|
+
|
|
10596
|
+
const int base = ow * args.s0 - args.p0;
|
|
10597
|
+
|
|
10598
|
+
float acc = 0.0f;
|
|
10599
|
+
int cnt = 0;
|
|
10600
|
+
|
|
10601
|
+
const int src_off = row * args.IW;
|
|
10602
|
+
const int dst_off = row * args.OW;
|
|
10603
|
+
|
|
10604
|
+
for (int ki = 0; ki < args.k0; ++ki) {
|
|
10605
|
+
const int j = base + ki;
|
|
10606
|
+
if (j < 0 || j >= args.IW) {
|
|
10607
|
+
continue;
|
|
10608
|
+
}
|
|
10609
|
+
acc += src[src_off + j];
|
|
10610
|
+
cnt += 1;
|
|
10611
|
+
}
|
|
10612
|
+
|
|
10613
|
+
dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
|
|
10614
|
+
}
|
|
10615
|
+
|
|
9872
10616
|
kernel void kernel_opt_step_adamw_f32(
|
|
9873
10617
|
constant ggml_metal_kargs_opt_step_adamw & args,
|
|
9874
10618
|
device float * x,
|
|
@@ -9919,7 +10663,7 @@ kernel void kernel_opt_step_sgd_f32(
|
|
|
9919
10663
|
|
|
9920
10664
|
template<typename T>
|
|
9921
10665
|
kernel void kernel_memset(
|
|
9922
|
-
constant
|
|
10666
|
+
constant ggml_metal_kargs_memset & args,
|
|
9923
10667
|
device T * dst,
|
|
9924
10668
|
uint tpig[[thread_position_in_grid]]) {
|
|
9925
10669
|
dst[tpig] = args.val;
|