whispercpp 1.3.1 → 1.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +4 -3
- data/README.md +92 -31
- data/Rakefile +26 -7
- data/ext/.gitignore +5 -7
- data/ext/dependencies.rb +61 -0
- data/ext/extconf.rb +21 -198
- data/ext/options.rb +221 -0
- data/ext/ruby_whisper.c +159 -0
- data/ext/ruby_whisper.h +17 -2
- data/ext/ruby_whisper_context.c +641 -0
- data/ext/ruby_whisper_error.c +52 -0
- data/ext/ruby_whisper_model.c +232 -0
- data/ext/ruby_whisper_params.c +1301 -0
- data/ext/ruby_whisper_segment.c +143 -0
- data/ext/ruby_whisper_transcribe.cpp +87 -0
- data/ext/ruby_whisper_vad_params.c +288 -0
- data/ext/sources/.dockerignore +3 -0
- data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
- data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
- data/ext/sources/CMakeLists.txt +251 -0
- data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
- data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
- data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
- data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
- data/ext/sources/bindings/javascript/package.json +26 -0
- data/ext/sources/bindings/javascript/whisper.js +19 -0
- data/ext/sources/build-xcframework.sh +547 -0
- data/ext/sources/ci/run.sh +336 -0
- data/ext/sources/close-issue.yml +28 -0
- data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
- data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
- data/ext/sources/cmake/build-info.cmake +60 -0
- data/ext/sources/cmake/git-vars.cmake +22 -0
- data/ext/sources/cmake/whisper-config.cmake.in +65 -0
- data/ext/sources/cmake/whisper.pc.in +10 -0
- data/ext/sources/examples/CMakeLists.txt +124 -0
- data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
- data/ext/sources/examples/addon.node/addon.cpp +438 -0
- data/ext/sources/examples/addon.node/index.js +54 -0
- data/ext/sources/examples/addon.node/package.json +16 -0
- data/ext/sources/examples/bench/CMakeLists.txt +8 -0
- data/ext/sources/examples/bench/bench.cpp +175 -0
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
- data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
- data/ext/sources/examples/cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/cli/cli.cpp +1294 -0
- data/ext/sources/examples/coi-serviceworker.js +146 -0
- data/ext/sources/examples/command/CMakeLists.txt +10 -0
- data/ext/sources/examples/command/command.cpp +776 -0
- data/ext/sources/examples/command/commands.txt +9 -0
- data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
- data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/common-ggml.cpp +238 -0
- data/ext/sources/examples/common-ggml.h +18 -0
- data/ext/sources/examples/common-sdl.cpp +227 -0
- data/ext/sources/examples/common-sdl.h +49 -0
- data/ext/sources/examples/common-whisper.cpp +168 -0
- data/ext/sources/examples/common-whisper.h +24 -0
- data/ext/sources/examples/common.cpp +675 -0
- data/ext/sources/examples/common.h +322 -0
- data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
- data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
- data/ext/sources/examples/generate-karaoke.sh +57 -0
- data/ext/sources/examples/grammar-parser.cpp +423 -0
- data/ext/sources/examples/grammar-parser.h +29 -0
- data/ext/sources/examples/helpers.js +191 -0
- data/ext/sources/examples/json.hpp +24596 -0
- data/ext/sources/examples/livestream.sh +112 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
- data/ext/sources/examples/lsp/lsp.cpp +467 -0
- data/ext/sources/examples/lsp/whisper.vim +362 -0
- data/ext/sources/examples/miniaudio.h +93468 -0
- data/ext/sources/examples/python/test_whisper_processor.py +7 -0
- data/ext/sources/examples/python/whisper_processor.py +54 -0
- data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
- data/ext/sources/examples/quantize/quantize.cpp +223 -0
- data/ext/sources/examples/server/CMakeLists.txt +12 -0
- data/ext/sources/examples/server/bench.js +29 -0
- data/ext/sources/examples/server/httplib.h +10497 -0
- data/ext/sources/examples/server/server.cpp +1091 -0
- data/ext/sources/examples/server.py +115 -0
- data/ext/sources/examples/stb_vorbis.c +5584 -0
- data/ext/sources/examples/stream/CMakeLists.txt +10 -0
- data/ext/sources/examples/stream/stream.cpp +429 -0
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
- data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
- data/ext/sources/examples/sycl/build.sh +22 -0
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
- data/ext/sources/examples/sycl/run-whisper.sh +17 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
- data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
- data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
- data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
- data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
- data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
- data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
- data/ext/sources/examples/talk-llama/llama-context.h +276 -0
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
- data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
- data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
- data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
- data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
- data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
- data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
- data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
- data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
- data/ext/sources/examples/talk-llama/llama-io.h +35 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
- data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
- data/ext/sources/examples/talk-llama/llama-model.h +425 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
- data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
- data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
- data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
- data/ext/sources/examples/talk-llama/llama.cpp +354 -0
- data/ext/sources/examples/talk-llama/llama.h +1377 -0
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
- data/ext/sources/examples/talk-llama/speak +40 -0
- data/ext/sources/examples/talk-llama/speak.bat +1 -0
- data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
- data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
- data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
- data/ext/sources/examples/talk-llama/unicode.h +66 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
- data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
- data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
- data/ext/sources/ggml/CMakeLists.txt +390 -0
- data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
- data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
- data/ext/sources/ggml/cmake/common.cmake +26 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
- data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
- data/ext/{ggml → sources/ggml}/include/ggml-backend.h +9 -7
- data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
- data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +9 -1
- data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
- data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
- data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
- data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
- data/ext/{ggml → sources/ggml}/include/ggml.h +182 -265
- data/ext/sources/ggml/include/gguf.h +202 -0
- data/ext/sources/ggml/src/CMakeLists.txt +346 -0
- data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
- data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +87 -53
- data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +26 -14
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
- data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
- data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +135 -1
- data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +564 -146
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
- data/ext/{ggml → sources/ggml}/src/ggml-common.h +12 -8
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +2 -1
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/cpu-feats-x86.cpp +5 -1
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +163 -41
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.c +4029 -1117
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +67 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
- data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
- data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
- data/ext/{ggml → sources/ggml}/src/ggml-impl.h +64 -19
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
- data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2178 -1064
- data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +1575 -1218
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
- data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
- data/ext/{ggml → sources/ggml}/src/ggml-quants.c +114 -120
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +480 -73
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +32 -33
- data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +4 -2
- data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/convert.cpp +104 -28
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +156 -17
- data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1004 -1240
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
- data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +0 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmvq.cpp +261 -166
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/norm.cpp +204 -81
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
- data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +35 -25
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
- data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
- data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3130 -1087
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
- data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
- data/ext/{ggml → sources/ggml}/src/ggml.c +676 -1820
- data/ext/sources/ggml/src/gguf.cpp +1330 -0
- data/ext/{include → sources/include}/whisper.h +68 -2
- data/ext/sources/src/CMakeLists.txt +143 -0
- data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
- data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +35 -10
- data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
- data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +28 -3
- data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
- data/ext/sources/src/whisper-arch.h +197 -0
- data/ext/{src → sources/src}/whisper.cpp +1905 -374
- data/ext/sources/tests/CMakeLists.txt +105 -0
- data/ext/sources/tests/earnings21/eval.mk +58 -0
- data/ext/sources/tests/earnings21/eval.py +68 -0
- data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
- data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
- data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
- data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
- data/ext/sources/tests/earnings21/requirements.txt +6 -0
- data/ext/sources/tests/en-0-ref.txt +1 -0
- data/ext/sources/tests/en-1-ref.txt +1 -0
- data/ext/sources/tests/en-2-ref.txt +1 -0
- data/ext/sources/tests/es-0-ref.txt +1 -0
- data/ext/sources/tests/librispeech/eval.mk +39 -0
- data/ext/sources/tests/librispeech/eval.py +47 -0
- data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
- data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
- data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
- data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
- data/ext/sources/tests/librispeech/requirements.txt +6 -0
- data/ext/sources/tests/run-tests.sh +130 -0
- data/ext/sources/tests/test-c.c +3 -0
- data/ext/sources/tests/test-vad-full.cpp +54 -0
- data/ext/sources/tests/test-vad.cpp +83 -0
- data/ext/sources/tests/test-whisper.js +58 -0
- data/extsources.rb +33 -5
- data/lib/whisper/model/uri.rb +149 -128
- data/sig/whisper.rbs +480 -0
- data/tests/helper.rb +28 -0
- data/tests/test_callback.rb +45 -3
- data/tests/test_error.rb +2 -2
- data/tests/test_model.rb +38 -0
- data/tests/test_package.rb +18 -3
- data/tests/test_params.rb +145 -8
- data/tests/test_segment.rb +10 -19
- data/tests/test_vad.rb +19 -0
- data/tests/test_vad_params.rb +103 -0
- data/tests/test_whisper.rb +37 -37
- data/whispercpp.gemspec +5 -4
- metadata +766 -111
- data/ext/cpu.mk +0 -9
- data/ext/examples/dr_wav.h +0 -8815
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
- data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
- data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
- data/ext/metal-embed.mk +0 -17
- data/ext/metal.mk +0 -6
- data/ext/ruby_whisper.cpp +0 -1909
- data/ext/scripts/get-flags.mk +0 -38
- data/lib/whisper.rb +0 -2
- /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-aarch64.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
- /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
- /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
- /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
@@ -19,7 +19,17 @@
|
|
19
19
|
// max number of MTLCommandBuffer used to submit a graph for processing
|
20
20
|
#define GGML_METAL_MAX_COMMAND_BUFFERS 8
|
21
21
|
|
22
|
-
#
|
22
|
+
#ifndef TARGET_OS_VISION
|
23
|
+
#define TARGET_OS_VISION 0
|
24
|
+
#endif
|
25
|
+
|
26
|
+
// create residency sets only on macOS >= 15.0
|
27
|
+
#if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \
|
28
|
+
TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \
|
29
|
+
TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \
|
30
|
+
TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000
|
31
|
+
#define GGML_METAL_HAS_RESIDENCY_SETS 1
|
32
|
+
#endif
|
23
33
|
|
24
34
|
// globals
|
25
35
|
|
@@ -34,11 +44,13 @@ static struct ggml_backend_device g_ggml_backend_metal_device;
|
|
34
44
|
// note: assumes single GPU device - the default one
|
35
45
|
// TODO: support multiple GPU devices
|
36
46
|
static struct ggml_backend_metal_device_context {
|
37
|
-
id<MTLDevice>
|
38
|
-
int
|
47
|
+
id<MTLDevice> mtl_device;
|
48
|
+
int mtl_device_ref_count;
|
49
|
+
id<MTLLibrary> mtl_library;
|
39
50
|
|
40
51
|
bool has_simdgroup_reduction;
|
41
52
|
bool has_simdgroup_mm;
|
53
|
+
bool has_residency_sets;
|
42
54
|
bool has_bfloat;
|
43
55
|
bool use_bfloat;
|
44
56
|
|
@@ -46,8 +58,10 @@ static struct ggml_backend_metal_device_context {
|
|
46
58
|
} g_ggml_ctx_dev_main = {
|
47
59
|
/*.mtl_device =*/ nil,
|
48
60
|
/*.mtl_device_ref_count =*/ 0,
|
61
|
+
/*.mtl_library =*/ nil,
|
49
62
|
/*.has_simdgroup_reduction =*/ false,
|
50
63
|
/*.has_simdgroup_mm =*/ false,
|
64
|
+
/*.has_residency_sets =*/ false,
|
51
65
|
/*.has_bfloat =*/ false,
|
52
66
|
/*.use_bfloat =*/ false,
|
53
67
|
/*.name =*/ "",
|
@@ -59,12 +73,18 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|
59
73
|
|
60
74
|
if (ctx->mtl_device == nil) {
|
61
75
|
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
76
|
+
}
|
62
77
|
|
78
|
+
if (ctx->mtl_device) {
|
63
79
|
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
64
80
|
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
65
81
|
|
66
82
|
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
67
83
|
|
84
|
+
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
85
|
+
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL;
|
86
|
+
#endif
|
87
|
+
|
68
88
|
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
69
89
|
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
70
90
|
|
@@ -90,8 +110,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
|
90
110
|
ctx->mtl_device_ref_count--;
|
91
111
|
|
92
112
|
if (ctx->mtl_device_ref_count == 0) {
|
93
|
-
|
94
|
-
|
113
|
+
if (ctx->mtl_library) {
|
114
|
+
[ctx->mtl_library release];
|
115
|
+
ctx->mtl_library = nil;
|
116
|
+
}
|
117
|
+
|
118
|
+
if (ctx->mtl_device) {
|
119
|
+
[ctx->mtl_device release];
|
120
|
+
ctx->mtl_device = nil;
|
121
|
+
}
|
95
122
|
}
|
96
123
|
}
|
97
124
|
|
@@ -122,6 +149,8 @@ enum ggml_metal_kernel_type {
|
|
122
149
|
GGML_METAL_KERNEL_TYPE_SIGMOID,
|
123
150
|
GGML_METAL_KERNEL_TYPE_GELU,
|
124
151
|
GGML_METAL_KERNEL_TYPE_GELU_4,
|
152
|
+
GGML_METAL_KERNEL_TYPE_GELU_ERF,
|
153
|
+
GGML_METAL_KERNEL_TYPE_GELU_ERF_4,
|
125
154
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
126
155
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
127
156
|
GGML_METAL_KERNEL_TYPE_SILU,
|
@@ -157,10 +186,13 @@ enum ggml_metal_kernel_type {
|
|
157
186
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
158
187
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
159
188
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
189
|
+
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
160
190
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
161
191
|
GGML_METAL_KERNEL_TYPE_NORM,
|
162
192
|
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
163
193
|
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
194
|
+
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
195
|
+
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
164
196
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
165
197
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
166
198
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
@@ -276,30 +308,36 @@ enum ggml_metal_kernel_type {
|
|
276
308
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
|
277
309
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
278
310
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
311
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
|
312
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
|
313
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
|
314
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
|
315
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
|
316
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16,
|
317
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16,
|
318
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
|
319
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
|
320
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
|
321
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
|
322
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
|
323
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
|
324
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16,
|
325
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16,
|
326
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16,
|
327
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16,
|
328
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16,
|
329
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16,
|
330
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16,
|
331
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16,
|
332
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16,
|
333
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16,
|
334
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
|
301
335
|
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
302
336
|
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
337
|
+
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
|
338
|
+
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
|
339
|
+
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
|
340
|
+
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
|
303
341
|
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
304
342
|
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
305
343
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
@@ -321,43 +359,78 @@ enum ggml_metal_kernel_type {
|
|
321
359
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
322
360
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
323
361
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
362
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
|
363
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
|
324
364
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
365
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
|
325
366
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
326
367
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
327
368
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
328
369
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
|
329
370
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
|
371
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
|
372
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
|
330
373
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
374
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
|
331
375
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
332
376
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
333
377
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
334
378
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
|
335
379
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
|
380
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
|
381
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
|
336
382
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
383
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
|
337
384
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
338
385
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
339
386
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
340
387
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
|
341
388
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
|
389
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
|
390
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
|
342
391
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
392
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
|
343
393
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
344
394
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
345
395
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
346
396
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
|
347
397
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
|
398
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
|
399
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
|
348
400
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
401
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
|
349
402
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
350
403
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
351
404
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
352
405
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
|
353
406
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
|
407
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
|
408
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
|
354
409
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
410
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
|
355
411
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
356
412
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
357
413
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
358
414
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
|
359
415
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
|
416
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
|
417
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
|
360
418
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
419
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
|
420
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
|
421
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
|
422
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
|
423
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64,
|
424
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64,
|
425
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64,
|
426
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,
|
427
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
|
428
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
|
429
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
|
430
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,
|
431
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,
|
432
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,
|
433
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,
|
361
434
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
362
435
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
|
363
436
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
@@ -365,6 +438,20 @@ enum ggml_metal_kernel_type {
|
|
365
438
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
|
366
439
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
|
367
440
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
|
441
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192,
|
442
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192,
|
443
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192,
|
444
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192,
|
445
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192,
|
446
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192,
|
447
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192,
|
448
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,
|
449
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128,
|
450
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128,
|
451
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128,
|
452
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128,
|
453
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128,
|
454
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,
|
368
455
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
369
456
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
|
370
457
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
|
@@ -372,6 +459,13 @@ enum ggml_metal_kernel_type {
|
|
372
459
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
373
460
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
374
461
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
462
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
|
463
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
|
464
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
|
465
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
|
466
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
|
467
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
|
468
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
|
375
469
|
GGML_METAL_KERNEL_TYPE_SET_I32,
|
376
470
|
GGML_METAL_KERNEL_TYPE_SET_F32,
|
377
471
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
@@ -387,11 +481,22 @@ enum ggml_metal_kernel_type {
|
|
387
481
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
388
482
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
389
483
|
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
484
|
+
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
|
485
|
+
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
|
486
|
+
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
|
487
|
+
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
|
488
|
+
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
|
489
|
+
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
|
490
|
+
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
|
491
|
+
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
|
492
|
+
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
|
493
|
+
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
|
390
494
|
GGML_METAL_KERNEL_TYPE_CONCAT,
|
391
495
|
GGML_METAL_KERNEL_TYPE_SQR,
|
392
496
|
GGML_METAL_KERNEL_TYPE_SQRT,
|
393
497
|
GGML_METAL_KERNEL_TYPE_SIN,
|
394
498
|
GGML_METAL_KERNEL_TYPE_COS,
|
499
|
+
GGML_METAL_KERNEL_TYPE_NEG,
|
395
500
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
396
501
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
397
502
|
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
@@ -400,7 +505,264 @@ enum ggml_metal_kernel_type {
|
|
400
505
|
GGML_METAL_KERNEL_TYPE_COUNT
|
401
506
|
};
|
402
507
|
|
508
|
+
//
|
509
|
+
// ggml_metal_heap
|
510
|
+
//
|
511
|
+
|
512
|
+
struct ggml_metal_heap {
|
513
|
+
// number of times the heap was unused
|
514
|
+
int n_unused;
|
515
|
+
|
516
|
+
// total number of buffer allocations in this heap across all computes
|
517
|
+
int64_t n_alloc;
|
518
|
+
|
519
|
+
// current offset in the heap - we reset this after each node in order to reuse the memory
|
520
|
+
size_t offs;
|
521
|
+
|
522
|
+
// the currently allocated MTLBuffer objects in this heap
|
523
|
+
id<MTLHeap> obj;
|
524
|
+
|
525
|
+
NSMutableArray * bufs;
|
526
|
+
};
|
527
|
+
|
528
|
+
static struct ggml_metal_heap * ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
|
529
|
+
struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap));
|
530
|
+
|
531
|
+
MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
|
532
|
+
desc.storageMode = MTLStorageModePrivate;
|
533
|
+
desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
|
534
|
+
desc.type = MTLHeapTypePlacement;
|
535
|
+
desc.size = size;
|
536
|
+
|
537
|
+
heap->n_unused = 0;
|
538
|
+
heap->n_alloc = 0;
|
539
|
+
|
540
|
+
heap->obj = [device newHeapWithDescriptor:desc];
|
541
|
+
if (!heap->obj) {
|
542
|
+
GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
|
543
|
+
|
544
|
+
free(heap);
|
545
|
+
|
546
|
+
return false;
|
547
|
+
}
|
548
|
+
|
549
|
+
[desc release];
|
550
|
+
|
551
|
+
heap->bufs = [[NSMutableArray alloc] init];
|
552
|
+
|
553
|
+
return heap;
|
554
|
+
}
|
555
|
+
|
556
|
+
static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
|
557
|
+
heap->offs = 0;
|
558
|
+
|
559
|
+
// count how many graph computes the heap ended up being unused
|
560
|
+
if ([heap->bufs count] > 0) {
|
561
|
+
heap->n_unused = 0;
|
562
|
+
} else {
|
563
|
+
heap->n_unused++;
|
564
|
+
}
|
565
|
+
|
566
|
+
for (id<MTLBuffer> buf in heap->bufs) {
|
567
|
+
[buf release];
|
568
|
+
}
|
569
|
+
[heap->bufs removeAllObjects];
|
570
|
+
|
571
|
+
// tell the OS that it can reuse this memory if needed
|
572
|
+
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
|
573
|
+
[heap->obj setPurgeableState:MTLPurgeableStateVolatile];
|
574
|
+
}
|
575
|
+
|
576
|
+
static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
|
577
|
+
if (heap == nil) {
|
578
|
+
return;
|
579
|
+
}
|
580
|
+
|
581
|
+
ggml_metal_heap_reset(heap);
|
582
|
+
|
583
|
+
[heap->obj release];
|
584
|
+
[heap->bufs release];
|
585
|
+
|
586
|
+
free(heap);
|
587
|
+
}
|
588
|
+
|
589
|
+
@interface ggml_metal_heap_ptr : NSObject
|
590
|
+
|
591
|
+
@property (nonatomic, assign) struct ggml_metal_heap * data;
|
592
|
+
|
593
|
+
@end
|
594
|
+
|
595
|
+
@implementation ggml_metal_heap_ptr
|
596
|
+
@end
|
597
|
+
|
598
|
+
//
|
599
|
+
// ggml_metal_mem_pool
|
600
|
+
//
|
601
|
+
|
602
|
+
struct ggml_metal_mem_pool {
|
603
|
+
id<MTLDevice> device;
|
604
|
+
|
605
|
+
int n_heaps; // total number of heaps ever created (including those that were removed)
|
606
|
+
|
607
|
+
NSMutableArray * heaps;
|
608
|
+
NSMutableArray * heaps_to_remove;
|
609
|
+
};
|
610
|
+
|
611
|
+
static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) {
|
612
|
+
struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool));
|
613
|
+
|
614
|
+
mem_pool->n_heaps = 0;
|
615
|
+
|
616
|
+
mem_pool->heaps = [[NSMutableArray alloc] init];
|
617
|
+
mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
|
618
|
+
|
619
|
+
return mem_pool;
|
620
|
+
}
|
621
|
+
|
622
|
+
static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) {
|
623
|
+
GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
|
624
|
+
|
625
|
+
size_t size_all = 0;
|
626
|
+
size_t size_cur = 0;
|
627
|
+
|
628
|
+
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
629
|
+
GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data);
|
630
|
+
GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc);
|
631
|
+
GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused);
|
632
|
+
GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
|
633
|
+
GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]);
|
634
|
+
|
635
|
+
if ([ptr.data->bufs count] > 0) {
|
636
|
+
size_cur += [ptr.data->obj size];
|
637
|
+
}
|
638
|
+
size_all += [ptr.data->obj size];
|
639
|
+
|
640
|
+
ggml_metal_heap_free(ptr.data);
|
641
|
+
[ptr release];
|
642
|
+
}
|
643
|
+
[mem_pool->heaps release];
|
644
|
+
[mem_pool->heaps_to_remove release];
|
645
|
+
|
646
|
+
if (size_all > 0) {
|
647
|
+
GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
|
648
|
+
GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
|
649
|
+
}
|
650
|
+
|
651
|
+
free(mem_pool);
|
652
|
+
}
|
653
|
+
|
654
|
+
static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
|
655
|
+
for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
|
656
|
+
ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
|
657
|
+
|
658
|
+
struct ggml_metal_heap * heap = ptr.data;
|
659
|
+
ggml_metal_heap_reset(heap);
|
660
|
+
|
661
|
+
// if the heap hasn't been used for a while, remove it
|
662
|
+
if (heap->n_unused >= 128) {
|
663
|
+
[mem_pool->heaps_to_remove addObject:@(i)];
|
664
|
+
}
|
665
|
+
}
|
666
|
+
|
667
|
+
if (mem_pool->heaps_to_remove.count > 0) {
|
668
|
+
// remove in reverse order
|
669
|
+
for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) {
|
670
|
+
NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
|
671
|
+
ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
|
672
|
+
|
673
|
+
struct ggml_metal_heap * heap = ptr.data;
|
674
|
+
ggml_metal_heap_free(heap);
|
675
|
+
|
676
|
+
[mem_pool->heaps removeObjectAtIndex:index];
|
677
|
+
[ptr release];
|
678
|
+
|
679
|
+
if (i == 0) {
|
680
|
+
break;
|
681
|
+
}
|
682
|
+
}
|
683
|
+
|
684
|
+
[mem_pool->heaps_to_remove removeAllObjects];
|
685
|
+
}
|
686
|
+
}
|
687
|
+
|
688
|
+
static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
|
689
|
+
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
690
|
+
ptr.data->offs = 0;
|
691
|
+
}
|
692
|
+
}
|
693
|
+
|
694
|
+
static id<MTLBuffer> ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) {
|
695
|
+
const size_t alignment = 256;
|
696
|
+
|
697
|
+
const size_t size_aligned = GGML_PAD(size, alignment);
|
698
|
+
|
699
|
+
// try one of the existing heaps
|
700
|
+
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
701
|
+
struct ggml_metal_heap * heap = ptr.data;
|
702
|
+
if (heap->offs + size_aligned <= [heap->obj size]) {
|
703
|
+
// if this is the first buffer in the heap for the current command buffer, tell the OS that
|
704
|
+
// it cannot free the memory used by the heap
|
705
|
+
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
|
706
|
+
if ([heap->bufs count] == 0) {
|
707
|
+
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
|
708
|
+
}
|
709
|
+
|
710
|
+
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
|
711
|
+
if (buf == nil) {
|
712
|
+
GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
|
713
|
+
return nil;
|
714
|
+
}
|
715
|
+
|
716
|
+
heap->n_alloc++;
|
717
|
+
heap->offs += size_aligned;
|
718
|
+
|
719
|
+
[heap->bufs addObject:buf];
|
720
|
+
|
721
|
+
return buf;
|
722
|
+
}
|
723
|
+
}
|
724
|
+
|
725
|
+
// create a new heap that can fit this buffer
|
726
|
+
ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new];
|
727
|
+
|
728
|
+
struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned);
|
729
|
+
if (heap == NULL) {
|
730
|
+
GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
|
731
|
+
return NULL;
|
732
|
+
}
|
733
|
+
|
734
|
+
//GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
|
735
|
+
|
736
|
+
heap_ptr.data = heap;
|
737
|
+
ggml_metal_heap_reset(heap);
|
738
|
+
|
739
|
+
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
|
740
|
+
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
|
741
|
+
if (buf == nil) {
|
742
|
+
GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
|
743
|
+
return NULL;
|
744
|
+
}
|
745
|
+
|
746
|
+
heap->n_alloc++;
|
747
|
+
heap->offs += size_aligned;
|
748
|
+
|
749
|
+
[heap->bufs addObject:buf];
|
750
|
+
|
751
|
+
[mem_pool->heaps addObject:heap_ptr];
|
752
|
+
mem_pool->n_heaps++;
|
753
|
+
|
754
|
+
return buf;
|
755
|
+
}
|
756
|
+
|
757
|
+
struct ggml_metal_command_buffer {
|
758
|
+
id<MTLCommandBuffer> obj;
|
759
|
+
|
760
|
+
// each command buffer has a memory pool from which it can allocate temporary buffers during the compute
|
761
|
+
struct ggml_metal_mem_pool * mem_pool;
|
762
|
+
};
|
763
|
+
|
403
764
|
struct ggml_backend_metal_context {
|
765
|
+
id<MTLDevice> device;
|
404
766
|
id<MTLCommandQueue> queue;
|
405
767
|
|
406
768
|
dispatch_queue_t d_queue;
|
@@ -425,7 +787,7 @@ struct ggml_backend_metal_context {
|
|
425
787
|
void (^encode_async)(size_t ith);
|
426
788
|
|
427
789
|
// n_cb command buffers + 1 used by the main thread
|
428
|
-
|
790
|
+
struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
429
791
|
|
430
792
|
// abort ggml_metal_graph_compute if callback returns true
|
431
793
|
ggml_abort_callback abort_callback;
|
@@ -437,11 +799,13 @@ struct ggml_backend_metal_context {
|
|
437
799
|
// for now it is easier to work in a separate file
|
438
800
|
// static NSString * const msl_library_source = @"see metal.metal";
|
439
801
|
|
802
|
+
#if !GGML_METAL_EMBED_LIBRARY
|
440
803
|
// Here to assist with NSBundle Path Hack
|
441
804
|
@interface GGMLMetalClass : NSObject
|
442
805
|
@end
|
443
806
|
@implementation GGMLMetalClass
|
444
807
|
@end
|
808
|
+
#endif
|
445
809
|
|
446
810
|
static void * ggml_metal_host_malloc(size_t n) {
|
447
811
|
void * data = NULL;
|
@@ -463,159 +827,176 @@ static void * ggml_metal_host_malloc(size_t n) {
|
|
463
827
|
return data;
|
464
828
|
}
|
465
829
|
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
#endif
|
477
|
-
|
478
|
-
// init context
|
479
|
-
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
480
|
-
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
830
|
+
// load library
|
831
|
+
//
|
832
|
+
// - first check if the library is embedded
|
833
|
+
// - then check if the library is in the bundle
|
834
|
+
// - if not found, load the source and compile it
|
835
|
+
// - if that fails, return NULL
|
836
|
+
static id<MTLLibrary> ggml_metal_load_library(id<MTLDevice> device, bool use_bfloat) {
|
837
|
+
id<MTLLibrary> metal_library = nil;
|
838
|
+
NSError * error = nil;
|
839
|
+
NSString * src = nil;
|
481
840
|
|
482
|
-
|
483
|
-
GGML_LOG_INFO("%s:
|
841
|
+
#if GGML_METAL_EMBED_LIBRARY
|
842
|
+
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
|
484
843
|
|
485
|
-
|
486
|
-
|
844
|
+
extern const char ggml_metallib_start[];
|
845
|
+
extern const char ggml_metallib_end[];
|
487
846
|
|
488
|
-
|
847
|
+
src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
|
489
848
|
|
490
|
-
// load library
|
491
|
-
//
|
492
|
-
// - first check if the library is embedded
|
493
|
-
// - then check if the library is in the bundle
|
494
|
-
// - if not found, load the source and compile it
|
495
|
-
// - if that fails, return NULL
|
496
|
-
{
|
497
|
-
NSBundle * bundle = nil;
|
498
|
-
#ifdef SWIFT_PACKAGE
|
499
|
-
bundle = SWIFTPM_MODULE_BUNDLE;
|
500
849
|
#else
|
501
|
-
bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
502
|
-
#endif
|
503
|
-
|
504
|
-
NSError * error = nil;
|
505
850
|
|
506
|
-
#
|
507
|
-
|
851
|
+
#ifdef SWIFT_PACKAGE
|
852
|
+
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
|
508
853
|
#else
|
509
|
-
|
854
|
+
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
510
855
|
#endif
|
511
856
|
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
}
|
857
|
+
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
|
858
|
+
if (path_lib == nil) {
|
859
|
+
// Try to find the resource in the directory where the current binary located.
|
860
|
+
NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
|
861
|
+
NSString * bin_dir = [current_binary stringByDeletingLastPathComponent];
|
862
|
+
NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
|
863
|
+
if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
|
864
|
+
GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]);
|
865
|
+
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error];
|
866
|
+
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
|
867
|
+
// Optionally, if this is a symlink, try to resolve it.
|
868
|
+
default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error];
|
869
|
+
if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) {
|
870
|
+
// It is a relative path, adding the binary directory as directory prefix.
|
871
|
+
default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]];
|
872
|
+
}
|
873
|
+
if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
|
874
|
+
// Link to the resource could not be resolved.
|
875
|
+
default_metallib_path = nil;
|
876
|
+
} else {
|
877
|
+
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]);
|
534
878
|
}
|
535
|
-
} else {
|
536
|
-
// The resource couldn't be found in the binary's directory.
|
537
|
-
default_metallib_path = nil;
|
538
879
|
}
|
539
|
-
|
880
|
+
} else {
|
881
|
+
// The resource couldn't be found in the binary's directory.
|
882
|
+
default_metallib_path = nil;
|
540
883
|
}
|
884
|
+
path_lib = default_metallib_path;
|
885
|
+
}
|
541
886
|
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
887
|
+
if (path_lib != nil) {
|
888
|
+
// pre-compiled library found
|
889
|
+
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
|
890
|
+
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
|
546
891
|
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
892
|
+
metal_library = [device newLibraryWithURL:libURL error:&error];
|
893
|
+
if (error) {
|
894
|
+
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
895
|
+
return NULL;
|
896
|
+
}
|
897
|
+
} else {
|
898
|
+
GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
899
|
+
|
900
|
+
NSString * path_source;
|
901
|
+
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
|
902
|
+
|
903
|
+
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
|
904
|
+
|
905
|
+
if (path_resource) {
|
906
|
+
path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
|
552
907
|
} else {
|
553
|
-
|
554
|
-
|
908
|
+
path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
909
|
+
}
|
555
910
|
|
556
|
-
|
557
|
-
|
911
|
+
if (path_source == nil) {
|
912
|
+
GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
|
913
|
+
path_source = @"ggml-metal.metal";
|
914
|
+
}
|
558
915
|
|
559
|
-
|
560
|
-
#else
|
561
|
-
GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
916
|
+
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
|
562
917
|
|
563
|
-
|
564
|
-
|
918
|
+
src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
|
919
|
+
if (error) {
|
920
|
+
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
921
|
+
return NULL;
|
922
|
+
}
|
923
|
+
}
|
924
|
+
#endif
|
565
925
|
|
566
|
-
|
926
|
+
if (!metal_library) {
|
927
|
+
@autoreleasepool {
|
928
|
+
// dictionary of preprocessor macros
|
929
|
+
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
567
930
|
|
568
|
-
if (
|
569
|
-
|
570
|
-
} else {
|
571
|
-
path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
931
|
+
if (use_bfloat) {
|
932
|
+
[prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
|
572
933
|
}
|
573
934
|
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
935
|
+
#if GGML_METAL_EMBED_LIBRARY
|
936
|
+
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
|
937
|
+
#endif
|
938
|
+
|
939
|
+
MTLCompileOptions * options = [MTLCompileOptions new];
|
940
|
+
options.preprocessorMacros = prep;
|
578
941
|
|
579
|
-
|
942
|
+
//[options setFastMathEnabled:false];
|
580
943
|
|
581
|
-
|
944
|
+
metal_library = [device newLibraryWithSource:src options:options error:&error];
|
582
945
|
if (error) {
|
583
946
|
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
584
947
|
return NULL;
|
585
948
|
}
|
949
|
+
|
950
|
+
#if !__has_feature(objc_arc)
|
951
|
+
[options release];
|
952
|
+
#endif
|
953
|
+
}
|
954
|
+
}
|
955
|
+
|
956
|
+
#if GGML_METAL_EMBED_LIBRARY
|
957
|
+
[src release];
|
586
958
|
#endif // GGML_METAL_EMBED_LIBRARY
|
587
959
|
|
588
|
-
|
589
|
-
|
590
|
-
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
960
|
+
return metal_library;
|
961
|
+
}
|
591
962
|
|
592
|
-
|
593
|
-
|
594
|
-
}
|
963
|
+
static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
|
964
|
+
GGML_LOG_INFO("%s: allocating\n", __func__);
|
595
965
|
|
596
|
-
#if
|
597
|
-
|
966
|
+
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
|
967
|
+
// Show all the Metal device instances in the system
|
968
|
+
NSArray * devices = MTLCopyAllDevices();
|
969
|
+
for (id<MTLDevice> device in devices) {
|
970
|
+
GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
|
971
|
+
}
|
972
|
+
[devices release]; // since it was created by a *Copy* C method
|
598
973
|
#endif
|
599
974
|
|
600
|
-
|
601
|
-
|
975
|
+
// init context
|
976
|
+
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
977
|
+
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
602
978
|
|
603
|
-
|
979
|
+
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
604
980
|
|
605
|
-
|
606
|
-
if (error) {
|
607
|
-
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
608
|
-
return NULL;
|
609
|
-
}
|
981
|
+
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
610
982
|
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
983
|
+
ctx->device = device;
|
984
|
+
ctx->queue = [device newCommandQueue];
|
985
|
+
if (ctx->queue == nil) {
|
986
|
+
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
987
|
+
return NULL;
|
988
|
+
}
|
989
|
+
|
990
|
+
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
991
|
+
|
992
|
+
// load library
|
993
|
+
if (ctx_dev->mtl_library == nil) {
|
994
|
+
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
|
995
|
+
}
|
996
|
+
id<MTLLibrary> metal_library = ctx_dev->mtl_library;
|
997
|
+
if (metal_library == nil) {
|
998
|
+
GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
|
999
|
+
return NULL;
|
619
1000
|
}
|
620
1001
|
|
621
1002
|
// print MTL GPU family:
|
@@ -649,6 +1030,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
649
1030
|
|
650
1031
|
GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
|
651
1032
|
GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
|
1033
|
+
GGML_LOG_INFO("%s: has residency sets = %s\n", __func__, ctx_dev->has_residency_sets ? "true" : "false");
|
652
1034
|
GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
|
653
1035
|
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
|
654
1036
|
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
|
@@ -660,7 +1042,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
660
1042
|
ctx->gf = nil;
|
661
1043
|
ctx->encode_async = nil;
|
662
1044
|
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
663
|
-
ctx->
|
1045
|
+
ctx->cmd_bufs[i].obj = nil;
|
1046
|
+
|
1047
|
+
ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
|
1048
|
+
ctx->cmd_bufs[i].mem_pool->device = device;
|
664
1049
|
}
|
665
1050
|
|
666
1051
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
@@ -688,7 +1073,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
688
1073
|
[metal_function release]; \
|
689
1074
|
if (error) { \
|
690
1075
|
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
691
|
-
[metal_library release]; \
|
692
1076
|
return NULL; \
|
693
1077
|
} \
|
694
1078
|
} else { \
|
@@ -701,304 +1085,380 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
701
1085
|
|
702
1086
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
703
1087
|
|
704
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,
|
705
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
706
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB,
|
707
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW,
|
708
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,
|
709
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
710
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,
|
711
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
712
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
713
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
714
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
715
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16,
|
716
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,
|
717
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,
|
718
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,
|
719
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,
|
720
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,
|
721
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID,
|
722
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,
|
723
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4,
|
724
|
-
GGML_METAL_ADD_KERNEL(
|
725
|
-
GGML_METAL_ADD_KERNEL(
|
726
|
-
GGML_METAL_ADD_KERNEL(
|
727
|
-
GGML_METAL_ADD_KERNEL(
|
728
|
-
GGML_METAL_ADD_KERNEL(
|
729
|
-
GGML_METAL_ADD_KERNEL(
|
730
|
-
GGML_METAL_ADD_KERNEL(
|
731
|
-
GGML_METAL_ADD_KERNEL(
|
732
|
-
GGML_METAL_ADD_KERNEL(
|
733
|
-
GGML_METAL_ADD_KERNEL(
|
734
|
-
GGML_METAL_ADD_KERNEL(
|
735
|
-
GGML_METAL_ADD_KERNEL(
|
736
|
-
GGML_METAL_ADD_KERNEL(
|
737
|
-
GGML_METAL_ADD_KERNEL(
|
738
|
-
GGML_METAL_ADD_KERNEL(
|
739
|
-
GGML_METAL_ADD_KERNEL(
|
740
|
-
GGML_METAL_ADD_KERNEL(
|
741
|
-
GGML_METAL_ADD_KERNEL(
|
742
|
-
GGML_METAL_ADD_KERNEL(
|
743
|
-
GGML_METAL_ADD_KERNEL(
|
744
|
-
GGML_METAL_ADD_KERNEL(
|
745
|
-
GGML_METAL_ADD_KERNEL(
|
746
|
-
GGML_METAL_ADD_KERNEL(
|
747
|
-
GGML_METAL_ADD_KERNEL(
|
748
|
-
GGML_METAL_ADD_KERNEL(
|
749
|
-
GGML_METAL_ADD_KERNEL(
|
750
|
-
GGML_METAL_ADD_KERNEL(
|
751
|
-
GGML_METAL_ADD_KERNEL(
|
752
|
-
GGML_METAL_ADD_KERNEL(
|
753
|
-
GGML_METAL_ADD_KERNEL(
|
754
|
-
GGML_METAL_ADD_KERNEL(
|
755
|
-
GGML_METAL_ADD_KERNEL(
|
756
|
-
GGML_METAL_ADD_KERNEL(
|
757
|
-
GGML_METAL_ADD_KERNEL(
|
758
|
-
GGML_METAL_ADD_KERNEL(
|
759
|
-
GGML_METAL_ADD_KERNEL(
|
760
|
-
GGML_METAL_ADD_KERNEL(
|
761
|
-
GGML_METAL_ADD_KERNEL(
|
762
|
-
GGML_METAL_ADD_KERNEL(
|
763
|
-
GGML_METAL_ADD_KERNEL(
|
764
|
-
GGML_METAL_ADD_KERNEL(
|
765
|
-
GGML_METAL_ADD_KERNEL(
|
766
|
-
GGML_METAL_ADD_KERNEL(
|
767
|
-
GGML_METAL_ADD_KERNEL(
|
768
|
-
GGML_METAL_ADD_KERNEL(
|
769
|
-
GGML_METAL_ADD_KERNEL(
|
770
|
-
GGML_METAL_ADD_KERNEL(
|
771
|
-
GGML_METAL_ADD_KERNEL(
|
772
|
-
GGML_METAL_ADD_KERNEL(
|
773
|
-
GGML_METAL_ADD_KERNEL(
|
774
|
-
GGML_METAL_ADD_KERNEL(
|
775
|
-
GGML_METAL_ADD_KERNEL(
|
776
|
-
GGML_METAL_ADD_KERNEL(
|
777
|
-
GGML_METAL_ADD_KERNEL(
|
778
|
-
GGML_METAL_ADD_KERNEL(
|
779
|
-
GGML_METAL_ADD_KERNEL(
|
780
|
-
GGML_METAL_ADD_KERNEL(
|
781
|
-
GGML_METAL_ADD_KERNEL(
|
782
|
-
GGML_METAL_ADD_KERNEL(
|
783
|
-
GGML_METAL_ADD_KERNEL(
|
784
|
-
GGML_METAL_ADD_KERNEL(
|
785
|
-
GGML_METAL_ADD_KERNEL(
|
786
|
-
GGML_METAL_ADD_KERNEL(
|
787
|
-
GGML_METAL_ADD_KERNEL(
|
788
|
-
GGML_METAL_ADD_KERNEL(
|
789
|
-
GGML_METAL_ADD_KERNEL(
|
790
|
-
GGML_METAL_ADD_KERNEL(
|
791
|
-
GGML_METAL_ADD_KERNEL(
|
792
|
-
GGML_METAL_ADD_KERNEL(
|
793
|
-
GGML_METAL_ADD_KERNEL(
|
794
|
-
GGML_METAL_ADD_KERNEL(
|
795
|
-
GGML_METAL_ADD_KERNEL(
|
796
|
-
GGML_METAL_ADD_KERNEL(
|
797
|
-
GGML_METAL_ADD_KERNEL(
|
798
|
-
GGML_METAL_ADD_KERNEL(
|
799
|
-
GGML_METAL_ADD_KERNEL(
|
800
|
-
GGML_METAL_ADD_KERNEL(
|
801
|
-
GGML_METAL_ADD_KERNEL(
|
802
|
-
GGML_METAL_ADD_KERNEL(
|
803
|
-
GGML_METAL_ADD_KERNEL(
|
804
|
-
GGML_METAL_ADD_KERNEL(
|
805
|
-
GGML_METAL_ADD_KERNEL(
|
806
|
-
GGML_METAL_ADD_KERNEL(
|
807
|
-
GGML_METAL_ADD_KERNEL(
|
808
|
-
GGML_METAL_ADD_KERNEL(
|
809
|
-
GGML_METAL_ADD_KERNEL(
|
810
|
-
GGML_METAL_ADD_KERNEL(
|
811
|
-
GGML_METAL_ADD_KERNEL(
|
812
|
-
GGML_METAL_ADD_KERNEL(
|
813
|
-
GGML_METAL_ADD_KERNEL(
|
814
|
-
GGML_METAL_ADD_KERNEL(
|
815
|
-
GGML_METAL_ADD_KERNEL(
|
816
|
-
GGML_METAL_ADD_KERNEL(
|
817
|
-
GGML_METAL_ADD_KERNEL(
|
818
|
-
GGML_METAL_ADD_KERNEL(
|
819
|
-
GGML_METAL_ADD_KERNEL(
|
820
|
-
GGML_METAL_ADD_KERNEL(
|
821
|
-
GGML_METAL_ADD_KERNEL(
|
822
|
-
GGML_METAL_ADD_KERNEL(
|
823
|
-
GGML_METAL_ADD_KERNEL(
|
824
|
-
GGML_METAL_ADD_KERNEL(
|
825
|
-
GGML_METAL_ADD_KERNEL(
|
826
|
-
GGML_METAL_ADD_KERNEL(
|
827
|
-
GGML_METAL_ADD_KERNEL(
|
828
|
-
GGML_METAL_ADD_KERNEL(
|
829
|
-
GGML_METAL_ADD_KERNEL(
|
830
|
-
GGML_METAL_ADD_KERNEL(
|
831
|
-
GGML_METAL_ADD_KERNEL(
|
832
|
-
GGML_METAL_ADD_KERNEL(
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
GGML_METAL_ADD_KERNEL(
|
837
|
-
GGML_METAL_ADD_KERNEL(
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
GGML_METAL_ADD_KERNEL(
|
842
|
-
GGML_METAL_ADD_KERNEL(
|
843
|
-
GGML_METAL_ADD_KERNEL(
|
844
|
-
GGML_METAL_ADD_KERNEL(
|
845
|
-
GGML_METAL_ADD_KERNEL(
|
846
|
-
GGML_METAL_ADD_KERNEL(
|
847
|
-
GGML_METAL_ADD_KERNEL(
|
848
|
-
GGML_METAL_ADD_KERNEL(
|
849
|
-
GGML_METAL_ADD_KERNEL(
|
850
|
-
GGML_METAL_ADD_KERNEL(
|
851
|
-
GGML_METAL_ADD_KERNEL(
|
852
|
-
GGML_METAL_ADD_KERNEL(
|
853
|
-
GGML_METAL_ADD_KERNEL(
|
854
|
-
GGML_METAL_ADD_KERNEL(
|
855
|
-
GGML_METAL_ADD_KERNEL(
|
856
|
-
GGML_METAL_ADD_KERNEL(
|
857
|
-
GGML_METAL_ADD_KERNEL(
|
858
|
-
GGML_METAL_ADD_KERNEL(
|
859
|
-
GGML_METAL_ADD_KERNEL(
|
860
|
-
GGML_METAL_ADD_KERNEL(
|
861
|
-
GGML_METAL_ADD_KERNEL(
|
862
|
-
GGML_METAL_ADD_KERNEL(
|
863
|
-
GGML_METAL_ADD_KERNEL(
|
864
|
-
GGML_METAL_ADD_KERNEL(
|
865
|
-
GGML_METAL_ADD_KERNEL(
|
866
|
-
GGML_METAL_ADD_KERNEL(
|
867
|
-
GGML_METAL_ADD_KERNEL(
|
868
|
-
GGML_METAL_ADD_KERNEL(
|
869
|
-
GGML_METAL_ADD_KERNEL(
|
870
|
-
GGML_METAL_ADD_KERNEL(
|
871
|
-
GGML_METAL_ADD_KERNEL(
|
872
|
-
GGML_METAL_ADD_KERNEL(
|
873
|
-
GGML_METAL_ADD_KERNEL(
|
874
|
-
GGML_METAL_ADD_KERNEL(
|
875
|
-
GGML_METAL_ADD_KERNEL(
|
876
|
-
GGML_METAL_ADD_KERNEL(
|
877
|
-
GGML_METAL_ADD_KERNEL(
|
878
|
-
GGML_METAL_ADD_KERNEL(
|
879
|
-
GGML_METAL_ADD_KERNEL(
|
880
|
-
GGML_METAL_ADD_KERNEL(
|
881
|
-
GGML_METAL_ADD_KERNEL(
|
882
|
-
GGML_METAL_ADD_KERNEL(
|
883
|
-
GGML_METAL_ADD_KERNEL(
|
884
|
-
GGML_METAL_ADD_KERNEL(
|
885
|
-
GGML_METAL_ADD_KERNEL(
|
886
|
-
GGML_METAL_ADD_KERNEL(
|
887
|
-
GGML_METAL_ADD_KERNEL(
|
888
|
-
GGML_METAL_ADD_KERNEL(
|
889
|
-
GGML_METAL_ADD_KERNEL(
|
890
|
-
GGML_METAL_ADD_KERNEL(
|
891
|
-
GGML_METAL_ADD_KERNEL(
|
892
|
-
GGML_METAL_ADD_KERNEL(
|
893
|
-
GGML_METAL_ADD_KERNEL(
|
894
|
-
GGML_METAL_ADD_KERNEL(
|
895
|
-
GGML_METAL_ADD_KERNEL(
|
896
|
-
GGML_METAL_ADD_KERNEL(
|
897
|
-
GGML_METAL_ADD_KERNEL(
|
898
|
-
GGML_METAL_ADD_KERNEL(
|
899
|
-
GGML_METAL_ADD_KERNEL(
|
900
|
-
GGML_METAL_ADD_KERNEL(
|
901
|
-
GGML_METAL_ADD_KERNEL(
|
902
|
-
GGML_METAL_ADD_KERNEL(
|
903
|
-
GGML_METAL_ADD_KERNEL(
|
904
|
-
GGML_METAL_ADD_KERNEL(
|
905
|
-
GGML_METAL_ADD_KERNEL(
|
906
|
-
GGML_METAL_ADD_KERNEL(
|
907
|
-
GGML_METAL_ADD_KERNEL(
|
908
|
-
GGML_METAL_ADD_KERNEL(
|
909
|
-
GGML_METAL_ADD_KERNEL(
|
910
|
-
GGML_METAL_ADD_KERNEL(
|
911
|
-
GGML_METAL_ADD_KERNEL(
|
912
|
-
GGML_METAL_ADD_KERNEL(
|
913
|
-
GGML_METAL_ADD_KERNEL(
|
914
|
-
GGML_METAL_ADD_KERNEL(
|
915
|
-
GGML_METAL_ADD_KERNEL(
|
916
|
-
GGML_METAL_ADD_KERNEL(
|
917
|
-
GGML_METAL_ADD_KERNEL(
|
918
|
-
GGML_METAL_ADD_KERNEL(
|
919
|
-
GGML_METAL_ADD_KERNEL(
|
920
|
-
GGML_METAL_ADD_KERNEL(
|
921
|
-
GGML_METAL_ADD_KERNEL(
|
922
|
-
GGML_METAL_ADD_KERNEL(
|
923
|
-
GGML_METAL_ADD_KERNEL(
|
924
|
-
GGML_METAL_ADD_KERNEL(
|
925
|
-
GGML_METAL_ADD_KERNEL(
|
926
|
-
GGML_METAL_ADD_KERNEL(
|
927
|
-
GGML_METAL_ADD_KERNEL(
|
928
|
-
GGML_METAL_ADD_KERNEL(
|
929
|
-
GGML_METAL_ADD_KERNEL(
|
930
|
-
GGML_METAL_ADD_KERNEL(
|
931
|
-
GGML_METAL_ADD_KERNEL(
|
932
|
-
GGML_METAL_ADD_KERNEL(
|
933
|
-
GGML_METAL_ADD_KERNEL(
|
934
|
-
GGML_METAL_ADD_KERNEL(
|
935
|
-
GGML_METAL_ADD_KERNEL(
|
936
|
-
GGML_METAL_ADD_KERNEL(
|
937
|
-
GGML_METAL_ADD_KERNEL(
|
938
|
-
GGML_METAL_ADD_KERNEL(
|
939
|
-
GGML_METAL_ADD_KERNEL(
|
940
|
-
GGML_METAL_ADD_KERNEL(
|
941
|
-
GGML_METAL_ADD_KERNEL(
|
942
|
-
GGML_METAL_ADD_KERNEL(
|
943
|
-
GGML_METAL_ADD_KERNEL(
|
944
|
-
GGML_METAL_ADD_KERNEL(
|
945
|
-
GGML_METAL_ADD_KERNEL(
|
946
|
-
GGML_METAL_ADD_KERNEL(
|
947
|
-
GGML_METAL_ADD_KERNEL(
|
948
|
-
GGML_METAL_ADD_KERNEL(
|
949
|
-
GGML_METAL_ADD_KERNEL(
|
950
|
-
GGML_METAL_ADD_KERNEL(
|
951
|
-
GGML_METAL_ADD_KERNEL(
|
952
|
-
GGML_METAL_ADD_KERNEL(
|
953
|
-
GGML_METAL_ADD_KERNEL(
|
954
|
-
GGML_METAL_ADD_KERNEL(
|
955
|
-
GGML_METAL_ADD_KERNEL(
|
956
|
-
GGML_METAL_ADD_KERNEL(
|
957
|
-
GGML_METAL_ADD_KERNEL(
|
958
|
-
GGML_METAL_ADD_KERNEL(
|
959
|
-
GGML_METAL_ADD_KERNEL(
|
960
|
-
GGML_METAL_ADD_KERNEL(
|
961
|
-
GGML_METAL_ADD_KERNEL(
|
962
|
-
GGML_METAL_ADD_KERNEL(
|
963
|
-
GGML_METAL_ADD_KERNEL(
|
964
|
-
GGML_METAL_ADD_KERNEL(
|
965
|
-
GGML_METAL_ADD_KERNEL(
|
966
|
-
GGML_METAL_ADD_KERNEL(
|
967
|
-
GGML_METAL_ADD_KERNEL(
|
968
|
-
GGML_METAL_ADD_KERNEL(
|
969
|
-
GGML_METAL_ADD_KERNEL(
|
970
|
-
GGML_METAL_ADD_KERNEL(
|
971
|
-
GGML_METAL_ADD_KERNEL(
|
972
|
-
GGML_METAL_ADD_KERNEL(
|
973
|
-
GGML_METAL_ADD_KERNEL(
|
974
|
-
GGML_METAL_ADD_KERNEL(
|
975
|
-
GGML_METAL_ADD_KERNEL(
|
976
|
-
GGML_METAL_ADD_KERNEL(
|
977
|
-
GGML_METAL_ADD_KERNEL(
|
978
|
-
GGML_METAL_ADD_KERNEL(
|
979
|
-
GGML_METAL_ADD_KERNEL(
|
980
|
-
GGML_METAL_ADD_KERNEL(
|
981
|
-
GGML_METAL_ADD_KERNEL(
|
982
|
-
GGML_METAL_ADD_KERNEL(
|
983
|
-
GGML_METAL_ADD_KERNEL(
|
984
|
-
GGML_METAL_ADD_KERNEL(
|
985
|
-
GGML_METAL_ADD_KERNEL(
|
986
|
-
GGML_METAL_ADD_KERNEL(
|
987
|
-
GGML_METAL_ADD_KERNEL(
|
988
|
-
GGML_METAL_ADD_KERNEL(
|
989
|
-
GGML_METAL_ADD_KERNEL(
|
990
|
-
GGML_METAL_ADD_KERNEL(
|
991
|
-
GGML_METAL_ADD_KERNEL(
|
992
|
-
GGML_METAL_ADD_KERNEL(
|
993
|
-
GGML_METAL_ADD_KERNEL(
|
994
|
-
GGML_METAL_ADD_KERNEL(
|
995
|
-
GGML_METAL_ADD_KERNEL(
|
996
|
-
GGML_METAL_ADD_KERNEL(
|
997
|
-
GGML_METAL_ADD_KERNEL(
|
1088
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
1089
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
1090
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
|
1091
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
|
1092
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
1093
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
1094
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
1095
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
1096
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
1097
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
1098
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
1099
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
|
1100
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
1101
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
1102
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
1103
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
1104
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
1105
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
1106
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
1107
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
1108
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true);
|
1109
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true);
|
1110
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
1111
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
1112
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
1113
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
1114
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
|
1115
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
1116
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
1117
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
1118
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction);
|
1119
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
1120
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
1121
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
1122
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
1123
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat);
|
1124
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
1125
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
1126
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
1127
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
1128
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
1129
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
1130
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
1131
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
1132
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
|
1133
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
1134
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
1135
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
1136
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
1137
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
1138
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
1139
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
1140
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
|
1141
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
1142
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
1143
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
1144
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
1145
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
1146
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
1147
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
1148
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
1149
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
1150
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
1151
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
1152
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
1153
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
1154
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
1155
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
|
1156
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
|
1157
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
|
1158
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
|
1159
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
|
1160
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
|
1161
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction);
|
1162
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction);
|
1163
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
1164
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
1165
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
1166
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
1167
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
1168
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
1169
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
|
1170
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
|
1171
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
|
1172
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
|
1173
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
|
1174
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
|
1175
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
|
1176
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
|
1177
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
|
1178
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
|
1179
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
|
1180
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
|
1181
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
|
1182
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
|
1183
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
|
1184
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
|
1185
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
|
1186
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
|
1187
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
|
1188
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
|
1189
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
|
1190
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
|
1191
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
|
1192
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
|
1193
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction);
|
1194
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction);
|
1195
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction);
|
1196
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction);
|
1197
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction);
|
1198
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction);
|
1199
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction);
|
1200
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction);
|
1201
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction);
|
1202
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction);
|
1203
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction);
|
1204
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction);
|
1205
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction);
|
1206
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
|
1207
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
|
1208
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
|
1209
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction);
|
1210
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction);
|
1211
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction);
|
1212
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction);
|
1213
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction);
|
1214
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction);
|
1215
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction);
|
1216
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction);
|
1217
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction);
|
1218
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction);
|
1219
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction);
|
1220
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction);
|
1221
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction);
|
1222
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
|
1223
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
|
1224
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
|
1225
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
1226
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
|
1227
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
|
1228
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
|
1229
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
|
1230
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
|
1231
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
|
1232
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
|
1233
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
|
1234
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction);
|
1235
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction);
|
1236
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction);
|
1237
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction);
|
1238
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction);
|
1239
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction);
|
1240
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction);
|
1241
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction);
|
1242
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
|
1243
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
|
1244
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
|
1245
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
|
1246
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
|
1247
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
|
1248
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
|
1249
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
|
1250
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
|
1251
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
|
1252
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
|
1253
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
|
1254
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
|
1255
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
|
1256
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm);
|
1257
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm);
|
1258
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm);
|
1259
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm);
|
1260
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm);
|
1261
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm);
|
1262
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm);
|
1263
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm);
|
1264
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
|
1265
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
|
1266
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
1267
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
|
1268
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
|
1269
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
|
1270
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
|
1271
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
|
1272
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm);
|
1273
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm);
|
1274
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
|
1275
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
|
1276
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
|
1277
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
|
1278
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
|
1279
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
|
1280
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, mul_mm_id_q5_K_f16, has_simdgroup_mm);
|
1281
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, mul_mm_id_q6_K_f16, has_simdgroup_mm);
|
1282
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, mul_mm_id_iq2_xxs_f16, has_simdgroup_mm);
|
1283
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, mul_mm_id_iq2_xs_f16, has_simdgroup_mm);
|
1284
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, mul_mm_id_iq3_xxs_f16, has_simdgroup_mm);
|
1285
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, mul_mm_id_iq3_s_f16, has_simdgroup_mm);
|
1286
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, mul_mm_id_iq2_s_f16, has_simdgroup_mm);
|
1287
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, mul_mm_id_iq1_s_f16, has_simdgroup_mm);
|
1288
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm);
|
1289
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm);
|
1290
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
|
1291
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
1292
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
1293
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
|
1294
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true);
|
1295
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true);
|
1296
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true);
|
1297
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
1298
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
1299
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
1300
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
1301
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
|
1302
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
|
1303
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true);
|
1304
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
|
1305
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
1306
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
1307
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
|
1308
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
1309
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
1310
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
1311
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
1312
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
1313
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
|
1314
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
|
1315
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
|
1316
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
|
1317
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
|
1318
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
|
1319
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
|
1320
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
1321
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
|
1322
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
|
1323
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
|
1324
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
|
1325
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat);
|
1326
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat);
|
1327
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
|
1328
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
|
1329
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
|
1330
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
|
1331
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
1332
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
1333
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
1334
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm);
|
1335
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm);
|
1336
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
|
1337
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
|
1338
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
|
1339
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
|
1340
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
|
1341
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
|
1342
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
|
1343
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm);
|
1344
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm);
|
1345
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
|
1346
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
|
1347
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
|
1348
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
|
1349
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
|
1350
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
|
1351
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
|
1352
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm);
|
1353
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm);
|
1354
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
|
1355
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
|
1356
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
|
1357
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
|
1358
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
|
1359
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
|
1360
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
|
1361
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm);
|
1362
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm);
|
1363
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
|
1364
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
|
1365
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
|
1366
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
|
1367
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
|
1368
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
|
1369
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
|
1370
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm);
|
1371
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
|
1372
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
|
1373
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
|
1374
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
1375
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
|
1376
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
|
1377
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
|
1378
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
|
1379
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction);
|
1380
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction);
|
1381
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction);
|
1382
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction);
|
1383
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
|
1384
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
|
1385
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
|
1386
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction);
|
1387
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction);
|
1388
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction);
|
1389
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction);
|
1390
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
|
1391
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
|
1392
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
|
1393
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
|
1394
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
|
1395
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
|
1396
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
|
1397
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, flash_attn_ext_vec_f16_h192, has_simdgroup_reduction);
|
1398
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, flash_attn_ext_vec_bf16_h192, has_simdgroup_reduction && use_bfloat);
|
1399
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, flash_attn_ext_vec_q4_0_h192, has_simdgroup_reduction);
|
1400
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, flash_attn_ext_vec_q4_1_h192, has_simdgroup_reduction);
|
1401
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, flash_attn_ext_vec_q5_0_h192, has_simdgroup_reduction);
|
1402
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, flash_attn_ext_vec_q5_1_h192, has_simdgroup_reduction);
|
1403
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, flash_attn_ext_vec_q8_0_h192, has_simdgroup_reduction);
|
1404
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, flash_attn_ext_vec_f16_hk192_hv128, has_simdgroup_reduction);
|
1405
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, flash_attn_ext_vec_bf16_hk192_hv128, has_simdgroup_reduction && use_bfloat);
|
1406
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, flash_attn_ext_vec_q4_0_hk192_hv128, has_simdgroup_reduction);
|
1407
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, flash_attn_ext_vec_q4_1_hk192_hv128, has_simdgroup_reduction);
|
1408
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, flash_attn_ext_vec_q5_0_hk192_hv128, has_simdgroup_reduction);
|
1409
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, flash_attn_ext_vec_q5_1_hk192_hv128, has_simdgroup_reduction);
|
1410
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, flash_attn_ext_vec_q8_0_hk192_hv128, has_simdgroup_reduction);
|
1411
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
|
1412
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat);
|
1413
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
|
1414
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
|
1415
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
1416
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
1417
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
1418
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction);
|
1419
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat);
|
1420
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction);
|
1421
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction);
|
1422
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
|
1423
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
|
1424
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
|
1425
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
|
1426
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
|
1427
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
1428
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
1429
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
|
1430
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
1431
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
1432
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
|
1433
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
|
1434
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
1435
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
1436
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
1437
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
1438
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
1439
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
1440
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
|
1441
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true);
|
1442
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
|
1443
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true);
|
1444
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
|
1445
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true);
|
1446
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
|
1447
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true);
|
1448
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
|
1449
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true);
|
1450
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
1451
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
1452
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
1453
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
1454
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
1455
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
1456
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
1457
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
1458
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
1459
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
998
1460
|
}
|
999
1461
|
|
1000
|
-
[metal_library release];
|
1001
|
-
|
1002
1462
|
return ctx;
|
1003
1463
|
}
|
1004
1464
|
|
@@ -1013,6 +1473,12 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
|
1013
1473
|
|
1014
1474
|
[ctx->queue release];
|
1015
1475
|
|
1476
|
+
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
1477
|
+
// ctx->cmd_bufs[i].obj is auto released
|
1478
|
+
|
1479
|
+
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
|
1480
|
+
}
|
1481
|
+
|
1016
1482
|
dispatch_release(ctx->d_queue);
|
1017
1483
|
|
1018
1484
|
free(ctx);
|
@@ -1035,8 +1501,70 @@ struct ggml_backend_metal_buffer_context {
|
|
1035
1501
|
// multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
|
1036
1502
|
int n_buffers;
|
1037
1503
|
struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
1504
|
+
|
1505
|
+
// optional MTLResidencySet
|
1506
|
+
id rset;
|
1038
1507
|
};
|
1039
1508
|
|
1509
|
+
// rset init
|
1510
|
+
static bool ggml_backend_metal_buffer_rset_init(
|
1511
|
+
struct ggml_backend_metal_buffer_context * ctx,
|
1512
|
+
struct ggml_backend_metal_device_context * ctx_dev,
|
1513
|
+
id<MTLDevice> device) {
|
1514
|
+
ctx->rset = nil;
|
1515
|
+
|
1516
|
+
if (!ctx_dev->has_residency_sets) {
|
1517
|
+
return true;
|
1518
|
+
}
|
1519
|
+
|
1520
|
+
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
1521
|
+
if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
|
1522
|
+
MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init];
|
1523
|
+
desc.label = @"ggml_backend_metal";
|
1524
|
+
desc.initialCapacity = ctx->n_buffers;
|
1525
|
+
|
1526
|
+
NSError * error;
|
1527
|
+
ctx->rset = [device newResidencySetWithDescriptor:desc error:&error];
|
1528
|
+
if (error) {
|
1529
|
+
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
1530
|
+
[desc release];
|
1531
|
+
return false;
|
1532
|
+
}
|
1533
|
+
|
1534
|
+
[desc release];
|
1535
|
+
|
1536
|
+
for (int i = 0; i < ctx->n_buffers; i++) {
|
1537
|
+
[ctx->rset addAllocation:ctx->buffers[i].metal];
|
1538
|
+
}
|
1539
|
+
|
1540
|
+
[ctx->rset commit];
|
1541
|
+
[ctx->rset requestResidency];
|
1542
|
+
|
1543
|
+
return true;
|
1544
|
+
}
|
1545
|
+
#else
|
1546
|
+
GGML_UNUSED(ctx_dev);
|
1547
|
+
GGML_UNUSED(device);
|
1548
|
+
#endif
|
1549
|
+
|
1550
|
+
return true;
|
1551
|
+
}
|
1552
|
+
|
1553
|
+
// rset free
|
1554
|
+
static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer_context * ctx) {
|
1555
|
+
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
1556
|
+
if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
|
1557
|
+
if (ctx->rset) {
|
1558
|
+
[ctx->rset endResidency];
|
1559
|
+
[ctx->rset removeAllAllocations];
|
1560
|
+
[ctx->rset release];
|
1561
|
+
}
|
1562
|
+
}
|
1563
|
+
#else
|
1564
|
+
GGML_UNUSED(ctx);
|
1565
|
+
#endif
|
1566
|
+
}
|
1567
|
+
|
1040
1568
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
1041
1569
|
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
1042
1570
|
// Metal buffer based on the host memory pointer
|
@@ -1089,10 +1617,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
1089
1617
|
case GGML_UNARY_OP_RELU:
|
1090
1618
|
case GGML_UNARY_OP_SIGMOID:
|
1091
1619
|
case GGML_UNARY_OP_GELU:
|
1620
|
+
case GGML_UNARY_OP_GELU_ERF:
|
1092
1621
|
case GGML_UNARY_OP_GELU_QUICK:
|
1093
1622
|
case GGML_UNARY_OP_SILU:
|
1094
1623
|
case GGML_UNARY_OP_ELU:
|
1095
|
-
|
1624
|
+
case GGML_UNARY_OP_NEG:
|
1625
|
+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
1096
1626
|
default:
|
1097
1627
|
return false;
|
1098
1628
|
}
|
@@ -1102,61 +1632,73 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
1102
1632
|
case GGML_OP_TRANSPOSE:
|
1103
1633
|
case GGML_OP_PERMUTE:
|
1104
1634
|
case GGML_OP_CONCAT:
|
1635
|
+
return true;
|
1105
1636
|
case GGML_OP_ADD:
|
1106
1637
|
case GGML_OP_SUB:
|
1107
|
-
case GGML_OP_ACC:
|
1108
1638
|
case GGML_OP_MUL:
|
1109
1639
|
case GGML_OP_DIV:
|
1640
|
+
return op->src[0]->type == GGML_TYPE_F32;
|
1641
|
+
case GGML_OP_ACC:
|
1110
1642
|
case GGML_OP_REPEAT:
|
1111
1643
|
case GGML_OP_SCALE:
|
1112
|
-
case GGML_OP_CLAMP:
|
1113
1644
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
1114
1645
|
return true;
|
1646
|
+
case GGML_OP_CLAMP:
|
1647
|
+
return op->src[0]->type == GGML_TYPE_F32;
|
1115
1648
|
case GGML_OP_SQR:
|
1116
1649
|
case GGML_OP_SQRT:
|
1117
1650
|
case GGML_OP_SIN:
|
1118
1651
|
case GGML_OP_COS:
|
1119
|
-
return ggml_is_contiguous(op->src[0]);
|
1652
|
+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
1653
|
+
case GGML_OP_LOG:
|
1654
|
+
return false; // TODO: implement
|
1120
1655
|
case GGML_OP_SUM_ROWS:
|
1121
1656
|
case GGML_OP_SOFT_MAX:
|
1122
1657
|
case GGML_OP_GROUP_NORM:
|
1123
|
-
return has_simdgroup_reduction;
|
1658
|
+
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
1124
1659
|
case GGML_OP_RMS_NORM:
|
1125
|
-
|
1660
|
+
case GGML_OP_L2_NORM:
|
1661
|
+
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
1126
1662
|
case GGML_OP_ARGMAX:
|
1127
|
-
case GGML_OP_NORM:
|
1128
1663
|
return true;
|
1664
|
+
case GGML_OP_NORM:
|
1665
|
+
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
1129
1666
|
case GGML_OP_ROPE:
|
1130
|
-
|
1131
|
-
const int mode = ((const int32_t *) op->op_params)[2];
|
1132
|
-
if (mode & GGML_ROPE_TYPE_MROPE) {
|
1133
|
-
return false;
|
1134
|
-
}
|
1135
|
-
if (mode & GGML_ROPE_TYPE_VISION) {
|
1136
|
-
return false;
|
1137
|
-
}
|
1138
|
-
return true;
|
1139
|
-
}
|
1667
|
+
return true;
|
1140
1668
|
case GGML_OP_IM2COL:
|
1141
1669
|
return op->src[0]->type == GGML_TYPE_F16;
|
1142
1670
|
case GGML_OP_POOL_1D:
|
1143
1671
|
return false;
|
1144
|
-
case GGML_OP_POOL_2D:
|
1145
1672
|
case GGML_OP_UPSCALE:
|
1673
|
+
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
1674
|
+
case GGML_OP_POOL_2D:
|
1146
1675
|
case GGML_OP_PAD:
|
1147
1676
|
case GGML_OP_PAD_REFLECT_1D:
|
1148
|
-
case GGML_OP_ARANGE:
|
1149
1677
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
1150
1678
|
case GGML_OP_ARGSORT:
|
1151
1679
|
case GGML_OP_LEAKY_RELU:
|
1680
|
+
return op->src[0]->type == GGML_TYPE_F32;
|
1681
|
+
case GGML_OP_ARANGE:
|
1152
1682
|
return true;
|
1153
1683
|
case GGML_OP_FLASH_ATTN_EXT:
|
1684
|
+
if (op->src[0]->ne[0] == 32) {
|
1685
|
+
// head size == 32 (e.g. bert-bge-small)
|
1686
|
+
// TODO: not sure if it is worth adding kernels for this size
|
1687
|
+
return false;
|
1688
|
+
}
|
1689
|
+
if (op->src[0]->ne[0] == 576) {
|
1690
|
+
// DeepSeek sizes
|
1691
|
+
// TODO: disabled for now, until optmized
|
1692
|
+
return false;
|
1693
|
+
}
|
1154
1694
|
if (op->src[1]->type != op->src[2]->type) {
|
1155
1695
|
return false;
|
1156
1696
|
}
|
1157
1697
|
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
1158
1698
|
case GGML_OP_SSM_CONV:
|
1159
1699
|
case GGML_OP_SSM_SCAN:
|
1700
|
+
case GGML_OP_RWKV_WKV6:
|
1701
|
+
case GGML_OP_RWKV_WKV7:
|
1160
1702
|
return true;
|
1161
1703
|
case GGML_OP_MUL_MAT:
|
1162
1704
|
case GGML_OP_MUL_MAT_ID:
|
@@ -1198,6 +1740,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
1198
1740
|
default:
|
1199
1741
|
return false;
|
1200
1742
|
}
|
1743
|
+
case GGML_TYPE_Q4_0:
|
1744
|
+
case GGML_TYPE_Q4_1:
|
1745
|
+
case GGML_TYPE_Q5_0:
|
1746
|
+
case GGML_TYPE_Q5_1:
|
1747
|
+
case GGML_TYPE_Q8_0:
|
1748
|
+
switch (op->type) {
|
1749
|
+
case GGML_TYPE_F32:
|
1750
|
+
case GGML_TYPE_F16:
|
1751
|
+
return true;
|
1752
|
+
default:
|
1753
|
+
return false;
|
1754
|
+
}
|
1201
1755
|
default:
|
1202
1756
|
return false;
|
1203
1757
|
};
|
@@ -1222,10 +1776,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
1222
1776
|
}
|
1223
1777
|
}
|
1224
1778
|
|
1225
|
-
static
|
1779
|
+
static bool ggml_metal_encode_node(
|
1226
1780
|
ggml_backend_t backend,
|
1227
1781
|
int idx,
|
1228
|
-
id<MTLComputeCommandEncoder> encoder
|
1782
|
+
id<MTLComputeCommandEncoder> encoder,
|
1783
|
+
struct ggml_metal_mem_pool * mem_pool) {
|
1229
1784
|
struct ggml_backend_metal_context * ctx = backend->context;
|
1230
1785
|
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
1231
1786
|
|
@@ -1241,7 +1796,7 @@ static void ggml_metal_encode_node(
|
|
1241
1796
|
struct ggml_tensor * dst = node;
|
1242
1797
|
|
1243
1798
|
if (ggml_is_empty(dst)) {
|
1244
|
-
return;
|
1799
|
+
return true;
|
1245
1800
|
}
|
1246
1801
|
|
1247
1802
|
switch (dst->op) {
|
@@ -1252,7 +1807,7 @@ static void ggml_metal_encode_node(
|
|
1252
1807
|
case GGML_OP_PERMUTE:
|
1253
1808
|
{
|
1254
1809
|
// noop -> next node
|
1255
|
-
} return;
|
1810
|
+
} return true;
|
1256
1811
|
default:
|
1257
1812
|
{
|
1258
1813
|
} break;
|
@@ -1263,6 +1818,8 @@ static void ggml_metal_encode_node(
|
|
1263
1818
|
GGML_ABORT("unsupported op");
|
1264
1819
|
}
|
1265
1820
|
|
1821
|
+
ggml_metal_mem_pool_clear(mem_pool);
|
1822
|
+
|
1266
1823
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
1267
1824
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
1268
1825
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
@@ -1699,6 +2256,25 @@ static void ggml_metal_encode_node(
|
|
1699
2256
|
|
1700
2257
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1701
2258
|
} break;
|
2259
|
+
case GGML_UNARY_OP_GELU_ERF:
|
2260
|
+
{
|
2261
|
+
int64_t n = ggml_nelements(dst);
|
2262
|
+
|
2263
|
+
id<MTLComputePipelineState> pipeline = nil;
|
2264
|
+
|
2265
|
+
if (n % 4 == 0) {
|
2266
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline;
|
2267
|
+
n /= 4;
|
2268
|
+
} else {
|
2269
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline;
|
2270
|
+
}
|
2271
|
+
|
2272
|
+
[encoder setComputePipelineState:pipeline];
|
2273
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2274
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2275
|
+
|
2276
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2277
|
+
} break;
|
1702
2278
|
case GGML_UNARY_OP_GELU_QUICK:
|
1703
2279
|
{
|
1704
2280
|
int64_t n = ggml_nelements(dst);
|
@@ -1749,6 +2325,18 @@ static void ggml_metal_encode_node(
|
|
1749
2325
|
|
1750
2326
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1751
2327
|
} break;
|
2328
|
+
case GGML_UNARY_OP_NEG:
|
2329
|
+
{
|
2330
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
|
2331
|
+
|
2332
|
+
[encoder setComputePipelineState:pipeline];
|
2333
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2334
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2335
|
+
|
2336
|
+
const int64_t n = ggml_nelements(dst);
|
2337
|
+
|
2338
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2339
|
+
} break;
|
1752
2340
|
default:
|
1753
2341
|
{
|
1754
2342
|
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
@@ -1817,34 +2405,38 @@ static void ggml_metal_encode_node(
|
|
1817
2405
|
|
1818
2406
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
1819
2407
|
|
1820
|
-
|
2408
|
+
|
2409
|
+
ggml_metal_kargs_sum_rows args = {
|
2410
|
+
/*.ne00 =*/ ne00,
|
2411
|
+
/*.ne01 =*/ ne01,
|
2412
|
+
/*.ne02 =*/ ne02,
|
2413
|
+
/*.ne03 =*/ ne03,
|
2414
|
+
/*.nb00 =*/ nb00,
|
2415
|
+
/*.nb01 =*/ nb01,
|
2416
|
+
/*.nb02 =*/ nb02,
|
2417
|
+
/*.nb03 =*/ nb03,
|
2418
|
+
/*.ne10 =*/ ne10,
|
2419
|
+
/*.ne11 =*/ ne11,
|
2420
|
+
/*.ne12 =*/ ne12,
|
2421
|
+
/*.ne13 =*/ ne13,
|
2422
|
+
/*.nb10 =*/ nb10,
|
2423
|
+
/*.nb11 =*/ nb11,
|
2424
|
+
/*.nb12 =*/ nb12,
|
2425
|
+
/*.nb13 =*/ nb13,
|
2426
|
+
/*.ne0 =*/ ne0,
|
2427
|
+
/*.ne1 =*/ ne1,
|
2428
|
+
/*.ne2 =*/ ne2,
|
2429
|
+
/*.ne3 =*/ ne3,
|
2430
|
+
/*.nb0 =*/ nb0,
|
2431
|
+
/*.nb1 =*/ nb1,
|
2432
|
+
/*.nb2 =*/ nb2,
|
2433
|
+
/*.nb3 =*/ nb3,
|
2434
|
+
};
|
2435
|
+
|
1821
2436
|
[encoder setComputePipelineState:pipeline];
|
1822
2437
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1823
2438
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1824
|
-
[encoder setBytes:&
|
1825
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
1826
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
1827
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
1828
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
1829
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1830
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1831
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
1832
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
1833
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
1834
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
1835
|
-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
1836
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
1837
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
1838
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
1839
|
-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
1840
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
1841
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
|
1842
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
|
1843
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
|
1844
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
|
1845
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
|
1846
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
|
1847
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
|
2439
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
1848
2440
|
|
1849
2441
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1850
2442
|
} break;
|
@@ -1893,24 +2485,76 @@ static void ggml_metal_encode_node(
|
|
1893
2485
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
1894
2486
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
1895
2487
|
|
1896
|
-
|
1897
|
-
|
2488
|
+
// use this branch to test the ggml_metal_mem_pool functionality
|
2489
|
+
#if 0
|
2490
|
+
// cpy to tmp buffer in MTLHeap
|
2491
|
+
|
2492
|
+
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
|
2493
|
+
if (!h_src0) {
|
2494
|
+
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
|
2495
|
+
return false;
|
2496
|
+
}
|
2497
|
+
|
2498
|
+
offs_src0 = 0;
|
2499
|
+
|
2500
|
+
ggml_metal_kargs_cpy args_cpy = {
|
2501
|
+
/*.ne00 =*/ ne00,
|
2502
|
+
/*.ne01 =*/ ne01,
|
2503
|
+
/*.ne02 =*/ ne02,
|
2504
|
+
/*.ne03 =*/ ne03,
|
2505
|
+
/*.nb00 =*/ nb00,
|
2506
|
+
/*.nb01 =*/ nb01,
|
2507
|
+
/*.nb02 =*/ nb02,
|
2508
|
+
/*.nb03 =*/ nb03,
|
2509
|
+
/*.ne0 =*/ ne00,
|
2510
|
+
/*.ne1 =*/ ne01,
|
2511
|
+
/*.ne2 =*/ ne02,
|
2512
|
+
/*.ne3 =*/ ne03,
|
2513
|
+
/*.nb0 =*/ nb00,
|
2514
|
+
/*.nb1 =*/ nb01,
|
2515
|
+
/*.nb2 =*/ nb02,
|
2516
|
+
/*.nb3 =*/ nb03,
|
2517
|
+
};
|
2518
|
+
|
2519
|
+
if (src0->type == GGML_TYPE_F16) {
|
2520
|
+
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
|
2521
|
+
} else {
|
2522
|
+
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
|
2523
|
+
}
|
2524
|
+
[encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
|
2525
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
2526
|
+
[encoder setBuffer:h_src0 offset:0 atIndex:2];
|
2527
|
+
|
2528
|
+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
2529
|
+
int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
|
2530
|
+
|
2531
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
|
2532
|
+
|
2533
|
+
#else
|
2534
|
+
id<MTLBuffer> h_src0 = id_src0;
|
2535
|
+
#endif
|
2536
|
+
// softmax
|
2537
|
+
|
2538
|
+
ggml_metal_kargs_soft_max args = {
|
2539
|
+
/*.ne00 =*/ ne00,
|
2540
|
+
/*.ne01 =*/ ne01,
|
2541
|
+
/*.ne02 =*/ ne02,
|
2542
|
+
/*.scale =*/ scale,
|
2543
|
+
/*.max_bias =*/ max_bias,
|
2544
|
+
/*.m0 =*/ m0,
|
2545
|
+
/*.m1 =*/ m1,
|
2546
|
+
/*.n_head_log2 =*/ n_head_log2,
|
2547
|
+
};
|
2548
|
+
|
1898
2549
|
[encoder setComputePipelineState:pipeline];
|
1899
|
-
[encoder setBuffer:
|
2550
|
+
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:0];
|
1900
2551
|
if (id_src1) {
|
1901
|
-
[encoder setBuffer:id_src1 offset:offs_src1
|
2552
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1902
2553
|
} else {
|
1903
|
-
[encoder setBuffer:
|
2554
|
+
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
|
1904
2555
|
}
|
1905
|
-
[encoder setBuffer:id_dst
|
1906
|
-
[encoder setBytes:&
|
1907
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1908
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1909
|
-
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
1910
|
-
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
|
1911
|
-
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
|
1912
|
-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
|
1913
|
-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
|
2556
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
2557
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
1914
2558
|
|
1915
2559
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
1916
2560
|
|
@@ -1928,13 +2572,16 @@ static void ggml_metal_encode_node(
|
|
1928
2572
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
|
1929
2573
|
}
|
1930
2574
|
|
1931
|
-
|
2575
|
+
ggml_metal_kargs_diag_mask_inf args = {
|
2576
|
+
/*.ne00 =*/ ne00,
|
2577
|
+
/*.ne01 =*/ ne01,
|
2578
|
+
/*.n_past =*/ n_past,
|
2579
|
+
};
|
2580
|
+
|
1932
2581
|
[encoder setComputePipelineState:pipeline];
|
1933
2582
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1934
2583
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1935
|
-
[encoder setBytes:&
|
1936
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
1937
|
-
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
2584
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
1938
2585
|
|
1939
2586
|
if (ne00%8 == 0) {
|
1940
2587
|
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
@@ -1953,27 +2600,30 @@ static void ggml_metal_encode_node(
|
|
1953
2600
|
|
1954
2601
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
|
1955
2602
|
|
1956
|
-
|
2603
|
+
ggml_metal_kargs_ssm_conv args = {
|
2604
|
+
/*.ne00 =*/ ne00,
|
2605
|
+
/*.ne01 =*/ ne01,
|
2606
|
+
/*.ne02 =*/ ne02,
|
2607
|
+
/*.nb00 =*/ nb00,
|
2608
|
+
/*.nb01 =*/ nb01,
|
2609
|
+
/*.nb02 =*/ nb02,
|
2610
|
+
/*.ne10 =*/ ne10,
|
2611
|
+
/*.ne11 =*/ ne11,
|
2612
|
+
/*.nb10 =*/ nb10,
|
2613
|
+
/*.nb11 =*/ nb11,
|
2614
|
+
/*.ne0 =*/ ne0,
|
2615
|
+
/*.ne1 =*/ ne1,
|
2616
|
+
/*.ne2 =*/ ne2,
|
2617
|
+
/*.nb0 =*/ nb0,
|
2618
|
+
/*.nb1 =*/ nb1,
|
2619
|
+
/*.nb2 =*/ nb2,
|
2620
|
+
};
|
2621
|
+
|
1957
2622
|
[encoder setComputePipelineState:pipeline];
|
1958
2623
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1959
2624
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1960
2625
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1961
|
-
[encoder setBytes:&
|
1962
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1963
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1964
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
1965
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1966
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1967
|
-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
1968
|
-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
1969
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
|
1970
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
|
1971
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
1972
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
1973
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
|
1974
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
|
1975
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
|
1976
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
|
2626
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
1977
2627
|
|
1978
2628
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1979
2629
|
} break;
|
@@ -2024,7 +2674,31 @@ static void ggml_metal_encode_node(
|
|
2024
2674
|
|
2025
2675
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
2026
2676
|
|
2027
|
-
|
2677
|
+
ggml_metal_kargs_ssm_scan args = {
|
2678
|
+
/*.d_state =*/ d_state,
|
2679
|
+
/*.d_inner =*/ d_inner,
|
2680
|
+
/*.n_seq_tokens =*/ n_seq_tokens,
|
2681
|
+
/*.n_seqs =*/ n_seqs,
|
2682
|
+
/*.nb00 =*/ nb00,
|
2683
|
+
/*.nb01 =*/ nb01,
|
2684
|
+
/*.nb02 =*/ nb02,
|
2685
|
+
/*.nb10 =*/ nb10,
|
2686
|
+
/*.nb11 =*/ nb11,
|
2687
|
+
/*.nb12 =*/ nb12,
|
2688
|
+
/*.nb13 =*/ nb13,
|
2689
|
+
/*.nb20 =*/ nb20,
|
2690
|
+
/*.nb21 =*/ nb21,
|
2691
|
+
/*.nb22 =*/ nb22,
|
2692
|
+
/*.nb30 =*/ nb30,
|
2693
|
+
/*.nb31 =*/ nb31,
|
2694
|
+
/*.nb40 =*/ nb40,
|
2695
|
+
/*.nb41 =*/ nb41,
|
2696
|
+
/*.nb42 =*/ nb42,
|
2697
|
+
/*.nb50 =*/ nb50,
|
2698
|
+
/*.nb51 =*/ nb51,
|
2699
|
+
/*.nb52 =*/ nb52,
|
2700
|
+
};
|
2701
|
+
|
2028
2702
|
[encoder setComputePipelineState:pipeline];
|
2029
2703
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2030
2704
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
@@ -2033,33 +2707,87 @@ static void ggml_metal_encode_node(
|
|
2033
2707
|
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
2034
2708
|
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
2035
2709
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
2036
|
-
|
2037
|
-
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
|
2038
|
-
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
|
2039
|
-
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
|
2040
|
-
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
|
2041
|
-
|
2042
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
|
2043
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
|
2044
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
|
2045
|
-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
2046
|
-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
2047
|
-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
2048
|
-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
2049
|
-
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
|
2050
|
-
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
|
2051
|
-
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
|
2052
|
-
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
|
2053
|
-
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
|
2054
|
-
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
|
2055
|
-
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
|
2056
|
-
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
|
2057
|
-
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
|
2058
|
-
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
|
2059
|
-
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
|
2710
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:7];
|
2060
2711
|
|
2061
2712
|
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2062
2713
|
} break;
|
2714
|
+
case GGML_OP_RWKV_WKV6:
|
2715
|
+
{
|
2716
|
+
const int64_t B = dst->src[5]->ne[1];
|
2717
|
+
const int64_t T = dst->src[0]->ne[2];
|
2718
|
+
const int64_t C = dst->ne[0];
|
2719
|
+
const int64_t H = dst->src[0]->ne[1];
|
2720
|
+
|
2721
|
+
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
2722
|
+
GGML_ASSERT(C % H == 0);
|
2723
|
+
GGML_ASSERT(C / H == 64);
|
2724
|
+
|
2725
|
+
size_t offs_src3 = 0;
|
2726
|
+
size_t offs_src4 = 0;
|
2727
|
+
size_t offs_src5 = 0;
|
2728
|
+
|
2729
|
+
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
|
2730
|
+
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
|
2731
|
+
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
|
2732
|
+
|
2733
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
|
2734
|
+
|
2735
|
+
[encoder setComputePipelineState:pipeline];
|
2736
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2737
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2738
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
2739
|
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
2740
|
+
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
2741
|
+
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
2742
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
2743
|
+
|
2744
|
+
[encoder setBytes:&B length:sizeof(B) atIndex:7];
|
2745
|
+
[encoder setBytes:&T length:sizeof(T) atIndex:8];
|
2746
|
+
[encoder setBytes:&C length:sizeof(C) atIndex:9];
|
2747
|
+
[encoder setBytes:&H length:sizeof(H) atIndex:10];
|
2748
|
+
|
2749
|
+
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
|
2750
|
+
} break;
|
2751
|
+
case GGML_OP_RWKV_WKV7:
|
2752
|
+
{
|
2753
|
+
const int64_t B = dst->src[6]->ne[1];
|
2754
|
+
const int64_t T = dst->src[0]->ne[2];
|
2755
|
+
const int64_t C = dst->ne[0];
|
2756
|
+
const int64_t H = dst->src[0]->ne[1];
|
2757
|
+
|
2758
|
+
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
2759
|
+
GGML_ASSERT(C % H == 0);
|
2760
|
+
GGML_ASSERT(C / H == 64);
|
2761
|
+
|
2762
|
+
size_t offs_src3 = 0;
|
2763
|
+
size_t offs_src4 = 0;
|
2764
|
+
size_t offs_src5 = 0;
|
2765
|
+
size_t offs_src6 = 0;
|
2766
|
+
|
2767
|
+
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
|
2768
|
+
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
|
2769
|
+
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
|
2770
|
+
id<MTLBuffer> id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil;
|
2771
|
+
|
2772
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline;
|
2773
|
+
|
2774
|
+
[encoder setComputePipelineState:pipeline];
|
2775
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2776
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2777
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
2778
|
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
2779
|
+
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
2780
|
+
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
2781
|
+
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
|
2782
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
2783
|
+
|
2784
|
+
[encoder setBytes:&B length:sizeof(B) atIndex:8];
|
2785
|
+
[encoder setBytes:&T length:sizeof(T) atIndex:9];
|
2786
|
+
[encoder setBytes:&C length:sizeof(C) atIndex:10];
|
2787
|
+
[encoder setBytes:&H length:sizeof(H) atIndex:11];
|
2788
|
+
|
2789
|
+
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
|
2790
|
+
} break;
|
2063
2791
|
case GGML_OP_MUL_MAT:
|
2064
2792
|
{
|
2065
2793
|
GGML_ASSERT(ne00 == ne10);
|
@@ -2067,8 +2795,8 @@ static void ggml_metal_encode_node(
|
|
2067
2795
|
GGML_ASSERT(ne12 % ne02 == 0);
|
2068
2796
|
GGML_ASSERT(ne13 % ne03 == 0);
|
2069
2797
|
|
2070
|
-
const
|
2071
|
-
const
|
2798
|
+
const uint32_t r2 = ne12/ne02;
|
2799
|
+
const uint32_t r3 = ne13/ne03;
|
2072
2800
|
|
2073
2801
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
2074
2802
|
// to the matrix-vector kernel
|
@@ -2317,173 +3045,182 @@ static void ggml_metal_encode_node(
|
|
2317
3045
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
2318
3046
|
|
2319
3047
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
2320
|
-
[encoder dispatchThreadgroups:MTLSizeMake(
|
3048
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
2321
3049
|
} else {
|
2322
|
-
int nth0 = 32;
|
2323
|
-
int nth1 = 1;
|
2324
|
-
int nrows = 1;
|
2325
|
-
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
2326
|
-
|
2327
3050
|
id<MTLComputePipelineState> pipeline = nil;
|
2328
3051
|
|
3052
|
+
int nsg = 0; // number of simdgroups
|
3053
|
+
int nr0 = 0; // number of src0 rows per simdgroup
|
3054
|
+
int nr1 = 1; // number of src1 rows per threadgroup
|
3055
|
+
|
3056
|
+
size_t smem = 0; // shared memory
|
3057
|
+
|
2329
3058
|
// use custom matrix x vector kernel
|
2330
3059
|
switch (src0t) {
|
2331
3060
|
case GGML_TYPE_F32:
|
2332
3061
|
{
|
2333
3062
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
3063
|
+
nsg = 1;
|
3064
|
+
nr0 = 1;
|
3065
|
+
nr1 = 4;
|
2334
3066
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
2335
|
-
nrows = 4;
|
2336
3067
|
} break;
|
2337
3068
|
case GGML_TYPE_F16:
|
2338
3069
|
{
|
2339
|
-
|
2340
|
-
|
3070
|
+
nsg = 1;
|
3071
|
+
nr0 = 1;
|
2341
3072
|
if (src1t == GGML_TYPE_F32) {
|
2342
3073
|
if (ne11 * ne12 < 4) {
|
2343
3074
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
2344
3075
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
2345
3076
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
2346
|
-
|
3077
|
+
nr1 = ne11;
|
2347
3078
|
} else {
|
2348
3079
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
|
2349
|
-
|
3080
|
+
nr1 = 4;
|
2350
3081
|
}
|
2351
3082
|
} else {
|
2352
3083
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
|
2353
|
-
|
3084
|
+
nr1 = 4;
|
2354
3085
|
}
|
2355
3086
|
} break;
|
2356
3087
|
case GGML_TYPE_BF16:
|
2357
3088
|
{
|
2358
|
-
|
2359
|
-
|
3089
|
+
nsg = 1;
|
3090
|
+
nr0 = 1;
|
2360
3091
|
if (src1t == GGML_TYPE_F32) {
|
2361
3092
|
if (ne11 * ne12 < 4) {
|
2362
3093
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
2363
3094
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
2364
3095
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
2365
|
-
|
3096
|
+
nr1 = ne11;
|
2366
3097
|
} else {
|
2367
3098
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
|
2368
|
-
|
3099
|
+
nr1 = 4;
|
2369
3100
|
}
|
2370
3101
|
} else {
|
2371
3102
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
|
2372
|
-
|
3103
|
+
nr1 = 4;
|
2373
3104
|
}
|
2374
3105
|
} break;
|
2375
3106
|
case GGML_TYPE_Q4_0:
|
2376
3107
|
{
|
2377
|
-
|
2378
|
-
|
3108
|
+
nsg = N_SG_Q4_0;
|
3109
|
+
nr0 = N_R0_Q4_0;
|
2379
3110
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
|
2380
3111
|
} break;
|
2381
3112
|
case GGML_TYPE_Q4_1:
|
2382
3113
|
{
|
2383
|
-
|
2384
|
-
|
3114
|
+
nsg = N_SG_Q4_1;
|
3115
|
+
nr0 = N_R0_Q4_1;
|
2385
3116
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
|
2386
3117
|
} break;
|
2387
3118
|
case GGML_TYPE_Q5_0:
|
2388
3119
|
{
|
2389
|
-
|
2390
|
-
|
3120
|
+
nsg = N_SG_Q5_0;
|
3121
|
+
nr0 = N_R0_Q5_0;
|
2391
3122
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
|
2392
3123
|
} break;
|
2393
3124
|
case GGML_TYPE_Q5_1:
|
2394
3125
|
{
|
2395
|
-
|
2396
|
-
|
3126
|
+
nsg = N_SG_Q5_1;
|
3127
|
+
nr0 = N_R0_Q5_1;
|
2397
3128
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
|
2398
3129
|
} break;
|
2399
3130
|
case GGML_TYPE_Q8_0:
|
2400
3131
|
{
|
2401
|
-
|
2402
|
-
|
3132
|
+
nsg = N_SG_Q8_0;
|
3133
|
+
nr0 = N_R0_Q8_0;
|
2403
3134
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
2404
3135
|
} break;
|
2405
3136
|
case GGML_TYPE_Q2_K:
|
2406
3137
|
{
|
2407
|
-
|
2408
|
-
|
3138
|
+
nsg = N_SG_Q2_K;
|
3139
|
+
nr0 = N_R0_Q2_K;
|
2409
3140
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
|
2410
3141
|
} break;
|
2411
3142
|
case GGML_TYPE_Q3_K:
|
2412
3143
|
{
|
2413
|
-
|
2414
|
-
|
3144
|
+
nsg = N_SG_Q3_K;
|
3145
|
+
nr0 = N_R0_Q3_K;
|
2415
3146
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
|
2416
3147
|
} break;
|
2417
3148
|
case GGML_TYPE_Q4_K:
|
2418
3149
|
{
|
2419
|
-
|
2420
|
-
|
3150
|
+
nsg = N_SG_Q4_K;
|
3151
|
+
nr0 = N_R0_Q4_K;
|
2421
3152
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
|
2422
3153
|
} break;
|
2423
3154
|
case GGML_TYPE_Q5_K:
|
2424
3155
|
{
|
2425
|
-
|
2426
|
-
|
3156
|
+
nsg = N_SG_Q5_K;
|
3157
|
+
nr0 = N_R0_Q5_K;
|
2427
3158
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
|
2428
3159
|
} break;
|
2429
3160
|
case GGML_TYPE_Q6_K:
|
2430
3161
|
{
|
2431
|
-
|
2432
|
-
|
3162
|
+
nsg = N_SG_Q6_K;
|
3163
|
+
nr0 = N_R0_Q6_K;
|
2433
3164
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
|
2434
3165
|
} break;
|
2435
3166
|
case GGML_TYPE_IQ2_XXS:
|
2436
3167
|
{
|
2437
|
-
|
2438
|
-
|
3168
|
+
nsg = N_SG_IQ2_XXS;
|
3169
|
+
nr0 = N_R0_IQ2_XXS;
|
3170
|
+
smem = 256*8+128;
|
2439
3171
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
|
2440
3172
|
} break;
|
2441
3173
|
case GGML_TYPE_IQ2_XS:
|
2442
3174
|
{
|
2443
|
-
|
2444
|
-
|
3175
|
+
nsg = N_SG_IQ2_XS;
|
3176
|
+
nr0 = N_R0_IQ2_XS;
|
3177
|
+
smem = 512*8+128;
|
2445
3178
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
2446
3179
|
} break;
|
2447
3180
|
case GGML_TYPE_IQ3_XXS:
|
2448
3181
|
{
|
2449
|
-
|
2450
|
-
|
3182
|
+
nsg = N_SG_IQ3_XXS;
|
3183
|
+
nr0 = N_R0_IQ3_XXS;
|
3184
|
+
smem = 256*4+128;
|
2451
3185
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
2452
3186
|
} break;
|
2453
3187
|
case GGML_TYPE_IQ3_S:
|
2454
3188
|
{
|
2455
|
-
|
2456
|
-
|
3189
|
+
nsg = N_SG_IQ3_S;
|
3190
|
+
nr0 = N_R0_IQ3_S;
|
3191
|
+
smem = 512*4;
|
2457
3192
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
2458
3193
|
} break;
|
2459
3194
|
case GGML_TYPE_IQ2_S:
|
2460
3195
|
{
|
2461
|
-
|
2462
|
-
|
3196
|
+
nsg = N_SG_IQ2_S;
|
3197
|
+
nr0 = N_R0_IQ2_S;
|
2463
3198
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
|
2464
3199
|
} break;
|
2465
3200
|
case GGML_TYPE_IQ1_S:
|
2466
3201
|
{
|
2467
|
-
|
2468
|
-
|
3202
|
+
nsg = N_SG_IQ1_S;
|
3203
|
+
nr0 = N_R0_IQ1_S;
|
2469
3204
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
2470
3205
|
} break;
|
2471
3206
|
case GGML_TYPE_IQ1_M:
|
2472
3207
|
{
|
2473
|
-
|
2474
|
-
|
3208
|
+
nsg = N_SG_IQ1_M;
|
3209
|
+
nr0 = N_R0_IQ1_M;
|
2475
3210
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
|
2476
3211
|
} break;
|
2477
3212
|
case GGML_TYPE_IQ4_NL:
|
2478
3213
|
{
|
2479
|
-
|
2480
|
-
|
3214
|
+
nsg = N_SG_IQ4_NL;
|
3215
|
+
nr0 = N_R0_IQ4_NL;
|
3216
|
+
smem = 32*sizeof(float);
|
2481
3217
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
2482
3218
|
} break;
|
2483
3219
|
case GGML_TYPE_IQ4_XS:
|
2484
3220
|
{
|
2485
|
-
|
2486
|
-
|
3221
|
+
nsg = N_SG_IQ4_XS;
|
3222
|
+
nr0 = N_R0_IQ4_XS;
|
3223
|
+
smem = 32*sizeof(float);
|
2487
3224
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
2488
3225
|
} break;
|
2489
3226
|
default:
|
@@ -2520,47 +3257,14 @@ static void ggml_metal_encode_node(
|
|
2520
3257
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
2521
3258
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
2522
3259
|
|
2523
|
-
if (
|
2524
|
-
|
2525
|
-
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
2526
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2527
|
-
}
|
2528
|
-
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
2529
|
-
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
2530
|
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
2531
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2532
|
-
}
|
2533
|
-
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
2534
|
-
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
2535
|
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
2536
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2537
|
-
}
|
2538
|
-
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
2539
|
-
const int mem_size = 32*sizeof(float);
|
2540
|
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
2541
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2542
|
-
}
|
2543
|
-
else if (src0t == GGML_TYPE_Q4_K) {
|
2544
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2545
|
-
}
|
2546
|
-
else if (src0t == GGML_TYPE_Q3_K) {
|
2547
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2548
|
-
}
|
2549
|
-
else if (src0t == GGML_TYPE_Q5_K) {
|
2550
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2551
|
-
}
|
2552
|
-
else if (src0t == GGML_TYPE_Q6_K) {
|
2553
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2554
|
-
} else {
|
2555
|
-
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
2556
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
3260
|
+
if (smem > 0) {
|
3261
|
+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
2557
3262
|
}
|
3263
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
2558
3264
|
}
|
2559
3265
|
} break;
|
2560
3266
|
case GGML_OP_MUL_MAT_ID:
|
2561
3267
|
{
|
2562
|
-
const int n_as = src0->ne[2];
|
2563
|
-
|
2564
3268
|
// src2 = ids
|
2565
3269
|
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
2566
3270
|
|
@@ -2574,26 +3278,22 @@ static void ggml_metal_encode_node(
|
|
2574
3278
|
GGML_ASSERT(ne03 == 1);
|
2575
3279
|
GGML_ASSERT(ne13 == 1);
|
2576
3280
|
|
3281
|
+
const uint32_t r2 = 1;
|
3282
|
+
const uint32_t r3 = 1;
|
3283
|
+
|
2577
3284
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
2578
3285
|
// to the matrix-vector kernel
|
2579
3286
|
// ne20 = n_used_experts
|
2580
|
-
// ne21 = n_rows
|
2581
|
-
const int
|
2582
|
-
const int dst_rows_min = n_as;
|
2583
|
-
const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4;
|
2584
|
-
|
2585
|
-
// max size of the rowids array in the kernel shared buffer
|
2586
|
-
GGML_ASSERT(dst_rows <= dst_rows_max);
|
3287
|
+
// ne21 = n_rows (batch size)
|
3288
|
+
const int ne21_mm_id_min = 32;
|
2587
3289
|
|
2588
3290
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
2589
3291
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
2590
|
-
// !!!
|
2591
|
-
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
2592
|
-
// indirect matrix multiplication
|
2593
|
-
// !!!
|
2594
3292
|
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
2595
3293
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
2596
|
-
|
3294
|
+
(ne21 >= ne21_mm_id_min)) {
|
3295
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
3296
|
+
|
2597
3297
|
// some Metal matrix data types require aligned pointers
|
2598
3298
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
2599
3299
|
switch (src0->type) {
|
@@ -2603,203 +3303,319 @@ static void ggml_metal_encode_node(
|
|
2603
3303
|
default: break;
|
2604
3304
|
}
|
2605
3305
|
|
2606
|
-
|
3306
|
+
const int64_t neh10 = ne10; // n_embd
|
3307
|
+
const int64_t neh11 = ne21; // n_tokens
|
3308
|
+
const int64_t neh12 = ne02; // n_expert
|
2607
3309
|
|
2608
|
-
|
2609
|
-
|
2610
|
-
|
2611
|
-
|
2612
|
-
|
2613
|
-
|
2614
|
-
|
2615
|
-
|
2616
|
-
|
2617
|
-
|
2618
|
-
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
|
2619
|
-
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
|
2620
|
-
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
|
2621
|
-
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
|
2622
|
-
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
2623
|
-
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
2624
|
-
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
2625
|
-
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
|
2626
|
-
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
|
2627
|
-
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
2628
|
-
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
|
2629
|
-
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
2630
|
-
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
|
2631
|
-
default: GGML_ABORT("MUL_MAT_ID not implemented");
|
3310
|
+
const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
|
3311
|
+
const uint64_t nbh11 = nbh10*neh10;
|
3312
|
+
const uint64_t nbh12 = nbh11*neh11;
|
3313
|
+
const uint64_t nbh13 = nbh12*neh12;
|
3314
|
+
|
3315
|
+
const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
|
3316
|
+
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
|
3317
|
+
if (!h_src1) {
|
3318
|
+
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
|
3319
|
+
return false;
|
2632
3320
|
}
|
2633
3321
|
|
2634
|
-
|
2635
|
-
|
2636
|
-
|
2637
|
-
|
2638
|
-
|
2639
|
-
|
2640
|
-
|
2641
|
-
|
2642
|
-
|
2643
|
-
|
2644
|
-
|
2645
|
-
|
2646
|
-
|
2647
|
-
|
2648
|
-
|
2649
|
-
|
2650
|
-
|
3322
|
+
const int64_t neh0 = ne0;
|
3323
|
+
const int64_t neh1 = ne21;
|
3324
|
+
const int64_t neh2 = ne02;
|
3325
|
+
|
3326
|
+
const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
|
3327
|
+
const uint64_t nbh1 = nbh0*neh0;
|
3328
|
+
const uint64_t nbh2 = nbh1*neh1;
|
3329
|
+
//const uint64_t nbh3 = nbh2*neh2;
|
3330
|
+
|
3331
|
+
const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
|
3332
|
+
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
|
3333
|
+
if (!h_dst) {
|
3334
|
+
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
|
3335
|
+
return false;
|
3336
|
+
}
|
3337
|
+
|
3338
|
+
// tokens per expert
|
3339
|
+
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
|
3340
|
+
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
|
3341
|
+
if (!h_tpe) {
|
3342
|
+
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
|
3343
|
+
return false;
|
3344
|
+
}
|
3345
|
+
|
3346
|
+
// id map
|
3347
|
+
// [n_expert_used, n_tokens]
|
3348
|
+
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
|
3349
|
+
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
|
3350
|
+
if (!h_ids) {
|
3351
|
+
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
|
3352
|
+
return false;
|
3353
|
+
}
|
3354
|
+
|
3355
|
+
{
|
3356
|
+
const int nth = MIN(1024, ne10/4);
|
3357
|
+
|
3358
|
+
ggml_metal_kargs_mul_mm_id_map0 args = {
|
3359
|
+
ne10,
|
3360
|
+
ne11, // n_expert_used (bcast)
|
3361
|
+
nb11,
|
3362
|
+
nb12,
|
3363
|
+
neh11, // n_tokens
|
3364
|
+
nbh11,
|
3365
|
+
ne20, // n_expert_used
|
3366
|
+
nb21,
|
3367
|
+
};
|
3368
|
+
|
3369
|
+
id<MTLComputePipelineState> pipeline = nil;
|
3370
|
+
|
3371
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
|
3372
|
+
|
3373
|
+
[encoder setComputePipelineState:pipeline];
|
3374
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
3375
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
3376
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
3377
|
+
[encoder setBuffer: h_src1 offset:0 atIndex:3];
|
3378
|
+
[encoder setBuffer: h_tpe offset:0 atIndex:4];
|
3379
|
+
[encoder setBuffer: h_ids offset:0 atIndex:5];
|
3380
|
+
|
3381
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
3382
|
+
}
|
3383
|
+
|
3384
|
+
{
|
3385
|
+
id<MTLComputePipelineState> pipeline = nil;
|
3386
|
+
|
3387
|
+
switch (src0->type) {
|
3388
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16 ].pipeline; break;
|
3389
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16 ].pipeline; break;
|
3390
|
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16 ].pipeline; break;
|
3391
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16 ].pipeline; break;
|
3392
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16 ].pipeline; break;
|
3393
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
|
3394
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
|
3395
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
|
3396
|
+
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
|
3397
|
+
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
|
3398
|
+
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
|
3399
|
+
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16 ].pipeline; break;
|
3400
|
+
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16 ].pipeline; break;
|
3401
|
+
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break;
|
3402
|
+
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break;
|
3403
|
+
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break;
|
3404
|
+
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16 ].pipeline; break;
|
3405
|
+
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16 ].pipeline; break;
|
3406
|
+
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16 ].pipeline; break;
|
3407
|
+
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break;
|
3408
|
+
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break;
|
3409
|
+
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break;
|
3410
|
+
default: GGML_ABORT("MUL_MAT_ID not implemented");
|
3411
|
+
}
|
2651
3412
|
|
2652
|
-
|
2653
|
-
|
2654
|
-
|
2655
|
-
|
2656
|
-
|
2657
|
-
|
3413
|
+
ggml_metal_kargs_mul_mm_id args = {
|
3414
|
+
/*.ne00 =*/ ne00,
|
3415
|
+
/*.ne02 =*/ ne02,
|
3416
|
+
/*.nb01 =*/ nb01,
|
3417
|
+
/*.nb02 =*/ nb02,
|
3418
|
+
/*.nb03 =*/ nb03,
|
3419
|
+
/*.neh12 =*/ neh12,
|
3420
|
+
/*.nbh10 =*/ nbh10,
|
3421
|
+
/*.nbh11 =*/ nbh11,
|
3422
|
+
/*.nbh12 =*/ nbh12,
|
3423
|
+
/*.nbh13 =*/ nbh13,
|
3424
|
+
/*.neh0 =*/ neh0,
|
3425
|
+
/*.neh1 =*/ neh1,
|
3426
|
+
/*.r2 =*/ r2,
|
3427
|
+
/*.r3 =*/ r3,
|
3428
|
+
};
|
3429
|
+
|
3430
|
+
[encoder setComputePipelineState:pipeline];
|
3431
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
3432
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
3433
|
+
[encoder setBuffer: h_src1 offset:0 atIndex:2];
|
3434
|
+
[encoder setBuffer: h_tpe offset:0 atIndex:3];
|
3435
|
+
[encoder setBuffer: h_dst offset:0 atIndex:4];
|
3436
|
+
|
3437
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
3438
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
3439
|
+
}
|
2658
3440
|
|
2659
|
-
|
3441
|
+
{
|
3442
|
+
GGML_ASSERT(ne0 % 4 == 0);
|
2660
3443
|
|
2661
|
-
|
2662
|
-
} else {
|
2663
|
-
int nth0 = 32;
|
2664
|
-
int nth1 = 1;
|
2665
|
-
int nrows = 1;
|
2666
|
-
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
3444
|
+
const int nth = MIN(1024, ne0/4);
|
2667
3445
|
|
3446
|
+
ggml_metal_kargs_mul_mm_id_map1 args = {
|
3447
|
+
ne20, // n_expert_used
|
3448
|
+
neh0,
|
3449
|
+
neh1,
|
3450
|
+
nbh1,
|
3451
|
+
nbh2,
|
3452
|
+
ne0,
|
3453
|
+
nb1,
|
3454
|
+
nb2,
|
3455
|
+
};
|
3456
|
+
|
3457
|
+
id<MTLComputePipelineState> pipeline = nil;
|
3458
|
+
|
3459
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
|
3460
|
+
|
3461
|
+
[encoder setComputePipelineState:pipeline];
|
3462
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
3463
|
+
[encoder setBuffer: h_dst offset:0 atIndex:1];
|
3464
|
+
[encoder setBuffer: h_ids offset:0 atIndex:2];
|
3465
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
3466
|
+
|
3467
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
3468
|
+
}
|
3469
|
+
} else {
|
2668
3470
|
id<MTLComputePipelineState> pipeline = nil;
|
2669
3471
|
|
3472
|
+
int nsg = 0; // number of simdgroups
|
3473
|
+
int nr0 = 0; // number of src0 rows per simdgroup
|
3474
|
+
int nr1 = 1; // number of src1 rows per threadgroup
|
3475
|
+
|
3476
|
+
size_t smem = 0; // shared memory
|
3477
|
+
|
2670
3478
|
// use custom matrix x vector kernel
|
2671
3479
|
switch (src0t) {
|
2672
3480
|
case GGML_TYPE_F32:
|
2673
3481
|
{
|
2674
3482
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
3483
|
+
nsg = 1;
|
3484
|
+
nr0 = 1;
|
2675
3485
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
|
2676
3486
|
} break;
|
2677
3487
|
case GGML_TYPE_F16:
|
2678
3488
|
{
|
2679
3489
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
2680
|
-
|
2681
|
-
|
3490
|
+
nsg = 1;
|
3491
|
+
nr0 = 1;
|
2682
3492
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
2683
3493
|
} break;
|
2684
3494
|
case GGML_TYPE_BF16:
|
2685
3495
|
{
|
2686
3496
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
2687
|
-
|
2688
|
-
|
3497
|
+
nsg = 1;
|
3498
|
+
nr0 = 1;
|
2689
3499
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
|
2690
3500
|
} break;
|
2691
3501
|
case GGML_TYPE_Q4_0:
|
2692
3502
|
{
|
2693
|
-
|
2694
|
-
|
3503
|
+
nsg = N_SG_Q4_0;
|
3504
|
+
nr0 = N_R0_Q4_0;
|
2695
3505
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
|
2696
3506
|
} break;
|
2697
3507
|
case GGML_TYPE_Q4_1:
|
2698
3508
|
{
|
2699
|
-
|
2700
|
-
|
3509
|
+
nsg = N_SG_Q4_1;
|
3510
|
+
nr0 = N_R0_Q4_1;
|
2701
3511
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
|
2702
3512
|
} break;
|
2703
3513
|
case GGML_TYPE_Q5_0:
|
2704
3514
|
{
|
2705
|
-
|
2706
|
-
|
3515
|
+
nsg = N_SG_Q5_0;
|
3516
|
+
nr0 = N_R0_Q5_0;
|
2707
3517
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
|
2708
3518
|
} break;
|
2709
3519
|
case GGML_TYPE_Q5_1:
|
2710
3520
|
{
|
2711
|
-
|
2712
|
-
|
3521
|
+
nsg = N_SG_Q5_1;
|
3522
|
+
nr0 = N_R0_Q5_1;
|
2713
3523
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
|
2714
3524
|
} break;
|
2715
3525
|
case GGML_TYPE_Q8_0:
|
2716
3526
|
{
|
2717
|
-
|
2718
|
-
|
3527
|
+
nsg = N_SG_Q8_0;
|
3528
|
+
nr0 = N_R0_Q8_0;
|
2719
3529
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
|
2720
3530
|
} break;
|
2721
3531
|
case GGML_TYPE_Q2_K:
|
2722
3532
|
{
|
2723
|
-
|
2724
|
-
|
3533
|
+
nsg = N_SG_Q2_K;
|
3534
|
+
nr0 = N_R0_Q2_K;
|
2725
3535
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
|
2726
3536
|
} break;
|
2727
3537
|
case GGML_TYPE_Q3_K:
|
2728
3538
|
{
|
2729
|
-
|
2730
|
-
|
3539
|
+
nsg = N_SG_Q3_K;
|
3540
|
+
nr0 = N_R0_Q3_K;
|
2731
3541
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
|
2732
3542
|
} break;
|
2733
3543
|
case GGML_TYPE_Q4_K:
|
2734
3544
|
{
|
2735
|
-
|
2736
|
-
|
3545
|
+
nsg = N_SG_Q4_K;
|
3546
|
+
nr0 = N_R0_Q4_K;
|
2737
3547
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
|
2738
3548
|
} break;
|
2739
3549
|
case GGML_TYPE_Q5_K:
|
2740
3550
|
{
|
2741
|
-
|
2742
|
-
|
3551
|
+
nsg = N_SG_Q5_K;
|
3552
|
+
nr0 = N_R0_Q5_K;
|
2743
3553
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
|
2744
3554
|
} break;
|
2745
3555
|
case GGML_TYPE_Q6_K:
|
2746
3556
|
{
|
2747
|
-
|
2748
|
-
|
3557
|
+
nsg = N_SG_Q6_K;
|
3558
|
+
nr0 = N_R0_Q6_K;
|
2749
3559
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
|
2750
3560
|
} break;
|
2751
3561
|
case GGML_TYPE_IQ2_XXS:
|
2752
3562
|
{
|
2753
|
-
|
2754
|
-
|
3563
|
+
nsg = N_SG_IQ2_XXS;
|
3564
|
+
nr0 = N_R0_IQ2_XXS;
|
3565
|
+
smem = 256*8+128;
|
2755
3566
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
|
2756
3567
|
} break;
|
2757
3568
|
case GGML_TYPE_IQ2_XS:
|
2758
3569
|
{
|
2759
|
-
|
2760
|
-
|
3570
|
+
nsg = N_SG_IQ2_XS;
|
3571
|
+
nr0 = N_R0_IQ2_XS;
|
3572
|
+
smem = 512*8+128;
|
2761
3573
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
|
2762
3574
|
} break;
|
2763
3575
|
case GGML_TYPE_IQ3_XXS:
|
2764
3576
|
{
|
2765
|
-
|
2766
|
-
|
3577
|
+
nsg = N_SG_IQ3_XXS;
|
3578
|
+
nr0 = N_R0_IQ3_XXS;
|
3579
|
+
smem = 256*4+128;
|
2767
3580
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
|
2768
3581
|
} break;
|
2769
3582
|
case GGML_TYPE_IQ3_S:
|
2770
3583
|
{
|
2771
|
-
|
2772
|
-
|
3584
|
+
nsg = N_SG_IQ3_S;
|
3585
|
+
nr0 = N_R0_IQ3_S;
|
3586
|
+
smem = 512*4;
|
2773
3587
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
|
2774
3588
|
} break;
|
2775
3589
|
case GGML_TYPE_IQ2_S:
|
2776
3590
|
{
|
2777
|
-
|
2778
|
-
|
3591
|
+
nsg = N_SG_IQ2_S;
|
3592
|
+
nr0 = N_R0_IQ2_S;
|
2779
3593
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
|
2780
3594
|
} break;
|
2781
3595
|
case GGML_TYPE_IQ1_S:
|
2782
3596
|
{
|
2783
|
-
|
2784
|
-
|
3597
|
+
nsg = N_SG_IQ1_S;
|
3598
|
+
nr0 = N_R0_IQ1_S;
|
2785
3599
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
|
2786
3600
|
} break;
|
2787
3601
|
case GGML_TYPE_IQ1_M:
|
2788
3602
|
{
|
2789
|
-
|
2790
|
-
|
3603
|
+
nsg = N_SG_IQ1_M;
|
3604
|
+
nr0 = N_R0_IQ1_M;
|
2791
3605
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
|
2792
3606
|
} break;
|
2793
3607
|
case GGML_TYPE_IQ4_NL:
|
2794
3608
|
{
|
2795
|
-
|
2796
|
-
|
3609
|
+
nsg = N_SG_IQ4_NL;
|
3610
|
+
nr0 = N_R0_IQ4_NL;
|
3611
|
+
smem = 32*sizeof(float);
|
2797
3612
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
2798
3613
|
} break;
|
2799
3614
|
case GGML_TYPE_IQ4_XS:
|
2800
3615
|
{
|
2801
|
-
|
2802
|
-
|
3616
|
+
nsg = N_SG_IQ4_XS;
|
3617
|
+
nr0 = N_R0_IQ4_XS;
|
3618
|
+
smem = 32*sizeof(float);
|
2803
3619
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
2804
3620
|
} break;
|
2805
3621
|
default:
|
@@ -2810,7 +3626,7 @@ static void ggml_metal_encode_node(
|
|
2810
3626
|
};
|
2811
3627
|
|
2812
3628
|
if (ggml_is_quantized(src0t)) {
|
2813
|
-
GGML_ASSERT(ne00 >=
|
3629
|
+
GGML_ASSERT(ne00 >= nsg*nr0);
|
2814
3630
|
}
|
2815
3631
|
|
2816
3632
|
ggml_metal_kargs_mul_mv_id args = {
|
@@ -2843,43 +3659,12 @@ static void ggml_metal_encode_node(
|
|
2843
3659
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
2844
3660
|
|
2845
3661
|
const int64_t _ne1 = 1;
|
2846
|
-
const
|
3662
|
+
const int64_t ne123 = ne20*ne21;
|
2847
3663
|
|
2848
|
-
if (
|
2849
|
-
|
2850
|
-
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
2851
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2852
|
-
}
|
2853
|
-
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
2854
|
-
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
2855
|
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
2856
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2857
|
-
}
|
2858
|
-
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
2859
|
-
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
2860
|
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
2861
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2862
|
-
}
|
2863
|
-
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
2864
|
-
const int mem_size = 32*sizeof(float);
|
2865
|
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
2866
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2867
|
-
}
|
2868
|
-
else if (src0t == GGML_TYPE_Q4_K) {
|
2869
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2870
|
-
}
|
2871
|
-
else if (src0t == GGML_TYPE_Q3_K) {
|
2872
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2873
|
-
}
|
2874
|
-
else if (src0t == GGML_TYPE_Q5_K) {
|
2875
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2876
|
-
}
|
2877
|
-
else if (src0t == GGML_TYPE_Q6_K) {
|
2878
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2879
|
-
} else {
|
2880
|
-
const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
|
2881
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
3664
|
+
if (smem > 0) {
|
3665
|
+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
2882
3666
|
}
|
3667
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
2883
3668
|
}
|
2884
3669
|
} break;
|
2885
3670
|
case GGML_OP_GET_ROWS:
|
@@ -2913,19 +3698,22 @@ static void ggml_metal_encode_node(
|
|
2913
3698
|
default: GGML_ABORT("not implemented");
|
2914
3699
|
}
|
2915
3700
|
|
2916
|
-
|
3701
|
+
ggml_metal_kargs_get_rows args = {
|
3702
|
+
/*.ne00 =*/ ne00,
|
3703
|
+
/*.nb01 =*/ nb01,
|
3704
|
+
/*.nb02 =*/ nb02,
|
3705
|
+
/*.ne10 =*/ ne10,
|
3706
|
+
/*.nb10 =*/ nb10,
|
3707
|
+
/*.nb11 =*/ nb11,
|
3708
|
+
/*.nb1 =*/ nb1,
|
3709
|
+
/*.nb2 =*/ nb2,
|
3710
|
+
};
|
3711
|
+
|
2917
3712
|
[encoder setComputePipelineState:pipeline];
|
2918
3713
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2919
3714
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2920
3715
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
2921
|
-
[encoder setBytes:&
|
2922
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
2923
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
2924
|
-
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
2925
|
-
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
2926
|
-
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
2927
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
2928
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
3716
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
2929
3717
|
|
2930
3718
|
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
2931
3719
|
} break;
|
@@ -2963,6 +3751,42 @@ static void ggml_metal_encode_node(
|
|
2963
3751
|
|
2964
3752
|
const int64_t nrows = ggml_nrows(src0);
|
2965
3753
|
|
3754
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
3755
|
+
} break;
|
3756
|
+
case GGML_OP_L2_NORM:
|
3757
|
+
{
|
3758
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
3759
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
3760
|
+
|
3761
|
+
float eps;
|
3762
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
3763
|
+
|
3764
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline;
|
3765
|
+
|
3766
|
+
int nth = 32; // SIMD width
|
3767
|
+
|
3768
|
+
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
3769
|
+
nth *= 2;
|
3770
|
+
}
|
3771
|
+
|
3772
|
+
nth = MIN(nth, ne00/4);
|
3773
|
+
|
3774
|
+
ggml_metal_kargs_l2_norm args = {
|
3775
|
+
/*.ne00 =*/ ne00,
|
3776
|
+
/*.ne00_4 =*/ ne00/4,
|
3777
|
+
/*.nb01 =*/ nb01,
|
3778
|
+
/*.eps =*/ eps,
|
3779
|
+
};
|
3780
|
+
|
3781
|
+
[encoder setComputePipelineState:pipeline];
|
3782
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
3783
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
3784
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
3785
|
+
|
3786
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
3787
|
+
|
3788
|
+
const int64_t nrows = ggml_nrows(src0);
|
3789
|
+
|
2966
3790
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2967
3791
|
} break;
|
2968
3792
|
case GGML_OP_GROUP_NORM:
|
@@ -2982,18 +3806,21 @@ static void ggml_metal_encode_node(
|
|
2982
3806
|
|
2983
3807
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
|
2984
3808
|
|
2985
|
-
|
3809
|
+
ggml_metal_kargs_group_norm args = {
|
3810
|
+
/*.ne00 =*/ ne00,
|
3811
|
+
/*.ne01 =*/ ne01,
|
3812
|
+
/*.ne02 =*/ ne02,
|
3813
|
+
/*.nb00 =*/ nb00,
|
3814
|
+
/*.nb01 =*/ nb01,
|
3815
|
+
/*.nb02 =*/ nb02,
|
3816
|
+
/*.n_groups =*/ n_groups,
|
3817
|
+
/*.eps =*/ eps,
|
3818
|
+
};
|
3819
|
+
|
2986
3820
|
[encoder setComputePipelineState:pipeline];
|
2987
3821
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2988
3822
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2989
|
-
[encoder setBytes:&
|
2990
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
2991
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
2992
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
2993
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
2994
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
2995
|
-
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
2996
|
-
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
3823
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
2997
3824
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
2998
3825
|
|
2999
3826
|
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
@@ -3036,6 +3863,7 @@ static void ggml_metal_encode_node(
|
|
3036
3863
|
} break;
|
3037
3864
|
case GGML_OP_ROPE:
|
3038
3865
|
{
|
3866
|
+
|
3039
3867
|
// make sure we have one or more position id(ne10) per token(ne02)
|
3040
3868
|
GGML_ASSERT(ne10 % ne02 == 0);
|
3041
3869
|
GGML_ASSERT(ne10 >= ne02);
|
@@ -3062,20 +3890,42 @@ static void ggml_metal_encode_node(
|
|
3062
3890
|
memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
|
3063
3891
|
memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
|
3064
3892
|
|
3065
|
-
const bool is_neox
|
3893
|
+
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
3894
|
+
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
3895
|
+
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
3896
|
+
|
3897
|
+
// mrope
|
3898
|
+
const int sect_0 = ((const int32_t *) dst->op_params)[11];
|
3899
|
+
const int sect_1 = ((const int32_t *) dst->op_params)[12];
|
3900
|
+
const int sect_2 = ((const int32_t *) dst->op_params)[13];
|
3901
|
+
const int sect_3 = ((const int32_t *) dst->op_params)[14];
|
3066
3902
|
|
3067
3903
|
id<MTLComputePipelineState> pipeline = nil;
|
3068
3904
|
|
3069
|
-
if (
|
3905
|
+
if (is_neox) {
|
3070
3906
|
switch (src0->type) {
|
3071
|
-
case GGML_TYPE_F32: pipeline = ctx->kernels[
|
3072
|
-
case GGML_TYPE_F16: pipeline = ctx->kernels[
|
3907
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
3908
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
3909
|
+
default: GGML_ABORT("fatal error");
|
3910
|
+
};
|
3911
|
+
} else if (is_mrope && !is_vision) {
|
3912
|
+
GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
|
3913
|
+
switch (src0->type) {
|
3914
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
|
3915
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
|
3916
|
+
default: GGML_ABORT("fatal error");
|
3917
|
+
};
|
3918
|
+
} else if (is_vision) {
|
3919
|
+
GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
|
3920
|
+
switch (src0->type) {
|
3921
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
|
3922
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
|
3073
3923
|
default: GGML_ABORT("fatal error");
|
3074
3924
|
};
|
3075
3925
|
} else {
|
3076
3926
|
switch (src0->type) {
|
3077
|
-
case GGML_TYPE_F32: pipeline = ctx->kernels[
|
3078
|
-
case GGML_TYPE_F16: pipeline = ctx->kernels[
|
3927
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
3928
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
3079
3929
|
default: GGML_ABORT("fatal error");
|
3080
3930
|
};
|
3081
3931
|
}
|
@@ -3106,6 +3956,10 @@ static void ggml_metal_encode_node(
|
|
3106
3956
|
/*.attn_factor =*/ attn_factor,
|
3107
3957
|
/*.beta_fast =*/ beta_fast,
|
3108
3958
|
/*.beta_slow =*/ beta_slow,
|
3959
|
+
/* sect_0 =*/ sect_0,
|
3960
|
+
/* sect_1 =*/ sect_1,
|
3961
|
+
/* sect_2 =*/ sect_2,
|
3962
|
+
/* sect_3 =*/ sect_3,
|
3109
3963
|
};
|
3110
3964
|
|
3111
3965
|
[encoder setComputePipelineState:pipeline];
|
@@ -3151,8 +4005,8 @@ static void ggml_metal_encode_node(
|
|
3151
4005
|
|
3152
4006
|
const int32_t CHW = IC * KH * KW;
|
3153
4007
|
|
3154
|
-
const
|
3155
|
-
const
|
4008
|
+
const uint64_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
4009
|
+
const uint64_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
3156
4010
|
|
3157
4011
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
|
3158
4012
|
|
@@ -3174,27 +4028,30 @@ static void ggml_metal_encode_node(
|
|
3174
4028
|
default: GGML_ABORT("fatal error");
|
3175
4029
|
};
|
3176
4030
|
|
3177
|
-
|
4031
|
+
ggml_metal_kargs_im2col args = {
|
4032
|
+
/*.ofs0 =*/ ofs0,
|
4033
|
+
/*.ofs1 =*/ ofs1,
|
4034
|
+
/*.IW =*/ IW,
|
4035
|
+
/*.IH =*/ IH,
|
4036
|
+
/*.CHW =*/ CHW,
|
4037
|
+
/*.s0 =*/ s0,
|
4038
|
+
/*.s1 =*/ s1,
|
4039
|
+
/*.p0 =*/ p0,
|
4040
|
+
/*.p1 =*/ p1,
|
4041
|
+
/*.d0 =*/ d0,
|
4042
|
+
/*.d1 =*/ d1,
|
4043
|
+
/*.N =*/ N,
|
4044
|
+
/*.KH =*/ KH,
|
4045
|
+
/*.KW =*/ KW,
|
4046
|
+
/*.KHW =*/ KH * KW,
|
4047
|
+
};
|
4048
|
+
|
3178
4049
|
[encoder setComputePipelineState:pipeline];
|
3179
4050
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
3180
4051
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
3181
|
-
[encoder setBytes:&
|
3182
|
-
[encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
|
3183
|
-
[encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
|
3184
|
-
[encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
|
3185
|
-
[encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
|
3186
|
-
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
|
3187
|
-
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
|
3188
|
-
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
|
3189
|
-
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
|
3190
|
-
[encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
|
3191
|
-
[encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
|
4052
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
3192
4053
|
|
3193
4054
|
if (is_gt_mttpt) {
|
3194
|
-
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
|
3195
|
-
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
|
3196
|
-
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
|
3197
|
-
|
3198
4055
|
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
|
3199
4056
|
|
3200
4057
|
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
|
@@ -3234,16 +4091,20 @@ static void ggml_metal_encode_node(
|
|
3234
4091
|
default: GGML_ABORT("fatal error");
|
3235
4092
|
};
|
3236
4093
|
|
4094
|
+
ggml_metal_kargs_conv_transpose_1d args = {
|
4095
|
+
/*.IC =*/ IC,
|
4096
|
+
/*.IL =*/ IL,
|
4097
|
+
/*.K =*/ K,
|
4098
|
+
/*.s0 =*/ s0,
|
4099
|
+
/*.nb0 =*/ nb0,
|
4100
|
+
/*.nb1 =*/ nb1,
|
4101
|
+
};
|
4102
|
+
|
3237
4103
|
[encoder setComputePipelineState:pipeline];
|
3238
4104
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
3239
4105
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
3240
4106
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
3241
|
-
[encoder setBytes:&
|
3242
|
-
[encoder setBytes:&IL length:sizeof( int32_t) atIndex:4];
|
3243
|
-
[encoder setBytes:&K length:sizeof( int32_t) atIndex:5];
|
3244
|
-
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6];
|
3245
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7];
|
3246
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8];
|
4107
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
3247
4108
|
|
3248
4109
|
[encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
3249
4110
|
} break;
|
@@ -3258,30 +4119,33 @@ static void ggml_metal_encode_node(
|
|
3258
4119
|
|
3259
4120
|
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
|
3260
4121
|
|
3261
|
-
|
4122
|
+
ggml_metal_kargs_upscale args = {
|
4123
|
+
/*.ne00 =*/ ne00,
|
4124
|
+
/*.ne01 =*/ ne01,
|
4125
|
+
/*.ne02 =*/ ne02,
|
4126
|
+
/*.ne03 =*/ ne03,
|
4127
|
+
/*.nb00 =*/ nb00,
|
4128
|
+
/*.nb01 =*/ nb01,
|
4129
|
+
/*.nb02 =*/ nb02,
|
4130
|
+
/*.nb03 =*/ nb03,
|
4131
|
+
/*.ne0 =*/ ne0,
|
4132
|
+
/*.ne1 =*/ ne1,
|
4133
|
+
/*.ne2 =*/ ne2,
|
4134
|
+
/*.ne3 =*/ ne3,
|
4135
|
+
/*.nb0 =*/ nb0,
|
4136
|
+
/*.nb1 =*/ nb1,
|
4137
|
+
/*.nb2 =*/ nb2,
|
4138
|
+
/*.nb3 =*/ nb3,
|
4139
|
+
/*.sf0 =*/ sf0,
|
4140
|
+
/*.sf1 =*/ sf1,
|
4141
|
+
/*.sf2 =*/ sf2,
|
4142
|
+
/*.sf3 =*/ sf3
|
4143
|
+
};
|
4144
|
+
|
3262
4145
|
[encoder setComputePipelineState:pipeline];
|
3263
4146
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
3264
4147
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
3265
|
-
[encoder setBytes:&
|
3266
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
3267
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
3268
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
3269
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
3270
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
3271
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
3272
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
3273
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
3274
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
3275
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
3276
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
3277
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
3278
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
3279
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
3280
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
3281
|
-
[encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
|
3282
|
-
[encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
|
3283
|
-
[encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
|
3284
|
-
[encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
|
4148
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
3285
4149
|
|
3286
4150
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
3287
4151
|
|
@@ -3293,26 +4157,29 @@ static void ggml_metal_encode_node(
|
|
3293
4157
|
|
3294
4158
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
|
3295
4159
|
|
3296
|
-
|
4160
|
+
ggml_metal_kargs_pad args = {
|
4161
|
+
/*.ne00 =*/ ne00,
|
4162
|
+
/*.ne01 =*/ ne01,
|
4163
|
+
/*.ne02 =*/ ne02,
|
4164
|
+
/*.ne03 =*/ ne03,
|
4165
|
+
/*.nb00 =*/ nb00,
|
4166
|
+
/*.nb01 =*/ nb01,
|
4167
|
+
/*.nb02 =*/ nb02,
|
4168
|
+
/*.nb03 =*/ nb03,
|
4169
|
+
/*.ne0 =*/ ne0,
|
4170
|
+
/*.ne1 =*/ ne1,
|
4171
|
+
/*.ne2 =*/ ne2,
|
4172
|
+
/*.ne3 =*/ ne3,
|
4173
|
+
/*.nb0 =*/ nb0,
|
4174
|
+
/*.nb1 =*/ nb1,
|
4175
|
+
/*.nb2 =*/ nb2,
|
4176
|
+
/*.nb3 =*/ nb3
|
4177
|
+
};
|
4178
|
+
|
3297
4179
|
[encoder setComputePipelineState:pipeline];
|
3298
4180
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
3299
4181
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
3300
|
-
[encoder setBytes:&
|
3301
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
3302
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
3303
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
3304
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
3305
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
3306
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
3307
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
3308
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
3309
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
3310
|
-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
3311
|
-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
3312
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
3313
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
3314
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
3315
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
4182
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
3316
4183
|
|
3317
4184
|
const int nth = MIN(1024, ne0);
|
3318
4185
|
|
@@ -3327,24 +4194,31 @@ static void ggml_metal_encode_node(
|
|
3327
4194
|
|
3328
4195
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
|
3329
4196
|
|
4197
|
+
ggml_metal_kargs_pad_reflect_1d args = {
|
4198
|
+
/*.ne00 =*/ ne00,
|
4199
|
+
/*.ne01 =*/ ne01,
|
4200
|
+
/*.ne02 =*/ ne02,
|
4201
|
+
/*.ne03 =*/ ne03,
|
4202
|
+
/*.nb00 =*/ nb00,
|
4203
|
+
/*.nb01 =*/ nb01,
|
4204
|
+
/*.nb02 =*/ nb02,
|
4205
|
+
/*.nb03 =*/ nb03,
|
4206
|
+
/*.ne0 =*/ ne0,
|
4207
|
+
/*.ne1 =*/ ne1,
|
4208
|
+
/*.ne2 =*/ ne2,
|
4209
|
+
/*.ne3 =*/ ne3,
|
4210
|
+
/*.nb0 =*/ nb0,
|
4211
|
+
/*.nb1 =*/ nb1,
|
4212
|
+
/*.nb2 =*/ nb2,
|
4213
|
+
/*.nb3 =*/ nb3,
|
4214
|
+
/*.p0 =*/ p0,
|
4215
|
+
/*.p1 =*/ p1
|
4216
|
+
};
|
4217
|
+
|
3330
4218
|
[encoder setComputePipelineState:pipeline];
|
3331
4219
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
3332
4220
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
3333
|
-
[encoder setBytes:&
|
3334
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
3335
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
3336
|
-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
3337
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
|
3338
|
-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
3339
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
3340
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
3341
|
-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
3342
|
-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
|
3343
|
-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
|
3344
|
-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
|
3345
|
-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
|
3346
|
-
[encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
|
3347
|
-
[encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
|
4221
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
3348
4222
|
|
3349
4223
|
const int nth = MIN(1024, ne0);
|
3350
4224
|
|
@@ -3362,12 +4236,15 @@ static void ggml_metal_encode_node(
|
|
3362
4236
|
|
3363
4237
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
|
3364
4238
|
|
3365
|
-
|
4239
|
+
ggml_metal_kargs_arange args = {
|
4240
|
+
/*.ne0 =*/ ne0,
|
4241
|
+
/*.start =*/ start,
|
4242
|
+
/*.step =*/ step
|
4243
|
+
};
|
4244
|
+
|
3366
4245
|
[encoder setComputePipelineState:pipeline];
|
3367
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
3368
|
-
[encoder setBytes:&
|
3369
|
-
[encoder setBytes:&start length:sizeof(start) atIndex:2];
|
3370
|
-
[encoder setBytes:&step length:sizeof(step) atIndex:3];
|
4246
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
4247
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:1];
|
3371
4248
|
|
3372
4249
|
const int nth = MIN(1024, ne0);
|
3373
4250
|
|
@@ -3384,13 +4261,16 @@ static void ggml_metal_encode_node(
|
|
3384
4261
|
|
3385
4262
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
|
3386
4263
|
|
3387
|
-
|
4264
|
+
ggml_metal_kargs_timestep_embedding args = {
|
4265
|
+
/*.nb1 =*/ nb1,
|
4266
|
+
/*.dim =*/ dim,
|
4267
|
+
/*.max_period =*/ max_period
|
4268
|
+
};
|
4269
|
+
|
3388
4270
|
[encoder setComputePipelineState:pipeline];
|
3389
4271
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
3390
4272
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
3391
|
-
[encoder setBytes:&
|
3392
|
-
[encoder setBytes:&dim length:sizeof(dim) atIndex:3];
|
3393
|
-
[encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
|
4273
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
3394
4274
|
|
3395
4275
|
const int nth = MIN(1024, half);
|
3396
4276
|
|
@@ -3423,12 +4303,15 @@ static void ggml_metal_encode_node(
|
|
3423
4303
|
default: GGML_ABORT("fatal error");
|
3424
4304
|
};
|
3425
4305
|
|
3426
|
-
|
4306
|
+
ggml_metal_kargs_argsort args = {
|
4307
|
+
/*.ncols =*/ ne00,
|
4308
|
+
/*.ncols_pad =*/ ne00_padded
|
4309
|
+
};
|
4310
|
+
|
3427
4311
|
[encoder setComputePipelineState:pipeline];
|
3428
|
-
[encoder setBuffer:id_src0
|
3429
|
-
[encoder setBuffer:id_dst
|
3430
|
-
[encoder setBytes:&
|
3431
|
-
[encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
|
4312
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
4313
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
4314
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
3432
4315
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
3433
4316
|
|
3434
4317
|
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
|
@@ -3442,11 +4325,14 @@ static void ggml_metal_encode_node(
|
|
3442
4325
|
|
3443
4326
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
|
3444
4327
|
|
3445
|
-
|
4328
|
+
ggml_metal_kargs_leaky_relu args = {
|
4329
|
+
/*.slope =*/ slope
|
4330
|
+
};
|
4331
|
+
|
3446
4332
|
[encoder setComputePipelineState:pipeline];
|
3447
4333
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
3448
4334
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
3449
|
-
[encoder setBytes:&
|
4335
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
3450
4336
|
|
3451
4337
|
const int64_t n = ggml_nelements(dst);
|
3452
4338
|
|
@@ -3460,7 +4346,9 @@ static void ggml_metal_encode_node(
|
|
3460
4346
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
3461
4347
|
GGML_ASSERT(src1->type == src2->type);
|
3462
4348
|
|
3463
|
-
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
4349
|
+
//GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
4350
|
+
GGML_ASSERT(ne11 == ne21);
|
4351
|
+
GGML_ASSERT(ne12 == ne22);
|
3464
4352
|
|
3465
4353
|
struct ggml_tensor * src3 = node->src[3];
|
3466
4354
|
|
@@ -3507,125 +4395,175 @@ static void ggml_metal_encode_node(
|
|
3507
4395
|
|
3508
4396
|
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
3509
4397
|
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
3510
|
-
|
4398
|
+
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
|
4399
|
+
if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
|
3511
4400
|
switch (src1->type) {
|
3512
4401
|
case GGML_TYPE_F16:
|
3513
4402
|
{
|
3514
|
-
|
3515
|
-
|
3516
|
-
|
3517
|
-
|
3518
|
-
|
3519
|
-
|
3520
|
-
|
3521
|
-
|
3522
|
-
|
3523
|
-
|
3524
|
-
|
3525
|
-
|
3526
|
-
|
4403
|
+
if (ne00 == 192 && ne20 == 128) {
|
4404
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
|
4405
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4406
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
|
4407
|
+
} else {
|
4408
|
+
switch (ne00) {
|
4409
|
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
4410
|
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
4411
|
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
4412
|
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
4413
|
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
4414
|
+
case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline; break;
|
4415
|
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
4416
|
+
default:
|
4417
|
+
{
|
4418
|
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
4419
|
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
4420
|
+
GGML_ABORT("add template specialization for this size");
|
4421
|
+
}
|
4422
|
+
}
|
3527
4423
|
}
|
3528
4424
|
} break;
|
3529
4425
|
case GGML_TYPE_BF16:
|
3530
4426
|
{
|
3531
|
-
|
3532
|
-
|
3533
|
-
|
3534
|
-
|
3535
|
-
|
3536
|
-
|
3537
|
-
|
3538
|
-
|
3539
|
-
|
3540
|
-
|
3541
|
-
|
3542
|
-
|
3543
|
-
|
4427
|
+
if (ne00 == 192 && ne20 == 128) {
|
4428
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
|
4429
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4430
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
|
4431
|
+
} else {
|
4432
|
+
switch (ne00) {
|
4433
|
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
4434
|
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
|
4435
|
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
|
4436
|
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
|
4437
|
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
|
4438
|
+
case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline; break;
|
4439
|
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
|
4440
|
+
default:
|
4441
|
+
{
|
4442
|
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
4443
|
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
4444
|
+
GGML_ABORT("add template specialization for this size");
|
4445
|
+
}
|
4446
|
+
}
|
3544
4447
|
}
|
3545
4448
|
} break;
|
3546
4449
|
case GGML_TYPE_Q4_0:
|
3547
4450
|
{
|
3548
|
-
|
3549
|
-
|
3550
|
-
|
3551
|
-
|
3552
|
-
|
3553
|
-
|
3554
|
-
|
3555
|
-
|
3556
|
-
|
3557
|
-
|
3558
|
-
|
3559
|
-
|
3560
|
-
|
4451
|
+
if (ne00 == 192 && ne20 == 128) {
|
4452
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
|
4453
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4454
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
|
4455
|
+
} else {
|
4456
|
+
switch (ne00) {
|
4457
|
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
4458
|
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
|
4459
|
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
|
4460
|
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
|
4461
|
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
|
4462
|
+
case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline; break;
|
4463
|
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
|
4464
|
+
default:
|
4465
|
+
{
|
4466
|
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
4467
|
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
4468
|
+
GGML_ABORT("add template specialization for this size");
|
4469
|
+
}
|
4470
|
+
}
|
3561
4471
|
}
|
3562
4472
|
} break;
|
3563
4473
|
case GGML_TYPE_Q4_1:
|
3564
4474
|
{
|
3565
|
-
|
3566
|
-
|
3567
|
-
|
3568
|
-
|
3569
|
-
|
3570
|
-
|
3571
|
-
|
3572
|
-
|
3573
|
-
|
3574
|
-
|
3575
|
-
|
3576
|
-
|
3577
|
-
|
4475
|
+
if (ne00 == 192 && ne20 == 128) {
|
4476
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
|
4477
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4478
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
|
4479
|
+
} else {
|
4480
|
+
switch (ne00) {
|
4481
|
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
4482
|
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
|
4483
|
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
|
4484
|
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
|
4485
|
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
|
4486
|
+
case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline; break;
|
4487
|
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
|
4488
|
+
default:
|
4489
|
+
{
|
4490
|
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
4491
|
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
4492
|
+
GGML_ABORT("add template specialization for this size");
|
4493
|
+
}
|
4494
|
+
}
|
3578
4495
|
}
|
3579
4496
|
} break;
|
3580
4497
|
case GGML_TYPE_Q5_0:
|
3581
4498
|
{
|
3582
|
-
|
3583
|
-
|
3584
|
-
|
3585
|
-
|
3586
|
-
|
3587
|
-
|
3588
|
-
|
3589
|
-
|
3590
|
-
|
3591
|
-
|
3592
|
-
|
3593
|
-
|
3594
|
-
|
4499
|
+
if (ne00 == 192 && ne20 == 128) {
|
4500
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
|
4501
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4502
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
|
4503
|
+
} else {
|
4504
|
+
switch (ne00) {
|
4505
|
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
4506
|
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
|
4507
|
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
|
4508
|
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
|
4509
|
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
|
4510
|
+
case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline; break;
|
4511
|
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
|
4512
|
+
default:
|
4513
|
+
{
|
4514
|
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
4515
|
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
4516
|
+
GGML_ABORT("add template specialization for this size");
|
4517
|
+
}
|
4518
|
+
}
|
3595
4519
|
}
|
3596
4520
|
} break;
|
3597
4521
|
case GGML_TYPE_Q5_1:
|
3598
4522
|
{
|
3599
|
-
|
3600
|
-
|
3601
|
-
|
3602
|
-
|
3603
|
-
|
3604
|
-
|
3605
|
-
|
3606
|
-
|
3607
|
-
|
3608
|
-
|
3609
|
-
|
3610
|
-
|
3611
|
-
|
4523
|
+
if (ne00 == 192 && ne20 == 128) {
|
4524
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
|
4525
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4526
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
|
4527
|
+
} else {
|
4528
|
+
switch (ne00) {
|
4529
|
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
4530
|
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
|
4531
|
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
|
4532
|
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
|
4533
|
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
|
4534
|
+
case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline; break;
|
4535
|
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
|
4536
|
+
default:
|
4537
|
+
{
|
4538
|
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
4539
|
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
4540
|
+
GGML_ABORT("add template specialization for this size");
|
4541
|
+
}
|
4542
|
+
}
|
3612
4543
|
}
|
3613
4544
|
} break;
|
3614
4545
|
case GGML_TYPE_Q8_0:
|
3615
4546
|
{
|
3616
|
-
|
3617
|
-
|
3618
|
-
|
3619
|
-
|
3620
|
-
|
3621
|
-
|
3622
|
-
|
3623
|
-
|
3624
|
-
|
3625
|
-
|
3626
|
-
|
3627
|
-
|
3628
|
-
|
4547
|
+
if (ne00 == 192 && ne20 == 128) {
|
4548
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
|
4549
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4550
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
|
4551
|
+
} else {
|
4552
|
+
switch (ne00) {
|
4553
|
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
4554
|
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
|
4555
|
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
|
4556
|
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
|
4557
|
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
|
4558
|
+
case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline; break;
|
4559
|
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
|
4560
|
+
default:
|
4561
|
+
{
|
4562
|
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
4563
|
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
4564
|
+
GGML_ABORT("add template specialization for this size");
|
4565
|
+
}
|
4566
|
+
}
|
3629
4567
|
}
|
3630
4568
|
} break;
|
3631
4569
|
default:
|
@@ -3639,6 +4577,42 @@ static void ggml_metal_encode_node(
|
|
3639
4577
|
use_vec_kernel = true;
|
3640
4578
|
|
3641
4579
|
switch (ne00) {
|
4580
|
+
case 64:
|
4581
|
+
{
|
4582
|
+
switch (src1->type) {
|
4583
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break;
|
4584
|
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break;
|
4585
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break;
|
4586
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break;
|
4587
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break;
|
4588
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break;
|
4589
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break;
|
4590
|
+
default:
|
4591
|
+
{
|
4592
|
+
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
4593
|
+
GGML_LOG_ERROR("add template specialization for this type\n");
|
4594
|
+
GGML_ABORT("add template specialization for this type");
|
4595
|
+
}
|
4596
|
+
}
|
4597
|
+
} break;
|
4598
|
+
case 96:
|
4599
|
+
{
|
4600
|
+
switch (src1->type) {
|
4601
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break;
|
4602
|
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break;
|
4603
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break;
|
4604
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break;
|
4605
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break;
|
4606
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break;
|
4607
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break;
|
4608
|
+
default:
|
4609
|
+
{
|
4610
|
+
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
4611
|
+
GGML_LOG_ERROR("add template specialization for this type\n");
|
4612
|
+
GGML_ABORT("add template specialization for this type");
|
4613
|
+
}
|
4614
|
+
}
|
4615
|
+
} break;
|
3642
4616
|
case 128:
|
3643
4617
|
{
|
3644
4618
|
switch (src1->type) {
|
@@ -3657,6 +4631,42 @@ static void ggml_metal_encode_node(
|
|
3657
4631
|
}
|
3658
4632
|
}
|
3659
4633
|
} break;
|
4634
|
+
case 192:
|
4635
|
+
{
|
4636
|
+
if (ne20 == 128) {
|
4637
|
+
switch (src1->type) {
|
4638
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break;
|
4639
|
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break;
|
4640
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break;
|
4641
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break;
|
4642
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break;
|
4643
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break;
|
4644
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break;
|
4645
|
+
default:
|
4646
|
+
{
|
4647
|
+
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
4648
|
+
GGML_LOG_ERROR("add template specialization for this type\n");
|
4649
|
+
GGML_ABORT("add template specialization for this type");
|
4650
|
+
}
|
4651
|
+
}
|
4652
|
+
} else {
|
4653
|
+
switch (src1->type) {
|
4654
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break;
|
4655
|
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break;
|
4656
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break;
|
4657
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break;
|
4658
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break;
|
4659
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break;
|
4660
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break;
|
4661
|
+
default:
|
4662
|
+
{
|
4663
|
+
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
4664
|
+
GGML_LOG_ERROR("add template specialization for this type\n");
|
4665
|
+
GGML_ABORT("add template specialization for this type");
|
4666
|
+
}
|
4667
|
+
}
|
4668
|
+
}
|
4669
|
+
} break;
|
3660
4670
|
case 256:
|
3661
4671
|
{
|
3662
4672
|
switch (src1->type) {
|
@@ -3675,12 +4685,36 @@ static void ggml_metal_encode_node(
|
|
3675
4685
|
}
|
3676
4686
|
}
|
3677
4687
|
} break;
|
4688
|
+
case 576:
|
4689
|
+
{
|
4690
|
+
if (ne20 == 512) {
|
4691
|
+
switch (src1->type) {
|
4692
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
|
4693
|
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
|
4694
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
|
4695
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
|
4696
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
|
4697
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
|
4698
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
|
4699
|
+
default:
|
4700
|
+
{
|
4701
|
+
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
4702
|
+
GGML_LOG_ERROR("add template specialization for this type\n");
|
4703
|
+
GGML_ABORT("add template specialization for this type");
|
4704
|
+
}
|
4705
|
+
}
|
4706
|
+
} else {
|
4707
|
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
|
4708
|
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
4709
|
+
GGML_ABORT("add template specialization for this size");
|
4710
|
+
}
|
4711
|
+
} break;
|
3678
4712
|
default:
|
3679
|
-
|
3680
|
-
|
3681
|
-
|
3682
|
-
|
3683
|
-
|
4713
|
+
{
|
4714
|
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
4715
|
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
4716
|
+
GGML_ABORT("add template specialization for this size");
|
4717
|
+
}
|
3684
4718
|
}
|
3685
4719
|
}
|
3686
4720
|
|
@@ -3694,9 +4728,12 @@ static void ggml_metal_encode_node(
|
|
3694
4728
|
/*.ne11 =*/ ne11,
|
3695
4729
|
/*.ne_12_2 =*/ ne12,
|
3696
4730
|
/*.ne_12_3 =*/ ne13,
|
3697
|
-
/*.
|
3698
|
-
/*.
|
3699
|
-
/*.
|
4731
|
+
/*.nb11 =*/ nb11,
|
4732
|
+
/*.nb12 =*/ nb12,
|
4733
|
+
/*.nb13 =*/ nb13,
|
4734
|
+
/*.nb21 =*/ nb21,
|
4735
|
+
/*.nb22 =*/ nb22,
|
4736
|
+
/*.nb23 =*/ nb23,
|
3700
4737
|
/*.nb31 =*/ nb31,
|
3701
4738
|
/*.ne1 =*/ ne1,
|
3702
4739
|
/*.ne2 =*/ ne2,
|
@@ -3775,10 +4812,9 @@ static void ggml_metal_encode_node(
|
|
3775
4812
|
// ne00*(nsg)
|
3776
4813
|
// each simdgroup has a full f16 head vector in shared mem to accumulate results
|
3777
4814
|
//
|
3778
|
-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 +
|
4815
|
+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
|
3779
4816
|
|
3780
4817
|
int64_t nsgmax = 2;
|
3781
|
-
|
3782
4818
|
while (true) {
|
3783
4819
|
const size_t smem = FATTN_SMEM(nsgmax);
|
3784
4820
|
if (smem > device.maxThreadgroupMemoryLength) {
|
@@ -3810,10 +4846,6 @@ static void ggml_metal_encode_node(
|
|
3810
4846
|
case GGML_OP_CPY:
|
3811
4847
|
case GGML_OP_CONT:
|
3812
4848
|
{
|
3813
|
-
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
3814
|
-
|
3815
|
-
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
3816
|
-
|
3817
4849
|
id<MTLComputePipelineState> pipeline = nil;
|
3818
4850
|
|
3819
4851
|
switch (src0t) {
|
@@ -3847,7 +4879,47 @@ static void ggml_metal_encode_node(
|
|
3847
4879
|
switch (dstt) {
|
3848
4880
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
|
3849
4881
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
|
3850
|
-
default:
|
4882
|
+
default: GGML_ABORT("not implemented");
|
4883
|
+
};
|
4884
|
+
} break;
|
4885
|
+
case GGML_TYPE_Q4_0:
|
4886
|
+
{
|
4887
|
+
switch (dstt) {
|
4888
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
|
4889
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break;
|
4890
|
+
default: GGML_ABORT("not implemented");
|
4891
|
+
};
|
4892
|
+
} break;
|
4893
|
+
case GGML_TYPE_Q4_1:
|
4894
|
+
{
|
4895
|
+
switch (dstt) {
|
4896
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
|
4897
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break;
|
4898
|
+
default: GGML_ABORT("not implemented");
|
4899
|
+
};
|
4900
|
+
} break;
|
4901
|
+
case GGML_TYPE_Q5_0:
|
4902
|
+
{
|
4903
|
+
switch (dstt) {
|
4904
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
|
4905
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break;
|
4906
|
+
default: GGML_ABORT("not implemented");
|
4907
|
+
};
|
4908
|
+
} break;
|
4909
|
+
case GGML_TYPE_Q5_1:
|
4910
|
+
{
|
4911
|
+
switch (dstt) {
|
4912
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
|
4913
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break;
|
4914
|
+
default: GGML_ABORT("not implemented");
|
4915
|
+
};
|
4916
|
+
} break;
|
4917
|
+
case GGML_TYPE_Q8_0:
|
4918
|
+
{
|
4919
|
+
switch (dstt) {
|
4920
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
|
4921
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break;
|
4922
|
+
default: GGML_ABORT("not implemented");
|
3851
4923
|
};
|
3852
4924
|
} break;
|
3853
4925
|
default: GGML_ABORT("not implemented");
|
@@ -3877,7 +4949,11 @@ static void ggml_metal_encode_node(
|
|
3877
4949
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
3878
4950
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
3879
4951
|
|
4952
|
+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
4953
|
+
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
4954
|
+
|
3880
4955
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
4956
|
+
|
3881
4957
|
} break;
|
3882
4958
|
case GGML_OP_SET:
|
3883
4959
|
{
|
@@ -3982,21 +5058,24 @@ static void ggml_metal_encode_node(
|
|
3982
5058
|
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
|
3983
5059
|
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
|
3984
5060
|
|
3985
|
-
|
5061
|
+
ggml_metal_kargs_pool_2d args_pool_2d = {
|
5062
|
+
/* .k0 = */ k0,
|
5063
|
+
/* .k1 = */ k1,
|
5064
|
+
/* .s0 = */ s0,
|
5065
|
+
/* .s1 = */ s1,
|
5066
|
+
/* .p0 = */ p0,
|
5067
|
+
/* .p1 = */ p1,
|
5068
|
+
/* .IH = */ IH,
|
5069
|
+
/* .IW = */ IW,
|
5070
|
+
/* .OH = */ OH,
|
5071
|
+
/* .OW = */ OW,
|
5072
|
+
/* .parallel_elements = */ parallel_elements
|
5073
|
+
};
|
5074
|
+
|
3986
5075
|
[encoder setComputePipelineState:pipeline];
|
3987
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
3988
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
3989
|
-
[encoder setBytes:&
|
3990
|
-
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
|
3991
|
-
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
|
3992
|
-
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
|
3993
|
-
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
|
3994
|
-
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
|
3995
|
-
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
|
3996
|
-
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
|
3997
|
-
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
|
3998
|
-
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
|
3999
|
-
[encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12];
|
5076
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
5077
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
5078
|
+
[encoder setBytes:&args_pool_2d length:sizeof(args_pool_2d) atIndex:2];
|
4000
5079
|
|
4001
5080
|
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
4002
5081
|
} break;
|
@@ -4031,6 +5110,8 @@ static void ggml_metal_encode_node(
|
|
4031
5110
|
GGML_ABORT("fatal error");
|
4032
5111
|
}
|
4033
5112
|
}
|
5113
|
+
|
5114
|
+
return true;
|
4034
5115
|
}
|
4035
5116
|
|
4036
5117
|
static enum ggml_status ggml_metal_graph_compute(
|
@@ -4084,25 +5165,25 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
4084
5165
|
}
|
4085
5166
|
|
4086
5167
|
// the main thread commits the first few commands immediately
|
4087
|
-
//
|
5168
|
+
// cmd_buf[n_cb]
|
4088
5169
|
{
|
4089
|
-
id<MTLCommandBuffer>
|
4090
|
-
ctx->
|
5170
|
+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
5171
|
+
ctx->cmd_bufs[n_cb].obj = cmd_buf;
|
4091
5172
|
|
4092
|
-
[
|
5173
|
+
[cmd_buf enqueue];
|
4093
5174
|
ctx->encode_async(n_cb);
|
4094
5175
|
}
|
4095
5176
|
|
4096
5177
|
// prepare the rest of the command buffers asynchronously
|
4097
|
-
//
|
5178
|
+
// cmd_buf[0.. n_cb)
|
4098
5179
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
4099
|
-
id<MTLCommandBuffer>
|
4100
|
-
ctx->
|
5180
|
+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
5181
|
+
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
|
4101
5182
|
|
4102
5183
|
// always enqueue the first two command buffers
|
4103
5184
|
// enqueue all of the command buffers if we don't need to abort
|
4104
5185
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
4105
|
-
[
|
5186
|
+
[cmd_buf enqueue];
|
4106
5187
|
}
|
4107
5188
|
}
|
4108
5189
|
|
@@ -4111,14 +5192,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
4111
5192
|
// wait for completion and check status of each command buffer
|
4112
5193
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
4113
5194
|
{
|
4114
|
-
id<MTLCommandBuffer>
|
4115
|
-
[
|
5195
|
+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
|
5196
|
+
[cmd_buf waitUntilCompleted];
|
4116
5197
|
|
4117
|
-
MTLCommandBufferStatus status = [
|
5198
|
+
MTLCommandBufferStatus status = [cmd_buf status];
|
4118
5199
|
if (status != MTLCommandBufferStatusCompleted) {
|
4119
5200
|
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
4120
5201
|
if (status == MTLCommandBufferStatusError) {
|
4121
|
-
GGML_LOG_INFO("error: %s\n", [[
|
5202
|
+
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
4122
5203
|
}
|
4123
5204
|
|
4124
5205
|
return GGML_STATUS_FAILED;
|
@@ -4126,20 +5207,20 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
4126
5207
|
}
|
4127
5208
|
|
4128
5209
|
for (int i = 0; i < n_cb; ++i) {
|
4129
|
-
id<MTLCommandBuffer>
|
4130
|
-
[
|
5210
|
+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
|
5211
|
+
[cmd_buf waitUntilCompleted];
|
4131
5212
|
|
4132
|
-
MTLCommandBufferStatus status = [
|
5213
|
+
MTLCommandBufferStatus status = [cmd_buf status];
|
4133
5214
|
if (status != MTLCommandBufferStatusCompleted) {
|
4134
5215
|
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
4135
5216
|
if (status == MTLCommandBufferStatusError) {
|
4136
|
-
GGML_LOG_INFO("error: %s\n", [[
|
5217
|
+
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
4137
5218
|
}
|
4138
5219
|
|
4139
5220
|
return GGML_STATUS_FAILED;
|
4140
5221
|
}
|
4141
5222
|
|
4142
|
-
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->
|
5223
|
+
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
|
4143
5224
|
if (!next_buffer) {
|
4144
5225
|
continue;
|
4145
5226
|
}
|
@@ -4176,6 +5257,8 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
|
|
4176
5257
|
for (int i = 0; i < ctx->n_buffers; i++) {
|
4177
5258
|
[ctx->buffers[i].metal release];
|
4178
5259
|
}
|
5260
|
+
|
5261
|
+
ggml_backend_metal_buffer_rset_free(ctx);
|
4179
5262
|
ggml_backend_metal_device_rel(buffer->buft->device->context);
|
4180
5263
|
|
4181
5264
|
if (ctx->owned) {
|
@@ -4198,19 +5281,19 @@ static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
4198
5281
|
static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
4199
5282
|
memset((char *)tensor->data + offset, value, size);
|
4200
5283
|
|
4201
|
-
|
5284
|
+
GGML_UNUSED(buffer);
|
4202
5285
|
}
|
4203
5286
|
|
4204
5287
|
static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
4205
5288
|
memcpy((char *)tensor->data + offset, data, size);
|
4206
5289
|
|
4207
|
-
|
5290
|
+
GGML_UNUSED(buffer);
|
4208
5291
|
}
|
4209
5292
|
|
4210
5293
|
static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
4211
5294
|
memcpy(data, (const char *)tensor->data + offset, size);
|
4212
5295
|
|
4213
|
-
|
5296
|
+
GGML_UNUSED(buffer);
|
4214
5297
|
}
|
4215
5298
|
|
4216
5299
|
static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
|
@@ -4220,7 +5303,7 @@ static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, c
|
|
4220
5303
|
}
|
4221
5304
|
return false;
|
4222
5305
|
|
4223
|
-
|
5306
|
+
GGML_UNUSED(buffer);
|
4224
5307
|
}
|
4225
5308
|
|
4226
5309
|
static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
@@ -4246,7 +5329,7 @@ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
|
|
4246
5329
|
static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
4247
5330
|
return "Metal";
|
4248
5331
|
|
4249
|
-
|
5332
|
+
GGML_UNUSED(buft);
|
4250
5333
|
}
|
4251
5334
|
|
4252
5335
|
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
|
@@ -4270,8 +5353,8 @@ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t s
|
|
4270
5353
|
}
|
4271
5354
|
#endif
|
4272
5355
|
#endif
|
4273
|
-
|
4274
|
-
|
5356
|
+
GGML_UNUSED(device);
|
5357
|
+
GGML_UNUSED(size_aligned);
|
4275
5358
|
}
|
4276
5359
|
|
4277
5360
|
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
@@ -4284,7 +5367,8 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
4284
5367
|
size_aligned += (size_page - (size_aligned % size_page));
|
4285
5368
|
}
|
4286
5369
|
|
4287
|
-
|
5370
|
+
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
|
5371
|
+
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
4288
5372
|
|
4289
5373
|
ctx->all_data = ggml_metal_host_malloc(size_aligned);
|
4290
5374
|
ctx->all_size = size_aligned;
|
@@ -4307,7 +5391,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
4307
5391
|
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
|
4308
5392
|
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
4309
5393
|
free(ctx);
|
4310
|
-
ggml_backend_metal_device_rel(
|
5394
|
+
ggml_backend_metal_device_rel(ctx_dev);
|
5395
|
+
return NULL;
|
5396
|
+
}
|
5397
|
+
|
5398
|
+
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
5399
|
+
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
5400
|
+
free(ctx);
|
5401
|
+
ggml_backend_metal_device_rel(ctx_dev);
|
4311
5402
|
return NULL;
|
4312
5403
|
}
|
4313
5404
|
|
@@ -4318,7 +5409,7 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
4318
5409
|
|
4319
5410
|
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
4320
5411
|
return 32;
|
4321
|
-
|
5412
|
+
GGML_UNUSED(buft);
|
4322
5413
|
}
|
4323
5414
|
|
4324
5415
|
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
@@ -4328,13 +5419,13 @@ static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_ty
|
|
4328
5419
|
|
4329
5420
|
return max_size;
|
4330
5421
|
|
4331
|
-
|
5422
|
+
GGML_UNUSED(buft);
|
4332
5423
|
}
|
4333
5424
|
|
4334
5425
|
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
4335
5426
|
return true;
|
4336
5427
|
|
4337
|
-
|
5428
|
+
GGML_UNUSED(buft);
|
4338
5429
|
}
|
4339
5430
|
|
4340
5431
|
ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
@@ -4357,7 +5448,7 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
|
4357
5448
|
static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {
|
4358
5449
|
return "Metal_Mapped";
|
4359
5450
|
|
4360
|
-
|
5451
|
+
GGML_UNUSED(buft);
|
4361
5452
|
}
|
4362
5453
|
|
4363
5454
|
static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) {
|
@@ -4400,7 +5491,8 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
4400
5491
|
size_aligned += (size_page - (size_aligned % size_page));
|
4401
5492
|
}
|
4402
5493
|
|
4403
|
-
|
5494
|
+
struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
|
5495
|
+
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
4404
5496
|
|
4405
5497
|
// the buffer fits into the max buffer size allowed by the device
|
4406
5498
|
if (size_aligned <= device.maxBufferLength) {
|
@@ -4453,6 +5545,13 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
4453
5545
|
}
|
4454
5546
|
}
|
4455
5547
|
|
5548
|
+
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
5549
|
+
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
5550
|
+
free(ctx);
|
5551
|
+
ggml_backend_metal_device_rel(ctx_dev);
|
5552
|
+
return NULL;
|
5553
|
+
}
|
5554
|
+
|
4456
5555
|
return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
|
4457
5556
|
}
|
4458
5557
|
|
@@ -4461,7 +5560,7 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
4461
5560
|
static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
4462
5561
|
return "Metal";
|
4463
5562
|
|
4464
|
-
|
5563
|
+
GGML_UNUSED(backend);
|
4465
5564
|
}
|
4466
5565
|
|
4467
5566
|
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
@@ -4504,8 +5603,9 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
4504
5603
|
|
4505
5604
|
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
4506
5605
|
|
4507
|
-
id<MTLCommandBuffer>
|
4508
|
-
|
5606
|
+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
|
5607
|
+
|
5608
|
+
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
|
4509
5609
|
|
4510
5610
|
int node_start = 0;
|
4511
5611
|
int node_end = n_nodes_0;
|
@@ -4517,22 +5617,29 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
4517
5617
|
|
4518
5618
|
const bool should_capture = ctx->capture_next_compute;
|
4519
5619
|
|
5620
|
+
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
5621
|
+
ggml_metal_mem_pool_reset(mem_pool);
|
5622
|
+
|
4520
5623
|
for (int idx = node_start; idx < node_end; ++idx) {
|
4521
5624
|
if (should_capture) {
|
4522
5625
|
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
4523
5626
|
}
|
4524
5627
|
|
4525
|
-
ggml_metal_encode_node(backend, idx, encoder);
|
5628
|
+
const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
|
4526
5629
|
|
4527
5630
|
if (should_capture) {
|
4528
5631
|
[encoder popDebugGroup];
|
4529
5632
|
}
|
5633
|
+
|
5634
|
+
if (!res) {
|
5635
|
+
break;
|
5636
|
+
}
|
4530
5637
|
}
|
4531
5638
|
|
4532
5639
|
[encoder endEncoding];
|
4533
5640
|
|
4534
5641
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
4535
|
-
[
|
5642
|
+
[cmd_buf commit];
|
4536
5643
|
}
|
4537
5644
|
});
|
4538
5645
|
}
|
@@ -4766,6 +5873,13 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
|
|
4766
5873
|
}
|
4767
5874
|
}
|
4768
5875
|
|
5876
|
+
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
5877
|
+
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
5878
|
+
free(ctx);
|
5879
|
+
ggml_backend_metal_device_rel(ctx_dev);
|
5880
|
+
return NULL;
|
5881
|
+
}
|
5882
|
+
|
4769
5883
|
return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
|
4770
5884
|
}
|
4771
5885
|
|
@@ -4779,7 +5893,7 @@ static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml
|
|
4779
5893
|
return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
|
4780
5894
|
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
|
4781
5895
|
|
4782
|
-
|
5896
|
+
GGML_UNUSED(dev);
|
4783
5897
|
}
|
4784
5898
|
|
4785
5899
|
static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|